mirror of
https://github.com/jupyterhub/jupyterhub.git
synced 2025-10-14 21:43:01 +00:00
Fix tests, passing commit
arg in decorator,
and extracting message from exceptions. Also, lint.
This commit is contained in:
@@ -232,7 +232,7 @@ def _existing_only(func):
|
|||||||
"""Decorator for checking if roles exist"""
|
"""Decorator for checking if roles exist"""
|
||||||
|
|
||||||
@wraps(func)
|
@wraps(func)
|
||||||
def _check_existence(db, entity, role=None, *, rolename=None):
|
def _check_existence(db, entity, role=None, commit=True, *, rolename=None):
|
||||||
if isinstance(role, str):
|
if isinstance(role, str):
|
||||||
rolename = role
|
rolename = role
|
||||||
if rolename is not None:
|
if rolename is not None:
|
||||||
@@ -241,7 +241,7 @@ def _existing_only(func):
|
|||||||
if role is None:
|
if role is None:
|
||||||
raise ValueError(f"Role {rolename} does not exist")
|
raise ValueError(f"Role {rolename} does not exist")
|
||||||
|
|
||||||
return func(db, entity, role)
|
return func(db, entity, role, commit=commit)
|
||||||
|
|
||||||
return _check_existence
|
return _check_existence
|
||||||
|
|
||||||
|
@@ -13,7 +13,7 @@ from traitlets.config import Config
|
|||||||
|
|
||||||
from jupyterhub import auth, crypto, orm
|
from jupyterhub import auth, crypto, orm
|
||||||
|
|
||||||
from .mocking import MockHub, MockPAMAuthenticator, MockStructGroup, MockStructPasswd
|
from .mocking import MockPAMAuthenticator, MockStructGroup, MockStructPasswd
|
||||||
from .utils import add_user, async_requests, get_page, public_url
|
from .utils import add_user, async_requests, get_page, public_url
|
||||||
|
|
||||||
|
|
||||||
@@ -595,10 +595,6 @@ async def test_auth_managed_groups(
|
|||||||
assert groups == expected_refresh_groups
|
assert groups == expected_refresh_groups
|
||||||
|
|
||||||
|
|
||||||
def get_role_names(role_list):
|
|
||||||
return [role['name'] for role in role_list]
|
|
||||||
|
|
||||||
|
|
||||||
class MockRolesAuthenticator(auth.Authenticator):
|
class MockRolesAuthenticator(auth.Authenticator):
|
||||||
authenticated_roles = Any()
|
authenticated_roles = Any()
|
||||||
refresh_roles = Any()
|
refresh_roles = Any()
|
||||||
@@ -618,45 +614,44 @@ class MockRolesAuthenticator(auth.Authenticator):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"authenticated_roles, refresh_roles",
|
"authenticated_roles",
|
||||||
[
|
[
|
||||||
([{"name": "testrole-1", "users": "testuser-1"}], None),
|
(None),
|
||||||
([{"name": "testrole-2", "users": "testuser-1"}], None),
|
([{"name": "role-1"}]),
|
||||||
([{"name": "testrole-3", "users": "testuser-1"}], None),
|
([{"name": "role-2", "description": "test role 2"}]),
|
||||||
([{"name": "testrole-4", "users": "testuser-1"}], None),
|
([{"name": "role-3", "scopes": ["admin:servers"]}]),
|
||||||
([{"name": "testrole-5", "users": "testuser-1"}], None),
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
async def test_auth_managed_roles(app, user, role, authenticated_roles, refresh_roles):
|
async def test_auth_managed_roles(app, user, role, authenticated_roles):
|
||||||
|
|
||||||
authenticator = MockRolesAuthenticator(
|
authenticator = MockRolesAuthenticator(
|
||||||
parent=app,
|
parent=app,
|
||||||
authenticated_roles=authenticated_roles,
|
authenticated_roles=authenticated_roles,
|
||||||
)
|
)
|
||||||
hub = MockHub(db_url='sqlite:///jupyterhub.sqlite')
|
|
||||||
user.roles.append(role)
|
user.roles.append(role)
|
||||||
app.db.commit()
|
app.db.commit()
|
||||||
before_roles = role.name
|
before_roles = [
|
||||||
|
{
|
||||||
|
'name': r.name,
|
||||||
|
'description': r.description,
|
||||||
|
'scopes': r.scopes,
|
||||||
|
'users': r.users,
|
||||||
|
}
|
||||||
|
for r in user.roles
|
||||||
|
]
|
||||||
|
|
||||||
if authenticated_roles is None:
|
if authenticated_roles is None:
|
||||||
expected_authenticated_roles = before_roles
|
expected_roles = before_roles
|
||||||
else:
|
else:
|
||||||
expected_authenticated_roles = authenticated_roles
|
expected_roles = authenticated_roles
|
||||||
|
|
||||||
# Check if user gets auth-managed roles
|
# Check if user gets auth-managed roles
|
||||||
with mock.patch.dict(app.tornado_settings, {"authenticator": authenticator}):
|
with mock.patch.dict(app.tornado_settings, {"authenticator": authenticator}):
|
||||||
|
await app.login_user(user.name)
|
||||||
assert not app.db.dirty
|
assert not app.db.dirty
|
||||||
all_roles = app.db.query(orm.Role).all()
|
|
||||||
user_roles = sorted(g.name for g in user.roles)
|
|
||||||
expected_authenticated_roles_names = get_role_names(expected_authenticated_roles)
|
|
||||||
for name in expected_authenticated_roles_names:
|
|
||||||
assert name in user_roles
|
|
||||||
role = orm.Role.find(app.db, name)
|
|
||||||
app.db.delete(role)
|
|
||||||
|
|
||||||
# Check if roles are deleted after restart
|
assert len(user.roles) == len(expected_roles)
|
||||||
await hub.initialize()
|
|
||||||
all_roles = app.db.query(orm.Role).all()
|
for expected_role in expected_roles:
|
||||||
all_roles_names = sorted(g.name for g in all_roles)
|
role = orm.Role.find(app.db, expected_role['name'])
|
||||||
for name in expected_authenticated_roles_names:
|
assert role.description == expected_role.get('description', None)
|
||||||
assert name not in all_roles_names
|
assert len(role.scopes) == len(expected_role.get('scopes', []))
|
||||||
|
@@ -303,10 +303,7 @@ class User:
|
|||||||
|
|
||||||
def sync_roles(self, auth_roles):
|
def sync_roles(self, auth_roles):
|
||||||
"""Synchronize roles with database"""
|
"""Synchronize roles with database"""
|
||||||
auth_roles_by_name = {
|
auth_roles_by_name = {role['name']: role for role in auth_roles}
|
||||||
role['name']: role
|
|
||||||
for role in auth_roles
|
|
||||||
}
|
|
||||||
|
|
||||||
current_user_roles = {r.name for r in self.orm_user.roles}
|
current_user_roles = {r.name for r in self.orm_user.roles}
|
||||||
new_user_roles = set(auth_roles_by_name.keys())
|
new_user_roles = set(auth_roles_by_name.keys())
|
||||||
@@ -336,8 +333,12 @@ class User:
|
|||||||
# creates role, or if it exists, update its `description` and `scopes`
|
# creates role, or if it exists, update its `description` and `scopes`
|
||||||
try:
|
try:
|
||||||
orm_role = roles.create_role(self.db, role, commit=False)
|
orm_role = roles.create_role(self.db, role, commit=False)
|
||||||
except (roles.RoleValueError, roles.InvalidNameError, scopes.ScopeNotFound) as e:
|
except (
|
||||||
raise web.HTTPError(409, e.value)
|
roles.RoleValueError,
|
||||||
|
roles.InvalidNameError,
|
||||||
|
scopes.ScopeNotFound,
|
||||||
|
) as e:
|
||||||
|
raise web.HTTPError(409, str(e))
|
||||||
|
|
||||||
# Update the groups, services and users for the role
|
# Update the groups, services and users for the role
|
||||||
groups = []
|
groups = []
|
||||||
@@ -366,11 +367,15 @@ class User:
|
|||||||
|
|
||||||
# assign the granted roles to the current user
|
# assign the granted roles to the current user
|
||||||
for role_name in granted_roles:
|
for role_name in granted_roles:
|
||||||
roles.grant_role(self.db, entity=self.orm_user, rolename=role_name, commit=False)
|
roles.grant_role(
|
||||||
|
self.db, entity=self.orm_user, rolename=role_name, commit=False
|
||||||
|
)
|
||||||
|
|
||||||
# strip the user of roles no longer directly granted
|
# strip the user of roles no longer directly granted
|
||||||
for role_name in stripped_roles:
|
for role_name in stripped_roles:
|
||||||
roles.strip_role(self.db, entity=self.orm_user, rolename=role_name, commit=False)
|
roles.strip_role(
|
||||||
|
self.db, entity=self.orm_user, rolename=role_name, commit=False
|
||||||
|
)
|
||||||
|
|
||||||
self.db.commit()
|
self.db.commit()
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user