implement API token expiry

This commit is contained in:
Min RK
2018-05-07 11:59:12 +02:00
parent a17f5e4f1b
commit 58c91e3fd4
6 changed files with 133 additions and 10 deletions

View File

@@ -0,0 +1,24 @@
"""Add APIToken.expires_at
Revision ID: 896818069c98
Revises: d68c98b66cd4
Create Date: 2018-05-07 11:35:58.050542
"""
# revision identifiers, used by Alembic.
revision = '896818069c98'
down_revision = 'd68c98b66cd4'
branch_labels = None
depends_on = None
from alembic import op
import sqlalchemy as sa
def upgrade():
op.add_column('api_tokens', sa.Column('expires_at', sa.DateTime(), nullable=True))
def downgrade():
op.drop_column('api_tokens', 'expires_at')

View File

@@ -4,6 +4,7 @@
# Distributed under the terms of the Modified BSD License. # Distributed under the terms of the Modified BSD License.
import asyncio import asyncio
from datetime import datetime
import json import json
from async_generator import aclosing from async_generator import aclosing
@@ -201,13 +202,30 @@ class UserTokenListAPIHandler(APIHandler):
user = self.find_user(name) user = self.find_user(name)
if not user: if not user:
raise web.HTTPError(404, "No such user: %s" % name) raise web.HTTPError(404, "No such user: %s" % name)
now = datetime.utcnow()
api_tokens = [] api_tokens = []
def sort_key(token): def sort_key(token):
return token.last_activity or token.created return token.last_activity or token.created
for token in sorted(user.api_tokens, key=sort_key): for token in sorted(user.api_tokens, key=sort_key):
if token.expires_at and token.expires_at < now:
# exclude expired tokens
self.db.delete(token)
self.db.commit()
continue
api_tokens.append(self.token_model(token)) api_tokens.append(self.token_model(token))
oauth_tokens = [] oauth_tokens = []
# OAuth tokens use integer timestamps
now_timestamp = now.timestamp()
for token in sorted(user.oauth_tokens, key=sort_key): for token in sorted(user.oauth_tokens, key=sort_key):
if token.expires_at and token.expires_at < now_timestamp:
# exclude expired tokens
self.db.delete(token)
self.db.commit()
continue
oauth_tokens.append(self.token_model(token)) oauth_tokens.append(self.token_model(token))
self.write(json.dumps({ self.write(json.dumps({
'api_tokens': api_tokens, 'api_tokens': api_tokens,
@@ -252,7 +270,7 @@ class UserTokenListAPIHandler(APIHandler):
if requester is not user: if requester is not user:
note += " by %s %s" % (kind, requester.name) note += " by %s %s" % (kind, requester.name)
api_token = user.new_api_token(note=note) api_token = user.new_api_token(note=note, expires_in=body.get('expires_in', None))
if requester is not user: if requester is not user:
self.log.info("%s %s requested API token for %s", kind.title(), requester.name, user.name) self.log.info("%s %s requested API token for %s", kind.title(), requester.name, user.name)
else: else:

View File

@@ -9,6 +9,7 @@ import atexit
import binascii import binascii
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from datetime import datetime, timezone from datetime import datetime, timezone
from functools import partial
from getpass import getuser from getpass import getuser
import logging import logging
from operator import itemgetter from operator import itemgetter
@@ -1249,10 +1250,23 @@ class JupyterHub(Application):
self.log.debug("Not duplicating token %s", orm_token) self.log.debug("Not duplicating token %s", orm_token)
db.commit() db.commit()
# purge expired tokens hourly
purge_expired_tokens_interval = 3600
async def init_api_tokens(self): async def init_api_tokens(self):
"""Load predefined API tokens (for services) into database""" """Load predefined API tokens (for services) into database"""
await self._add_tokens(self.service_tokens, kind='service') await self._add_tokens(self.service_tokens, kind='service')
await self._add_tokens(self.api_tokens, kind='user') await self._add_tokens(self.api_tokens, kind='user')
purge_expired_tokens = partial(orm.APIToken.purge_expired, self.db)
purge_expired_tokens()
# purge expired tokens hourly
# we don't need to be prompt about this
# because expired tokens cannot be used anyway
pc = PeriodicCallback(
purge_expired_tokens,
1e3 * self.purge_expired_tokens_interval,
)
pc.start()
def init_services(self): def init_services(self):
self._service_map.clear() self._service_map.clear()

View File

@@ -3,6 +3,7 @@
# Copyright (c) Jupyter Development Team. # Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License. # Distributed under the terms of the Modified BSD License.
from collections import defaultdict
from datetime import datetime from datetime import datetime
from http.client import responses from http.client import responses
@@ -229,12 +230,24 @@ class TokenPageHandler(BaseHandler):
token.last_activity or never, token.last_activity or never,
token.created or never, token.created or never,
) )
api_tokens = sorted(user.api_tokens, key=sort_key, reverse=True)
now = datetime.utcnow()
api_tokens = []
for token in sorted(user.api_tokens, key=sort_key, reverse=True):
if token.expires_at and token.expires_at < now:
self.db.delete(token)
self.db.commit()
continue
api_tokens.append(token)
# group oauth client tokens by client id # group oauth client tokens by client id
from collections import defaultdict
oauth_tokens = defaultdict(list) oauth_tokens = defaultdict(list)
for token in user.oauth_tokens: for token in user.oauth_tokens:
if token.expires_at and token.expires_at < now:
self.log.warning("Deleting expired token")
self.db.delete(token)
self.db.commit()
continue
if not token.client_id: if not token.client_id:
# token should have been deleted when client was deleted # token should have been deleted when client was deleted
self.log.warning("Deleting stale oauth token for %s", user.name) self.log.warning("Deleting stale oauth token for %s", user.name)

View File

@@ -3,7 +3,7 @@
# Copyright (c) Jupyter Development Team. # Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License. # Distributed under the terms of the Modified BSD License.
from datetime import datetime from datetime import datetime, timedelta
import enum import enum
import json import json
@@ -14,7 +14,7 @@ from tornado.log import app_log
from sqlalchemy.types import TypeDecorator, TEXT, LargeBinary from sqlalchemy.types import TypeDecorator, TEXT, LargeBinary
from sqlalchemy import ( from sqlalchemy import (
create_engine, event, inspect, create_engine, event, inspect, or_,
Column, Integer, ForeignKey, Unicode, Boolean, Column, Integer, ForeignKey, Unicode, Boolean,
DateTime, Enum, Table, DateTime, Enum, Table,
) )
@@ -33,6 +33,9 @@ from .utils import (
new_token, hash_token, compare_token, new_token, hash_token, compare_token,
) )
# top-level variable for easier mocking in tests
utcnow = datetime.utcnow
class JSONDict(TypeDecorator): class JSONDict(TypeDecorator):
"""Represents an immutable structure as a json-encoded string. """Represents an immutable structure as a json-encoded string.
@@ -176,12 +179,12 @@ class User(Base):
running=sum(bool(s.server) for s in self._orm_spawners), running=sum(bool(s.server) for s in self._orm_spawners),
) )
def new_api_token(self, token=None, generated=True, note=''): def new_api_token(self, token=None, **kwargs):
"""Create a new API token """Create a new API token
If `token` is given, load that token. If `token` is given, load that token.
""" """
return APIToken.new(token=token, user=self, note=note, generated=generated) return APIToken.new(token=token, user=self, **kwargs)
@classmethod @classmethod
def find(cls, db, name): def find(cls, db, name):
@@ -242,11 +245,11 @@ class Service(Base):
server = relationship(Server, cascade='all') server = relationship(Server, cascade='all')
pid = Column(Integer) pid = Column(Integer)
def new_api_token(self, token=None, generated=True, note=''): def new_api_token(self, token=None, **kwargs):
"""Create a new API token """Create a new API token
If `token` is given, load that token. If `token` is given, load that token.
""" """
return APIToken.new(token=token, service=self, note=note, generated=generated) return APIToken.new(token=token, service=self, **kwargs)
@classmethod @classmethod
def find(cls, db, name): def find(cls, db, name):
@@ -348,6 +351,7 @@ class APIToken(Hashed, Base):
# token metadata for bookkeeping # token metadata for bookkeeping
created = Column(DateTime, default=datetime.utcnow) created = Column(DateTime, default=datetime.utcnow)
expires_at = Column(DateTime, default=None, nullable=True)
last_activity = Column(DateTime) last_activity = Column(DateTime)
note = Column(Unicode(1023)) note = Column(Unicode(1023))
@@ -369,6 +373,22 @@ class APIToken(Hashed, Base):
name=name, name=name,
) )
@classmethod
def purge_expired(cls, db):
"""Purge expired API Tokens from the database"""
now = utcnow()
deleted = False
for token in (
db.query(cls)
.filter(cls.expires_at != None)
.filter(cls.expires_at < now)
):
app_log.debug("Purging expired %s", token)
deleted = True
db.delete(token)
if deleted:
db.commit()
@classmethod @classmethod
def find(cls, db, token, *, kind=None): def find(cls, db, token, *, kind=None):
"""Find a token object by value. """Find a token object by value.
@@ -379,6 +399,9 @@ class APIToken(Hashed, Base):
`kind='service'` only returns API tokens for services `kind='service'` only returns API tokens for services
""" """
prefix_match = cls.find_prefix(db, token) prefix_match = cls.find_prefix(db, token)
prefix_match = prefix_match.filter(
or_(cls.expires_at == None, cls.expires_at >= utcnow())
)
if kind == 'user': if kind == 'user':
prefix_match = prefix_match.filter(cls.user_id != None) prefix_match = prefix_match.filter(cls.user_id != None)
elif kind == 'service': elif kind == 'service':
@@ -390,7 +413,8 @@ class APIToken(Hashed, Base):
return orm_token return orm_token
@classmethod @classmethod
def new(cls, token=None, user=None, service=None, note='', generated=True): def new(cls, token=None, user=None, service=None, note='', generated=True,
expires_in=None):
"""Generate a new API token for a user or service""" """Generate a new API token for a user or service"""
assert user or service assert user or service
assert not (user and service) assert not (user and service)
@@ -412,6 +436,8 @@ class APIToken(Hashed, Base):
else: else:
assert service.id is not None assert service.id is not None
orm_token.service = service orm_token.service = service
if expires_in is not None:
orm_token.expires_at = utcnow() + timedelta(seconds=expires_in)
db.add(orm_token) db.add(orm_token)
db.commit() db.commit()
return token return token

