mirror of
https://github.com/jupyterhub/jupyterhub.git
synced 2025-10-16 14:33:00 +00:00
store relationship between oauth client and service/spawner
so that we can look up the spawner/service from the oauth client and vice versa
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
"""rbac
|
||||
"""RBAC
|
||||
|
||||
Revision ID: 833da8570507
|
||||
Revises: 4dc2d5a8c53c
|
||||
@@ -16,6 +16,30 @@ import sqlalchemy as sa
|
||||
|
||||
|
||||
def upgrade():
|
||||
# associate spawners and services with their oauth clients
|
||||
op.add_column(
|
||||
'services', sa.Column('oauth_client_id', sa.Unicode(length=255), nullable=True)
|
||||
)
|
||||
op.create_foreign_key(
|
||||
None,
|
||||
'services',
|
||||
'oauth_clients',
|
||||
['oauth_client_id'],
|
||||
['identifier'],
|
||||
ondelete='SET NULL',
|
||||
)
|
||||
op.add_column(
|
||||
'spawners', sa.Column('oauth_client_id', sa.Unicode(length=255), nullable=True)
|
||||
)
|
||||
op.create_foreign_key(
|
||||
None,
|
||||
'spawners',
|
||||
'oauth_clients',
|
||||
['oauth_client_id'],
|
||||
['identifier'],
|
||||
ondelete='SET NULL',
|
||||
)
|
||||
|
||||
# FIXME, maybe: currently drops all api tokens and forces recreation!
|
||||
# this ensures a consistent database, but requires:
|
||||
# 1. all servers to be stopped for upgrade (maybe unavoidable anyway)
|
||||
@@ -33,6 +57,12 @@ def upgrade():
|
||||
|
||||
|
||||
def downgrade():
|
||||
|
||||
op.drop_constraint(None, 'spawners', type_='foreignkey')
|
||||
op.drop_column('spawners', 'oauth_client_id')
|
||||
op.drop_constraint(None, 'services', type_='foreignkey')
|
||||
op.drop_column('services', 'oauth_client_id')
|
||||
|
||||
# delete OAuth tokens for non-jupyterhub clients
|
||||
# drop new columns from api tokens
|
||||
op.drop_constraint(None, 'api_tokens', type_='foreignkey')
|
||||
|
@@ -1956,6 +1956,7 @@ class JupyterHub(Application):
|
||||
for name, usernames in self.load_groups.items():
|
||||
group = orm.Group.find(db, name)
|
||||
if group is None:
|
||||
self.log.info(f"Creating group {name}")
|
||||
group = orm.Group(name=name)
|
||||
db.add(group)
|
||||
for username in usernames:
|
||||
@@ -1970,8 +1971,10 @@ class JupyterHub(Application):
|
||||
if user is None:
|
||||
if not self.authenticator.validate_username(username):
|
||||
raise ValueError("Group username %r is not valid" % username)
|
||||
self.log.info(f"Creating user {username} for group {name}")
|
||||
user = orm.User(name=username)
|
||||
db.add(user)
|
||||
self.log.debug(f"Adding user {username} to group {name}")
|
||||
group.users.append(user)
|
||||
db.commit()
|
||||
|
||||
@@ -2264,6 +2267,10 @@ class JupyterHub(Application):
|
||||
allowed_roles=service.oauth_roles,
|
||||
description="JupyterHub service %s" % service.name,
|
||||
)
|
||||
service.orm.oauth_client_id = service.oauth_client_id
|
||||
else:
|
||||
if service.oauth_client:
|
||||
self.db.delete(service.oauth_client)
|
||||
|
||||
self._service_map[name] = service
|
||||
|
||||
|
@@ -10,7 +10,6 @@ from http.client import responses
|
||||
from jinja2 import TemplateNotFound
|
||||
from tornado import web
|
||||
from tornado.httputil import url_concat
|
||||
from tornado.httputil import urlparse
|
||||
|
||||
from .. import __version__
|
||||
from .. import orm
|
||||
@@ -590,8 +589,9 @@ class TokenPageHandler(BaseHandler):
|
||||
token = tokens[0]
|
||||
oauth_clients.append(
|
||||
{
|
||||
'client': token.client,
|
||||
'description': token.client.description or token.client.identifier,
|
||||
'client': token.oauth_client,
|
||||
'description': token.oauth_client.description
|
||||
or token.oauth_client.identifier,
|
||||
'created': created,
|
||||
'last_activity': last_activity,
|
||||
'tokens': tokens,
|
||||
|
@@ -2,8 +2,6 @@
|
||||
|
||||
implements https://oauthlib.readthedocs.io/en/latest/oauth2/server.html
|
||||
"""
|
||||
from datetime import timedelta
|
||||
|
||||
from oauthlib import uri_validate
|
||||
from oauthlib.oauth2 import RequestValidator
|
||||
from oauthlib.oauth2 import WebApplicationServer
|
||||
|
@@ -326,6 +326,21 @@ class Spawner(Base):
|
||||
last_activity = Column(DateTime, nullable=True)
|
||||
user_options = Column(JSONDict)
|
||||
|
||||
# added in 2.0
|
||||
oauth_client_id = Column(
|
||||
Unicode(255),
|
||||
ForeignKey(
|
||||
'oauth_clients.identifier',
|
||||
ondelete='SET NULL',
|
||||
),
|
||||
)
|
||||
oauth_client = relationship(
|
||||
'OAuthClient',
|
||||
backref=backref("spawner", uselist=False),
|
||||
cascade="all, delete-orphan",
|
||||
single_parent=True,
|
||||
)
|
||||
|
||||
# properties on the spawner wrapper
|
||||
# some APIs get these low-level objects
|
||||
# when the spawner isn't running,
|
||||
@@ -377,6 +392,21 @@ class Service(Base):
|
||||
)
|
||||
pid = Column(Integer)
|
||||
|
||||
# added in 2.0
|
||||
oauth_client_id = Column(
|
||||
Unicode(255),
|
||||
ForeignKey(
|
||||
'oauth_clients.identifier',
|
||||
ondelete='SET NULL',
|
||||
),
|
||||
)
|
||||
oauth_client = relationship(
|
||||
'OAuthClient',
|
||||
backref=backref("service", uselist=False),
|
||||
cascade="all, delete-orphan",
|
||||
single_parent=True,
|
||||
)
|
||||
|
||||
def new_api_token(self, token=None, **kwargs):
|
||||
"""Create a new API token
|
||||
If `token` is given, load that token.
|
||||
@@ -567,6 +597,7 @@ class APIToken(Hashed, Base):
|
||||
ondelete='CASCADE',
|
||||
),
|
||||
)
|
||||
|
||||
# FIXME: refresh_tokens not implemented
|
||||
# should be a relation to another token table
|
||||
# refresh_token = Column(
|
||||
@@ -746,7 +777,7 @@ class OAuthClient(Base):
|
||||
return self.identifier
|
||||
|
||||
access_tokens = relationship(
|
||||
APIToken, backref='client', cascade='all, delete-orphan'
|
||||
APIToken, backref='oauth_client', cascade='all, delete-orphan'
|
||||
)
|
||||
codes = relationship(OAuthCode, backref='client', cascade='all, delete-orphan')
|
||||
|
||||
|
@@ -307,7 +307,7 @@ async def test_get_self(app):
|
||||
db.commit()
|
||||
oauth_token = orm.APIToken(
|
||||
user=u.orm_user,
|
||||
client=oauth_client,
|
||||
oauth_client=oauth_client,
|
||||
token=token,
|
||||
)
|
||||
db.add(oauth_token)
|
||||
|
@@ -364,7 +364,7 @@ def test_user_delete_cascade(db):
|
||||
oauth_code = orm.OAuthCode(client=oauth_client, user=user)
|
||||
db.add(oauth_code)
|
||||
oauth_token = orm.APIToken(
|
||||
client=oauth_client,
|
||||
oauth_client=oauth_client,
|
||||
user=user,
|
||||
)
|
||||
db.add(oauth_token)
|
||||
@@ -401,7 +401,7 @@ def test_oauth_client_delete_cascade(db):
|
||||
oauth_code = orm.OAuthCode(client=oauth_client, user=user)
|
||||
db.add(oauth_code)
|
||||
oauth_token = orm.APIToken(
|
||||
client=oauth_client,
|
||||
oauth_client=oauth_client,
|
||||
user=user,
|
||||
)
|
||||
db.add(oauth_token)
|
||||
@@ -525,7 +525,7 @@ def test_expiring_oauth_token(app, user):
|
||||
db.add(client)
|
||||
orm_token = orm.APIToken(
|
||||
token=token,
|
||||
client=client,
|
||||
oauth_client=client,
|
||||
user=user,
|
||||
expires_at=now() + timedelta(seconds=30),
|
||||
)
|
||||
|
@@ -870,7 +870,7 @@ async def test_oauth_token_page(app):
|
||||
client = orm.OAuthClient(identifier='token')
|
||||
app.db.add(client)
|
||||
oauth_token = orm.APIToken(
|
||||
client=client,
|
||||
oauth_client=client,
|
||||
user=user,
|
||||
)
|
||||
app.db.add(oauth_token)
|
||||
|
@@ -590,15 +590,11 @@ class User:
|
||||
client_id = spawner.oauth_client_id
|
||||
oauth_provider = self.settings.get('oauth_provider')
|
||||
if oauth_provider:
|
||||
oauth_client = oauth_provider.fetch_by_client_id(client_id)
|
||||
# create a new OAuth client + secret on every launch
|
||||
# containers that resume will be updated below
|
||||
|
||||
allowed_roles = spawner.oauth_roles
|
||||
if callable(allowed_roles):
|
||||
allowed_roles = allowed_roles(spawner)
|
||||
|
||||
oauth_provider.add_client(
|
||||
oauth_client = oauth_provider.add_client(
|
||||
client_id,
|
||||
api_token,
|
||||
url_path_join(self.url, server_name, 'oauth_callback'),
|
||||
@@ -606,6 +602,7 @@ class User:
|
||||
description="Server at %s"
|
||||
% (url_path_join(self.base_url, server_name) + '/'),
|
||||
)
|
||||
spawner.orm_spawner.oauth_client = oauth_client
|
||||
db.commit()
|
||||
|
||||
# trigger pre-spawn hook on authenticator
|
||||
@@ -614,7 +611,7 @@ class User:
|
||||
spawner._start_pending = True
|
||||
|
||||
if authenticator:
|
||||
# pre_spawn_start can thow errors that can lead to a redirect loop
|
||||
# pre_spawn_start can throw errors that can lead to a redirect loop
|
||||
# if left uncaught (see https://github.com/jupyterhub/jupyterhub/issues/2683)
|
||||
await maybe_future(authenticator.pre_spawn_start(self, spawner))
|
||||
|
||||
|
Reference in New Issue
Block a user