diff --git a/jupyterhub/spawner.py b/jupyterhub/spawner.py index 2ac5c8c8..b0e8e60c 100644 --- a/jupyterhub/spawner.py +++ b/jupyterhub/spawner.py @@ -56,6 +56,18 @@ class Spawner(LoggingConfigurable): orm_spawner = Any() user = Any() + def __init_subclass__(cls, **kwargs): + super().__init_subclass__() + + missing = [] + for attr in ('start','stop', 'poll'): + if getattr(Spawner, attr) is getattr(cls, attr): + missing.append(attr) + + if missing: + raise NotImplementedError("class `{}` needs to redefine the `start`," + "`stop` and `poll` methods. `{}` not redefined.".format(cls.__name__, '`, `'.join(missing))) + @property def server(self): if self.orm_spawner and self.orm_spawner.server: diff --git a/jupyterhub/tests/mocking.py b/jupyterhub/tests/mocking.py index 11d6472c..9e753bb5 100644 --- a/jupyterhub/tests/mocking.py +++ b/jupyterhub/tests/mocking.py @@ -76,15 +76,23 @@ class SlowSpawner(MockSpawner): class NeverSpawner(MockSpawner): """A spawner that will never start""" - + @default('start_timeout') def _start_timeout_default(self): return 1 - + def start(self): """Return a Future that will never finish""" return Future() + @gen.coroutine + def stop(self): + pass + + @gen.coroutine + def poll(self): + return 0 + class FormSpawner(MockSpawner): """A spawner that has an options form defined""" diff --git a/jupyterhub/tests/test_spawner.py b/jupyterhub/tests/test_spawner.py index e2e01eb7..c062e3da 100644 --- a/jupyterhub/tests/test_spawner.py +++ b/jupyterhub/tests/test_spawner.py @@ -18,7 +18,7 @@ from tornado import gen from ..user import User from ..objects import Hub from .. import spawner as spawnermod -from ..spawner import LocalProcessSpawner +from ..spawner import LocalProcessSpawner, Spawner from .. import orm from .utils import async_requests @@ -246,3 +246,24 @@ def test_shell_cmd(db, tmpdir, request): r.raise_for_status() env = r.json() assert env['TESTVAR'] == 'foo' + + +def test_inherit_overwrite(): + """On 3.6+ we check things are overwritten at import time + """ + if sys.version_info >= (3,6): + with pytest.raises(NotImplementedError): + class S(Spawner): + pass + + +def test_inherit_ok(): + class S(Spawner): + def start(): + pass + + def stop(): + pass + + def poll(): + pass