mirror of
https://github.com/jupyterhub/jupyterhub.git
synced 2025-10-11 03:52:59 +00:00
test coverage for Authenticator.managed_groups
- tests - docs - ensure all group APIs are rejected when auth is in control - use 'groups' field in return value of authenticate/refresh_user, instead of defining new method - log group changes in sync_groups
This commit is contained in:
@@ -247,6 +247,36 @@ class MyAuthenticator(Authenticator):
|
||||
spawner.environment['UPSTREAM_TOKEN'] = auth_state['upstream_token']
|
||||
```
|
||||
|
||||
## Authenticator-managed group membership
|
||||
|
||||
:::{versionadded} 2.2
|
||||
:::
|
||||
|
||||
Some identity providers may have their own concept of group membership that you would like to preserve in JupyterHub.
|
||||
This is now possible with `Authenticator.managed_groups`.
|
||||
|
||||
You can set the config:
|
||||
|
||||
```python
|
||||
c.Authenticator.manage_groups = True
|
||||
```
|
||||
|
||||
to enable this behavior.
|
||||
The default is False for Authenticators that ship with JupyterHub,
|
||||
but may be True for custom Authenticators.
|
||||
Check your Authenticator's documentation for manage_groups support.
|
||||
|
||||
If True, {meth}`.Authenticator.authenticate` and {meth}`.Authenticator.refresh_user` may include a field `groups`
|
||||
which is a list of group names the user should be a member of:
|
||||
|
||||
- Membership will be added for any group in the list
|
||||
- Membership in any groups not in the list will be revoked
|
||||
- Any groups not already present in the database will be created
|
||||
- If `None` is returned, no changes are made to the user's group membership
|
||||
|
||||
If authenticator-managed groups are enabled,
|
||||
all group-management via the API is disabled.
|
||||
|
||||
## pre_spawn_start and post_spawn_stop hooks
|
||||
|
||||
Authenticators uses two hooks, [pre_spawn_start(user, spawner)][] and
|
||||
|
@@ -19,7 +19,7 @@ c.AzureAdOAuthenticator.tenant_id = os.getenv("AAD_TENANT_ID")
|
||||
c.AzureAdOAuthenticator.username_claim = "email"
|
||||
c.AzureAdOAuthenticator.authorize_url = os.getenv("AAD_AUTHORIZE_URL")
|
||||
c.AzureAdOAuthenticator.token_url = os.getenv("AAD_TOKEN_URL")
|
||||
c.Authenticator.authenticator_managed_groups = True
|
||||
c.Authenticator.manage_groups = True
|
||||
c.Authenticator.refresh_pre_spawn = True
|
||||
|
||||
# Optionally set a global password that all users must use
|
||||
|
@@ -33,6 +33,11 @@ class _GroupAPIHandler(APIHandler):
|
||||
raise web.HTTPError(404, "No such group: %s", group_name)
|
||||
return group
|
||||
|
||||
def check_authenticator_managed_groups(self):
|
||||
"""Raise error on group-management APIs if Authenticator is managing groups"""
|
||||
if self.authenticator.manage_groups:
|
||||
raise web.HTTPError(400, "Group management via API is disabled")
|
||||
|
||||
|
||||
class GroupListAPIHandler(_GroupAPIHandler):
|
||||
@needs_scope('list:groups')
|
||||
@@ -69,8 +74,7 @@ class GroupListAPIHandler(_GroupAPIHandler):
|
||||
async def post(self):
|
||||
"""POST creates Multiple groups"""
|
||||
|
||||
if self.authenticator.manage_groups:
|
||||
raise web.HTTPError(400, "Group management via API is disabled")
|
||||
self.check_authenticator_managed_groups()
|
||||
|
||||
model = self.get_json_body()
|
||||
if not model or not isinstance(model, dict) or not model.get('groups'):
|
||||
@@ -110,6 +114,7 @@ class GroupAPIHandler(_GroupAPIHandler):
|
||||
@needs_scope('admin:groups')
|
||||
async def post(self, group_name):
|
||||
"""POST creates a group by name"""
|
||||
self.check_authenticator_managed_groups()
|
||||
model = self.get_json_body()
|
||||
if model is None:
|
||||
model = {}
|
||||
@@ -136,6 +141,7 @@ class GroupAPIHandler(_GroupAPIHandler):
|
||||
@needs_scope('delete:groups')
|
||||
def delete(self, group_name):
|
||||
"""Delete a group by name"""
|
||||
self.check_authenticator_managed_groups()
|
||||
group = self.find_group(group_name)
|
||||
self.log.info("Deleting group %s", group_name)
|
||||
self.db.delete(group)
|
||||
@@ -149,6 +155,7 @@ class GroupUsersAPIHandler(_GroupAPIHandler):
|
||||
@needs_scope('groups')
|
||||
def post(self, group_name):
|
||||
"""POST adds users to a group"""
|
||||
self.check_authenticator_managed_groups()
|
||||
group = self.find_group(group_name)
|
||||
data = self.get_json_body()
|
||||
self._check_group_model(data)
|
||||
@@ -167,6 +174,7 @@ class GroupUsersAPIHandler(_GroupAPIHandler):
|
||||
@needs_scope('groups')
|
||||
async def delete(self, group_name):
|
||||
"""DELETE removes users from a group"""
|
||||
self.check_authenticator_managed_groups()
|
||||
group = self.find_group(group_name)
|
||||
data = self.get_json_body()
|
||||
self._check_group_model(data)
|
||||
|
@@ -582,9 +582,13 @@ class Authenticator(LoggingConfigurable):
|
||||
or None if Authentication failed.
|
||||
|
||||
The Authenticator may return a dict instead, which MUST have a
|
||||
key `name` holding the username, and MAY have two optional keys
|
||||
set: `auth_state`, a dictionary of of auth state that will be
|
||||
persisted; and `admin`, the admin setting value for the user.
|
||||
key `name` holding the username, and MAY have additional keys:
|
||||
|
||||
- `auth_state`, a dictionary of of auth state that will be
|
||||
persisted;
|
||||
- `admin`, the admin setting value for the user
|
||||
- `groups`, the list of group names the user should be a member of,
|
||||
if Authenticator.manage_groups is True.
|
||||
"""
|
||||
|
||||
def pre_spawn_start(self, user, spawner):
|
||||
@@ -640,25 +644,14 @@ class Authenticator(LoggingConfigurable):
|
||||
config=True,
|
||||
help="""Let authenticator manage user groups
|
||||
|
||||
Authenticator must implement get_user_groups for this to be useful.
|
||||
If True, Authenticator.authenticate and/or .refresh_user
|
||||
may return a list of group names in the 'groups' field,
|
||||
which will be assigned to the user.
|
||||
|
||||
All group-assignment APIs are disabled if this is True.
|
||||
""",
|
||||
)
|
||||
|
||||
def load_user_groups(self, user, auth_state):
|
||||
"""Hook called allowing authenticator to read user groups
|
||||
|
||||
Updates user group memberships
|
||||
|
||||
Args:
|
||||
auth_state (dict): Proprietary dict returned by authenticator
|
||||
user(User): the User object associated with the auth-state
|
||||
|
||||
Returns:
|
||||
groups (list):
|
||||
List of user group memberships
|
||||
"""
|
||||
return None
|
||||
|
||||
auto_login = Bool(
|
||||
False,
|
||||
config=True,
|
||||
|
@@ -622,9 +622,6 @@ class BaseHandler(RequestHandler):
|
||||
def authenticate(self, data):
|
||||
return maybe_future(self.authenticator.get_authenticated_user(self, data))
|
||||
|
||||
def load_user_groups(self, user, auth_info):
|
||||
return maybe_future(self.authenticator.load_user_groups(user, auth_info))
|
||||
|
||||
def get_next_url(self, user=None, default=None):
|
||||
"""Get the next_url for login redirect
|
||||
|
||||
@@ -776,6 +773,13 @@ class BaseHandler(RequestHandler):
|
||||
# always ensure default roles ('user', 'admin' if admin) are assigned
|
||||
# after a successful login
|
||||
roles.assign_default_roles(self.db, entity=user)
|
||||
|
||||
# apply authenticator-managed groups
|
||||
if self.authenticator.manage_groups:
|
||||
group_names = authenticated.get("groups")
|
||||
if group_names is not None:
|
||||
user.sync_groups(group_names)
|
||||
|
||||
# always set auth_state and commit,
|
||||
# because there could be key-rotation or clearing of previous values
|
||||
# going on.
|
||||
@@ -783,12 +787,6 @@ class BaseHandler(RequestHandler):
|
||||
# auth_state is not enabled. Force None.
|
||||
auth_state = None
|
||||
|
||||
if self.authenticator.manage_groups:
|
||||
# Run authenticator user-group reload hook
|
||||
user_groups = await self.load_user_groups(user, authenticated)
|
||||
if user_groups is not None:
|
||||
user.sync_groups(user_groups)
|
||||
|
||||
await user.save_auth_state(auth_state)
|
||||
|
||||
return user
|
||||
|
@@ -1806,6 +1806,38 @@ async def test_group_add_delete_users(app):
|
||||
assert sorted(u.name for u in group.users) == sorted(names[2:])
|
||||
|
||||
|
||||
@mark.group
|
||||
async def test_auth_managed_groups(request, app, group, user):
|
||||
group.users.append(user)
|
||||
app.db.commit()
|
||||
app.authenticator.manage_groups = True
|
||||
request.addfinalizer(lambda: setattr(app.authenticator, "manage_groups", False))
|
||||
# create groups
|
||||
r = await api_request(app, 'groups', method='post')
|
||||
assert r.status_code == 400
|
||||
r = await api_request(app, 'groups/newgroup', method='post')
|
||||
assert r.status_code == 400
|
||||
# delete groups
|
||||
r = await api_request(app, f'groups/{group.name}', method='delete')
|
||||
assert r.status_code == 400
|
||||
# add users to group
|
||||
r = await api_request(
|
||||
app,
|
||||
f'groups/{group.name}/users',
|
||||
method='post',
|
||||
data=json.dumps({"users": [user.name]}),
|
||||
)
|
||||
assert r.status_code == 400
|
||||
# remove users from group
|
||||
r = await api_request(
|
||||
app,
|
||||
f'groups/{group.name}/users',
|
||||
method='delete',
|
||||
data=json.dumps({"users": [user.name]}),
|
||||
)
|
||||
assert r.status_code == 400
|
||||
|
||||
|
||||
# -----------------
|
||||
# Service API tests
|
||||
# -----------------
|
||||
|
@@ -7,6 +7,7 @@ from urllib.parse import urlparse
|
||||
|
||||
import pytest
|
||||
from requests import HTTPError
|
||||
from traitlets import Any
|
||||
from traitlets.config import Config
|
||||
|
||||
from .mocking import MockPAMAuthenticator
|
||||
@@ -14,6 +15,7 @@ from .mocking import MockStructGroup
|
||||
from .mocking import MockStructPasswd
|
||||
from .utils import add_user
|
||||
from .utils import async_requests
|
||||
from .utils import get_page
|
||||
from .utils import public_url
|
||||
from jupyterhub import auth
|
||||
from jupyterhub import crypto
|
||||
@@ -527,3 +529,71 @@ async def test_nullauthenticator(app):
|
||||
r = await async_requests.get(public_url(app))
|
||||
assert urlparse(r.url).path.endswith("/hub/login")
|
||||
assert r.status_code == 403
|
||||
|
||||
|
||||
class MockGroupsAuthenticator(auth.Authenticator):
|
||||
authenticated_groups = Any()
|
||||
refresh_groups = Any()
|
||||
|
||||
manage_groups = True
|
||||
|
||||
def authenticate(self, handler, data):
|
||||
return {
|
||||
"name": data["username"],
|
||||
"groups": self.authenticated_groups,
|
||||
}
|
||||
|
||||
async def refresh_user(self, user, handler):
|
||||
return {
|
||||
"name": user.name,
|
||||
"groups": self.refresh_groups,
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"authenticated_groups, refresh_groups",
|
||||
[
|
||||
(None, None),
|
||||
(["auth1"], None),
|
||||
(None, ["auth1"]),
|
||||
(["auth1"], ["auth1", "auth2"]),
|
||||
(["auth1", "auth2"], ["auth1"]),
|
||||
(["auth1", "auth2"], ["auth3"]),
|
||||
(["auth1", "auth2"], ["auth3"]),
|
||||
],
|
||||
)
|
||||
async def test_auth_managed_groups(
|
||||
app, user, group, authenticated_groups, refresh_groups
|
||||
):
|
||||
|
||||
authenticator = MockGroupsAuthenticator(
|
||||
parent=app,
|
||||
authenticated_groups=authenticated_groups,
|
||||
refresh_groups=refresh_groups,
|
||||
)
|
||||
|
||||
user.groups.append(group)
|
||||
app.db.commit()
|
||||
before_groups = [group.name]
|
||||
if authenticated_groups is None:
|
||||
expected_authenticated_groups = before_groups
|
||||
else:
|
||||
expected_authenticated_groups = authenticated_groups
|
||||
if refresh_groups is None:
|
||||
expected_refresh_groups = expected_authenticated_groups
|
||||
else:
|
||||
expected_refresh_groups = refresh_groups
|
||||
|
||||
with mock.patch.dict(app.tornado_settings, {"authenticator": authenticator}):
|
||||
cookies = await app.login_user(user.name)
|
||||
assert not app.db.dirty
|
||||
groups = sorted(g.name for g in user.groups)
|
||||
assert groups == expected_authenticated_groups
|
||||
|
||||
# force refresh_user on next request
|
||||
user._auth_refreshed -= 10 + app.authenticator.auth_refresh_age
|
||||
r = await get_page('home', app, cookies=cookies, allow_redirects=False)
|
||||
assert r.status_code == 200
|
||||
assert not app.db.dirty
|
||||
groups = sorted(g.name for g in user.groups)
|
||||
assert groups == expected_refresh_groups
|
||||
|
@@ -1,5 +1,6 @@
|
||||
import pytest
|
||||
|
||||
from .. import orm
|
||||
from ..user import UserDict
|
||||
from .utils import add_user
|
||||
|
||||
@@ -20,3 +21,35 @@ async def test_userdict_get(db, attr):
|
||||
assert userdict.get(key).id == u.id
|
||||
# `in` should find it now
|
||||
assert key in userdict
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"group_names",
|
||||
[
|
||||
["isin1", "isin2"],
|
||||
["isin1"],
|
||||
["notin", "isin1"],
|
||||
["new-group", "isin1"],
|
||||
[],
|
||||
],
|
||||
)
|
||||
def test_sync_groups(app, user, group_names):
|
||||
expected = sorted(group_names)
|
||||
db = app.db
|
||||
db.add(orm.Group(name="notin"))
|
||||
in_groups = [orm.Group(name="isin1"), orm.Group(name="isin2")]
|
||||
for group in in_groups:
|
||||
db.add(group)
|
||||
db.commit()
|
||||
user.groups = in_groups
|
||||
db.commit()
|
||||
user.sync_groups(group_names)
|
||||
assert not app.db.dirty
|
||||
after_groups = sorted(g.name for g in user.groups)
|
||||
assert after_groups == expected
|
||||
# double-check backref
|
||||
for group in db.query(orm.Group):
|
||||
if group.name in expected:
|
||||
assert user.orm_user in group.users
|
||||
else:
|
||||
assert user.orm_user not in group.users
|
||||
|
@@ -253,18 +253,41 @@ class User:
|
||||
def spawner_class(self):
|
||||
return self.settings.get('spawner_class', LocalProcessSpawner)
|
||||
|
||||
def sync_groups(self, user_groups):
|
||||
"""Syncronize groups with database"""
|
||||
def sync_groups(self, group_names):
|
||||
"""Synchronize groups with database"""
|
||||
|
||||
if user_groups:
|
||||
current_groups = {g.name for g in self.orm_user.groups}
|
||||
new_groups = set(group_names)
|
||||
if current_groups == new_groups:
|
||||
# no change, nothing to do
|
||||
return
|
||||
|
||||
# log group changes
|
||||
new_groups = set(group_names).difference(current_groups)
|
||||
removed_groups = current_groups.difference(group_names)
|
||||
if new_groups:
|
||||
self.log.info("Adding user {self.name} to group(s): {new_groups}")
|
||||
if removed_groups:
|
||||
self.log.info("Removing user {self.name} from group(s): {removed_groups}")
|
||||
|
||||
if group_names:
|
||||
groups = (
|
||||
self.db.query(orm.Group).filter(orm.Group.name.in_(user_groups)).all()
|
||||
self.db.query(orm.Group).filter(orm.Group.name.in_(group_names)).all()
|
||||
)
|
||||
groups = {g.name: g for g in groups}
|
||||
|
||||
self.groups = [groups.get(g, orm.Group(name=g)) for g in user_groups]
|
||||
existing_groups = {g.name for g in groups}
|
||||
for group_name in group_names:
|
||||
if group_name not in existing_groups:
|
||||
# create groups that don't exist yet
|
||||
self.log.info(
|
||||
f"Creating new group {group_name} for user {self.name}"
|
||||
)
|
||||
group = orm.Group(name=group_name)
|
||||
self.db.add(group)
|
||||
groups.append(group)
|
||||
self.groups = groups
|
||||
else:
|
||||
self.groups = []
|
||||
self.db.commit()
|
||||
|
||||
async def save_auth_state(self, auth_state):
|
||||
"""Encrypt and store auth_state"""
|
||||
|
Reference in New Issue
Block a user