remove separate oauth tokens

- merge oauth token fields into APITokens
- create oauth client 'jupyterhub' which owns current API tokens
- db upgrade is currently to drop both token tables, and force recreation on next start
This commit is contained in:
Min RK
2021-03-10 16:41:37 +01:00
parent 2fdf820fe5
commit 0b56fd9e62
15 changed files with 264 additions and 202 deletions

View File

@@ -0,0 +1,119 @@
"""rbac
Revision ID: 833da8570507
Revises: 4dc2d5a8c53c
Create Date: 2021-02-17 15:03:04.360368
"""
# revision identifiers, used by Alembic.
revision = '833da8570507'
down_revision = '4dc2d5a8c53c'
branch_labels = None
depends_on = None
from alembic import op
import sqlalchemy as sa
def upgrade():
# FIXME: currently drops all api tokens and forces recreation!
# this ensures a consistent database, but requires:
# 1. all servers to be stopped for upgrade (maybe unavoidable anyway)
# 2. any manually issued/stored tokens to be re-issued
# tokens loaded via configuration will be recreated on launch and unaffected
op.drop_table('api_tokens')
op.drop_table('oauth_access_tokens')
return
# TODO: explore in-place migration. This seems hard!
# 1. add new columns in api tokens
# 2. fill default fields (client_id='jupyterhub') for all api tokens
# 3. copy oauth tokens into api tokens
# 4. give oauth tokens 'identify' scopes
c = op.get_bind()
naming_convention = {
"fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s",
}
with op.batch_alter_table(
"api_tokens",
naming_convention=naming_convention,
) as batch_op:
batch_op.add_column(
sa.Column(
'client_id',
sa.Unicode(255),
# sa.ForeignKey('oauth_clients.identifier', ondelete='CASCADE'),
nullable=True,
),
)
# batch_cursor = op.get_bind()
# batch_cursor.execute(
# """
# UPDATE api_tokens
# SET client_id='jupyterhub'
# WHERE client_id IS NULL
# """
# )
batch_op.create_foreign_key(
"fk_api_token_client_id",
# 'api_tokens',
'oauth_clients',
['client_id'],
['identifier'],
ondelete='CASCADE',
)
c.execute(
"""
UPDATE api_tokens
SET client_id='jupyterhub'
WHERE client_id IS NULL
"""
)
op.add_column(
'api_tokens',
sa.Column(
'grant_type',
sa.Enum(
'authorization_code',
'implicit',
'password',
'client_credentials',
'refresh_token',
name='granttype',
),
server_default='authorization_code',
nullable=False,
),
)
op.add_column(
'api_tokens', sa.Column('refresh_token', sa.Unicode(length=255), nullable=True)
)
op.add_column(
'api_tokens', sa.Column('session_id', sa.Unicode(length=255), nullable=True)
)
# TODO: migrate OAuth tokens into APIToken table
op.drop_index('ix_oauth_access_tokens_prefix', table_name='oauth_access_tokens')
op.drop_table('oauth_access_tokens')
def downgrade():
# delete OAuth tokens for non-jupyterhub clients
# drop new columns from api tokens
op.drop_constraint(None, 'api_tokens', type_='foreignkey')
op.drop_column('api_tokens', 'session_id')
op.drop_column('api_tokens', 'refresh_token')
op.drop_column('api_tokens', 'grant_type')
op.drop_column('api_tokens', 'client_id')
# FIXME: only drop tokens whose client id is not 'jupyterhub'
# until then, drop all tokens
op.drop_table("api_tokens")
op.drop_table('api_token_role_map')
op.drop_table('service_role_map')
op.drop_table('user_role_map')
op.drop_table('roles')

View File

@@ -29,8 +29,6 @@ class TokenAPIHandler(APIHandler):
"/authorizations/token/:token endpoint is deprecated in JupyterHub 2.0. Use /api/user" "/authorizations/token/:token endpoint is deprecated in JupyterHub 2.0. Use /api/user"
) )
orm_token = orm.APIToken.find(self.db, token) orm_token = orm.APIToken.find(self.db, token)
if orm_token is None:
orm_token = orm.OAuthAccessToken.find(self.db, token)
if orm_token is None: if orm_token is None:
raise web.HTTPError(404) raise web.HTTPError(404)

View File

