mirror of
https://github.com/jupyterhub/jupyterhub.git
synced 2025-10-17 23:13:00 +00:00
synchronize implementation of expiring values
- base Expiring class - ensures expiring values (OAuthCode, OAuthAccessToken, APIToken) are not returned from `find` - all expire appropriately via purge_expired
This commit is contained in:
@@ -1832,17 +1832,27 @@ class JupyterHub(Application):
|
|||||||
# purge expired tokens hourly
|
# purge expired tokens hourly
|
||||||
purge_expired_tokens_interval = 3600
|
purge_expired_tokens_interval = 3600
|
||||||
|
|
||||||
|
def purge_expired_tokens(self):
|
||||||
|
"""purge all expiring token objects from the database
|
||||||
|
|
||||||
|
run periodically
|
||||||
|
"""
|
||||||
|
# this should be all the subclasses of Expiring
|
||||||
|
for cls in (orm.APIToken, orm.OAuthAccessToken, orm.OAuthCode):
|
||||||
|
self.log.debug("Purging expired {name}s".format(name=cls.__name__))
|
||||||
|
cls.purge_expired(self.db)
|
||||||
|
|
||||||
async def init_api_tokens(self):
|
async def init_api_tokens(self):
|
||||||
"""Load predefined API tokens (for services) into database"""
|
"""Load predefined API tokens (for services) into database"""
|
||||||
await self._add_tokens(self.service_tokens, kind='service')
|
await self._add_tokens(self.service_tokens, kind='service')
|
||||||
await self._add_tokens(self.api_tokens, kind='user')
|
await self._add_tokens(self.api_tokens, kind='user')
|
||||||
purge_expired_tokens = partial(orm.APIToken.purge_expired, self.db)
|
|
||||||
purge_expired_tokens()
|
self.purge_expired_tokens()
|
||||||
# purge expired tokens hourly
|
# purge expired tokens hourly
|
||||||
# we don't need to be prompt about this
|
# we don't need to be prompt about this
|
||||||
# because expired tokens cannot be used anyway
|
# because expired tokens cannot be used anyway
|
||||||
pc = PeriodicCallback(
|
pc = PeriodicCallback(
|
||||||
purge_expired_tokens, 1e3 * self.purge_expired_tokens_interval
|
self.purge_expired_tokens, 1e3 * self.purge_expired_tokens_interval
|
||||||
)
|
)
|
||||||
pc.start()
|
pc.start()
|
||||||
|
|
||||||
|
@@ -2,7 +2,6 @@
|
|||||||
|
|
||||||
implements https://oauthlib.readthedocs.io/en/latest/oauth2/server.html
|
implements https://oauthlib.readthedocs.io/en/latest/oauth2/server.html
|
||||||
"""
|
"""
|
||||||
from datetime import datetime
|
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
from oauthlib import uri_validate
|
from oauthlib import uri_validate
|
||||||
@@ -250,7 +249,7 @@ class JupyterHubRequestValidator(RequestValidator):
|
|||||||
client=orm_client,
|
client=orm_client,
|
||||||
code=code['code'],
|
code=code['code'],
|
||||||
# oauth has 5 minutes to complete
|
# oauth has 5 minutes to complete
|
||||||
expires_at=int(datetime.utcnow().timestamp() + 300),
|
expires_at=int(orm.OAuthCode.now() + 300),
|
||||||
# TODO: persist oauth scopes
|
# TODO: persist oauth scopes
|
||||||
# scopes=request.scopes,
|
# scopes=request.scopes,
|
||||||
user=request.user.orm_user,
|
user=request.user.orm_user,
|
||||||
@@ -347,7 +346,7 @@ class JupyterHubRequestValidator(RequestValidator):
|
|||||||
orm_access_token = orm.OAuthAccessToken(
|
orm_access_token = orm.OAuthAccessToken(
|
||||||
client=client,
|
client=client,
|
||||||
grant_type=orm.GrantType.authorization_code,
|
grant_type=orm.GrantType.authorization_code,
|
||||||
expires_at=datetime.utcnow().timestamp() + token['expires_in'],
|
expires_at=orm.OAuthAccessToken.now() + token['expires_in'],
|
||||||
refresh_token=token['refresh_token'],
|
refresh_token=token['refresh_token'],
|
||||||
# TODO: save scopes,
|
# TODO: save scopes,
|
||||||
# scopes=scopes,
|
# scopes=scopes,
|
||||||
@@ -441,7 +440,7 @@ class JupyterHubRequestValidator(RequestValidator):
|
|||||||
Method is used by:
|
Method is used by:
|
||||||
- Authorization Code Grant
|
- Authorization Code Grant
|
||||||
"""
|
"""
|
||||||
orm_code = self.db.query(orm.OAuthCode).filter_by(code=code).first()
|
orm_code = orm.OAuthCode.find(self.db, code=code)
|
||||||
if orm_code is None:
|
if orm_code is None:
|
||||||
app_log.debug("No such code: %s", code)
|
app_log.debug("No such code: %s", code)
|
||||||
return False
|
return False
|
||||||
|
@@ -311,7 +311,46 @@ class Service(Base):
|
|||||||
return db.query(cls).filter(cls.name == name).first()
|
return db.query(cls).filter(cls.name == name).first()
|
||||||
|
|
||||||
|
|
||||||
class Hashed(object):
|
class Expiring:
|
||||||
|
"""Mixin for expiring entries
|
||||||
|
|
||||||
|
Subclass must define at least expires_at property,
|
||||||
|
which should be unix timestamp or datetime object
|
||||||
|
"""
|
||||||
|
|
||||||
|
now = utcnow # funciton, must return float timestamp or datetime
|
||||||
|
expires_at = None # must be defined
|
||||||
|
|
||||||
|
@property
|
||||||
|
def expires_in(self):
|
||||||
|
"""Property returning expiration in seconds from now
|
||||||
|
|
||||||
|
or None
|
||||||
|
"""
|
||||||
|
if self.expires_at:
|
||||||
|
delta = self.expires_at - self.now()
|
||||||
|
if isinstance(delta, timedelta):
|
||||||
|
delta = delta.total_seconds()
|
||||||
|
return delta
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def purge_expired(cls, db):
|
||||||
|
"""Purge expired API Tokens from the database"""
|
||||||
|
now = cls.now()
|
||||||
|
deleted = False
|
||||||
|
for obj in (
|
||||||
|
db.query(cls).filter(cls.expires_at != None).filter(cls.expires_at < now)
|
||||||
|
):
|
||||||
|
app_log.debug("Purging expired %s", obj)
|
||||||
|
deleted = True
|
||||||
|
db.delete(obj)
|
||||||
|
if deleted:
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
|
||||||
|
class Hashed(Expiring):
|
||||||
"""Mixin for tables with hashed tokens"""
|
"""Mixin for tables with hashed tokens"""
|
||||||
|
|
||||||
prefix_length = 4
|
prefix_length = 4
|
||||||
@@ -368,11 +407,21 @@ class Hashed(object):
|
|||||||
"""Start the query for matching token.
|
"""Start the query for matching token.
|
||||||
|
|
||||||
Returns an SQLAlchemy query already filtered by prefix-matches.
|
Returns an SQLAlchemy query already filtered by prefix-matches.
|
||||||
|
|
||||||
|
.. versionchanged:: 1.2
|
||||||
|
|
||||||
|
Excludes expired matches.
|
||||||
"""
|
"""
|
||||||
prefix = token[: cls.prefix_length]
|
prefix = token[: cls.prefix_length]
|
||||||
# since we can't filter on hashed values, filter on prefix
|
# since we can't filter on hashed values, filter on prefix
|
||||||
# so we aren't comparing with all tokens
|
# so we aren't comparing with all tokens
|
||||||
return db.query(cls).filter(bindparam('prefix', prefix).startswith(cls.prefix))
|
prefix_match = db.query(cls).filter(
|
||||||
|
bindparam('prefix', prefix).startswith(cls.prefix)
|
||||||
|
)
|
||||||
|
prefix_match = prefix_match.filter(
|
||||||
|
or_(cls.expires_at == None, cls.expires_at >= cls.now())
|
||||||
|
)
|
||||||
|
return prefix_match
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def find(cls, db, token):
|
def find(cls, db, token):
|
||||||
@@ -408,6 +457,7 @@ class APIToken(Hashed, Base):
|
|||||||
return 'a%i' % self.id
|
return 'a%i' % self.id
|
||||||
|
|
||||||
# token metadata for bookkeeping
|
# token metadata for bookkeeping
|
||||||
|
now = datetime.utcnow # for expiry
|
||||||
created = Column(DateTime, default=datetime.utcnow)
|
created = Column(DateTime, default=datetime.utcnow)
|
||||||
expires_at = Column(DateTime, default=None, nullable=True)
|
expires_at = Column(DateTime, default=None, nullable=True)
|
||||||
last_activity = Column(DateTime)
|
last_activity = Column(DateTime)
|
||||||
@@ -428,20 +478,6 @@ class APIToken(Hashed, Base):
|
|||||||
cls=self.__class__.__name__, pre=self.prefix, kind=kind, name=name
|
cls=self.__class__.__name__, pre=self.prefix, kind=kind, name=name
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def purge_expired(cls, db):
|
|
||||||
"""Purge expired API Tokens from the database"""
|
|
||||||
now = utcnow()
|
|
||||||
deleted = False
|
|
||||||
for token in (
|
|
||||||
db.query(cls).filter(cls.expires_at != None).filter(cls.expires_at < now)
|
|
||||||
):
|
|
||||||
app_log.debug("Purging expired %s", token)
|
|
||||||
deleted = True
|
|
||||||
db.delete(token)
|
|
||||||
if deleted:
|
|
||||||
db.commit()
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def find(cls, db, token, *, kind=None):
|
def find(cls, db, token, *, kind=None):
|
||||||
"""Find a token object by value.
|
"""Find a token object by value.
|
||||||
@@ -452,9 +488,6 @@ class APIToken(Hashed, Base):
|
|||||||
`kind='service'` only returns API tokens for services
|
`kind='service'` only returns API tokens for services
|
||||||
"""
|
"""
|
||||||
prefix_match = cls.find_prefix(db, token)
|
prefix_match = cls.find_prefix(db, token)
|
||||||
prefix_match = prefix_match.filter(
|
|
||||||
or_(cls.expires_at == None, cls.expires_at >= utcnow())
|
|
||||||
)
|
|
||||||
if kind == 'user':
|
if kind == 'user':
|
||||||
prefix_match = prefix_match.filter(cls.user_id != None)
|
prefix_match = prefix_match.filter(cls.user_id != None)
|
||||||
elif kind == 'service':
|
elif kind == 'service':
|
||||||
@@ -497,7 +530,7 @@ class APIToken(Hashed, Base):
|
|||||||
assert service.id is not None
|
assert service.id is not None
|
||||||
orm_token.service = service
|
orm_token.service = service
|
||||||
if expires_in is not None:
|
if expires_in is not None:
|
||||||
orm_token.expires_at = utcnow() + timedelta(seconds=expires_in)
|
orm_token.expires_at = cls.now() + timedelta(seconds=expires_in)
|
||||||
db.add(orm_token)
|
db.add(orm_token)
|
||||||
db.commit()
|
db.commit()
|
||||||
return token
|
return token
|
||||||
@@ -521,6 +554,10 @@ class OAuthAccessToken(Hashed, Base):
|
|||||||
__tablename__ = 'oauth_access_tokens'
|
__tablename__ = 'oauth_access_tokens'
|
||||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def now():
|
||||||
|
return datetime.utcnow().timestamp()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def api_id(self):
|
def api_id(self):
|
||||||
return 'o%i' % self.id
|
return 'o%i' % self.id
|
||||||
@@ -547,11 +584,12 @@ class OAuthAccessToken(Hashed, Base):
|
|||||||
last_activity = Column(DateTime, nullable=True)
|
last_activity = Column(DateTime, nullable=True)
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return "<{cls}('{prefix}...', client_id={client_id!r}, user={user!r}>".format(
|
return "<{cls}('{prefix}...', client_id={client_id!r}, user={user!r}, expires_in={expires_in}>".format(
|
||||||
cls=self.__class__.__name__,
|
cls=self.__class__.__name__,
|
||||||
client_id=self.client_id,
|
client_id=self.client_id,
|
||||||
user=self.user and self.user.name,
|
user=self.user and self.user.name,
|
||||||
prefix=self.prefix,
|
prefix=self.prefix,
|
||||||
|
expires_in=self.expires_in,
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -568,8 +606,9 @@ class OAuthAccessToken(Hashed, Base):
|
|||||||
return orm_token
|
return orm_token
|
||||||
|
|
||||||
|
|
||||||
class OAuthCode(Base):
|
class OAuthCode(Expiring, Base):
|
||||||
__tablename__ = 'oauth_codes'
|
__tablename__ = 'oauth_codes'
|
||||||
|
|
||||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||||
client_id = Column(
|
client_id = Column(
|
||||||
Unicode(255), ForeignKey('oauth_clients.identifier', ondelete='CASCADE')
|
Unicode(255), ForeignKey('oauth_clients.identifier', ondelete='CASCADE')
|
||||||
@@ -581,6 +620,19 @@ class OAuthCode(Base):
|
|||||||
# state = Column(Unicode(1023))
|
# state = Column(Unicode(1023))
|
||||||
user_id = Column(Integer, ForeignKey('users.id', ondelete='CASCADE'))
|
user_id = Column(Integer, ForeignKey('users.id', ondelete='CASCADE'))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def now():
|
||||||
|
return datetime.utcnow().timestamp()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def find(cls, db, code):
|
||||||
|
return (
|
||||||
|
db.query(cls)
|
||||||
|
.filter(cls.code == code)
|
||||||
|
.filter(or_(cls.expires_at == None, cls.expires_at >= cls.now()))
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class OAuthClient(Base):
|
class OAuthClient(Base):
|
||||||
__tablename__ = 'oauth_clients'
|
__tablename__ = 'oauth_clients'
|
||||||
|
@@ -134,7 +134,7 @@ def test_token_expiry(db):
|
|||||||
assert orm_token.expires_at > now + timedelta(seconds=50)
|
assert orm_token.expires_at > now + timedelta(seconds=50)
|
||||||
assert orm_token.expires_at < now + timedelta(seconds=70)
|
assert orm_token.expires_at < now + timedelta(seconds=70)
|
||||||
the_future = mock.patch(
|
the_future = mock.patch(
|
||||||
'jupyterhub.orm.utcnow', lambda: now + timedelta(seconds=70)
|
'jupyterhub.orm.APIToken.now', lambda: now + timedelta(seconds=70)
|
||||||
)
|
)
|
||||||
with the_future:
|
with the_future:
|
||||||
found = orm.APIToken.find(db, token=token)
|
found = orm.APIToken.find(db, token=token)
|
||||||
@@ -482,3 +482,78 @@ def test_group_delete_cascade(db):
|
|||||||
db.delete(user1)
|
db.delete(user1)
|
||||||
db.commit()
|
db.commit()
|
||||||
assert user1 not in group1.users
|
assert user1 not in group1.users
|
||||||
|
|
||||||
|
|
||||||
|
def test_expiring_api_token(app, user):
|
||||||
|
db = app.db
|
||||||
|
token = orm.APIToken.new(expires_in=30, user=user)
|
||||||
|
orm_token = orm.APIToken.find(db, token, kind='user')
|
||||||
|
assert orm_token
|
||||||
|
|
||||||
|
# purge_expired doesn't delete non-expired
|
||||||
|
orm.APIToken.purge_expired(db)
|
||||||
|
found = orm.APIToken.find(db, token)
|
||||||
|
assert found is orm_token
|
||||||
|
|
||||||
|
with mock.patch.object(
|
||||||
|
orm.APIToken, 'now', lambda: datetime.utcnow() + timedelta(seconds=60)
|
||||||
|
):
|
||||||
|
found = orm.APIToken.find(db, token)
|
||||||
|
assert found is None
|
||||||
|
assert orm_token in db.query(orm.APIToken)
|
||||||
|
orm.APIToken.purge_expired(db)
|
||||||
|
assert orm_token not in db.query(orm.APIToken)
|
||||||
|
|
||||||
|
|
||||||
|
def test_expiring_oauth_token(app, user):
|
||||||
|
db = app.db
|
||||||
|
token = "abc123"
|
||||||
|
now = orm.OAuthAccessToken.now
|
||||||
|
client = orm.OAuthClient(identifier="xxx", secret="yyy")
|
||||||
|
db.add(client)
|
||||||
|
orm_token = orm.OAuthAccessToken(
|
||||||
|
token=token,
|
||||||
|
grant_type=orm.GrantType.authorization_code,
|
||||||
|
client=client,
|
||||||
|
user=user,
|
||||||
|
expires_at=now() + 30,
|
||||||
|
)
|
||||||
|
db.add(orm_token)
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
found = orm.OAuthAccessToken.find(db, token)
|
||||||
|
assert found is orm_token
|
||||||
|
# purge_expired doesn't delete non-expired
|
||||||
|
orm.OAuthAccessToken.purge_expired(db)
|
||||||
|
found = orm.OAuthAccessToken.find(db, token)
|
||||||
|
assert found is orm_token
|
||||||
|
|
||||||
|
with mock.patch.object(orm.OAuthAccessToken, 'now', lambda: now() + 60):
|
||||||
|
found = orm.OAuthAccessToken.find(db, token)
|
||||||
|
assert found is None
|
||||||
|
assert orm_token in db.query(orm.OAuthAccessToken)
|
||||||
|
orm.OAuthAccessToken.purge_expired(db)
|
||||||
|
assert orm_token not in db.query(orm.OAuthAccessToken)
|
||||||
|
|
||||||
|
|
||||||
|
def test_expiring_oauth_code(app, user):
|
||||||
|
db = app.db
|
||||||
|
code = "abc123"
|
||||||
|
now = orm.OAuthCode.now
|
||||||
|
orm_code = orm.OAuthCode(code=code, expires_at=now() + 30)
|
||||||
|
db.add(orm_code)
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
found = orm.OAuthCode.find(db, code)
|
||||||
|
assert found is orm_code
|
||||||
|
# purge_expired doesn't delete non-expired
|
||||||
|
orm.OAuthCode.purge_expired(db)
|
||||||
|
found = orm.OAuthCode.find(db, code)
|
||||||
|
assert found is orm_code
|
||||||
|
|
||||||
|
with mock.patch.object(orm.OAuthCode, 'now', lambda: now() + 60):
|
||||||
|
found = orm.OAuthCode.find(db, code)
|
||||||
|
assert found is None
|
||||||
|
assert orm_code in db.query(orm.OAuthCode)
|
||||||
|
orm.OAuthCode.purge_expired(db)
|
||||||
|
assert orm_code not in db.query(orm.OAuthCode)
|
||||||
|
Reference in New Issue
Block a user