diff --git a/jupyterhub/roles.py b/jupyterhub/roles.py index 8328ecab..57cca8a9 100644 --- a/jupyterhub/roles.py +++ b/jupyterhub/roles.py @@ -232,7 +232,7 @@ def _existing_only(func): """Decorator for checking if roles exist""" @wraps(func) - def _check_existence(db, entity, role=None, *, rolename=None): + def _check_existence(db, entity, role=None, commit=True, *, rolename=None): if isinstance(role, str): rolename = role if rolename is not None: @@ -241,7 +241,7 @@ def _existing_only(func): if role is None: raise ValueError(f"Role {rolename} does not exist") - return func(db, entity, role) + return func(db, entity, role, commit=commit) return _check_existence diff --git a/jupyterhub/tests/test_auth.py b/jupyterhub/tests/test_auth.py index 69a41534..0a024b27 100644 --- a/jupyterhub/tests/test_auth.py +++ b/jupyterhub/tests/test_auth.py @@ -13,7 +13,7 @@ from traitlets.config import Config from jupyterhub import auth, crypto, orm -from .mocking import MockHub, MockPAMAuthenticator, MockStructGroup, MockStructPasswd +from .mocking import MockPAMAuthenticator, MockStructGroup, MockStructPasswd from .utils import add_user, async_requests, get_page, public_url @@ -595,10 +595,6 @@ async def test_auth_managed_groups( assert groups == expected_refresh_groups -def get_role_names(role_list): - return [role['name'] for role in role_list] - - class MockRolesAuthenticator(auth.Authenticator): authenticated_roles = Any() refresh_roles = Any() @@ -618,45 +614,44 @@ class MockRolesAuthenticator(auth.Authenticator): @pytest.mark.parametrize( - "authenticated_roles, refresh_roles", + "authenticated_roles", [ - ([{"name": "testrole-1", "users": "testuser-1"}], None), - ([{"name": "testrole-2", "users": "testuser-1"}], None), - ([{"name": "testrole-3", "users": "testuser-1"}], None), - ([{"name": "testrole-4", "users": "testuser-1"}], None), - ([{"name": "testrole-5", "users": "testuser-1"}], None), + (None), + ([{"name": "role-1"}]), + ([{"name": "role-2", "description": "test role 2"}]), + ([{"name": "role-3", "scopes": ["admin:servers"]}]), ], ) -async def test_auth_managed_roles(app, user, role, authenticated_roles, refresh_roles): - +async def test_auth_managed_roles(app, user, role, authenticated_roles): authenticator = MockRolesAuthenticator( parent=app, authenticated_roles=authenticated_roles, ) - hub = MockHub(db_url='sqlite:///jupyterhub.sqlite') user.roles.append(role) app.db.commit() - before_roles = role.name + before_roles = [ + { + 'name': r.name, + 'description': r.description, + 'scopes': r.scopes, + 'users': r.users, + } + for r in user.roles + ] if authenticated_roles is None: - expected_authenticated_roles = before_roles + expected_roles = before_roles else: - expected_authenticated_roles = authenticated_roles + expected_roles = authenticated_roles # Check if user gets auth-managed roles with mock.patch.dict(app.tornado_settings, {"authenticator": authenticator}): + await app.login_user(user.name) assert not app.db.dirty - all_roles = app.db.query(orm.Role).all() - user_roles = sorted(g.name for g in user.roles) - expected_authenticated_roles_names = get_role_names(expected_authenticated_roles) - for name in expected_authenticated_roles_names: - assert name in user_roles - role = orm.Role.find(app.db, name) - app.db.delete(role) - # Check if roles are deleted after restart - await hub.initialize() - all_roles = app.db.query(orm.Role).all() - all_roles_names = sorted(g.name for g in all_roles) - for name in expected_authenticated_roles_names: - assert name not in all_roles_names + assert len(user.roles) == len(expected_roles) + + for expected_role in expected_roles: + role = orm.Role.find(app.db, expected_role['name']) + assert role.description == expected_role.get('description', None) + assert len(role.scopes) == len(expected_role.get('scopes', [])) diff --git a/jupyterhub/user.py b/jupyterhub/user.py index 167c56fb..235962cd 100644 --- a/jupyterhub/user.py +++ b/jupyterhub/user.py @@ -303,10 +303,7 @@ class User: def sync_roles(self, auth_roles): """Synchronize roles with database""" - auth_roles_by_name = { - role['name']: role - for role in auth_roles - } + auth_roles_by_name = {role['name']: role for role in auth_roles} current_user_roles = {r.name for r in self.orm_user.roles} new_user_roles = set(auth_roles_by_name.keys()) @@ -336,8 +333,12 @@ class User: # creates role, or if it exists, update its `description` and `scopes` try: orm_role = roles.create_role(self.db, role, commit=False) - except (roles.RoleValueError, roles.InvalidNameError, scopes.ScopeNotFound) as e: - raise web.HTTPError(409, e.value) + except ( + roles.RoleValueError, + roles.InvalidNameError, + scopes.ScopeNotFound, + ) as e: + raise web.HTTPError(409, str(e)) # Update the groups, services and users for the role groups = [] @@ -366,11 +367,15 @@ class User: # assign the granted roles to the current user for role_name in granted_roles: - roles.grant_role(self.db, entity=self.orm_user, rolename=role_name, commit=False) + roles.grant_role( + self.db, entity=self.orm_user, rolename=role_name, commit=False + ) # strip the user of roles no longer directly granted for role_name in stripped_roles: - roles.strip_role(self.db, entity=self.orm_user, rolename=role_name, commit=False) + roles.strip_role( + self.db, entity=self.orm_user, rolename=role_name, commit=False + ) self.db.commit()