@@ -205,23 +205,6 @@ class APIHandler(BaseHandler):
def token_model(self, token): def token_model(self, token):
"""Get the JSON model for an APIToken""" """Get the JSON model for an APIToken"""
expires_at = None
if isinstance(token, orm.APIToken):
kind = 'api_token'
roles = [r.name for r in token.roles]
extra = {'note': token.note}
expires_at = token.expires_at
elif isinstance(token, orm.OAuthAccessToken):
kind = 'oauth'
# oauth tokens do not bear roles
roles = []
extra = {'oauth_client': token.client.description or token.client.client_id}
if token.expires_at:
expires_at = datetime.fromtimestamp(token.expires_at)
else:
raise TypeError(
"token must be an APIToken or OAuthAccessToken, not %s" % type(token)
)
if token.user: if token.user:
owner_key = 'user' owner_key = 'user'
@@ -234,13 +217,14 @@ class APIHandler(BaseHandler):
model = { model = {
owner_key: owner, owner_key: owner,
'id': token.api_id, 'id': token.api_id,
'kind': kind, 'kind': 'api_token',
'roles': [role for role in roles], 'roles': [r.name for r in token.roles],
'created': isoformat(token.created), 'created': isoformat(token.created),
'last_activity': isoformat(token.last_activity), 'last_activity': isoformat(token.last_activity),
'expires_at': isoformat(expires_at), 'expires_at': isoformat(token.expires_at),
'note': token.note,
'oauth_client': token.client.description or token.client.client_id,
} }
model.update(extra)
return model return model
def user_model(self, user): def user_model(self, user):

View File

@@ -32,9 +32,6 @@ class SelfAPIHandler(APIHandler):
async def get(self): async def get(self):
user = self.current_user user = self.current_user
if user is None:
# whoami can be accessed via oauth token
user = self.get_current_user_oauth_token()
if user is None: if user is None:
raise web.HTTPError(403) raise web.HTTPError(403)
if isinstance(user, orm.Service): if isinstance(user, orm.Service):
@@ -316,17 +313,7 @@ class UserTokenListAPIHandler(APIHandler):
continue continue
api_tokens.append(self.token_model(token)) api_tokens.append(self.token_model(token))
oauth_tokens = [] self.write(json.dumps({'api_tokens': api_tokens}))
# OAuth tokens use integer timestamps
now_timestamp = now.timestamp()
for token in sorted(user.oauth_tokens, key=sort_key):
if token.expires_at and token.expires_at < now_timestamp:
# exclude expired tokens
self.db.delete(token)
self.db.commit()
continue
oauth_tokens.append(self.token_model(token))
self.write(json.dumps({'api_tokens': api_tokens, 'oauth_tokens': oauth_tokens}))
# Todo: Set to @needs_scope('users:tokens') # Todo: Set to @needs_scope('users:tokens')
async def post(self, user_name): async def post(self, user_name):
@@ -410,19 +397,15 @@ class UserTokenAPIHandler(APIHandler):
(e.g. wrong owner, invalid key format, etc.) (e.g. wrong owner, invalid key format, etc.)
""" """
not_found = "No such token %s for user %s" % (token_id, user.name) not_found = "No such token %s for user %s" % (token_id, user.name)
prefix, id_ = token_id[0], token_id[1:] prefix, id_ = token_id[:1], token_id[1:]
if prefix == 'a': if prefix != 'a':
Token = orm.APIToken
elif prefix == 'o':
Token = orm.OAuthAccessToken
else:
raise web.HTTPError(404, not_found) raise web.HTTPError(404, not_found)
try: try:
id_ = int(id_) id_ = int(id_)
except ValueError: except ValueError:
raise web.HTTPError(404, not_found) raise web.HTTPError(404, not_found)
orm_token = self.db.query(Token).filter(Token.id == id_).first() orm_token = self.db.query(orm.APIToken).filter_by(id=id_).first()
if orm_token is None or orm_token.user is not user.orm_user: if orm_token is None or orm_token.user is not user.orm_user:
raise web.HTTPError(404, "Token not found %s", orm_token) raise web.HTTPError(404, "Token not found %s", orm_token)
return orm_token return orm_token
@@ -444,10 +427,10 @@ class UserTokenAPIHandler(APIHandler):
raise web.HTTPError(404, "No such user: %s" % user_name) raise web.HTTPError(404, "No such user: %s" % user_name)
token = self.find_token_by_id(user, token_id) token = self.find_token_by_id(user, token_id)
# deleting an oauth token deletes *all* oauth tokens for that client # deleting an oauth token deletes *all* oauth tokens for that client
if isinstance(token, orm.OAuthAccessToken):
client_id = token.client_id client_id = token.client_id
if token.client_id != "jupyterhub":
tokens = [ tokens = [
token for token in user.oauth_tokens if token.client_id == client_id token for token in user.api_tokens if token.client_id == client_id
] ]
else: else:
tokens = [token] tokens = [token]

