Merge pull request #3548 from C4IROcean/authenticator_user_group_management

Authenticator user group management
This commit is contained in:
Min RK
2022-01-25 14:36:41 +01:00
committed by GitHub
11 changed files with 278 additions and 3 deletions

View File

@@ -247,6 +247,36 @@ class MyAuthenticator(Authenticator):
spawner.environment['UPSTREAM_TOKEN'] = auth_state['upstream_token'] spawner.environment['UPSTREAM_TOKEN'] = auth_state['upstream_token']
``` ```
## Authenticator-managed group membership
:::{versionadded} 2.2
:::
Some identity providers may have their own concept of group membership that you would like to preserve in JupyterHub.
This is now possible with `Authenticator.managed_groups`.
You can set the config:
```python
c.Authenticator.manage_groups = True
```
to enable this behavior.
The default is False for Authenticators that ship with JupyterHub,
but may be True for custom Authenticators.
Check your Authenticator's documentation for manage_groups support.
If True, {meth}`.Authenticator.authenticate` and {meth}`.Authenticator.refresh_user` may include a field `groups`
which is a list of group names the user should be a member of:
- Membership will be added for any group in the list
- Membership in any groups not in the list will be revoked
- Any groups not already present in the database will be created
- If `None` is returned, no changes are made to the user's group membership
If authenticator-managed groups are enabled,
all group-management via the API is disabled.
## pre_spawn_start and post_spawn_stop hooks ## pre_spawn_start and post_spawn_stop hooks
Authenticators uses two hooks, [pre_spawn_start(user, spawner)][] and Authenticators uses two hooks, [pre_spawn_start(user, spawner)][] and

View File

@@ -0,0 +1,30 @@
"""sample jupyterhub config file for testing
configures jupyterhub with dummyauthenticator and simplespawner
to enable testing without administrative privileges.
"""
c = get_config() # noqa
c.Application.log_level = 'DEBUG'
from oauthenticator.azuread import AzureAdOAuthenticator
import os
c.JupyterHub.authenticator_class = AzureAdOAuthenticator
c.AzureAdOAuthenticator.client_id = os.getenv("AAD_CLIENT_ID")
c.AzureAdOAuthenticator.client_secret = os.getenv("AAD_CLIENT_SECRET")
c.AzureAdOAuthenticator.oauth_callback_url = os.getenv("AAD_CALLBACK_URL")
c.AzureAdOAuthenticator.tenant_id = os.getenv("AAD_TENANT_ID")
c.AzureAdOAuthenticator.username_claim = "email"
c.AzureAdOAuthenticator.authorize_url = os.getenv("AAD_AUTHORIZE_URL")
c.AzureAdOAuthenticator.token_url = os.getenv("AAD_TOKEN_URL")
c.Authenticator.manage_groups = True
c.Authenticator.refresh_pre_spawn = True
# Optionally set a global password that all users must use
# c.DummyAuthenticator.password = "your_password"
from jupyterhub.spawner import SimpleLocalProcessSpawner
c.JupyterHub.spawner_class = SimpleLocalProcessSpawner

View File

@@ -0,0 +1,2 @@
oauthenticator
pyjwt

View File

