Merge pull request #185 from minrk/outer-user

move non-persisted User objects (spawner-related) off of orm.User
This commit is contained in:
Min RK
2015-12-02 12:44:47 +01:00
10 changed files with 299 additions and 172 deletions

View File

@@ -18,7 +18,7 @@ class TokenAPIHandler(APIHandler):
orm_token = orm.APIToken.find(self.db, token) orm_token = orm.APIToken.find(self.db, token)
if orm_token is None: if orm_token is None:
raise web.HTTPError(404) raise web.HTTPError(404)
self.write(json.dumps(self.user_model(orm_token.user))) self.write(json.dumps(self.user_model(self.users[orm_token.user])))
class CookieAPIHandler(APIHandler): class CookieAPIHandler(APIHandler):

View File

@@ -15,7 +15,7 @@ from .base import APIHandler
class UserListAPIHandler(APIHandler): class UserListAPIHandler(APIHandler):
@admin_only @admin_only
def get(self): def get(self):
users = self.db.query(orm.User) users = [ self._user_from_orm(u) for u in self.db.query(orm.User) ]
data = [ self.user_model(u) for u in users ] data = [ self.user_model(u) for u in users ]
self.write(json.dumps(data)) self.write(json.dumps(data))
@@ -104,6 +104,8 @@ class UserAPIHandler(APIHandler):
yield gen.maybe_future(self.authenticator.add_user(user)) yield gen.maybe_future(self.authenticator.add_user(user))
except Exception: except Exception:
self.log.error("Failed to create user: %s" % name, exc_info=True) self.log.error("Failed to create user: %s" % name, exc_info=True)
# remove from registry
self.users.pop(user.id, None)
self.db.delete(user) self.db.delete(user)
self.db.commit() self.db.commit()
raise web.HTTPError(400, "Failed to create user: %s" % name) raise web.HTTPError(400, "Failed to create user: %s" % name)
@@ -127,7 +129,8 @@ class UserAPIHandler(APIHandler):
raise web.HTTPError(400, "%s's server is in the process of stopping, please wait." % name) raise web.HTTPError(400, "%s's server is in the process of stopping, please wait." % name)
yield gen.maybe_future(self.authenticator.delete_user(user)) yield gen.maybe_future(self.authenticator.delete_user(user))
# remove from registry
self.users.pop(user.id, None)
# remove from the db # remove from the db
self.db.delete(user) self.db.delete(user)
self.db.commit() self.db.commit()

View File

