diff --git a/jupyterhub/handlers/base.py b/jupyterhub/handlers/base.py index f6108818..dd6ce51b 100644 --- a/jupyterhub/handlers/base.py +++ b/jupyterhub/handlers/base.py @@ -73,7 +73,7 @@ class BaseHandler(RequestHandler): if not match: return None token = match.group(1) - orm_token = self.db.query(orm.APIToken).filter(orm.APIToken.token == token).first() + orm_token = orm.APIToken.find(self.db, token) if orm_token is None: return None else: @@ -83,8 +83,7 @@ class BaseHandler(RequestHandler): """get_current_user from a cookie token""" token = self.get_cookie(self.hub.server.cookie_name, None) if token: - cookie_token = self.db.query(orm.CookieToken).filter( - orm.CookieToken.token==token).first() + cookie_token = orm.CookieToken.find(self.db, token) if cookie_token: return cookie_token.user else: @@ -103,7 +102,7 @@ class BaseHandler(RequestHandler): return None if no such user """ - return self.db.query(orm.User).filter(orm.User.name==name).first() + return orm.User.find(self.db, name) def user_from_username(self, username): """Get ORM User for username""" diff --git a/jupyterhub/orm.py b/jupyterhub/orm.py index 5233750b..34399a1b 100644 --- a/jupyterhub/orm.py +++ b/jupyterhub/orm.py @@ -262,6 +262,14 @@ class User(Base): """Return a new cookie token""" return self._new_token(CookieToken) + @classmethod + def find(cls, db, name): + """Find a user by name. + + Returns None if not found. + """ + return db.query(cls).filter(cls.name==name).first() + @gen.coroutine def spawn(self, spawner_class, base_url='/', hub=None, config=None): db = inspect(self).session @@ -321,6 +329,15 @@ class Token(object): u=self.user.name, ) + @classmethod + def find(cls, db, token): + """Find a token object by value. + + Returns None if not found. + """ + return db.query(cls).filter(cls.token==token).first() + + class APIToken(Token, Base): """An API token""" diff --git a/jupyterhub/tests/test_orm.py b/jupyterhub/tests/test_orm.py index 8990f110..bd651942 100644 --- a/jupyterhub/tests/test_orm.py +++ b/jupyterhub/tests/test_orm.py @@ -71,6 +71,11 @@ def test_user(db): assert user.server.ip == u'localhost' assert user.state == {'pid': 4234} + found = orm.User.find(db, u'kaylee') + assert found.name == user.name + found = orm.User.find(db, u'badger') + assert found is None + def test_tokens(db): user = orm.User(name=u'inara') @@ -87,3 +92,7 @@ def test_tokens(db): assert len(user.api_tokens) == 1 assert len(user.cookie_tokens) == 3 + found = orm.CookieToken.find(db, token=token.token) + assert found.token == token.token + found = orm.APIToken.find(db, token.token) + assert found is None