From feae3eacb1ce5ff175e36cd4d94314454bc4010a Mon Sep 17 00:00:00 2001 From: Matthias Bussonnier Date: Thu, 27 Jul 2017 11:47:02 -0700 Subject: [PATCH] Try enforcing methods overwrite at import time. Currently Spawners need to overwrite start, stop, poll. When this is not done, it will fail at runtime. This replicate this check at class definition time, meaning that potential errors will be caught way earlier. It also have not runtime cost as the check is a class definition time (ie often import time). This takes only effect on Python 3.6+ which introduce __init_subclass__, we could do it with metaclasses, but that's might be too complicated. If one want to create a class the avoid these restriction they can overwrite __init_subclass__ and not call the super() method. --- jupyterhub/spawner.py | 12 ++++++++++++ jupyterhub/tests/mocking.py | 12 ++++++++++-- jupyterhub/tests/test_spawner.py | 23 ++++++++++++++++++++++- 3 files changed, 44 insertions(+), 3 deletions(-) diff --git a/jupyterhub/spawner.py b/jupyterhub/spawner.py index 2ac5c8c8..b0e8e60c 100644 --- a/jupyterhub/spawner.py +++ b/jupyterhub/spawner.py @@ -56,6 +56,18 @@ class Spawner(LoggingConfigurable): orm_spawner = Any() user = Any() + def __init_subclass__(cls, **kwargs): + super().__init_subclass__() + + missing = [] + for attr in ('start','stop', 'poll'): + if getattr(Spawner, attr) is getattr(cls, attr): + missing.append(attr) + + if missing: + raise NotImplementedError("class `{}` needs to redefine the `start`," + "`stop` and `poll` methods. `{}` not redefined.".format(cls.__name__, '`, `'.join(missing))) + @property def server(self): if self.orm_spawner and self.orm_spawner.server: diff --git a/jupyterhub/tests/mocking.py b/jupyterhub/tests/mocking.py index 11d6472c..9e753bb5 100644 --- a/jupyterhub/tests/mocking.py +++ b/jupyterhub/tests/mocking.py @@ -76,15 +76,23 @@ class SlowSpawner(MockSpawner): class NeverSpawner(MockSpawner): """A spawner that will never start""" - + @default('start_timeout') def _start_timeout_default(self): return 1 - + def start(self): """Return a Future that will never finish""" return Future() + @gen.coroutine + def stop(self): + pass + + @gen.coroutine + def poll(self): + return 0 + class FormSpawner(MockSpawner): """A spawner that has an options form defined""" diff --git a/jupyterhub/tests/test_spawner.py b/jupyterhub/tests/test_spawner.py index e2e01eb7..c062e3da 100644 --- a/jupyterhub/tests/test_spawner.py +++ b/jupyterhub/tests/test_spawner.py @@ -18,7 +18,7 @@ from tornado import gen from ..user import User from ..objects import Hub from .. import spawner as spawnermod -from ..spawner import LocalProcessSpawner +from ..spawner import LocalProcessSpawner, Spawner from .. import orm from .utils import async_requests @@ -246,3 +246,24 @@ def test_shell_cmd(db, tmpdir, request): r.raise_for_status() env = r.json() assert env['TESTVAR'] == 'foo' + + +def test_inherit_overwrite(): + """On 3.6+ we check things are overwritten at import time + """ + if sys.version_info >= (3,6): + with pytest.raises(NotImplementedError): + class S(Spawner): + pass + + +def test_inherit_ok(): + class S(Spawner): + def start(): + pass + + def stop(): + pass + + def poll(): + pass