test coverage for Authenticator.managed_groups

- tests
- docs
- ensure all group APIs are rejected when auth is in control
- use 'groups' field in return value of authenticate/refresh_user, instead of defining new method
- log group changes in sync_groups
This commit is contained in:
Min RK
2022-01-20 14:44:47 +01:00
parent 144abcb965
commit 88be7a9967
9 changed files with 226 additions and 39 deletions

View File

@@ -247,6 +247,36 @@ class MyAuthenticator(Authenticator):
spawner.environment['UPSTREAM_TOKEN'] = auth_state['upstream_token'] spawner.environment['UPSTREAM_TOKEN'] = auth_state['upstream_token']
``` ```
## Authenticator-managed group membership
:::{versionadded} 2.2
:::
Some identity providers may have their own concept of group membership that you would like to preserve in JupyterHub.
This is now possible with `Authenticator.managed_groups`.
You can set the config:
```python
c.Authenticator.manage_groups = True
```
to enable this behavior.
The default is False for Authenticators that ship with JupyterHub,
but may be True for custom Authenticators.
Check your Authenticator's documentation for manage_groups support.
If True, {meth}`.Authenticator.authenticate` and {meth}`.Authenticator.refresh_user` may include a field `groups`
which is a list of group names the user should be a member of:
- Membership will be added for any group in the list
- Membership in any groups not in the list will be revoked
- Any groups not already present in the database will be created
- If `None` is returned, no changes are made to the user's group membership
If authenticator-managed groups are enabled,
all group-management via the API is disabled.
## pre_spawn_start and post_spawn_stop hooks ## pre_spawn_start and post_spawn_stop hooks
Authenticators uses two hooks, [pre_spawn_start(user, spawner)][] and Authenticators uses two hooks, [pre_spawn_start(user, spawner)][] and

View File

@@ -19,7 +19,7 @@ c.AzureAdOAuthenticator.tenant_id = os.getenv("AAD_TENANT_ID")
c.AzureAdOAuthenticator.username_claim = "email" c.AzureAdOAuthenticator.username_claim = "email"
c.AzureAdOAuthenticator.authorize_url = os.getenv("AAD_AUTHORIZE_URL") c.AzureAdOAuthenticator.authorize_url = os.getenv("AAD_AUTHORIZE_URL")
c.AzureAdOAuthenticator.token_url = os.getenv("AAD_TOKEN_URL") c.AzureAdOAuthenticator.token_url = os.getenv("AAD_TOKEN_URL")
c.Authenticator.authenticator_managed_groups = True c.Authenticator.manage_groups = True
c.Authenticator.refresh_pre_spawn = True c.Authenticator.refresh_pre_spawn = True
# Optionally set a global password that all users must use # Optionally set a global password that all users must use

View File