View File

@@ -2014,12 +2014,13 @@ class JupyterHub(Application):
run periodically run periodically
""" """
# this should be all the subclasses of Expiring # this should be all the subclasses of Expiring
for cls in (orm.APIToken, orm.OAuthAccessToken, orm.OAuthCode): for cls in (orm.APIToken, orm.OAuthCode):
self.log.debug("Purging expired {name}s".format(name=cls.__name__)) self.log.debug("Purging expired {name}s".format(name=cls.__name__))
cls.purge_expired(self.db) cls.purge_expired(self.db)
async def init_api_tokens(self): async def init_api_tokens(self):
"""Load predefined API tokens (for services) into database""" """Load predefined API tokens (for services) into database"""
await self._add_tokens(self.service_tokens, kind='service') await self._add_tokens(self.service_tokens, kind='service')
await self._add_tokens(self.api_tokens, kind='user') await self._add_tokens(self.api_tokens, kind='user')
@@ -2292,13 +2293,30 @@ class JupyterHub(Application):
login_url=url_path_join(base_url, 'login'), login_url=url_path_join(base_url, 'login'),
token_expires_in=self.oauth_token_expires_in, token_expires_in=self.oauth_token_expires_in,
) )
# ensure the default oauth client exists
if (
not self.db.query(orm.OAuthClient)
.filter_by(identifier="jupyterhub")
.first()
):
# create the oauth client for jupyterhub itself
# this allows us to distinguish between orphaned tokens
# (failed cascade deletion) and tokens issued by the hub
# it has no client_secret, which means it cannot be used
# to make requests
self.oauth_provider.add_client(
client_id="jupyterhub",
client_secret="",
redirect_uri="",
description="JupyterHub",
)
def cleanup_oauth_clients(self): def cleanup_oauth_clients(self):
"""Cleanup any OAuth clients that shouldn't be in the database. """Cleanup any OAuth clients that shouldn't be in the database.
This should mainly be services that have been removed from configuration or renamed. This should mainly be services that have been removed from configuration or renamed.
""" """
oauth_client_ids = set() oauth_client_ids = {"jupyterhub"}
for service in self._service_map.values(): for service in self._service_map.values():
if service.oauth_available: if service.oauth_available:
oauth_client_ids.add(service.oauth_client_id) oauth_client_ids.add(service.oauth_client_id)

View File

@@ -247,26 +247,6 @@ class BaseHandler(RequestHandler):
return None return None
return match.group(1) return match.group(1)
def get_current_user_oauth_token(self):
"""Get the current user identified by OAuth access token
Separate from API token because OAuth access tokens
can only be used for identifying users,
not using the API.
"""
token = self.get_auth_token()
if token is None:
return None
orm_token = orm.OAuthAccessToken.find(self.db, token)
if orm_token is None:
return None
now = datetime.utcnow()
recorded = self._record_activity(orm_token, now)
if self._record_activity(orm_token.user, now) or recorded:
self.db.commit()
return self._user_from_orm(orm_token.user)
def _record_activity(self, obj, timestamp=None): def _record_activity(self, obj, timestamp=None):
"""record activity on an ORM object """record activity on an ORM object
@@ -373,7 +353,7 @@ class BaseHandler(RequestHandler):
# FIXME: scopes should give us better control than this # FIXME: scopes should give us better control than this
# don't consider API requests originating from a server # don't consider API requests originating from a server
# to be activity from the user # to be activity from the user
if not orm_token.note.startswith("Server at "): if not orm_token.note or not orm_token.note.startswith("Server at "):
recorded = self._record_activity(orm_token.user, now) or recorded recorded = self._record_activity(orm_token.user, now) or recorded
if recorded: if recorded:
self.db.commit() self.db.commit()
@@ -501,10 +481,8 @@ class BaseHandler(RequestHandler):
# don't clear session tokens if not logged in, # don't clear session tokens if not logged in,
# because that could be a malicious logout request! # because that could be a malicious logout request!
count = 0 count = 0
for access_token in ( for access_token in self.db.query(orm.APIToken).filter_by(
self.db.query(orm.OAuthAccessToken) user_id=user.id, session_id=session_id
.filter(orm.OAuthAccessToken.user_id == user.id)
.filter(orm.OAuthAccessToken.session_id == session_id)
): ):
self.db.delete(access_token) self.db.delete(access_token)
count += 1 count += 1

