mirror of
https://github.com/jupyterhub/jupyterhub.git
synced 2025-10-15 14:03:02 +00:00
Try enforcing methods overwrite at import time.
Currently Spawners need to overwrite start, stop, poll. When this is not done, it will fail at runtime. This replicate this check at class definition time, meaning that potential errors will be caught way earlier. It also have not runtime cost as the check is a class definition time (ie often import time). This takes only effect on Python 3.6+ which introduce __init_subclass__, we could do it with metaclasses, but that's might be too complicated. If one want to create a class the avoid these restriction they can overwrite __init_subclass__ and not call the super() method.
This commit is contained in:
@@ -56,6 +56,18 @@ class Spawner(LoggingConfigurable):
|
|||||||
orm_spawner = Any()
|
orm_spawner = Any()
|
||||||
user = Any()
|
user = Any()
|
||||||
|
|
||||||
|
def __init_subclass__(cls, **kwargs):
|
||||||
|
super().__init_subclass__()
|
||||||
|
|
||||||
|
missing = []
|
||||||
|
for attr in ('start','stop', 'poll'):
|
||||||
|
if getattr(Spawner, attr) is getattr(cls, attr):
|
||||||
|
missing.append(attr)
|
||||||
|
|
||||||
|
if missing:
|
||||||
|
raise NotImplementedError("class `{}` needs to redefine the `start`,"
|
||||||
|
"`stop` and `poll` methods. `{}` not redefined.".format(cls.__name__, '`, `'.join(missing)))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def server(self):
|
def server(self):
|
||||||
if self.orm_spawner and self.orm_spawner.server:
|
if self.orm_spawner and self.orm_spawner.server:
|
||||||
|
@@ -76,15 +76,23 @@ class SlowSpawner(MockSpawner):
|
|||||||
|
|
||||||
class NeverSpawner(MockSpawner):
|
class NeverSpawner(MockSpawner):
|
||||||
"""A spawner that will never start"""
|
"""A spawner that will never start"""
|
||||||
|
|
||||||
@default('start_timeout')
|
@default('start_timeout')
|
||||||
def _start_timeout_default(self):
|
def _start_timeout_default(self):
|
||||||
return 1
|
return 1
|
||||||
|
|
||||||
def start(self):
|
def start(self):
|
||||||
"""Return a Future that will never finish"""
|
"""Return a Future that will never finish"""
|
||||||
return Future()
|
return Future()
|
||||||
|
|
||||||
|
@gen.coroutine
|
||||||
|
def stop(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@gen.coroutine
|
||||||
|
def poll(self):
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
class FormSpawner(MockSpawner):
|
class FormSpawner(MockSpawner):
|
||||||
"""A spawner that has an options form defined"""
|
"""A spawner that has an options form defined"""
|
||||||
|
@@ -18,7 +18,7 @@ from tornado import gen
|
|||||||
from ..user import User
|
from ..user import User
|
||||||
from ..objects import Hub
|
from ..objects import Hub
|
||||||
from .. import spawner as spawnermod
|
from .. import spawner as spawnermod
|
||||||
from ..spawner import LocalProcessSpawner
|
from ..spawner import LocalProcessSpawner, Spawner
|
||||||
from .. import orm
|
from .. import orm
|
||||||
from .utils import async_requests
|
from .utils import async_requests
|
||||||
|
|
||||||
@@ -246,3 +246,24 @@ def test_shell_cmd(db, tmpdir, request):
|
|||||||
r.raise_for_status()
|
r.raise_for_status()
|
||||||
env = r.json()
|
env = r.json()
|
||||||
assert env['TESTVAR'] == 'foo'
|
assert env['TESTVAR'] == 'foo'
|
||||||
|
|
||||||
|
|
||||||
|
def test_inherit_overwrite():
|
||||||
|
"""On 3.6+ we check things are overwritten at import time
|
||||||
|
"""
|
||||||
|
if sys.version_info >= (3,6):
|
||||||
|
with pytest.raises(NotImplementedError):
|
||||||
|
class S(Spawner):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def test_inherit_ok():
|
||||||
|
class S(Spawner):
|
||||||
|
def start():
|
||||||
|
pass
|
||||||
|
|
||||||
|
def stop():
|
||||||
|
pass
|
||||||
|
|
||||||
|
def poll():
|
||||||
|
pass
|
||||||
|
Reference in New Issue
Block a user