@@ -33,6 +33,11 @@ class _GroupAPIHandler(APIHandler):
raise web.HTTPError(404, "No such group: %s", group_name) raise web.HTTPError(404, "No such group: %s", group_name)
return group return group
def check_authenticator_managed_groups(self):
"""Raise error on group-management APIs if Authenticator is managing groups"""
if self.authenticator.manage_groups:
raise web.HTTPError(400, "Group management via API is disabled")
class GroupListAPIHandler(_GroupAPIHandler): class GroupListAPIHandler(_GroupAPIHandler):
@needs_scope('list:groups') @needs_scope('list:groups')
@@ -68,6 +73,9 @@ class GroupListAPIHandler(_GroupAPIHandler):
@needs_scope('admin:groups') @needs_scope('admin:groups')
async def post(self): async def post(self):
"""POST creates Multiple groups""" """POST creates Multiple groups"""
self.check_authenticator_managed_groups()
model = self.get_json_body() model = self.get_json_body()
if not model or not isinstance(model, dict) or not model.get('groups'): if not model or not isinstance(model, dict) or not model.get('groups'):
raise web.HTTPError(400, "Must specify at least one group to create") raise web.HTTPError(400, "Must specify at least one group to create")
@@ -106,6 +114,7 @@ class GroupAPIHandler(_GroupAPIHandler):
@needs_scope('admin:groups') @needs_scope('admin:groups')
async def post(self, group_name): async def post(self, group_name):
"""POST creates a group by name""" """POST creates a group by name"""
self.check_authenticator_managed_groups()
model = self.get_json_body() model = self.get_json_body()
if model is None: if model is None:
model = {} model = {}
@@ -132,6 +141,7 @@ class GroupAPIHandler(_GroupAPIHandler):
@needs_scope('delete:groups') @needs_scope('delete:groups')
def delete(self, group_name): def delete(self, group_name):
"""Delete a group by name""" """Delete a group by name"""
self.check_authenticator_managed_groups()
group = self.find_group(group_name) group = self.find_group(group_name)
self.log.info("Deleting group %s", group_name) self.log.info("Deleting group %s", group_name)
self.db.delete(group) self.db.delete(group)
@@ -145,6 +155,7 @@ class GroupUsersAPIHandler(_GroupAPIHandler):
@needs_scope('groups') @needs_scope('groups')
def post(self, group_name): def post(self, group_name):
"""POST adds users to a group""" """POST adds users to a group"""
self.check_authenticator_managed_groups()
group = self.find_group(group_name) group = self.find_group(group_name)
data = self.get_json_body() data = self.get_json_body()
self._check_group_model(data) self._check_group_model(data)
@@ -163,6 +174,7 @@ class GroupUsersAPIHandler(_GroupAPIHandler):
@needs_scope('groups') @needs_scope('groups')
async def delete(self, group_name): async def delete(self, group_name):
"""DELETE removes users from a group""" """DELETE removes users from a group"""
self.check_authenticator_managed_groups()
group = self.find_group(group_name) group = self.find_group(group_name)
data = self.get_json_body() data = self.get_json_body()
self._check_group_model(data) self._check_group_model(data)

View File

@@ -2001,6 +2001,9 @@ class JupyterHub(Application):
async def init_groups(self): async def init_groups(self):
"""Load predefined groups into the database""" """Load predefined groups into the database"""
db = self.db db = self.db
if self.authenticator.manage_groups and self.load_groups:
raise ValueError("Group management has been offloaded to the authenticator")
for name, usernames in self.load_groups.items(): for name, usernames in self.load_groups.items():
group = orm.Group.find(db, name) group = orm.Group.find(db, name)
if group is None: if group is None:

View File

