remove separate oauth tokens

- merge oauth token fields into APITokens
- create oauth client 'jupyterhub' which owns current API tokens
- db upgrade is currently to drop both token tables, and force recreation on next start
This commit is contained in:
Min RK
2021-03-10 16:41:37 +01:00
parent 2fdf820fe5
commit 0b56fd9e62
15 changed files with 264 additions and 202 deletions

View File

@@ -0,0 +1,119 @@
"""rbac
Revision ID: 833da8570507
Revises: 4dc2d5a8c53c
Create Date: 2021-02-17 15:03:04.360368
"""
# revision identifiers, used by Alembic.
revision = '833da8570507'
down_revision = '4dc2d5a8c53c'
branch_labels = None
depends_on = None
from alembic import op
import sqlalchemy as sa
def upgrade():
# FIXME: currently drops all api tokens and forces recreation!
# this ensures a consistent database, but requires:
# 1. all servers to be stopped for upgrade (maybe unavoidable anyway)
# 2. any manually issued/stored tokens to be re-issued
# tokens loaded via configuration will be recreated on launch and unaffected
op.drop_table('api_tokens')
op.drop_table('oauth_access_tokens')
return
# TODO: explore in-place migration. This seems hard!
# 1. add new columns in api tokens
# 2. fill default fields (client_id='jupyterhub') for all api tokens
# 3. copy oauth tokens into api tokens
# 4. give oauth tokens 'identify' scopes
c = op.get_bind()
naming_convention = {
"fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s",
}
with op.batch_alter_table(
"api_tokens",
naming_convention=naming_convention,
) as batch_op:
batch_op.add_column(
sa.Column(
'client_id',
sa.Unicode(255),
# sa.ForeignKey('oauth_clients.identifier', ondelete='CASCADE'),
nullable=True,
),
)
# batch_cursor = op.get_bind()
# batch_cursor.execute(
# """
# UPDATE api_tokens
# SET client_id='jupyterhub'
# WHERE client_id IS NULL
# """
# )
batch_op.create_foreign_key(
"fk_api_token_client_id",
# 'api_tokens',
'oauth_clients',
['client_id'],
['identifier'],
ondelete='CASCADE',
)
c.execute(
"""
UPDATE api_tokens
SET client_id='jupyterhub'
WHERE client_id IS NULL
"""
)
op.add_column(
'api_tokens',
sa.Column(
'grant_type',
sa.Enum(
'authorization_code',
'implicit',
'password',
'client_credentials',
'refresh_token',
name='granttype',
),
server_default='authorization_code',
nullable=False,
),
)
op.add_column(
'api_tokens', sa.Column('refresh_token', sa.Unicode(length=255), nullable=True)
)
op.add_column(
'api_tokens', sa.Column('session_id', sa.Unicode(length=255), nullable=True)
)
# TODO: migrate OAuth tokens into APIToken table
op.drop_index('ix_oauth_access_tokens_prefix', table_name='oauth_access_tokens')
op.drop_table('oauth_access_tokens')
def downgrade():
# delete OAuth tokens for non-jupyterhub clients
# drop new columns from api tokens
op.drop_constraint(None, 'api_tokens', type_='foreignkey')
op.drop_column('api_tokens', 'session_id')
op.drop_column('api_tokens', 'refresh_token')
op.drop_column('api_tokens', 'grant_type')
op.drop_column('api_tokens', 'client_id')
# FIXME: only drop tokens whose client id is not 'jupyterhub'
# until then, drop all tokens
op.drop_table("api_tokens")
op.drop_table('api_token_role_map')
op.drop_table('service_role_map')
op.drop_table('user_role_map')
op.drop_table('roles')

View File

@@ -29,8 +29,6 @@ class TokenAPIHandler(APIHandler):
"/authorizations/token/:token endpoint is deprecated in JupyterHub 2.0. Use /api/user"
)
orm_token = orm.APIToken.find(self.db, token)
if orm_token is None:
orm_token = orm.OAuthAccessToken.find(self.db, token)
if orm_token is None:
raise web.HTTPError(404)

