diff --git a/docs/source/reference/authenticators.md b/docs/source/reference/authenticators.md index 068fc248..61f8ecbe 100644 --- a/docs/source/reference/authenticators.md +++ b/docs/source/reference/authenticators.md @@ -247,6 +247,36 @@ class MyAuthenticator(Authenticator): spawner.environment['UPSTREAM_TOKEN'] = auth_state['upstream_token'] ``` +## Authenticator-managed group membership + +:::{versionadded} 2.2 +::: + +Some identity providers may have their own concept of group membership that you would like to preserve in JupyterHub. +This is now possible with `Authenticator.managed_groups`. + +You can set the config: + +```python +c.Authenticator.manage_groups = True +``` + +to enable this behavior. +The default is False for Authenticators that ship with JupyterHub, +but may be True for custom Authenticators. +Check your Authenticator's documentation for manage_groups support. + +If True, {meth}`.Authenticator.authenticate` and {meth}`.Authenticator.refresh_user` may include a field `groups` +which is a list of group names the user should be a member of: + +- Membership will be added for any group in the list +- Membership in any groups not in the list will be revoked +- Any groups not already present in the database will be created +- If `None` is returned, no changes are made to the user's group membership + +If authenticator-managed groups are enabled, +all group-management via the API is disabled. + ## pre_spawn_start and post_spawn_stop hooks Authenticators uses two hooks, [pre_spawn_start(user, spawner)][] and diff --git a/examples/azuread-with-group-management/jupyterhub_config.py b/examples/azuread-with-group-management/jupyterhub_config.py index 9614dd70..f8da8746 100644 --- a/examples/azuread-with-group-management/jupyterhub_config.py +++ b/examples/azuread-with-group-management/jupyterhub_config.py @@ -19,7 +19,7 @@ c.AzureAdOAuthenticator.tenant_id = os.getenv("AAD_TENANT_ID") c.AzureAdOAuthenticator.username_claim = "email" c.AzureAdOAuthenticator.authorize_url = os.getenv("AAD_AUTHORIZE_URL") c.AzureAdOAuthenticator.token_url = os.getenv("AAD_TOKEN_URL") -c.Authenticator.authenticator_managed_groups = True +c.Authenticator.manage_groups = True c.Authenticator.refresh_pre_spawn = True # Optionally set a global password that all users must use diff --git a/jupyterhub/apihandlers/groups.py b/jupyterhub/apihandlers/groups.py index d587ec17..dbe9fec2 100644 --- a/jupyterhub/apihandlers/groups.py +++ b/jupyterhub/apihandlers/groups.py @@ -33,6 +33,11 @@ class _GroupAPIHandler(APIHandler): raise web.HTTPError(404, "No such group: %s", group_name) return group + def check_authenticator_managed_groups(self): + """Raise error on group-management APIs if Authenticator is managing groups""" + if self.authenticator.manage_groups: + raise web.HTTPError(400, "Group management via API is disabled") + class GroupListAPIHandler(_GroupAPIHandler): @needs_scope('list:groups') @@ -69,8 +74,7 @@ class GroupListAPIHandler(_GroupAPIHandler): async def post(self): """POST creates Multiple groups""" - if self.authenticator.manage_groups: - raise web.HTTPError(400, "Group management via API is disabled") + self.check_authenticator_managed_groups() model = self.get_json_body() if not model or not isinstance(model, dict) or not model.get('groups'): @@ -110,6 +114,7 @@ class GroupAPIHandler(_GroupAPIHandler): @needs_scope('admin:groups') async def post(self, group_name): """POST creates a group by name""" + self.check_authenticator_managed_groups() model = self.get_json_body() if model is None: model = {} @@ -136,6 +141,7 @@ class GroupAPIHandler(_GroupAPIHandler): @needs_scope('delete:groups') def delete(self, group_name): """Delete a group by name""" + self.check_authenticator_managed_groups() group = self.find_group(group_name) self.log.info("Deleting group %s", group_name) self.db.delete(group) @@ -149,6 +155,7 @@ class GroupUsersAPIHandler(_GroupAPIHandler): @needs_scope('groups') def post(self, group_name): """POST adds users to a group""" + self.check_authenticator_managed_groups() group = self.find_group(group_name) data = self.get_json_body() self._check_group_model(data) @@ -167,6 +174,7 @@ class GroupUsersAPIHandler(_GroupAPIHandler): @needs_scope('groups') async def delete(self, group_name): """DELETE removes users from a group""" + self.check_authenticator_managed_groups() group = self.find_group(group_name) data = self.get_json_body() self._check_group_model(data) diff --git a/jupyterhub/auth.py b/jupyterhub/auth.py index 403bc43d..1ac38988 100644 --- a/jupyterhub/auth.py +++ b/jupyterhub/auth.py @@ -582,9 +582,13 @@ class Authenticator(LoggingConfigurable): or None if Authentication failed. The Authenticator may return a dict instead, which MUST have a - key `name` holding the username, and MAY have two optional keys - set: `auth_state`, a dictionary of of auth state that will be - persisted; and `admin`, the admin setting value for the user. + key `name` holding the username, and MAY have additional keys: + + - `auth_state`, a dictionary of of auth state that will be + persisted; + - `admin`, the admin setting value for the user + - `groups`, the list of group names the user should be a member of, + if Authenticator.manage_groups is True. """ def pre_spawn_start(self, user, spawner): @@ -639,26 +643,15 @@ class Authenticator(LoggingConfigurable): False, config=True, help="""Let authenticator manage user groups - - Authenticator must implement get_user_groups for this to be useful. + + If True, Authenticator.authenticate and/or .refresh_user + may return a list of group names in the 'groups' field, + which will be assigned to the user. + + All group-assignment APIs are disabled if this is True. """, ) - def load_user_groups(self, user, auth_state): - """Hook called allowing authenticator to read user groups - - Updates user group memberships - - Args: - auth_state (dict): Proprietary dict returned by authenticator - user(User): the User object associated with the auth-state - - Returns: - groups (list): - List of user group memberships - """ - return None - auto_login = Bool( False, config=True, diff --git a/jupyterhub/handlers/base.py b/jupyterhub/handlers/base.py index f8dbe797..f3573fdc 100644 --- a/jupyterhub/handlers/base.py +++ b/jupyterhub/handlers/base.py @@ -622,9 +622,6 @@ class BaseHandler(RequestHandler): def authenticate(self, data): return maybe_future(self.authenticator.get_authenticated_user(self, data)) - def load_user_groups(self, user, auth_info): - return maybe_future(self.authenticator.load_user_groups(user, auth_info)) - def get_next_url(self, user=None, default=None): """Get the next_url for login redirect @@ -776,6 +773,13 @@ class BaseHandler(RequestHandler): # always ensure default roles ('user', 'admin' if admin) are assigned # after a successful login roles.assign_default_roles(self.db, entity=user) + + # apply authenticator-managed groups + if self.authenticator.manage_groups: + group_names = authenticated.get("groups") + if group_names is not None: + user.sync_groups(group_names) + # always set auth_state and commit, # because there could be key-rotation or clearing of previous values # going on. @@ -783,12 +787,6 @@ class BaseHandler(RequestHandler): # auth_state is not enabled. Force None. auth_state = None - if self.authenticator.manage_groups: - # Run authenticator user-group reload hook - user_groups = await self.load_user_groups(user, authenticated) - if user_groups is not None: - user.sync_groups(user_groups) - await user.save_auth_state(auth_state) return user diff --git a/jupyterhub/tests/test_api.py b/jupyterhub/tests/test_api.py index 1397ea63..cc63e78d 100644 --- a/jupyterhub/tests/test_api.py +++ b/jupyterhub/tests/test_api.py @@ -1806,6 +1806,38 @@ async def test_group_add_delete_users(app): assert sorted(u.name for u in group.users) == sorted(names[2:]) +@mark.group +async def test_auth_managed_groups(request, app, group, user): + group.users.append(user) + app.db.commit() + app.authenticator.manage_groups = True + request.addfinalizer(lambda: setattr(app.authenticator, "manage_groups", False)) + # create groups + r = await api_request(app, 'groups', method='post') + assert r.status_code == 400 + r = await api_request(app, 'groups/newgroup', method='post') + assert r.status_code == 400 + # delete groups + r = await api_request(app, f'groups/{group.name}', method='delete') + assert r.status_code == 400 + # add users to group + r = await api_request( + app, + f'groups/{group.name}/users', + method='post', + data=json.dumps({"users": [user.name]}), + ) + assert r.status_code == 400 + # remove users from group + r = await api_request( + app, + f'groups/{group.name}/users', + method='delete', + data=json.dumps({"users": [user.name]}), + ) + assert r.status_code == 400 + + # ----------------- # Service API tests # ----------------- diff --git a/jupyterhub/tests/test_auth.py b/jupyterhub/tests/test_auth.py index 1f627426..667a046d 100644 --- a/jupyterhub/tests/test_auth.py +++ b/jupyterhub/tests/test_auth.py @@ -7,6 +7,7 @@ from urllib.parse import urlparse import pytest from requests import HTTPError +from traitlets import Any from traitlets.config import Config from .mocking import MockPAMAuthenticator @@ -14,6 +15,7 @@ from .mocking import MockStructGroup from .mocking import MockStructPasswd from .utils import add_user from .utils import async_requests +from .utils import get_page from .utils import public_url from jupyterhub import auth from jupyterhub import crypto @@ -527,3 +529,71 @@ async def test_nullauthenticator(app): r = await async_requests.get(public_url(app)) assert urlparse(r.url).path.endswith("/hub/login") assert r.status_code == 403 + + +class MockGroupsAuthenticator(auth.Authenticator): + authenticated_groups = Any() + refresh_groups = Any() + + manage_groups = True + + def authenticate(self, handler, data): + return { + "name": data["username"], + "groups": self.authenticated_groups, + } + + async def refresh_user(self, user, handler): + return { + "name": user.name, + "groups": self.refresh_groups, + } + + +@pytest.mark.parametrize( + "authenticated_groups, refresh_groups", + [ + (None, None), + (["auth1"], None), + (None, ["auth1"]), + (["auth1"], ["auth1", "auth2"]), + (["auth1", "auth2"], ["auth1"]), + (["auth1", "auth2"], ["auth3"]), + (["auth1", "auth2"], ["auth3"]), + ], +) +async def test_auth_managed_groups( + app, user, group, authenticated_groups, refresh_groups +): + + authenticator = MockGroupsAuthenticator( + parent=app, + authenticated_groups=authenticated_groups, + refresh_groups=refresh_groups, + ) + + user.groups.append(group) + app.db.commit() + before_groups = [group.name] + if authenticated_groups is None: + expected_authenticated_groups = before_groups + else: + expected_authenticated_groups = authenticated_groups + if refresh_groups is None: + expected_refresh_groups = expected_authenticated_groups + else: + expected_refresh_groups = refresh_groups + + with mock.patch.dict(app.tornado_settings, {"authenticator": authenticator}): + cookies = await app.login_user(user.name) + assert not app.db.dirty + groups = sorted(g.name for g in user.groups) + assert groups == expected_authenticated_groups + + # force refresh_user on next request + user._auth_refreshed -= 10 + app.authenticator.auth_refresh_age + r = await get_page('home', app, cookies=cookies, allow_redirects=False) + assert r.status_code == 200 + assert not app.db.dirty + groups = sorted(g.name for g in user.groups) + assert groups == expected_refresh_groups diff --git a/jupyterhub/tests/test_user.py b/jupyterhub/tests/test_user.py index df7bc8a2..61e0270b 100644 --- a/jupyterhub/tests/test_user.py +++ b/jupyterhub/tests/test_user.py @@ -1,5 +1,6 @@ import pytest +from .. import orm from ..user import UserDict from .utils import add_user @@ -20,3 +21,35 @@ async def test_userdict_get(db, attr): assert userdict.get(key).id == u.id # `in` should find it now assert key in userdict + + +@pytest.mark.parametrize( + "group_names", + [ + ["isin1", "isin2"], + ["isin1"], + ["notin", "isin1"], + ["new-group", "isin1"], + [], + ], +) +def test_sync_groups(app, user, group_names): + expected = sorted(group_names) + db = app.db + db.add(orm.Group(name="notin")) + in_groups = [orm.Group(name="isin1"), orm.Group(name="isin2")] + for group in in_groups: + db.add(group) + db.commit() + user.groups = in_groups + db.commit() + user.sync_groups(group_names) + assert not app.db.dirty + after_groups = sorted(g.name for g in user.groups) + assert after_groups == expected + # double-check backref + for group in db.query(orm.Group): + if group.name in expected: + assert user.orm_user in group.users + else: + assert user.orm_user not in group.users diff --git a/jupyterhub/user.py b/jupyterhub/user.py index 0e314594..65e2d4cf 100644 --- a/jupyterhub/user.py +++ b/jupyterhub/user.py @@ -253,18 +253,41 @@ class User: def spawner_class(self): return self.settings.get('spawner_class', LocalProcessSpawner) - def sync_groups(self, user_groups): - """Syncronize groups with database""" + def sync_groups(self, group_names): + """Synchronize groups with database""" - if user_groups: + current_groups = {g.name for g in self.orm_user.groups} + new_groups = set(group_names) + if current_groups == new_groups: + # no change, nothing to do + return + + # log group changes + new_groups = set(group_names).difference(current_groups) + removed_groups = current_groups.difference(group_names) + if new_groups: + self.log.info("Adding user {self.name} to group(s): {new_groups}") + if removed_groups: + self.log.info("Removing user {self.name} from group(s): {removed_groups}") + + if group_names: groups = ( - self.db.query(orm.Group).filter(orm.Group.name.in_(user_groups)).all() + self.db.query(orm.Group).filter(orm.Group.name.in_(group_names)).all() ) - groups = {g.name: g for g in groups} - - self.groups = [groups.get(g, orm.Group(name=g)) for g in user_groups] + existing_groups = {g.name for g in groups} + for group_name in group_names: + if group_name not in existing_groups: + # create groups that don't exist yet + self.log.info( + f"Creating new group {group_name} for user {self.name}" + ) + group = orm.Group(name=group_name) + self.db.add(group) + groups.append(group) + self.groups = groups else: self.groups = [] + self.db.commit() async def save_auth_state(self, auth_state): """Encrypt and store auth_state"""