Fix tests, passing commit arg in decorator,

and extracting message from exceptions. Also, lint.
This commit is contained in:
krassowski
2024-03-24 20:18:59 +00:00
parent c685d4bec9
commit 1799b57e4b
3 changed files with 40 additions and 40 deletions

View File

@@ -232,7 +232,7 @@ def _existing_only(func):
"""Decorator for checking if roles exist""" """Decorator for checking if roles exist"""
@wraps(func) @wraps(func)
def _check_existence(db, entity, role=None, *, rolename=None): def _check_existence(db, entity, role=None, commit=True, *, rolename=None):
if isinstance(role, str): if isinstance(role, str):
rolename = role rolename = role
if rolename is not None: if rolename is not None:
@@ -241,7 +241,7 @@ def _existing_only(func):
if role is None: if role is None:
raise ValueError(f"Role {rolename} does not exist") raise ValueError(f"Role {rolename} does not exist")
return func(db, entity, role) return func(db, entity, role, commit=commit)
return _check_existence return _check_existence

View File

@@ -13,7 +13,7 @@ from traitlets.config import Config
from jupyterhub import auth, crypto, orm from jupyterhub import auth, crypto, orm
from .mocking import MockHub, MockPAMAuthenticator, MockStructGroup, MockStructPasswd from .mocking import MockPAMAuthenticator, MockStructGroup, MockStructPasswd
from .utils import add_user, async_requests, get_page, public_url from .utils import add_user, async_requests, get_page, public_url
@@ -595,10 +595,6 @@ async def test_auth_managed_groups(
assert groups == expected_refresh_groups assert groups == expected_refresh_groups
def get_role_names(role_list):
return [role['name'] for role in role_list]
class MockRolesAuthenticator(auth.Authenticator): class MockRolesAuthenticator(auth.Authenticator):
authenticated_roles = Any() authenticated_roles = Any()
refresh_roles = Any() refresh_roles = Any()
@@ -618,45 +614,44 @@ class MockRolesAuthenticator(auth.Authenticator):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"authenticated_roles, refresh_roles", "authenticated_roles",
[ [
([{"name": "testrole-1", "users": "testuser-1"}], None), (None),
([{"name": "testrole-2", "users": "testuser-1"}], None), ([{"name": "role-1"}]),
([{"name": "testrole-3", "users": "testuser-1"}], None), ([{"name": "role-2", "description": "test role 2"}]),
([{"name": "testrole-4", "users": "testuser-1"}], None), ([{"name": "role-3", "scopes": ["admin:servers"]}]),
([{"name": "testrole-5", "users": "testuser-1"}], None),
], ],
) )
async def test_auth_managed_roles(app, user, role, authenticated_roles, refresh_roles): async def test_auth_managed_roles(app, user, role, authenticated_roles):
authenticator = MockRolesAuthenticator( authenticator = MockRolesAuthenticator(
parent=app, parent=app,
authenticated_roles=authenticated_roles, authenticated_roles=authenticated_roles,
) )
hub = MockHub(db_url='sqlite:///jupyterhub.sqlite')
user.roles.append(role) user.roles.append(role)
app.db.commit() app.db.commit()
before_roles = role.name before_roles = [
{
'name': r.name,
'description': r.description,
'scopes': r.scopes,
'users': r.users,
}
for r in user.roles
]
if authenticated_roles is None: if authenticated_roles is None:
expected_authenticated_roles = before_roles expected_roles = before_roles
else: else:
expected_authenticated_roles = authenticated_roles expected_roles = authenticated_roles
# Check if user gets auth-managed roles # Check if user gets auth-managed roles
with mock.patch.dict(app.tornado_settings, {"authenticator": authenticator}): with mock.patch.dict(app.tornado_settings, {"authenticator": authenticator}):
await app.login_user(user.name)
assert not app.db.dirty assert not app.db.dirty
all_roles = app.db.query(orm.Role).all()
user_roles = sorted(g.name for g in user.roles)
expected_authenticated_roles_names = get_role_names(expected_authenticated_roles)
for name in expected_authenticated_roles_names:
assert name in user_roles
role = orm.Role.find(app.db, name)
app.db.delete(role)
# Check if roles are deleted after restart assert len(user.roles) == len(expected_roles)
await hub.initialize()
all_roles = app.db.query(orm.Role).all() for expected_role in expected_roles:
all_roles_names = sorted(g.name for g in all_roles) role = orm.Role.find(app.db, expected_role['name'])
for name in expected_authenticated_roles_names: assert role.description == expected_role.get('description', None)
assert name not in all_roles_names assert len(role.scopes) == len(expected_role.get('scopes', []))

View File

@@ -303,10 +303,7 @@ class User:
def sync_roles(self, auth_roles): def sync_roles(self, auth_roles):
"""Synchronize roles with database""" """Synchronize roles with database"""
auth_roles_by_name = { auth_roles_by_name = {role['name']: role for role in auth_roles}
role['name']: role
for role in auth_roles
}
current_user_roles = {r.name for r in self.orm_user.roles} current_user_roles = {r.name for r in self.orm_user.roles}
new_user_roles = set(auth_roles_by_name.keys()) new_user_roles = set(auth_roles_by_name.keys())
@@ -336,8 +333,12 @@ class User:
# creates role, or if it exists, update its `description` and `scopes` # creates role, or if it exists, update its `description` and `scopes`
try: try:
orm_role = roles.create_role(self.db, role, commit=False) orm_role = roles.create_role(self.db, role, commit=False)
except (roles.RoleValueError, roles.InvalidNameError, scopes.ScopeNotFound) as e: except (
raise web.HTTPError(409, e.value) roles.RoleValueError,
roles.InvalidNameError,
scopes.ScopeNotFound,
) as e:
raise web.HTTPError(409, str(e))
# Update the groups, services and users for the role # Update the groups, services and users for the role
groups = [] groups = []
@@ -366,11 +367,15 @@ class User:
# assign the granted roles to the current user # assign the granted roles to the current user
for role_name in granted_roles: for role_name in granted_roles:
roles.grant_role(self.db, entity=self.orm_user, rolename=role_name, commit=False) roles.grant_role(
self.db, entity=self.orm_user, rolename=role_name, commit=False
)
# strip the user of roles no longer directly granted # strip the user of roles no longer directly granted
for role_name in stripped_roles: for role_name in stripped_roles:
roles.strip_role(self.db, entity=self.orm_user, rolename=role_name, commit=False) roles.strip_role(
self.db, entity=self.orm_user, rolename=role_name, commit=False
)
self.db.commit() self.db.commit()