implement API token expiry

This commit is contained in:
Min RK
2018-05-07 11:59:12 +02:00
parent a17f5e4f1b
commit 58c91e3fd4
6 changed files with 133 additions and 10 deletions

View File

@@ -3,7 +3,7 @@
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
from datetime import datetime
from datetime import datetime, timedelta
import enum
import json
@@ -14,7 +14,7 @@ from tornado.log import app_log
from sqlalchemy.types import TypeDecorator, TEXT, LargeBinary
from sqlalchemy import (
create_engine, event, inspect,
create_engine, event, inspect, or_,
Column, Integer, ForeignKey, Unicode, Boolean,
DateTime, Enum, Table,
)
@@ -33,6 +33,9 @@ from .utils import (
new_token, hash_token, compare_token,
)
# top-level variable for easier mocking in tests
utcnow = datetime.utcnow
class JSONDict(TypeDecorator):
"""Represents an immutable structure as a json-encoded string.
@@ -176,12 +179,12 @@ class User(Base):
running=sum(bool(s.server) for s in self._orm_spawners),
)
def new_api_token(self, token=None, generated=True, note=''):
def new_api_token(self, token=None, **kwargs):
"""Create a new API token
If `token` is given, load that token.
"""
return APIToken.new(token=token, user=self, note=note, generated=generated)
return APIToken.new(token=token, user=self, **kwargs)
@classmethod
def find(cls, db, name):
@@ -242,11 +245,11 @@ class Service(Base):
server = relationship(Server, cascade='all')
pid = Column(Integer)
def new_api_token(self, token=None, generated=True, note=''):
def new_api_token(self, token=None, **kwargs):
"""Create a new API token
If `token` is given, load that token.
"""
return APIToken.new(token=token, service=self, note=note, generated=generated)
return APIToken.new(token=token, service=self, **kwargs)
@classmethod
def find(cls, db, name):
@@ -348,6 +351,7 @@ class APIToken(Hashed, Base):
# token metadata for bookkeeping
created = Column(DateTime, default=datetime.utcnow)
expires_at = Column(DateTime, default=None, nullable=True)
last_activity = Column(DateTime)
note = Column(Unicode(1023))
@@ -369,6 +373,22 @@ class APIToken(Hashed, Base):
name=name,
)
@classmethod
def purge_expired(cls, db):
"""Purge expired API Tokens from the database"""
now = utcnow()
deleted = False
for token in (
db.query(cls)
.filter(cls.expires_at != None)
.filter(cls.expires_at < now)
):
app_log.debug("Purging expired %s", token)
deleted = True
db.delete(token)
if deleted:
db.commit()
@classmethod
def find(cls, db, token, *, kind=None):
"""Find a token object by value.
@@ -379,6 +399,9 @@ class APIToken(Hashed, Base):
`kind='service'` only returns API tokens for services
"""
prefix_match = cls.find_prefix(db, token)
prefix_match = prefix_match.filter(
or_(cls.expires_at == None, cls.expires_at >= utcnow())
)
if kind == 'user':
prefix_match = prefix_match.filter(cls.user_id != None)
elif kind == 'service':
@@ -390,7 +413,8 @@ class APIToken(Hashed, Base):
return orm_token
@classmethod
def new(cls, token=None, user=None, service=None, note='', generated=True):
def new(cls, token=None, user=None, service=None, note='', generated=True,
expires_in=None):
"""Generate a new API token for a user or service"""
assert user or service
assert not (user and service)
@@ -412,6 +436,8 @@ class APIToken(Hashed, Base):
else:
assert service.id is not None
orm_token.service = service
if expires_in is not None:
orm_token.expires_at = utcnow() + timedelta(seconds=expires_in)
db.add(orm_token)
db.commit()
return token