mirror of
https://github.com/jupyterhub/jupyterhub.git
synced 2025-10-07 18:14:10 +00:00
support groups in _intersect_scopes
Requires db resolution
This commit is contained in:
@@ -13,7 +13,9 @@ import functools
|
||||
import inspect
|
||||
import warnings
|
||||
from enum import Enum
|
||||
from functools import lru_cache
|
||||
|
||||
import sqlalchemy as sa
|
||||
from tornado import web
|
||||
from tornado.log import app_log
|
||||
|
||||
@@ -114,21 +116,38 @@ class Scope(Enum):
|
||||
ALL = True
|
||||
|
||||
|
||||
def _intersect_expanded_scopes(scopes_a, scopes_b):
|
||||
"""Intersect two sets of expanded scopes by comparing their permissions
|
||||
def _intersect_expanded_scopes(scopes_a, scopes_b, db=None):
|
||||
"""Intersect two sets of scopes by comparing their permissions
|
||||
|
||||
Arguments:
|
||||
scopes_a, scopes_b: sets of expanded scopes
|
||||
db (optional): db connection for resolving group membership
|
||||
|
||||
Returns:
|
||||
intersection: set of expanded scopes as intersection of the arguments
|
||||
|
||||
Note: Intersects correctly with ALL and exact filter matches
|
||||
(i.e. users!user=x & read:users:name -> read:users:name!user=x)
|
||||
|
||||
Does not currently intersect with containing filters
|
||||
(i.e. users!group=x & users!user=y even if user y is in group x)
|
||||
If db is given, group membership will be accounted for in intersections,
|
||||
Otherwise, it can result in lower than intended permissions,
|
||||
(i.e. users!group=x & users!user=y will be empty, even if user y is in group x.)
|
||||
"""
|
||||
empty_set = frozenset()
|
||||
|
||||
# cached lookups for group membership of users and servers
|
||||
@lru_cache()
|
||||
def groups_for_user(username):
|
||||
"""Get set of group names for a given username"""
|
||||
user = db.query(orm.User).filter_by(name=username).first()
|
||||
if user is None:
|
||||
return empty_set
|
||||
else:
|
||||
return {group.name for group in user.groups}
|
||||
|
||||
@lru_cache()
|
||||
def groups_for_server(server):
|
||||
"""Get set of group names for a given server"""
|
||||
username, _, servername = server.partition("/")
|
||||
return groups_for_user(username)
|
||||
|
||||
parsed_scopes_a = parse_scopes(scopes_a)
|
||||
parsed_scopes_b = parse_scopes(scopes_b)
|
||||
|
||||
@@ -144,11 +163,14 @@ def _intersect_expanded_scopes(scopes_a, scopes_b):
|
||||
elif filters_b == Scope.ALL:
|
||||
common_filters[base] = filters_a
|
||||
else:
|
||||
# warn *if* there are non-overlapping user= and group= filters
|
||||
common_entities = filters_a.keys() & filters_b.keys()
|
||||
all_entities = filters_a.keys() | filters_b.keys()
|
||||
|
||||
# if we don't have a db session, we can't check group membership
|
||||
# warn *if* there are non-overlapping user= and group= filters that we can't check
|
||||
if (
|
||||
not warned
|
||||
db is None
|
||||
and not warned
|
||||
and 'group' in all_entities
|
||||
and ('user' in all_entities or 'server' in all_entities)
|
||||
):
|
||||
@@ -160,39 +182,67 @@ def _intersect_expanded_scopes(scopes_a, scopes_b):
|
||||
not warned
|
||||
and "group" in a
|
||||
and b_key in b
|
||||
and set(a["group"]).difference(b.get("group", []))
|
||||
and set(b[b_key]).difference(a.get(b_key, []))
|
||||
and a["group"].difference(b.get("group", []))
|
||||
and b[b_key].difference(a.get(b_key, []))
|
||||
):
|
||||
warnings.warn(
|
||||
f"{base}[!{b_key}={b[b_key]}, !group={a['group']}] combinations of filters present,"
|
||||
" intersection between not considered. May result in lower than intended permissions.",
|
||||
" without db access. Intersection between not considered."
|
||||
" May result in lower than intended permissions.",
|
||||
UserWarning,
|
||||
)
|
||||
warned = True
|
||||
|
||||
common_filters[base] = {
|
||||
entity: set(parsed_scopes_a[base][entity])
|
||||
& set(parsed_scopes_b[base][entity])
|
||||
entity: filters_a[entity] & filters_b[entity]
|
||||
for entity in common_entities
|
||||
}
|
||||
|
||||
if 'server' in all_entities and 'user' in all_entities:
|
||||
if filters_a.get('server') == filters_b.get('server'):
|
||||
continue
|
||||
# resolve hierarchies (group/user/server) in both directions
|
||||
common_servers = common_filters[base].get("server", set())
|
||||
common_users = common_filters[base].get("user", set())
|
||||
|
||||
additional_servers = set()
|
||||
# resolve user/server hierarchy in both directions
|
||||
for a, b in [(filters_a, filters_b), (filters_b, filters_a)]:
|
||||
if 'server' in a and 'user' in b:
|
||||
for server in a['server']:
|
||||
if 'server' in a and b.get('server') != a['server']:
|
||||
# skip already-added servers (includes overlapping servers)
|
||||
servers = a['server'].difference(common_servers)
|
||||
|
||||
# resolve user/server hierarchy
|
||||
if servers and 'user' in b:
|
||||
for server in servers:
|
||||
username, _, servername = server.partition("/")
|
||||
if username in b['user']:
|
||||
additional_servers.add(server)
|
||||
common_servers.add(server)
|
||||
|
||||
if additional_servers:
|
||||
if "server" not in common_filters[base]:
|
||||
common_filters[base]["server"] = set()
|
||||
common_filters[base]["server"].update(additional_servers)
|
||||
# resolve group/server hierarchy if db available
|
||||
servers = servers.difference(common_servers)
|
||||
if db is not None and servers and 'group' in b:
|
||||
for server in servers:
|
||||
server_groups = groups_for_server(server)
|
||||
if server_groups & b['group']:
|
||||
common_servers.add(server)
|
||||
|
||||
# resolve group/user hierarchy if db available and user sets aren't identical
|
||||
if (
|
||||
db is not None
|
||||
and 'user' in a
|
||||
and 'group' in b
|
||||
and b.get('user') != a['user']
|
||||
):
|
||||
# skip already-added users (includes overlapping users)
|
||||
users = a['user'].difference(common_users)
|
||||
for username in users:
|
||||
groups = groups_for_user(username)
|
||||
if groups & b["group"]:
|
||||
common_users.add(username)
|
||||
|
||||
# add server filter if there wasn't one before
|
||||
if common_servers and "server" not in common_filters[base]:
|
||||
common_filters[base]["server"] = common_servers
|
||||
|
||||
# add user filter if it's non-empty and there wasn't one before
|
||||
if common_users and "user" not in common_filters[base]:
|
||||
common_filters[base]["user"] = common_users
|
||||
|
||||
return unparse_scopes(common_filters)
|
||||
|
||||
@@ -244,11 +294,21 @@ def get_scopes_for(orm_object):
|
||||
)
|
||||
|
||||
owner_scopes = roles.expand_roles_to_scopes(owner)
|
||||
|
||||
if token_scopes == {'all'}:
|
||||
# token_scopes is only 'all', return owner scopes as-is
|
||||
# short-circuit common case where we don't need to compute an intersection
|
||||
return owner_scopes
|
||||
|
||||
if 'all' in token_scopes:
|
||||
token_scopes.remove('all')
|
||||
token_scopes |= owner_scopes
|
||||
|
||||
intersection = _intersect_expanded_scopes(token_scopes, owner_scopes)
|
||||
intersection = _intersect_expanded_scopes(
|
||||
token_scopes,
|
||||
owner_scopes,
|
||||
db=sa.inspect(orm_object).session,
|
||||
)
|
||||
discarded_token_scopes = token_scopes - intersection
|
||||
|
||||
# Not taking symmetric difference here because token owner can naturally have more scopes than token
|
||||
@@ -473,7 +533,8 @@ def check_scope_filter(sub_scope, orm_resource, kind):
|
||||
# Fall back on checking if we have group access for this user
|
||||
orm_resource = orm_resource.user
|
||||
kind = 'user'
|
||||
elif kind == 'user' and 'group' in sub_scope:
|
||||
|
||||
if kind == 'user' and 'group' in sub_scope:
|
||||
group_names = {group.name for group in orm_resource.groups}
|
||||
user_in_group = bool(group_names & set(sub_scope['group']))
|
||||
if user_in_group:
|
||||
|
@@ -38,7 +38,7 @@ from traitlets import Unicode
|
||||
from traitlets import validate
|
||||
from traitlets.config import SingletonConfigurable
|
||||
|
||||
from ..scopes import _intersect_scopes
|
||||
from ..scopes import _intersect_expanded_scopes
|
||||
from ..utils import url_path_join
|
||||
|
||||
|
||||
@@ -70,7 +70,7 @@ def check_scopes(required_scopes, scopes):
|
||||
if isinstance(required_scopes, str):
|
||||
required_scopes = {required_scopes}
|
||||
|
||||
intersection = _intersect_scopes(required_scopes, scopes)
|
||||
intersection = _intersect_expanded_scopes(required_scopes, scopes)
|
||||
# re-intersect with required_scopes in case the intersection
|
||||
# applies stricter filters than required_scopes declares
|
||||
# e.g. required_scopes = {'read:users'} and intersection has only {'read:users!user=x'}
|
||||
|
@@ -763,3 +763,81 @@ def test_intersect_expanded_scopes(left, right, expected, should_warn, recwarn):
|
||||
assert len(recwarn) == 1
|
||||
else:
|
||||
assert len(recwarn) == 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"left, right, expected, groups",
|
||||
[
|
||||
(
|
||||
["users!group=gx"],
|
||||
["users!user=ux"],
|
||||
["users!user=ux"],
|
||||
{"gx": ["ux"]},
|
||||
),
|
||||
(
|
||||
["read:users!group=gx"],
|
||||
["read:users!user=nosuchuser"],
|
||||
[],
|
||||
{},
|
||||
),
|
||||
(
|
||||
["read:users!group=gx"],
|
||||
["read:users!server=nosuchuser/server"],
|
||||
[],
|
||||
{},
|
||||
),
|
||||
(
|
||||
["read:users!group=gx"],
|
||||
["read:users!server=ux/server"],
|
||||
["read:users!server=ux/server"],
|
||||
{"gx": ["ux"]},
|
||||
),
|
||||
(
|
||||
["read:users!group=gx"],
|
||||
["read:users!server=ux/server", "read:users!user=uy"],
|
||||
["read:users!server=ux/server"],
|
||||
{"gx": ["ux"], "gy": ["uy"]},
|
||||
),
|
||||
(
|
||||
["read:users!group=gy"],
|
||||
["read:users!server=ux/server", "read:users!user=uy"],
|
||||
["read:users!user=uy"],
|
||||
{"gx": ["ux"], "gy": ["uy"]},
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_intersect_groups(request, db, left, right, expected, groups):
|
||||
if isinstance(left, str):
|
||||
left = set([left])
|
||||
if isinstance(right, str):
|
||||
right = set([right])
|
||||
|
||||
# if we have a db connection, we can actually resolve
|
||||
created = []
|
||||
for groupname, members in groups.items():
|
||||
group = orm.Group.find(db, name=groupname)
|
||||
if not group:
|
||||
group = orm.Group(name=groupname)
|
||||
db.add(group)
|
||||
created.append(group)
|
||||
db.commit()
|
||||
for username in members:
|
||||
user = orm.User.find(db, name=username)
|
||||
if user is None:
|
||||
user = orm.User(name=username)
|
||||
db.add(user)
|
||||
created.append(user)
|
||||
user.groups.append(group)
|
||||
db.commit()
|
||||
|
||||
def _cleanup():
|
||||
for obj in created:
|
||||
db.delete(obj)
|
||||
db.commit()
|
||||
|
||||
request.addfinalizer(_cleanup)
|
||||
|
||||
# run every test in both directions, to ensure symmetry of the inputs
|
||||
for a, b in [(left, right), (right, left)]:
|
||||
intersection = _intersect_expanded_scopes(set(left), set(right), db)
|
||||
assert intersection == set(expected)
|
||||
|
@@ -6,6 +6,7 @@ from urllib.parse import urlparse
|
||||
import pytest
|
||||
|
||||
import jupyterhub
|
||||
from .. import orm
|
||||
from ..utils import url_path_join
|
||||
from .mocking import public_url
|
||||
from .mocking import StubSingleUserSpawner
|
||||
@@ -16,6 +17,8 @@ from .utils import AsyncSession
|
||||
@pytest.mark.parametrize(
|
||||
"access_scopes, server_name, expect_success",
|
||||
[
|
||||
(["access:users:servers!group=$group"], "", True),
|
||||
(["access:users:servers!group=other-group"], "", False),
|
||||
(["access:users:servers"], "", True),
|
||||
(["access:users:servers"], "named", True),
|
||||
(["access:users:servers!user=$user"], "", True),
|
||||
@@ -46,6 +49,16 @@ async def test_singleuser_auth(
|
||||
# login, start the server
|
||||
cookies = await app.login_user('nandy')
|
||||
user = app.users['nandy']
|
||||
|
||||
group = orm.Group.find(app.db, name="visitors")
|
||||
if group is None:
|
||||
group = orm.Group(name="visitors")
|
||||
app.db.add(group)
|
||||
app.db.commit()
|
||||
if group not in user.groups:
|
||||
user.groups.append(group)
|
||||
app.db.commit()
|
||||
|
||||
if server_name not in user.spawners or not user.spawners[server_name].active:
|
||||
await user.spawn(server_name)
|
||||
await app.proxy.add_user(user, server_name)
|
||||
@@ -85,6 +98,7 @@ async def test_singleuser_auth(
|
||||
access_scopes = [
|
||||
s.replace("$server", f"{user.name}/{server_name}") for s in access_scopes
|
||||
]
|
||||
access_scopes = [s.replace("$group", f"{group.name}") for s in access_scopes]
|
||||
other_user = create_user_with_scopes(*access_scopes, name="burgess")
|
||||
|
||||
cookies = await app.login_user('burgess')
|
||||
|
Reference in New Issue
Block a user