Try enforcing methods overwrite at import time.

Currently Spawners need to overwrite start, stop, poll. When this is not
done, it will fail at runtime.

This replicate this check at class definition time, meaning that
potential errors will be caught way earlier. It also have not runtime
cost as the check is a class definition time (ie often import time).

This takes only effect on Python 3.6+ which introduce __init_subclass__,
we could do it with metaclasses, but that's might be too complicated.

If one want to create a class the avoid these restriction they can
overwrite __init_subclass__ and not call the super() method.
This commit is contained in:
Matthias Bussonnier
2017-07-27 11:47:02 -07:00
parent cc24f36e80
commit feae3eacb1
3 changed files with 44 additions and 3 deletions

View File

@@ -56,6 +56,18 @@ class Spawner(LoggingConfigurable):
orm_spawner = Any()
user = Any()
def __init_subclass__(cls, **kwargs):
super().__init_subclass__()
missing = []
for attr in ('start','stop', 'poll'):
if getattr(Spawner, attr) is getattr(cls, attr):
missing.append(attr)
if missing:
raise NotImplementedError("class `{}` needs to redefine the `start`,"
"`stop` and `poll` methods. `{}` not redefined.".format(cls.__name__, '`, `'.join(missing)))
@property
def server(self):
if self.orm_spawner and self.orm_spawner.server:

View File

@@ -76,15 +76,23 @@ class SlowSpawner(MockSpawner):
class NeverSpawner(MockSpawner):
"""A spawner that will never start"""
@default('start_timeout')
def _start_timeout_default(self):
return 1
def start(self):
"""Return a Future that will never finish"""
return Future()
@gen.coroutine
def stop(self):
pass
@gen.coroutine
def poll(self):
return 0
class FormSpawner(MockSpawner):
"""A spawner that has an options form defined"""

View File

@@ -18,7 +18,7 @@ from tornado import gen
from ..user import User
from ..objects import Hub
from .. import spawner as spawnermod
from ..spawner import LocalProcessSpawner
from ..spawner import LocalProcessSpawner, Spawner
from .. import orm
from .utils import async_requests
@@ -246,3 +246,24 @@ def test_shell_cmd(db, tmpdir, request):
r.raise_for_status()
env = r.json()
assert env['TESTVAR'] == 'foo'
def test_inherit_overwrite():
"""On 3.6+ we check things are overwritten at import time
"""
if sys.version_info >= (3,6):
with pytest.raises(NotImplementedError):
class S(Spawner):
pass
def test_inherit_ok():
class S(Spawner):
def start():
pass
def stop():
pass
def poll():
pass