mirror of
https://github.com/jupyterhub/jupyterhub.git
synced 2025-10-15 14:03:02 +00:00
synchronize implementation of expiring values
- base Expiring class - ensures expiring values (OAuthCode, OAuthAccessToken, APIToken) are not returned from `find` - all expire appropriately via purge_expired
This commit is contained in:
@@ -1832,17 +1832,27 @@ class JupyterHub(Application):
|
||||
# purge expired tokens hourly
|
||||
purge_expired_tokens_interval = 3600
|
||||
|
||||
def purge_expired_tokens(self):
|
||||
"""purge all expiring token objects from the database
|
||||
|
||||
run periodically
|
||||
"""
|
||||
# this should be all the subclasses of Expiring
|
||||
for cls in (orm.APIToken, orm.OAuthAccessToken, orm.OAuthCode):
|
||||
self.log.debug("Purging expired {name}s".format(name=cls.__name__))
|
||||
cls.purge_expired(self.db)
|
||||
|
||||
async def init_api_tokens(self):
|
||||
"""Load predefined API tokens (for services) into database"""
|
||||
await self._add_tokens(self.service_tokens, kind='service')
|
||||
await self._add_tokens(self.api_tokens, kind='user')
|
||||
purge_expired_tokens = partial(orm.APIToken.purge_expired, self.db)
|
||||
purge_expired_tokens()
|
||||
|
||||
self.purge_expired_tokens()
|
||||
# purge expired tokens hourly
|
||||
# we don't need to be prompt about this
|
||||
# because expired tokens cannot be used anyway
|
||||
pc = PeriodicCallback(
|
||||
purge_expired_tokens, 1e3 * self.purge_expired_tokens_interval
|
||||
self.purge_expired_tokens, 1e3 * self.purge_expired_tokens_interval
|
||||
)
|
||||
pc.start()
|
||||
|
||||
|
@@ -2,7 +2,6 @@
|
||||
|
||||
implements https://oauthlib.readthedocs.io/en/latest/oauth2/server.html
|
||||
"""
|
||||
from datetime import datetime
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from oauthlib import uri_validate
|
||||
@@ -250,7 +249,7 @@ class JupyterHubRequestValidator(RequestValidator):
|
||||
client=orm_client,
|
||||
code=code['code'],
|
||||
# oauth has 5 minutes to complete
|
||||
expires_at=int(datetime.utcnow().timestamp() + 300),
|
||||
expires_at=int(orm.OAuthCode.now() + 300),
|
||||
# TODO: persist oauth scopes
|
||||
# scopes=request.scopes,
|
||||
user=request.user.orm_user,
|
||||
@@ -347,7 +346,7 @@ class JupyterHubRequestValidator(RequestValidator):
|
||||
orm_access_token = orm.OAuthAccessToken(
|
||||
client=client,
|
||||
grant_type=orm.GrantType.authorization_code,
|
||||
expires_at=datetime.utcnow().timestamp() + token['expires_in'],
|
||||
expires_at=orm.OAuthAccessToken.now() + token['expires_in'],
|
||||
refresh_token=token['refresh_token'],
|
||||
# TODO: save scopes,
|
||||
# scopes=scopes,
|
||||
@@ -441,7 +440,7 @@ class JupyterHubRequestValidator(RequestValidator):
|
||||
Method is used by:
|
||||
- Authorization Code Grant
|
||||
"""
|
||||
orm_code = self.db.query(orm.OAuthCode).filter_by(code=code).first()
|
||||
orm_code = orm.OAuthCode.find(self.db, code=code)
|
||||
if orm_code is None:
|
||||
app_log.debug("No such code: %s", code)
|
||||
return False
|
||||
|
@@ -311,7 +311,46 @@ class Service(Base):
|
||||
return db.query(cls).filter(cls.name == name).first()
|
||||
|
||||
|
||||
class Hashed(object):
|
||||
class Expiring:
|
||||
"""Mixin for expiring entries
|
||||
|
||||
Subclass must define at least expires_at property,
|
||||
which should be unix timestamp or datetime object
|
||||
"""
|
||||
|
||||
now = utcnow # funciton, must return float timestamp or datetime
|
||||
expires_at = None # must be defined
|
||||
|
||||
@property
|
||||
def expires_in(self):
|
||||
"""Property returning expiration in seconds from now
|
||||
|
||||
or None
|
||||
"""
|
||||
if self.expires_at:
|
||||
delta = self.expires_at - self.now()
|
||||
if isinstance(delta, timedelta):
|
||||
delta = delta.total_seconds()
|
||||
return delta
|
||||
else:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def purge_expired(cls, db):
|
||||
"""Purge expired API Tokens from the database"""
|
||||
now = cls.now()
|
||||
deleted = False
|
||||
for obj in (
|
||||
db.query(cls).filter(cls.expires_at != None).filter(cls.expires_at < now)
|
||||
):
|
||||
app_log.debug("Purging expired %s", obj)
|
||||
deleted = True
|
||||
db.delete(obj)
|
||||
if deleted:
|
||||
db.commit()
|
||||
|
||||
|
||||
class Hashed(Expiring):
|
||||
"""Mixin for tables with hashed tokens"""
|
||||
|
||||
prefix_length = 4
|
||||
@@ -368,11 +407,21 @@ class Hashed(object):
|
||||
"""Start the query for matching token.
|
||||
|
||||
Returns an SQLAlchemy query already filtered by prefix-matches.
|
||||
|
||||
.. versionchanged:: 1.2
|
||||
|
||||
Excludes expired matches.
|
||||
"""
|
||||
prefix = token[: cls.prefix_length]
|
||||
# since we can't filter on hashed values, filter on prefix
|
||||
# so we aren't comparing with all tokens
|
||||
return db.query(cls).filter(bindparam('prefix', prefix).startswith(cls.prefix))
|
||||
prefix_match = db.query(cls).filter(
|
||||
bindparam('prefix', prefix).startswith(cls.prefix)
|
||||
)
|
||||
prefix_match = prefix_match.filter(
|
||||
or_(cls.expires_at == None, cls.expires_at >= cls.now())
|
||||
)
|
||||
return prefix_match
|
||||
|
||||
@classmethod
|
||||
def find(cls, db, token):
|
||||
@@ -408,6 +457,7 @@ class APIToken(Hashed, Base):
|
||||
return 'a%i' % self.id
|
||||
|
||||
# token metadata for bookkeeping
|
||||
now = datetime.utcnow # for expiry
|
||||
created = Column(DateTime, default=datetime.utcnow)
|
||||
expires_at = Column(DateTime, default=None, nullable=True)
|
||||
last_activity = Column(DateTime)
|
||||
@@ -428,20 +478,6 @@ class APIToken(Hashed, Base):
|
||||
cls=self.__class__.__name__, pre=self.prefix, kind=kind, name=name
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def purge_expired(cls, db):
|
||||
"""Purge expired API Tokens from the database"""
|
||||
now = utcnow()
|
||||
deleted = False
|
||||
for token in (
|
||||
db.query(cls).filter(cls.expires_at != None).filter(cls.expires_at < now)
|
||||
):
|
||||
app_log.debug("Purging expired %s", token)
|
||||
deleted = True
|
||||
db.delete(token)
|
||||
if deleted:
|
||||
db.commit()
|
||||
|
||||
@classmethod
|
||||
def find(cls, db, token, *, kind=None):
|
||||
"""Find a token object by value.
|
||||
@@ -452,9 +488,6 @@ class APIToken(Hashed, Base):
|
||||
`kind='service'` only returns API tokens for services
|
||||
"""
|
||||
prefix_match = cls.find_prefix(db, token)
|
||||
prefix_match = prefix_match.filter(
|
||||
or_(cls.expires_at == None, cls.expires_at >= utcnow())
|
||||
)
|
||||
if kind == 'user':
|
||||
prefix_match = prefix_match.filter(cls.user_id != None)
|
||||
elif kind == 'service':
|
||||
@@ -497,7 +530,7 @@ class APIToken(Hashed, Base):
|
||||
assert service.id is not None
|
||||
orm_token.service = service
|
||||
if expires_in is not None:
|
||||
orm_token.expires_at = utcnow() + timedelta(seconds=expires_in)
|
||||
orm_token.expires_at = cls.now() + timedelta(seconds=expires_in)
|
||||
db.add(orm_token)
|
||||
db.commit()
|
||||
return token
|
||||
@@ -521,6 +554,10 @@ class OAuthAccessToken(Hashed, Base):
|
||||
__tablename__ = 'oauth_access_tokens'
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
|
||||
@staticmethod
|
||||
def now():
|
||||
return datetime.utcnow().timestamp()
|
||||
|
||||
@property
|
||||
def api_id(self):
|
||||
return 'o%i' % self.id
|
||||
@@ -547,11 +584,12 @@ class OAuthAccessToken(Hashed, Base):
|
||||
last_activity = Column(DateTime, nullable=True)
|
||||
|
||||
def __repr__(self):
|
||||
return "<{cls}('{prefix}...', client_id={client_id!r}, user={user!r}>".format(
|
||||
return "<{cls}('{prefix}...', client_id={client_id!r}, user={user!r}, expires_in={expires_in}>".format(
|
||||
cls=self.__class__.__name__,
|
||||
client_id=self.client_id,
|
||||
user=self.user and self.user.name,
|
||||
prefix=self.prefix,
|
||||
expires_in=self.expires_in,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -568,8 +606,9 @@ class OAuthAccessToken(Hashed, Base):
|
||||
return orm_token
|
||||
|
||||
|
||||
class OAuthCode(Base):
|
||||
class OAuthCode(Expiring, Base):
|
||||
__tablename__ = 'oauth_codes'
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
client_id = Column(
|
||||
Unicode(255), ForeignKey('oauth_clients.identifier', ondelete='CASCADE')
|
||||
@@ -581,6 +620,19 @@ class OAuthCode(Base):
|
||||
# state = Column(Unicode(1023))
|
||||
user_id = Column(Integer, ForeignKey('users.id', ondelete='CASCADE'))
|
||||
|
||||
@staticmethod
|
||||
def now():
|
||||
return datetime.utcnow().timestamp()
|
||||
|
||||
@classmethod
|
||||
def find(cls, db, code):
|
||||
return (
|
||||
db.query(cls)
|
||||
.filter(cls.code == code)
|
||||
.filter(or_(cls.expires_at == None, cls.expires_at >= cls.now()))
|
||||
.first()
|
||||
)
|
||||
|
||||
|
||||
class OAuthClient(Base):
|
||||
__tablename__ = 'oauth_clients'
|
||||
|
@@ -134,7 +134,7 @@ def test_token_expiry(db):
|
||||
assert orm_token.expires_at > now + timedelta(seconds=50)
|
||||
assert orm_token.expires_at < now + timedelta(seconds=70)
|
||||
the_future = mock.patch(
|
||||
'jupyterhub.orm.utcnow', lambda: now + timedelta(seconds=70)
|
||||
'jupyterhub.orm.APIToken.now', lambda: now + timedelta(seconds=70)
|
||||
)
|
||||
with the_future:
|
||||
found = orm.APIToken.find(db, token=token)
|
||||
@@ -482,3 +482,78 @@ def test_group_delete_cascade(db):
|
||||
db.delete(user1)
|
||||
db.commit()
|
||||
assert user1 not in group1.users
|
||||
|
||||
|
||||
def test_expiring_api_token(app, user):
|
||||
db = app.db
|
||||
token = orm.APIToken.new(expires_in=30, user=user)
|
||||
orm_token = orm.APIToken.find(db, token, kind='user')
|
||||
assert orm_token
|
||||
|
||||
# purge_expired doesn't delete non-expired
|
||||
orm.APIToken.purge_expired(db)
|
||||
found = orm.APIToken.find(db, token)
|
||||
assert found is orm_token
|
||||
|
||||
with mock.patch.object(
|
||||
orm.APIToken, 'now', lambda: datetime.utcnow() + timedelta(seconds=60)
|
||||
):
|
||||
found = orm.APIToken.find(db, token)
|
||||
assert found is None
|
||||
assert orm_token in db.query(orm.APIToken)
|
||||
orm.APIToken.purge_expired(db)
|
||||
assert orm_token not in db.query(orm.APIToken)
|
||||
|
||||
|
||||
def test_expiring_oauth_token(app, user):
|
||||
db = app.db
|
||||
token = "abc123"
|
||||
now = orm.OAuthAccessToken.now
|
||||
client = orm.OAuthClient(identifier="xxx", secret="yyy")
|
||||
db.add(client)
|
||||
orm_token = orm.OAuthAccessToken(
|
||||
token=token,
|
||||
grant_type=orm.GrantType.authorization_code,
|
||||
client=client,
|
||||
user=user,
|
||||
expires_at=now() + 30,
|
||||
)
|
||||
db.add(orm_token)
|
||||
db.commit()
|
||||
|
||||
found = orm.OAuthAccessToken.find(db, token)
|
||||
assert found is orm_token
|
||||
# purge_expired doesn't delete non-expired
|
||||
orm.OAuthAccessToken.purge_expired(db)
|
||||
found = orm.OAuthAccessToken.find(db, token)
|
||||
assert found is orm_token
|
||||
|
||||
with mock.patch.object(orm.OAuthAccessToken, 'now', lambda: now() + 60):
|
||||
found = orm.OAuthAccessToken.find(db, token)
|
||||
assert found is None
|
||||
assert orm_token in db.query(orm.OAuthAccessToken)
|
||||
orm.OAuthAccessToken.purge_expired(db)
|
||||
assert orm_token not in db.query(orm.OAuthAccessToken)
|
||||
|
||||
|
||||
def test_expiring_oauth_code(app, user):
|
||||
db = app.db
|
||||
code = "abc123"
|
||||
now = orm.OAuthCode.now
|
||||
orm_code = orm.OAuthCode(code=code, expires_at=now() + 30)
|
||||
db.add(orm_code)
|
||||
db.commit()
|
||||
|
||||
found = orm.OAuthCode.find(db, code)
|
||||
assert found is orm_code
|
||||
# purge_expired doesn't delete non-expired
|
||||
orm.OAuthCode.purge_expired(db)
|
||||
found = orm.OAuthCode.find(db, code)
|
||||
assert found is orm_code
|
||||
|
||||
with mock.patch.object(orm.OAuthCode, 'now', lambda: now() + 60):
|
||||
found = orm.OAuthCode.find(db, code)
|
||||
assert found is None
|
||||
assert orm_code in db.query(orm.OAuthCode)
|
||||
orm.OAuthCode.purge_expired(db)
|
||||
assert orm_code not in db.query(orm.OAuthCode)
|
||||
|
Reference in New Issue
Block a user