user.state is keyed by server name

This commit is contained in:
Min RK
2017-06-21 16:52:54 +02:00
parent 3145011004
commit 5263e4ceae
2 changed files with 19 additions and 9 deletions

View File

@@ -55,6 +55,7 @@ class Spawner(LoggingConfigurable):
user = Any() user = Any()
hub = Any() hub = Any()
authenticator = Any() authenticator = Any()
server = Any()
admin_access = Bool(False) admin_access = Bool(False)
api_token = Unicode() api_token = Unicode()
oauth_client_id = Unicode() oauth_client_id = Unicode()
@@ -355,11 +356,6 @@ class Spawner(LoggingConfigurable):
""" """
).tag(config=True) ).tag(config=True)
def __init__(self, **kwargs):
super(Spawner, self).__init__(**kwargs)
if self.user.state:
self.load_state(self.user.state)
def load_state(self, state): def load_state(self, state):
"""Restore state of spawner from database. """Restore state of spawner from database.

View File

@@ -1,7 +1,6 @@
# Copyright (c) Jupyter Development Team. # Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License. # Distributed under the terms of the Modified BSD License.
from collections import defaultdict
from datetime import datetime, timedelta from datetime import datetime, timedelta
from urllib.parse import quote, urlparse from urllib.parse import quote, urlparse
@@ -75,6 +74,15 @@ class UserDict(dict):
dict.__delitem__(self, user_id) dict.__delitem__(self, user_id)
class _SpawnerDict(dict):
def __init__(self, spawner_factory):
self.spawner_factory = spawner_factory
def __getitem__(self, key):
if key not in self:
self[key] = self.spawner_factory(key)
return super().__getitem__(key)
class User(HasTraits): class User(HasTraits):
@default('log') @default('log')
@@ -101,6 +109,7 @@ class User(HasTraits):
spawner.db = self.db spawner.db = self.db
orm_user = None orm_user = None
spawners = None
@property @property
def authenticator(self): def authenticator(self):
@@ -121,17 +130,19 @@ class User(HasTraits):
self.base_url = url_path_join( self.base_url = url_path_join(
self.settings.get('base_url', '/'), 'user', self.escaped_name) self.settings.get('base_url', '/'), 'user', self.escaped_name)
self.spawners = defaultdict(self._new_spawner) self.spawners = _SpawnerDict(self._new_spawner)
def _new_spawner(self): def _new_spawner(self, name):
"""Create a new spawner""" """Create a new spawner"""
return self.spawner_class( spawner = self.spawner_class(
user=self, user=self,
db=self.db, db=self.db,
hub=self.settings.get('hub'), hub=self.settings.get('hub'),
authenticator=self.authenticator, authenticator=self.authenticator,
config=self.settings.get('config'), config=self.settings.get('config'),
) )
spawner.load_state((self.state or {}).get(name, {}))
return spawner
# singleton property, self.spawner maps onto spawner with empty server_name # singleton property, self.spawner maps onto spawner with empty server_name
@property @property
@@ -324,6 +335,8 @@ class User(HasTraits):
spawner.start_polling() spawner.start_polling()
# store state # store state
if self.state is None:
self.state = {}
self.state[server_name] = spawner.get_state() self.state[server_name] = spawner.get_state()
self.last_activity = datetime.utcnow() self.last_activity = datetime.utcnow()
db.commit() db.commit()
@@ -379,6 +392,7 @@ class User(HasTraits):
self.last_activity = datetime.utcnow() self.last_activity = datetime.utcnow()
# remove server entry from db # remove server entry from db
self.db.delete(spawner.server.orm_server) self.db.delete(spawner.server.orm_server)
spawner.server = None
if not spawner.will_resume: if not spawner.will_resume:
# find and remove the API token if the spawner isn't # find and remove the API token if the spawner isn't
# going to re-use it next time # going to re-use it next time