@@ -33,6 +33,11 @@ class _GroupAPIHandler(APIHandler):
raise web.HTTPError(404, "No such group: %s", group_name) raise web.HTTPError(404, "No such group: %s", group_name)
return group return group
def check_authenticator_managed_groups(self):
"""Raise error on group-management APIs if Authenticator is managing groups"""
if self.authenticator.manage_groups:
raise web.HTTPError(400, "Group management via API is disabled")
class GroupListAPIHandler(_GroupAPIHandler): class GroupListAPIHandler(_GroupAPIHandler):
@needs_scope('list:groups') @needs_scope('list:groups')
@@ -69,8 +74,7 @@ class GroupListAPIHandler(_GroupAPIHandler):
async def post(self): async def post(self):
"""POST creates Multiple groups""" """POST creates Multiple groups"""
if self.authenticator.manage_groups: self.check_authenticator_managed_groups()
raise web.HTTPError(400, "Group management via API is disabled")
model = self.get_json_body() model = self.get_json_body()
if not model or not isinstance(model, dict) or not model.get('groups'): if not model or not isinstance(model, dict) or not model.get('groups'):
@@ -110,6 +114,7 @@ class GroupAPIHandler(_GroupAPIHandler):
@needs_scope('admin:groups') @needs_scope('admin:groups')
async def post(self, group_name): async def post(self, group_name):
"""POST creates a group by name""" """POST creates a group by name"""
self.check_authenticator_managed_groups()
model = self.get_json_body() model = self.get_json_body()
if model is None: if model is None:
model = {} model = {}
@@ -136,6 +141,7 @@ class GroupAPIHandler(_GroupAPIHandler):
@needs_scope('delete:groups') @needs_scope('delete:groups')
def delete(self, group_name): def delete(self, group_name):
"""Delete a group by name""" """Delete a group by name"""
self.check_authenticator_managed_groups()
group = self.find_group(group_name) group = self.find_group(group_name)
self.log.info("Deleting group %s", group_name) self.log.info("Deleting group %s", group_name)
self.db.delete(group) self.db.delete(group)
@@ -149,6 +155,7 @@ class GroupUsersAPIHandler(_GroupAPIHandler):
@needs_scope('groups') @needs_scope('groups')
def post(self, group_name): def post(self, group_name):
"""POST adds users to a group""" """POST adds users to a group"""
self.check_authenticator_managed_groups()
group = self.find_group(group_name) group = self.find_group(group_name)
data = self.get_json_body() data = self.get_json_body()
self._check_group_model(data) self._check_group_model(data)
@@ -167,6 +174,7 @@ class GroupUsersAPIHandler(_GroupAPIHandler):
@needs_scope('groups') @needs_scope('groups')
async def delete(self, group_name): async def delete(self, group_name):
"""DELETE removes users from a group""" """DELETE removes users from a group"""
self.check_authenticator_managed_groups()
group = self.find_group(group_name) group = self.find_group(group_name)
data = self.get_json_body() data = self.get_json_body()
self._check_group_model(data) self._check_group_model(data)

View File

