diff --git a/jupyterhub/app.py b/jupyterhub/app.py index baa65f65..e1610211 100644 --- a/jupyterhub/app.py +++ b/jupyterhub/app.py @@ -454,7 +454,7 @@ class JupyterHub(Application): 'name': 'formgrader', 'url': 'http://127.0.0.1:1234', 'token': 'super-secret', - 'environment': + 'environment': } ] """ @@ -1095,7 +1095,10 @@ class JupyterHub(Application): # if user.server is defined. log = self.log.warning if user.server else self.log.debug log("%s not running.", user.name) - user.server = None + # remove all server or servers entry from db related to the user + for server in user.servers: + db.delete(server) + db.commit() user_summaries.append(_user_summary(user)) diff --git a/jupyterhub/orm.py b/jupyterhub/orm.py index f9827aad..40fccd5e 100644 --- a/jupyterhub/orm.py +++ b/jupyterhub/orm.py @@ -17,8 +17,10 @@ from sqlalchemy import ( DateTime, ) from sqlalchemy.ext.declarative import declarative_base, declared_attr -from sqlalchemy.orm import sessionmaker, relationship +from sqlalchemy.orm import sessionmaker, relationship, backref from sqlalchemy.pool import StaticPool +from sqlalchemy.schema import Index, UniqueConstraint +from sqlalchemy.ext.associationproxy import association_proxy from sqlalchemy.sql.expression import bindparam from sqlalchemy import create_engine, Table @@ -68,6 +70,9 @@ class Server(Base): base_url = Column(Unicode(255), default='/') cookie_name = Column(Unicode(255), default='cookie') + # added to handle multi-server feature + last_activity = Column(DateTime, default=datetime.utcnow) + def __repr__(self): return "" % (self.ip, self.port) @@ -181,6 +186,8 @@ class Proxy(Base): client=client, ) + # FIX-ME + # we need to add a reference to a specific server @gen.coroutine def add_user(self, user, client=None): """Add a user's server to the proxy table.""" @@ -248,6 +255,8 @@ class Proxy(Base): resp = yield self.api_request('', client=client) return json.loads(resp.body.decode('utf8', 'replace')) + # FIX-ME + # we need to add a reference to a specific server @gen.coroutine def check_routes(self, user_dict, service_dict, routes=None): """Check that all users are properly routed on the proxy""" @@ -268,7 +277,7 @@ class Proxy(Base): if user.name in user_routes: self.log.warning("Removing route for not running %s", user.name) futures.append(self.delete_user(user)) - + # check service routes service_routes = { r['service'] for r in routes.values() if 'service' in r } for orm_service in db.query(Service).filter(Service.server != None): @@ -326,7 +335,7 @@ class Group(Base): id = Column(Integer, primary_key=True, autoincrement=True) name = Column(Unicode(1023), unique=True) users = relationship('User', secondary='user_group_map', back_populates='groups') - + def __repr__(self): return "<%s %s (%i users)>" % ( self.__class__.__name__, self.name, len(self.users) @@ -334,7 +343,6 @@ class Group(Base): @classmethod def find(cls, db, name): """Find a group by name. - Returns None if not found. """ return db.query(cls).filter(cls.name==name).first() @@ -343,8 +351,10 @@ class Group(Base): class User(Base): """The User table - Each user has a single server, - and multiple tokens used for authorization. + Each user can have one or more single user notebook servers. + + Each single user notebook server will have a unique token for authorization. + Therefore, a user with multiple notebook servers will have multiple tokens. API tokens grant access to the Hub's REST API. These are used by single-user servers to authenticate requests, @@ -355,13 +365,17 @@ class User(Base): A `state` column contains a JSON dict, used for restoring state of a Spawner. + + + `servers` is a list that contains a reference for each of the user's single user notebook servers. + The method `server` returns the first entry in the user's `servers` list. """ __tablename__ = 'users' id = Column(Integer, primary_key=True, autoincrement=True) name = Column(Unicode(1023), unique=True) - # should we allow multiple servers per user? - _server_id = Column(Integer, ForeignKey('servers.id', ondelete="SET NULL")) - server = relationship(Server, primaryjoin=_server_id == Server.id) + + servers = association_proxy("user_to_servers", "server", creator=lambda server: UserServer(server=server)) + admin = Column(Boolean, default=False) last_activity = Column(DateTime, default=datetime.utcnow) @@ -377,6 +391,16 @@ class User(Base): other_user_cookies = set([]) + @property + def server(self): + """Returns the first element of servers. + Returns None if the list is empty. + """ + if len(self.servers) == 0: + return None + else: + return self.servers[0] + def __repr__(self): if self.server: return "<{cls}({name}@{ip}:{port})>".format( @@ -393,7 +417,7 @@ class User(Base): def new_api_token(self, token=None): """Create a new API token - + If `token` is given, load that token. """ return APIToken.new(token=token, user=self) @@ -401,12 +425,40 @@ class User(Base): @classmethod def find(cls, db, name): """Find a user by name. - Returns None if not found. """ return db.query(cls).filter(cls.name==name).first() +class UserServer(Base): + """The UserServer table + + A table storing the One-To-Many relationship between a user and servers. + Each user may have one or more servers. + A server can have only one (1) user. This condition is maintained by UniqueConstraint. + """ + __tablename__ = 'users_servers' + + _user_id = Column(Integer, ForeignKey('users.id'), primary_key=True) + _server_id = Column(Integer, ForeignKey('servers.id'), primary_key=True) + + user = relationship(User, backref=backref('user_to_servers', cascade='all, delete-orphan')) + server = relationship(Server, backref=backref('server_to_users', cascade='all, delete-orphan') + ) + + __table_args__ = ( + UniqueConstraint('_server_id'), + Index('server_user_index', '_server_id', '_user_id'),) + + def __repr__(self): + return "<{cls}({name}@{ip}:{port})>".format( + cls=self.__class__.__name__, + name=self.user.name, + ip=self.server.ip, + port=self.server.port, + ) + + class Service(Base): """A service run with JupyterHub @@ -414,7 +466,6 @@ class Service(Base): A service can have API tokens for accessing the Hub's API It has: - - name - admin - api tokens @@ -427,7 +478,7 @@ class Service(Base): """ __tablename__ = 'services' id = Column(Integer, primary_key=True, autoincrement=True) - + # common user interface: name = Column(Unicode(1023), unique=True) admin = Column(Boolean, default=False) @@ -441,11 +492,10 @@ class Service(Base): def new_api_token(self, token=None): """Create a new API token - If `token` is given, load that token. """ return APIToken.new(token=token, service=self) - + @classmethod def find(cls, db, name): """Find a service by name. @@ -458,7 +508,7 @@ class Service(Base): class APIToken(Base): """An API token""" __tablename__ = 'api_tokens' - + # _constraint = ForeignKeyConstraint(['user_id', 'server_id'], ['users.id', 'services.id']) @declared_attr def user_id(cls): @@ -509,7 +559,7 @@ class APIToken(Base): """Find a token object by value. Returns None if not found. - + `kind='user'` only returns API tokens for users `kind='service'` only returns API tokens for services """ diff --git a/jupyterhub/tests/conftest.py b/jupyterhub/tests/conftest.py index 4d298d3e..1ae99b39 100644 --- a/jupyterhub/tests/conftest.py +++ b/jupyterhub/tests/conftest.py @@ -30,8 +30,8 @@ def db(): _db = orm.new_session_factory('sqlite:///:memory:', echo=True)() user = orm.User( name=getuser(), - server=orm.Server(), ) + user.servers.append(orm.Server()) hub = orm.Hub( server=orm.Server(), ) diff --git a/jupyterhub/tests/test_orm.py b/jupyterhub/tests/test_orm.py index 182160e4..f201e8d8 100644 --- a/jupyterhub/tests/test_orm.py +++ b/jupyterhub/tests/test_orm.py @@ -65,9 +65,10 @@ def test_hub(db): def test_user(db): user = orm.User(name='kaylee', - server=orm.Server(), state={'pid': 4234}, ) + server = orm.Server() + user.servers.append(server) db.add(user) db.commit() assert user.name == 'kaylee' diff --git a/jupyterhub/user.py b/jupyterhub/user.py index 78bac2f4..0a9883d6 100644 --- a/jupyterhub/user.py +++ b/jupyterhub/user.py @@ -18,23 +18,23 @@ from .spawner import LocalProcessSpawner class UserDict(dict): """Like defaultdict, but for users - + Getting by a user id OR an orm.User instance returns a User wrapper around the orm user. """ def __init__(self, db_factory, settings): self.db_factory = db_factory self.settings = settings super().__init__() - + @property def db(self): return self.db_factory() - + def __contains__(self, key): if isinstance(key, (User, orm.User)): key = key.id return dict.__contains__(self, key) - + def __getitem__(self, key): if isinstance(key, User): key = key.id @@ -63,7 +63,7 @@ class UserDict(dict): return dict.__getitem__(self, id) else: raise KeyError(repr(key)) - + def __delitem__(self, key): user = self[key] user_id = user.id @@ -74,13 +74,13 @@ class UserDict(dict): class User(HasTraits): - + @default('log') def _log_default(self): return app_log - + settings = Dict() - + db = Any(allow_none=True) @default('db') def _db_default(self): @@ -94,32 +94,32 @@ class User(HasTraits): id = self.orm_user.id self.orm_user = change['new'].query(orm.User).filter(orm.User.id==id).first() self.spawner.db = self.db - + orm_user = None spawner = None spawn_pending = False stop_pending = False waiting_for_response = False - + @property def authenticator(self): return self.settings.get('authenticator', None) - + @property def spawner_class(self): return self.settings.get('spawner_class', LocalProcessSpawner) - + def __init__(self, orm_user, settings, **kwargs): self.orm_user = orm_user self.settings = settings super().__init__(**kwargs) - + hub = self.db.query(orm.Hub).first() - + self.cookie_name = '%s-%s' % (hub.server.cookie_name, quote(self.name, safe='')) self.base_url = url_path_join( self.settings.get('base_url', '/'), 'user', self.escaped_name) - + self.spawner = self.spawner_class( user=self, db=self.db, @@ -127,24 +127,24 @@ class User(HasTraits): authenticator=self.authenticator, config=self.settings.get('config'), ) - + # pass get/setattr to ORM user - + def __getattr__(self, attr): if hasattr(self.orm_user, attr): return getattr(self.orm_user, attr) else: raise AttributeError(attr) - + def __setattr__(self, attr, value): if self.orm_user and hasattr(self.orm_user, attr): setattr(self.orm_user, attr, value) else: super().__setattr__(attr, value) - + def __repr__(self): return repr(self.orm_user) - + @property def running(self): """property for whether a user has a running server""" @@ -153,25 +153,25 @@ class User(HasTraits): if self.server is None: return False return True - + @property def escaped_name(self): """My name, escaped for use in URLs, cookies, etc.""" return quote(self.name, safe='@') - + @property def proxy_path(self): if self.settings.get('subdomain_host'): return url_path_join('/' + self.domain, self.base_url) else: return self.base_url - + @property def domain(self): """Get the domain for my server.""" # FIXME: escaped_name probably isn't escaped enough in general for a domain fragment return self.escaped_name + '.' + self.settings['domain'] - + @property def host(self): """Get the *host* for my server (proto://domain[:port])""" @@ -179,11 +179,11 @@ class User(HasTraits): parsed = urlparse(self.settings['subdomain_host']) h = '%s://%s.%s' % (parsed.scheme, self.escaped_name, parsed.netloc) return h - + @property def url(self): """My URL - + Full name.domain/path if using subdomains, otherwise just my /base/url """ if self.settings.get('subdomain_host'): @@ -193,22 +193,22 @@ class User(HasTraits): ) else: return self.base_url - + @gen.coroutine def spawn(self, options=None): """Start the user's spawner""" db = self.db - - self.server = orm.Server( + server = orm.Server( cookie_name=self.cookie_name, base_url=self.base_url, ) - db.add(self.server) + self.servers.append(server) + db.add(self) db.commit() - + api_token = self.new_api_token() db.commit() - + spawner = self.spawner spawner.user_options = options or {} # we are starting a new server, make sure it doesn't restore state @@ -300,7 +300,7 @@ class User(HasTraits): @gen.coroutine def stop(self): """Stop the user's spawner - + and cleanup after it. """ self.spawn_pending = False @@ -315,11 +315,10 @@ class User(HasTraits): spawner.clear_state() self.state = spawner.get_state() self.last_activity = datetime.utcnow() - # cleanup server entry, API token from defunct server - if self.server: - # cleanup server entry from db - self.db.delete(self.server) - self.server = None + # Cleanup defunct servers: delete entry and API token for each server + for server in self.servers: + # remove server entry from db + self.db.delete(server) if not spawner.will_resume: # find and remove the API token if the spawner isn't # going to re-use it next time @@ -335,4 +334,3 @@ class User(HasTraits): yield gen.maybe_future( auth.post_spawn_stop(self, spawner) ) -