store relationship between oauth client and service/spawner

so that we can look up the spawner/service from the oauth client and vice versa
This commit is contained in:
Min RK
2021-05-12 14:48:16 +02:00
parent 563146445f
commit 7e46d5d0fc
9 changed files with 85 additions and 22 deletions

View File

@@ -1,4 +1,4 @@
"""rbac """RBAC
Revision ID: 833da8570507 Revision ID: 833da8570507
Revises: 4dc2d5a8c53c Revises: 4dc2d5a8c53c
@@ -16,6 +16,30 @@ import sqlalchemy as sa
def upgrade(): def upgrade():
# associate spawners and services with their oauth clients
op.add_column(
'services', sa.Column('oauth_client_id', sa.Unicode(length=255), nullable=True)
)
op.create_foreign_key(
None,
'services',
'oauth_clients',
['oauth_client_id'],
['identifier'],
ondelete='SET NULL',
)
op.add_column(
'spawners', sa.Column('oauth_client_id', sa.Unicode(length=255), nullable=True)
)
op.create_foreign_key(
None,
'spawners',
'oauth_clients',
['oauth_client_id'],
['identifier'],
ondelete='SET NULL',
)
# FIXME, maybe: currently drops all api tokens and forces recreation! # FIXME, maybe: currently drops all api tokens and forces recreation!
# this ensures a consistent database, but requires: # this ensures a consistent database, but requires:
# 1. all servers to be stopped for upgrade (maybe unavoidable anyway) # 1. all servers to be stopped for upgrade (maybe unavoidable anyway)
@@ -33,6 +57,12 @@ def upgrade():
def downgrade(): def downgrade():
op.drop_constraint(None, 'spawners', type_='foreignkey')
op.drop_column('spawners', 'oauth_client_id')
op.drop_constraint(None, 'services', type_='foreignkey')
op.drop_column('services', 'oauth_client_id')
# delete OAuth tokens for non-jupyterhub clients # delete OAuth tokens for non-jupyterhub clients
# drop new columns from api tokens # drop new columns from api tokens
op.drop_constraint(None, 'api_tokens', type_='foreignkey') op.drop_constraint(None, 'api_tokens', type_='foreignkey')

View File

@@ -394,7 +394,7 @@ class JupyterHub(Application):
even if your Hub authentication is still valid. even if your Hub authentication is still valid.
If your Hub authentication is valid, If your Hub authentication is valid,
logging in may be a transparent redirect as you refresh the page. logging in may be a transparent redirect as you refresh the page.
This does not affect JupyterHub API tokens in general, This does not affect JupyterHub API tokens in general,
which do not expire by default. which do not expire by default.
Only tokens issued during the oauth flow Only tokens issued during the oauth flow
@@ -887,7 +887,7 @@ class JupyterHub(Application):
"/", "/",
help=""" help="""
The routing prefix for the Hub itself. The routing prefix for the Hub itself.
Override to send only a subset of traffic to the Hub. Override to send only a subset of traffic to the Hub.
Default is to use the Hub as the default route for all requests. Default is to use the Hub as the default route for all requests.
@@ -899,7 +899,7 @@ class JupyterHub(Application):
may want to handle these events themselves, may want to handle these events themselves,
in which case they can register their own default target with the proxy in which case they can register their own default target with the proxy
and set e.g. `hub_routespec = /hub/` to serve only the hub's own pages, or even `/hub/api/` for api-only operation. and set e.g. `hub_routespec = /hub/` to serve only the hub's own pages, or even `/hub/api/` for api-only operation.
Note: hub_routespec must include the base_url, if any. Note: hub_routespec must include the base_url, if any.
.. versionadded:: 1.4 .. versionadded:: 1.4
@@ -1484,7 +1484,7 @@ class JupyterHub(Application):
Can be a Unicode string (e.g. '/hub/home') or a callable based on the handler object: Can be a Unicode string (e.g. '/hub/home') or a callable based on the handler object:
:: ::
def default_url_fn(handler): def default_url_fn(handler):
user = handler.current_user user = handler.current_user
if user and user.admin: if user and user.admin:
@@ -1956,6 +1956,7 @@ class JupyterHub(Application):
for name, usernames in self.load_groups.items(): for name, usernames in self.load_groups.items():
group = orm.Group.find(db, name) group = orm.Group.find(db, name)
if group is None: if group is None:
self.log.info(f"Creating group {name}")
group = orm.Group(name=name) group = orm.Group(name=name)
db.add(group) db.add(group)
for username in usernames: for username in usernames:
@@ -1970,8 +1971,10 @@ class JupyterHub(Application):
if user is None: if user is None:
if not self.authenticator.validate_username(username): if not self.authenticator.validate_username(username):
raise ValueError("Group username %r is not valid" % username) raise ValueError("Group username %r is not valid" % username)
self.log.info(f"Creating user {username} for group {name}")
user = orm.User(name=username) user = orm.User(name=username)
db.add(user) db.add(user)
self.log.debug(f"Adding user {username} to group {name}")
group.users.append(user) group.users.append(user)
db.commit() db.commit()
@@ -2264,6 +2267,10 @@ class JupyterHub(Application):
allowed_roles=service.oauth_roles, allowed_roles=service.oauth_roles,
description="JupyterHub service %s" % service.name, description="JupyterHub service %s" % service.name,
) )
service.orm.oauth_client_id = service.oauth_client_id
else:
if service.oauth_client:
self.db.delete(service.oauth_client)
self._service_map[name] = service self._service_map[name] = service