View File

@@ -205,23 +205,6 @@ class APIHandler(BaseHandler):
def token_model(self, token):
"""Get the JSON model for an APIToken"""
expires_at = None
if isinstance(token, orm.APIToken):
kind = 'api_token'
roles = [r.name for r in token.roles]
extra = {'note': token.note}
expires_at = token.expires_at
elif isinstance(token, orm.OAuthAccessToken):
kind = 'oauth'
# oauth tokens do not bear roles
roles = []
extra = {'oauth_client': token.client.description or token.client.client_id}
if token.expires_at:
expires_at = datetime.fromtimestamp(token.expires_at)
else:
raise TypeError(
"token must be an APIToken or OAuthAccessToken, not %s" % type(token)
)
if token.user:
owner_key = 'user'
@@ -234,13 +217,14 @@ class APIHandler(BaseHandler):
model = {
owner_key: owner,
'id': token.api_id,
'kind': kind,
'roles': [role for role in roles],
'kind': 'api_token',
'roles': [r.name for r in token.roles],
'created': isoformat(token.created),
'last_activity': isoformat(token.last_activity),
'expires_at': isoformat(expires_at),
'expires_at': isoformat(token.expires_at),
'note': token.note,
'oauth_client': token.client.description or token.client.client_id,
}
model.update(extra)
return model
def user_model(self, user):

View File

@@ -32,9 +32,6 @@ class SelfAPIHandler(APIHandler):
async def get(self):
user = self.current_user
if user is None:
# whoami can be accessed via oauth token
user = self.get_current_user_oauth_token()
if user is None:
raise web.HTTPError(403)
if isinstance(user, orm.Service):
@@ -316,17 +313,7 @@ class UserTokenListAPIHandler(APIHandler):
continue
api_tokens.append(self.token_model(token))
oauth_tokens = []
# OAuth tokens use integer timestamps
now_timestamp = now.timestamp()
for token in sorted(user.oauth_tokens, key=sort_key):
if token.expires_at and token.expires_at < now_timestamp:
# exclude expired tokens
self.db.delete(token)
self.db.commit()
continue
oauth_tokens.append(self.token_model(token))
self.write(json.dumps({'api_tokens': api_tokens, 'oauth_tokens': oauth_tokens}))
self.write(json.dumps({'api_tokens': api_tokens}))
# Todo: Set to @needs_scope('users:tokens')
async def post(self, user_name):
@@ -410,19 +397,15 @@ class UserTokenAPIHandler(APIHandler):
(e.g. wrong owner, invalid key format, etc.)
"""
not_found = "No such token %s for user %s" % (token_id, user.name)
prefix, id_ = token_id[0], token_id[1:]
if prefix == 'a':
Token = orm.APIToken
elif prefix == 'o':
Token = orm.OAuthAccessToken
else:
prefix, id_ = token_id[:1], token_id[1:]
if prefix != 'a':
raise web.HTTPError(404, not_found)
try:
id_ = int(id_)
except ValueError:
raise web.HTTPError(404, not_found)
orm_token = self.db.query(Token).filter(Token.id == id_).first()
orm_token = self.db.query(orm.APIToken).filter_by(id=id_).first()
if orm_token is None or orm_token.user is not user.orm_user:
raise web.HTTPError(404, "Token not found %s", orm_token)
return orm_token
@@ -444,10 +427,10 @@ class UserTokenAPIHandler(APIHandler):
raise web.HTTPError(404, "No such user: %s" % user_name)
token = self.find_token_by_id(user, token_id)
# deleting an oauth token deletes *all* oauth tokens for that client
if isinstance(token, orm.OAuthAccessToken):
client_id = token.client_id
client_id = token.client_id
if token.client_id != "jupyterhub":
tokens = [
token for token in user.oauth_tokens if token.client_id == client_id
token for token in user.api_tokens if token.client_id == client_id
]
else:
tokens = [token]

View File

@@ -2014,12 +2014,13 @@ class JupyterHub(Application):
run periodically
"""
# this should be all the subclasses of Expiring
for cls in (orm.APIToken, orm.OAuthAccessToken, orm.OAuthCode):
for cls in (orm.APIToken, orm.OAuthCode):
self.log.debug("Purging expired {name}s".format(name=cls.__name__))
cls.purge_expired(self.db)
async def init_api_tokens(self):
"""Load predefined API tokens (for services) into database"""
await self._add_tokens(self.service_tokens, kind='service')
await self._add_tokens(self.api_tokens, kind='user')
@@ -2292,13 +2293,30 @@ class JupyterHub(Application):
login_url=url_path_join(base_url, 'login'),
token_expires_in=self.oauth_token_expires_in,
)
# ensure the default oauth client exists
if (
not self.db.query(orm.OAuthClient)
.filter_by(identifier="jupyterhub")
.first()
):
# create the oauth client for jupyterhub itself
# this allows us to distinguish between orphaned tokens
# (failed cascade deletion) and tokens issued by the hub
# it has no client_secret, which means it cannot be used
# to make requests
self.oauth_provider.add_client(
client_id="jupyterhub",
client_secret="",
redirect_uri="",
description="JupyterHub",
)
def cleanup_oauth_clients(self):
"""Cleanup any OAuth clients that shouldn't be in the database.
This should mainly be services that have been removed from configuration or renamed.
"""
oauth_client_ids = set()
oauth_client_ids = {"jupyterhub"}
for service in self._service_map.values():
if service.oauth_available:
oauth_client_ids.add(service.oauth_client_id)

