synchronize implementation of expiring values

- base Expiring class
- ensures expiring values (OAuthCode, OAuthAccessToken, APIToken) are not returned from `find`
- all expire appropriately via purge_expired
This commit is contained in:
Min RK
2020-06-04 12:04:44 +02:00
parent fd28e224f2
commit d4b5373c05
4 changed files with 166 additions and 30 deletions

View File

@@ -1832,17 +1832,27 @@ class JupyterHub(Application):
# purge expired tokens hourly
purge_expired_tokens_interval = 3600
def purge_expired_tokens(self):
"""purge all expiring token objects from the database
run periodically
"""
# this should be all the subclasses of Expiring
for cls in (orm.APIToken, orm.OAuthAccessToken, orm.OAuthCode):
self.log.debug("Purging expired {name}s".format(name=cls.__name__))
cls.purge_expired(self.db)
async def init_api_tokens(self):
"""Load predefined API tokens (for services) into database"""
await self._add_tokens(self.service_tokens, kind='service')
await self._add_tokens(self.api_tokens, kind='user')
purge_expired_tokens = partial(orm.APIToken.purge_expired, self.db)
purge_expired_tokens()
self.purge_expired_tokens()
# purge expired tokens hourly
# we don't need to be prompt about this
# because expired tokens cannot be used anyway
pc = PeriodicCallback(
purge_expired_tokens, 1e3 * self.purge_expired_tokens_interval
self.purge_expired_tokens, 1e3 * self.purge_expired_tokens_interval
)
pc.start()

View File

@@ -2,7 +2,6 @@
implements https://oauthlib.readthedocs.io/en/latest/oauth2/server.html
"""
from datetime import datetime
from urllib.parse import urlparse
from oauthlib import uri_validate
@@ -250,7 +249,7 @@ class JupyterHubRequestValidator(RequestValidator):
client=orm_client,
code=code['code'],
# oauth has 5 minutes to complete
expires_at=int(datetime.utcnow().timestamp() + 300),
expires_at=int(orm.OAuthCode.now() + 300),
# TODO: persist oauth scopes
# scopes=request.scopes,
user=request.user.orm_user,
@@ -347,7 +346,7 @@ class JupyterHubRequestValidator(RequestValidator):
orm_access_token = orm.OAuthAccessToken(
client=client,
grant_type=orm.GrantType.authorization_code,
expires_at=datetime.utcnow().timestamp() + token['expires_in'],
expires_at=orm.OAuthAccessToken.now() + token['expires_in'],
refresh_token=token['refresh_token'],
# TODO: save scopes,
# scopes=scopes,
@@ -441,7 +440,7 @@ class JupyterHubRequestValidator(RequestValidator):
Method is used by:
- Authorization Code Grant
"""
orm_code = self.db.query(orm.OAuthCode).filter_by(code=code).first()
orm_code = orm.OAuthCode.find(self.db, code=code)
if orm_code is None:
app_log.debug("No such code: %s", code)
return False

View File

