From d4b5373c05c2614cb1af01128f18bc241d48e02c Mon Sep 17 00:00:00 2001 From: Min RK Date: Thu, 4 Jun 2020 12:04:44 +0200 Subject: [PATCH] synchronize implementation of expiring values - base Expiring class - ensures expiring values (OAuthCode, OAuthAccessToken, APIToken) are not returned from `find` - all expire appropriately via purge_expired --- jupyterhub/app.py | 16 ++++-- jupyterhub/oauth/provider.py | 7 ++- jupyterhub/orm.py | 96 +++++++++++++++++++++++++++--------- jupyterhub/tests/test_orm.py | 77 ++++++++++++++++++++++++++++- 4 files changed, 166 insertions(+), 30 deletions(-) diff --git a/jupyterhub/app.py b/jupyterhub/app.py index 50809e19..174d5c71 100644 --- a/jupyterhub/app.py +++ b/jupyterhub/app.py @@ -1832,17 +1832,27 @@ class JupyterHub(Application): # purge expired tokens hourly purge_expired_tokens_interval = 3600 + def purge_expired_tokens(self): + """purge all expiring token objects from the database + + run periodically + """ + # this should be all the subclasses of Expiring + for cls in (orm.APIToken, orm.OAuthAccessToken, orm.OAuthCode): + self.log.debug("Purging expired {name}s".format(name=cls.__name__)) + cls.purge_expired(self.db) + async def init_api_tokens(self): """Load predefined API tokens (for services) into database""" await self._add_tokens(self.service_tokens, kind='service') await self._add_tokens(self.api_tokens, kind='user') - purge_expired_tokens = partial(orm.APIToken.purge_expired, self.db) - purge_expired_tokens() + + self.purge_expired_tokens() # purge expired tokens hourly # we don't need to be prompt about this # because expired tokens cannot be used anyway pc = PeriodicCallback( - purge_expired_tokens, 1e3 * self.purge_expired_tokens_interval + self.purge_expired_tokens, 1e3 * self.purge_expired_tokens_interval ) pc.start() diff --git a/jupyterhub/oauth/provider.py b/jupyterhub/oauth/provider.py index 6157223f..7c31d90d 100644 --- a/jupyterhub/oauth/provider.py +++ b/jupyterhub/oauth/provider.py @@ -2,7 +2,6 @@ implements https://oauthlib.readthedocs.io/en/latest/oauth2/server.html """ -from datetime import datetime from urllib.parse import urlparse from oauthlib import uri_validate @@ -250,7 +249,7 @@ class JupyterHubRequestValidator(RequestValidator): client=orm_client, code=code['code'], # oauth has 5 minutes to complete - expires_at=int(datetime.utcnow().timestamp() + 300), + expires_at=int(orm.OAuthCode.now() + 300), # TODO: persist oauth scopes # scopes=request.scopes, user=request.user.orm_user, @@ -347,7 +346,7 @@ class JupyterHubRequestValidator(RequestValidator): orm_access_token = orm.OAuthAccessToken( client=client, grant_type=orm.GrantType.authorization_code, - expires_at=datetime.utcnow().timestamp() + token['expires_in'], + expires_at=orm.OAuthAccessToken.now() + token['expires_in'], refresh_token=token['refresh_token'], # TODO: save scopes, # scopes=scopes, @@ -441,7 +440,7 @@ class JupyterHubRequestValidator(RequestValidator): Method is used by: - Authorization Code Grant """ - orm_code = self.db.query(orm.OAuthCode).filter_by(code=code).first() + orm_code = orm.OAuthCode.find(self.db, code=code) if orm_code is None: app_log.debug("No such code: %s", code) return False diff --git a/jupyterhub/orm.py b/jupyterhub/orm.py index 130d16f0..cc96ca88 100644 --- a/jupyterhub/orm.py +++ b/jupyterhub/orm.py @@ -311,7 +311,46 @@ class Service(Base): return db.query(cls).filter(cls.name == name).first() -class Hashed(object): +class Expiring: + """Mixin for expiring entries + + Subclass must define at least expires_at property, + which should be unix timestamp or datetime object + """ + + now = utcnow # funciton, must return float timestamp or datetime + expires_at = None # must be defined + + @property + def expires_in(self): + """Property returning expiration in seconds from now + + or None + """ + if self.expires_at: + delta = self.expires_at - self.now() + if isinstance(delta, timedelta): + delta = delta.total_seconds() + return delta + else: + return None + + @classmethod + def purge_expired(cls, db): + """Purge expired API Tokens from the database""" + now = cls.now() + deleted = False + for obj in ( + db.query(cls).filter(cls.expires_at != None).filter(cls.expires_at < now) + ): + app_log.debug("Purging expired %s", obj) + deleted = True + db.delete(obj) + if deleted: + db.commit() + + +class Hashed(Expiring): """Mixin for tables with hashed tokens""" prefix_length = 4 @@ -368,11 +407,21 @@ class Hashed(object): """Start the query for matching token. Returns an SQLAlchemy query already filtered by prefix-matches. + + .. versionchanged:: 1.2 + + Excludes expired matches. """ prefix = token[: cls.prefix_length] # since we can't filter on hashed values, filter on prefix # so we aren't comparing with all tokens - return db.query(cls).filter(bindparam('prefix', prefix).startswith(cls.prefix)) + prefix_match = db.query(cls).filter( + bindparam('prefix', prefix).startswith(cls.prefix) + ) + prefix_match = prefix_match.filter( + or_(cls.expires_at == None, cls.expires_at >= cls.now()) + ) + return prefix_match @classmethod def find(cls, db, token): @@ -408,6 +457,7 @@ class APIToken(Hashed, Base): return 'a%i' % self.id # token metadata for bookkeeping + now = datetime.utcnow # for expiry created = Column(DateTime, default=datetime.utcnow) expires_at = Column(DateTime, default=None, nullable=True) last_activity = Column(DateTime) @@ -428,20 +478,6 @@ class APIToken(Hashed, Base): cls=self.__class__.__name__, pre=self.prefix, kind=kind, name=name ) - @classmethod - def purge_expired(cls, db): - """Purge expired API Tokens from the database""" - now = utcnow() - deleted = False - for token in ( - db.query(cls).filter(cls.expires_at != None).filter(cls.expires_at < now) - ): - app_log.debug("Purging expired %s", token) - deleted = True - db.delete(token) - if deleted: - db.commit() - @classmethod def find(cls, db, token, *, kind=None): """Find a token object by value. @@ -452,9 +488,6 @@ class APIToken(Hashed, Base): `kind='service'` only returns API tokens for services """ prefix_match = cls.find_prefix(db, token) - prefix_match = prefix_match.filter( - or_(cls.expires_at == None, cls.expires_at >= utcnow()) - ) if kind == 'user': prefix_match = prefix_match.filter(cls.user_id != None) elif kind == 'service': @@ -497,7 +530,7 @@ class APIToken(Hashed, Base): assert service.id is not None orm_token.service = service if expires_in is not None: - orm_token.expires_at = utcnow() + timedelta(seconds=expires_in) + orm_token.expires_at = cls.now() + timedelta(seconds=expires_in) db.add(orm_token) db.commit() return token @@ -521,6 +554,10 @@ class OAuthAccessToken(Hashed, Base): __tablename__ = 'oauth_access_tokens' id = Column(Integer, primary_key=True, autoincrement=True) + @staticmethod + def now(): + return datetime.utcnow().timestamp() + @property def api_id(self): return 'o%i' % self.id @@ -547,11 +584,12 @@ class OAuthAccessToken(Hashed, Base): last_activity = Column(DateTime, nullable=True) def __repr__(self): - return "<{cls}('{prefix}...', client_id={client_id!r}, user={user!r}>".format( + return "<{cls}('{prefix}...', client_id={client_id!r}, user={user!r}, expires_in={expires_in}>".format( cls=self.__class__.__name__, client_id=self.client_id, user=self.user and self.user.name, prefix=self.prefix, + expires_in=self.expires_in, ) @classmethod @@ -568,8 +606,9 @@ class OAuthAccessToken(Hashed, Base): return orm_token -class OAuthCode(Base): +class OAuthCode(Expiring, Base): __tablename__ = 'oauth_codes' + id = Column(Integer, primary_key=True, autoincrement=True) client_id = Column( Unicode(255), ForeignKey('oauth_clients.identifier', ondelete='CASCADE') @@ -581,6 +620,19 @@ class OAuthCode(Base): # state = Column(Unicode(1023)) user_id = Column(Integer, ForeignKey('users.id', ondelete='CASCADE')) + @staticmethod + def now(): + return datetime.utcnow().timestamp() + + @classmethod + def find(cls, db, code): + return ( + db.query(cls) + .filter(cls.code == code) + .filter(or_(cls.expires_at == None, cls.expires_at >= cls.now())) + .first() + ) + class OAuthClient(Base): __tablename__ = 'oauth_clients' diff --git a/jupyterhub/tests/test_orm.py b/jupyterhub/tests/test_orm.py index 4790f386..0c125c5a 100644 --- a/jupyterhub/tests/test_orm.py +++ b/jupyterhub/tests/test_orm.py @@ -134,7 +134,7 @@ def test_token_expiry(db): assert orm_token.expires_at > now + timedelta(seconds=50) assert orm_token.expires_at < now + timedelta(seconds=70) the_future = mock.patch( - 'jupyterhub.orm.utcnow', lambda: now + timedelta(seconds=70) + 'jupyterhub.orm.APIToken.now', lambda: now + timedelta(seconds=70) ) with the_future: found = orm.APIToken.find(db, token=token) @@ -482,3 +482,78 @@ def test_group_delete_cascade(db): db.delete(user1) db.commit() assert user1 not in group1.users + + +def test_expiring_api_token(app, user): + db = app.db + token = orm.APIToken.new(expires_in=30, user=user) + orm_token = orm.APIToken.find(db, token, kind='user') + assert orm_token + + # purge_expired doesn't delete non-expired + orm.APIToken.purge_expired(db) + found = orm.APIToken.find(db, token) + assert found is orm_token + + with mock.patch.object( + orm.APIToken, 'now', lambda: datetime.utcnow() + timedelta(seconds=60) + ): + found = orm.APIToken.find(db, token) + assert found is None + assert orm_token in db.query(orm.APIToken) + orm.APIToken.purge_expired(db) + assert orm_token not in db.query(orm.APIToken) + + +def test_expiring_oauth_token(app, user): + db = app.db + token = "abc123" + now = orm.OAuthAccessToken.now + client = orm.OAuthClient(identifier="xxx", secret="yyy") + db.add(client) + orm_token = orm.OAuthAccessToken( + token=token, + grant_type=orm.GrantType.authorization_code, + client=client, + user=user, + expires_at=now() + 30, + ) + db.add(orm_token) + db.commit() + + found = orm.OAuthAccessToken.find(db, token) + assert found is orm_token + # purge_expired doesn't delete non-expired + orm.OAuthAccessToken.purge_expired(db) + found = orm.OAuthAccessToken.find(db, token) + assert found is orm_token + + with mock.patch.object(orm.OAuthAccessToken, 'now', lambda: now() + 60): + found = orm.OAuthAccessToken.find(db, token) + assert found is None + assert orm_token in db.query(orm.OAuthAccessToken) + orm.OAuthAccessToken.purge_expired(db) + assert orm_token not in db.query(orm.OAuthAccessToken) + + +def test_expiring_oauth_code(app, user): + db = app.db + code = "abc123" + now = orm.OAuthCode.now + orm_code = orm.OAuthCode(code=code, expires_at=now() + 30) + db.add(orm_code) + db.commit() + + found = orm.OAuthCode.find(db, code) + assert found is orm_code + # purge_expired doesn't delete non-expired + orm.OAuthCode.purge_expired(db) + found = orm.OAuthCode.find(db, code) + assert found is orm_code + + with mock.patch.object(orm.OAuthCode, 'now', lambda: now() + 60): + found = orm.OAuthCode.find(db, code) + assert found is None + assert orm_code in db.query(orm.OAuthCode) + orm.OAuthCode.purge_expired(db) + assert orm_code not in db.query(orm.OAuthCode)