View File

@@ -247,26 +247,6 @@ class BaseHandler(RequestHandler):
return None
return match.group(1)
def get_current_user_oauth_token(self):
"""Get the current user identified by OAuth access token
Separate from API token because OAuth access tokens
can only be used for identifying users,
not using the API.
"""
token = self.get_auth_token()
if token is None:
return None
orm_token = orm.OAuthAccessToken.find(self.db, token)
if orm_token is None:
return None
now = datetime.utcnow()
recorded = self._record_activity(orm_token, now)
if self._record_activity(orm_token.user, now) or recorded:
self.db.commit()
return self._user_from_orm(orm_token.user)
def _record_activity(self, obj, timestamp=None):
"""record activity on an ORM object
@@ -373,7 +353,7 @@ class BaseHandler(RequestHandler):
# FIXME: scopes should give us better control than this
# don't consider API requests originating from a server
# to be activity from the user
if not orm_token.note.startswith("Server at "):
if not orm_token.note or not orm_token.note.startswith("Server at "):
recorded = self._record_activity(orm_token.user, now) or recorded
if recorded:
self.db.commit()
@@ -501,10 +481,8 @@ class BaseHandler(RequestHandler):
# don't clear session tokens if not logged in,
# because that could be a malicious logout request!
count = 0
for access_token in (
self.db.query(orm.OAuthAccessToken)
.filter(orm.OAuthAccessToken.user_id == user.id)
.filter(orm.OAuthAccessToken.session_id == session_id)
for access_token in self.db.query(orm.APIToken).filter_by(
user_id=user.id, session_id=session_id
):
self.db.delete(access_token)
count += 1

View File

@@ -552,36 +552,32 @@ class TokenPageHandler(BaseHandler):
return (token.last_activity or never, token.created or never)
now = datetime.utcnow()
api_tokens = []
for token in sorted(user.api_tokens, key=sort_key, reverse=True):
if token.expires_at and token.expires_at < now:
self.db.delete(token)
self.db.commit()
continue
api_tokens.append(token)
# group oauth client tokens by client id
# AccessTokens have expires_at as an integer timestamp
now_timestamp = now.timestamp()
oauth_tokens = defaultdict(list)
for token in user.oauth_tokens:
if token.expires_at and token.expires_at < now_timestamp:
self.log.warning("Deleting expired token")
all_tokens = defaultdict(list)
for token in sorted(user.api_tokens, key=sort_key, reverse=True):
if token.expires_at and token.expires_at < now:
self.log.warning(f"Deleting expired token {token}")
self.db.delete(token)
self.db.commit()
continue
if not token.client_id:
# token should have been deleted when client was deleted
self.log.warning("Deleting stale oauth token for %s", user.name)
self.log.warning("Deleting stale oauth token {token}")
self.db.delete(token)
self.db.commit()
continue
oauth_tokens[token.client_id].append(token)
all_tokens[token.client_id].append(token)
# individually list tokens issued by jupyterhub itself
api_tokens = all_tokens.pop("jupyterhub", [])
# group all other tokens issued under their owners
# get the earliest created and latest last_activity
# timestamp for a given oauth client
oauth_clients = []
for client_id, tokens in oauth_tokens.items():
for client_id, tokens in all_tokens.items():
created = tokens[0].created
last_activity = tokens[0].last_activity
for token in tokens[1:]:

View File

@@ -2,18 +2,18 @@
implements https://oauthlib.readthedocs.io/en/latest/oauth2/server.html
"""
from datetime import timedelta
from oauthlib import uri_validate
from oauthlib.oauth2 import RequestValidator
from oauthlib.oauth2 import WebApplicationServer
from oauthlib.oauth2.rfc6749.grant_types import authorization_code
from oauthlib.oauth2.rfc6749.grant_types import base
from tornado.escape import url_escape
from tornado.log import app_log
from .. import orm
from ..utils import compare_token
from ..utils import hash_token
from ..utils import url_path_join
# patch absolute-uri check
# because we want to allow relative uri oauth
@@ -60,6 +60,9 @@ class JupyterHubRequestValidator(RequestValidator):
)
if oauth_client is None:
return False
if not client_secret or not oauth_client.secret:
# disallow authentication with no secret
return False
if not compare_token(oauth_client.secret, client_secret):
app_log.warning("Client secret mismatch for %s", client_id)
return False
@@ -339,10 +342,10 @@ class JupyterHubRequestValidator(RequestValidator):
.filter_by(identifier=request.client.client_id)
.first()
)
orm_access_token = orm.OAuthAccessToken(
client=client,
orm_access_token = orm.APIToken.new(
client_id=client.identifier,
grant_type=orm.GrantType.authorization_code,
expires_at=orm.OAuthAccessToken.now() + token['expires_in'],
expires_at=orm.APIToken.now() + timedelta(seconds=token['expires_in']),
refresh_token=token['refresh_token'],
# TODO: save scopes,
# scopes=scopes,
@@ -412,6 +415,8 @@ class JupyterHubRequestValidator(RequestValidator):
)
if orm_client is None:
return False
if not orm_client.secret:
return False
request.client = orm_client
return True
@@ -574,14 +579,16 @@ class JupyterHubOAuthServer(WebApplicationServer):
app_log.info(f'Creating oauth client {client_id}')
else:
app_log.info(f'Updating oauth client {client_id}')
orm_client.secret = hash_token(client_secret)
orm_client.secret = hash_token(client_secret) if client_secret else ""
orm_client.redirect_uri = redirect_uri
orm_client.description = description
self.db.commit()
def fetch_by_client_id(self, client_id):
"""Find a client by its id"""
return self.db.query(orm.OAuthClient).filter_by(identifier=client_id).first()
client = self.db.query(orm.OAuthClient).filter_by(identifier=client_id).first()
if client and client.secret:
return client
def make_provider(session_factory, url_prefix, login_url, **oauth_server_kwargs):

View File

@@ -277,9 +277,6 @@ class User(Base):
last_activity = Column(DateTime, nullable=True)
api_tokens = relationship("APIToken", backref="user", cascade="all, delete-orphan")
oauth_tokens = relationship(
"OAuthAccessToken", backref="user", cascade="all, delete-orphan"
)
oauth_codes = relationship(
"OAuthCode", backref="user", cascade="all, delete-orphan"
)
@@ -485,7 +482,9 @@ class Hashed(Expiring):
@classmethod
def check_token(cls, db, token):
"""Check if a token is acceptable"""
print("checking", cls, token, len(token), cls.min_length)
if len(token) < cls.min_length:
print("raising")
raise ValueError(
"Tokens must be at least %i characters, got %r"
% (cls.min_length, token)
@@ -530,6 +529,20 @@ class Hashed(Expiring):
return orm_token
# ------------------------------------
# OAuth tables
# ------------------------------------
class GrantType(enum.Enum):
# we only use authorization_code for now
authorization_code = 'authorization_code'
implicit = 'implicit'
password = 'password'
client_credentials = 'client_credentials'
refresh_token = 'refresh_token'
class APIToken(Hashed, Base):
"""An API token"""
@@ -548,6 +561,15 @@ class APIToken(Hashed, Base):
def api_id(self):
return 'a%i' % self.id
# added in 2.0
client_id = Column(
Unicode(255), ForeignKey('oauth_clients.identifier', ondelete='CASCADE')
)
grant_type = Column(Enum(GrantType), nullable=False)
refresh_token = Column(Unicode(255))
# the browser session id associated with a given token
session_id = Column(Unicode(255))
# token metadata for bookkeeping
now = datetime.utcnow # for expiry
created = Column(DateTime, default=datetime.utcnow)
@@ -566,8 +588,12 @@ class APIToken(Hashed, Base):
# this shouldn't happen
kind = 'owner'
name = 'unknown'
return "<{cls}('{pre}...', {kind}='{name}')>".format(
cls=self.__class__.__name__, pre=self.prefix, kind=kind, name=name
return "<{cls}('{pre}...', {kind}='{name}', client_id={client_id!r})>".format(
cls=self.__class__.__name__,
pre=self.prefix,
kind=kind,
name=name,
client_id=self.client_id,
)
@classmethod
@@ -588,6 +614,14 @@ class APIToken(Hashed, Base):
raise ValueError("kind must be 'user', 'service', or None, not %r" % kind)
for orm_token in prefix_match:
if orm_token.match(token):
if not orm_token.client_id:
app_log.warning(
"Deleting stale oauth token for %s with no client",
orm_token.user and orm_token.user.name,
)
db.delete(orm_token)
db.commit()
return
return orm_token
@classmethod
@@ -600,6 +634,7 @@ class APIToken(Hashed, Base):
note='',
generated=True,
expires_in=None,
client_id='jupyterhub',
):
"""Generate a new API token for a user or service"""
assert user or service
@@ -614,7 +649,12 @@ class APIToken(Hashed, Base):
cls.check_token(db, token)
# two stages to ensure orm_token.generated has been set
# before token setter is called
orm_token = cls(generated=generated, note=note or '')
orm_token = cls(
generated=generated,
note=note or '',
grant_type=GrantType.authorization_code,
client_id=client_id,
)
orm_token.token = token
if user:
assert user.id is not None
@@ -641,76 +681,6 @@ class APIToken(Hashed, Base):
return token
# ------------------------------------
# OAuth tables
# ------------------------------------
class GrantType(enum.Enum):
# we only use authorization_code for now
authorization_code = 'authorization_code'
implicit = 'implicit'
password = 'password'
client_credentials = 'client_credentials'
refresh_token = 'refresh_token'
class OAuthAccessToken(Hashed, Base):
__tablename__ = 'oauth_access_tokens'
id = Column(Integer, primary_key=True, autoincrement=True)
@staticmethod
def now():
return datetime.utcnow().timestamp()
@property
def api_id(self):
return 'o%i' % self.id
client_id = Column(
Unicode(255), ForeignKey('oauth_clients.identifier', ondelete='CASCADE')
)
grant_type = Column(Enum(GrantType), nullable=False)
expires_at = Column(Integer)
refresh_token = Column(Unicode(255))
# TODO: drop refresh_expires_at. Refresh tokens shouldn't expire
refresh_expires_at = Column(Integer)
user_id = Column(Integer, ForeignKey('users.id', ondelete='CASCADE'))
service = None # for API-equivalence with APIToken
# the browser session id associated with a given token
session_id = Column(Unicode(255))
# from Hashed
hashed = Column(Unicode(255), unique=True)
prefix = Column(Unicode(16), index=True)
created = Column(DateTime, default=datetime.utcnow)
last_activity = Column(DateTime, nullable=True)
def __repr__(self):
return "<{cls}('{prefix}...', client_id={client_id!r}, user={user!r}, expires_in={expires_in}>".format(
cls=self.__class__.__name__,
client_id=self.client_id,
user=self.user and self.user.name,
prefix=self.prefix,
expires_in=self.expires_in,
)
@classmethod
def find(cls, db, token):
orm_token = super().find(db, token)
if orm_token and not orm_token.client_id:
app_log.warning(
"Deleting stale oauth token for %s with no client",
orm_token.user and orm_token.user.name,
)
db.delete(orm_token)
db.commit()
return
return orm_token
class OAuthCode(Expiring, Base):
__tablename__ = 'oauth_codes'
@@ -752,7 +722,7 @@ class OAuthClient(Base):
return self.identifier
access_tokens = relationship(
OAuthAccessToken, backref='client', cascade='all, delete-orphan'
APIToken, backref='client', cascade='all, delete-orphan'
)
codes = relationship(OAuthCode, backref='client', cascade='all, delete-orphan')

View File

@@ -51,6 +51,7 @@ from traitlets import Dict
from traitlets import HasTraits
from traitlets import Instance
from traitlets import Unicode
from traitlets import validate
from traitlets.config import LoggingConfigurable
from .. import orm
@@ -284,6 +285,15 @@ class Service(LoggingConfigurable):
def _default_client_id(self):
return 'service-%s' % self.name
@validate("oauth_client_id")
def _validate_client_id(self, proposal):
if not proposal.value.startswith("service-"):
raise ValueError(
f"service {self.name} has oauth_client_id='{proposal.value}'."
" Service oauth client ids must start with 'service-'"
)
return proposal.value
oauth_redirect_uri = Unicode(
help="""OAuth redirect URI for this service.

View File

@@ -70,7 +70,11 @@ def populate_db(url):
code = orm.OAuthCode(client_id=client.identifier)
db.add(code)
db.commit()
access_token = orm.OAuthAccessToken(
if jupyterhub.version_info < (2, 0):
Token = orm.OAuthAccessToken
else:
Token = orm.APIToken
access_token = Token(
client_id=client.identifier,
user_id=user.id,
grant_type=orm.GrantType.authorization_code,

View File

@@ -273,7 +273,7 @@ async def test_get_self(app):
oauth_client = orm.OAuthClient(identifier='eurydice')
db.add(oauth_client)
db.commit()
oauth_token = orm.OAuthAccessToken(
oauth_token = orm.APIToken(
user=u.orm_user,
client=oauth_client,
token=token,
@@ -1423,12 +1423,11 @@ async def test_token_list(app, as_user, for_user, status):
if status != 200:
return
reply = r.json()
assert sorted(reply) == ['api_tokens', 'oauth_tokens']
assert sorted(reply) == ['api_tokens']
assert len(reply['api_tokens']) == len(for_user_obj.api_tokens)
assert all(token['user'] == for_user for token in reply['api_tokens'])
assert all(token['user'] == for_user for token in reply['oauth_tokens'])
# validate individual token ids
for token in reply['api_tokens'] + reply['oauth_tokens']:
for token in reply['api_tokens']:
r = await api_request(
app, 'users', for_user, 'tokens', token['id'], headers=headers
)

View File

@@ -355,7 +355,7 @@ def test_user_delete_cascade(db):
spawner.server = server = orm.Server()
oauth_code = orm.OAuthCode(client=oauth_client, user=user)
db.add(oauth_code)
oauth_token = orm.OAuthAccessToken(
oauth_token = orm.APIToken(
client=oauth_client, user=user, grant_type=orm.GrantType.authorization_code
)
db.add(oauth_token)
@@ -377,7 +377,7 @@ def test_user_delete_cascade(db):
assert_not_found(db, orm.Spawner, spawner_id)
assert_not_found(db, orm.Server, server_id)
assert_not_found(db, orm.OAuthCode, oauth_code_id)
assert_not_found(db, orm.OAuthAccessToken, oauth_token_id)
assert_not_found(db, orm.APIToken, oauth_token_id)
def test_oauth_client_delete_cascade(db):
@@ -391,12 +391,12 @@ def test_oauth_client_delete_cascade(db):
# these should all be deleted automatically when the user goes away
oauth_code = orm.OAuthCode(client=oauth_client, user=user)
db.add(oauth_code)
oauth_token = orm.OAuthAccessToken(
oauth_token = orm.APIToken(
client=oauth_client, user=user, grant_type=orm.GrantType.authorization_code
)
db.add(oauth_token)
db.commit()
assert user.oauth_tokens == [oauth_token]
assert user.tokens == [oauth_token]
# record all of the ids
oauth_code_id = oauth_code.id
@@ -408,8 +408,8 @@ def test_oauth_client_delete_cascade(db):
# verify that everything gets deleted
assert_not_found(db, orm.OAuthCode, oauth_code_id)
assert_not_found(db, orm.OAuthAccessToken, oauth_token_id)
assert user.oauth_tokens == []
assert_not_found(db, orm.APIToken, oauth_token_id)
assert user.tokens == []
assert user.oauth_codes == []
@@ -510,32 +510,32 @@ def test_expiring_api_token(app, user):
def test_expiring_oauth_token(app, user):
db = app.db
token = "abc123"
now = orm.OAuthAccessToken.now
now = orm.APIToken.now
client = orm.OAuthClient(identifier="xxx", secret="yyy")
db.add(client)
orm_token = orm.OAuthAccessToken(
orm_token = orm.APIToken(
token=token,
grant_type=orm.GrantType.authorization_code,
client=client,
user=user,
expires_at=now() + 30,
expires_at=now() + datetime.timedelta(seconds=30),
)
db.add(orm_token)
db.commit()
found = orm.OAuthAccessToken.find(db, token)
found = orm.APIToken.find(db, token)
assert found is orm_token
# purge_expired doesn't delete non-expired
orm.OAuthAccessToken.purge_expired(db)
found = orm.OAuthAccessToken.find(db, token)
orm.APIToken.purge_expired(db)
found = orm.APIToken.find(db, token)
assert found is orm_token
with mock.patch.object(orm.OAuthAccessToken, 'now', lambda: now() + 60):
found = orm.OAuthAccessToken.find(db, token)
with mock.patch.object(orm.APIToken, 'now', lambda: now() + 60):
found = orm.APIToken.find(db, token)
assert found is None
assert orm_token in db.query(orm.OAuthAccessToken)
orm.OAuthAccessToken.purge_expired(db)
assert orm_token not in db.query(orm.OAuthAccessToken)
assert orm_token in db.query(orm.APIToken)
orm.APIToken.purge_expired(db)
assert orm_token not in db.query(orm.APIToken)
def test_expiring_oauth_code(app, user):

View File

@@ -869,7 +869,7 @@ async def test_oauth_token_page(app):
user = app.users[orm.User.find(app.db, name)]
client = orm.OAuthClient(identifier='token')
app.db.add(client)
oauth_token = orm.OAuthAccessToken(
oauth_token = orm.APIToken(
client=client, user=user, grant_type=orm.GrantType.authorization_code
)
app.db.add(oauth_token)

View File

@@ -444,11 +444,7 @@ async def test_oauth_logout(app, mockservice_url):
def auth_tokens():
"""Return list of OAuth access tokens for the user"""
return list(
app.db.query(orm.OAuthAccessToken).filter(
orm.OAuthAccessToken.user_id == app_user.id
)
)
return list(app.db.query(orm.APIToken).filter_by(user_id=app_user.id))
# ensure we start empty
assert auth_tokens() == []