Merge pull request #3380 from minrk/rm-oauth-tokens

Merge OAuth and API tokens
This commit is contained in:
Min RK
2021-04-14 16:27:14 +02:00
committed by GitHub
19 changed files with 278 additions and 252 deletions

View File

@@ -20,7 +20,7 @@ fi
# Configure a set of databases in the database server for upgrade tests
set -x
for SUFFIX in '' _upgrade_072 _upgrade_081 _upgrade_094; do
for SUFFIX in '' _upgrade_100 _upgrade_122 _upgrade_130; do
$SQL_CLIENT "DROP DATABASE jupyterhub${SUFFIX};" 2>/dev/null || true
$SQL_CLIENT "CREATE DATABASE jupyterhub${SUFFIX} ${EXTRA_CREATE_DATABASE_ARGS:-};"
done

View File

@@ -3,8 +3,8 @@
# Distributed under the terms of the Modified BSD License.
version_info = (
1,
4,
2,
0,
0,
"", # release (b1, rc1, or "" for final or dev)
"dev", # dev or nothing for beta/rc/stable releases

View File

@@ -0,0 +1,49 @@
"""rbac
Revision ID: 833da8570507
Revises: 4dc2d5a8c53c
Create Date: 2021-02-17 15:03:04.360368
"""
# revision identifiers, used by Alembic.
revision = '833da8570507'
down_revision = '4dc2d5a8c53c'
branch_labels = None
depends_on = None
from alembic import op
import sqlalchemy as sa
def upgrade():
# FIXME, maybe: currently drops all api tokens and forces recreation!
# this ensures a consistent database, but requires:
# 1. all servers to be stopped for upgrade (maybe unavoidable anyway)
# 2. any manually issued/stored tokens to be re-issued
# tokens loaded via configuration will be recreated on launch and unaffected
op.drop_table('api_tokens')
op.drop_table('oauth_access_tokens')
return
# TODO: explore in-place migration. This seems hard!
# 1. add new columns in api tokens
# 2. fill default fields (client_id='jupyterhub') for all api tokens
# 3. copy oauth tokens into api tokens
# 4. give oauth tokens 'identify' scopes
def downgrade():
# delete OAuth tokens for non-jupyterhub clients
# drop new columns from api tokens
op.drop_constraint(None, 'api_tokens', type_='foreignkey')
op.drop_column('api_tokens', 'session_id')
op.drop_column('api_tokens', 'client_id')
# FIXME: only drop tokens whose client id is not 'jupyterhub'
# until then, drop all tokens
op.drop_table("api_tokens")
op.drop_table('api_token_role_map')
op.drop_table('service_role_map')
op.drop_table('user_role_map')
op.drop_table('roles')

View File

@@ -29,8 +29,6 @@ class TokenAPIHandler(APIHandler):
"/authorizations/token/:token endpoint is deprecated in JupyterHub 2.0. Use /api/user"
)
orm_token = orm.APIToken.find(self.db, token)
if orm_token is None:
orm_token = orm.OAuthAccessToken.find(self.db, token)
if orm_token is None:
raise web.HTTPError(404)

View File

@@ -205,23 +205,6 @@ class APIHandler(BaseHandler):
def token_model(self, token):
"""Get the JSON model for an APIToken"""
expires_at = None
if isinstance(token, orm.APIToken):
kind = 'api_token'
roles = [r.name for r in token.roles]
extra = {'note': token.note}
expires_at = token.expires_at
elif isinstance(token, orm.OAuthAccessToken):
kind = 'oauth'
# oauth tokens do not bear roles
roles = []
extra = {'oauth_client': token.client.description or token.client.client_id}
if token.expires_at:
expires_at = datetime.fromtimestamp(token.expires_at)
else:
raise TypeError(
"token must be an APIToken or OAuthAccessToken, not %s" % type(token)
)
if token.user:
owner_key = 'user'
@@ -234,13 +217,14 @@ class APIHandler(BaseHandler):
model = {
owner_key: owner,
'id': token.api_id,
'kind': kind,
'roles': [role for role in roles],
'kind': 'api_token',
'roles': [r.name for r in token.roles],
'created': isoformat(token.created),
'last_activity': isoformat(token.last_activity),
'expires_at': isoformat(expires_at),
'expires_at': isoformat(token.expires_at),
'note': token.note,
'oauth_client': token.client.description or token.client.client_id,
}
model.update(extra)
return model
def user_model(self, user):