View File

@@ -3,8 +3,10 @@
# Copyright (c) Jupyter Development Team. # Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License. # Distributed under the terms of the Modified BSD License.
from datetime import datetime, timedelta
import os import os
import socket import socket
from unittest import mock
import pytest import pytest
from tornado import gen from tornado import gen
@@ -99,6 +101,32 @@ def test_tokens(db):
assert len(user.api_tokens) == 3 assert len(user.api_tokens) == 3
def test_token_expiry(db):
user = orm.User(name='parker')
db.add(user)
db.commit()
now = datetime.utcnow()
token = user.new_api_token(expires_in=60)
orm_token = orm.APIToken.find(db, token=token)
assert orm_token
assert orm_token.expires_at is not None
# approximate range
assert orm_token.expires_at > now + timedelta(seconds=50)
assert orm_token.expires_at < now + timedelta(seconds=70)
the_future = mock.patch('jupyterhub.orm.utcnow', lambda : now + timedelta(seconds=70))
with the_future:
found = orm.APIToken.find(db, token=token)
assert found is None
# purging shouldn't delete non-expired tokens
orm.APIToken.purge_expired(db)
assert orm.APIToken.find(db, token=token)
with the_future:
orm.APIToken.purge_expired(db)
assert orm.APIToken.find(db, token=token) is None
# after purging, make sure we aren't in the user token list
assert orm_token not in user.api_tokens
def test_service_tokens(db): def test_service_tokens(db):
service = orm.Service(name='secret') service = orm.Service(name='secret')
db.add(service) db.add(service)