mirror of
https://github.com/jupyterhub/jupyterhub.git
synced 2025-10-14 13:33:00 +00:00
Fix tests, passing commit
arg in decorator,
and extracting message from exceptions. Also, lint.
This commit is contained in:
@@ -232,7 +232,7 @@ def _existing_only(func):
|
||||
"""Decorator for checking if roles exist"""
|
||||
|
||||
@wraps(func)
|
||||
def _check_existence(db, entity, role=None, *, rolename=None):
|
||||
def _check_existence(db, entity, role=None, commit=True, *, rolename=None):
|
||||
if isinstance(role, str):
|
||||
rolename = role
|
||||
if rolename is not None:
|
||||
@@ -241,7 +241,7 @@ def _existing_only(func):
|
||||
if role is None:
|
||||
raise ValueError(f"Role {rolename} does not exist")
|
||||
|
||||
return func(db, entity, role)
|
||||
return func(db, entity, role, commit=commit)
|
||||
|
||||
return _check_existence
|
||||
|
||||
|
@@ -13,7 +13,7 @@ from traitlets.config import Config
|
||||
|
||||
from jupyterhub import auth, crypto, orm
|
||||
|
||||
from .mocking import MockHub, MockPAMAuthenticator, MockStructGroup, MockStructPasswd
|
||||
from .mocking import MockPAMAuthenticator, MockStructGroup, MockStructPasswd
|
||||
from .utils import add_user, async_requests, get_page, public_url
|
||||
|
||||
|
||||
@@ -595,10 +595,6 @@ async def test_auth_managed_groups(
|
||||
assert groups == expected_refresh_groups
|
||||
|
||||
|
||||
def get_role_names(role_list):
|
||||
return [role['name'] for role in role_list]
|
||||
|
||||
|
||||
class MockRolesAuthenticator(auth.Authenticator):
|
||||
authenticated_roles = Any()
|
||||
refresh_roles = Any()
|
||||
@@ -618,45 +614,44 @@ class MockRolesAuthenticator(auth.Authenticator):
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"authenticated_roles, refresh_roles",
|
||||
"authenticated_roles",
|
||||
[
|
||||
([{"name": "testrole-1", "users": "testuser-1"}], None),
|
||||
([{"name": "testrole-2", "users": "testuser-1"}], None),
|
||||
([{"name": "testrole-3", "users": "testuser-1"}], None),
|
||||
([{"name": "testrole-4", "users": "testuser-1"}], None),
|
||||
([{"name": "testrole-5", "users": "testuser-1"}], None),
|
||||
(None),
|
||||
([{"name": "role-1"}]),
|
||||
([{"name": "role-2", "description": "test role 2"}]),
|
||||
([{"name": "role-3", "scopes": ["admin:servers"]}]),
|
||||
],
|
||||
)
|
||||
async def test_auth_managed_roles(app, user, role, authenticated_roles, refresh_roles):
|
||||
|
||||
async def test_auth_managed_roles(app, user, role, authenticated_roles):
|
||||
authenticator = MockRolesAuthenticator(
|
||||
parent=app,
|
||||
authenticated_roles=authenticated_roles,
|
||||
)
|
||||
hub = MockHub(db_url='sqlite:///jupyterhub.sqlite')
|
||||
user.roles.append(role)
|
||||
app.db.commit()
|
||||
before_roles = role.name
|
||||
before_roles = [
|
||||
{
|
||||
'name': r.name,
|
||||
'description': r.description,
|
||||
'scopes': r.scopes,
|
||||
'users': r.users,
|
||||
}
|
||||
for r in user.roles
|
||||
]
|
||||
|
||||
if authenticated_roles is None:
|
||||
expected_authenticated_roles = before_roles
|
||||
expected_roles = before_roles
|
||||
else:
|
||||
expected_authenticated_roles = authenticated_roles
|
||||
expected_roles = authenticated_roles
|
||||
|
||||
# Check if user gets auth-managed roles
|
||||
with mock.patch.dict(app.tornado_settings, {"authenticator": authenticator}):
|
||||
await app.login_user(user.name)
|
||||
assert not app.db.dirty
|
||||
all_roles = app.db.query(orm.Role).all()
|
||||
user_roles = sorted(g.name for g in user.roles)
|
||||
expected_authenticated_roles_names = get_role_names(expected_authenticated_roles)
|
||||
for name in expected_authenticated_roles_names:
|
||||
assert name in user_roles
|
||||
role = orm.Role.find(app.db, name)
|
||||
app.db.delete(role)
|
||||
|
||||
# Check if roles are deleted after restart
|
||||
await hub.initialize()
|
||||
all_roles = app.db.query(orm.Role).all()
|
||||
all_roles_names = sorted(g.name for g in all_roles)
|
||||
for name in expected_authenticated_roles_names:
|
||||
assert name not in all_roles_names
|
||||
assert len(user.roles) == len(expected_roles)
|
||||
|
||||
for expected_role in expected_roles:
|
||||
role = orm.Role.find(app.db, expected_role['name'])
|
||||
assert role.description == expected_role.get('description', None)
|
||||
assert len(role.scopes) == len(expected_role.get('scopes', []))
|
||||
|
@@ -303,10 +303,7 @@ class User:
|
||||
|
||||
def sync_roles(self, auth_roles):
|
||||
"""Synchronize roles with database"""
|
||||
auth_roles_by_name = {
|
||||
role['name']: role
|
||||
for role in auth_roles
|
||||
}
|
||||
auth_roles_by_name = {role['name']: role for role in auth_roles}
|
||||
|
||||
current_user_roles = {r.name for r in self.orm_user.roles}
|
||||
new_user_roles = set(auth_roles_by_name.keys())
|
||||
@@ -336,8 +333,12 @@ class User:
|
||||
# creates role, or if it exists, update its `description` and `scopes`
|
||||
try:
|
||||
orm_role = roles.create_role(self.db, role, commit=False)
|
||||
except (roles.RoleValueError, roles.InvalidNameError, scopes.ScopeNotFound) as e:
|
||||
raise web.HTTPError(409, e.value)
|
||||
except (
|
||||
roles.RoleValueError,
|
||||
roles.InvalidNameError,
|
||||
scopes.ScopeNotFound,
|
||||
) as e:
|
||||
raise web.HTTPError(409, str(e))
|
||||
|
||||
# Update the groups, services and users for the role
|
||||
groups = []
|
||||
@@ -366,11 +367,15 @@ class User:
|
||||
|
||||
# assign the granted roles to the current user
|
||||
for role_name in granted_roles:
|
||||
roles.grant_role(self.db, entity=self.orm_user, rolename=role_name, commit=False)
|
||||
roles.grant_role(
|
||||
self.db, entity=self.orm_user, rolename=role_name, commit=False
|
||||
)
|
||||
|
||||
# strip the user of roles no longer directly granted
|
||||
for role_name in stripped_roles:
|
||||
roles.strip_role(self.db, entity=self.orm_user, rolename=role_name, commit=False)
|
||||
roles.strip_role(
|
||||
self.db, entity=self.orm_user, rolename=role_name, commit=False
|
||||
)
|
||||
|
||||
self.db.commit()
|
||||
|
||||
|
Reference in New Issue
Block a user