mirror of
https://github.com/jupyterhub/jupyterhub.git
synced 2025-10-10 03:23:04 +00:00
remove separate oauth tokens
- merge oauth token fields into APITokens - create oauth client 'jupyterhub' which owns current API tokens - db upgrade is currently to drop both token tables, and force recreation on next start
This commit is contained in:
119
jupyterhub/alembic/versions/833da8570507_rbac.py
Normal file
119
jupyterhub/alembic/versions/833da8570507_rbac.py
Normal file
@@ -0,0 +1,119 @@
|
||||
"""rbac
|
||||
|
||||
Revision ID: 833da8570507
|
||||
Revises: 4dc2d5a8c53c
|
||||
Create Date: 2021-02-17 15:03:04.360368
|
||||
|
||||
"""
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '833da8570507'
|
||||
down_revision = '4dc2d5a8c53c'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
def upgrade():
|
||||
# FIXME: currently drops all api tokens and forces recreation!
|
||||
# this ensures a consistent database, but requires:
|
||||
# 1. all servers to be stopped for upgrade (maybe unavoidable anyway)
|
||||
# 2. any manually issued/stored tokens to be re-issued
|
||||
|
||||
# tokens loaded via configuration will be recreated on launch and unaffected
|
||||
op.drop_table('api_tokens')
|
||||
op.drop_table('oauth_access_tokens')
|
||||
return
|
||||
# TODO: explore in-place migration. This seems hard!
|
||||
# 1. add new columns in api tokens
|
||||
# 2. fill default fields (client_id='jupyterhub') for all api tokens
|
||||
# 3. copy oauth tokens into api tokens
|
||||
# 4. give oauth tokens 'identify' scopes
|
||||
|
||||
c = op.get_bind()
|
||||
naming_convention = {
|
||||
"fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s",
|
||||
}
|
||||
with op.batch_alter_table(
|
||||
"api_tokens",
|
||||
naming_convention=naming_convention,
|
||||
) as batch_op:
|
||||
batch_op.add_column(
|
||||
sa.Column(
|
||||
'client_id',
|
||||
sa.Unicode(255),
|
||||
# sa.ForeignKey('oauth_clients.identifier', ondelete='CASCADE'),
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
# batch_cursor = op.get_bind()
|
||||
# batch_cursor.execute(
|
||||
# """
|
||||
# UPDATE api_tokens
|
||||
# SET client_id='jupyterhub'
|
||||
# WHERE client_id IS NULL
|
||||
# """
|
||||
# )
|
||||
batch_op.create_foreign_key(
|
||||
"fk_api_token_client_id",
|
||||
# 'api_tokens',
|
||||
'oauth_clients',
|
||||
['client_id'],
|
||||
['identifier'],
|
||||
ondelete='CASCADE',
|
||||
)
|
||||
|
||||
c.execute(
|
||||
"""
|
||||
UPDATE api_tokens
|
||||
SET client_id='jupyterhub'
|
||||
WHERE client_id IS NULL
|
||||
"""
|
||||
)
|
||||
|
||||
op.add_column(
|
||||
'api_tokens',
|
||||
sa.Column(
|
||||
'grant_type',
|
||||
sa.Enum(
|
||||
'authorization_code',
|
||||
'implicit',
|
||||
'password',
|
||||
'client_credentials',
|
||||
'refresh_token',
|
||||
name='granttype',
|
||||
),
|
||||
server_default='authorization_code',
|
||||
nullable=False,
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
'api_tokens', sa.Column('refresh_token', sa.Unicode(length=255), nullable=True)
|
||||
)
|
||||
op.add_column(
|
||||
'api_tokens', sa.Column('session_id', sa.Unicode(length=255), nullable=True)
|
||||
)
|
||||
|
||||
# TODO: migrate OAuth tokens into APIToken table
|
||||
|
||||
op.drop_index('ix_oauth_access_tokens_prefix', table_name='oauth_access_tokens')
|
||||
op.drop_table('oauth_access_tokens')
|
||||
|
||||
|
||||
def downgrade():
|
||||
# delete OAuth tokens for non-jupyterhub clients
|
||||
# drop new columns from api tokens
|
||||
op.drop_constraint(None, 'api_tokens', type_='foreignkey')
|
||||
op.drop_column('api_tokens', 'session_id')
|
||||
op.drop_column('api_tokens', 'refresh_token')
|
||||
op.drop_column('api_tokens', 'grant_type')
|
||||
op.drop_column('api_tokens', 'client_id')
|
||||
# FIXME: only drop tokens whose client id is not 'jupyterhub'
|
||||
# until then, drop all tokens
|
||||
op.drop_table("api_tokens")
|
||||
|
||||
op.drop_table('api_token_role_map')
|
||||
op.drop_table('service_role_map')
|
||||
op.drop_table('user_role_map')
|
||||
op.drop_table('roles')
|
@@ -29,8 +29,6 @@ class TokenAPIHandler(APIHandler):
|
||||
"/authorizations/token/:token endpoint is deprecated in JupyterHub 2.0. Use /api/user"
|
||||
)
|
||||
orm_token = orm.APIToken.find(self.db, token)
|
||||
if orm_token is None:
|
||||
orm_token = orm.OAuthAccessToken.find(self.db, token)
|
||||
if orm_token is None:
|
||||
raise web.HTTPError(404)
|
||||
|
||||
|
@@ -205,23 +205,6 @@ class APIHandler(BaseHandler):
|
||||
|
||||
def token_model(self, token):
|
||||
"""Get the JSON model for an APIToken"""
|
||||
expires_at = None
|
||||
if isinstance(token, orm.APIToken):
|
||||
kind = 'api_token'
|
||||
roles = [r.name for r in token.roles]
|
||||
extra = {'note': token.note}
|
||||
expires_at = token.expires_at
|
||||
elif isinstance(token, orm.OAuthAccessToken):
|
||||
kind = 'oauth'
|
||||
# oauth tokens do not bear roles
|
||||
roles = []
|
||||
extra = {'oauth_client': token.client.description or token.client.client_id}
|
||||
if token.expires_at:
|
||||
expires_at = datetime.fromtimestamp(token.expires_at)
|
||||
else:
|
||||
raise TypeError(
|
||||
"token must be an APIToken or OAuthAccessToken, not %s" % type(token)
|
||||
)
|
||||
|
||||
if token.user:
|
||||
owner_key = 'user'
|
||||
@@ -234,13 +217,14 @@ class APIHandler(BaseHandler):
|
||||
model = {
|
||||
owner_key: owner,
|
||||
'id': token.api_id,
|
||||
'kind': kind,
|
||||
'roles': [role for role in roles],
|
||||
'kind': 'api_token',
|
||||
'roles': [r.name for r in token.roles],
|
||||
'created': isoformat(token.created),
|
||||
'last_activity': isoformat(token.last_activity),
|
||||
'expires_at': isoformat(expires_at),
|
||||
'expires_at': isoformat(token.expires_at),
|
||||
'note': token.note,
|
||||
'oauth_client': token.client.description or token.client.client_id,
|
||||
}
|
||||
model.update(extra)
|
||||
return model
|
||||
|
||||
def user_model(self, user):
|
||||
|
@@ -32,9 +32,6 @@ class SelfAPIHandler(APIHandler):
|
||||
|
||||
async def get(self):
|
||||
user = self.current_user
|
||||
if user is None:
|
||||
# whoami can be accessed via oauth token
|
||||
user = self.get_current_user_oauth_token()
|
||||
if user is None:
|
||||
raise web.HTTPError(403)
|
||||
if isinstance(user, orm.Service):
|
||||
@@ -316,17 +313,7 @@ class UserTokenListAPIHandler(APIHandler):
|
||||
continue
|
||||
api_tokens.append(self.token_model(token))
|
||||
|
||||
oauth_tokens = []
|
||||
# OAuth tokens use integer timestamps
|
||||
now_timestamp = now.timestamp()
|
||||
for token in sorted(user.oauth_tokens, key=sort_key):
|
||||
if token.expires_at and token.expires_at < now_timestamp:
|
||||
# exclude expired tokens
|
||||
self.db.delete(token)
|
||||
self.db.commit()
|
||||
continue
|
||||
oauth_tokens.append(self.token_model(token))
|
||||
self.write(json.dumps({'api_tokens': api_tokens, 'oauth_tokens': oauth_tokens}))
|
||||
self.write(json.dumps({'api_tokens': api_tokens}))
|
||||
|
||||
# Todo: Set to @needs_scope('users:tokens')
|
||||
async def post(self, user_name):
|
||||
@@ -410,19 +397,15 @@ class UserTokenAPIHandler(APIHandler):
|
||||
(e.g. wrong owner, invalid key format, etc.)
|
||||
"""
|
||||
not_found = "No such token %s for user %s" % (token_id, user.name)
|
||||
prefix, id_ = token_id[0], token_id[1:]
|
||||
if prefix == 'a':
|
||||
Token = orm.APIToken
|
||||
elif prefix == 'o':
|
||||
Token = orm.OAuthAccessToken
|
||||
else:
|
||||
prefix, id_ = token_id[:1], token_id[1:]
|
||||
if prefix != 'a':
|
||||
raise web.HTTPError(404, not_found)
|
||||
try:
|
||||
id_ = int(id_)
|
||||
except ValueError:
|
||||
raise web.HTTPError(404, not_found)
|
||||
|
||||
orm_token = self.db.query(Token).filter(Token.id == id_).first()
|
||||
orm_token = self.db.query(orm.APIToken).filter_by(id=id_).first()
|
||||
if orm_token is None or orm_token.user is not user.orm_user:
|
||||
raise web.HTTPError(404, "Token not found %s", orm_token)
|
||||
return orm_token
|
||||
@@ -444,10 +427,10 @@ class UserTokenAPIHandler(APIHandler):
|
||||
raise web.HTTPError(404, "No such user: %s" % user_name)
|
||||
token = self.find_token_by_id(user, token_id)
|
||||
# deleting an oauth token deletes *all* oauth tokens for that client
|
||||
if isinstance(token, orm.OAuthAccessToken):
|
||||
client_id = token.client_id
|
||||
if token.client_id != "jupyterhub":
|
||||
tokens = [
|
||||
token for token in user.oauth_tokens if token.client_id == client_id
|
||||
token for token in user.api_tokens if token.client_id == client_id
|
||||
]
|
||||
else:
|
||||
tokens = [token]
|
||||
|
@@ -2014,12 +2014,13 @@ class JupyterHub(Application):
|
||||
run periodically
|
||||
"""
|
||||
# this should be all the subclasses of Expiring
|
||||
for cls in (orm.APIToken, orm.OAuthAccessToken, orm.OAuthCode):
|
||||
for cls in (orm.APIToken, orm.OAuthCode):
|
||||
self.log.debug("Purging expired {name}s".format(name=cls.__name__))
|
||||
cls.purge_expired(self.db)
|
||||
|
||||
async def init_api_tokens(self):
|
||||
"""Load predefined API tokens (for services) into database"""
|
||||
|
||||
await self._add_tokens(self.service_tokens, kind='service')
|
||||
await self._add_tokens(self.api_tokens, kind='user')
|
||||
|
||||
@@ -2292,13 +2293,30 @@ class JupyterHub(Application):
|
||||
login_url=url_path_join(base_url, 'login'),
|
||||
token_expires_in=self.oauth_token_expires_in,
|
||||
)
|
||||
# ensure the default oauth client exists
|
||||
if (
|
||||
not self.db.query(orm.OAuthClient)
|
||||
.filter_by(identifier="jupyterhub")
|
||||
.first()
|
||||
):
|
||||
# create the oauth client for jupyterhub itself
|
||||
# this allows us to distinguish between orphaned tokens
|
||||
# (failed cascade deletion) and tokens issued by the hub
|
||||
# it has no client_secret, which means it cannot be used
|
||||
# to make requests
|
||||
self.oauth_provider.add_client(
|
||||
client_id="jupyterhub",
|
||||
client_secret="",
|
||||
redirect_uri="",
|
||||
description="JupyterHub",
|
||||
)
|
||||
|
||||
def cleanup_oauth_clients(self):
|
||||
"""Cleanup any OAuth clients that shouldn't be in the database.
|
||||
|
||||
This should mainly be services that have been removed from configuration or renamed.
|
||||
"""
|
||||
oauth_client_ids = set()
|
||||
oauth_client_ids = {"jupyterhub"}
|
||||
for service in self._service_map.values():
|
||||
if service.oauth_available:
|
||||
oauth_client_ids.add(service.oauth_client_id)
|
||||
|
@@ -247,26 +247,6 @@ class BaseHandler(RequestHandler):
|
||||
return None
|
||||
return match.group(1)
|
||||
|
||||
def get_current_user_oauth_token(self):
|
||||
"""Get the current user identified by OAuth access token
|
||||
|
||||
Separate from API token because OAuth access tokens
|
||||
can only be used for identifying users,
|
||||
not using the API.
|
||||
"""
|
||||
token = self.get_auth_token()
|
||||
if token is None:
|
||||
return None
|
||||
orm_token = orm.OAuthAccessToken.find(self.db, token)
|
||||
if orm_token is None:
|
||||
return None
|
||||
|
||||
now = datetime.utcnow()
|
||||
recorded = self._record_activity(orm_token, now)
|
||||
if self._record_activity(orm_token.user, now) or recorded:
|
||||
self.db.commit()
|
||||
return self._user_from_orm(orm_token.user)
|
||||
|
||||
def _record_activity(self, obj, timestamp=None):
|
||||
"""record activity on an ORM object
|
||||
|
||||
@@ -373,7 +353,7 @@ class BaseHandler(RequestHandler):
|
||||
# FIXME: scopes should give us better control than this
|
||||
# don't consider API requests originating from a server
|
||||
# to be activity from the user
|
||||
if not orm_token.note.startswith("Server at "):
|
||||
if not orm_token.note or not orm_token.note.startswith("Server at "):
|
||||
recorded = self._record_activity(orm_token.user, now) or recorded
|
||||
if recorded:
|
||||
self.db.commit()
|
||||
@@ -501,10 +481,8 @@ class BaseHandler(RequestHandler):
|
||||
# don't clear session tokens if not logged in,
|
||||
# because that could be a malicious logout request!
|
||||
count = 0
|
||||
for access_token in (
|
||||
self.db.query(orm.OAuthAccessToken)
|
||||
.filter(orm.OAuthAccessToken.user_id == user.id)
|
||||
.filter(orm.OAuthAccessToken.session_id == session_id)
|
||||
for access_token in self.db.query(orm.APIToken).filter_by(
|
||||
user_id=user.id, session_id=session_id
|
||||
):
|
||||
self.db.delete(access_token)
|
||||
count += 1
|
||||
|
@@ -552,36 +552,32 @@ class TokenPageHandler(BaseHandler):
|
||||
return (token.last_activity or never, token.created or never)
|
||||
|
||||
now = datetime.utcnow()
|
||||
api_tokens = []
|
||||
for token in sorted(user.api_tokens, key=sort_key, reverse=True):
|
||||
if token.expires_at and token.expires_at < now:
|
||||
self.db.delete(token)
|
||||
self.db.commit()
|
||||
continue
|
||||
api_tokens.append(token)
|
||||
|
||||
# group oauth client tokens by client id
|
||||
# AccessTokens have expires_at as an integer timestamp
|
||||
now_timestamp = now.timestamp()
|
||||
oauth_tokens = defaultdict(list)
|
||||
for token in user.oauth_tokens:
|
||||
if token.expires_at and token.expires_at < now_timestamp:
|
||||
self.log.warning("Deleting expired token")
|
||||
all_tokens = defaultdict(list)
|
||||
for token in sorted(user.api_tokens, key=sort_key, reverse=True):
|
||||
if token.expires_at and token.expires_at < now:
|
||||
self.log.warning(f"Deleting expired token {token}")
|
||||
self.db.delete(token)
|
||||
self.db.commit()
|
||||
continue
|
||||
if not token.client_id:
|
||||
# token should have been deleted when client was deleted
|
||||
self.log.warning("Deleting stale oauth token for %s", user.name)
|
||||
self.log.warning("Deleting stale oauth token {token}")
|
||||
self.db.delete(token)
|
||||
self.db.commit()
|
||||
continue
|
||||
oauth_tokens[token.client_id].append(token)
|
||||
all_tokens[token.client_id].append(token)
|
||||
|
||||
# individually list tokens issued by jupyterhub itself
|
||||
api_tokens = all_tokens.pop("jupyterhub", [])
|
||||
|
||||
# group all other tokens issued under their owners
|
||||
# get the earliest created and latest last_activity
|
||||
# timestamp for a given oauth client
|
||||
oauth_clients = []
|
||||
for client_id, tokens in oauth_tokens.items():
|
||||
|
||||
for client_id, tokens in all_tokens.items():
|
||||
created = tokens[0].created
|
||||
last_activity = tokens[0].last_activity
|
||||
for token in tokens[1:]:
|
||||
|
@@ -2,18 +2,18 @@
|
||||
|
||||
implements https://oauthlib.readthedocs.io/en/latest/oauth2/server.html
|
||||
"""
|
||||
from datetime import timedelta
|
||||
|
||||
from oauthlib import uri_validate
|
||||
from oauthlib.oauth2 import RequestValidator
|
||||
from oauthlib.oauth2 import WebApplicationServer
|
||||
from oauthlib.oauth2.rfc6749.grant_types import authorization_code
|
||||
from oauthlib.oauth2.rfc6749.grant_types import base
|
||||
from tornado.escape import url_escape
|
||||
from tornado.log import app_log
|
||||
|
||||
from .. import orm
|
||||
from ..utils import compare_token
|
||||
from ..utils import hash_token
|
||||
from ..utils import url_path_join
|
||||
|
||||
# patch absolute-uri check
|
||||
# because we want to allow relative uri oauth
|
||||
@@ -60,6 +60,9 @@ class JupyterHubRequestValidator(RequestValidator):
|
||||
)
|
||||
if oauth_client is None:
|
||||
return False
|
||||
if not client_secret or not oauth_client.secret:
|
||||
# disallow authentication with no secret
|
||||
return False
|
||||
if not compare_token(oauth_client.secret, client_secret):
|
||||
app_log.warning("Client secret mismatch for %s", client_id)
|
||||
return False
|
||||
@@ -339,10 +342,10 @@ class JupyterHubRequestValidator(RequestValidator):
|
||||
.filter_by(identifier=request.client.client_id)
|
||||
.first()
|
||||
)
|
||||
orm_access_token = orm.OAuthAccessToken(
|
||||
client=client,
|
||||
orm_access_token = orm.APIToken.new(
|
||||
client_id=client.identifier,
|
||||
grant_type=orm.GrantType.authorization_code,
|
||||
expires_at=orm.OAuthAccessToken.now() + token['expires_in'],
|
||||
expires_at=orm.APIToken.now() + timedelta(seconds=token['expires_in']),
|
||||
refresh_token=token['refresh_token'],
|
||||
# TODO: save scopes,
|
||||
# scopes=scopes,
|
||||
@@ -412,6 +415,8 @@ class JupyterHubRequestValidator(RequestValidator):
|
||||
)
|
||||
if orm_client is None:
|
||||
return False
|
||||
if not orm_client.secret:
|
||||
return False
|
||||
request.client = orm_client
|
||||
return True
|
||||
|
||||
@@ -574,14 +579,16 @@ class JupyterHubOAuthServer(WebApplicationServer):
|
||||
app_log.info(f'Creating oauth client {client_id}')
|
||||
else:
|
||||
app_log.info(f'Updating oauth client {client_id}')
|
||||
orm_client.secret = hash_token(client_secret)
|
||||
orm_client.secret = hash_token(client_secret) if client_secret else ""
|
||||
orm_client.redirect_uri = redirect_uri
|
||||
orm_client.description = description
|
||||
self.db.commit()
|
||||
|
||||
def fetch_by_client_id(self, client_id):
|
||||
"""Find a client by its id"""
|
||||
return self.db.query(orm.OAuthClient).filter_by(identifier=client_id).first()
|
||||
client = self.db.query(orm.OAuthClient).filter_by(identifier=client_id).first()
|
||||
if client and client.secret:
|
||||
return client
|
||||
|
||||
|
||||
def make_provider(session_factory, url_prefix, login_url, **oauth_server_kwargs):
|
||||
|
@@ -277,9 +277,6 @@ class User(Base):
|
||||
last_activity = Column(DateTime, nullable=True)
|
||||
|
||||
api_tokens = relationship("APIToken", backref="user", cascade="all, delete-orphan")
|
||||
oauth_tokens = relationship(
|
||||
"OAuthAccessToken", backref="user", cascade="all, delete-orphan"
|
||||
)
|
||||
oauth_codes = relationship(
|
||||
"OAuthCode", backref="user", cascade="all, delete-orphan"
|
||||
)
|
||||
@@ -485,7 +482,9 @@ class Hashed(Expiring):
|
||||
@classmethod
|
||||
def check_token(cls, db, token):
|
||||
"""Check if a token is acceptable"""
|
||||
print("checking", cls, token, len(token), cls.min_length)
|
||||
if len(token) < cls.min_length:
|
||||
print("raising")
|
||||
raise ValueError(
|
||||
"Tokens must be at least %i characters, got %r"
|
||||
% (cls.min_length, token)
|
||||
@@ -530,6 +529,20 @@ class Hashed(Expiring):
|
||||
return orm_token
|
||||
|
||||
|
||||
# ------------------------------------
|
||||
# OAuth tables
|
||||
# ------------------------------------
|
||||
|
||||
|
||||
class GrantType(enum.Enum):
|
||||
# we only use authorization_code for now
|
||||
authorization_code = 'authorization_code'
|
||||
implicit = 'implicit'
|
||||
password = 'password'
|
||||
client_credentials = 'client_credentials'
|
||||
refresh_token = 'refresh_token'
|
||||
|
||||
|
||||
class APIToken(Hashed, Base):
|
||||
"""An API token"""
|
||||
|
||||
@@ -548,6 +561,15 @@ class APIToken(Hashed, Base):
|
||||
def api_id(self):
|
||||
return 'a%i' % self.id
|
||||
|
||||
# added in 2.0
|
||||
client_id = Column(
|
||||
Unicode(255), ForeignKey('oauth_clients.identifier', ondelete='CASCADE')
|
||||
)
|
||||
grant_type = Column(Enum(GrantType), nullable=False)
|
||||
refresh_token = Column(Unicode(255))
|
||||
# the browser session id associated with a given token
|
||||
session_id = Column(Unicode(255))
|
||||
|
||||
# token metadata for bookkeeping
|
||||
now = datetime.utcnow # for expiry
|
||||
created = Column(DateTime, default=datetime.utcnow)
|
||||
@@ -566,8 +588,12 @@ class APIToken(Hashed, Base):
|
||||
# this shouldn't happen
|
||||
kind = 'owner'
|
||||
name = 'unknown'
|
||||
return "<{cls}('{pre}...', {kind}='{name}')>".format(
|
||||
cls=self.__class__.__name__, pre=self.prefix, kind=kind, name=name
|
||||
return "<{cls}('{pre}...', {kind}='{name}', client_id={client_id!r})>".format(
|
||||
cls=self.__class__.__name__,
|
||||
pre=self.prefix,
|
||||
kind=kind,
|
||||
name=name,
|
||||
client_id=self.client_id,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -588,6 +614,14 @@ class APIToken(Hashed, Base):
|
||||
raise ValueError("kind must be 'user', 'service', or None, not %r" % kind)
|
||||
for orm_token in prefix_match:
|
||||
if orm_token.match(token):
|
||||
if not orm_token.client_id:
|
||||
app_log.warning(
|
||||
"Deleting stale oauth token for %s with no client",
|
||||
orm_token.user and orm_token.user.name,
|
||||
)
|
||||
db.delete(orm_token)
|
||||
db.commit()
|
||||
return
|
||||
return orm_token
|
||||
|
||||
@classmethod
|
||||
@@ -600,6 +634,7 @@ class APIToken(Hashed, Base):
|
||||
note='',
|
||||
generated=True,
|
||||
expires_in=None,
|
||||
client_id='jupyterhub',
|
||||
):
|
||||
"""Generate a new API token for a user or service"""
|
||||
assert user or service
|
||||
@@ -614,7 +649,12 @@ class APIToken(Hashed, Base):
|
||||
cls.check_token(db, token)
|
||||
# two stages to ensure orm_token.generated has been set
|
||||
# before token setter is called
|
||||
orm_token = cls(generated=generated, note=note or '')
|
||||
orm_token = cls(
|
||||
generated=generated,
|
||||
note=note or '',
|
||||
grant_type=GrantType.authorization_code,
|
||||
client_id=client_id,
|
||||
)
|
||||
orm_token.token = token
|
||||
if user:
|
||||
assert user.id is not None
|
||||
@@ -641,76 +681,6 @@ class APIToken(Hashed, Base):
|
||||
return token
|
||||
|
||||
|
||||
# ------------------------------------
|
||||
# OAuth tables
|
||||
# ------------------------------------
|
||||
|
||||
|
||||
class GrantType(enum.Enum):
|
||||
# we only use authorization_code for now
|
||||
authorization_code = 'authorization_code'
|
||||
implicit = 'implicit'
|
||||
password = 'password'
|
||||
client_credentials = 'client_credentials'
|
||||
refresh_token = 'refresh_token'
|
||||
|
||||
|
||||
class OAuthAccessToken(Hashed, Base):
|
||||
__tablename__ = 'oauth_access_tokens'
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
|
||||
@staticmethod
|
||||
def now():
|
||||
return datetime.utcnow().timestamp()
|
||||
|
||||
@property
|
||||
def api_id(self):
|
||||
return 'o%i' % self.id
|
||||
|
||||
client_id = Column(
|
||||
Unicode(255), ForeignKey('oauth_clients.identifier', ondelete='CASCADE')
|
||||
)
|
||||
grant_type = Column(Enum(GrantType), nullable=False)
|
||||
expires_at = Column(Integer)
|
||||
refresh_token = Column(Unicode(255))
|
||||
# TODO: drop refresh_expires_at. Refresh tokens shouldn't expire
|
||||
refresh_expires_at = Column(Integer)
|
||||
user_id = Column(Integer, ForeignKey('users.id', ondelete='CASCADE'))
|
||||
service = None # for API-equivalence with APIToken
|
||||
|
||||
# the browser session id associated with a given token
|
||||
session_id = Column(Unicode(255))
|
||||
|
||||
# from Hashed
|
||||
hashed = Column(Unicode(255), unique=True)
|
||||
prefix = Column(Unicode(16), index=True)
|
||||
|
||||
created = Column(DateTime, default=datetime.utcnow)
|
||||
last_activity = Column(DateTime, nullable=True)
|
||||
|
||||
def __repr__(self):
|
||||
return "<{cls}('{prefix}...', client_id={client_id!r}, user={user!r}, expires_in={expires_in}>".format(
|
||||
cls=self.__class__.__name__,
|
||||
client_id=self.client_id,
|
||||
user=self.user and self.user.name,
|
||||
prefix=self.prefix,
|
||||
expires_in=self.expires_in,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def find(cls, db, token):
|
||||
orm_token = super().find(db, token)
|
||||
if orm_token and not orm_token.client_id:
|
||||
app_log.warning(
|
||||
"Deleting stale oauth token for %s with no client",
|
||||
orm_token.user and orm_token.user.name,
|
||||
)
|
||||
db.delete(orm_token)
|
||||
db.commit()
|
||||
return
|
||||
return orm_token
|
||||
|
||||
|
||||
class OAuthCode(Expiring, Base):
|
||||
__tablename__ = 'oauth_codes'
|
||||
|
||||
@@ -752,7 +722,7 @@ class OAuthClient(Base):
|
||||
return self.identifier
|
||||
|
||||
access_tokens = relationship(
|
||||
OAuthAccessToken, backref='client', cascade='all, delete-orphan'
|
||||
APIToken, backref='client', cascade='all, delete-orphan'
|
||||
)
|
||||
codes = relationship(OAuthCode, backref='client', cascade='all, delete-orphan')
|
||||
|
||||
|
@@ -51,6 +51,7 @@ from traitlets import Dict
|
||||
from traitlets import HasTraits
|
||||
from traitlets import Instance
|
||||
from traitlets import Unicode
|
||||
from traitlets import validate
|
||||
from traitlets.config import LoggingConfigurable
|
||||
|
||||
from .. import orm
|
||||
@@ -284,6 +285,15 @@ class Service(LoggingConfigurable):
|
||||
def _default_client_id(self):
|
||||
return 'service-%s' % self.name
|
||||
|
||||
@validate("oauth_client_id")
|
||||
def _validate_client_id(self, proposal):
|
||||
if not proposal.value.startswith("service-"):
|
||||
raise ValueError(
|
||||
f"service {self.name} has oauth_client_id='{proposal.value}'."
|
||||
" Service oauth client ids must start with 'service-'"
|
||||
)
|
||||
return proposal.value
|
||||
|
||||
oauth_redirect_uri = Unicode(
|
||||
help="""OAuth redirect URI for this service.
|
||||
|
||||
|
@@ -70,7 +70,11 @@ def populate_db(url):
|
||||
code = orm.OAuthCode(client_id=client.identifier)
|
||||
db.add(code)
|
||||
db.commit()
|
||||
access_token = orm.OAuthAccessToken(
|
||||
if jupyterhub.version_info < (2, 0):
|
||||
Token = orm.OAuthAccessToken
|
||||
else:
|
||||
Token = orm.APIToken
|
||||
access_token = Token(
|
||||
client_id=client.identifier,
|
||||
user_id=user.id,
|
||||
grant_type=orm.GrantType.authorization_code,
|
||||
|
@@ -273,7 +273,7 @@ async def test_get_self(app):
|
||||
oauth_client = orm.OAuthClient(identifier='eurydice')
|
||||
db.add(oauth_client)
|
||||
db.commit()
|
||||
oauth_token = orm.OAuthAccessToken(
|
||||
oauth_token = orm.APIToken(
|
||||
user=u.orm_user,
|
||||
client=oauth_client,
|
||||
token=token,
|
||||
@@ -1423,12 +1423,11 @@ async def test_token_list(app, as_user, for_user, status):
|
||||
if status != 200:
|
||||
return
|
||||
reply = r.json()
|
||||
assert sorted(reply) == ['api_tokens', 'oauth_tokens']
|
||||
assert sorted(reply) == ['api_tokens']
|
||||
assert len(reply['api_tokens']) == len(for_user_obj.api_tokens)
|
||||
assert all(token['user'] == for_user for token in reply['api_tokens'])
|
||||
assert all(token['user'] == for_user for token in reply['oauth_tokens'])
|
||||
# validate individual token ids
|
||||
for token in reply['api_tokens'] + reply['oauth_tokens']:
|
||||
for token in reply['api_tokens']:
|
||||
r = await api_request(
|
||||
app, 'users', for_user, 'tokens', token['id'], headers=headers
|
||||
)
|
||||
|
@@ -355,7 +355,7 @@ def test_user_delete_cascade(db):
|
||||
spawner.server = server = orm.Server()
|
||||
oauth_code = orm.OAuthCode(client=oauth_client, user=user)
|
||||
db.add(oauth_code)
|
||||
oauth_token = orm.OAuthAccessToken(
|
||||
oauth_token = orm.APIToken(
|
||||
client=oauth_client, user=user, grant_type=orm.GrantType.authorization_code
|
||||
)
|
||||
db.add(oauth_token)
|
||||
@@ -377,7 +377,7 @@ def test_user_delete_cascade(db):
|
||||
assert_not_found(db, orm.Spawner, spawner_id)
|
||||
assert_not_found(db, orm.Server, server_id)
|
||||
assert_not_found(db, orm.OAuthCode, oauth_code_id)
|
||||
assert_not_found(db, orm.OAuthAccessToken, oauth_token_id)
|
||||
assert_not_found(db, orm.APIToken, oauth_token_id)
|
||||
|
||||
|
||||
def test_oauth_client_delete_cascade(db):
|
||||
@@ -391,12 +391,12 @@ def test_oauth_client_delete_cascade(db):
|
||||
# these should all be deleted automatically when the user goes away
|
||||
oauth_code = orm.OAuthCode(client=oauth_client, user=user)
|
||||
db.add(oauth_code)
|
||||
oauth_token = orm.OAuthAccessToken(
|
||||
oauth_token = orm.APIToken(
|
||||
client=oauth_client, user=user, grant_type=orm.GrantType.authorization_code
|
||||
)
|
||||
db.add(oauth_token)
|
||||
db.commit()
|
||||
assert user.oauth_tokens == [oauth_token]
|
||||
assert user.tokens == [oauth_token]
|
||||
|
||||
# record all of the ids
|
||||
oauth_code_id = oauth_code.id
|
||||
@@ -408,8 +408,8 @@ def test_oauth_client_delete_cascade(db):
|
||||
|
||||
# verify that everything gets deleted
|
||||
assert_not_found(db, orm.OAuthCode, oauth_code_id)
|
||||
assert_not_found(db, orm.OAuthAccessToken, oauth_token_id)
|
||||
assert user.oauth_tokens == []
|
||||
assert_not_found(db, orm.APIToken, oauth_token_id)
|
||||
assert user.tokens == []
|
||||
assert user.oauth_codes == []
|
||||
|
||||
|
||||
@@ -510,32 +510,32 @@ def test_expiring_api_token(app, user):
|
||||
def test_expiring_oauth_token(app, user):
|
||||
db = app.db
|
||||
token = "abc123"
|
||||
now = orm.OAuthAccessToken.now
|
||||
now = orm.APIToken.now
|
||||
client = orm.OAuthClient(identifier="xxx", secret="yyy")
|
||||
db.add(client)
|
||||
orm_token = orm.OAuthAccessToken(
|
||||
orm_token = orm.APIToken(
|
||||
token=token,
|
||||
grant_type=orm.GrantType.authorization_code,
|
||||
client=client,
|
||||
user=user,
|
||||
expires_at=now() + 30,
|
||||
expires_at=now() + datetime.timedelta(seconds=30),
|
||||
)
|
||||
db.add(orm_token)
|
||||
db.commit()
|
||||
|
||||
found = orm.OAuthAccessToken.find(db, token)
|
||||
found = orm.APIToken.find(db, token)
|
||||
assert found is orm_token
|
||||
# purge_expired doesn't delete non-expired
|
||||
orm.OAuthAccessToken.purge_expired(db)
|
||||
found = orm.OAuthAccessToken.find(db, token)
|
||||
orm.APIToken.purge_expired(db)
|
||||
found = orm.APIToken.find(db, token)
|
||||
assert found is orm_token
|
||||
|
||||
with mock.patch.object(orm.OAuthAccessToken, 'now', lambda: now() + 60):
|
||||
found = orm.OAuthAccessToken.find(db, token)
|
||||
with mock.patch.object(orm.APIToken, 'now', lambda: now() + 60):
|
||||
found = orm.APIToken.find(db, token)
|
||||
assert found is None
|
||||
assert orm_token in db.query(orm.OAuthAccessToken)
|
||||
orm.OAuthAccessToken.purge_expired(db)
|
||||
assert orm_token not in db.query(orm.OAuthAccessToken)
|
||||
assert orm_token in db.query(orm.APIToken)
|
||||
orm.APIToken.purge_expired(db)
|
||||
assert orm_token not in db.query(orm.APIToken)
|
||||
|
||||
|
||||
def test_expiring_oauth_code(app, user):
|
||||
|
@@ -869,7 +869,7 @@ async def test_oauth_token_page(app):
|
||||
user = app.users[orm.User.find(app.db, name)]
|
||||
client = orm.OAuthClient(identifier='token')
|
||||
app.db.add(client)
|
||||
oauth_token = orm.OAuthAccessToken(
|
||||
oauth_token = orm.APIToken(
|
||||
client=client, user=user, grant_type=orm.GrantType.authorization_code
|
||||
)
|
||||
app.db.add(oauth_token)
|
||||
|
@@ -444,11 +444,7 @@ async def test_oauth_logout(app, mockservice_url):
|
||||
|
||||
def auth_tokens():
|
||||
"""Return list of OAuth access tokens for the user"""
|
||||
return list(
|
||||
app.db.query(orm.OAuthAccessToken).filter(
|
||||
orm.OAuthAccessToken.user_id == app_user.id
|
||||
)
|
||||
)
|
||||
return list(app.db.query(orm.APIToken).filter_by(user_id=app_user.id))
|
||||
|
||||
# ensure we start empty
|
||||
assert auth_tokens() == []
|
||||
|
Reference in New Issue
Block a user