diff --git a/jupyterhub/app.py b/jupyterhub/app.py index f2b17fee..8a61d1ba 100644 --- a/jupyterhub/app.py +++ b/jupyterhub/app.py @@ -146,18 +146,18 @@ class NewToken(Application): class UpgradeDB(Application): """Upgrade the JupyterHub database schema.""" - + name = 'jupyterhub-upgrade-db' version = jupyterhub.__version__ description = """Upgrade the JupyterHub database to the current schema. - + Usage: jupyterhub upgrade-db """ aliases = common_aliases classes = [] - + def _backup_db_file(self, db_file): """Backup a database file""" if not os.path.exists(db_file): @@ -171,7 +171,7 @@ class UpgradeDB(Application): backup_db_file = '{}.{}.{}'.format(db_file, timestamp, i) if os.path.exists(backup_db_file): self.exit("backup db file already exists: %s" % backup_db_file) - + self.log.info("Backing up %s => %s", db_file, backup_db_file) shutil.copy(db_file, backup_db_file) @@ -222,12 +222,12 @@ class JupyterHub(Application): Authenticator, PAMAuthenticator, ]) - + load_groups = Dict(List(Unicode()), help="""Dict of 'group': ['usernames'] to load at startup. - + This strictly *adds* groups and users to groups. - + Loading one set of groups, then starting JupyterHub again with a different set will not remove users or groups from previous launches. That must be done through the API. @@ -414,7 +414,7 @@ class JupyterHub(Application): api_tokens = Dict(Unicode(), help="""PENDING DEPRECATION: consider using service_tokens - + Dict of token:username to be loaded into the database. Allows ahead-of-time generation of API tokens for use by externally managed services, @@ -437,14 +437,14 @@ class JupyterHub(Application): Allows ahead-of-time generation of API tokens for use by externally managed services. """ ).tag(config=True) - + services = List(Dict(), help="""List of service specification dictionaries. - + A service - + For instance:: - + services = [ { 'name': 'cull_idle', @@ -454,7 +454,7 @@ class JupyterHub(Application): 'name': 'formgrader', 'url': 'http://127.0.0.1:1234', 'token': 'super-secret', - 'env': + 'env': } ] """ @@ -608,7 +608,7 @@ class JupyterHub(Application): Instance(logging.Handler), help="Extra log handlers to set on JupyterHub logger", ).tag(config=True) - + statsd = Any(allow_none=False, help="The statsd client, if any. A mock will be used if we aren't using statsd") @default('statsd') def _statsd(self): @@ -919,7 +919,7 @@ class JupyterHub(Application): # The whitelist set and the users in the db are now the same. # From this point on, any user changes should be done simultaneously # to the whitelist set and user db, unless the whitelist is empty (all users allowed). - + @gen.coroutine def init_groups(self): """Load predefined groups into the database""" @@ -941,7 +941,7 @@ class JupyterHub(Application): db.add(user) group.users.append(user) db.commit() - + @gen.coroutine def _add_tokens(self, token_dict, kind): """Add tokens for users or services to the database""" @@ -982,13 +982,13 @@ class JupyterHub(Application): else: self.log.debug("Not duplicating token %s", orm_token) db.commit() - + @gen.coroutine def init_api_tokens(self): """Load predefined API tokens (for services) into database""" yield self._add_tokens(self.service_tokens, kind='service') yield self._add_tokens(self.api_tokens, kind='user') - + def init_services(self): self._service_map.clear() if self.domain: @@ -1095,7 +1095,9 @@ class JupyterHub(Application): # if user.server is defined. log = self.log.warning if user.server else self.log.debug log("%s not running.", user.name) - user.server = None + for server in user.servers: + db.delete(server) + db.commit() user_summaries.append(_user_summary(user)) @@ -1458,7 +1460,7 @@ class JupyterHub(Application): except Exception as e: self.log.critical("Failed to start proxy", exc_info=True) self.exit(1) - + for service_name, service in self._service_map.items(): if not service.managed: continue diff --git a/jupyterhub/orm.py b/jupyterhub/orm.py index d2cc3c17..3f8e3baa 100644 --- a/jupyterhub/orm.py +++ b/jupyterhub/orm.py @@ -19,7 +19,7 @@ from sqlalchemy import ( from sqlalchemy.ext.declarative import declarative_base, declared_attr from sqlalchemy.orm import sessionmaker, relationship, backref from sqlalchemy.pool import StaticPool -from sqlalchemy.schema import Index +from sqlalchemy.schema import Index, UniqueConstraint from sqlalchemy.ext.associationproxy import association_proxy from sqlalchemy.sql.expression import bindparam from sqlalchemy import create_engine, Table @@ -32,11 +32,8 @@ from .utils import ( class JSONDict(TypeDecorator): """Represents an immutable structure as a json-encoded string. - Usage:: - JSONEncodedDict(255) - """ impl = TEXT @@ -59,7 +56,6 @@ Base.log = app_log class Server(Base): """The basic state of a server - connection and cookie info """ __tablename__ = 'servers' @@ -69,6 +65,7 @@ class Server(Base): port = Column(Integer, default=random_port) base_url = Column(Unicode(255), default='/') cookie_name = Column(Unicode(255), default='cookie') + # added to handle multi-server feature last_activity = Column(DateTime, default=datetime.utcnow) @@ -101,10 +98,8 @@ class Server(Base): @property def bind_url(self): """representation of URL used for binding - Never used in APIs, only logging, - since it can be non-connectable value, such as '', - meaning all interfaces. + since it can be non-connectable value, such as '', meaning all interfaces. """ if self.ip in {'', '0.0.0.0'}: return self.url.replace('127.0.0.1', self.ip or '*', 1) @@ -116,8 +111,7 @@ class Server(Base): if http: yield wait_for_http_server(self.url, timeout=timeout) else: - yield wait_for_server(self.ip or '127.0.0.1', self.port, - timeout=timeout) + yield wait_for_server(self.ip or '127.0.0.1', self.port, timeout=timeout) def is_up(self): """Is the server accepting connections?""" @@ -126,7 +120,6 @@ class Server(Base): class Proxy(Base): """A configurable-http-proxy instance. - A proxy consists of the API server info and the public-facing server info, plus an auth token for configuring the proxy table. """ @@ -155,10 +148,10 @@ class Proxy(Base): body = json.dumps(body) self.log.debug("Fetching %s %s", method, url) req = HTTPRequest(url, - method=method, - headers={'Authorization': 'token {}'.format(self.auth_token)}, - body=body, - ) + method=method, + headers={'Authorization': 'token {}'.format(self.auth_token)}, + body=body, + ) return client.fetch(req) @@ -170,48 +163,48 @@ class Proxy(Base): "Service %s does not have an http endpoint to add to the proxy.", service.name) self.log.info("Adding service %s to proxy %s => %s", - service.name, service.proxy_path, service.server.host, - ) + service.name, service.proxy_path, service.server.host, + ) yield self.api_request(service.proxy_path, - method='POST', - body=dict( - target=service.server.host, - service=service.name, - ), - client=client, - ) + method='POST', + body=dict( + target=service.server.host, + service=service.name, + ), + client=client, + ) @gen.coroutine def delete_service(self, service, client=None): """Remove a service's server from the proxy table.""" self.log.info("Removing service %s from proxy", service.name) yield self.api_request(service.proxy_path, - method='DELETE', - client=client, - ) + method='DELETE', + client=client, + ) - @gen.coroutine # FIX-ME # we need to add a reference to a specific server + @gen.coroutine def add_user(self, user, client=None): """Add a user's server to the proxy table.""" self.log.info("Adding user %s to proxy %s => %s", - user.name, user.proxy_path, user.server.host, - ) + user.name, user.proxy_path, user.server.host, + ) if user.spawn_pending: raise RuntimeError( - "User %s's spawn is pending, shouldn't be added to the proxy yet!", user.name) + "User %s's spawn is pending, shouldn't be added to the proxy yet!", user.name) yield self.api_request(user.proxy_path, - method='POST', - body=dict( - target=user.server.host, - user=user.name, - ), - client=client, - ) + method='POST', + body=dict( + target=user.server.host, + user=user.name, + ), + client=client, + ) @gen.coroutine def delete_user(self, user, client=None): @@ -225,7 +218,6 @@ class Proxy(Base): @gen.coroutine def add_all_services(self, service_dict): """Update the proxy table from the database. - Used when loading up a new proxy. """ db = inspect(self).session @@ -241,7 +233,6 @@ class Proxy(Base): @gen.coroutine def add_all_users(self, user_dict): """Update the proxy table from the database. - Used when loading up a new proxy. """ db = inspect(self).session @@ -299,11 +290,10 @@ class Proxy(Base): yield f + class Hub(Base): """Bring it all together at the hub. - The Hub is a server, plus its API path suffix - the api_url is the full URL plus the api_path suffix on the end of the server base_url. """ @@ -329,10 +319,9 @@ class Hub(Base): # user:group many:many mapping table user_group_map = Table('user_group_map', Base.metadata, - Column('user_id', ForeignKey('users.id'), primary_key=True), - Column('group_id', ForeignKey('groups.id'), primary_key=True), - ) - + Column('user_id', ForeignKey('users.id'), primary_key=True), + Column('group_id', ForeignKey('groups.id'), primary_key=True), +) class Group(Base): """User Groups""" @@ -345,47 +334,41 @@ class Group(Base): return "<%s %s (%i users)>" % ( self.__class__.__name__, self.name, len(self.users) ) - @classmethod def find(cls, db, name): """Find a group by name. - Returns None if not found. """ - return db.query(cls).filter(cls.name == name).first() + return db.query(cls).filter(cls.name==name).first() class User(Base): """The User table - - Each user can have more than one server, + Each user can have more than a single server, and multiple tokens used for authorization. - API tokens grant access to the Hub's REST API. These are used by single-user servers to authenticate requests, and external services to manipulate the Hub. - Cookies are set with a single ID. Resetting the Cookie ID invalidates all cookies, forcing user to login again. - - `server` returns the first entry for the users' servers. - A `state` column contains a JSON dict, used for restoring state of a Spawner. + `server` returns the first entry for the users' servers. + 'servers' is a list that contains a reference to the ser's Servers. """ __tablename__ = 'users' id = Column(Integer, primary_key=True, autoincrement=True) name = Column(Unicode(1023), unique=True) - admin = Column(Boolean, default=False) - last_activity = Column(DateTime, default=datetime.utcnow) servers = association_proxy("user_to_servers", "server", creator=lambda server: UserServer(server=server)) + admin = Column(Boolean, default=False) + last_activity = Column(DateTime, default=datetime.utcnow) + api_tokens = relationship("APIToken", backref="user") cookie_id = Column(Unicode(1023), default=new_token) # User.state is actually Spawner state - # We will need to figure something else - # out if/when we have multiple spawners per user + # We will need to figure something else out if/when we have multiple spawners per user state = Column(JSONDict) # Authenticators can store their state here: auth_state = Column(JSONDict) @@ -397,7 +380,6 @@ class User(Base): @property def server(self): """Returns the first element of servers. - Returns None if the list is empty. """ if len(self.servers) == 0: @@ -429,24 +411,19 @@ class User(Base): @classmethod def find(cls, db, name): """Find a user by name. - Returns None if not found. """ - return db.query(cls).filter(cls.name == name).first() + return db.query(cls).filter(cls.name==name).first() class UserServer(Base): """The UserServer table - Each user can have have more than one server, - we use this table to mantain the Many-To-Many - relationship between Users and Servers. + we use this table to mantain the Many-To-One + relationship between Users and Servers tables. - Cookies are set with a single ID. - Resetting the Cookie ID invalidates all cookies, forcing user to login again. - - A `state` column contains a JSON dict, - used for restoring state of a Spawner. + Servers can have only 1 user, this condition is mantained + by UniqueConstraint """ __tablename__ = 'users_servers' @@ -454,9 +431,12 @@ class UserServer(Base): _server_id = Column(Integer, ForeignKey('servers.id'), primary_key=True) user = relationship(User, backref=backref('user_to_servers', cascade='all, delete-orphan')) - server = relationship(Server, backref=backref('server_to_users', cascade='all, delete-orphan')) + server = relationship(Server, backref=backref('server_to_users', cascade='all, delete-orphan') + ) - __table_args__ = (Index('server_user_index', '_server_id', '_user_id'),) + __table_args__ = ( + UniqueConstraint('_server_id'), + Index('server_user_index', '_server_id', '_user_id'),) def __repr__(self): return "<{cls}({name}@{ip}:{port})>".format( @@ -469,21 +449,15 @@ class UserServer(Base): class Service(Base): """A service run with JupyterHub - A service is similar to a User without a Spawner. A service can have API tokens for accessing the Hub's API - It has: - - name - admin - api tokens - server (if proxied http endpoint) - In addition to what it has in common with users, a Service has extra info: - - pid: the process id (if managed) - """ __tablename__ = 'services' id = Column(Integer, primary_key=True, autoincrement=True) @@ -501,7 +475,6 @@ class Service(Base): def new_api_token(self, token=None): """Create a new API token - If `token` is given, load that token. """ return APIToken.new(token=token, service=self) @@ -509,7 +482,6 @@ class Service(Base): @classmethod def find(cls, db, name): """Find a service by name. - Returns None if not found. """ return db.query(cls).filter(cls.name==name).first() @@ -567,7 +539,6 @@ class APIToken(Base): @classmethod def find(cls, db, token, *, kind=None): """Find a token object by value. - Returns None if not found. `kind='user'` only returns API tokens for users diff --git a/jupyterhub/tests/conftest.py b/jupyterhub/tests/conftest.py index 4d298d3e..1ae99b39 100644 --- a/jupyterhub/tests/conftest.py +++ b/jupyterhub/tests/conftest.py @@ -30,8 +30,8 @@ def db(): _db = orm.new_session_factory('sqlite:///:memory:', echo=True)() user = orm.User( name=getuser(), - server=orm.Server(), ) + user.servers.append(orm.Server()) hub = orm.Hub( server=orm.Server(), ) diff --git a/jupyterhub/tests/test_orm.py b/jupyterhub/tests/test_orm.py index 338e2c3d..f201e8d8 100644 --- a/jupyterhub/tests/test_orm.py +++ b/jupyterhub/tests/test_orm.py @@ -54,7 +54,7 @@ def test_hub(db): port = 1234, base_url='/hubtest/', ), - + ) db.add(hub) db.commit() @@ -129,14 +129,14 @@ def test_service_server(db): service = orm.Service(name='has_servers') db.add(service) db.commit() - + assert service.server is None server = service.server = orm.Server() assert service assert server.id is None db.commit() assert isinstance(server.id, int) - + def test_token_find(db): service = db.query(orm.Service).first() @@ -172,17 +172,17 @@ def test_spawn_fails(db, io_loop): orm_user = orm.User(name='aeofel') db.add(orm_user) db.commit() - + class BadSpawner(MockSpawner): @gen.coroutine def start(self): raise RuntimeError("Split the party") - + user = User(orm_user, { 'spawner_class': BadSpawner, 'config': None, }) - + with pytest.raises(Exception) as exc: io_loop.run_sync(user.spawn) assert user.server is None @@ -192,7 +192,7 @@ def test_spawn_fails(db, io_loop): def test_groups(db): user = orm.User.find(db, name='aeofel') db.add(user) - + group = orm.Group(name='lives') db.add(group) db.commit() diff --git a/jupyterhub/user.py b/jupyterhub/user.py index eed8216f..7fcc2e29 100644 --- a/jupyterhub/user.py +++ b/jupyterhub/user.py @@ -18,23 +18,23 @@ from .spawner import LocalProcessSpawner class UserDict(dict): """Like defaultdict, but for users - + Getting by a user id OR an orm.User instance returns a User wrapper around the orm user. """ def __init__(self, db_factory, settings): self.db_factory = db_factory self.settings = settings super().__init__() - + @property def db(self): return self.db_factory() - + def __contains__(self, key): if isinstance(key, (User, orm.User)): key = key.id return dict.__contains__(self, key) - + def __getitem__(self, key): if isinstance(key, User): key = key.id @@ -63,7 +63,7 @@ class UserDict(dict): return dict.__getitem__(self, id) else: raise KeyError(repr(key)) - + def __delitem__(self, key): user = self[key] user_id = user.id @@ -74,13 +74,13 @@ class UserDict(dict): class User(HasTraits): - + @default('log') def _log_default(self): return app_log - + settings = Dict() - + db = Any(allow_none=True) @default('db') def _db_default(self): @@ -94,32 +94,32 @@ class User(HasTraits): id = self.orm_user.id self.orm_user = change['new'].query(orm.User).filter(orm.User.id==id).first() self.spawner.db = self.db - + orm_user = None spawner = None spawn_pending = False stop_pending = False waiting_for_response = False - + @property def authenticator(self): return self.settings.get('authenticator', None) - + @property def spawner_class(self): return self.settings.get('spawner_class', LocalProcessSpawner) - + def __init__(self, orm_user, settings, **kwargs): self.orm_user = orm_user self.settings = settings super().__init__(**kwargs) - + hub = self.db.query(orm.Hub).first() - + self.cookie_name = '%s-%s' % (hub.server.cookie_name, quote(self.name, safe='')) self.base_url = url_path_join( self.settings.get('base_url', '/'), 'user', self.escaped_name) - + self.spawner = self.spawner_class( user=self, db=self.db, @@ -127,24 +127,24 @@ class User(HasTraits): authenticator=self.authenticator, config=self.settings.get('config'), ) - + # pass get/setattr to ORM user - + def __getattr__(self, attr): if hasattr(self.orm_user, attr): return getattr(self.orm_user, attr) else: raise AttributeError(attr) - + def __setattr__(self, attr, value): if self.orm_user and hasattr(self.orm_user, attr): setattr(self.orm_user, attr, value) else: super().__setattr__(attr, value) - + def __repr__(self): return repr(self.orm_user) - + @property def running(self): """property for whether a user has a running server""" @@ -153,25 +153,25 @@ class User(HasTraits): if self.server is None: return False return True - + @property def escaped_name(self): """My name, escaped for use in URLs, cookies, etc.""" return quote(self.name, safe='@') - + @property def proxy_path(self): if self.settings.get('subdomain_host'): return url_path_join('/' + self.domain, self.base_url) else: return self.base_url - + @property def domain(self): """Get the domain for my server.""" # FIXME: escaped_name probably isn't escaped enough in general for a domain fragment return self.escaped_name + '.' + self.settings['domain'] - + @property def host(self): """Get the *host* for my server (proto://domain[:port])""" @@ -179,11 +179,11 @@ class User(HasTraits): parsed = urlparse(self.settings['subdomain_host']) h = '%s://%s.%s' % (parsed.scheme, self.escaped_name, parsed.netloc) return h - + @property def url(self): """My URL - + Full name.domain/path if using subdomains, otherwise just my /base/url """ if self.settings.get('subdomain_host'): @@ -193,22 +193,22 @@ class User(HasTraits): ) else: return self.base_url - + @gen.coroutine def spawn(self, options=None): """Start the user's spawner""" db = self.db - - self.server = orm.Server( + server = orm.Server( cookie_name=self.cookie_name, base_url=self.base_url, ) - db.add(self.server) + self.servers.append(server) + db.add(self) db.commit() - + api_token = self.new_api_token() db.commit() - + spawner = self.spawner spawner.user_options = options or {} # we are starting a new server, make sure it doesn't restore state @@ -294,7 +294,7 @@ class User(HasTraits): @gen.coroutine def stop(self): """Stop the user's spawner - + and cleanup after it. """ self.spawn_pending = False @@ -316,7 +316,8 @@ class User(HasTraits): orm_token = orm.APIToken.find(self.db, api_token) if orm_token: self.db.delete(orm_token) - self.server = None + for server in self.servers: + self.db.delete(server) self.db.commit() finally: self.stop_pending = False @@ -326,4 +327,3 @@ class User(HasTraits): yield gen.maybe_future( auth.post_spawn_stop(self, spawner) ) -