@@ -311,7 +311,46 @@ class Service(Base):
return db.query(cls).filter(cls.name == name).first()
class Hashed(object):
class Expiring:
"""Mixin for expiring entries
Subclass must define at least expires_at property,
which should be unix timestamp or datetime object
"""
now = utcnow # funciton, must return float timestamp or datetime
expires_at = None # must be defined
@property
def expires_in(self):
"""Property returning expiration in seconds from now
or None
"""
if self.expires_at:
delta = self.expires_at - self.now()
if isinstance(delta, timedelta):
delta = delta.total_seconds()
return delta
else:
return None
@classmethod
def purge_expired(cls, db):
"""Purge expired API Tokens from the database"""
now = cls.now()
deleted = False
for obj in (
db.query(cls).filter(cls.expires_at != None).filter(cls.expires_at < now)
):
app_log.debug("Purging expired %s", obj)
deleted = True
db.delete(obj)
if deleted:
db.commit()
class Hashed(Expiring):
"""Mixin for tables with hashed tokens"""
prefix_length = 4
@@ -368,11 +407,21 @@ class Hashed(object):
"""Start the query for matching token.
Returns an SQLAlchemy query already filtered by prefix-matches.
.. versionchanged:: 1.2
Excludes expired matches.
"""
prefix = token[: cls.prefix_length]
# since we can't filter on hashed values, filter on prefix
# so we aren't comparing with all tokens
return db.query(cls).filter(bindparam('prefix', prefix).startswith(cls.prefix))
prefix_match = db.query(cls).filter(
bindparam('prefix', prefix).startswith(cls.prefix)
)
prefix_match = prefix_match.filter(
or_(cls.expires_at == None, cls.expires_at >= cls.now())
)
return prefix_match
@classmethod
def find(cls, db, token):
@@ -408,6 +457,7 @@ class APIToken(Hashed, Base):
return 'a%i' % self.id
# token metadata for bookkeeping
now = datetime.utcnow # for expiry
created = Column(DateTime, default=datetime.utcnow)
expires_at = Column(DateTime, default=None, nullable=True)
last_activity = Column(DateTime)
@@ -428,20 +478,6 @@ class APIToken(Hashed, Base):
cls=self.__class__.__name__, pre=self.prefix, kind=kind, name=name
)
@classmethod
def purge_expired(cls, db):
"""Purge expired API Tokens from the database"""
now = utcnow()
deleted = False
for token in (
db.query(cls).filter(cls.expires_at != None).filter(cls.expires_at < now)
):
app_log.debug("Purging expired %s", token)
deleted = True
db.delete(token)
if deleted:
db.commit()
@classmethod
def find(cls, db, token, *, kind=None):
"""Find a token object by value.
@@ -452,9 +488,6 @@ class APIToken(Hashed, Base):
`kind='service'` only returns API tokens for services
"""
prefix_match = cls.find_prefix(db, token)
prefix_match = prefix_match.filter(
or_(cls.expires_at == None, cls.expires_at >= utcnow())
)
if kind == 'user':
prefix_match = prefix_match.filter(cls.user_id != None)
elif kind == 'service':
@@ -497,7 +530,7 @@ class APIToken(Hashed, Base):
assert service.id is not None
orm_token.service = service
if expires_in is not None:
orm_token.expires_at = utcnow() + timedelta(seconds=expires_in)
orm_token.expires_at = cls.now() + timedelta(seconds=expires_in)
db.add(orm_token)
db.commit()
return token
@@ -521,6 +554,10 @@ class OAuthAccessToken(Hashed, Base):
__tablename__ = 'oauth_access_tokens'
id = Column(Integer, primary_key=True, autoincrement=True)
@staticmethod
def now():
return datetime.utcnow().timestamp()
@property
def api_id(self):
return 'o%i' % self.id
@@ -547,11 +584,12 @@ class OAuthAccessToken(Hashed, Base):
last_activity = Column(DateTime, nullable=True)
def __repr__(self):
return "<{cls}('{prefix}...', client_id={client_id!r}, user={user!r}>".format(
return "<{cls}('{prefix}...', client_id={client_id!r}, user={user!r}, expires_in={expires_in}>".format(
cls=self.__class__.__name__,
client_id=self.client_id,
user=self.user and self.user.name,
prefix=self.prefix,
expires_in=self.expires_in,
)
@classmethod
@@ -568,8 +606,9 @@ class OAuthAccessToken(Hashed, Base):
return orm_token
class OAuthCode(Base):
class OAuthCode(Expiring, Base):
__tablename__ = 'oauth_codes'
id = Column(Integer, primary_key=True, autoincrement=True)
client_id = Column(
Unicode(255), ForeignKey('oauth_clients.identifier', ondelete='CASCADE')
@@ -581,6 +620,19 @@ class OAuthCode(Base):
# state = Column(Unicode(1023))
user_id = Column(Integer, ForeignKey('users.id', ondelete='CASCADE'))
@staticmethod
def now():
return datetime.utcnow().timestamp()
@classmethod
def find(cls, db, code):
return (
db.query(cls)
.filter(cls.code == code)
.filter(or_(cls.expires_at == None, cls.expires_at >= cls.now()))
.first()
)
class OAuthClient(Base):
__tablename__ = 'oauth_clients'

View File

@@ -134,7 +134,7 @@ def test_token_expiry(db):
assert orm_token.expires_at > now + timedelta(seconds=50)
assert orm_token.expires_at < now + timedelta(seconds=70)
the_future = mock.patch(
'jupyterhub.orm.utcnow', lambda: now + timedelta(seconds=70)
'jupyterhub.orm.APIToken.now', lambda: now + timedelta(seconds=70)
)
with the_future:
found = orm.APIToken.find(db, token=token)
@@ -482,3 +482,78 @@ def test_group_delete_cascade(db):
db.delete(user1)
db.commit()
assert user1 not in group1.users
def test_expiring_api_token(app, user):
db = app.db
token = orm.APIToken.new(expires_in=30, user=user)
orm_token = orm.APIToken.find(db, token, kind='user')
assert orm_token
# purge_expired doesn't delete non-expired
orm.APIToken.purge_expired(db)
found = orm.APIToken.find(db, token)
assert found is orm_token
with mock.patch.object(
orm.APIToken, 'now', lambda: datetime.utcnow() + timedelta(seconds=60)
):
found = orm.APIToken.find(db, token)
assert found is None
assert orm_token in db.query(orm.APIToken)
orm.APIToken.purge_expired(db)
assert orm_token not in db.query(orm.APIToken)
def test_expiring_oauth_token(app, user):
db = app.db
token = "abc123"
now = orm.OAuthAccessToken.now
client = orm.OAuthClient(identifier="xxx", secret="yyy")
db.add(client)
orm_token = orm.OAuthAccessToken(
token=token,
grant_type=orm.GrantType.authorization_code,
client=client,
user=user,
expires_at=now() + 30,
)
db.add(orm_token)
db.commit()
found = orm.OAuthAccessToken.find(db, token)
assert found is orm_token
# purge_expired doesn't delete non-expired
orm.OAuthAccessToken.purge_expired(db)
found = orm.OAuthAccessToken.find(db, token)
assert found is orm_token
with mock.patch.object(orm.OAuthAccessToken, 'now', lambda: now() + 60):
found = orm.OAuthAccessToken.find(db, token)
assert found is None
assert orm_token in db.query(orm.OAuthAccessToken)
orm.OAuthAccessToken.purge_expired(db)
assert orm_token not in db.query(orm.OAuthAccessToken)
def test_expiring_oauth_code(app, user):
db = app.db
code = "abc123"
now = orm.OAuthCode.now
orm_code = orm.OAuthCode(code=code, expires_at=now() + 30)
db.add(orm_code)
db.commit()
found = orm.OAuthCode.find(db, code)
assert found is orm_code
# purge_expired doesn't delete non-expired
orm.OAuthCode.purge_expired(db)
found = orm.OAuthCode.find(db, code)
assert found is orm_code
with mock.patch.object(orm.OAuthCode, 'now', lambda: now() + 60):
found = orm.OAuthCode.find(db, code)
assert found is None
assert orm_code in db.query(orm.OAuthCode)
orm.OAuthCode.purge_expired(db)
assert orm_code not in db.query(orm.OAuthCode)