View File

@@ -10,7 +10,6 @@ from http.client import responses
from jinja2 import TemplateNotFound from jinja2 import TemplateNotFound
from tornado import web from tornado import web
from tornado.httputil import url_concat from tornado.httputil import url_concat
from tornado.httputil import urlparse
from .. import __version__ from .. import __version__
from .. import orm from .. import orm
@@ -590,8 +589,9 @@ class TokenPageHandler(BaseHandler):
token = tokens[0] token = tokens[0]
oauth_clients.append( oauth_clients.append(
{ {
'client': token.client, 'client': token.oauth_client,
'description': token.client.description or token.client.identifier, 'description': token.oauth_client.description
or token.oauth_client.identifier,
'created': created, 'created': created,
'last_activity': last_activity, 'last_activity': last_activity,
'tokens': tokens, 'tokens': tokens,

View File

@@ -2,8 +2,6 @@
implements https://oauthlib.readthedocs.io/en/latest/oauth2/server.html implements https://oauthlib.readthedocs.io/en/latest/oauth2/server.html
""" """
from datetime import timedelta
from oauthlib import uri_validate from oauthlib import uri_validate
from oauthlib.oauth2 import RequestValidator from oauthlib.oauth2 import RequestValidator
from oauthlib.oauth2 import WebApplicationServer from oauthlib.oauth2 import WebApplicationServer

View File

@@ -326,6 +326,21 @@ class Spawner(Base):
last_activity = Column(DateTime, nullable=True) last_activity = Column(DateTime, nullable=True)
user_options = Column(JSONDict) user_options = Column(JSONDict)
# added in 2.0
oauth_client_id = Column(
Unicode(255),
ForeignKey(
'oauth_clients.identifier',
ondelete='SET NULL',
),
)
oauth_client = relationship(
'OAuthClient',
backref=backref("spawner", uselist=False),
cascade="all, delete-orphan",
single_parent=True,
)
# properties on the spawner wrapper # properties on the spawner wrapper
# some APIs get these low-level objects # some APIs get these low-level objects
# when the spawner isn't running, # when the spawner isn't running,
@@ -377,6 +392,21 @@ class Service(Base):
) )
pid = Column(Integer) pid = Column(Integer)
# added in 2.0
oauth_client_id = Column(
Unicode(255),
ForeignKey(
'oauth_clients.identifier',
ondelete='SET NULL',
),
)
oauth_client = relationship(
'OAuthClient',
backref=backref("service", uselist=False),
cascade="all, delete-orphan",
single_parent=True,
)
def new_api_token(self, token=None, **kwargs): def new_api_token(self, token=None, **kwargs):
"""Create a new API token """Create a new API token
If `token` is given, load that token. If `token` is given, load that token.
@@ -567,6 +597,7 @@ class APIToken(Hashed, Base):
ondelete='CASCADE', ondelete='CASCADE',
), ),
) )
# FIXME: refresh_tokens not implemented # FIXME: refresh_tokens not implemented
# should be a relation to another token table # should be a relation to another token table
# refresh_token = Column( # refresh_token = Column(
@@ -746,7 +777,7 @@ class OAuthClient(Base):
return self.identifier return self.identifier
access_tokens = relationship( access_tokens = relationship(
APIToken, backref='client', cascade='all, delete-orphan' APIToken, backref='oauth_client', cascade='all, delete-orphan'
) )
codes = relationship(OAuthCode, backref='client', cascade='all, delete-orphan') codes = relationship(OAuthCode, backref='client', cascade='all, delete-orphan')

