Try enforcing methods overwrite at import time.

Currently Spawners need to overwrite start, stop, poll. When this is not
done, it will fail at runtime.

This replicate this check at class definition time, meaning that
potential errors will be caught way earlier. It also have not runtime
cost as the check is a class definition time (ie often import time).

This takes only effect on Python 3.6+ which introduce __init_subclass__,
we could do it with metaclasses, but that's might be too complicated.

If one want to create a class the avoid these restriction they can
overwrite __init_subclass__ and not call the super() method.
This commit is contained in:
Matthias Bussonnier
2017-07-27 11:47:02 -07:00
parent cc24f36e80
commit feae3eacb1
3 changed files with 44 additions and 3 deletions

View File

@@ -56,6 +56,18 @@ class Spawner(LoggingConfigurable):
orm_spawner = Any() orm_spawner = Any()
user = Any() user = Any()
def __init_subclass__(cls, **kwargs):
super().__init_subclass__()
missing = []
for attr in ('start','stop', 'poll'):
if getattr(Spawner, attr) is getattr(cls, attr):
missing.append(attr)
if missing:
raise NotImplementedError("class `{}` needs to redefine the `start`,"
"`stop` and `poll` methods. `{}` not redefined.".format(cls.__name__, '`, `'.join(missing)))
@property @property
def server(self): def server(self):
if self.orm_spawner and self.orm_spawner.server: if self.orm_spawner and self.orm_spawner.server:

View File

@@ -85,6 +85,14 @@ class NeverSpawner(MockSpawner):
"""Return a Future that will never finish""" """Return a Future that will never finish"""
return Future() return Future()
@gen.coroutine
def stop(self):
pass
@gen.coroutine
def poll(self):
return 0
class FormSpawner(MockSpawner): class FormSpawner(MockSpawner):
"""A spawner that has an options form defined""" """A spawner that has an options form defined"""

View File

@@ -18,7 +18,7 @@ from tornado import gen
from ..user import User from ..user import User
from ..objects import Hub from ..objects import Hub
from .. import spawner as spawnermod from .. import spawner as spawnermod
from ..spawner import LocalProcessSpawner from ..spawner import LocalProcessSpawner, Spawner
from .. import orm from .. import orm
from .utils import async_requests from .utils import async_requests
@@ -246,3 +246,24 @@ def test_shell_cmd(db, tmpdir, request):
r.raise_for_status() r.raise_for_status()
env = r.json() env = r.json()
assert env['TESTVAR'] == 'foo' assert env['TESTVAR'] == 'foo'
def test_inherit_overwrite():
"""On 3.6+ we check things are overwritten at import time
"""
if sys.version_info >= (3,6):
with pytest.raises(NotImplementedError):
class S(Spawner):
pass
def test_inherit_ok():
class S(Spawner):
def start():
pass
def stop():
pass
def poll():
pass