mirror of
https://github.com/jupyterhub/jupyterhub.git
synced 2025-10-15 14:03:02 +00:00
implement API token expiry
This commit is contained in:
24
jupyterhub/alembic/versions/896818069c98_token_expires.py
Normal file
24
jupyterhub/alembic/versions/896818069c98_token_expires.py
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
"""Add APIToken.expires_at
|
||||||
|
|
||||||
|
Revision ID: 896818069c98
|
||||||
|
Revises: d68c98b66cd4
|
||||||
|
Create Date: 2018-05-07 11:35:58.050542
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision = '896818069c98'
|
||||||
|
down_revision = 'd68c98b66cd4'
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade():
|
||||||
|
op.add_column('api_tokens', sa.Column('expires_at', sa.DateTime(), nullable=True))
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade():
|
||||||
|
op.drop_column('api_tokens', 'expires_at')
|
@@ -4,6 +4,7 @@
|
|||||||
# Distributed under the terms of the Modified BSD License.
|
# Distributed under the terms of the Modified BSD License.
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
from datetime import datetime
|
||||||
import json
|
import json
|
||||||
|
|
||||||
from async_generator import aclosing
|
from async_generator import aclosing
|
||||||
@@ -201,13 +202,30 @@ class UserTokenListAPIHandler(APIHandler):
|
|||||||
user = self.find_user(name)
|
user = self.find_user(name)
|
||||||
if not user:
|
if not user:
|
||||||
raise web.HTTPError(404, "No such user: %s" % name)
|
raise web.HTTPError(404, "No such user: %s" % name)
|
||||||
|
|
||||||
|
now = datetime.utcnow()
|
||||||
|
|
||||||
api_tokens = []
|
api_tokens = []
|
||||||
def sort_key(token):
|
def sort_key(token):
|
||||||
return token.last_activity or token.created
|
return token.last_activity or token.created
|
||||||
|
|
||||||
for token in sorted(user.api_tokens, key=sort_key):
|
for token in sorted(user.api_tokens, key=sort_key):
|
||||||
|
if token.expires_at and token.expires_at < now:
|
||||||
|
# exclude expired tokens
|
||||||
|
self.db.delete(token)
|
||||||
|
self.db.commit()
|
||||||
|
continue
|
||||||
api_tokens.append(self.token_model(token))
|
api_tokens.append(self.token_model(token))
|
||||||
|
|
||||||
oauth_tokens = []
|
oauth_tokens = []
|
||||||
|
# OAuth tokens use integer timestamps
|
||||||
|
now_timestamp = now.timestamp()
|
||||||
for token in sorted(user.oauth_tokens, key=sort_key):
|
for token in sorted(user.oauth_tokens, key=sort_key):
|
||||||
|
if token.expires_at and token.expires_at < now_timestamp:
|
||||||
|
# exclude expired tokens
|
||||||
|
self.db.delete(token)
|
||||||
|
self.db.commit()
|
||||||
|
continue
|
||||||
oauth_tokens.append(self.token_model(token))
|
oauth_tokens.append(self.token_model(token))
|
||||||
self.write(json.dumps({
|
self.write(json.dumps({
|
||||||
'api_tokens': api_tokens,
|
'api_tokens': api_tokens,
|
||||||
@@ -252,7 +270,7 @@ class UserTokenListAPIHandler(APIHandler):
|
|||||||
if requester is not user:
|
if requester is not user:
|
||||||
note += " by %s %s" % (kind, requester.name)
|
note += " by %s %s" % (kind, requester.name)
|
||||||
|
|
||||||
api_token = user.new_api_token(note=note)
|
api_token = user.new_api_token(note=note, expires_in=body.get('expires_in', None))
|
||||||
if requester is not user:
|
if requester is not user:
|
||||||
self.log.info("%s %s requested API token for %s", kind.title(), requester.name, user.name)
|
self.log.info("%s %s requested API token for %s", kind.title(), requester.name, user.name)
|
||||||
else:
|
else:
|
||||||
|
@@ -9,6 +9,7 @@ import atexit
|
|||||||
import binascii
|
import binascii
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
|
from functools import partial
|
||||||
from getpass import getuser
|
from getpass import getuser
|
||||||
import logging
|
import logging
|
||||||
from operator import itemgetter
|
from operator import itemgetter
|
||||||
@@ -1249,10 +1250,23 @@ class JupyterHub(Application):
|
|||||||
self.log.debug("Not duplicating token %s", orm_token)
|
self.log.debug("Not duplicating token %s", orm_token)
|
||||||
db.commit()
|
db.commit()
|
||||||
|
|
||||||
|
# purge expired tokens hourly
|
||||||
|
purge_expired_tokens_interval = 3600
|
||||||
|
|
||||||
async def init_api_tokens(self):
|
async def init_api_tokens(self):
|
||||||
"""Load predefined API tokens (for services) into database"""
|
"""Load predefined API tokens (for services) into database"""
|
||||||
await self._add_tokens(self.service_tokens, kind='service')
|
await self._add_tokens(self.service_tokens, kind='service')
|
||||||
await self._add_tokens(self.api_tokens, kind='user')
|
await self._add_tokens(self.api_tokens, kind='user')
|
||||||
|
purge_expired_tokens = partial(orm.APIToken.purge_expired, self.db)
|
||||||
|
purge_expired_tokens()
|
||||||
|
# purge expired tokens hourly
|
||||||
|
# we don't need to be prompt about this
|
||||||
|
# because expired tokens cannot be used anyway
|
||||||
|
pc = PeriodicCallback(
|
||||||
|
purge_expired_tokens,
|
||||||
|
1e3 * self.purge_expired_tokens_interval,
|
||||||
|
)
|
||||||
|
pc.start()
|
||||||
|
|
||||||
def init_services(self):
|
def init_services(self):
|
||||||
self._service_map.clear()
|
self._service_map.clear()
|
||||||
|
@@ -3,6 +3,7 @@
|
|||||||
# Copyright (c) Jupyter Development Team.
|
# Copyright (c) Jupyter Development Team.
|
||||||
# Distributed under the terms of the Modified BSD License.
|
# Distributed under the terms of the Modified BSD License.
|
||||||
|
|
||||||
|
from collections import defaultdict
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from http.client import responses
|
from http.client import responses
|
||||||
|
|
||||||
@@ -229,12 +230,24 @@ class TokenPageHandler(BaseHandler):
|
|||||||
token.last_activity or never,
|
token.last_activity or never,
|
||||||
token.created or never,
|
token.created or never,
|
||||||
)
|
)
|
||||||
api_tokens = sorted(user.api_tokens, key=sort_key, reverse=True)
|
|
||||||
|
now = datetime.utcnow()
|
||||||
|
api_tokens = []
|
||||||
|
for token in sorted(user.api_tokens, key=sort_key, reverse=True):
|
||||||
|
if token.expires_at and token.expires_at < now:
|
||||||
|
self.db.delete(token)
|
||||||
|
self.db.commit()
|
||||||
|
continue
|
||||||
|
api_tokens.append(token)
|
||||||
|
|
||||||
# group oauth client tokens by client id
|
# group oauth client tokens by client id
|
||||||
from collections import defaultdict
|
|
||||||
oauth_tokens = defaultdict(list)
|
oauth_tokens = defaultdict(list)
|
||||||
for token in user.oauth_tokens:
|
for token in user.oauth_tokens:
|
||||||
|
if token.expires_at and token.expires_at < now:
|
||||||
|
self.log.warning("Deleting expired token")
|
||||||
|
self.db.delete(token)
|
||||||
|
self.db.commit()
|
||||||
|
continue
|
||||||
if not token.client_id:
|
if not token.client_id:
|
||||||
# token should have been deleted when client was deleted
|
# token should have been deleted when client was deleted
|
||||||
self.log.warning("Deleting stale oauth token for %s", user.name)
|
self.log.warning("Deleting stale oauth token for %s", user.name)
|
||||||
|
@@ -3,7 +3,7 @@
|
|||||||
# Copyright (c) Jupyter Development Team.
|
# Copyright (c) Jupyter Development Team.
|
||||||
# Distributed under the terms of the Modified BSD License.
|
# Distributed under the terms of the Modified BSD License.
|
||||||
|
|
||||||
from datetime import datetime
|
from datetime import datetime, timedelta
|
||||||
import enum
|
import enum
|
||||||
import json
|
import json
|
||||||
|
|
||||||
@@ -14,7 +14,7 @@ from tornado.log import app_log
|
|||||||
|
|
||||||
from sqlalchemy.types import TypeDecorator, TEXT, LargeBinary
|
from sqlalchemy.types import TypeDecorator, TEXT, LargeBinary
|
||||||
from sqlalchemy import (
|
from sqlalchemy import (
|
||||||
create_engine, event, inspect,
|
create_engine, event, inspect, or_,
|
||||||
Column, Integer, ForeignKey, Unicode, Boolean,
|
Column, Integer, ForeignKey, Unicode, Boolean,
|
||||||
DateTime, Enum, Table,
|
DateTime, Enum, Table,
|
||||||
)
|
)
|
||||||
@@ -33,6 +33,9 @@ from .utils import (
|
|||||||
new_token, hash_token, compare_token,
|
new_token, hash_token, compare_token,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# top-level variable for easier mocking in tests
|
||||||
|
utcnow = datetime.utcnow
|
||||||
|
|
||||||
|
|
||||||
class JSONDict(TypeDecorator):
|
class JSONDict(TypeDecorator):
|
||||||
"""Represents an immutable structure as a json-encoded string.
|
"""Represents an immutable structure as a json-encoded string.
|
||||||
@@ -176,12 +179,12 @@ class User(Base):
|
|||||||
running=sum(bool(s.server) for s in self._orm_spawners),
|
running=sum(bool(s.server) for s in self._orm_spawners),
|
||||||
)
|
)
|
||||||
|
|
||||||
def new_api_token(self, token=None, generated=True, note=''):
|
def new_api_token(self, token=None, **kwargs):
|
||||||
"""Create a new API token
|
"""Create a new API token
|
||||||
|
|
||||||
If `token` is given, load that token.
|
If `token` is given, load that token.
|
||||||
"""
|
"""
|
||||||
return APIToken.new(token=token, user=self, note=note, generated=generated)
|
return APIToken.new(token=token, user=self, **kwargs)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def find(cls, db, name):
|
def find(cls, db, name):
|
||||||
@@ -242,11 +245,11 @@ class Service(Base):
|
|||||||
server = relationship(Server, cascade='all')
|
server = relationship(Server, cascade='all')
|
||||||
pid = Column(Integer)
|
pid = Column(Integer)
|
||||||
|
|
||||||
def new_api_token(self, token=None, generated=True, note=''):
|
def new_api_token(self, token=None, **kwargs):
|
||||||
"""Create a new API token
|
"""Create a new API token
|
||||||
If `token` is given, load that token.
|
If `token` is given, load that token.
|
||||||
"""
|
"""
|
||||||
return APIToken.new(token=token, service=self, note=note, generated=generated)
|
return APIToken.new(token=token, service=self, **kwargs)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def find(cls, db, name):
|
def find(cls, db, name):
|
||||||
@@ -348,6 +351,7 @@ class APIToken(Hashed, Base):
|
|||||||
|
|
||||||
# token metadata for bookkeeping
|
# token metadata for bookkeeping
|
||||||
created = Column(DateTime, default=datetime.utcnow)
|
created = Column(DateTime, default=datetime.utcnow)
|
||||||
|
expires_at = Column(DateTime, default=None, nullable=True)
|
||||||
last_activity = Column(DateTime)
|
last_activity = Column(DateTime)
|
||||||
note = Column(Unicode(1023))
|
note = Column(Unicode(1023))
|
||||||
|
|
||||||
@@ -369,6 +373,22 @@ class APIToken(Hashed, Base):
|
|||||||
name=name,
|
name=name,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def purge_expired(cls, db):
|
||||||
|
"""Purge expired API Tokens from the database"""
|
||||||
|
now = utcnow()
|
||||||
|
deleted = False
|
||||||
|
for token in (
|
||||||
|
db.query(cls)
|
||||||
|
.filter(cls.expires_at != None)
|
||||||
|
.filter(cls.expires_at < now)
|
||||||
|
):
|
||||||
|
app_log.debug("Purging expired %s", token)
|
||||||
|
deleted = True
|
||||||
|
db.delete(token)
|
||||||
|
if deleted:
|
||||||
|
db.commit()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def find(cls, db, token, *, kind=None):
|
def find(cls, db, token, *, kind=None):
|
||||||
"""Find a token object by value.
|
"""Find a token object by value.
|
||||||
@@ -379,6 +399,9 @@ class APIToken(Hashed, Base):
|
|||||||
`kind='service'` only returns API tokens for services
|
`kind='service'` only returns API tokens for services
|
||||||
"""
|
"""
|
||||||
prefix_match = cls.find_prefix(db, token)
|
prefix_match = cls.find_prefix(db, token)
|
||||||
|
prefix_match = prefix_match.filter(
|
||||||
|
or_(cls.expires_at == None, cls.expires_at >= utcnow())
|
||||||
|
)
|
||||||
if kind == 'user':
|
if kind == 'user':
|
||||||
prefix_match = prefix_match.filter(cls.user_id != None)
|
prefix_match = prefix_match.filter(cls.user_id != None)
|
||||||
elif kind == 'service':
|
elif kind == 'service':
|
||||||
@@ -390,7 +413,8 @@ class APIToken(Hashed, Base):
|
|||||||
return orm_token
|
return orm_token
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def new(cls, token=None, user=None, service=None, note='', generated=True):
|
def new(cls, token=None, user=None, service=None, note='', generated=True,
|
||||||
|
expires_in=None):
|
||||||
"""Generate a new API token for a user or service"""
|
"""Generate a new API token for a user or service"""
|
||||||
assert user or service
|
assert user or service
|
||||||
assert not (user and service)
|
assert not (user and service)
|
||||||
@@ -412,6 +436,8 @@ class APIToken(Hashed, Base):
|
|||||||
else:
|
else:
|
||||||
assert service.id is not None
|
assert service.id is not None
|
||||||
orm_token.service = service
|
orm_token.service = service
|
||||||
|
if expires_in is not None:
|
||||||
|
orm_token.expires_at = utcnow() + timedelta(seconds=expires_in)
|
||||||
db.add(orm_token)
|
db.add(orm_token)
|
||||||
db.commit()
|
db.commit()
|
||||||
return token
|
return token
|
||||||
|
@@ -3,8 +3,10 @@
|
|||||||
# Copyright (c) Jupyter Development Team.
|
# Copyright (c) Jupyter Development Team.
|
||||||
# Distributed under the terms of the Modified BSD License.
|
# Distributed under the terms of the Modified BSD License.
|
||||||
|
|
||||||
|
from datetime import datetime, timedelta
|
||||||
import os
|
import os
|
||||||
import socket
|
import socket
|
||||||
|
from unittest import mock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from tornado import gen
|
from tornado import gen
|
||||||
@@ -99,6 +101,32 @@ def test_tokens(db):
|
|||||||
assert len(user.api_tokens) == 3
|
assert len(user.api_tokens) == 3
|
||||||
|
|
||||||
|
|
||||||
|
def test_token_expiry(db):
|
||||||
|
user = orm.User(name='parker')
|
||||||
|
db.add(user)
|
||||||
|
db.commit()
|
||||||
|
now = datetime.utcnow()
|
||||||
|
token = user.new_api_token(expires_in=60)
|
||||||
|
orm_token = orm.APIToken.find(db, token=token)
|
||||||
|
assert orm_token
|
||||||
|
assert orm_token.expires_at is not None
|
||||||
|
# approximate range
|
||||||
|
assert orm_token.expires_at > now + timedelta(seconds=50)
|
||||||
|
assert orm_token.expires_at < now + timedelta(seconds=70)
|
||||||
|
the_future = mock.patch('jupyterhub.orm.utcnow', lambda : now + timedelta(seconds=70))
|
||||||
|
with the_future:
|
||||||
|
found = orm.APIToken.find(db, token=token)
|
||||||
|
assert found is None
|
||||||
|
# purging shouldn't delete non-expired tokens
|
||||||
|
orm.APIToken.purge_expired(db)
|
||||||
|
assert orm.APIToken.find(db, token=token)
|
||||||
|
with the_future:
|
||||||
|
orm.APIToken.purge_expired(db)
|
||||||
|
assert orm.APIToken.find(db, token=token) is None
|
||||||
|
# after purging, make sure we aren't in the user token list
|
||||||
|
assert orm_token not in user.api_tokens
|
||||||
|
|
||||||
|
|
||||||
def test_service_tokens(db):
|
def test_service_tokens(db):
|
||||||
service = orm.Service(name='secret')
|
service = orm.Service(name='secret')
|
||||||
db.add(service)
|
db.add(service)
|
||||||
|
Reference in New Issue
Block a user