@@ -582,9 +582,13 @@ class Authenticator(LoggingConfigurable):
or None if Authentication failed. or None if Authentication failed.
The Authenticator may return a dict instead, which MUST have a The Authenticator may return a dict instead, which MUST have a
key `name` holding the username, and MAY have two optional keys key `name` holding the username, and MAY have additional keys:
set: `auth_state`, a dictionary of of auth state that will be
persisted; and `admin`, the admin setting value for the user. - `auth_state`, a dictionary of of auth state that will be
persisted;
- `admin`, the admin setting value for the user
- `groups`, the list of group names the user should be a member of,
if Authenticator.manage_groups is True.
""" """
def pre_spawn_start(self, user, spawner): def pre_spawn_start(self, user, spawner):
@@ -635,6 +639,19 @@ class Authenticator(LoggingConfigurable):
""" """
self.allowed_users.discard(user.name) self.allowed_users.discard(user.name)
manage_groups = Bool(
False,
config=True,
help="""Let authenticator manage user groups
If True, Authenticator.authenticate and/or .refresh_user
may return a list of group names in the 'groups' field,
which will be assigned to the user.
All group-assignment APIs are disabled if this is True.
""",
)
auto_login = Bool( auto_login = Bool(
False, False,
config=True, config=True,

View File

@@ -774,13 +774,22 @@ class BaseHandler(RequestHandler):
# always ensure default roles ('user', 'admin' if admin) are assigned # always ensure default roles ('user', 'admin' if admin) are assigned
# after a successful login # after a successful login
roles.assign_default_roles(self.db, entity=user) roles.assign_default_roles(self.db, entity=user)
# apply authenticator-managed groups
if self.authenticator.manage_groups:
group_names = authenticated.get("groups")
if group_names is not None:
user.sync_groups(group_names)
# always set auth_state and commit, # always set auth_state and commit,
# because there could be key-rotation or clearing of previous values # because there could be key-rotation or clearing of previous values
# going on. # going on.
if not self.authenticator.enable_auth_state: if not self.authenticator.enable_auth_state:
# auth_state is not enabled. Force None. # auth_state is not enabled. Force None.
auth_state = None auth_state = None
await user.save_auth_state(auth_state) await user.save_auth_state(auth_state)
return user return user
async def login_user(self, data=None): async def login_user(self, data=None):
@@ -794,6 +803,7 @@ class BaseHandler(RequestHandler):
self.set_login_cookie(user) self.set_login_cookie(user)
self.statsd.incr('login.success') self.statsd.incr('login.success')
self.statsd.timing('login.authenticate.success', auth_timer.ms) self.statsd.timing('login.authenticate.success', auth_timer.ms)
self.log.info("User logged in: %s", user.name) self.log.info("User logged in: %s", user.name)
user._auth_refreshed = time.monotonic() user._auth_refreshed = time.monotonic()
return user return user

View File

@@ -1806,6 +1806,38 @@ async def test_group_add_delete_users(app):
assert sorted(u.name for u in group.users) == sorted(names[2:]) assert sorted(u.name for u in group.users) == sorted(names[2:])
@mark.group
async def test_auth_managed_groups(request, app, group, user):
group.users.append(user)
app.db.commit()
app.authenticator.manage_groups = True
request.addfinalizer(lambda: setattr(app.authenticator, "manage_groups", False))
# create groups
r = await api_request(app, 'groups', method='post')
assert r.status_code == 400
r = await api_request(app, 'groups/newgroup', method='post')
assert r.status_code == 400
# delete groups
r = await api_request(app, f'groups/{group.name}', method='delete')
assert r.status_code == 400
# add users to group
r = await api_request(
app,
f'groups/{group.name}/users',
method='post',
data=json.dumps({"users": [user.name]}),
)
assert r.status_code == 400
# remove users from group
r = await api_request(
app,
f'groups/{group.name}/users',
method='delete',
data=json.dumps({"users": [user.name]}),
)
assert r.status_code == 400
# ----------------- # -----------------
# Service API tests # Service API tests
# ----------------- # -----------------

View File

@@ -7,6 +7,7 @@ from urllib.parse import urlparse
import pytest import pytest
from requests import HTTPError from requests import HTTPError
from traitlets import Any
from traitlets.config import Config from traitlets.config import Config
from .mocking import MockPAMAuthenticator from .mocking import MockPAMAuthenticator
@@ -14,6 +15,7 @@ from .mocking import MockStructGroup
from .mocking import MockStructPasswd from .mocking import MockStructPasswd
from .utils import add_user from .utils import add_user
from .utils import async_requests from .utils import async_requests
from .utils import get_page
from .utils import public_url from .utils import public_url
from jupyterhub import auth from jupyterhub import auth
from jupyterhub import crypto from jupyterhub import crypto
@@ -527,3 +529,71 @@ async def test_nullauthenticator(app):
r = await async_requests.get(public_url(app)) r = await async_requests.get(public_url(app))
assert urlparse(r.url).path.endswith("/hub/login") assert urlparse(r.url).path.endswith("/hub/login")
assert r.status_code == 403 assert r.status_code == 403
class MockGroupsAuthenticator(auth.Authenticator):
authenticated_groups = Any()
refresh_groups = Any()
manage_groups = True
def authenticate(self, handler, data):
return {
"name": data["username"],
"groups": self.authenticated_groups,
}
async def refresh_user(self, user, handler):
return {
"name": user.name,
"groups": self.refresh_groups,
}
@pytest.mark.parametrize(
"authenticated_groups, refresh_groups",
[
(None, None),
(["auth1"], None),
(None, ["auth1"]),
(["auth1"], ["auth1", "auth2"]),
(["auth1", "auth2"], ["auth1"]),
(["auth1", "auth2"], ["auth3"]),
(["auth1", "auth2"], ["auth3"]),
],
)
async def test_auth_managed_groups(
app, user, group, authenticated_groups, refresh_groups
):
authenticator = MockGroupsAuthenticator(
parent=app,
authenticated_groups=authenticated_groups,
refresh_groups=refresh_groups,
)
user.groups.append(group)
app.db.commit()
before_groups = [group.name]
if authenticated_groups is None:
expected_authenticated_groups = before_groups
else:
expected_authenticated_groups = authenticated_groups
if refresh_groups is None:
expected_refresh_groups = expected_authenticated_groups
else:
expected_refresh_groups = refresh_groups
with mock.patch.dict(app.tornado_settings, {"authenticator": authenticator}):
cookies = await app.login_user(user.name)
assert not app.db.dirty
groups = sorted(g.name for g in user.groups)
assert groups == expected_authenticated_groups
# force refresh_user on next request
user._auth_refreshed -= 10 + app.authenticator.auth_refresh_age
r = await get_page('home', app, cookies=cookies, allow_redirects=False)
assert r.status_code == 200
assert not app.db.dirty
groups = sorted(g.name for g in user.groups)
assert groups == expected_refresh_groups

View File

@@ -1,5 +1,6 @@
import pytest import pytest
from .. import orm
from ..user import UserDict from ..user import UserDict
from .utils import add_user from .utils import add_user
@@ -20,3 +21,35 @@ async def test_userdict_get(db, attr):
assert userdict.get(key).id == u.id assert userdict.get(key).id == u.id
# `in` should find it now # `in` should find it now
assert key in userdict assert key in userdict
@pytest.mark.parametrize(
"group_names",
[
["isin1", "isin2"],
["isin1"],
["notin", "isin1"],
["new-group", "isin1"],
[],
],
)
def test_sync_groups(app, user, group_names):
expected = sorted(group_names)
db = app.db
db.add(orm.Group(name="notin"))
in_groups = [orm.Group(name="isin1"), orm.Group(name="isin2")]
for group in in_groups:
db.add(group)
db.commit()
user.groups = in_groups
db.commit()
user.sync_groups(group_names)
assert not app.db.dirty
after_groups = sorted(g.name for g in user.groups)
assert after_groups == expected
# double-check backref
for group in db.query(orm.Group):
if group.name in expected:
assert user.orm_user in group.users
else:
assert user.orm_user not in group.users

View File

@@ -253,6 +253,42 @@ class User:
def spawner_class(self): def spawner_class(self):
return self.settings.get('spawner_class', LocalProcessSpawner) return self.settings.get('spawner_class', LocalProcessSpawner)
def sync_groups(self, group_names):
"""Synchronize groups with database"""
current_groups = {g.name for g in self.orm_user.groups}
new_groups = set(group_names)
if current_groups == new_groups:
# no change, nothing to do
return
# log group changes
new_groups = set(group_names).difference(current_groups)
removed_groups = current_groups.difference(group_names)
if new_groups:
self.log.info("Adding user {self.name} to group(s): {new_groups}")
if removed_groups:
self.log.info("Removing user {self.name} from group(s): {removed_groups}")
if group_names:
groups = (
self.db.query(orm.Group).filter(orm.Group.name.in_(group_names)).all()
)
existing_groups = {g.name for g in groups}
for group_name in group_names:
if group_name not in existing_groups:
# create groups that don't exist yet
self.log.info(
f"Creating new group {group_name} for user {self.name}"
)
group = orm.Group(name=group_name)
self.db.add(group)
groups.append(group)
self.groups = groups
else:
self.groups = []
self.db.commit()
async def save_auth_state(self, auth_state): async def save_auth_state(self, auth_state):
"""Encrypt and store auth_state""" """Encrypt and store auth_state"""
if auth_state is None: if auth_state is None: