diff --git a/jupyterhub/scopes.py b/jupyterhub/scopes.py index afce5974..32276a20 100644 --- a/jupyterhub/scopes.py +++ b/jupyterhub/scopes.py @@ -13,7 +13,9 @@ import functools import inspect import warnings from enum import Enum +from functools import lru_cache +import sqlalchemy as sa from tornado import web from tornado.log import app_log @@ -114,21 +116,38 @@ class Scope(Enum): ALL = True -def _intersect_expanded_scopes(scopes_a, scopes_b): - """Intersect two sets of expanded scopes by comparing their permissions +def _intersect_expanded_scopes(scopes_a, scopes_b, db=None): + """Intersect two sets of scopes by comparing their permissions Arguments: scopes_a, scopes_b: sets of expanded scopes + db (optional): db connection for resolving group membership Returns: intersection: set of expanded scopes as intersection of the arguments - Note: Intersects correctly with ALL and exact filter matches - (i.e. users!user=x & read:users:name -> read:users:name!user=x) - - Does not currently intersect with containing filters - (i.e. users!group=x & users!user=y even if user y is in group x) + If db is given, group membership will be accounted for in intersections, + Otherwise, it can result in lower than intended permissions, + (i.e. users!group=x & users!user=y will be empty, even if user y is in group x.) """ + empty_set = frozenset() + + # cached lookups for group membership of users and servers + @lru_cache() + def groups_for_user(username): + """Get set of group names for a given username""" + user = db.query(orm.User).filter_by(name=username).first() + if user is None: + return empty_set + else: + return {group.name for group in user.groups} + + @lru_cache() + def groups_for_server(server): + """Get set of group names for a given server""" + username, _, servername = server.partition("/") + return groups_for_user(username) + parsed_scopes_a = parse_scopes(scopes_a) parsed_scopes_b = parse_scopes(scopes_b) @@ -144,11 +163,14 @@ def _intersect_expanded_scopes(scopes_a, scopes_b): elif filters_b == Scope.ALL: common_filters[base] = filters_a else: - # warn *if* there are non-overlapping user= and group= filters common_entities = filters_a.keys() & filters_b.keys() all_entities = filters_a.keys() | filters_b.keys() + + # if we don't have a db session, we can't check group membership + # warn *if* there are non-overlapping user= and group= filters that we can't check if ( - not warned + db is None + and not warned and 'group' in all_entities and ('user' in all_entities or 'server' in all_entities) ): @@ -160,39 +182,67 @@ def _intersect_expanded_scopes(scopes_a, scopes_b): not warned and "group" in a and b_key in b - and set(a["group"]).difference(b.get("group", [])) - and set(b[b_key]).difference(a.get(b_key, [])) + and a["group"].difference(b.get("group", [])) + and b[b_key].difference(a.get(b_key, [])) ): warnings.warn( f"{base}[!{b_key}={b[b_key]}, !group={a['group']}] combinations of filters present," - " intersection between not considered. May result in lower than intended permissions.", + " without db access. Intersection between not considered." + " May result in lower than intended permissions.", UserWarning, ) warned = True common_filters[base] = { - entity: set(parsed_scopes_a[base][entity]) - & set(parsed_scopes_b[base][entity]) + entity: filters_a[entity] & filters_b[entity] for entity in common_entities } - if 'server' in all_entities and 'user' in all_entities: - if filters_a.get('server') == filters_b.get('server'): - continue + # resolve hierarchies (group/user/server) in both directions + common_servers = common_filters[base].get("server", set()) + common_users = common_filters[base].get("user", set()) - additional_servers = set() - # resolve user/server hierarchy in both directions - for a, b in [(filters_a, filters_b), (filters_b, filters_a)]: - if 'server' in a and 'user' in b: - for server in a['server']: + for a, b in [(filters_a, filters_b), (filters_b, filters_a)]: + if 'server' in a and b.get('server') != a['server']: + # skip already-added servers (includes overlapping servers) + servers = a['server'].difference(common_servers) + + # resolve user/server hierarchy + if servers and 'user' in b: + for server in servers: username, _, servername = server.partition("/") if username in b['user']: - additional_servers.add(server) + common_servers.add(server) - if additional_servers: - if "server" not in common_filters[base]: - common_filters[base]["server"] = set() - common_filters[base]["server"].update(additional_servers) + # resolve group/server hierarchy if db available + servers = servers.difference(common_servers) + if db is not None and servers and 'group' in b: + for server in servers: + server_groups = groups_for_server(server) + if server_groups & b['group']: + common_servers.add(server) + + # resolve group/user hierarchy if db available and user sets aren't identical + if ( + db is not None + and 'user' in a + and 'group' in b + and b.get('user') != a['user'] + ): + # skip already-added users (includes overlapping users) + users = a['user'].difference(common_users) + for username in users: + groups = groups_for_user(username) + if groups & b["group"]: + common_users.add(username) + + # add server filter if there wasn't one before + if common_servers and "server" not in common_filters[base]: + common_filters[base]["server"] = common_servers + + # add user filter if it's non-empty and there wasn't one before + if common_users and "user" not in common_filters[base]: + common_filters[base]["user"] = common_users return unparse_scopes(common_filters) @@ -244,11 +294,21 @@ def get_scopes_for(orm_object): ) owner_scopes = roles.expand_roles_to_scopes(owner) + + if token_scopes == {'all'}: + # token_scopes is only 'all', return owner scopes as-is + # short-circuit common case where we don't need to compute an intersection + return owner_scopes + if 'all' in token_scopes: token_scopes.remove('all') token_scopes |= owner_scopes - intersection = _intersect_expanded_scopes(token_scopes, owner_scopes) + intersection = _intersect_expanded_scopes( + token_scopes, + owner_scopes, + db=sa.inspect(orm_object).session, + ) discarded_token_scopes = token_scopes - intersection # Not taking symmetric difference here because token owner can naturally have more scopes than token @@ -473,7 +533,8 @@ def check_scope_filter(sub_scope, orm_resource, kind): # Fall back on checking if we have group access for this user orm_resource = orm_resource.user kind = 'user' - elif kind == 'user' and 'group' in sub_scope: + + if kind == 'user' and 'group' in sub_scope: group_names = {group.name for group in orm_resource.groups} user_in_group = bool(group_names & set(sub_scope['group'])) if user_in_group: diff --git a/jupyterhub/services/auth.py b/jupyterhub/services/auth.py index 0127883e..64080723 100644 --- a/jupyterhub/services/auth.py +++ b/jupyterhub/services/auth.py @@ -38,7 +38,7 @@ from traitlets import Unicode from traitlets import validate from traitlets.config import SingletonConfigurable -from ..scopes import _intersect_scopes +from ..scopes import _intersect_expanded_scopes from ..utils import url_path_join @@ -70,7 +70,7 @@ def check_scopes(required_scopes, scopes): if isinstance(required_scopes, str): required_scopes = {required_scopes} - intersection = _intersect_scopes(required_scopes, scopes) + intersection = _intersect_expanded_scopes(required_scopes, scopes) # re-intersect with required_scopes in case the intersection # applies stricter filters than required_scopes declares # e.g. required_scopes = {'read:users'} and intersection has only {'read:users!user=x'} diff --git a/jupyterhub/tests/test_scopes.py b/jupyterhub/tests/test_scopes.py index 81642cba..305112f7 100644 --- a/jupyterhub/tests/test_scopes.py +++ b/jupyterhub/tests/test_scopes.py @@ -763,3 +763,81 @@ def test_intersect_expanded_scopes(left, right, expected, should_warn, recwarn): assert len(recwarn) == 1 else: assert len(recwarn) == 0 + + +@pytest.mark.parametrize( + "left, right, expected, groups", + [ + ( + ["users!group=gx"], + ["users!user=ux"], + ["users!user=ux"], + {"gx": ["ux"]}, + ), + ( + ["read:users!group=gx"], + ["read:users!user=nosuchuser"], + [], + {}, + ), + ( + ["read:users!group=gx"], + ["read:users!server=nosuchuser/server"], + [], + {}, + ), + ( + ["read:users!group=gx"], + ["read:users!server=ux/server"], + ["read:users!server=ux/server"], + {"gx": ["ux"]}, + ), + ( + ["read:users!group=gx"], + ["read:users!server=ux/server", "read:users!user=uy"], + ["read:users!server=ux/server"], + {"gx": ["ux"], "gy": ["uy"]}, + ), + ( + ["read:users!group=gy"], + ["read:users!server=ux/server", "read:users!user=uy"], + ["read:users!user=uy"], + {"gx": ["ux"], "gy": ["uy"]}, + ), + ], +) +def test_intersect_groups(request, db, left, right, expected, groups): + if isinstance(left, str): + left = set([left]) + if isinstance(right, str): + right = set([right]) + + # if we have a db connection, we can actually resolve + created = [] + for groupname, members in groups.items(): + group = orm.Group.find(db, name=groupname) + if not group: + group = orm.Group(name=groupname) + db.add(group) + created.append(group) + db.commit() + for username in members: + user = orm.User.find(db, name=username) + if user is None: + user = orm.User(name=username) + db.add(user) + created.append(user) + user.groups.append(group) + db.commit() + + def _cleanup(): + for obj in created: + db.delete(obj) + db.commit() + + request.addfinalizer(_cleanup) + + # run every test in both directions, to ensure symmetry of the inputs + for a, b in [(left, right), (right, left)]: + intersection = _intersect_expanded_scopes(set(left), set(right), db) + assert intersection == set(expected) diff --git a/jupyterhub/tests/test_singleuser.py b/jupyterhub/tests/test_singleuser.py index 0b0eef5a..d2fe3657 100644 --- a/jupyterhub/tests/test_singleuser.py +++ b/jupyterhub/tests/test_singleuser.py @@ -6,6 +6,7 @@ from urllib.parse import urlparse import pytest import jupyterhub +from .. import orm from ..utils import url_path_join from .mocking import public_url from .mocking import StubSingleUserSpawner @@ -16,6 +17,8 @@ from .utils import AsyncSession @pytest.mark.parametrize( "access_scopes, server_name, expect_success", [ + (["access:users:servers!group=$group"], "", True), + (["access:users:servers!group=other-group"], "", False), (["access:users:servers"], "", True), (["access:users:servers"], "named", True), (["access:users:servers!user=$user"], "", True), @@ -46,6 +49,16 @@ async def test_singleuser_auth( # login, start the server cookies = await app.login_user('nandy') user = app.users['nandy'] + + group = orm.Group.find(app.db, name="visitors") + if group is None: + group = orm.Group(name="visitors") + app.db.add(group) + app.db.commit() + if group not in user.groups: + user.groups.append(group) + app.db.commit() + if server_name not in user.spawners or not user.spawners[server_name].active: await user.spawn(server_name) await app.proxy.add_user(user, server_name) @@ -85,6 +98,7 @@ async def test_singleuser_auth( access_scopes = [ s.replace("$server", f"{user.name}/{server_name}") for s in access_scopes ] + access_scopes = [s.replace("$group", f"{group.name}") for s in access_scopes] other_user = create_user_with_scopes(*access_scopes, name="burgess") cookies = await app.login_user('burgess')