View File

@@ -552,36 +552,32 @@ class TokenPageHandler(BaseHandler):
return (token.last_activity or never, token.created or never) return (token.last_activity or never, token.created or never)
now = datetime.utcnow() now = datetime.utcnow()
api_tokens = []
for token in sorted(user.api_tokens, key=sort_key, reverse=True):
if token.expires_at and token.expires_at < now:
self.db.delete(token)
self.db.commit()
continue
api_tokens.append(token)
# group oauth client tokens by client id # group oauth client tokens by client id
# AccessTokens have expires_at as an integer timestamp all_tokens = defaultdict(list)
now_timestamp = now.timestamp() for token in sorted(user.api_tokens, key=sort_key, reverse=True):
oauth_tokens = defaultdict(list) if token.expires_at and token.expires_at < now:
for token in user.oauth_tokens: self.log.warning(f"Deleting expired token {token}")
if token.expires_at and token.expires_at < now_timestamp:
self.log.warning("Deleting expired token")
self.db.delete(token) self.db.delete(token)
self.db.commit() self.db.commit()
continue continue
if not token.client_id: if not token.client_id:
# token should have been deleted when client was deleted # token should have been deleted when client was deleted
self.log.warning("Deleting stale oauth token for %s", user.name) self.log.warning("Deleting stale oauth token {token}")
self.db.delete(token) self.db.delete(token)
self.db.commit() self.db.commit()
continue continue
oauth_tokens[token.client_id].append(token) all_tokens[token.client_id].append(token)
# individually list tokens issued by jupyterhub itself
api_tokens = all_tokens.pop("jupyterhub", [])
# group all other tokens issued under their owners
# get the earliest created and latest last_activity # get the earliest created and latest last_activity
# timestamp for a given oauth client # timestamp for a given oauth client
oauth_clients = [] oauth_clients = []
for client_id, tokens in oauth_tokens.items():
for client_id, tokens in all_tokens.items():
created = tokens[0].created created = tokens[0].created
last_activity = tokens[0].last_activity last_activity = tokens[0].last_activity
for token in tokens[1:]: for token in tokens[1:]:

View File

@@ -2,18 +2,18 @@
implements https://oauthlib.readthedocs.io/en/latest/oauth2/server.html implements https://oauthlib.readthedocs.io/en/latest/oauth2/server.html
""" """
from datetime import timedelta
from oauthlib import uri_validate from oauthlib import uri_validate
from oauthlib.oauth2 import RequestValidator from oauthlib.oauth2 import RequestValidator
from oauthlib.oauth2 import WebApplicationServer from oauthlib.oauth2 import WebApplicationServer
from oauthlib.oauth2.rfc6749.grant_types import authorization_code from oauthlib.oauth2.rfc6749.grant_types import authorization_code
from oauthlib.oauth2.rfc6749.grant_types import base from oauthlib.oauth2.rfc6749.grant_types import base
from tornado.escape import url_escape
from tornado.log import app_log from tornado.log import app_log
from .. import orm from .. import orm
from ..utils import compare_token from ..utils import compare_token
from ..utils import hash_token from ..utils import hash_token
from ..utils import url_path_join
# patch absolute-uri check # patch absolute-uri check
# because we want to allow relative uri oauth # because we want to allow relative uri oauth
@@ -60,6 +60,9 @@ class JupyterHubRequestValidator(RequestValidator):
) )
if oauth_client is None: if oauth_client is None:
return False return False
if not client_secret or not oauth_client.secret:
# disallow authentication with no secret
return False
if not compare_token(oauth_client.secret, client_secret): if not compare_token(oauth_client.secret, client_secret):
app_log.warning("Client secret mismatch for %s", client_id) app_log.warning("Client secret mismatch for %s", client_id)
return False return False
@@ -339,10 +342,10 @@ class JupyterHubRequestValidator(RequestValidator):
.filter_by(identifier=request.client.client_id) .filter_by(identifier=request.client.client_id)
.first() .first()
) )
orm_access_token = orm.OAuthAccessToken( orm_access_token = orm.APIToken.new(
client=client, client_id=client.identifier,
grant_type=orm.GrantType.authorization_code, grant_type=orm.GrantType.authorization_code,
expires_at=orm.OAuthAccessToken.now() + token['expires_in'], expires_at=orm.APIToken.now() + timedelta(seconds=token['expires_in']),
refresh_token=token['refresh_token'], refresh_token=token['refresh_token'],
# TODO: save scopes, # TODO: save scopes,
# scopes=scopes, # scopes=scopes,
@@ -412,6 +415,8 @@ class JupyterHubRequestValidator(RequestValidator):
) )
if orm_client is None: if orm_client is None:
return False return False
if not orm_client.secret:
return False
request.client = orm_client request.client = orm_client
return True return True
@@ -574,14 +579,16 @@ class JupyterHubOAuthServer(WebApplicationServer):
app_log.info(f'Creating oauth client {client_id}') app_log.info(f'Creating oauth client {client_id}')
else: else:
app_log.info(f'Updating oauth client {client_id}') app_log.info(f'Updating oauth client {client_id}')
orm_client.secret = hash_token(client_secret) orm_client.secret = hash_token(client_secret) if client_secret else ""
orm_client.redirect_uri = redirect_uri orm_client.redirect_uri = redirect_uri
orm_client.description = description orm_client.description = description
self.db.commit() self.db.commit()
def fetch_by_client_id(self, client_id): def fetch_by_client_id(self, client_id):
"""Find a client by its id""" """Find a client by its id"""
return self.db.query(orm.OAuthClient).filter_by(identifier=client_id).first() client = self.db.query(orm.OAuthClient).filter_by(identifier=client_id).first()
if client and client.secret:
return client
def make_provider(session_factory, url_prefix, login_url, **oauth_server_kwargs): def make_provider(session_factory, url_prefix, login_url, **oauth_server_kwargs):

