mirror of
https://github.com/jupyterhub/jupyterhub.git
synced 2025-10-15 14:03:02 +00:00
implement API token expiry
This commit is contained in:
24
jupyterhub/alembic/versions/896818069c98_token_expires.py
Normal file
24
jupyterhub/alembic/versions/896818069c98_token_expires.py
Normal file
@@ -0,0 +1,24 @@
|
||||
"""Add APIToken.expires_at
|
||||
|
||||
Revision ID: 896818069c98
|
||||
Revises: d68c98b66cd4
|
||||
Create Date: 2018-05-07 11:35:58.050542
|
||||
|
||||
"""
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '896818069c98'
|
||||
down_revision = 'd68c98b66cd4'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
def upgrade():
|
||||
op.add_column('api_tokens', sa.Column('expires_at', sa.DateTime(), nullable=True))
|
||||
|
||||
|
||||
def downgrade():
|
||||
op.drop_column('api_tokens', 'expires_at')
|
@@ -4,6 +4,7 @@
|
||||
# Distributed under the terms of the Modified BSD License.
|
||||
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
import json
|
||||
|
||||
from async_generator import aclosing
|
||||
@@ -201,13 +202,30 @@ class UserTokenListAPIHandler(APIHandler):
|
||||
user = self.find_user(name)
|
||||
if not user:
|
||||
raise web.HTTPError(404, "No such user: %s" % name)
|
||||
|
||||
now = datetime.utcnow()
|
||||
|
||||
api_tokens = []
|
||||
def sort_key(token):
|
||||
return token.last_activity or token.created
|
||||
|
||||
for token in sorted(user.api_tokens, key=sort_key):
|
||||
if token.expires_at and token.expires_at < now:
|
||||
# exclude expired tokens
|
||||
self.db.delete(token)
|
||||
self.db.commit()
|
||||
continue
|
||||
api_tokens.append(self.token_model(token))
|
||||
|
||||
oauth_tokens = []
|
||||
# OAuth tokens use integer timestamps
|
||||
now_timestamp = now.timestamp()
|
||||
for token in sorted(user.oauth_tokens, key=sort_key):
|
||||
if token.expires_at and token.expires_at < now_timestamp:
|
||||
# exclude expired tokens
|
||||
self.db.delete(token)
|
||||
self.db.commit()
|
||||
continue
|
||||
oauth_tokens.append(self.token_model(token))
|
||||
self.write(json.dumps({
|
||||
'api_tokens': api_tokens,
|
||||
@@ -252,7 +270,7 @@ class UserTokenListAPIHandler(APIHandler):
|
||||
if requester is not user:
|
||||
note += " by %s %s" % (kind, requester.name)
|
||||
|
||||
api_token = user.new_api_token(note=note)
|
||||
api_token = user.new_api_token(note=note, expires_in=body.get('expires_in', None))
|
||||
if requester is not user:
|
||||
self.log.info("%s %s requested API token for %s", kind.title(), requester.name, user.name)
|
||||
else:
|
||||
|
@@ -9,6 +9,7 @@ import atexit
|
||||
import binascii
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from datetime import datetime, timezone
|
||||
from functools import partial
|
||||
from getpass import getuser
|
||||
import logging
|
||||
from operator import itemgetter
|
||||
@@ -1249,10 +1250,23 @@ class JupyterHub(Application):
|
||||
self.log.debug("Not duplicating token %s", orm_token)
|
||||
db.commit()
|
||||
|
||||
# purge expired tokens hourly
|
||||
purge_expired_tokens_interval = 3600
|
||||
|
||||
async def init_api_tokens(self):
|
||||
"""Load predefined API tokens (for services) into database"""
|
||||
await self._add_tokens(self.service_tokens, kind='service')
|
||||
await self._add_tokens(self.api_tokens, kind='user')
|
||||
purge_expired_tokens = partial(orm.APIToken.purge_expired, self.db)
|
||||
purge_expired_tokens()
|
||||
# purge expired tokens hourly
|
||||
# we don't need to be prompt about this
|
||||
# because expired tokens cannot be used anyway
|
||||
pc = PeriodicCallback(
|
||||
purge_expired_tokens,
|
||||
1e3 * self.purge_expired_tokens_interval,
|
||||
)
|
||||
pc.start()
|
||||
|
||||
def init_services(self):
|
||||
self._service_map.clear()
|
||||
|
@@ -3,6 +3,7 @@
|
||||
# Copyright (c) Jupyter Development Team.
|
||||
# Distributed under the terms of the Modified BSD License.
|
||||
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from http.client import responses
|
||||
|
||||
@@ -229,12 +230,24 @@ class TokenPageHandler(BaseHandler):
|
||||
token.last_activity or never,
|
||||
token.created or never,
|
||||
)
|
||||
api_tokens = sorted(user.api_tokens, key=sort_key, reverse=True)
|
||||
|
||||
now = datetime.utcnow()
|
||||
api_tokens = []
|
||||
for token in sorted(user.api_tokens, key=sort_key, reverse=True):
|
||||
if token.expires_at and token.expires_at < now:
|
||||
self.db.delete(token)
|
||||
self.db.commit()
|
||||
continue
|
||||
api_tokens.append(token)
|
||||
|
||||
# group oauth client tokens by client id
|
||||
from collections import defaultdict
|
||||
oauth_tokens = defaultdict(list)
|
||||
for token in user.oauth_tokens:
|
||||
if token.expires_at and token.expires_at < now:
|
||||
self.log.warning("Deleting expired token")
|
||||
self.db.delete(token)
|
||||
self.db.commit()
|
||||
continue
|
||||
if not token.client_id:
|
||||
# token should have been deleted when client was deleted
|
||||
self.log.warning("Deleting stale oauth token for %s", user.name)
|
||||
|
@@ -3,7 +3,7 @@
|
||||
# Copyright (c) Jupyter Development Team.
|
||||
# Distributed under the terms of the Modified BSD License.
|
||||
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timedelta
|
||||
import enum
|
||||
import json
|
||||
|
||||
@@ -14,7 +14,7 @@ from tornado.log import app_log
|
||||
|
||||
from sqlalchemy.types import TypeDecorator, TEXT, LargeBinary
|
||||
from sqlalchemy import (
|
||||
create_engine, event, inspect,
|
||||
create_engine, event, inspect, or_,
|
||||
Column, Integer, ForeignKey, Unicode, Boolean,
|
||||
DateTime, Enum, Table,
|
||||
)
|
||||
@@ -33,6 +33,9 @@ from .utils import (
|
||||
new_token, hash_token, compare_token,
|
||||
)
|
||||
|
||||
# top-level variable for easier mocking in tests
|
||||
utcnow = datetime.utcnow
|
||||
|
||||
|
||||
class JSONDict(TypeDecorator):
|
||||
"""Represents an immutable structure as a json-encoded string.
|
||||
@@ -176,12 +179,12 @@ class User(Base):
|
||||
running=sum(bool(s.server) for s in self._orm_spawners),
|
||||
)
|
||||
|
||||
def new_api_token(self, token=None, generated=True, note=''):
|
||||
def new_api_token(self, token=None, **kwargs):
|
||||
"""Create a new API token
|
||||
|
||||
If `token` is given, load that token.
|
||||
"""
|
||||
return APIToken.new(token=token, user=self, note=note, generated=generated)
|
||||
return APIToken.new(token=token, user=self, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def find(cls, db, name):
|
||||
@@ -242,11 +245,11 @@ class Service(Base):
|
||||
server = relationship(Server, cascade='all')
|
||||
pid = Column(Integer)
|
||||
|
||||
def new_api_token(self, token=None, generated=True, note=''):
|
||||
def new_api_token(self, token=None, **kwargs):
|
||||
"""Create a new API token
|
||||
If `token` is given, load that token.
|
||||
"""
|
||||
return APIToken.new(token=token, service=self, note=note, generated=generated)
|
||||
return APIToken.new(token=token, service=self, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def find(cls, db, name):
|
||||
@@ -348,6 +351,7 @@ class APIToken(Hashed, Base):
|
||||
|
||||
# token metadata for bookkeeping
|
||||
created = Column(DateTime, default=datetime.utcnow)
|
||||
expires_at = Column(DateTime, default=None, nullable=True)
|
||||
last_activity = Column(DateTime)
|
||||
note = Column(Unicode(1023))
|
||||
|
||||
@@ -369,6 +373,22 @@ class APIToken(Hashed, Base):
|
||||
name=name,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def purge_expired(cls, db):
|
||||
"""Purge expired API Tokens from the database"""
|
||||
now = utcnow()
|
||||
deleted = False
|
||||
for token in (
|
||||
db.query(cls)
|
||||
.filter(cls.expires_at != None)
|
||||
.filter(cls.expires_at < now)
|
||||
):
|
||||
app_log.debug("Purging expired %s", token)
|
||||
deleted = True
|
||||
db.delete(token)
|
||||
if deleted:
|
||||
db.commit()
|
||||
|
||||
@classmethod
|
||||
def find(cls, db, token, *, kind=None):
|
||||
"""Find a token object by value.
|
||||
@@ -379,6 +399,9 @@ class APIToken(Hashed, Base):
|
||||
`kind='service'` only returns API tokens for services
|
||||
"""
|
||||
prefix_match = cls.find_prefix(db, token)
|
||||
prefix_match = prefix_match.filter(
|
||||
or_(cls.expires_at == None, cls.expires_at >= utcnow())
|
||||
)
|
||||
if kind == 'user':
|
||||
prefix_match = prefix_match.filter(cls.user_id != None)
|
||||
elif kind == 'service':
|
||||
@@ -390,7 +413,8 @@ class APIToken(Hashed, Base):
|
||||
return orm_token
|
||||
|
||||
@classmethod
|
||||
def new(cls, token=None, user=None, service=None, note='', generated=True):
|
||||
def new(cls, token=None, user=None, service=None, note='', generated=True,
|
||||
expires_in=None):
|
||||
"""Generate a new API token for a user or service"""
|
||||
assert user or service
|
||||
assert not (user and service)
|
||||
@@ -412,6 +436,8 @@ class APIToken(Hashed, Base):
|
||||
else:
|
||||
assert service.id is not None
|
||||
orm_token.service = service
|
||||
if expires_in is not None:
|
||||
orm_token.expires_at = utcnow() + timedelta(seconds=expires_in)
|
||||
db.add(orm_token)
|
||||
db.commit()
|
||||
return token
|
||||
|
@@ -3,8 +3,10 @@
|
||||
# Copyright (c) Jupyter Development Team.
|
||||
# Distributed under the terms of the Modified BSD License.
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
import os
|
||||
import socket
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
from tornado import gen
|
||||
@@ -99,6 +101,32 @@ def test_tokens(db):
|
||||
assert len(user.api_tokens) == 3
|
||||
|
||||
|
||||
def test_token_expiry(db):
|
||||
user = orm.User(name='parker')
|
||||
db.add(user)
|
||||
db.commit()
|
||||
now = datetime.utcnow()
|
||||
token = user.new_api_token(expires_in=60)
|
||||
orm_token = orm.APIToken.find(db, token=token)
|
||||
assert orm_token
|
||||
assert orm_token.expires_at is not None
|
||||
# approximate range
|
||||
assert orm_token.expires_at > now + timedelta(seconds=50)
|
||||
assert orm_token.expires_at < now + timedelta(seconds=70)
|
||||
the_future = mock.patch('jupyterhub.orm.utcnow', lambda : now + timedelta(seconds=70))
|
||||
with the_future:
|
||||
found = orm.APIToken.find(db, token=token)
|
||||
assert found is None
|
||||
# purging shouldn't delete non-expired tokens
|
||||
orm.APIToken.purge_expired(db)
|
||||
assert orm.APIToken.find(db, token=token)
|
||||
with the_future:
|
||||
orm.APIToken.purge_expired(db)
|
||||
assert orm.APIToken.find(db, token=token) is None
|
||||
# after purging, make sure we aren't in the user token list
|
||||
assert orm_token not in user.api_tokens
|
||||
|
||||
|
||||
def test_service_tokens(db):
|
||||
service = orm.Service(name='secret')
|
||||
db.add(service)
|
||||
|
Reference in New Issue
Block a user