diff --git a/jupyterhub/alembic/versions/896818069c98_token_expires.py b/jupyterhub/alembic/versions/896818069c98_token_expires.py new file mode 100644 index 00000000..cdd5f5c0 --- /dev/null +++ b/jupyterhub/alembic/versions/896818069c98_token_expires.py @@ -0,0 +1,24 @@ +"""Add APIToken.expires_at + +Revision ID: 896818069c98 +Revises: d68c98b66cd4 +Create Date: 2018-05-07 11:35:58.050542 + +""" + +# revision identifiers, used by Alembic. +revision = '896818069c98' +down_revision = 'd68c98b66cd4' +branch_labels = None +depends_on = None + +from alembic import op +import sqlalchemy as sa + + +def upgrade(): + op.add_column('api_tokens', sa.Column('expires_at', sa.DateTime(), nullable=True)) + + +def downgrade(): + op.drop_column('api_tokens', 'expires_at') diff --git a/jupyterhub/apihandlers/users.py b/jupyterhub/apihandlers/users.py index ba569edb..e368746d 100644 --- a/jupyterhub/apihandlers/users.py +++ b/jupyterhub/apihandlers/users.py @@ -4,6 +4,7 @@ # Distributed under the terms of the Modified BSD License. import asyncio +from datetime import datetime import json from async_generator import aclosing @@ -201,13 +202,30 @@ class UserTokenListAPIHandler(APIHandler): user = self.find_user(name) if not user: raise web.HTTPError(404, "No such user: %s" % name) + + now = datetime.utcnow() + api_tokens = [] def sort_key(token): return token.last_activity or token.created + for token in sorted(user.api_tokens, key=sort_key): + if token.expires_at and token.expires_at < now: + # exclude expired tokens + self.db.delete(token) + self.db.commit() + continue api_tokens.append(self.token_model(token)) + oauth_tokens = [] + # OAuth tokens use integer timestamps + now_timestamp = now.timestamp() for token in sorted(user.oauth_tokens, key=sort_key): + if token.expires_at and token.expires_at < now_timestamp: + # exclude expired tokens + self.db.delete(token) + self.db.commit() + continue oauth_tokens.append(self.token_model(token)) self.write(json.dumps({ 'api_tokens': api_tokens, @@ -252,7 +270,7 @@ class UserTokenListAPIHandler(APIHandler): if requester is not user: note += " by %s %s" % (kind, requester.name) - api_token = user.new_api_token(note=note) + api_token = user.new_api_token(note=note, expires_in=body.get('expires_in', None)) if requester is not user: self.log.info("%s %s requested API token for %s", kind.title(), requester.name, user.name) else: diff --git a/jupyterhub/app.py b/jupyterhub/app.py index 57503d35..569560cf 100644 --- a/jupyterhub/app.py +++ b/jupyterhub/app.py @@ -9,6 +9,7 @@ import atexit import binascii from concurrent.futures import ThreadPoolExecutor from datetime import datetime, timezone +from functools import partial from getpass import getuser import logging from operator import itemgetter @@ -1249,10 +1250,23 @@ class JupyterHub(Application): self.log.debug("Not duplicating token %s", orm_token) db.commit() + # purge expired tokens hourly + purge_expired_tokens_interval = 3600 + async def init_api_tokens(self): """Load predefined API tokens (for services) into database""" await self._add_tokens(self.service_tokens, kind='service') await self._add_tokens(self.api_tokens, kind='user') + purge_expired_tokens = partial(orm.APIToken.purge_expired, self.db) + purge_expired_tokens() + # purge expired tokens hourly + # we don't need to be prompt about this + # because expired tokens cannot be used anyway + pc = PeriodicCallback( + purge_expired_tokens, + 1e3 * self.purge_expired_tokens_interval, + ) + pc.start() def init_services(self): self._service_map.clear() diff --git a/jupyterhub/handlers/pages.py b/jupyterhub/handlers/pages.py index f17fb1be..10fc1e2c 100644 --- a/jupyterhub/handlers/pages.py +++ b/jupyterhub/handlers/pages.py @@ -3,6 +3,7 @@ # Copyright (c) Jupyter Development Team. # Distributed under the terms of the Modified BSD License. +from collections import defaultdict from datetime import datetime from http.client import responses @@ -229,12 +230,24 @@ class TokenPageHandler(BaseHandler): token.last_activity or never, token.created or never, ) - api_tokens = sorted(user.api_tokens, key=sort_key, reverse=True) + + now = datetime.utcnow() + api_tokens = [] + for token in sorted(user.api_tokens, key=sort_key, reverse=True): + if token.expires_at and token.expires_at < now: + self.db.delete(token) + self.db.commit() + continue + api_tokens.append(token) # group oauth client tokens by client id - from collections import defaultdict oauth_tokens = defaultdict(list) for token in user.oauth_tokens: + if token.expires_at and token.expires_at < now: + self.log.warning("Deleting expired token") + self.db.delete(token) + self.db.commit() + continue if not token.client_id: # token should have been deleted when client was deleted self.log.warning("Deleting stale oauth token for %s", user.name) diff --git a/jupyterhub/orm.py b/jupyterhub/orm.py index 84d1956f..658f67e2 100644 --- a/jupyterhub/orm.py +++ b/jupyterhub/orm.py @@ -3,7 +3,7 @@ # Copyright (c) Jupyter Development Team. # Distributed under the terms of the Modified BSD License. -from datetime import datetime +from datetime import datetime, timedelta import enum import json @@ -14,7 +14,7 @@ from tornado.log import app_log from sqlalchemy.types import TypeDecorator, TEXT, LargeBinary from sqlalchemy import ( - create_engine, event, inspect, + create_engine, event, inspect, or_, Column, Integer, ForeignKey, Unicode, Boolean, DateTime, Enum, Table, ) @@ -33,6 +33,9 @@ from .utils import ( new_token, hash_token, compare_token, ) +# top-level variable for easier mocking in tests +utcnow = datetime.utcnow + class JSONDict(TypeDecorator): """Represents an immutable structure as a json-encoded string. @@ -176,12 +179,12 @@ class User(Base): running=sum(bool(s.server) for s in self._orm_spawners), ) - def new_api_token(self, token=None, generated=True, note=''): + def new_api_token(self, token=None, **kwargs): """Create a new API token If `token` is given, load that token. """ - return APIToken.new(token=token, user=self, note=note, generated=generated) + return APIToken.new(token=token, user=self, **kwargs) @classmethod def find(cls, db, name): @@ -242,11 +245,11 @@ class Service(Base): server = relationship(Server, cascade='all') pid = Column(Integer) - def new_api_token(self, token=None, generated=True, note=''): + def new_api_token(self, token=None, **kwargs): """Create a new API token If `token` is given, load that token. """ - return APIToken.new(token=token, service=self, note=note, generated=generated) + return APIToken.new(token=token, service=self, **kwargs) @classmethod def find(cls, db, name): @@ -348,6 +351,7 @@ class APIToken(Hashed, Base): # token metadata for bookkeeping created = Column(DateTime, default=datetime.utcnow) + expires_at = Column(DateTime, default=None, nullable=True) last_activity = Column(DateTime) note = Column(Unicode(1023)) @@ -369,6 +373,22 @@ class APIToken(Hashed, Base): name=name, ) + @classmethod + def purge_expired(cls, db): + """Purge expired API Tokens from the database""" + now = utcnow() + deleted = False + for token in ( + db.query(cls) + .filter(cls.expires_at != None) + .filter(cls.expires_at < now) + ): + app_log.debug("Purging expired %s", token) + deleted = True + db.delete(token) + if deleted: + db.commit() + @classmethod def find(cls, db, token, *, kind=None): """Find a token object by value. @@ -379,6 +399,9 @@ class APIToken(Hashed, Base): `kind='service'` only returns API tokens for services """ prefix_match = cls.find_prefix(db, token) + prefix_match = prefix_match.filter( + or_(cls.expires_at == None, cls.expires_at >= utcnow()) + ) if kind == 'user': prefix_match = prefix_match.filter(cls.user_id != None) elif kind == 'service': @@ -390,7 +413,8 @@ class APIToken(Hashed, Base): return orm_token @classmethod - def new(cls, token=None, user=None, service=None, note='', generated=True): + def new(cls, token=None, user=None, service=None, note='', generated=True, + expires_in=None): """Generate a new API token for a user or service""" assert user or service assert not (user and service) @@ -412,6 +436,8 @@ class APIToken(Hashed, Base): else: assert service.id is not None orm_token.service = service + if expires_in is not None: + orm_token.expires_at = utcnow() + timedelta(seconds=expires_in) db.add(orm_token) db.commit() return token diff --git a/jupyterhub/tests/test_orm.py b/jupyterhub/tests/test_orm.py index 7edbf3e3..2ccc99a1 100644 --- a/jupyterhub/tests/test_orm.py +++ b/jupyterhub/tests/test_orm.py @@ -3,8 +3,10 @@ # Copyright (c) Jupyter Development Team. # Distributed under the terms of the Modified BSD License. +from datetime import datetime, timedelta import os import socket +from unittest import mock import pytest from tornado import gen @@ -99,6 +101,32 @@ def test_tokens(db): assert len(user.api_tokens) == 3 +def test_token_expiry(db): + user = orm.User(name='parker') + db.add(user) + db.commit() + now = datetime.utcnow() + token = user.new_api_token(expires_in=60) + orm_token = orm.APIToken.find(db, token=token) + assert orm_token + assert orm_token.expires_at is not None + # approximate range + assert orm_token.expires_at > now + timedelta(seconds=50) + assert orm_token.expires_at < now + timedelta(seconds=70) + the_future = mock.patch('jupyterhub.orm.utcnow', lambda : now + timedelta(seconds=70)) + with the_future: + found = orm.APIToken.find(db, token=token) + assert found is None + # purging shouldn't delete non-expired tokens + orm.APIToken.purge_expired(db) + assert orm.APIToken.find(db, token=token) + with the_future: + orm.APIToken.purge_expired(db) + assert orm.APIToken.find(db, token=token) is None + # after purging, make sure we aren't in the user token list + assert orm_token not in user.api_tokens + + def test_service_tokens(db): service = orm.Service(name='secret') db.add(service)