Fix tests, passing commit arg in decorator,

and extracting message from exceptions. Also, lint.
This commit is contained in:
krassowski
2024-03-24 20:18:59 +00:00
parent c685d4bec9
commit 1799b57e4b
3 changed files with 40 additions and 40 deletions

View File

@@ -232,7 +232,7 @@ def _existing_only(func):
"""Decorator for checking if roles exist"""
@wraps(func)
def _check_existence(db, entity, role=None, *, rolename=None):
def _check_existence(db, entity, role=None, commit=True, *, rolename=None):
if isinstance(role, str):
rolename = role
if rolename is not None:
@@ -241,7 +241,7 @@ def _existing_only(func):
if role is None:
raise ValueError(f"Role {rolename} does not exist")
return func(db, entity, role)
return func(db, entity, role, commit=commit)
return _check_existence

View File

@@ -13,7 +13,7 @@ from traitlets.config import Config
from jupyterhub import auth, crypto, orm
from .mocking import MockHub, MockPAMAuthenticator, MockStructGroup, MockStructPasswd
from .mocking import MockPAMAuthenticator, MockStructGroup, MockStructPasswd
from .utils import add_user, async_requests, get_page, public_url
@@ -595,10 +595,6 @@ async def test_auth_managed_groups(
assert groups == expected_refresh_groups
def get_role_names(role_list):
return [role['name'] for role in role_list]
class MockRolesAuthenticator(auth.Authenticator):
authenticated_roles = Any()
refresh_roles = Any()
@@ -618,45 +614,44 @@ class MockRolesAuthenticator(auth.Authenticator):
@pytest.mark.parametrize(
"authenticated_roles, refresh_roles",
"authenticated_roles",
[
([{"name": "testrole-1", "users": "testuser-1"}], None),
([{"name": "testrole-2", "users": "testuser-1"}], None),
([{"name": "testrole-3", "users": "testuser-1"}], None),
([{"name": "testrole-4", "users": "testuser-1"}], None),
([{"name": "testrole-5", "users": "testuser-1"}], None),
(None),
([{"name": "role-1"}]),
([{"name": "role-2", "description": "test role 2"}]),
([{"name": "role-3", "scopes": ["admin:servers"]}]),
],
)
async def test_auth_managed_roles(app, user, role, authenticated_roles, refresh_roles):
async def test_auth_managed_roles(app, user, role, authenticated_roles):
authenticator = MockRolesAuthenticator(
parent=app,
authenticated_roles=authenticated_roles,
)
hub = MockHub(db_url='sqlite:///jupyterhub.sqlite')
user.roles.append(role)
app.db.commit()
before_roles = role.name
before_roles = [
{
'name': r.name,
'description': r.description,
'scopes': r.scopes,
'users': r.users,
}
for r in user.roles
]
if authenticated_roles is None:
expected_authenticated_roles = before_roles
expected_roles = before_roles
else:
expected_authenticated_roles = authenticated_roles
expected_roles = authenticated_roles
# Check if user gets auth-managed roles
with mock.patch.dict(app.tornado_settings, {"authenticator": authenticator}):
await app.login_user(user.name)
assert not app.db.dirty
all_roles = app.db.query(orm.Role).all()
user_roles = sorted(g.name for g in user.roles)
expected_authenticated_roles_names = get_role_names(expected_authenticated_roles)
for name in expected_authenticated_roles_names:
assert name in user_roles
role = orm.Role.find(app.db, name)
app.db.delete(role)
# Check if roles are deleted after restart
await hub.initialize()
all_roles = app.db.query(orm.Role).all()
all_roles_names = sorted(g.name for g in all_roles)
for name in expected_authenticated_roles_names:
assert name not in all_roles_names
assert len(user.roles) == len(expected_roles)
for expected_role in expected_roles:
role = orm.Role.find(app.db, expected_role['name'])
assert role.description == expected_role.get('description', None)
assert len(role.scopes) == len(expected_role.get('scopes', []))

View File

@@ -303,10 +303,7 @@ class User:
def sync_roles(self, auth_roles):
"""Synchronize roles with database"""
auth_roles_by_name = {
role['name']: role
for role in auth_roles
}
auth_roles_by_name = {role['name']: role for role in auth_roles}
current_user_roles = {r.name for r in self.orm_user.roles}
new_user_roles = set(auth_roles_by_name.keys())
@@ -336,8 +333,12 @@ class User:
# creates role, or if it exists, update its `description` and `scopes`
try:
orm_role = roles.create_role(self.db, role, commit=False)
except (roles.RoleValueError, roles.InvalidNameError, scopes.ScopeNotFound) as e:
raise web.HTTPError(409, e.value)
except (
roles.RoleValueError,
roles.InvalidNameError,
scopes.ScopeNotFound,
) as e:
raise web.HTTPError(409, str(e))
# Update the groups, services and users for the role
groups = []
@@ -366,11 +367,15 @@ class User:
# assign the granted roles to the current user
for role_name in granted_roles:
roles.grant_role(self.db, entity=self.orm_user, rolename=role_name, commit=False)
roles.grant_role(
self.db, entity=self.orm_user, rolename=role_name, commit=False
)
# strip the user of roles no longer directly granted
for role_name in stripped_roles:
roles.strip_role(self.db, entity=self.orm_user, rolename=role_name, commit=False)
roles.strip_role(
self.db, entity=self.orm_user, rolename=role_name, commit=False
)
self.db.commit()