View File

@@ -277,9 +277,6 @@ class User(Base):
last_activity = Column(DateTime, nullable=True) last_activity = Column(DateTime, nullable=True)
api_tokens = relationship("APIToken", backref="user", cascade="all, delete-orphan") api_tokens = relationship("APIToken", backref="user", cascade="all, delete-orphan")
oauth_tokens = relationship(
"OAuthAccessToken", backref="user", cascade="all, delete-orphan"
)
oauth_codes = relationship( oauth_codes = relationship(
"OAuthCode", backref="user", cascade="all, delete-orphan" "OAuthCode", backref="user", cascade="all, delete-orphan"
) )
@@ -485,7 +482,9 @@ class Hashed(Expiring):
@classmethod @classmethod
def check_token(cls, db, token): def check_token(cls, db, token):
"""Check if a token is acceptable""" """Check if a token is acceptable"""
print("checking", cls, token, len(token), cls.min_length)
if len(token) < cls.min_length: if len(token) < cls.min_length:
print("raising")
raise ValueError( raise ValueError(
"Tokens must be at least %i characters, got %r" "Tokens must be at least %i characters, got %r"
% (cls.min_length, token) % (cls.min_length, token)
@@ -530,6 +529,20 @@ class Hashed(Expiring):
return orm_token return orm_token
# ------------------------------------
# OAuth tables
# ------------------------------------
class GrantType(enum.Enum):
# we only use authorization_code for now
authorization_code = 'authorization_code'
implicit = 'implicit'
password = 'password'
client_credentials = 'client_credentials'
refresh_token = 'refresh_token'
class APIToken(Hashed, Base): class APIToken(Hashed, Base):
"""An API token""" """An API token"""
@@ -548,6 +561,15 @@ class APIToken(Hashed, Base):
def api_id(self): def api_id(self):
return 'a%i' % self.id return 'a%i' % self.id
# added in 2.0
client_id = Column(
Unicode(255), ForeignKey('oauth_clients.identifier', ondelete='CASCADE')
)
grant_type = Column(Enum(GrantType), nullable=False)
refresh_token = Column(Unicode(255))
# the browser session id associated with a given token
session_id = Column(Unicode(255))
# token metadata for bookkeeping # token metadata for bookkeeping
now = datetime.utcnow # for expiry now = datetime.utcnow # for expiry
created = Column(DateTime, default=datetime.utcnow) created = Column(DateTime, default=datetime.utcnow)
@@ -566,8 +588,12 @@ class APIToken(Hashed, Base):
# this shouldn't happen # this shouldn't happen
kind = 'owner' kind = 'owner'
name = 'unknown' name = 'unknown'
return "<{cls}('{pre}...', {kind}='{name}')>".format( return "<{cls}('{pre}...', {kind}='{name}', client_id={client_id!r})>".format(
cls=self.__class__.__name__, pre=self.prefix, kind=kind, name=name cls=self.__class__.__name__,
pre=self.prefix,
kind=kind,
name=name,
client_id=self.client_id,
) )
@classmethod @classmethod
@@ -588,6 +614,14 @@ class APIToken(Hashed, Base):
raise ValueError("kind must be 'user', 'service', or None, not %r" % kind) raise ValueError("kind must be 'user', 'service', or None, not %r" % kind)
for orm_token in prefix_match: for orm_token in prefix_match:
if orm_token.match(token): if orm_token.match(token):
if not orm_token.client_id:
app_log.warning(
"Deleting stale oauth token for %s with no client",
orm_token.user and orm_token.user.name,
)
db.delete(orm_token)
db.commit()
return
return orm_token return orm_token
@classmethod @classmethod
@@ -600,6 +634,7 @@ class APIToken(Hashed, Base):
note='', note='',
generated=True, generated=True,
expires_in=None, expires_in=None,
client_id='jupyterhub',
): ):
"""Generate a new API token for a user or service""" """Generate a new API token for a user or service"""
assert user or service assert user or service
@@ -614,7 +649,12 @@ class APIToken(Hashed, Base):
cls.check_token(db, token) cls.check_token(db, token)
# two stages to ensure orm_token.generated has been set # two stages to ensure orm_token.generated has been set
# before token setter is called # before token setter is called
orm_token = cls(generated=generated, note=note or '') orm_token = cls(
generated=generated,
note=note or '',
grant_type=GrantType.authorization_code,
client_id=client_id,
)
orm_token.token = token orm_token.token = token
if user: if user:
assert user.id is not None assert user.id is not None
@@ -641,76 +681,6 @@ class APIToken(Hashed, Base):
return token return token
# ------------------------------------
# OAuth tables
# ------------------------------------
class GrantType(enum.Enum):
# we only use authorization_code for now
authorization_code = 'authorization_code'
implicit = 'implicit'
password = 'password'
client_credentials = 'client_credentials'
refresh_token = 'refresh_token'
class OAuthAccessToken(Hashed, Base):
__tablename__ = 'oauth_access_tokens'
id = Column(Integer, primary_key=True, autoincrement=True)
@staticmethod
def now():
return datetime.utcnow().timestamp()
@property
def api_id(self):
return 'o%i' % self.id
client_id = Column(
Unicode(255), ForeignKey('oauth_clients.identifier', ondelete='CASCADE')
)
grant_type = Column(Enum(GrantType), nullable=False)
expires_at = Column(Integer)
refresh_token = Column(Unicode(255))
# TODO: drop refresh_expires_at. Refresh tokens shouldn't expire
refresh_expires_at = Column(Integer)
user_id = Column(Integer, ForeignKey('users.id', ondelete='CASCADE'))
service = None # for API-equivalence with APIToken
# the browser session id associated with a given token
session_id = Column(Unicode(255))
# from Hashed
hashed = Column(Unicode(255), unique=True)
prefix = Column(Unicode(16), index=True)
created = Column(DateTime, default=datetime.utcnow)
last_activity = Column(DateTime, nullable=True)
def __repr__(self):
return "<{cls}('{prefix}...', client_id={client_id!r}, user={user!r}, expires_in={expires_in}>".format(
cls=self.__class__.__name__,
client_id=self.client_id,
user=self.user and self.user.name,
prefix=self.prefix,
expires_in=self.expires_in,
)
@classmethod
def find(cls, db, token):
orm_token = super().find(db, token)
if orm_token and not orm_token.client_id:
app_log.warning(
"Deleting stale oauth token for %s with no client",
orm_token.user and orm_token.user.name,
)
db.delete(orm_token)
db.commit()
return
return orm_token
class OAuthCode(Expiring, Base): class OAuthCode(Expiring, Base):
__tablename__ = 'oauth_codes' __tablename__ = 'oauth_codes'
@@ -752,7 +722,7 @@ class OAuthClient(Base):
return self.identifier return self.identifier
access_tokens = relationship( access_tokens = relationship(
OAuthAccessToken, backref='client', cascade='all, delete-orphan' APIToken, backref='client', cascade='all, delete-orphan'
) )
codes = relationship(OAuthCode, backref='client', cascade='all, delete-orphan') codes = relationship(OAuthCode, backref='client', cascade='all, delete-orphan')