@@ -582,9 +582,13 @@ class Authenticator(LoggingConfigurable):
or None if Authentication failed. or None if Authentication failed.
The Authenticator may return a dict instead, which MUST have a The Authenticator may return a dict instead, which MUST have a
key `name` holding the username, and MAY have two optional keys key `name` holding the username, and MAY have additional keys:
set: `auth_state`, a dictionary of of auth state that will be
persisted; and `admin`, the admin setting value for the user. - `auth_state`, a dictionary of of auth state that will be
persisted;
- `admin`, the admin setting value for the user
- `groups`, the list of group names the user should be a member of,
if Authenticator.manage_groups is True.
""" """
def pre_spawn_start(self, user, spawner): def pre_spawn_start(self, user, spawner):
@@ -639,26 +643,15 @@ class Authenticator(LoggingConfigurable):
False, False,
config=True, config=True,
help="""Let authenticator manage user groups help="""Let authenticator manage user groups
Authenticator must implement get_user_groups for this to be useful. If True, Authenticator.authenticate and/or .refresh_user
may return a list of group names in the 'groups' field,
which will be assigned to the user.
All group-assignment APIs are disabled if this is True.
""", """,
) )
def load_user_groups(self, user, auth_state):
"""Hook called allowing authenticator to read user groups
Updates user group memberships
Args:
auth_state (dict): Proprietary dict returned by authenticator
user(User): the User object associated with the auth-state
Returns:
groups (list):
List of user group memberships
"""
return None
auto_login = Bool( auto_login = Bool(
False, False,
config=True, config=True,

View File

@@ -622,9 +622,6 @@ class BaseHandler(RequestHandler):
def authenticate(self, data): def authenticate(self, data):
return maybe_future(self.authenticator.get_authenticated_user(self, data)) return maybe_future(self.authenticator.get_authenticated_user(self, data))
def load_user_groups(self, user, auth_info):
return maybe_future(self.authenticator.load_user_groups(user, auth_info))
def get_next_url(self, user=None, default=None): def get_next_url(self, user=None, default=None):
"""Get the next_url for login redirect """Get the next_url for login redirect
@@ -776,6 +773,13 @@ class BaseHandler(RequestHandler):
# always ensure default roles ('user', 'admin' if admin) are assigned # always ensure default roles ('user', 'admin' if admin) are assigned
# after a successful login # after a successful login
roles.assign_default_roles(self.db, entity=user) roles.assign_default_roles(self.db, entity=user)
# apply authenticator-managed groups
if self.authenticator.manage_groups:
group_names = authenticated.get("groups")
if group_names is not None:
user.sync_groups(group_names)
# always set auth_state and commit, # always set auth_state and commit,
# because there could be key-rotation or clearing of previous values # because there could be key-rotation or clearing of previous values
# going on. # going on.
@@ -783,12 +787,6 @@ class BaseHandler(RequestHandler):
# auth_state is not enabled. Force None. # auth_state is not enabled. Force None.
auth_state = None auth_state = None
if self.authenticator.manage_groups:
# Run authenticator user-group reload hook
user_groups = await self.load_user_groups(user, authenticated)
if user_groups is not None:
user.sync_groups(user_groups)
await user.save_auth_state(auth_state) await user.save_auth_state(auth_state)
return user return user

View File

@@ -1806,6 +1806,38 @@ async def test_group_add_delete_users(app):
assert sorted(u.name for u in group.users) == sorted(names[2:]) assert sorted(u.name for u in group.users) == sorted(names[2:])
@mark.group
async def test_auth_managed_groups(request, app, group, user):
group.users.append(user)
app.db.commit()
app.authenticator.manage_groups = True
request.addfinalizer(lambda: setattr(app.authenticator, "manage_groups", False))
# create groups
r = await api_request(app, 'groups', method='post')
assert r.status_code == 400
r = await api_request(app, 'groups/newgroup', method='post')
assert r.status_code == 400
# delete groups
r = await api_request(app, f'groups/{group.name}', method='delete')
assert r.status_code == 400
# add users to group
r = await api_request(
app,
f'groups/{group.name}/users',
method='post',
data=json.dumps({"users": [user.name]}),
)
assert r.status_code == 400
# remove users from group
r = await api_request(
app,
f'groups/{group.name}/users',
method='delete',
data=json.dumps({"users": [user.name]}),
)
assert r.status_code == 400
# ----------------- # -----------------
# Service API tests # Service API tests
# ----------------- # -----------------

View File

@@ -7,6 +7,7 @@ from urllib.parse import urlparse
import pytest import pytest
from requests import HTTPError from requests import HTTPError
from traitlets import Any
from traitlets.config import Config from traitlets.config import Config
from .mocking import MockPAMAuthenticator from .mocking import MockPAMAuthenticator
@@ -14,6 +15,7 @@ from .mocking import MockStructGroup
from .mocking import MockStructPasswd from .mocking import MockStructPasswd
from .utils import add_user from .utils import add_user
from .utils import async_requests from .utils import async_requests
from .utils import get_page
from .utils import public_url from .utils import public_url
from jupyterhub import auth from jupyterhub import auth
from jupyterhub import crypto from jupyterhub import crypto
@@ -527,3 +529,71 @@ async def test_nullauthenticator(app):
r = await async_requests.get(public_url(app)) r = await async_requests.get(public_url(app))
assert urlparse(r.url).path.endswith("/hub/login") assert urlparse(r.url).path.endswith("/hub/login")
assert r.status_code == 403 assert r.status_code == 403
class MockGroupsAuthenticator(auth.Authenticator):
authenticated_groups = Any()
refresh_groups = Any()
manage_groups = True
def authenticate(self, handler, data):
return {
"name": data["username"],
"groups": self.authenticated_groups,
}
async def refresh_user(self, user, handler):
return {
"name": user.name,
"groups": self.refresh_groups,
}
@pytest.mark.parametrize(
"authenticated_groups, refresh_groups",
[
(None, None),
(["auth1"], None),
(None, ["auth1"]),
(["auth1"], ["auth1", "auth2"]),
(["auth1", "auth2"], ["auth1"]),
(["auth1", "auth2"], ["auth3"]),
(["auth1", "auth2"], ["auth3"]),
],
)
async def test_auth_managed_groups(
app, user, group, authenticated_groups, refresh_groups
):
authenticator = MockGroupsAuthenticator(
parent=app,
authenticated_groups=authenticated_groups,
refresh_groups=refresh_groups,
)
user.groups.append(group)
app.db.commit()
before_groups = [group.name]
if authenticated_groups is None:
expected_authenticated_groups = before_groups
else:
expected_authenticated_groups = authenticated_groups
if refresh_groups is None:
expected_refresh_groups = expected_authenticated_groups
else:
expected_refresh_groups = refresh_groups
with mock.patch.dict(app.tornado_settings, {"authenticator": authenticator}):
cookies = await app.login_user(user.name)
assert not app.db.dirty
groups = sorted(g.name for g in user.groups)
assert groups == expected_authenticated_groups
# force refresh_user on next request
user._auth_refreshed -= 10 + app.authenticator.auth_refresh_age
r = await get_page('home', app, cookies=cookies, allow_redirects=False)
assert r.status_code == 200
assert not app.db.dirty
groups = sorted(g.name for g in user.groups)
assert groups == expected_refresh_groups

View File

@@ -1,5 +1,6 @@
import pytest import pytest
from .. import orm
from ..user import UserDict from ..user import UserDict
from .utils import add_user from .utils import add_user
@@ -20,3 +21,35 @@ async def test_userdict_get(db, attr):
assert userdict.get(key).id == u.id assert userdict.get(key).id == u.id
# `in` should find it now # `in` should find it now
assert key in userdict assert key in userdict
@pytest.mark.parametrize(
"group_names",
[
["isin1", "isin2"],
["isin1"],
["notin", "isin1"],
["new-group", "isin1"],
[],
],
)
def test_sync_groups(app, user, group_names):
expected = sorted(group_names)
db = app.db
db.add(orm.Group(name="notin"))
in_groups = [orm.Group(name="isin1"), orm.Group(name="isin2")]
for group in in_groups:
db.add(group)
db.commit()
user.groups = in_groups
db.commit()
user.sync_groups(group_names)
assert not app.db.dirty
after_groups = sorted(g.name for g in user.groups)
assert after_groups == expected
# double-check backref
for group in db.query(orm.Group):
if group.name in expected:
assert user.orm_user in group.users
else:
assert user.orm_user not in group.users

View File

@@ -253,18 +253,41 @@ class User:
def spawner_class(self): def spawner_class(self):
return self.settings.get('spawner_class', LocalProcessSpawner) return self.settings.get('spawner_class', LocalProcessSpawner)
def sync_groups(self, user_groups): def sync_groups(self, group_names):
"""Syncronize groups with database""" """Synchronize groups with database"""
if user_groups: current_groups = {g.name for g in self.orm_user.groups}
new_groups = set(group_names)
if current_groups == new_groups:
# no change, nothing to do
return
# log group changes
new_groups = set(group_names).difference(current_groups)
removed_groups = current_groups.difference(group_names)
if new_groups:
self.log.info("Adding user {self.name} to group(s): {new_groups}")
if removed_groups:
self.log.info("Removing user {self.name} from group(s): {removed_groups}")
if group_names:
groups = ( groups = (
self.db.query(orm.Group).filter(orm.Group.name.in_(user_groups)).all() self.db.query(orm.Group).filter(orm.Group.name.in_(group_names)).all()
) )
groups = {g.name: g for g in groups} existing_groups = {g.name for g in groups}
for group_name in group_names:
self.groups = [groups.get(g, orm.Group(name=g)) for g in user_groups] if group_name not in existing_groups:
# create groups that don't exist yet
self.log.info(
f"Creating new group {group_name} for user {self.name}"
)
group = orm.Group(name=group_name)
self.db.add(group)
groups.append(group)
self.groups = groups
else: else:
self.groups = [] self.groups = []
self.db.commit()
async def save_auth_state(self, auth_state): async def save_auth_state(self, auth_state):
"""Encrypt and store auth_state""" """Encrypt and store auth_state"""