diff --git a/jupyterhub/apihandlers/auth.py b/jupyterhub/apihandlers/auth.py index 4fe77383..aefc98bf 100644 --- a/jupyterhub/apihandlers/auth.py +++ b/jupyterhub/apihandlers/auth.py @@ -18,7 +18,7 @@ class TokenAPIHandler(APIHandler): orm_token = orm.APIToken.find(self.db, token) if orm_token is None: raise web.HTTPError(404) - self.write(json.dumps(self.user_model(orm_token.user))) + self.write(json.dumps(self.user_model(self.users[orm_token.user]))) class CookieAPIHandler(APIHandler): diff --git a/jupyterhub/apihandlers/users.py b/jupyterhub/apihandlers/users.py index b207509b..00d0585b 100644 --- a/jupyterhub/apihandlers/users.py +++ b/jupyterhub/apihandlers/users.py @@ -15,7 +15,7 @@ from .base import APIHandler class UserListAPIHandler(APIHandler): @admin_only def get(self): - users = self.db.query(orm.User) + users = [ self._user_from_orm(u) for u in self.db.query(orm.User) ] data = [ self.user_model(u) for u in users ] self.write(json.dumps(data)) @@ -104,6 +104,8 @@ class UserAPIHandler(APIHandler): yield gen.maybe_future(self.authenticator.add_user(user)) except Exception: self.log.error("Failed to create user: %s" % name, exc_info=True) + # remove from registry + self.users.pop(user.id, None) self.db.delete(user) self.db.commit() raise web.HTTPError(400, "Failed to create user: %s" % name) @@ -127,7 +129,8 @@ class UserAPIHandler(APIHandler): raise web.HTTPError(400, "%s's server is in the process of stopping, please wait." % name) yield gen.maybe_future(self.authenticator.delete_user(user)) - + # remove from registry + self.users.pop(user.id, None) # remove from the db self.db.delete(user) self.db.commit() diff --git a/jupyterhub/app.py b/jupyterhub/app.py index c84e645b..e1f7a0c5 100644 --- a/jupyterhub/app.py +++ b/jupyterhub/app.py @@ -45,6 +45,7 @@ from . import handlers, apihandlers from .handlers.static import CacheControlStaticFilesHandler from . import orm +from .user import User, UserDict from ._data import DATA_FILES_PATH from .log import CoroutineLogFormatter, log_request from .traitlets import URLPrefix, Command @@ -349,6 +350,10 @@ class JupyterHub(Application): ) session_factory = Any() + users = Instance(UserDict) + def _users_default(self): + return UserDict(db_factory=lambda : self.db) + admin_access = Bool(False, config=True, help="""Grant admin users permission to access single-user servers. @@ -699,7 +704,8 @@ class JupyterHub(Application): yield self.proxy.delete_user(user) yield user.stop() - for user in db.query(orm.User): + for orm_user in db.query(orm.User): + self.users[orm_user.id] = user = User(orm_user) if not user.state: # without spawner state, server isn't valid user.server = None @@ -854,6 +860,7 @@ class JupyterHub(Application): proxy=self.proxy, hub=self.hub, admin_users=self.authenticator.admin_users, + users=self.users, admin_access=self.admin_access, authenticator=self.authenticator, spawner_class=self.spawner_class, @@ -921,7 +928,7 @@ class JupyterHub(Application): if self.cleanup_servers: self.log.info("Cleaning up single-user servers...") # request (async) process termination - for user in self.db.query(orm.User): + for uid, user in self.users.items(): if user.spawner is not None: futures.append(user.stop()) else: diff --git a/jupyterhub/handlers/base.py b/jupyterhub/handlers/base.py index f6914ed8..987bfdf2 100644 --- a/jupyterhub/handlers/base.py +++ b/jupyterhub/handlers/base.py @@ -4,7 +4,7 @@ # Distributed under the terms of the Modified BSD License. import re -from datetime import datetime, timedelta +from datetime import timedelta from http.client import responses from jinja2 import TemplateNotFound @@ -16,6 +16,7 @@ from tornado.web import RequestHandler from tornado import gen, web from .. import orm +from ..user import User from ..spawner import LocalProcessSpawner from ..utils import url_path_join @@ -53,7 +54,11 @@ class BaseHandler(RequestHandler): @property def db(self): return self.settings['db'] - + + @property + def users(self): + return self.settings.setdefault('users', {}) + @property def hub(self): return self.settings['hub'] @@ -145,13 +150,20 @@ class BaseHandler(RequestHandler): clear() return cookie_id = cookie_id.decode('utf8', 'replace') - user = self.db.query(orm.User).filter(orm.User.cookie_id==cookie_id).first() + u = self.db.query(orm.User).filter(orm.User.cookie_id==cookie_id).first() + user = self._user_from_orm(u) if user is None: self.log.warn("Invalid cookie token") # have cookie, but it's not valid. Clear it and start over. clear() return user + def _user_from_orm(self, orm_user): + """return User wrapper from orm.User object""" + if orm_user is None: + return + return self.users[orm_user] + def get_current_user_cookie(self): """get_current_user from a cookie token""" return self._user_for_cookie(self.hub.server.cookie_name) @@ -168,15 +180,18 @@ class BaseHandler(RequestHandler): return None if no such user """ - return orm.User.find(self.db, name) + orm_user = orm.User.find(db=self.db, name=name) + return self._user_from_orm(orm_user) def user_from_username(self, username): - """Get ORM User for username""" + """Get User for username, creating if it doesn't exist""" user = self.find_user(username) if user is None: - user = orm.User(name=username) - self.db.add(user) + # not found, create and register user + u = orm.User(name=username) + self.db.add(u) self.db.commit() + user = self._user_from_orm(u) return user def clear_login_cookie(self, name=None): @@ -259,7 +274,7 @@ class BaseHandler(RequestHandler): if user.spawn_pending: raise RuntimeError("Spawn already pending for: %s" % user.name) tic = IOLoop.current().time() - + f = user.spawn( spawner_class=self.spawner_class, base_url=self.base_url, diff --git a/jupyterhub/handlers/pages.py b/jupyterhub/handlers/pages.py index 2d13db8e..84330600 100644 --- a/jupyterhub/handlers/pages.py +++ b/jupyterhub/handlers/pages.py @@ -90,7 +90,8 @@ class AdminHandler(BaseHandler): ordered = [ getattr(c, o)() for c, o in zip(cols, orders) ] users = self.db.query(orm.User).order_by(*ordered) - running = users.filter(orm.User.server != None) + users = [ self._user_from_orm(u) for u in users ] + running = [ u for u in users if u.running ] html = self.render_template('admin.html', user=self.get_current_user(), diff --git a/jupyterhub/orm.py b/jupyterhub/orm.py index 141857a6..bcbf6308 100644 --- a/jupyterhub/orm.py +++ b/jupyterhub/orm.py @@ -3,15 +3,14 @@ # Copyright (c) Jupyter Development Team. # Distributed under the terms of the Modified BSD License. -from datetime import datetime, timedelta +from datetime import datetime import errno import json import socket -from urllib.parse import quote from tornado import gen from tornado.log import app_log -from tornado.httpclient import HTTPRequest, AsyncHTTPClient, HTTPError +from tornado.httpclient import HTTPRequest, AsyncHTTPClient from sqlalchemy.types import TypeDecorator, VARCHAR from sqlalchemy import ( @@ -271,7 +270,7 @@ class User(Base): used for restoring state of a Spawner. """ __tablename__ = 'users' - id = Column(Integer, primary_key=True) + id = Column(Integer, primary_key=True, autoincrement=True) name = Column(Unicode) # should we allow multiple servers per user? _server_id = Column(Integer, ForeignKey('servers.id')) @@ -282,12 +281,9 @@ class User(Base): api_tokens = relationship("APIToken", backref="user") cookie_id = Column(Unicode, default=new_token) state = Column(JSONDict) - spawner = None - spawn_pending = False - stop_pending = False other_user_cookies = set([]) - + def __repr__(self): if self.server: return "<{cls}({name}@{ip}:{port})>".format( @@ -302,20 +298,6 @@ class User(Base): name=self.name, ) - @property - def escaped_name(self): - """My name, escaped for use in URLs, cookies, etc.""" - return quote(self.name, safe='@') - - @property - def running(self): - """property for whether a user has a running server""" - if self.spawner is None: - return False - if self.server is None: - return False - return True - def new_api_token(self): """Create a new API token""" assert self.id is not None @@ -326,7 +308,7 @@ class User(Base): db.add(orm_token) db.commit() return token - + @classmethod def find(cls, db, name): """Find a user by name. @@ -334,125 +316,6 @@ class User(Base): Returns None if not found. """ return db.query(cls).filter(cls.name==name).first() - - @gen.coroutine - def spawn(self, spawner_class, base_url='/', hub=None, authenticator=None, config=None): - """Start the user's spawner""" - db = inspect(self).session - if hub is None: - hub = db.query(Hub).first() - - self.server = Server( - cookie_name='%s-%s' % (hub.server.cookie_name, quote(self.name, safe='')), - base_url=url_path_join(base_url, 'user', self.escaped_name), - ) - db.add(self.server) - db.commit() - - api_token = self.new_api_token() - db.commit() - - spawner = self.spawner = spawner_class( - config=config, - user=self, - hub=hub, - db=db, - authenticator=authenticator, - ) - # we are starting a new server, make sure it doesn't restore state - spawner.clear_state() - spawner.api_token = api_token - - # trigger pre-spawn hook on authenticator - if (authenticator): - yield gen.maybe_future(authenticator.pre_spawn_start(self, spawner)) - self.spawn_pending = True - # wait for spawner.start to return - try: - f = spawner.start() - yield gen.with_timeout(timedelta(seconds=spawner.start_timeout), f) - except Exception as e: - if isinstance(e, gen.TimeoutError): - self.log.warn("{user}'s server failed to start in {s} seconds, giving up".format( - user=self.name, s=spawner.start_timeout, - )) - e.reason = 'timeout' - else: - self.log.error("Unhandled error starting {user}'s server: {error}".format( - user=self.name, error=e, - )) - e.reason = 'error' - try: - yield self.stop() - except Exception: - self.log.error("Failed to cleanup {user}'s server that failed to start".format( - user=self.name, - ), exc_info=True) - # raise original exception - raise e - spawner.start_polling() - - # store state - self.state = spawner.get_state() - self.last_activity = datetime.utcnow() - db.commit() - try: - yield self.server.wait_up(http=True, timeout=spawner.http_timeout) - except Exception as e: - if isinstance(e, TimeoutError): - self.log.warn( - "{user}'s server never showed up at {url} " - "after {http_timeout} seconds. Giving up".format( - user=self.name, - url=self.server.url, - http_timeout=spawner.http_timeout, - ) - ) - e.reason = 'timeout' - else: - e.reason = 'error' - self.log.error("Unhandled error waiting for {user}'s server to show up at {url}: {error}".format( - user=self.name, url=self.server.url, error=e, - )) - try: - yield self.stop() - except Exception: - self.log.error("Failed to cleanup {user}'s server that failed to start".format( - user=self.name, - ), exc_info=True) - # raise original TimeoutError - raise e - self.spawn_pending = False - return self - - @gen.coroutine - def stop(self): - """Stop the user's spawner - - and cleanup after it. - """ - self.spawn_pending = False - spawner = self.spawner - if spawner is None: - return - spawner.stop_polling() - self.stop_pending = True - try: - status = yield spawner.poll() - if status is None: - yield self.spawner.stop() - spawner.clear_state() - self.state = spawner.get_state() - self.server = None - inspect(self).session.commit() - finally: - self.stop_pending = False - # trigger post-spawner hook on authenticator - auth = spawner.authenticator - if auth: - yield gen.maybe_future( - auth.post_spawn_stop(self, spawner) - ) class APIToken(Base): """An API token""" diff --git a/jupyterhub/tests/test_api.py b/jupyterhub/tests/test_api.py index 117b5644..d438801f 100644 --- a/jupyterhub/tests/test_api.py +++ b/jupyterhub/tests/test_api.py @@ -10,8 +10,9 @@ import requests from tornado import gen -from ..utils import url_path_join as ujoin from .. import orm +from ..user import User +from ..utils import url_path_join as ujoin from . import mocking @@ -41,11 +42,15 @@ def check_db_locks(func): def find_user(db, name): return db.query(orm.User).filter(orm.User.name==name).first() -def add_user(db, **kwargs): - user = orm.User(**kwargs) - db.add(user) +def add_user(db, app=None, **kwargs): + orm_user = orm.User(**kwargs) + db.add(orm_user) db.commit() - return user + if app: + user = app.users[orm_user.id] = User(orm_user) + return user + else: + return orm_user def auth_header(db, name): user = find_user(db, name) @@ -310,16 +315,18 @@ def get_app_user(app, name): No ORM methods should be called on the result. """ q = Queue() - def get_user(): + def get_user_id(): user = find_user(app.db, name) - q.put(user) - app.io_loop.add_callback(get_user) - return q.get(timeout=2) + q.put(user.id) + app.io_loop.add_callback(get_user_id) + user_id = q.get(timeout=2) + return app.users[user_id] def test_spawn(app, io_loop): db = app.db name = 'wash' - user = add_user(db, name=name) + user = add_user(db, app=app, name=name) + r = api_request(app, 'users', name, 'server', method='post') assert r.status_code == 201 assert 'pid' in user.state @@ -354,7 +361,7 @@ def test_slow_spawn(app, io_loop): db = app.db name = 'zoe' - user = add_user(db, name=name) + user = add_user(db, app=app, name=name) r = api_request(app, 'users', name, 'server', method='post') r.raise_for_status() assert r.status_code == 202 @@ -403,7 +410,7 @@ def test_never_spawn(app, io_loop): db = app.db name = 'badger' - user = add_user(db, name=name) + user = add_user(db, app=app, name=name) r = api_request(app, 'users', name, 'server', method='post') app_user = get_app_user(app, name) assert app_user.spawner is not None diff --git a/jupyterhub/tests/test_orm.py b/jupyterhub/tests/test_orm.py index 9aa6dc22..8d4f5bc2 100644 --- a/jupyterhub/tests/test_orm.py +++ b/jupyterhub/tests/test_orm.py @@ -7,6 +7,7 @@ import pytest from tornado import gen from .. import orm +from ..user import User from .mocking import MockSpawner @@ -94,9 +95,10 @@ def test_tokens(db): def test_spawn_fails(db, io_loop): - user = orm.User(name='aeofel') - db.add(user) + orm_user = orm.User(name='aeofel') + db.add(orm_user) db.commit() + user = User(orm_user) class BadSpawner(MockSpawner): @gen.coroutine diff --git a/jupyterhub/user.py b/jupyterhub/user.py new file mode 100644 index 00000000..15c20945 --- /dev/null +++ b/jupyterhub/user.py @@ -0,0 +1,229 @@ +# Copyright (c) Jupyter Development Team. +# Distributed under the terms of the Modified BSD License. + +from datetime import datetime, timedelta +from urllib.parse import quote + +from tornado import gen +from tornado.log import app_log + +from sqlalchemy import inspect + +from .utils import url_path_join + +from . import orm +from IPython.utils.traitlets import HasTraits, Any + +class UserDict(dict): + """Like defaultdict, but for users + + Getting by a user id OR an orm.User instance returns a User wrapper around the orm user. + """ + def __init__(self, db_factory): + self.db_factory = db_factory + super().__init__() + + @property + def db(self): + return self.db_factory() + + def __getitem__(self, key): + if isinstance(key, orm.User): + # users[orm_user] returns User(orm_user) + orm_user = key + if orm_user.id not in self: + user = self[orm_user.id] = User(orm_user) + return user + user = dict.__getitem__(self, orm_user.id) + user.db = self.db + return user + elif isinstance(key, int): + id = key + if id not in self: + orm_user = self.db.query(orm.User).filter(orm.User.id==id).first() + if orm_user is None: + raise KeyError("No such user: %s" % id) + user = self[id] = User(orm_user) + return dict.__getitem__(self, id) + else: + raise KeyError(repr(key)) + + +class User(HasTraits): + + def _log_default(self): + return app_log + + db = Any(allow_none=True) + def _db_default(self): + if self.orm_user: + return inspect(self.orm_user).session + + def _db_changed(self, name, old, new): + """Changing db session reacquires ORM User object""" + # db session changed, re-get orm User + if self.orm_user: + id = self.orm_user.id + self.orm_user = new.query(orm.User).filter(orm.User.id==id).first() + + orm_user = None + spawner = None + spawn_pending = False + stop_pending = False + + def __init__(self, orm_user, **kwargs): + self.orm_user = orm_user + super().__init__(**kwargs) + + # pass get/setattr to ORM user + + def __getattr__(self, attr): + if hasattr(self.orm_user, attr): + return getattr(self.orm_user, attr) + else: + raise AttributeError(attr) + + def __setattr__(self, attr, value): + if self.orm_user and hasattr(self.orm_user, attr): + setattr(self.orm_user, attr, value) + else: + super().__setattr__(attr, value) + + def __repr__(self): + return repr(self.orm_user) + + @property + def running(self): + """property for whether a user has a running server""" + if self.spawner is None: + return False + if self.server is None: + return False + return True + + @property + def escaped_name(self): + """My name, escaped for use in URLs, cookies, etc.""" + return quote(self.name, safe='@') + + @gen.coroutine + def spawn(self, spawner_class, base_url='/', hub=None, config=None, authenticator=None): + """Start the user's spawner""" + db = self.db + if hub is None: + hub = db.query(orm.Hub).first() + + self.server = orm.Server( + cookie_name='%s-%s' % (hub.server.cookie_name, quote(self.name, safe='')), + base_url=url_path_join(base_url, 'user', self.escaped_name), + ) + db.add(self.server) + db.commit() + + api_token = self.new_api_token() + db.commit() + + spawner = self.spawner = spawner_class( + config=config, + user=self, + hub=hub, + db=db, + authenticator=authenticator, + ) + # we are starting a new server, make sure it doesn't restore state + spawner.clear_state() + spawner.api_token = api_token + + # trigger pre-spawn hook on authenticator + if (authenticator): + yield gen.maybe_future(authenticator.pre_spawn_start(self, spawner)) + + self.spawn_pending = True + # wait for spawner.start to return + try: + f = spawner.start() + yield gen.with_timeout(timedelta(seconds=spawner.start_timeout), f) + except Exception as e: + if isinstance(e, gen.TimeoutError): + self.log.warn("{user}'s server failed to start in {s} seconds, giving up".format( + user=self.name, s=spawner.start_timeout, + )) + e.reason = 'timeout' + else: + self.log.error("Unhandled error starting {user}'s server: {error}".format( + user=self.name, error=e, + )) + e.reason = 'error' + try: + yield self.stop() + except Exception: + self.log.error("Failed to cleanup {user}'s server that failed to start".format( + user=self.name, + ), exc_info=True) + # raise original exception + raise e + spawner.start_polling() + + # store state + self.state = spawner.get_state() + self.last_activity = datetime.utcnow() + db.commit() + try: + yield self.server.wait_up(http=True, timeout=spawner.http_timeout) + except Exception as e: + if isinstance(e, TimeoutError): + self.log.warn( + "{user}'s server never showed up at {url} " + "after {http_timeout} seconds. Giving up".format( + user=self.name, + url=self.server.url, + http_timeout=spawner.http_timeout, + ) + ) + e.reason = 'timeout' + else: + e.reason = 'error' + self.log.error("Unhandled error waiting for {user}'s server to show up at {url}: {error}".format( + user=self.name, url=self.server.url, error=e, + )) + try: + yield self.stop() + except Exception: + self.log.error("Failed to cleanup {user}'s server that failed to start".format( + user=self.name, + ), exc_info=True) + # raise original TimeoutError + raise e + self.spawn_pending = False + return self + + @gen.coroutine + def stop(self): + """Stop the user's spawner + + and cleanup after it. + """ + self.spawn_pending = False + spawner = self.spawner + if spawner is None: + return + self.spawner.stop_polling() + self.stop_pending = True + try: + status = yield spawner.poll() + if status is None: + yield self.spawner.stop() + spawner.clear_state() + self.state = spawner.get_state() + self.last_activity = datetime.utcnow() + self.server = None + self.db.commit() + finally: + self.stop_pending = False + # trigger post-spawner hook on authenticator + auth = spawner.authenticator + if auth: + yield gen.maybe_future( + auth.post_spawn_stop(self, spawner) + ) + diff --git a/share/jupyter/hub/templates/admin.html b/share/jupyter/hub/templates/admin.html index d777e781..c8101a64 100644 --- a/share/jupyter/hub/templates/admin.html +++ b/share/jupyter/hub/templates/admin.html @@ -22,10 +22,10 @@ {% block thead %} - {{ th("User (%i)" % users.count(), 'name') }} + {{ th("User (%i)" % users|length, 'name') }} {{ th("Admin", 'admin') }} {{ th("Last Seen", 'last_activity') }} - {{ th("Running (%i)" % running.count(), 'running', colspan=2) }} + {{ th("Running (%i)" % running|length, 'running', colspan=2) }} {% endblock thead %}