View File

@@ -14,6 +14,7 @@ from tornado import web
from tornado.iostream import StreamClosedError
from .. import orm
from .. import scopes
from ..roles import assign_default_roles
from ..scopes import needs_scope
from ..user import User
@@ -32,14 +33,16 @@ class SelfAPIHandler(APIHandler):
async def get(self):
user = self.current_user
if user is None:
# whoami can be accessed via oauth token
user = self.get_current_user_oauth_token()
if user is None:
raise web.HTTPError(403)
if isinstance(user, orm.Service):
# ensure we have the minimal 'identify' scopes for the token owner
self.raw_scopes.update(scopes.identify_scopes(user))
self.parsed_scopes = scopes.parse_scopes(self.raw_scopes)
model = self.service_model(user)
else:
self.raw_scopes.update(scopes.identify_scopes(user.orm_user))
self.parsed_scopes = scopes.parse_scopes(self.raw_scopes)
model = self.user_model(user)
self.write(json.dumps(model))
@@ -316,17 +319,7 @@ class UserTokenListAPIHandler(APIHandler):
continue
api_tokens.append(self.token_model(token))
oauth_tokens = []
# OAuth tokens use integer timestamps
now_timestamp = now.timestamp()
for token in sorted(user.oauth_tokens, key=sort_key):
if token.expires_at and token.expires_at < now_timestamp:
# exclude expired tokens
self.db.delete(token)
self.db.commit()
continue
oauth_tokens.append(self.token_model(token))
self.write(json.dumps({'api_tokens': api_tokens, 'oauth_tokens': oauth_tokens}))
self.write(json.dumps({'api_tokens': api_tokens}))
# Todo: Set to @needs_scope('users:tokens')
async def post(self, user_name):
@@ -410,19 +403,15 @@ class UserTokenAPIHandler(APIHandler):
(e.g. wrong owner, invalid key format, etc.)
"""
not_found = "No such token %s for user %s" % (token_id, user.name)
prefix, id_ = token_id[0], token_id[1:]
if prefix == 'a':
Token = orm.APIToken
elif prefix == 'o':
Token = orm.OAuthAccessToken
else:
prefix, id_ = token_id[:1], token_id[1:]
if prefix != 'a':
raise web.HTTPError(404, not_found)
try:
id_ = int(id_)
except ValueError:
raise web.HTTPError(404, not_found)
orm_token = self.db.query(Token).filter(Token.id == id_).first()
orm_token = self.db.query(orm.APIToken).filter_by(id=id_).first()
if orm_token is None or orm_token.user is not user.orm_user:
raise web.HTTPError(404, "Token not found %s", orm_token)
return orm_token
@@ -444,10 +433,10 @@ class UserTokenAPIHandler(APIHandler):
raise web.HTTPError(404, "No such user: %s" % user_name)
token = self.find_token_by_id(user, token_id)
# deleting an oauth token deletes *all* oauth tokens for that client
if isinstance(token, orm.OAuthAccessToken):
client_id = token.client_id
if token.client_id != "jupyterhub":
tokens = [
token for token in user.oauth_tokens if token.client_id == client_id
token for token in user.api_tokens if token.client_id == client_id
]
else:
tokens = [token]

View File

@@ -1692,6 +1692,26 @@ class JupyterHub(Application):
except orm.DatabaseSchemaMismatch as e:
self.exit(e)
# ensure the default oauth client exists
if (
not self.db.query(orm.OAuthClient)
.filter_by(identifier="jupyterhub")
.one_or_none()
):
# create the oauth client for jupyterhub itself
# this allows us to distinguish between orphaned tokens
# (failed cascade deletion) and tokens issued by the hub
# it has no client_secret, which means it cannot be used
# to make requests
client = orm.OAuthClient(
identifier="jupyterhub",
secret="",
redirect_uri="",
description="JupyterHub",
)
self.db.add(client)
self.db.commit()
def init_hub(self):
"""Load the Hub URL config"""
hub_args = dict(
@@ -2014,12 +2034,13 @@ class JupyterHub(Application):
run periodically
"""
# this should be all the subclasses of Expiring
for cls in (orm.APIToken, orm.OAuthAccessToken, orm.OAuthCode):
for cls in (orm.APIToken, orm.OAuthCode):
self.log.debug("Purging expired {name}s".format(name=cls.__name__))
cls.purge_expired(self.db)
async def init_api_tokens(self):
"""Load predefined API tokens (for services) into database"""
await self._add_tokens(self.service_tokens, kind='service')
await self._add_tokens(self.api_tokens, kind='user')
@@ -2298,7 +2319,7 @@ class JupyterHub(Application):
This should mainly be services that have been removed from configuration or renamed.
"""
oauth_client_ids = set()
oauth_client_ids = {"jupyterhub"}
for service in self._service_map.values():
if service.oauth_available:
oauth_client_ids.add(service.oauth_client_id)

View File

@@ -247,26 +247,6 @@ class BaseHandler(RequestHandler):
return None
return match.group(1)
def get_current_user_oauth_token(self):
"""Get the current user identified by OAuth access token
Separate from API token because OAuth access tokens
can only be used for identifying users,
not using the API.
"""
token = self.get_auth_token()
if token is None:
return None
orm_token = orm.OAuthAccessToken.find(self.db, token)
if orm_token is None:
return None
now = datetime.utcnow()
recorded = self._record_activity(orm_token, now)
if self._record_activity(orm_token.user, now) or recorded:
self.db.commit()
return self._user_from_orm(orm_token.user)
def _record_activity(self, obj, timestamp=None):
"""record activity on an ORM object
@@ -373,7 +353,7 @@ class BaseHandler(RequestHandler):
# FIXME: scopes should give us better control than this
# don't consider API requests originating from a server
# to be activity from the user
if not orm_token.note.startswith("Server at "):
if not orm_token.note or not orm_token.note.startswith("Server at "):
recorded = self._record_activity(orm_token.user, now) or recorded
if recorded:
self.db.commit()
@@ -439,17 +419,10 @@ class BaseHandler(RequestHandler):
def _resolve_scopes(self):
self.raw_scopes = set()
app_log.debug("Loading and parsing scopes")
if not self.current_user:
# check for oauth tokens as long as #3380 not merged
user_from_oauth = self.get_current_user_oauth_token()
if user_from_oauth is not None:
self.raw_scopes = {f'read:users!user={user_from_oauth.name}'}
else:
app_log.debug("No user found, no scopes loaded")
else:
api_token = self.get_token()
if api_token:
self.raw_scopes = scopes.get_scopes_for(api_token)
if self.current_user:
orm_token = self.get_token()
if orm_token:
self.raw_scopes = scopes.get_scopes_for(orm_token)
else:
self.raw_scopes = scopes.get_scopes_for(self.current_user)
self.parsed_scopes = scopes.parse_scopes(self.raw_scopes)
@@ -501,10 +474,8 @@ class BaseHandler(RequestHandler):
# don't clear session tokens if not logged in,
# because that could be a malicious logout request!
count = 0
for access_token in (
self.db.query(orm.OAuthAccessToken)
.filter(orm.OAuthAccessToken.user_id == user.id)
.filter(orm.OAuthAccessToken.session_id == session_id)
for access_token in self.db.query(orm.APIToken).filter_by(
user_id=user.id, session_id=session_id
):
self.db.delete(access_token)
count += 1

View File

@@ -552,36 +552,32 @@ class TokenPageHandler(BaseHandler):
return (token.last_activity or never, token.created or never)
now = datetime.utcnow()
api_tokens = []
for token in sorted(user.api_tokens, key=sort_key, reverse=True):
if token.expires_at and token.expires_at < now:
self.db.delete(token)
self.db.commit()
continue
api_tokens.append(token)
# group oauth client tokens by client id
# AccessTokens have expires_at as an integer timestamp
now_timestamp = now.timestamp()
oauth_tokens = defaultdict(list)
for token in user.oauth_tokens:
if token.expires_at and token.expires_at < now_timestamp:
self.log.warning("Deleting expired token")
all_tokens = defaultdict(list)
for token in sorted(user.api_tokens, key=sort_key, reverse=True):
if token.expires_at and token.expires_at < now:
self.log.warning(f"Deleting expired token {token}")
self.db.delete(token)
self.db.commit()
continue
if not token.client_id:
# token should have been deleted when client was deleted
self.log.warning("Deleting stale oauth token for %s", user.name)
self.log.warning("Deleting stale oauth token {token}")
self.db.delete(token)
self.db.commit()
continue
oauth_tokens[token.client_id].append(token)
all_tokens[token.client_id].append(token)
# individually list tokens issued by jupyterhub itself
api_tokens = all_tokens.pop("jupyterhub", [])
# group all other tokens issued under their owners
# get the earliest created and latest last_activity
# timestamp for a given oauth client
oauth_clients = []
for client_id, tokens in oauth_tokens.items():
for client_id, tokens in all_tokens.items():
created = tokens[0].created
last_activity = tokens[0].last_activity
for token in tokens[1:]:

View File

@@ -2,18 +2,18 @@
implements https://oauthlib.readthedocs.io/en/latest/oauth2/server.html
"""
from datetime import timedelta
from oauthlib import uri_validate
from oauthlib.oauth2 import RequestValidator
from oauthlib.oauth2 import WebApplicationServer
from oauthlib.oauth2.rfc6749.grant_types import authorization_code
from oauthlib.oauth2.rfc6749.grant_types import base
from tornado.escape import url_escape
from tornado.log import app_log
from .. import orm
from ..utils import compare_token
from ..utils import hash_token
from ..utils import url_path_join
# patch absolute-uri check
# because we want to allow relative uri oauth
@@ -60,6 +60,9 @@ class JupyterHubRequestValidator(RequestValidator):
)
if oauth_client is None:
return False
if not client_secret or not oauth_client.secret:
# disallow authentication with no secret
return False
if not compare_token(oauth_client.secret, client_secret):
app_log.warning("Client secret mismatch for %s", client_id)
return False
@@ -339,19 +342,22 @@ class JupyterHubRequestValidator(RequestValidator):
.filter_by(identifier=request.client.client_id)
.first()
)
orm_access_token = orm.OAuthAccessToken(
client=client,
grant_type=orm.GrantType.authorization_code,
expires_at=orm.OAuthAccessToken.now() + token['expires_in'],
refresh_token=token['refresh_token'],
# TODO: save scopes,
# scopes=scopes,
# FIXME: pick a role
# this will be empty for now
roles = list(self.db.query(orm.Role).filter_by(name='identify'))
# FIXME: support refresh tokens
# These should be in a new table
token.pop("refresh_token", None)
# APIToken.new commits the token to the db
orm.APIToken.new(
client_id=client.identifier,
expires_in=token['expires_in'],
roles=roles,
token=token['access_token'],
session_id=request.session_id,
user=request.user,
)
self.db.add(orm_access_token)
self.db.commit()
return client.redirect_uri
def validate_bearer_token(self, token, scopes, request):
@@ -412,6 +418,8 @@ class JupyterHubRequestValidator(RequestValidator):
)
if orm_client is None:
return False
if not orm_client.secret:
return False
request.client = orm_client
return True
@@ -574,14 +582,16 @@ class JupyterHubOAuthServer(WebApplicationServer):
app_log.info(f'Creating oauth client {client_id}')
else:
app_log.info(f'Updating oauth client {client_id}')
orm_client.secret = hash_token(client_secret)
orm_client.secret = hash_token(client_secret) if client_secret else ""
orm_client.redirect_uri = redirect_uri
orm_client.description = description
self.db.commit()
def fetch_by_client_id(self, client_id):
"""Find a client by its id"""
return self.db.query(orm.OAuthClient).filter_by(identifier=client_id).first()
client = self.db.query(orm.OAuthClient).filter_by(identifier=client_id).first()
if client and client.secret:
return client
def make_provider(session_factory, url_prefix, login_url, **oauth_server_kwargs):

View File

@@ -277,9 +277,6 @@ class User(Base):
last_activity = Column(DateTime, nullable=True)
api_tokens = relationship("APIToken", backref="user", cascade="all, delete-orphan")
oauth_tokens = relationship(
"OAuthAccessToken", backref="user", cascade="all, delete-orphan"
)
oauth_codes = relationship(
"OAuthCode", backref="user", cascade="all, delete-orphan"
)
@@ -485,7 +482,9 @@ class Hashed(Expiring):
@classmethod
def check_token(cls, db, token):
"""Check if a token is acceptable"""
print("checking", cls, token, len(token), cls.min_length)
if len(token) < cls.min_length:
print("raising")
raise ValueError(
"Tokens must be at least %i characters, got %r"
% (cls.min_length, token)
@@ -530,14 +529,34 @@ class Hashed(Expiring):
return orm_token
# ------------------------------------
# OAuth tables
# ------------------------------------
class GrantType(enum.Enum):
# we only use authorization_code for now
authorization_code = 'authorization_code'
implicit = 'implicit'
password = 'password'
client_credentials = 'client_credentials'
refresh_token = 'refresh_token'
class APIToken(Hashed, Base):
"""An API token"""
__tablename__ = 'api_tokens'
user_id = Column(Integer, ForeignKey('users.id', ondelete="CASCADE"), nullable=True)
user_id = Column(
Integer,
ForeignKey('users.id', ondelete="CASCADE"),
nullable=True,
)
service_id = Column(
Integer, ForeignKey('services.id', ondelete="CASCADE"), nullable=True
Integer,
ForeignKey('services.id', ondelete="CASCADE"),
nullable=True,
)
id = Column(Integer, primary_key=True)
@@ -548,6 +567,26 @@ class APIToken(Hashed, Base):
def api_id(self):
return 'a%i' % self.id
# added in 2.0
client_id = Column(
Unicode(255),
ForeignKey(
'oauth_clients.identifier',
ondelete='CASCADE',
),
)
# FIXME: refresh_tokens not implemented
# should be a relation to another token table
# refresh_token = Column(
# Integer,
# ForeignKey('refresh_tokens.id', ondelete="CASCADE"),
# nullable=True,
# )
# the browser session id associated with a given token,
# if issued during oauth to be stored in a cookie
session_id = Column(Unicode(255), nullable=True)
# token metadata for bookkeeping
now = datetime.utcnow # for expiry
created = Column(DateTime, default=datetime.utcnow)
@@ -566,8 +605,12 @@ class APIToken(Hashed, Base):
# this shouldn't happen
kind = 'owner'
name = 'unknown'
return "<{cls}('{pre}...', {kind}='{name}')>".format(
cls=self.__class__.__name__, pre=self.prefix, kind=kind, name=name
return "<{cls}('{pre}...', {kind}='{name}', client_id={client_id!r})>".format(
cls=self.__class__.__name__,
pre=self.prefix,
kind=kind,
name=name,
client_id=self.client_id,
)
@classmethod
@@ -588,6 +631,14 @@ class APIToken(Hashed, Base):
raise ValueError("kind must be 'user', 'service', or None, not %r" % kind)
for orm_token in prefix_match:
if orm_token.match(token):
if not orm_token.client_id:
app_log.warning(
"Deleting stale oauth token for %s with no client",
orm_token.user and orm_token.user.name,
)
db.delete(orm_token)
db.commit()
return
return orm_token
@classmethod
@@ -599,7 +650,10 @@ class APIToken(Hashed, Base):
roles=None,
note='',
generated=True,
session_id=None,
expires_in=None,
client_id='jupyterhub',
return_orm=False,
):
"""Generate a new API token for a user or service"""
assert user or service
@@ -614,7 +668,12 @@ class APIToken(Hashed, Base):
cls.check_token(db, token)
# two stages to ensure orm_token.generated has been set
# before token setter is called
orm_token = cls(generated=generated, note=note or '')
orm_token = cls(
generated=generated,
note=note or '',
client_id=client_id,
session_id=session_id,
)
orm_token.token = token
if user:
assert user.id is not None
@@ -641,76 +700,6 @@ class APIToken(Hashed, Base):
return token
# ------------------------------------
# OAuth tables
# ------------------------------------
class GrantType(enum.Enum):
# we only use authorization_code for now
authorization_code = 'authorization_code'
implicit = 'implicit'
password = 'password'
client_credentials = 'client_credentials'
refresh_token = 'refresh_token'
class OAuthAccessToken(Hashed, Base):
__tablename__ = 'oauth_access_tokens'
id = Column(Integer, primary_key=True, autoincrement=True)
@staticmethod
def now():
return datetime.utcnow().timestamp()
@property
def api_id(self):
return 'o%i' % self.id
client_id = Column(
Unicode(255), ForeignKey('oauth_clients.identifier', ondelete='CASCADE')
)
grant_type = Column(Enum(GrantType), nullable=False)
expires_at = Column(Integer)
refresh_token = Column(Unicode(255))
# TODO: drop refresh_expires_at. Refresh tokens shouldn't expire
refresh_expires_at = Column(Integer)
user_id = Column(Integer, ForeignKey('users.id', ondelete='CASCADE'))
service = None # for API-equivalence with APIToken
# the browser session id associated with a given token
session_id = Column(Unicode(255))
# from Hashed
hashed = Column(Unicode(255), unique=True)
prefix = Column(Unicode(16), index=True)
created = Column(DateTime, default=datetime.utcnow)
last_activity = Column(DateTime, nullable=True)
def __repr__(self):
return "<{cls}('{prefix}...', client_id={client_id!r}, user={user!r}, expires_in={expires_in}>".format(
cls=self.__class__.__name__,
client_id=self.client_id,
user=self.user and self.user.name,
prefix=self.prefix,
expires_in=self.expires_in,
)
@classmethod
def find(cls, db, token):
orm_token = super().find(db, token)
if orm_token and not orm_token.client_id:
app_log.warning(
"Deleting stale oauth token for %s with no client",
orm_token.user and orm_token.user.name,
)
db.delete(orm_token)
db.commit()
return
return orm_token
class OAuthCode(Expiring, Base):
__tablename__ = 'oauth_codes'
@@ -752,7 +741,7 @@ class OAuthClient(Base):
return self.identifier
access_tokens = relationship(
OAuthAccessToken, backref='client', cascade='all, delete-orphan'
APIToken, backref='client', cascade='all, delete-orphan'
)
codes = relationship(OAuthCode, backref='client', cascade='all, delete-orphan')

View File

@@ -51,6 +51,7 @@ from traitlets import Dict
from traitlets import HasTraits
from traitlets import Instance
from traitlets import Unicode
from traitlets import validate
from traitlets.config import LoggingConfigurable
from .. import orm
@@ -284,6 +285,15 @@ class Service(LoggingConfigurable):
def _default_client_id(self):
return 'service-%s' % self.name
@validate("oauth_client_id")
def _validate_client_id(self, proposal):
if not proposal.value.startswith("service-"):
raise ValueError(
f"service {self.name} has oauth_client_id='{proposal.value}'."
" Service oauth client ids must start with 'service-'"
)
return proposal.value
oauth_redirect_uri = Unicode(
help="""OAuth redirect URI for this service.

View File

@@ -125,7 +125,11 @@ def db():
"""Get a db session"""
global _db
if _db is None:
_db = orm.new_session_factory('sqlite:///:memory:')()
# make sure some initial db contents are filled out
# specifically, the 'default' jupyterhub oauth client
app = MockHub(db_url='sqlite:///:memory:')
app.init_db()
_db = app.db
user = orm.User(name=getuser())
_db.add(user)
_db.commit()
@@ -164,9 +168,14 @@ def cleanup_after(request, io_loop):
allows cleanup of servers between tests
without having to launch a whole new app
"""
try:
yield
finally:
if _db is not None:
# cleanup after failed transactions
_db.rollback()
if not MockHub.initialized():
return
app = MockHub.instance()

View File

@@ -6,6 +6,7 @@ used in test_db.py
"""
import os
from datetime import datetime
from functools import partial
import jupyterhub
from jupyterhub import orm
@@ -62,24 +63,27 @@ def populate_db(url):
db.commit()
# create some oauth objects
if jupyterhub.version_info >= (0, 8):
# create oauth client
client = orm.OAuthClient(identifier='oauth-client')
db.add(client)
db.commit()
code = orm.OAuthCode(client_id=client.identifier)
db.add(code)
db.commit()
access_token = orm.OAuthAccessToken(
if jupyterhub.version_info < (2, 0):
Token = partial(
orm.OAuthAccessToken,
grant_type=orm.GrantType.authorization_code,
)
else:
Token = orm.APIToken
access_token = Token(
client_id=client.identifier,
user_id=user.id,
grant_type=orm.GrantType.authorization_code,
)
db.add(access_token)
db.commit()
# set some timestamps added in 0.9
if jupyterhub.version_info >= (0, 9):
assert user.created
assert admin.created
# set last_activity

View File

@@ -273,11 +273,10 @@ async def test_get_self(app):
oauth_client = orm.OAuthClient(identifier='eurydice')
db.add(oauth_client)
db.commit()
oauth_token = orm.OAuthAccessToken(
oauth_token = orm.APIToken(
user=u.orm_user,
client=oauth_client,
token=token,
grant_type=orm.GrantType.authorization_code,
)
db.add(oauth_token)
db.commit()
@@ -1423,12 +1422,11 @@ async def test_token_list(app, as_user, for_user, status):
if status != 200:
return
reply = r.json()
assert sorted(reply) == ['api_tokens', 'oauth_tokens']
assert sorted(reply) == ['api_tokens']
assert len(reply['api_tokens']) == len(for_user_obj.api_tokens)
assert all(token['user'] == for_user for token in reply['api_tokens'])
assert all(token['user'] == for_user for token in reply['oauth_tokens'])
# validate individual token ids
for token in reply['api_tokens'] + reply['oauth_tokens']:
for token in reply['api_tokens']:
r = await api_request(
app, 'users', for_user, 'tokens', token['id'], headers=headers
)

View File

@@ -36,7 +36,7 @@ def generate_old_db(env_dir, hub_version, db_url):
check_call([env_py, populate_db, db_url])
@pytest.mark.parametrize('hub_version', ['0.7.2', '0.8.1', '0.9.4'])
@pytest.mark.parametrize('hub_version', ['1.0.0', "1.2.2", "1.3.0"])
async def test_upgrade(tmpdir, hub_version):
db_url = os.getenv('JUPYTERHUB_TEST_DB_URL')
if db_url:

View File

@@ -355,8 +355,9 @@ def test_user_delete_cascade(db):
spawner.server = server = orm.Server()
oauth_code = orm.OAuthCode(client=oauth_client, user=user)
db.add(oauth_code)
oauth_token = orm.OAuthAccessToken(
client=oauth_client, user=user, grant_type=orm.GrantType.authorization_code
oauth_token = orm.APIToken(
client=oauth_client,
user=user,
)
db.add(oauth_token)
db.commit()
@@ -377,7 +378,7 @@ def test_user_delete_cascade(db):
assert_not_found(db, orm.Spawner, spawner_id)
assert_not_found(db, orm.Server, server_id)
assert_not_found(db, orm.OAuthCode, oauth_code_id)
assert_not_found(db, orm.OAuthAccessToken, oauth_token_id)
assert_not_found(db, orm.APIToken, oauth_token_id)
def test_oauth_client_delete_cascade(db):
@@ -391,12 +392,13 @@ def test_oauth_client_delete_cascade(db):
# these should all be deleted automatically when the user goes away
oauth_code = orm.OAuthCode(client=oauth_client, user=user)
db.add(oauth_code)
oauth_token = orm.OAuthAccessToken(
client=oauth_client, user=user, grant_type=orm.GrantType.authorization_code
oauth_token = orm.APIToken(
client=oauth_client,
user=user,
)
db.add(oauth_token)
db.commit()
assert user.oauth_tokens == [oauth_token]
assert user.api_tokens == [oauth_token]
# record all of the ids
oauth_code_id = oauth_code.id
@@ -408,8 +410,8 @@ def test_oauth_client_delete_cascade(db):
# verify that everything gets deleted
assert_not_found(db, orm.OAuthCode, oauth_code_id)
assert_not_found(db, orm.OAuthAccessToken, oauth_token_id)
assert user.oauth_tokens == []
assert_not_found(db, orm.APIToken, oauth_token_id)
assert user.api_tokens == []
assert user.oauth_codes == []
@@ -510,32 +512,31 @@ def test_expiring_api_token(app, user):
def test_expiring_oauth_token(app, user):
db = app.db
token = "abc123"
now = orm.OAuthAccessToken.now
now = orm.APIToken.now
client = orm.OAuthClient(identifier="xxx", secret="yyy")
db.add(client)
orm_token = orm.OAuthAccessToken(
orm_token = orm.APIToken(
token=token,
grant_type=orm.GrantType.authorization_code,
client=client,
user=user,
expires_at=now() + 30,
expires_at=now() + timedelta(seconds=30),
)
db.add(orm_token)
db.commit()
found = orm.OAuthAccessToken.find(db, token)
found = orm.APIToken.find(db, token)
assert found is orm_token
# purge_expired doesn't delete non-expired
orm.OAuthAccessToken.purge_expired(db)
found = orm.OAuthAccessToken.find(db, token)
orm.APIToken.purge_expired(db)
found = orm.APIToken.find(db, token)
assert found is orm_token
with mock.patch.object(orm.OAuthAccessToken, 'now', lambda: now() + 60):
found = orm.OAuthAccessToken.find(db, token)
with mock.patch.object(orm.APIToken, 'now', lambda: now() + timedelta(seconds=60)):
found = orm.APIToken.find(db, token)
assert found is None
assert orm_token in db.query(orm.OAuthAccessToken)
orm.OAuthAccessToken.purge_expired(db)
assert orm_token not in db.query(orm.OAuthAccessToken)
assert orm_token in db.query(orm.APIToken)
orm.APIToken.purge_expired(db)
assert orm_token not in db.query(orm.APIToken)
def test_expiring_oauth_code(app, user):

View File

@@ -869,8 +869,9 @@ async def test_oauth_token_page(app):
user = app.users[orm.User.find(app.db, name)]
client = orm.OAuthClient(identifier='token')
app.db.add(client)
oauth_token = orm.OAuthAccessToken(
client=client, user=user, grant_type=orm.GrantType.authorization_code
oauth_token = orm.APIToken(
client=client,
user=user,
)
app.db.add(oauth_token)
app.db.commit()

View File

@@ -444,11 +444,7 @@ async def test_oauth_logout(app, mockservice_url):
def auth_tokens():
"""Return list of OAuth access tokens for the user"""
return list(
app.db.query(orm.OAuthAccessToken).filter(
orm.OAuthAccessToken.user_id == app_user.id
)
)
return list(app.db.query(orm.APIToken).filter_by(user_id=app_user.id))
# ensure we start empty
assert auth_tokens() == []