add orm.User.find and orm.[Foo]Token.find

for simple get-by-name access
This commit is contained in:
MinRK
2014-09-20 12:09:23 -07:00
parent cf64828d32
commit 2eb42eb0b3
3 changed files with 29 additions and 4 deletions

View File

@@ -73,7 +73,7 @@ class BaseHandler(RequestHandler):
if not match: if not match:
return None return None
token = match.group(1) token = match.group(1)
orm_token = self.db.query(orm.APIToken).filter(orm.APIToken.token == token).first() orm_token = orm.APIToken.find(self.db, token)
if orm_token is None: if orm_token is None:
return None return None
else: else:
@@ -83,8 +83,7 @@ class BaseHandler(RequestHandler):
"""get_current_user from a cookie token""" """get_current_user from a cookie token"""
token = self.get_cookie(self.hub.server.cookie_name, None) token = self.get_cookie(self.hub.server.cookie_name, None)
if token: if token:
cookie_token = self.db.query(orm.CookieToken).filter( cookie_token = orm.CookieToken.find(self.db, token)
orm.CookieToken.token==token).first()
if cookie_token: if cookie_token:
return cookie_token.user return cookie_token.user
else: else:
@@ -103,7 +102,7 @@ class BaseHandler(RequestHandler):
return None if no such user return None if no such user
""" """
return self.db.query(orm.User).filter(orm.User.name==name).first() return orm.User.find(self.db, name)
def user_from_username(self, username): def user_from_username(self, username):
"""Get ORM User for username""" """Get ORM User for username"""

View File

@@ -262,6 +262,14 @@ class User(Base):
"""Return a new cookie token""" """Return a new cookie token"""
return self._new_token(CookieToken) return self._new_token(CookieToken)
@classmethod
def find(cls, db, name):
"""Find a user by name.
Returns None if not found.
"""
return db.query(cls).filter(cls.name==name).first()
@gen.coroutine @gen.coroutine
def spawn(self, spawner_class, base_url='/', hub=None, config=None): def spawn(self, spawner_class, base_url='/', hub=None, config=None):
db = inspect(self).session db = inspect(self).session
@@ -321,6 +329,15 @@ class Token(object):
u=self.user.name, u=self.user.name,
) )
@classmethod
def find(cls, db, token):
"""Find a token object by value.
Returns None if not found.
"""
return db.query(cls).filter(cls.token==token).first()
class APIToken(Token, Base): class APIToken(Token, Base):
"""An API token""" """An API token"""

View File

@@ -71,6 +71,11 @@ def test_user(db):
assert user.server.ip == u'localhost' assert user.server.ip == u'localhost'
assert user.state == {'pid': 4234} assert user.state == {'pid': 4234}
found = orm.User.find(db, u'kaylee')
assert found.name == user.name
found = orm.User.find(db, u'badger')
assert found is None
def test_tokens(db): def test_tokens(db):
user = orm.User(name=u'inara') user = orm.User(name=u'inara')
@@ -87,3 +92,7 @@ def test_tokens(db):
assert len(user.api_tokens) == 1 assert len(user.api_tokens) == 1
assert len(user.cookie_tokens) == 3 assert len(user.cookie_tokens) == 3
found = orm.CookieToken.find(db, token=token.token)
assert found.token == token.token
found = orm.APIToken.find(db, token.token)
assert found is None