test coverage for Authenticator.managed_groups

- tests
- docs
- ensure all group APIs are rejected when auth is in control
- use 'groups' field in return value of authenticate/refresh_user, instead of defining new method
- log group changes in sync_groups
This commit is contained in:
Min RK
2022-01-20 14:44:47 +01:00
parent 144abcb965
commit 88be7a9967
9 changed files with 226 additions and 39 deletions

View File

@@ -247,6 +247,36 @@ class MyAuthenticator(Authenticator):
spawner.environment['UPSTREAM_TOKEN'] = auth_state['upstream_token']
```
## Authenticator-managed group membership
:::{versionadded} 2.2
:::
Some identity providers may have their own concept of group membership that you would like to preserve in JupyterHub.
This is now possible with `Authenticator.managed_groups`.
You can set the config:
```python
c.Authenticator.manage_groups = True
```
to enable this behavior.
The default is False for Authenticators that ship with JupyterHub,
but may be True for custom Authenticators.
Check your Authenticator's documentation for manage_groups support.
If True, {meth}`.Authenticator.authenticate` and {meth}`.Authenticator.refresh_user` may include a field `groups`
which is a list of group names the user should be a member of:
- Membership will be added for any group in the list
- Membership in any groups not in the list will be revoked
- Any groups not already present in the database will be created
- If `None` is returned, no changes are made to the user's group membership
If authenticator-managed groups are enabled,
all group-management via the API is disabled.
## pre_spawn_start and post_spawn_stop hooks
Authenticators uses two hooks, [pre_spawn_start(user, spawner)][] and

View File

@@ -19,7 +19,7 @@ c.AzureAdOAuthenticator.tenant_id = os.getenv("AAD_TENANT_ID")
c.AzureAdOAuthenticator.username_claim = "email"
c.AzureAdOAuthenticator.authorize_url = os.getenv("AAD_AUTHORIZE_URL")
c.AzureAdOAuthenticator.token_url = os.getenv("AAD_TOKEN_URL")
c.Authenticator.authenticator_managed_groups = True
c.Authenticator.manage_groups = True
c.Authenticator.refresh_pre_spawn = True
# Optionally set a global password that all users must use

View File

@@ -33,6 +33,11 @@ class _GroupAPIHandler(APIHandler):
raise web.HTTPError(404, "No such group: %s", group_name)
return group
def check_authenticator_managed_groups(self):
"""Raise error on group-management APIs if Authenticator is managing groups"""
if self.authenticator.manage_groups:
raise web.HTTPError(400, "Group management via API is disabled")
class GroupListAPIHandler(_GroupAPIHandler):
@needs_scope('list:groups')
@@ -69,8 +74,7 @@ class GroupListAPIHandler(_GroupAPIHandler):
async def post(self):
"""POST creates Multiple groups"""
if self.authenticator.manage_groups:
raise web.HTTPError(400, "Group management via API is disabled")
self.check_authenticator_managed_groups()
model = self.get_json_body()
if not model or not isinstance(model, dict) or not model.get('groups'):
@@ -110,6 +114,7 @@ class GroupAPIHandler(_GroupAPIHandler):
@needs_scope('admin:groups')
async def post(self, group_name):
"""POST creates a group by name"""
self.check_authenticator_managed_groups()
model = self.get_json_body()
if model is None:
model = {}
@@ -136,6 +141,7 @@ class GroupAPIHandler(_GroupAPIHandler):
@needs_scope('delete:groups')
def delete(self, group_name):
"""Delete a group by name"""
self.check_authenticator_managed_groups()
group = self.find_group(group_name)
self.log.info("Deleting group %s", group_name)
self.db.delete(group)
@@ -149,6 +155,7 @@ class GroupUsersAPIHandler(_GroupAPIHandler):
@needs_scope('groups')
def post(self, group_name):
"""POST adds users to a group"""
self.check_authenticator_managed_groups()
group = self.find_group(group_name)
data = self.get_json_body()
self._check_group_model(data)
@@ -167,6 +174,7 @@ class GroupUsersAPIHandler(_GroupAPIHandler):
@needs_scope('groups')
async def delete(self, group_name):
"""DELETE removes users from a group"""
self.check_authenticator_managed_groups()
group = self.find_group(group_name)
data = self.get_json_body()
self._check_group_model(data)

View File

@@ -582,9 +582,13 @@ class Authenticator(LoggingConfigurable):
or None if Authentication failed.
The Authenticator may return a dict instead, which MUST have a
key `name` holding the username, and MAY have two optional keys
set: `auth_state`, a dictionary of of auth state that will be
persisted; and `admin`, the admin setting value for the user.
key `name` holding the username, and MAY have additional keys:
- `auth_state`, a dictionary of of auth state that will be
persisted;
- `admin`, the admin setting value for the user
- `groups`, the list of group names the user should be a member of,
if Authenticator.manage_groups is True.
"""
def pre_spawn_start(self, user, spawner):
@@ -639,26 +643,15 @@ class Authenticator(LoggingConfigurable):
False,
config=True,
help="""Let authenticator manage user groups
Authenticator must implement get_user_groups for this to be useful.
If True, Authenticator.authenticate and/or .refresh_user
may return a list of group names in the 'groups' field,
which will be assigned to the user.
All group-assignment APIs are disabled if this is True.
""",
)
def load_user_groups(self, user, auth_state):
"""Hook called allowing authenticator to read user groups
Updates user group memberships
Args:
auth_state (dict): Proprietary dict returned by authenticator
user(User): the User object associated with the auth-state
Returns:
groups (list):
List of user group memberships
"""
return None
auto_login = Bool(
False,
config=True,

View File

@@ -622,9 +622,6 @@ class BaseHandler(RequestHandler):
def authenticate(self, data):
return maybe_future(self.authenticator.get_authenticated_user(self, data))
def load_user_groups(self, user, auth_info):
return maybe_future(self.authenticator.load_user_groups(user, auth_info))
def get_next_url(self, user=None, default=None):
"""Get the next_url for login redirect
@@ -776,6 +773,13 @@ class BaseHandler(RequestHandler):
# always ensure default roles ('user', 'admin' if admin) are assigned
# after a successful login
roles.assign_default_roles(self.db, entity=user)
# apply authenticator-managed groups
if self.authenticator.manage_groups:
group_names = authenticated.get("groups")
if group_names is not None:
user.sync_groups(group_names)
# always set auth_state and commit,
# because there could be key-rotation or clearing of previous values
# going on.
@@ -783,12 +787,6 @@ class BaseHandler(RequestHandler):
# auth_state is not enabled. Force None.
auth_state = None
if self.authenticator.manage_groups:
# Run authenticator user-group reload hook
user_groups = await self.load_user_groups(user, authenticated)
if user_groups is not None:
user.sync_groups(user_groups)
await user.save_auth_state(auth_state)
return user

View File

@@ -1806,6 +1806,38 @@ async def test_group_add_delete_users(app):
assert sorted(u.name for u in group.users) == sorted(names[2:])
@mark.group
async def test_auth_managed_groups(request, app, group, user):
group.users.append(user)
app.db.commit()
app.authenticator.manage_groups = True
request.addfinalizer(lambda: setattr(app.authenticator, "manage_groups", False))
# create groups
r = await api_request(app, 'groups', method='post')
assert r.status_code == 400
r = await api_request(app, 'groups/newgroup', method='post')
assert r.status_code == 400
# delete groups
r = await api_request(app, f'groups/{group.name}', method='delete')
assert r.status_code == 400
# add users to group
r = await api_request(
app,
f'groups/{group.name}/users',
method='post',
data=json.dumps({"users": [user.name]}),
)
assert r.status_code == 400
# remove users from group
r = await api_request(
app,
f'groups/{group.name}/users',
method='delete',
data=json.dumps({"users": [user.name]}),
)
assert r.status_code == 400
# -----------------
# Service API tests
# -----------------

View File

@@ -7,6 +7,7 @@ from urllib.parse import urlparse
import pytest
from requests import HTTPError
from traitlets import Any
from traitlets.config import Config
from .mocking import MockPAMAuthenticator
@@ -14,6 +15,7 @@ from .mocking import MockStructGroup
from .mocking import MockStructPasswd
from .utils import add_user
from .utils import async_requests
from .utils import get_page
from .utils import public_url
from jupyterhub import auth
from jupyterhub import crypto
@@ -527,3 +529,71 @@ async def test_nullauthenticator(app):
r = await async_requests.get(public_url(app))
assert urlparse(r.url).path.endswith("/hub/login")
assert r.status_code == 403
class MockGroupsAuthenticator(auth.Authenticator):
authenticated_groups = Any()
refresh_groups = Any()
manage_groups = True
def authenticate(self, handler, data):
return {
"name": data["username"],
"groups": self.authenticated_groups,
}
async def refresh_user(self, user, handler):
return {
"name": user.name,
"groups": self.refresh_groups,
}
@pytest.mark.parametrize(
"authenticated_groups, refresh_groups",
[
(None, None),
(["auth1"], None),
(None, ["auth1"]),
(["auth1"], ["auth1", "auth2"]),
(["auth1", "auth2"], ["auth1"]),
(["auth1", "auth2"], ["auth3"]),
(["auth1", "auth2"], ["auth3"]),
],
)
async def test_auth_managed_groups(
app, user, group, authenticated_groups, refresh_groups
):
authenticator = MockGroupsAuthenticator(
parent=app,
authenticated_groups=authenticated_groups,
refresh_groups=refresh_groups,
)
user.groups.append(group)
app.db.commit()
before_groups = [group.name]
if authenticated_groups is None:
expected_authenticated_groups = before_groups
else:
expected_authenticated_groups = authenticated_groups
if refresh_groups is None:
expected_refresh_groups = expected_authenticated_groups
else:
expected_refresh_groups = refresh_groups
with mock.patch.dict(app.tornado_settings, {"authenticator": authenticator}):
cookies = await app.login_user(user.name)
assert not app.db.dirty
groups = sorted(g.name for g in user.groups)
assert groups == expected_authenticated_groups
# force refresh_user on next request
user._auth_refreshed -= 10 + app.authenticator.auth_refresh_age
r = await get_page('home', app, cookies=cookies, allow_redirects=False)
assert r.status_code == 200
assert not app.db.dirty
groups = sorted(g.name for g in user.groups)
assert groups == expected_refresh_groups

View File

@@ -1,5 +1,6 @@
import pytest
from .. import orm
from ..user import UserDict
from .utils import add_user
@@ -20,3 +21,35 @@ async def test_userdict_get(db, attr):
assert userdict.get(key).id == u.id
# `in` should find it now
assert key in userdict
@pytest.mark.parametrize(
"group_names",
[
["isin1", "isin2"],
["isin1"],
["notin", "isin1"],
["new-group", "isin1"],
[],
],
)
def test_sync_groups(app, user, group_names):
expected = sorted(group_names)
db = app.db
db.add(orm.Group(name="notin"))
in_groups = [orm.Group(name="isin1"), orm.Group(name="isin2")]
for group in in_groups:
db.add(group)
db.commit()
user.groups = in_groups
db.commit()
user.sync_groups(group_names)
assert not app.db.dirty
after_groups = sorted(g.name for g in user.groups)
assert after_groups == expected
# double-check backref
for group in db.query(orm.Group):
if group.name in expected:
assert user.orm_user in group.users
else:
assert user.orm_user not in group.users

View File

@@ -253,18 +253,41 @@ class User:
def spawner_class(self):
return self.settings.get('spawner_class', LocalProcessSpawner)
def sync_groups(self, user_groups):
"""Syncronize groups with database"""
def sync_groups(self, group_names):
"""Synchronize groups with database"""
if user_groups:
current_groups = {g.name for g in self.orm_user.groups}
new_groups = set(group_names)
if current_groups == new_groups:
# no change, nothing to do
return
# log group changes
new_groups = set(group_names).difference(current_groups)
removed_groups = current_groups.difference(group_names)
if new_groups:
self.log.info("Adding user {self.name} to group(s): {new_groups}")
if removed_groups:
self.log.info("Removing user {self.name} from group(s): {removed_groups}")
if group_names:
groups = (
self.db.query(orm.Group).filter(orm.Group.name.in_(user_groups)).all()
self.db.query(orm.Group).filter(orm.Group.name.in_(group_names)).all()
)
groups = {g.name: g for g in groups}
self.groups = [groups.get(g, orm.Group(name=g)) for g in user_groups]
existing_groups = {g.name for g in groups}
for group_name in group_names:
if group_name not in existing_groups:
# create groups that don't exist yet
self.log.info(
f"Creating new group {group_name} for user {self.name}"
)
group = orm.Group(name=group_name)
self.db.add(group)
groups.append(group)
self.groups = groups
else:
self.groups = []
self.db.commit()
async def save_auth_state(self, auth_state):
"""Encrypt and store auth_state"""