mirror of
https://github.com/jupyterhub/jupyterhub.git
synced 2025-10-08 10:34:10 +00:00
support groups in _intersect_scopes
Requires db resolution
This commit is contained in:
@@ -13,7 +13,9 @@ import functools
|
|||||||
import inspect
|
import inspect
|
||||||
import warnings
|
import warnings
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
from functools import lru_cache
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
from tornado import web
|
from tornado import web
|
||||||
from tornado.log import app_log
|
from tornado.log import app_log
|
||||||
|
|
||||||
@@ -114,21 +116,38 @@ class Scope(Enum):
|
|||||||
ALL = True
|
ALL = True
|
||||||
|
|
||||||
|
|
||||||
def _intersect_expanded_scopes(scopes_a, scopes_b):
|
def _intersect_expanded_scopes(scopes_a, scopes_b, db=None):
|
||||||
"""Intersect two sets of expanded scopes by comparing their permissions
|
"""Intersect two sets of scopes by comparing their permissions
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
scopes_a, scopes_b: sets of expanded scopes
|
scopes_a, scopes_b: sets of expanded scopes
|
||||||
|
db (optional): db connection for resolving group membership
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
intersection: set of expanded scopes as intersection of the arguments
|
intersection: set of expanded scopes as intersection of the arguments
|
||||||
|
|
||||||
Note: Intersects correctly with ALL and exact filter matches
|
If db is given, group membership will be accounted for in intersections,
|
||||||
(i.e. users!user=x & read:users:name -> read:users:name!user=x)
|
Otherwise, it can result in lower than intended permissions,
|
||||||
|
(i.e. users!group=x & users!user=y will be empty, even if user y is in group x.)
|
||||||
Does not currently intersect with containing filters
|
|
||||||
(i.e. users!group=x & users!user=y even if user y is in group x)
|
|
||||||
"""
|
"""
|
||||||
|
empty_set = frozenset()
|
||||||
|
|
||||||
|
# cached lookups for group membership of users and servers
|
||||||
|
@lru_cache()
|
||||||
|
def groups_for_user(username):
|
||||||
|
"""Get set of group names for a given username"""
|
||||||
|
user = db.query(orm.User).filter_by(name=username).first()
|
||||||
|
if user is None:
|
||||||
|
return empty_set
|
||||||
|
else:
|
||||||
|
return {group.name for group in user.groups}
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def groups_for_server(server):
|
||||||
|
"""Get set of group names for a given server"""
|
||||||
|
username, _, servername = server.partition("/")
|
||||||
|
return groups_for_user(username)
|
||||||
|
|
||||||
parsed_scopes_a = parse_scopes(scopes_a)
|
parsed_scopes_a = parse_scopes(scopes_a)
|
||||||
parsed_scopes_b = parse_scopes(scopes_b)
|
parsed_scopes_b = parse_scopes(scopes_b)
|
||||||
|
|
||||||
@@ -144,11 +163,14 @@ def _intersect_expanded_scopes(scopes_a, scopes_b):
|
|||||||
elif filters_b == Scope.ALL:
|
elif filters_b == Scope.ALL:
|
||||||
common_filters[base] = filters_a
|
common_filters[base] = filters_a
|
||||||
else:
|
else:
|
||||||
# warn *if* there are non-overlapping user= and group= filters
|
|
||||||
common_entities = filters_a.keys() & filters_b.keys()
|
common_entities = filters_a.keys() & filters_b.keys()
|
||||||
all_entities = filters_a.keys() | filters_b.keys()
|
all_entities = filters_a.keys() | filters_b.keys()
|
||||||
|
|
||||||
|
# if we don't have a db session, we can't check group membership
|
||||||
|
# warn *if* there are non-overlapping user= and group= filters that we can't check
|
||||||
if (
|
if (
|
||||||
not warned
|
db is None
|
||||||
|
and not warned
|
||||||
and 'group' in all_entities
|
and 'group' in all_entities
|
||||||
and ('user' in all_entities or 'server' in all_entities)
|
and ('user' in all_entities or 'server' in all_entities)
|
||||||
):
|
):
|
||||||
@@ -160,39 +182,67 @@ def _intersect_expanded_scopes(scopes_a, scopes_b):
|
|||||||
not warned
|
not warned
|
||||||
and "group" in a
|
and "group" in a
|
||||||
and b_key in b
|
and b_key in b
|
||||||
and set(a["group"]).difference(b.get("group", []))
|
and a["group"].difference(b.get("group", []))
|
||||||
and set(b[b_key]).difference(a.get(b_key, []))
|
and b[b_key].difference(a.get(b_key, []))
|
||||||
):
|
):
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
f"{base}[!{b_key}={b[b_key]}, !group={a['group']}] combinations of filters present,"
|
f"{base}[!{b_key}={b[b_key]}, !group={a['group']}] combinations of filters present,"
|
||||||
" intersection between not considered. May result in lower than intended permissions.",
|
" without db access. Intersection between not considered."
|
||||||
|
" May result in lower than intended permissions.",
|
||||||
UserWarning,
|
UserWarning,
|
||||||
)
|
)
|
||||||
warned = True
|
warned = True
|
||||||
|
|
||||||
common_filters[base] = {
|
common_filters[base] = {
|
||||||
entity: set(parsed_scopes_a[base][entity])
|
entity: filters_a[entity] & filters_b[entity]
|
||||||
& set(parsed_scopes_b[base][entity])
|
|
||||||
for entity in common_entities
|
for entity in common_entities
|
||||||
}
|
}
|
||||||
|
|
||||||
if 'server' in all_entities and 'user' in all_entities:
|
# resolve hierarchies (group/user/server) in both directions
|
||||||
if filters_a.get('server') == filters_b.get('server'):
|
common_servers = common_filters[base].get("server", set())
|
||||||
continue
|
common_users = common_filters[base].get("user", set())
|
||||||
|
|
||||||
additional_servers = set()
|
|
||||||
# resolve user/server hierarchy in both directions
|
|
||||||
for a, b in [(filters_a, filters_b), (filters_b, filters_a)]:
|
for a, b in [(filters_a, filters_b), (filters_b, filters_a)]:
|
||||||
if 'server' in a and 'user' in b:
|
if 'server' in a and b.get('server') != a['server']:
|
||||||
for server in a['server']:
|
# skip already-added servers (includes overlapping servers)
|
||||||
|
servers = a['server'].difference(common_servers)
|
||||||
|
|
||||||
|
# resolve user/server hierarchy
|
||||||
|
if servers and 'user' in b:
|
||||||
|
for server in servers:
|
||||||
username, _, servername = server.partition("/")
|
username, _, servername = server.partition("/")
|
||||||
if username in b['user']:
|
if username in b['user']:
|
||||||
additional_servers.add(server)
|
common_servers.add(server)
|
||||||
|
|
||||||
if additional_servers:
|
# resolve group/server hierarchy if db available
|
||||||
if "server" not in common_filters[base]:
|
servers = servers.difference(common_servers)
|
||||||
common_filters[base]["server"] = set()
|
if db is not None and servers and 'group' in b:
|
||||||
common_filters[base]["server"].update(additional_servers)
|
for server in servers:
|
||||||
|
server_groups = groups_for_server(server)
|
||||||
|
if server_groups & b['group']:
|
||||||
|
common_servers.add(server)
|
||||||
|
|
||||||
|
# resolve group/user hierarchy if db available and user sets aren't identical
|
||||||
|
if (
|
||||||
|
db is not None
|
||||||
|
and 'user' in a
|
||||||
|
and 'group' in b
|
||||||
|
and b.get('user') != a['user']
|
||||||
|
):
|
||||||
|
# skip already-added users (includes overlapping users)
|
||||||
|
users = a['user'].difference(common_users)
|
||||||
|
for username in users:
|
||||||
|
groups = groups_for_user(username)
|
||||||
|
if groups & b["group"]:
|
||||||
|
common_users.add(username)
|
||||||
|
|
||||||
|
# add server filter if there wasn't one before
|
||||||
|
if common_servers and "server" not in common_filters[base]:
|
||||||
|
common_filters[base]["server"] = common_servers
|
||||||
|
|
||||||
|
# add user filter if it's non-empty and there wasn't one before
|
||||||
|
if common_users and "user" not in common_filters[base]:
|
||||||
|
common_filters[base]["user"] = common_users
|
||||||
|
|
||||||
return unparse_scopes(common_filters)
|
return unparse_scopes(common_filters)
|
||||||
|
|
||||||
@@ -244,11 +294,21 @@ def get_scopes_for(orm_object):
|
|||||||
)
|
)
|
||||||
|
|
||||||
owner_scopes = roles.expand_roles_to_scopes(owner)
|
owner_scopes = roles.expand_roles_to_scopes(owner)
|
||||||
|
|
||||||
|
if token_scopes == {'all'}:
|
||||||
|
# token_scopes is only 'all', return owner scopes as-is
|
||||||
|
# short-circuit common case where we don't need to compute an intersection
|
||||||
|
return owner_scopes
|
||||||
|
|
||||||
if 'all' in token_scopes:
|
if 'all' in token_scopes:
|
||||||
token_scopes.remove('all')
|
token_scopes.remove('all')
|
||||||
token_scopes |= owner_scopes
|
token_scopes |= owner_scopes
|
||||||
|
|
||||||
intersection = _intersect_expanded_scopes(token_scopes, owner_scopes)
|
intersection = _intersect_expanded_scopes(
|
||||||
|
token_scopes,
|
||||||
|
owner_scopes,
|
||||||
|
db=sa.inspect(orm_object).session,
|
||||||
|
)
|
||||||
discarded_token_scopes = token_scopes - intersection
|
discarded_token_scopes = token_scopes - intersection
|
||||||
|
|
||||||
# Not taking symmetric difference here because token owner can naturally have more scopes than token
|
# Not taking symmetric difference here because token owner can naturally have more scopes than token
|
||||||
@@ -473,7 +533,8 @@ def check_scope_filter(sub_scope, orm_resource, kind):
|
|||||||
# Fall back on checking if we have group access for this user
|
# Fall back on checking if we have group access for this user
|
||||||
orm_resource = orm_resource.user
|
orm_resource = orm_resource.user
|
||||||
kind = 'user'
|
kind = 'user'
|
||||||
elif kind == 'user' and 'group' in sub_scope:
|
|
||||||
|
if kind == 'user' and 'group' in sub_scope:
|
||||||
group_names = {group.name for group in orm_resource.groups}
|
group_names = {group.name for group in orm_resource.groups}
|
||||||
user_in_group = bool(group_names & set(sub_scope['group']))
|
user_in_group = bool(group_names & set(sub_scope['group']))
|
||||||
if user_in_group:
|
if user_in_group:
|
||||||
|
@@ -38,7 +38,7 @@ from traitlets import Unicode
|
|||||||
from traitlets import validate
|
from traitlets import validate
|
||||||
from traitlets.config import SingletonConfigurable
|
from traitlets.config import SingletonConfigurable
|
||||||
|
|
||||||
from ..scopes import _intersect_scopes
|
from ..scopes import _intersect_expanded_scopes
|
||||||
from ..utils import url_path_join
|
from ..utils import url_path_join
|
||||||
|
|
||||||
|
|
||||||
@@ -70,7 +70,7 @@ def check_scopes(required_scopes, scopes):
|
|||||||
if isinstance(required_scopes, str):
|
if isinstance(required_scopes, str):
|
||||||
required_scopes = {required_scopes}
|
required_scopes = {required_scopes}
|
||||||
|
|
||||||
intersection = _intersect_scopes(required_scopes, scopes)
|
intersection = _intersect_expanded_scopes(required_scopes, scopes)
|
||||||
# re-intersect with required_scopes in case the intersection
|
# re-intersect with required_scopes in case the intersection
|
||||||
# applies stricter filters than required_scopes declares
|
# applies stricter filters than required_scopes declares
|
||||||
# e.g. required_scopes = {'read:users'} and intersection has only {'read:users!user=x'}
|
# e.g. required_scopes = {'read:users'} and intersection has only {'read:users!user=x'}
|
||||||
|
@@ -763,3 +763,81 @@ def test_intersect_expanded_scopes(left, right, expected, should_warn, recwarn):
|
|||||||
assert len(recwarn) == 1
|
assert len(recwarn) == 1
|
||||||
else:
|
else:
|
||||||
assert len(recwarn) == 0
|
assert len(recwarn) == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"left, right, expected, groups",
|
||||||
|
[
|
||||||
|
(
|
||||||
|
["users!group=gx"],
|
||||||
|
["users!user=ux"],
|
||||||
|
["users!user=ux"],
|
||||||
|
{"gx": ["ux"]},
|
||||||
|
),
|
||||||
|
(
|
||||||
|
["read:users!group=gx"],
|
||||||
|
["read:users!user=nosuchuser"],
|
||||||
|
[],
|
||||||
|
{},
|
||||||
|
),
|
||||||
|
(
|
||||||
|
["read:users!group=gx"],
|
||||||
|
["read:users!server=nosuchuser/server"],
|
||||||
|
[],
|
||||||
|
{},
|
||||||
|
),
|
||||||
|
(
|
||||||
|
["read:users!group=gx"],
|
||||||
|
["read:users!server=ux/server"],
|
||||||
|
["read:users!server=ux/server"],
|
||||||
|
{"gx": ["ux"]},
|
||||||
|
),
|
||||||
|
(
|
||||||
|
["read:users!group=gx"],
|
||||||
|
["read:users!server=ux/server", "read:users!user=uy"],
|
||||||
|
["read:users!server=ux/server"],
|
||||||
|
{"gx": ["ux"], "gy": ["uy"]},
|
||||||
|
),
|
||||||
|
(
|
||||||
|
["read:users!group=gy"],
|
||||||
|
["read:users!server=ux/server", "read:users!user=uy"],
|
||||||
|
["read:users!user=uy"],
|
||||||
|
{"gx": ["ux"], "gy": ["uy"]},
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_intersect_groups(request, db, left, right, expected, groups):
|
||||||
|
if isinstance(left, str):
|
||||||
|
left = set([left])
|
||||||
|
if isinstance(right, str):
|
||||||
|
right = set([right])
|
||||||
|
|
||||||
|
# if we have a db connection, we can actually resolve
|
||||||
|
created = []
|
||||||
|
for groupname, members in groups.items():
|
||||||
|
group = orm.Group.find(db, name=groupname)
|
||||||
|
if not group:
|
||||||
|
group = orm.Group(name=groupname)
|
||||||
|
db.add(group)
|
||||||
|
created.append(group)
|
||||||
|
db.commit()
|
||||||
|
for username in members:
|
||||||
|
user = orm.User.find(db, name=username)
|
||||||
|
if user is None:
|
||||||
|
user = orm.User(name=username)
|
||||||
|
db.add(user)
|
||||||
|
created.append(user)
|
||||||
|
user.groups.append(group)
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
def _cleanup():
|
||||||
|
for obj in created:
|
||||||
|
db.delete(obj)
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
request.addfinalizer(_cleanup)
|
||||||
|
|
||||||
|
# run every test in both directions, to ensure symmetry of the inputs
|
||||||
|
for a, b in [(left, right), (right, left)]:
|
||||||
|
intersection = _intersect_expanded_scopes(set(left), set(right), db)
|
||||||
|
assert intersection == set(expected)
|
||||||
|
@@ -6,6 +6,7 @@ from urllib.parse import urlparse
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
import jupyterhub
|
import jupyterhub
|
||||||
|
from .. import orm
|
||||||
from ..utils import url_path_join
|
from ..utils import url_path_join
|
||||||
from .mocking import public_url
|
from .mocking import public_url
|
||||||
from .mocking import StubSingleUserSpawner
|
from .mocking import StubSingleUserSpawner
|
||||||
@@ -16,6 +17,8 @@ from .utils import AsyncSession
|
|||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"access_scopes, server_name, expect_success",
|
"access_scopes, server_name, expect_success",
|
||||||
[
|
[
|
||||||
|
(["access:users:servers!group=$group"], "", True),
|
||||||
|
(["access:users:servers!group=other-group"], "", False),
|
||||||
(["access:users:servers"], "", True),
|
(["access:users:servers"], "", True),
|
||||||
(["access:users:servers"], "named", True),
|
(["access:users:servers"], "named", True),
|
||||||
(["access:users:servers!user=$user"], "", True),
|
(["access:users:servers!user=$user"], "", True),
|
||||||
@@ -46,6 +49,16 @@ async def test_singleuser_auth(
|
|||||||
# login, start the server
|
# login, start the server
|
||||||
cookies = await app.login_user('nandy')
|
cookies = await app.login_user('nandy')
|
||||||
user = app.users['nandy']
|
user = app.users['nandy']
|
||||||
|
|
||||||
|
group = orm.Group.find(app.db, name="visitors")
|
||||||
|
if group is None:
|
||||||
|
group = orm.Group(name="visitors")
|
||||||
|
app.db.add(group)
|
||||||
|
app.db.commit()
|
||||||
|
if group not in user.groups:
|
||||||
|
user.groups.append(group)
|
||||||
|
app.db.commit()
|
||||||
|
|
||||||
if server_name not in user.spawners or not user.spawners[server_name].active:
|
if server_name not in user.spawners or not user.spawners[server_name].active:
|
||||||
await user.spawn(server_name)
|
await user.spawn(server_name)
|
||||||
await app.proxy.add_user(user, server_name)
|
await app.proxy.add_user(user, server_name)
|
||||||
@@ -85,6 +98,7 @@ async def test_singleuser_auth(
|
|||||||
access_scopes = [
|
access_scopes = [
|
||||||
s.replace("$server", f"{user.name}/{server_name}") for s in access_scopes
|
s.replace("$server", f"{user.name}/{server_name}") for s in access_scopes
|
||||||
]
|
]
|
||||||
|
access_scopes = [s.replace("$group", f"{group.name}") for s in access_scopes]
|
||||||
other_user = create_user_with_scopes(*access_scopes, name="burgess")
|
other_user = create_user_with_scopes(*access_scopes, name="burgess")
|
||||||
|
|
||||||
cookies = await app.login_user('burgess')
|
cookies = await app.login_user('burgess')
|
||||||
|
Reference in New Issue
Block a user