View File

@@ -307,7 +307,7 @@ async def test_get_self(app):
db.commit() db.commit()
oauth_token = orm.APIToken( oauth_token = orm.APIToken(
user=u.orm_user, user=u.orm_user,
client=oauth_client, oauth_client=oauth_client,
token=token, token=token,
) )
db.add(oauth_token) db.add(oauth_token)

View File

@@ -364,7 +364,7 @@ def test_user_delete_cascade(db):
oauth_code = orm.OAuthCode(client=oauth_client, user=user) oauth_code = orm.OAuthCode(client=oauth_client, user=user)
db.add(oauth_code) db.add(oauth_code)
oauth_token = orm.APIToken( oauth_token = orm.APIToken(
client=oauth_client, oauth_client=oauth_client,
user=user, user=user,
) )
db.add(oauth_token) db.add(oauth_token)
@@ -401,7 +401,7 @@ def test_oauth_client_delete_cascade(db):
oauth_code = orm.OAuthCode(client=oauth_client, user=user) oauth_code = orm.OAuthCode(client=oauth_client, user=user)
db.add(oauth_code) db.add(oauth_code)
oauth_token = orm.APIToken( oauth_token = orm.APIToken(
client=oauth_client, oauth_client=oauth_client,
user=user, user=user,
) )
db.add(oauth_token) db.add(oauth_token)
@@ -525,7 +525,7 @@ def test_expiring_oauth_token(app, user):
db.add(client) db.add(client)
orm_token = orm.APIToken( orm_token = orm.APIToken(
token=token, token=token,
client=client, oauth_client=client,
user=user, user=user,
expires_at=now() + timedelta(seconds=30), expires_at=now() + timedelta(seconds=30),
) )

View File

@@ -870,7 +870,7 @@ async def test_oauth_token_page(app):
client = orm.OAuthClient(identifier='token') client = orm.OAuthClient(identifier='token')
app.db.add(client) app.db.add(client)
oauth_token = orm.APIToken( oauth_token = orm.APIToken(
client=client, oauth_client=client,
user=user, user=user,
) )
app.db.add(oauth_token) app.db.add(oauth_token)

View File

@@ -590,15 +590,11 @@ class User:
client_id = spawner.oauth_client_id client_id = spawner.oauth_client_id
oauth_provider = self.settings.get('oauth_provider') oauth_provider = self.settings.get('oauth_provider')
if oauth_provider: if oauth_provider:
oauth_client = oauth_provider.fetch_by_client_id(client_id)
# create a new OAuth client + secret on every launch
# containers that resume will be updated below
allowed_roles = spawner.oauth_roles allowed_roles = spawner.oauth_roles
if callable(allowed_roles): if callable(allowed_roles):
allowed_roles = allowed_roles(spawner) allowed_roles = allowed_roles(spawner)
oauth_provider.add_client( oauth_client = oauth_provider.add_client(
client_id, client_id,
api_token, api_token,
url_path_join(self.url, server_name, 'oauth_callback'), url_path_join(self.url, server_name, 'oauth_callback'),
@@ -606,6 +602,7 @@ class User:
description="Server at %s" description="Server at %s"
% (url_path_join(self.base_url, server_name) + '/'), % (url_path_join(self.base_url, server_name) + '/'),
) )
spawner.orm_spawner.oauth_client = oauth_client
db.commit() db.commit()
# trigger pre-spawn hook on authenticator # trigger pre-spawn hook on authenticator
@@ -614,7 +611,7 @@ class User:
spawner._start_pending = True spawner._start_pending = True
if authenticator: if authenticator:
# pre_spawn_start can thow errors that can lead to a redirect loop # pre_spawn_start can throw errors that can lead to a redirect loop
# if left uncaught (see https://github.com/jupyterhub/jupyterhub/issues/2683) # if left uncaught (see https://github.com/jupyterhub/jupyterhub/issues/2683)
await maybe_future(authenticator.pre_spawn_start(self, spawner)) await maybe_future(authenticator.pre_spawn_start(self, spawner))