diff --git a/jupyterhub/objects.py b/jupyterhub/objects.py index 15e2b420..83aab18f 100644 --- a/jupyterhub/objects.py +++ b/jupyterhub/objects.py @@ -62,6 +62,11 @@ class Server(HasTraits): return self.connect_port return self.port + @classmethod + def from_orm(cls, orm_server): + """Create a server from an orm.Server""" + return cls(orm_server=orm_server) + @classmethod def from_url(cls, url): """Create a Server from a given URL""" diff --git a/jupyterhub/services/service.py b/jupyterhub/services/service.py index a30f2532..4df13e04 100644 --- a/jupyterhub/services/service.py +++ b/jupyterhub/services/service.py @@ -235,7 +235,7 @@ class Service(LoggingConfigurable): @property def server(self): if self.orm.server: - return Server(orm_server=self.orm.server) + return Server.from_orm(self.orm.server) else: return None diff --git a/jupyterhub/spawner.py b/jupyterhub/spawner.py index 096e23f7..8272b4d2 100644 --- a/jupyterhub/spawner.py +++ b/jupyterhub/spawner.py @@ -15,13 +15,15 @@ import warnings from subprocess import Popen from tempfile import mkdtemp +from sqlalchemy import inspect + from tornado import gen from tornado.ioloop import PeriodicCallback, IOLoop from traitlets.config import LoggingConfigurable from traitlets import ( Any, Bool, Dict, Instance, Integer, Float, List, Unicode, - validate, + observe, validate, ) from .objects import Server @@ -89,6 +91,14 @@ class Spawner(LoggingConfigurable): authenticator = Any() hub = Any() orm_spawner = Any() + + @observe('orm_spawner') + def _orm_spawner_changed(self, change): + if change.new and change.new.server: + self._server = Server(orm_server=change.new.server) + else: + self._server = None + user = Any() def __init_subclass__(cls, **kwargs): @@ -105,8 +115,24 @@ class Spawner(LoggingConfigurable): @property def server(self): + if hasattr(self, '_server'): + return self._server if self.orm_spawner and self.orm_spawner.server: return Server(orm_server=self.orm_spawner.server) + + @server.setter + def server(self, server): + self._server = server + if self.orm_spawner: + if self.orm_spawner.server is not None: + # delete the old value + db = inspect(self.orm_spawner.server).session + db.delete(self.orm_spawner.server) + if server is None: + self.orm_spawner.server = None + else: + self.orm_spawner.server = server.orm_server + @property def name(self): if self.orm_spawner: diff --git a/jupyterhub/tests/test_spawner.py b/jupyterhub/tests/test_spawner.py index c062e3da..dc238572 100644 --- a/jupyterhub/tests/test_spawner.py +++ b/jupyterhub/tests/test_spawner.py @@ -16,7 +16,7 @@ import pytest from tornado import gen from ..user import User -from ..objects import Hub +from ..objects import Hub, Server from .. import spawner as spawnermod from ..spawner import LocalProcessSpawner, Spawner from .. import orm @@ -234,7 +234,10 @@ def test_shell_cmd(db, tmpdir, request): cmd=[sys.executable, '-m', 'jupyterhub.tests.mocksu'], shell_cmd=['bash', '--rcfile', str(f), '-i', '-c'], ) - s.orm_spawner.server = orm.Server() + server = orm.Server() + db.add(server) + db.commit() + s.server = Server.from_orm(server) db.commit() (ip, port) = yield s.start() request.addfinalizer(s.stop) diff --git a/jupyterhub/user.py b/jupyterhub/user.py index 128ab3fc..247c0245 100644 --- a/jupyterhub/user.py +++ b/jupyterhub/user.py @@ -9,12 +9,13 @@ from oauth2.error import ClientNotFoundError from sqlalchemy import inspect from tornado import gen from tornado.log import app_log +from traitlets import HasTraits, Any, Dict, default from .utils import url_path_join, default_server_name from . import orm from ._version import _check_version, __version__ -from traitlets import HasTraits, Any, Dict, observe, default +from .objects import Server from .spawner import LocalProcessSpawner from .crypto import encrypt, decrypt, CryptKeeper, EncryptionUnavailable, InvalidToken @@ -314,8 +315,8 @@ class User(HasTraits): spawner = self.spawners[server_name] - spawner.server = Server(orm_server=orm_server) - assert orm_spawner.server is orm_server + spawner.server = server = Server(orm_server=orm_server) + assert spawner.orm_spawner.server is orm_server # Passing user_options to the spawner spawner.user_options = options or {} @@ -452,8 +453,6 @@ class User(HasTraits): spawner.orm_spawner.state = spawner.get_state() self.last_activity = datetime.utcnow() # remove server entry from db - if spawner.server is not None: - self.db.delete(spawner.orm_spawner.server) spawner.server = None if not spawner.will_resume: # find and remove the API token if the spawner isn't