mirror of
https://github.com/jupyterhub/jupyterhub.git
synced 2025-10-19 07:53:00 +00:00
Added unit tests and fixed bugs in scope filter
This commit is contained in:
@@ -40,7 +40,7 @@ class GroupListAPIHandler(_GroupAPIHandler):
|
||||
"""List groups"""
|
||||
groups = self.db.query(orm.Group)
|
||||
if scope_filter is not None:
|
||||
groups.filter(orm.Group.name._in(scope_filter))
|
||||
groups = groups.filter(orm.Group.name.in_(scope_filter))
|
||||
data = [self.group_model(g) for g in groups]
|
||||
self.write(json.dumps(data))
|
||||
|
||||
|
@@ -93,7 +93,7 @@ class UserListAPIHandler(APIHandler):
|
||||
# no filter, return all users
|
||||
query = self.db.query(orm.User)
|
||||
if scope_filter is not None:
|
||||
query.filter(orm.User.name.in_(scope_filter))
|
||||
query = query.filter(orm.User.name.in_(scope_filter))
|
||||
|
||||
data = [
|
||||
self.user_model(u, include_servers=True, include_state=True)
|
||||
|
@@ -135,7 +135,7 @@ def add_role(db, role_dict):
|
||||
db.commit()
|
||||
|
||||
|
||||
def get_orm_class(kind):
|
||||
def get_orm_class(kind): # Todo: merge and move to orm.py
|
||||
if kind == 'users':
|
||||
Class = orm.User
|
||||
elif kind == 'services':
|
||||
|
@@ -1,4 +1,5 @@
|
||||
"""Test scopes for API handlers"""
|
||||
import json
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
@@ -214,3 +215,100 @@ async def test_expand_groups(app, user_name, in_group, status_code):
|
||||
app, 'users', user_name, headers=auth_header(app.db, user_name)
|
||||
)
|
||||
assert r.status_code == status_code
|
||||
|
||||
|
||||
async def test_user_filter(app):
|
||||
user_name = 'rollerblade'
|
||||
test_role = {
|
||||
'name': 'test',
|
||||
'description': '',
|
||||
'users': [user_name],
|
||||
'scopes': [
|
||||
'read:users!user=lindsay',
|
||||
'read:users!user=gob',
|
||||
'read:users!user=oscar',
|
||||
],
|
||||
}
|
||||
roles.add_role(app.db, test_role)
|
||||
name_in_scope = {'lindsay', 'oscar', 'gob'}
|
||||
outside_scope = {'maeby', 'marta'}
|
||||
group_name = 'bluth'
|
||||
group = orm.Group.find(app.db, name=group_name)
|
||||
if not group:
|
||||
group = orm.Group(name=group_name)
|
||||
app.db.add(group)
|
||||
for name in name_in_scope | outside_scope:
|
||||
user = add_user(app.db, name=name)
|
||||
if name not in group.users:
|
||||
group.users.append(user)
|
||||
kind = 'users'
|
||||
user = add_user(app.db, name=user_name)
|
||||
roles.update_roles(app.db, user, kind, roles=['test'])
|
||||
roles.remove_obj(app.db, user_name, kind, 'user')
|
||||
app.db.commit()
|
||||
r = await api_request(app, 'users', headers=auth_header(app.db, user_name))
|
||||
assert r.status_code == 200
|
||||
data = json.loads(r.content)
|
||||
result_names = {user['name'] for user in data}
|
||||
assert result_names == name_in_scope
|
||||
|
||||
|
||||
async def test_user_filter_with_group(app): # todo: Move role setup to setup method
|
||||
user_name = 'sally'
|
||||
test_role = {
|
||||
'name': 'test',
|
||||
'description': '',
|
||||
'users': [user_name],
|
||||
'scopes': ['read:users!group=sitwell'],
|
||||
}
|
||||
roles.add_role(app.db, test_role)
|
||||
user = add_user(app.db, name=user_name)
|
||||
name_set = {'sally', 'stan'}
|
||||
group_name = 'sitwell'
|
||||
group = orm.Group.find(app.db, name=group_name)
|
||||
if not group:
|
||||
group = orm.Group(name=group_name)
|
||||
app.db.add(group)
|
||||
for name in name_set:
|
||||
user = add_user(app.db, name=name)
|
||||
if name not in group.users:
|
||||
group.users.append(user)
|
||||
kind = 'users'
|
||||
roles.update_roles(app.db, user, kind, roles=['test'])
|
||||
roles.remove_obj(app.db, user_name, kind, 'user')
|
||||
app.db.commit()
|
||||
r = await api_request(app, 'users', headers=auth_header(app.db, user_name))
|
||||
assert r.status_code == 200
|
||||
data = json.loads(r.content)
|
||||
result_names = {user['name'] for user in data}
|
||||
assert result_names == name_set
|
||||
|
||||
|
||||
async def test_group_scope_filter(app):
|
||||
user_name = 'rollerblade'
|
||||
test_role = {
|
||||
'name': 'test',
|
||||
'description': '',
|
||||
'users': [user_name],
|
||||
'scopes': [
|
||||
'read:groups!group=sitwell',
|
||||
'read:groups!group=bluths',
|
||||
],
|
||||
}
|
||||
roles.add_role(app.db, test_role)
|
||||
user = add_user(app.db, name=user_name)
|
||||
group_set = {'sitwell', 'bluths', 'austero'}
|
||||
for group_name in group_set:
|
||||
group = orm.Group.find(app.db, name=group_name)
|
||||
if not group:
|
||||
group = orm.Group(name=group_name)
|
||||
app.db.add(group)
|
||||
kind = 'users'
|
||||
roles.update_roles(app.db, user, kind, roles=['test'])
|
||||
roles.remove_obj(app.db, user_name, kind, 'user')
|
||||
app.db.commit()
|
||||
r = await api_request(app, 'groups', headers=auth_header(app.db, user_name))
|
||||
assert r.status_code == 200
|
||||
data = json.loads(r.content)
|
||||
result_names = {user['name'] for user in data}
|
||||
assert result_names == {'sitwell', 'bluths'}
|
||||
|
@@ -30,6 +30,8 @@ from tornado.httpclient import HTTPError
|
||||
from tornado.log import app_log
|
||||
from tornado.platform.asyncio import to_asyncio_future
|
||||
|
||||
from . import orm # todo: only necessary for scopes, move later
|
||||
|
||||
|
||||
def random_port():
|
||||
"""Get a single random port."""
|
||||
@@ -347,25 +349,45 @@ def check_user_in_expanded_scope(handler, user_name, scope_group_names):
|
||||
|
||||
|
||||
def _flatten_groups(groups):
|
||||
user_set = {}
|
||||
user_set = set()
|
||||
for group in groups:
|
||||
user_set |= group.users
|
||||
user_set |= {
|
||||
user.name for user in group.users
|
||||
} # todo: I think this could be one query, no for loop
|
||||
return user_set
|
||||
|
||||
|
||||
def _get_scope_filter(req_scope, sub_scope):
|
||||
def get_orm_class(kind):
|
||||
class_dict = {
|
||||
'users': orm.User,
|
||||
'services': orm.Service,
|
||||
'tokens': orm.APIToken,
|
||||
'groups': orm.Group,
|
||||
}
|
||||
if kind not in class_dict:
|
||||
raise ValueError(
|
||||
"Kind must be one of %s, not %s" % (", ".join(class_dict), kind)
|
||||
)
|
||||
return class_dict[kind]
|
||||
|
||||
|
||||
def _get_scope_filter(db, req_scope, sub_scope):
|
||||
# Rough draft
|
||||
scope_translator = {
|
||||
'read:users': 'users',
|
||||
'users': 'users',
|
||||
'read:services': 'services',
|
||||
'read:groups': 'groups',
|
||||
}
|
||||
if req_scope not in scope_translator:
|
||||
raise AttributeError("Scope not found")
|
||||
raise AttributeError("Scope not found; scope filter not constructed")
|
||||
kind = scope_translator[req_scope]
|
||||
scope_filter = None # todo: orm.Class(kind).find
|
||||
if 'group' in sub_scope and kind == 'user':
|
||||
scope_filter += _flatten_groups(sub_scope['group'])
|
||||
Class = get_orm_class(kind)
|
||||
sub_scope_values = next(iter(sub_scope.values()))
|
||||
query = db.query(Class).filter(Class.name.in_(sub_scope_values))
|
||||
scope_filter = [entry.name for entry in query.all()]
|
||||
if 'group' in sub_scope and kind == 'users':
|
||||
groups = db.query(orm.Group).filter(orm.Group.name.in_(sub_scope['group']))
|
||||
scope_filter += _flatten_groups(groups)
|
||||
return set(scope_filter)
|
||||
|
||||
|
||||
@@ -380,9 +402,8 @@ def check_scope(api_handler, req_scope, scopes, **kwargs):
|
||||
# Apply filters
|
||||
sub_scope = scopes[req_scope]
|
||||
if 'scope_filter' in kwargs:
|
||||
scope_filter = _get_scope_filter(req_scope, sub_scope)
|
||||
kwargs['scope_filter'] = scope_filter
|
||||
return True
|
||||
scope_filter = _get_scope_filter(api_handler.db, req_scope, sub_scope)
|
||||
return scope_filter
|
||||
else:
|
||||
# Interface change: Now can have multiple filters
|
||||
for (filter_, filter_value) in kwargs.items():
|
||||
@@ -449,11 +470,16 @@ def needs_scope(scope):
|
||||
if resource_name in bound_sig.arguments:
|
||||
resource_value = bound_sig.arguments[resource_name]
|
||||
s_kwargs[resource] = resource_value
|
||||
if 'scope_filter' in bound_sig.arguments:
|
||||
s_kwargs['scope_filter'] = None
|
||||
if 'all' in self.scopes and self.current_user:
|
||||
# todo: What if no user is found? See test_api/test_referer_check
|
||||
self.scopes |= get_user_scopes(self.current_user.name)
|
||||
parsed_scopes = parse_scopes(self.scopes)
|
||||
if check_scope(self, scope, parsed_scopes, **s_kwargs):
|
||||
scope_filter = check_scope(self, scope, parsed_scopes, **s_kwargs)
|
||||
if scope_filter:
|
||||
if isinstance(scope_filter, set):
|
||||
kwargs['scope_filter'] = scope_filter
|
||||
return func(self, *args, **kwargs)
|
||||
else:
|
||||
# catching attr error occurring for older_requirements test
|
||||
|
Reference in New Issue
Block a user