@@ -45,6 +45,7 @@ from . import handlers, apihandlers
from .handlers.static import CacheControlStaticFilesHandler from .handlers.static import CacheControlStaticFilesHandler
from . import orm from . import orm
from .user import User, UserDict
from ._data import DATA_FILES_PATH from ._data import DATA_FILES_PATH
from .log import CoroutineLogFormatter, log_request from .log import CoroutineLogFormatter, log_request
from .traitlets import URLPrefix, Command from .traitlets import URLPrefix, Command
@@ -349,6 +350,10 @@ class JupyterHub(Application):
) )
session_factory = Any() session_factory = Any()
users = Instance(UserDict)
def _users_default(self):
return UserDict(db_factory=lambda : self.db)
admin_access = Bool(False, config=True, admin_access = Bool(False, config=True,
help="""Grant admin users permission to access single-user servers. help="""Grant admin users permission to access single-user servers.
@@ -699,7 +704,8 @@ class JupyterHub(Application):
yield self.proxy.delete_user(user) yield self.proxy.delete_user(user)
yield user.stop() yield user.stop()
for user in db.query(orm.User): for orm_user in db.query(orm.User):
self.users[orm_user.id] = user = User(orm_user)
if not user.state: if not user.state:
# without spawner state, server isn't valid # without spawner state, server isn't valid
user.server = None user.server = None
@@ -854,6 +860,7 @@ class JupyterHub(Application):
proxy=self.proxy, proxy=self.proxy,
hub=self.hub, hub=self.hub,
admin_users=self.authenticator.admin_users, admin_users=self.authenticator.admin_users,
users=self.users,
admin_access=self.admin_access, admin_access=self.admin_access,
authenticator=self.authenticator, authenticator=self.authenticator,
spawner_class=self.spawner_class, spawner_class=self.spawner_class,
@@ -921,7 +928,7 @@ class JupyterHub(Application):
if self.cleanup_servers: if self.cleanup_servers:
self.log.info("Cleaning up single-user servers...") self.log.info("Cleaning up single-user servers...")
# request (async) process termination # request (async) process termination
for user in self.db.query(orm.User): for uid, user in self.users.items():
if user.spawner is not None: if user.spawner is not None:
futures.append(user.stop()) futures.append(user.stop())
else: else:

View File

@@ -4,7 +4,7 @@
# Distributed under the terms of the Modified BSD License. # Distributed under the terms of the Modified BSD License.
import re import re
from datetime import datetime, timedelta from datetime import timedelta
from http.client import responses from http.client import responses
from jinja2 import TemplateNotFound from jinja2 import TemplateNotFound
@@ -16,6 +16,7 @@ from tornado.web import RequestHandler
from tornado import gen, web from tornado import gen, web
from .. import orm from .. import orm
from ..user import User
from ..spawner import LocalProcessSpawner from ..spawner import LocalProcessSpawner
from ..utils import url_path_join from ..utils import url_path_join
@@ -53,7 +54,11 @@ class BaseHandler(RequestHandler):
@property @property
def db(self): def db(self):
return self.settings['db'] return self.settings['db']
@property
def users(self):
return self.settings.setdefault('users', {})
@property @property
def hub(self): def hub(self):
return self.settings['hub'] return self.settings['hub']
@@ -145,13 +150,20 @@ class BaseHandler(RequestHandler):
clear() clear()
return return
cookie_id = cookie_id.decode('utf8', 'replace') cookie_id = cookie_id.decode('utf8', 'replace')
user = self.db.query(orm.User).filter(orm.User.cookie_id==cookie_id).first() u = self.db.query(orm.User).filter(orm.User.cookie_id==cookie_id).first()
user = self._user_from_orm(u)
if user is None: if user is None:
self.log.warn("Invalid cookie token") self.log.warn("Invalid cookie token")
# have cookie, but it's not valid. Clear it and start over. # have cookie, but it's not valid. Clear it and start over.
clear() clear()
return user return user
def _user_from_orm(self, orm_user):
"""return User wrapper from orm.User object"""
if orm_user is None:
return
return self.users[orm_user]
def get_current_user_cookie(self): def get_current_user_cookie(self):
"""get_current_user from a cookie token""" """get_current_user from a cookie token"""
return self._user_for_cookie(self.hub.server.cookie_name) return self._user_for_cookie(self.hub.server.cookie_name)
@@ -168,15 +180,18 @@ class BaseHandler(RequestHandler):
return None if no such user return None if no such user
""" """
return orm.User.find(self.db, name) orm_user = orm.User.find(db=self.db, name=name)
return self._user_from_orm(orm_user)
def user_from_username(self, username): def user_from_username(self, username):
"""Get ORM User for username""" """Get User for username, creating if it doesn't exist"""
user = self.find_user(username) user = self.find_user(username)
if user is None: if user is None:
user = orm.User(name=username) # not found, create and register user
self.db.add(user) u = orm.User(name=username)
self.db.add(u)
self.db.commit() self.db.commit()
user = self._user_from_orm(u)
return user return user
def clear_login_cookie(self, name=None): def clear_login_cookie(self, name=None):
@@ -259,7 +274,7 @@ class BaseHandler(RequestHandler):
if user.spawn_pending: if user.spawn_pending:
raise RuntimeError("Spawn already pending for: %s" % user.name) raise RuntimeError("Spawn already pending for: %s" % user.name)
tic = IOLoop.current().time() tic = IOLoop.current().time()
f = user.spawn( f = user.spawn(
spawner_class=self.spawner_class, spawner_class=self.spawner_class,
base_url=self.base_url, base_url=self.base_url,

View File

@@ -90,7 +90,8 @@ class AdminHandler(BaseHandler):
ordered = [ getattr(c, o)() for c, o in zip(cols, orders) ] ordered = [ getattr(c, o)() for c, o in zip(cols, orders) ]
users = self.db.query(orm.User).order_by(*ordered) users = self.db.query(orm.User).order_by(*ordered)
running = users.filter(orm.User.server != None) users = [ self._user_from_orm(u) for u in users ]
running = [ u for u in users if u.running ]
html = self.render_template('admin.html', html = self.render_template('admin.html',
user=self.get_current_user(), user=self.get_current_user(),

View File

@@ -3,15 +3,14 @@
# Copyright (c) Jupyter Development Team. # Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License. # Distributed under the terms of the Modified BSD License.
from datetime import datetime, timedelta from datetime import datetime
import errno import errno
import json import json
import socket import socket
from urllib.parse import quote
from tornado import gen from tornado import gen
from tornado.log import app_log from tornado.log import app_log
from tornado.httpclient import HTTPRequest, AsyncHTTPClient, HTTPError from tornado.httpclient import HTTPRequest, AsyncHTTPClient
from sqlalchemy.types import TypeDecorator, VARCHAR from sqlalchemy.types import TypeDecorator, VARCHAR
from sqlalchemy import ( from sqlalchemy import (
@@ -271,7 +270,7 @@ class User(Base):
used for restoring state of a Spawner. used for restoring state of a Spawner.
""" """
__tablename__ = 'users' __tablename__ = 'users'
id = Column(Integer, primary_key=True) id = Column(Integer, primary_key=True, autoincrement=True)
name = Column(Unicode) name = Column(Unicode)
# should we allow multiple servers per user? # should we allow multiple servers per user?
_server_id = Column(Integer, ForeignKey('servers.id')) _server_id = Column(Integer, ForeignKey('servers.id'))
@@ -282,12 +281,9 @@ class User(Base):
api_tokens = relationship("APIToken", backref="user") api_tokens = relationship("APIToken", backref="user")
cookie_id = Column(Unicode, default=new_token) cookie_id = Column(Unicode, default=new_token)
state = Column(JSONDict) state = Column(JSONDict)
spawner = None
spawn_pending = False
stop_pending = False
other_user_cookies = set([]) other_user_cookies = set([])
def __repr__(self): def __repr__(self):
if self.server: if self.server:
return "<{cls}({name}@{ip}:{port})>".format( return "<{cls}({name}@{ip}:{port})>".format(
@@ -302,20 +298,6 @@ class User(Base):
name=self.name, name=self.name,
) )
@property
def escaped_name(self):
"""My name, escaped for use in URLs, cookies, etc."""
return quote(self.name, safe='@')
@property
def running(self):
"""property for whether a user has a running server"""
if self.spawner is None:
return False
if self.server is None:
return False
return True
def new_api_token(self): def new_api_token(self):
"""Create a new API token""" """Create a new API token"""
assert self.id is not None assert self.id is not None
@@ -326,7 +308,7 @@ class User(Base):
db.add(orm_token) db.add(orm_token)
db.commit() db.commit()
return token return token
@classmethod @classmethod
def find(cls, db, name): def find(cls, db, name):
"""Find a user by name. """Find a user by name.
@@ -334,125 +316,6 @@ class User(Base):
Returns None if not found. Returns None if not found.
""" """
return db.query(cls).filter(cls.name==name).first() return db.query(cls).filter(cls.name==name).first()
@gen.coroutine
def spawn(self, spawner_class, base_url='/', hub=None, authenticator=None, config=None):
"""Start the user's spawner"""
db = inspect(self).session
if hub is None:
hub = db.query(Hub).first()
self.server = Server(
cookie_name='%s-%s' % (hub.server.cookie_name, quote(self.name, safe='')),
base_url=url_path_join(base_url, 'user', self.escaped_name),
)
db.add(self.server)
db.commit()
api_token = self.new_api_token()
db.commit()
spawner = self.spawner = spawner_class(
config=config,
user=self,
hub=hub,
db=db,
authenticator=authenticator,
)
# we are starting a new server, make sure it doesn't restore state
spawner.clear_state()
spawner.api_token = api_token
# trigger pre-spawn hook on authenticator
if (authenticator):
yield gen.maybe_future(authenticator.pre_spawn_start(self, spawner))
self.spawn_pending = True
# wait for spawner.start to return
try:
f = spawner.start()
yield gen.with_timeout(timedelta(seconds=spawner.start_timeout), f)
except Exception as e:
if isinstance(e, gen.TimeoutError):
self.log.warn("{user}'s server failed to start in {s} seconds, giving up".format(
user=self.name, s=spawner.start_timeout,
))
e.reason = 'timeout'
else:
self.log.error("Unhandled error starting {user}'s server: {error}".format(
user=self.name, error=e,
))
e.reason = 'error'
try:
yield self.stop()
except Exception:
self.log.error("Failed to cleanup {user}'s server that failed to start".format(
user=self.name,
), exc_info=True)
# raise original exception
raise e
spawner.start_polling()
# store state
self.state = spawner.get_state()
self.last_activity = datetime.utcnow()
db.commit()
try:
yield self.server.wait_up(http=True, timeout=spawner.http_timeout)
except Exception as e:
if isinstance(e, TimeoutError):
self.log.warn(
"{user}'s server never showed up at {url} "
"after {http_timeout} seconds. Giving up".format(
user=self.name,
url=self.server.url,
http_timeout=spawner.http_timeout,
)
)
e.reason = 'timeout'
else:
e.reason = 'error'
self.log.error("Unhandled error waiting for {user}'s server to show up at {url}: {error}".format(
user=self.name, url=self.server.url, error=e,
))
try:
yield self.stop()
except Exception:
self.log.error("Failed to cleanup {user}'s server that failed to start".format(
user=self.name,
), exc_info=True)
# raise original TimeoutError
raise e
self.spawn_pending = False
return self
@gen.coroutine
def stop(self):
"""Stop the user's spawner
and cleanup after it.
"""
self.spawn_pending = False
spawner = self.spawner
if spawner is None:
return
spawner.stop_polling()
self.stop_pending = True
try:
status = yield spawner.poll()
if status is None:
yield self.spawner.stop()
spawner.clear_state()
self.state = spawner.get_state()
self.server = None
inspect(self).session.commit()
finally:
self.stop_pending = False
# trigger post-spawner hook on authenticator
auth = spawner.authenticator
if auth:
yield gen.maybe_future(
auth.post_spawn_stop(self, spawner)
)
class APIToken(Base): class APIToken(Base):
"""An API token""" """An API token"""

View File

@@ -10,8 +10,9 @@ import requests
from tornado import gen from tornado import gen
from ..utils import url_path_join as ujoin
from .. import orm from .. import orm
from ..user import User
from ..utils import url_path_join as ujoin
from . import mocking from . import mocking
@@ -41,11 +42,15 @@ def check_db_locks(func):
def find_user(db, name): def find_user(db, name):
return db.query(orm.User).filter(orm.User.name==name).first() return db.query(orm.User).filter(orm.User.name==name).first()
def add_user(db, **kwargs): def add_user(db, app=None, **kwargs):
user = orm.User(**kwargs) orm_user = orm.User(**kwargs)
db.add(user) db.add(orm_user)
db.commit() db.commit()
return user if app:
user = app.users[orm_user.id] = User(orm_user)
return user
else:
return orm_user
def auth_header(db, name): def auth_header(db, name):
user = find_user(db, name) user = find_user(db, name)
@@ -310,16 +315,18 @@ def get_app_user(app, name):
No ORM methods should be called on the result. No ORM methods should be called on the result.
""" """
q = Queue() q = Queue()
def get_user(): def get_user_id():
user = find_user(app.db, name) user = find_user(app.db, name)
q.put(user) q.put(user.id)
app.io_loop.add_callback(get_user) app.io_loop.add_callback(get_user_id)
return q.get(timeout=2) user_id = q.get(timeout=2)
return app.users[user_id]
def test_spawn(app, io_loop): def test_spawn(app, io_loop):
db = app.db db = app.db
name = 'wash' name = 'wash'
user = add_user(db, name=name) user = add_user(db, app=app, name=name)
r = api_request(app, 'users', name, 'server', method='post') r = api_request(app, 'users', name, 'server', method='post')
assert r.status_code == 201 assert r.status_code == 201
assert 'pid' in user.state assert 'pid' in user.state
@@ -354,7 +361,7 @@ def test_slow_spawn(app, io_loop):
db = app.db db = app.db
name = 'zoe' name = 'zoe'
user = add_user(db, name=name) user = add_user(db, app=app, name=name)
r = api_request(app, 'users', name, 'server', method='post') r = api_request(app, 'users', name, 'server', method='post')
r.raise_for_status() r.raise_for_status()
assert r.status_code == 202 assert r.status_code == 202
@@ -403,7 +410,7 @@ def test_never_spawn(app, io_loop):
db = app.db db = app.db
name = 'badger' name = 'badger'
user = add_user(db, name=name) user = add_user(db, app=app, name=name)
r = api_request(app, 'users', name, 'server', method='post') r = api_request(app, 'users', name, 'server', method='post')
app_user = get_app_user(app, name) app_user = get_app_user(app, name)
assert app_user.spawner is not None assert app_user.spawner is not None

View File

@@ -7,6 +7,7 @@ import pytest
from tornado import gen from tornado import gen
from .. import orm from .. import orm
from ..user import User
from .mocking import MockSpawner from .mocking import MockSpawner
@@ -94,9 +95,10 @@ def test_tokens(db):
def test_spawn_fails(db, io_loop): def test_spawn_fails(db, io_loop):
user = orm.User(name='aeofel') orm_user = orm.User(name='aeofel')
db.add(user) db.add(orm_user)
db.commit() db.commit()
user = User(orm_user)
class BadSpawner(MockSpawner): class BadSpawner(MockSpawner):
@gen.coroutine @gen.coroutine

229
jupyterhub/user.py Normal file
View File

@@ -0,0 +1,229 @@
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
from datetime import datetime, timedelta
from urllib.parse import quote
from tornado import gen
from tornado.log import app_log
from sqlalchemy import inspect
from .utils import url_path_join
from . import orm
from IPython.utils.traitlets import HasTraits, Any
class UserDict(dict):
"""Like defaultdict, but for users
Getting by a user id OR an orm.User instance returns a User wrapper around the orm user.
"""
def __init__(self, db_factory):
self.db_factory = db_factory
super().__init__()
@property
def db(self):
return self.db_factory()
def __getitem__(self, key):
if isinstance(key, orm.User):
# users[orm_user] returns User(orm_user)
orm_user = key
if orm_user.id not in self:
user = self[orm_user.id] = User(orm_user)
return user
user = dict.__getitem__(self, orm_user.id)
user.db = self.db
return user
elif isinstance(key, int):
id = key
if id not in self:
orm_user = self.db.query(orm.User).filter(orm.User.id==id).first()
if orm_user is None:
raise KeyError("No such user: %s" % id)
user = self[id] = User(orm_user)
return dict.__getitem__(self, id)
else:
raise KeyError(repr(key))
class User(HasTraits):
def _log_default(self):
return app_log
db = Any(allow_none=True)
def _db_default(self):
if self.orm_user:
return inspect(self.orm_user).session
def _db_changed(self, name, old, new):
"""Changing db session reacquires ORM User object"""
# db session changed, re-get orm User
if self.orm_user:
id = self.orm_user.id
self.orm_user = new.query(orm.User).filter(orm.User.id==id).first()
orm_user = None
spawner = None
spawn_pending = False
stop_pending = False
def __init__(self, orm_user, **kwargs):
self.orm_user = orm_user
super().__init__(**kwargs)
# pass get/setattr to ORM user
def __getattr__(self, attr):
if hasattr(self.orm_user, attr):
return getattr(self.orm_user, attr)
else:
raise AttributeError(attr)
def __setattr__(self, attr, value):
if self.orm_user and hasattr(self.orm_user, attr):
setattr(self.orm_user, attr, value)
else:
super().__setattr__(attr, value)
def __repr__(self):
return repr(self.orm_user)
@property
def running(self):
"""property for whether a user has a running server"""
if self.spawner is None:
return False
if self.server is None:
return False
return True
@property
def escaped_name(self):
"""My name, escaped for use in URLs, cookies, etc."""
return quote(self.name, safe='@')
@gen.coroutine
def spawn(self, spawner_class, base_url='/', hub=None, config=None, authenticator=None):
"""Start the user's spawner"""
db = self.db
if hub is None:
hub = db.query(orm.Hub).first()
self.server = orm.Server(
cookie_name='%s-%s' % (hub.server.cookie_name, quote(self.name, safe='')),
base_url=url_path_join(base_url, 'user', self.escaped_name),
)
db.add(self.server)
db.commit()
api_token = self.new_api_token()
db.commit()
spawner = self.spawner = spawner_class(
config=config,
user=self,
hub=hub,
db=db,
authenticator=authenticator,
)
# we are starting a new server, make sure it doesn't restore state
spawner.clear_state()
spawner.api_token = api_token
# trigger pre-spawn hook on authenticator
if (authenticator):
yield gen.maybe_future(authenticator.pre_spawn_start(self, spawner))
self.spawn_pending = True
# wait for spawner.start to return
try:
f = spawner.start()
yield gen.with_timeout(timedelta(seconds=spawner.start_timeout), f)
except Exception as e:
if isinstance(e, gen.TimeoutError):
self.log.warn("{user}'s server failed to start in {s} seconds, giving up".format(
user=self.name, s=spawner.start_timeout,
))
e.reason = 'timeout'
else:
self.log.error("Unhandled error starting {user}'s server: {error}".format(
user=self.name, error=e,
))
e.reason = 'error'
try:
yield self.stop()
except Exception:
self.log.error("Failed to cleanup {user}'s server that failed to start".format(
user=self.name,
), exc_info=True)
# raise original exception
raise e
spawner.start_polling()
# store state
self.state = spawner.get_state()
self.last_activity = datetime.utcnow()
db.commit()
try:
yield self.server.wait_up(http=True, timeout=spawner.http_timeout)
except Exception as e:
if isinstance(e, TimeoutError):
self.log.warn(
"{user}'s server never showed up at {url} "
"after {http_timeout} seconds. Giving up".format(
user=self.name,
url=self.server.url,
http_timeout=spawner.http_timeout,
)
)
e.reason = 'timeout'
else:
e.reason = 'error'
self.log.error("Unhandled error waiting for {user}'s server to show up at {url}: {error}".format(
user=self.name, url=self.server.url, error=e,
))
try:
yield self.stop()
except Exception:
self.log.error("Failed to cleanup {user}'s server that failed to start".format(
user=self.name,
), exc_info=True)
# raise original TimeoutError
raise e
self.spawn_pending = False
return self
@gen.coroutine
def stop(self):
"""Stop the user's spawner
and cleanup after it.
"""
self.spawn_pending = False
spawner = self.spawner
if spawner is None:
return
self.spawner.stop_polling()
self.stop_pending = True
try:
status = yield spawner.poll()
if status is None:
yield self.spawner.stop()
spawner.clear_state()
self.state = spawner.get_state()
self.last_activity = datetime.utcnow()
self.server = None
self.db.commit()
finally:
self.stop_pending = False
# trigger post-spawner hook on authenticator
auth = spawner.authenticator
if auth:
yield gen.maybe_future(
auth.post_spawn_stop(self, spawner)
)

View File

@@ -22,10 +22,10 @@
<thead> <thead>
<tr> <tr>
{% block thead %} {% block thead %}
{{ th("User (%i)" % users.count(), 'name') }} {{ th("User (%i)" % users|length, 'name') }}
{{ th("Admin", 'admin') }} {{ th("Admin", 'admin') }}
{{ th("Last Seen", 'last_activity') }} {{ th("Last Seen", 'last_activity') }}
{{ th("Running (%i)" % running.count(), 'running', colspan=2) }} {{ th("Running (%i)" % running|length, 'running', colspan=2) }}
{% endblock thead %} {% endblock thead %}
</tr> </tr>
</thead> </thead>