View File

@@ -51,6 +51,7 @@ from traitlets import Dict
from traitlets import HasTraits from traitlets import HasTraits
from traitlets import Instance from traitlets import Instance
from traitlets import Unicode from traitlets import Unicode
from traitlets import validate
from traitlets.config import LoggingConfigurable from traitlets.config import LoggingConfigurable
from .. import orm from .. import orm
@@ -284,6 +285,15 @@ class Service(LoggingConfigurable):
def _default_client_id(self): def _default_client_id(self):
return 'service-%s' % self.name return 'service-%s' % self.name
@validate("oauth_client_id")
def _validate_client_id(self, proposal):
if not proposal.value.startswith("service-"):
raise ValueError(
f"service {self.name} has oauth_client_id='{proposal.value}'."
" Service oauth client ids must start with 'service-'"
)
return proposal.value
oauth_redirect_uri = Unicode( oauth_redirect_uri = Unicode(
help="""OAuth redirect URI for this service. help="""OAuth redirect URI for this service.

View File

@@ -70,7 +70,11 @@ def populate_db(url):
code = orm.OAuthCode(client_id=client.identifier) code = orm.OAuthCode(client_id=client.identifier)
db.add(code) db.add(code)
db.commit() db.commit()
access_token = orm.OAuthAccessToken( if jupyterhub.version_info < (2, 0):
Token = orm.OAuthAccessToken
else:
Token = orm.APIToken
access_token = Token(
client_id=client.identifier, client_id=client.identifier,
user_id=user.id, user_id=user.id,
grant_type=orm.GrantType.authorization_code, grant_type=orm.GrantType.authorization_code,

View File

@@ -273,7 +273,7 @@ async def test_get_self(app):
oauth_client = orm.OAuthClient(identifier='eurydice') oauth_client = orm.OAuthClient(identifier='eurydice')
db.add(oauth_client) db.add(oauth_client)
db.commit() db.commit()
oauth_token = orm.OAuthAccessToken( oauth_token = orm.APIToken(
user=u.orm_user, user=u.orm_user,
client=oauth_client, client=oauth_client,
token=token, token=token,
@@ -1423,12 +1423,11 @@ async def test_token_list(app, as_user, for_user, status):
if status != 200: if status != 200:
return return
reply = r.json() reply = r.json()
assert sorted(reply) == ['api_tokens', 'oauth_tokens'] assert sorted(reply) == ['api_tokens']
assert len(reply['api_tokens']) == len(for_user_obj.api_tokens) assert len(reply['api_tokens']) == len(for_user_obj.api_tokens)
assert all(token['user'] == for_user for token in reply['api_tokens']) assert all(token['user'] == for_user for token in reply['api_tokens'])
assert all(token['user'] == for_user for token in reply['oauth_tokens'])
# validate individual token ids # validate individual token ids
for token in reply['api_tokens'] + reply['oauth_tokens']: for token in reply['api_tokens']:
r = await api_request( r = await api_request(
app, 'users', for_user, 'tokens', token['id'], headers=headers app, 'users', for_user, 'tokens', token['id'], headers=headers
) )

View File

@@ -355,7 +355,7 @@ def test_user_delete_cascade(db):
spawner.server = server = orm.Server() spawner.server = server = orm.Server()
oauth_code = orm.OAuthCode(client=oauth_client, user=user) oauth_code = orm.OAuthCode(client=oauth_client, user=user)
db.add(oauth_code) db.add(oauth_code)
oauth_token = orm.OAuthAccessToken( oauth_token = orm.APIToken(
client=oauth_client, user=user, grant_type=orm.GrantType.authorization_code client=oauth_client, user=user, grant_type=orm.GrantType.authorization_code
) )
db.add(oauth_token) db.add(oauth_token)
@@ -377,7 +377,7 @@ def test_user_delete_cascade(db):
assert_not_found(db, orm.Spawner, spawner_id) assert_not_found(db, orm.Spawner, spawner_id)
assert_not_found(db, orm.Server, server_id) assert_not_found(db, orm.Server, server_id)
assert_not_found(db, orm.OAuthCode, oauth_code_id) assert_not_found(db, orm.OAuthCode, oauth_code_id)
assert_not_found(db, orm.OAuthAccessToken, oauth_token_id) assert_not_found(db, orm.APIToken, oauth_token_id)
def test_oauth_client_delete_cascade(db): def test_oauth_client_delete_cascade(db):
@@ -391,12 +391,12 @@ def test_oauth_client_delete_cascade(db):
# these should all be deleted automatically when the user goes away # these should all be deleted automatically when the user goes away
oauth_code = orm.OAuthCode(client=oauth_client, user=user) oauth_code = orm.OAuthCode(client=oauth_client, user=user)
db.add(oauth_code) db.add(oauth_code)
oauth_token = orm.OAuthAccessToken( oauth_token = orm.APIToken(
client=oauth_client, user=user, grant_type=orm.GrantType.authorization_code client=oauth_client, user=user, grant_type=orm.GrantType.authorization_code
) )
db.add(oauth_token) db.add(oauth_token)
db.commit() db.commit()
assert user.oauth_tokens == [oauth_token] assert user.tokens == [oauth_token]
# record all of the ids # record all of the ids
oauth_code_id = oauth_code.id oauth_code_id = oauth_code.id
@@ -408,8 +408,8 @@ def test_oauth_client_delete_cascade(db):
# verify that everything gets deleted # verify that everything gets deleted
assert_not_found(db, orm.OAuthCode, oauth_code_id) assert_not_found(db, orm.OAuthCode, oauth_code_id)
assert_not_found(db, orm.OAuthAccessToken, oauth_token_id) assert_not_found(db, orm.APIToken, oauth_token_id)
assert user.oauth_tokens == [] assert user.tokens == []
assert user.oauth_codes == [] assert user.oauth_codes == []
@@ -510,32 +510,32 @@ def test_expiring_api_token(app, user):
def test_expiring_oauth_token(app, user): def test_expiring_oauth_token(app, user):
db = app.db db = app.db
token = "abc123" token = "abc123"
now = orm.OAuthAccessToken.now now = orm.APIToken.now
client = orm.OAuthClient(identifier="xxx", secret="yyy") client = orm.OAuthClient(identifier="xxx", secret="yyy")
db.add(client) db.add(client)
orm_token = orm.OAuthAccessToken( orm_token = orm.APIToken(
token=token, token=token,
grant_type=orm.GrantType.authorization_code, grant_type=orm.GrantType.authorization_code,
client=client, client=client,
user=user, user=user,
expires_at=now() + 30, expires_at=now() + datetime.timedelta(seconds=30),
) )
db.add(orm_token) db.add(orm_token)
db.commit() db.commit()
found = orm.OAuthAccessToken.find(db, token) found = orm.APIToken.find(db, token)
assert found is orm_token assert found is orm_token
# purge_expired doesn't delete non-expired # purge_expired doesn't delete non-expired
orm.OAuthAccessToken.purge_expired(db) orm.APIToken.purge_expired(db)
found = orm.OAuthAccessToken.find(db, token) found = orm.APIToken.find(db, token)
assert found is orm_token assert found is orm_token
with mock.patch.object(orm.OAuthAccessToken, 'now', lambda: now() + 60): with mock.patch.object(orm.APIToken, 'now', lambda: now() + 60):
found = orm.OAuthAccessToken.find(db, token) found = orm.APIToken.find(db, token)
assert found is None assert found is None
assert orm_token in db.query(orm.OAuthAccessToken) assert orm_token in db.query(orm.APIToken)
orm.OAuthAccessToken.purge_expired(db) orm.APIToken.purge_expired(db)
assert orm_token not in db.query(orm.OAuthAccessToken) assert orm_token not in db.query(orm.APIToken)
def test_expiring_oauth_code(app, user): def test_expiring_oauth_code(app, user):

View File

@@ -869,7 +869,7 @@ async def test_oauth_token_page(app):
user = app.users[orm.User.find(app.db, name)] user = app.users[orm.User.find(app.db, name)]
client = orm.OAuthClient(identifier='token') client = orm.OAuthClient(identifier='token')
app.db.add(client) app.db.add(client)
oauth_token = orm.OAuthAccessToken( oauth_token = orm.APIToken(
client=client, user=user, grant_type=orm.GrantType.authorization_code client=client, user=user, grant_type=orm.GrantType.authorization_code
) )
app.db.add(oauth_token) app.db.add(oauth_token)

View File

@@ -444,11 +444,7 @@ async def test_oauth_logout(app, mockservice_url):
def auth_tokens(): def auth_tokens():
"""Return list of OAuth access tokens for the user""" """Return list of OAuth access tokens for the user"""
return list( return list(app.db.query(orm.APIToken).filter_by(user_id=app_user.id))
app.db.query(orm.OAuthAccessToken).filter(
orm.OAuthAccessToken.user_id == app_user.id
)
)
# ensure we start empty # ensure we start empty
assert auth_tokens() == [] assert auth_tokens() == []