support groups in _intersect_scopes

Requires db resolution
This commit is contained in:
Min RK
2021-05-21 15:23:31 +02:00
parent 40de16e0e1
commit fbea31d00a
4 changed files with 184 additions and 31 deletions

View File

@@ -13,7 +13,9 @@ import functools
import inspect
import warnings
from enum import Enum
from functools import lru_cache
import sqlalchemy as sa
from tornado import web
from tornado.log import app_log
@@ -114,21 +116,38 @@ class Scope(Enum):
ALL = True
def _intersect_expanded_scopes(scopes_a, scopes_b):
"""Intersect two sets of expanded scopes by comparing their permissions
def _intersect_expanded_scopes(scopes_a, scopes_b, db=None):
"""Intersect two sets of scopes by comparing their permissions
Arguments:
scopes_a, scopes_b: sets of expanded scopes
db (optional): db connection for resolving group membership
Returns:
intersection: set of expanded scopes as intersection of the arguments
Note: Intersects correctly with ALL and exact filter matches
(i.e. users!user=x & read:users:name -> read:users:name!user=x)
Does not currently intersect with containing filters
(i.e. users!group=x & users!user=y even if user y is in group x)
If db is given, group membership will be accounted for in intersections,
Otherwise, it can result in lower than intended permissions,
(i.e. users!group=x & users!user=y will be empty, even if user y is in group x.)
"""
empty_set = frozenset()
# cached lookups for group membership of users and servers
@lru_cache()
def groups_for_user(username):
"""Get set of group names for a given username"""
user = db.query(orm.User).filter_by(name=username).first()
if user is None:
return empty_set
else:
return {group.name for group in user.groups}
@lru_cache()
def groups_for_server(server):
"""Get set of group names for a given server"""
username, _, servername = server.partition("/")
return groups_for_user(username)
parsed_scopes_a = parse_scopes(scopes_a)
parsed_scopes_b = parse_scopes(scopes_b)
@@ -144,11 +163,14 @@ def _intersect_expanded_scopes(scopes_a, scopes_b):
elif filters_b == Scope.ALL:
common_filters[base] = filters_a
else:
# warn *if* there are non-overlapping user= and group= filters
common_entities = filters_a.keys() & filters_b.keys()
all_entities = filters_a.keys() | filters_b.keys()
# if we don't have a db session, we can't check group membership
# warn *if* there are non-overlapping user= and group= filters that we can't check
if (
not warned
db is None
and not warned
and 'group' in all_entities
and ('user' in all_entities or 'server' in all_entities)
):
@@ -160,39 +182,67 @@ def _intersect_expanded_scopes(scopes_a, scopes_b):
not warned
and "group" in a
and b_key in b
and set(a["group"]).difference(b.get("group", []))
and set(b[b_key]).difference(a.get(b_key, []))
and a["group"].difference(b.get("group", []))
and b[b_key].difference(a.get(b_key, []))
):
warnings.warn(
f"{base}[!{b_key}={b[b_key]}, !group={a['group']}] combinations of filters present,"
" intersection between not considered. May result in lower than intended permissions.",
" without db access. Intersection between not considered."
" May result in lower than intended permissions.",
UserWarning,
)
warned = True
common_filters[base] = {
entity: set(parsed_scopes_a[base][entity])
& set(parsed_scopes_b[base][entity])
entity: filters_a[entity] & filters_b[entity]
for entity in common_entities
}
if 'server' in all_entities and 'user' in all_entities:
if filters_a.get('server') == filters_b.get('server'):
continue
# resolve hierarchies (group/user/server) in both directions
common_servers = common_filters[base].get("server", set())
common_users = common_filters[base].get("user", set())
additional_servers = set()
# resolve user/server hierarchy in both directions
for a, b in [(filters_a, filters_b), (filters_b, filters_a)]:
if 'server' in a and 'user' in b:
for server in a['server']:
if 'server' in a and b.get('server') != a['server']:
# skip already-added servers (includes overlapping servers)
servers = a['server'].difference(common_servers)
# resolve user/server hierarchy
if servers and 'user' in b:
for server in servers:
username, _, servername = server.partition("/")
if username in b['user']:
additional_servers.add(server)
common_servers.add(server)
if additional_servers:
if "server" not in common_filters[base]:
common_filters[base]["server"] = set()
common_filters[base]["server"].update(additional_servers)
# resolve group/server hierarchy if db available
servers = servers.difference(common_servers)
if db is not None and servers and 'group' in b:
for server in servers:
server_groups = groups_for_server(server)
if server_groups & b['group']:
common_servers.add(server)
# resolve group/user hierarchy if db available and user sets aren't identical
if (
db is not None
and 'user' in a
and 'group' in b
and b.get('user') != a['user']
):
# skip already-added users (includes overlapping users)
users = a['user'].difference(common_users)
for username in users:
groups = groups_for_user(username)
if groups & b["group"]:
common_users.add(username)
# add server filter if there wasn't one before
if common_servers and "server" not in common_filters[base]:
common_filters[base]["server"] = common_servers
# add user filter if it's non-empty and there wasn't one before
if common_users and "user" not in common_filters[base]:
common_filters[base]["user"] = common_users
return unparse_scopes(common_filters)
@@ -244,11 +294,21 @@ def get_scopes_for(orm_object):
)
owner_scopes = roles.expand_roles_to_scopes(owner)
if token_scopes == {'all'}:
# token_scopes is only 'all', return owner scopes as-is
# short-circuit common case where we don't need to compute an intersection
return owner_scopes
if 'all' in token_scopes:
token_scopes.remove('all')
token_scopes |= owner_scopes
intersection = _intersect_expanded_scopes(token_scopes, owner_scopes)
intersection = _intersect_expanded_scopes(
token_scopes,
owner_scopes,
db=sa.inspect(orm_object).session,
)
discarded_token_scopes = token_scopes - intersection
# Not taking symmetric difference here because token owner can naturally have more scopes than token
@@ -473,7 +533,8 @@ def check_scope_filter(sub_scope, orm_resource, kind):
# Fall back on checking if we have group access for this user
orm_resource = orm_resource.user
kind = 'user'
elif kind == 'user' and 'group' in sub_scope:
if kind == 'user' and 'group' in sub_scope:
group_names = {group.name for group in orm_resource.groups}
user_in_group = bool(group_names & set(sub_scope['group']))
if user_in_group:

View File

@@ -38,7 +38,7 @@ from traitlets import Unicode
from traitlets import validate
from traitlets.config import SingletonConfigurable
from ..scopes import _intersect_scopes
from ..scopes import _intersect_expanded_scopes
from ..utils import url_path_join
@@ -70,7 +70,7 @@ def check_scopes(required_scopes, scopes):
if isinstance(required_scopes, str):
required_scopes = {required_scopes}
intersection = _intersect_scopes(required_scopes, scopes)
intersection = _intersect_expanded_scopes(required_scopes, scopes)
# re-intersect with required_scopes in case the intersection
# applies stricter filters than required_scopes declares
# e.g. required_scopes = {'read:users'} and intersection has only {'read:users!user=x'}

View File

@@ -763,3 +763,81 @@ def test_intersect_expanded_scopes(left, right, expected, should_warn, recwarn):
assert len(recwarn) == 1
else:
assert len(recwarn) == 0
@pytest.mark.parametrize(
"left, right, expected, groups",
[
(
["users!group=gx"],
["users!user=ux"],
["users!user=ux"],
{"gx": ["ux"]},
),
(
["read:users!group=gx"],
["read:users!user=nosuchuser"],
[],
{},
),
(
["read:users!group=gx"],
["read:users!server=nosuchuser/server"],
[],
{},
),
(
["read:users!group=gx"],
["read:users!server=ux/server"],
["read:users!server=ux/server"],
{"gx": ["ux"]},
),
(
["read:users!group=gx"],
["read:users!server=ux/server", "read:users!user=uy"],
["read:users!server=ux/server"],
{"gx": ["ux"], "gy": ["uy"]},
),
(
["read:users!group=gy"],
["read:users!server=ux/server", "read:users!user=uy"],
["read:users!user=uy"],
{"gx": ["ux"], "gy": ["uy"]},
),
],
)
def test_intersect_groups(request, db, left, right, expected, groups):
if isinstance(left, str):
left = set([left])
if isinstance(right, str):
right = set([right])
# if we have a db connection, we can actually resolve
created = []
for groupname, members in groups.items():
group = orm.Group.find(db, name=groupname)
if not group:
group = orm.Group(name=groupname)
db.add(group)
created.append(group)
db.commit()
for username in members:
user = orm.User.find(db, name=username)
if user is None:
user = orm.User(name=username)
db.add(user)
created.append(user)
user.groups.append(group)
db.commit()
def _cleanup():
for obj in created:
db.delete(obj)
db.commit()
request.addfinalizer(_cleanup)
# run every test in both directions, to ensure symmetry of the inputs
for a, b in [(left, right), (right, left)]:
intersection = _intersect_expanded_scopes(set(left), set(right), db)
assert intersection == set(expected)

View File

@@ -6,6 +6,7 @@ from urllib.parse import urlparse
import pytest
import jupyterhub
from .. import orm
from ..utils import url_path_join
from .mocking import public_url
from .mocking import StubSingleUserSpawner
@@ -16,6 +17,8 @@ from .utils import AsyncSession
@pytest.mark.parametrize(
"access_scopes, server_name, expect_success",
[
(["access:users:servers!group=$group"], "", True),
(["access:users:servers!group=other-group"], "", False),
(["access:users:servers"], "", True),
(["access:users:servers"], "named", True),
(["access:users:servers!user=$user"], "", True),
@@ -46,6 +49,16 @@ async def test_singleuser_auth(
# login, start the server
cookies = await app.login_user('nandy')
user = app.users['nandy']
group = orm.Group.find(app.db, name="visitors")
if group is None:
group = orm.Group(name="visitors")
app.db.add(group)
app.db.commit()
if group not in user.groups:
user.groups.append(group)
app.db.commit()
if server_name not in user.spawners or not user.spawners[server_name].active:
await user.spawn(server_name)
await app.proxy.add_user(user, server_name)
@@ -85,6 +98,7 @@ async def test_singleuser_auth(
access_scopes = [
s.replace("$server", f"{user.name}/{server_name}") for s in access_scopes
]
access_scopes = [s.replace("$group", f"{group.name}") for s in access_scopes]
other_user = create_user_with_scopes(*access_scopes, name="burgess")
cookies = await app.login_user('burgess')