implement API token expiry

This commit is contained in:
Min RK
2018-05-07 11:59:12 +02:00
parent a17f5e4f1b
commit 58c91e3fd4
6 changed files with 133 additions and 10 deletions

View File

@@ -0,0 +1,24 @@
"""Add APIToken.expires_at
Revision ID: 896818069c98
Revises: d68c98b66cd4
Create Date: 2018-05-07 11:35:58.050542
"""
# revision identifiers, used by Alembic.
revision = '896818069c98'
down_revision = 'd68c98b66cd4'
branch_labels = None
depends_on = None
from alembic import op
import sqlalchemy as sa
def upgrade():
op.add_column('api_tokens', sa.Column('expires_at', sa.DateTime(), nullable=True))
def downgrade():
op.drop_column('api_tokens', 'expires_at')

View File

@@ -4,6 +4,7 @@
# Distributed under the terms of the Modified BSD License.
import asyncio
from datetime import datetime
import json
from async_generator import aclosing
@@ -201,13 +202,30 @@ class UserTokenListAPIHandler(APIHandler):
user = self.find_user(name)
if not user:
raise web.HTTPError(404, "No such user: %s" % name)
now = datetime.utcnow()
api_tokens = []
def sort_key(token):
return token.last_activity or token.created
for token in sorted(user.api_tokens, key=sort_key):
if token.expires_at and token.expires_at < now:
# exclude expired tokens
self.db.delete(token)
self.db.commit()
continue
api_tokens.append(self.token_model(token))
oauth_tokens = []
# OAuth tokens use integer timestamps
now_timestamp = now.timestamp()
for token in sorted(user.oauth_tokens, key=sort_key):
if token.expires_at and token.expires_at < now_timestamp:
# exclude expired tokens
self.db.delete(token)
self.db.commit()
continue
oauth_tokens.append(self.token_model(token))
self.write(json.dumps({
'api_tokens': api_tokens,
@@ -252,7 +270,7 @@ class UserTokenListAPIHandler(APIHandler):
if requester is not user:
note += " by %s %s" % (kind, requester.name)
api_token = user.new_api_token(note=note)
api_token = user.new_api_token(note=note, expires_in=body.get('expires_in', None))
if requester is not user:
self.log.info("%s %s requested API token for %s", kind.title(), requester.name, user.name)
else:

View File

@@ -9,6 +9,7 @@ import atexit
import binascii
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime, timezone
from functools import partial
from getpass import getuser
import logging
from operator import itemgetter
@@ -1249,10 +1250,23 @@ class JupyterHub(Application):
self.log.debug("Not duplicating token %s", orm_token)
db.commit()
# purge expired tokens hourly
purge_expired_tokens_interval = 3600
async def init_api_tokens(self):
"""Load predefined API tokens (for services) into database"""
await self._add_tokens(self.service_tokens, kind='service')
await self._add_tokens(self.api_tokens, kind='user')
purge_expired_tokens = partial(orm.APIToken.purge_expired, self.db)
purge_expired_tokens()
# purge expired tokens hourly
# we don't need to be prompt about this
# because expired tokens cannot be used anyway
pc = PeriodicCallback(
purge_expired_tokens,
1e3 * self.purge_expired_tokens_interval,
)
pc.start()
def init_services(self):
self._service_map.clear()

View File

@@ -3,6 +3,7 @@
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
from collections import defaultdict
from datetime import datetime
from http.client import responses
@@ -229,12 +230,24 @@ class TokenPageHandler(BaseHandler):
token.last_activity or never,
token.created or never,
)
api_tokens = sorted(user.api_tokens, key=sort_key, reverse=True)
now = datetime.utcnow()
api_tokens = []
for token in sorted(user.api_tokens, key=sort_key, reverse=True):
if token.expires_at and token.expires_at < now:
self.db.delete(token)
self.db.commit()
continue
api_tokens.append(token)
# group oauth client tokens by client id
from collections import defaultdict
oauth_tokens = defaultdict(list)
for token in user.oauth_tokens:
if token.expires_at and token.expires_at < now:
self.log.warning("Deleting expired token")
self.db.delete(token)
self.db.commit()
continue
if not token.client_id:
# token should have been deleted when client was deleted
self.log.warning("Deleting stale oauth token for %s", user.name)

View File

@@ -3,7 +3,7 @@
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
from datetime import datetime
from datetime import datetime, timedelta
import enum
import json
@@ -14,7 +14,7 @@ from tornado.log import app_log
from sqlalchemy.types import TypeDecorator, TEXT, LargeBinary
from sqlalchemy import (
create_engine, event, inspect,
create_engine, event, inspect, or_,
Column, Integer, ForeignKey, Unicode, Boolean,
DateTime, Enum, Table,
)
@@ -33,6 +33,9 @@ from .utils import (
new_token, hash_token, compare_token,
)
# top-level variable for easier mocking in tests
utcnow = datetime.utcnow
class JSONDict(TypeDecorator):
"""Represents an immutable structure as a json-encoded string.
@@ -176,12 +179,12 @@ class User(Base):
running=sum(bool(s.server) for s in self._orm_spawners),
)
def new_api_token(self, token=None, generated=True, note=''):
def new_api_token(self, token=None, **kwargs):
"""Create a new API token
If `token` is given, load that token.
"""
return APIToken.new(token=token, user=self, note=note, generated=generated)
return APIToken.new(token=token, user=self, **kwargs)
@classmethod
def find(cls, db, name):
@@ -242,11 +245,11 @@ class Service(Base):
server = relationship(Server, cascade='all')
pid = Column(Integer)
def new_api_token(self, token=None, generated=True, note=''):
def new_api_token(self, token=None, **kwargs):
"""Create a new API token
If `token` is given, load that token.
"""
return APIToken.new(token=token, service=self, note=note, generated=generated)
return APIToken.new(token=token, service=self, **kwargs)
@classmethod
def find(cls, db, name):
@@ -348,6 +351,7 @@ class APIToken(Hashed, Base):
# token metadata for bookkeeping
created = Column(DateTime, default=datetime.utcnow)
expires_at = Column(DateTime, default=None, nullable=True)
last_activity = Column(DateTime)
note = Column(Unicode(1023))
@@ -369,6 +373,22 @@ class APIToken(Hashed, Base):
name=name,
)
@classmethod
def purge_expired(cls, db):
"""Purge expired API Tokens from the database"""
now = utcnow()
deleted = False
for token in (
db.query(cls)
.filter(cls.expires_at != None)
.filter(cls.expires_at < now)
):
app_log.debug("Purging expired %s", token)
deleted = True
db.delete(token)
if deleted:
db.commit()
@classmethod
def find(cls, db, token, *, kind=None):
"""Find a token object by value.
@@ -379,6 +399,9 @@ class APIToken(Hashed, Base):
`kind='service'` only returns API tokens for services
"""
prefix_match = cls.find_prefix(db, token)
prefix_match = prefix_match.filter(
or_(cls.expires_at == None, cls.expires_at >= utcnow())
)
if kind == 'user':
prefix_match = prefix_match.filter(cls.user_id != None)
elif kind == 'service':
@@ -390,7 +413,8 @@ class APIToken(Hashed, Base):
return orm_token
@classmethod
def new(cls, token=None, user=None, service=None, note='', generated=True):
def new(cls, token=None, user=None, service=None, note='', generated=True,
expires_in=None):
"""Generate a new API token for a user or service"""
assert user or service
assert not (user and service)
@@ -412,6 +436,8 @@ class APIToken(Hashed, Base):
else:
assert service.id is not None
orm_token.service = service
if expires_in is not None:
orm_token.expires_at = utcnow() + timedelta(seconds=expires_in)
db.add(orm_token)
db.commit()
return token

View File

@@ -3,8 +3,10 @@
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
from datetime import datetime, timedelta
import os
import socket
from unittest import mock
import pytest
from tornado import gen
@@ -99,6 +101,32 @@ def test_tokens(db):
assert len(user.api_tokens) == 3
def test_token_expiry(db):
user = orm.User(name='parker')
db.add(user)
db.commit()
now = datetime.utcnow()
token = user.new_api_token(expires_in=60)
orm_token = orm.APIToken.find(db, token=token)
assert orm_token
assert orm_token.expires_at is not None
# approximate range
assert orm_token.expires_at > now + timedelta(seconds=50)
assert orm_token.expires_at < now + timedelta(seconds=70)
the_future = mock.patch('jupyterhub.orm.utcnow', lambda : now + timedelta(seconds=70))
with the_future:
found = orm.APIToken.find(db, token=token)
assert found is None
# purging shouldn't delete non-expired tokens
orm.APIToken.purge_expired(db)
assert orm.APIToken.find(db, token=token)
with the_future:
orm.APIToken.purge_expired(db)
assert orm.APIToken.find(db, token=token) is None
# after purging, make sure we aren't in the user token list
assert orm_token not in user.api_tokens
def test_service_tokens(db):
service = orm.Service(name='secret')
db.add(service)