Added unit tests and fixed bugs in scope filter

This commit is contained in:
Omar Richardson
2021-01-04 22:44:23 +01:00
parent f4ba57b1d7
commit 82bebfaff2
5 changed files with 139 additions and 15 deletions

View File

@@ -40,7 +40,7 @@ class GroupListAPIHandler(_GroupAPIHandler):
"""List groups"""
groups = self.db.query(orm.Group)
if scope_filter is not None:
groups.filter(orm.Group.name._in(scope_filter))
groups = groups.filter(orm.Group.name.in_(scope_filter))
data = [self.group_model(g) for g in groups]
self.write(json.dumps(data))

View File

@@ -93,7 +93,7 @@ class UserListAPIHandler(APIHandler):
# no filter, return all users
query = self.db.query(orm.User)
if scope_filter is not None:
query.filter(orm.User.name.in_(scope_filter))
query = query.filter(orm.User.name.in_(scope_filter))
data = [
self.user_model(u, include_servers=True, include_state=True)

View File

@@ -135,7 +135,7 @@ def add_role(db, role_dict):
db.commit()
def get_orm_class(kind):
def get_orm_class(kind): # Todo: merge and move to orm.py
if kind == 'users':
Class = orm.User
elif kind == 'services':

View File

@@ -1,4 +1,5 @@
"""Test scopes for API handlers"""
import json
from unittest import mock
import pytest
@@ -214,3 +215,100 @@ async def test_expand_groups(app, user_name, in_group, status_code):
app, 'users', user_name, headers=auth_header(app.db, user_name)
)
assert r.status_code == status_code
async def test_user_filter(app):
user_name = 'rollerblade'
test_role = {
'name': 'test',
'description': '',
'users': [user_name],
'scopes': [
'read:users!user=lindsay',
'read:users!user=gob',
'read:users!user=oscar',
],
}
roles.add_role(app.db, test_role)
name_in_scope = {'lindsay', 'oscar', 'gob'}
outside_scope = {'maeby', 'marta'}
group_name = 'bluth'
group = orm.Group.find(app.db, name=group_name)
if not group:
group = orm.Group(name=group_name)
app.db.add(group)
for name in name_in_scope | outside_scope:
user = add_user(app.db, name=name)
if name not in group.users:
group.users.append(user)
kind = 'users'
user = add_user(app.db, name=user_name)
roles.update_roles(app.db, user, kind, roles=['test'])
roles.remove_obj(app.db, user_name, kind, 'user')
app.db.commit()
r = await api_request(app, 'users', headers=auth_header(app.db, user_name))
assert r.status_code == 200
data = json.loads(r.content)
result_names = {user['name'] for user in data}
assert result_names == name_in_scope
async def test_user_filter_with_group(app): # todo: Move role setup to setup method
user_name = 'sally'
test_role = {
'name': 'test',
'description': '',
'users': [user_name],
'scopes': ['read:users!group=sitwell'],
}
roles.add_role(app.db, test_role)
user = add_user(app.db, name=user_name)
name_set = {'sally', 'stan'}
group_name = 'sitwell'
group = orm.Group.find(app.db, name=group_name)
if not group:
group = orm.Group(name=group_name)
app.db.add(group)
for name in name_set:
user = add_user(app.db, name=name)
if name not in group.users:
group.users.append(user)
kind = 'users'
roles.update_roles(app.db, user, kind, roles=['test'])
roles.remove_obj(app.db, user_name, kind, 'user')
app.db.commit()
r = await api_request(app, 'users', headers=auth_header(app.db, user_name))
assert r.status_code == 200
data = json.loads(r.content)
result_names = {user['name'] for user in data}
assert result_names == name_set
async def test_group_scope_filter(app):
user_name = 'rollerblade'
test_role = {
'name': 'test',
'description': '',
'users': [user_name],
'scopes': [
'read:groups!group=sitwell',
'read:groups!group=bluths',
],
}
roles.add_role(app.db, test_role)
user = add_user(app.db, name=user_name)
group_set = {'sitwell', 'bluths', 'austero'}
for group_name in group_set:
group = orm.Group.find(app.db, name=group_name)
if not group:
group = orm.Group(name=group_name)
app.db.add(group)
kind = 'users'
roles.update_roles(app.db, user, kind, roles=['test'])
roles.remove_obj(app.db, user_name, kind, 'user')
app.db.commit()
r = await api_request(app, 'groups', headers=auth_header(app.db, user_name))
assert r.status_code == 200
data = json.loads(r.content)
result_names = {user['name'] for user in data}
assert result_names == {'sitwell', 'bluths'}

View File

@@ -30,6 +30,8 @@ from tornado.httpclient import HTTPError
from tornado.log import app_log
from tornado.platform.asyncio import to_asyncio_future
from . import orm # todo: only necessary for scopes, move later
def random_port():
"""Get a single random port."""
@@ -347,25 +349,45 @@ def check_user_in_expanded_scope(handler, user_name, scope_group_names):
def _flatten_groups(groups):
user_set = {}
user_set = set()
for group in groups:
user_set |= group.users
user_set |= {
user.name for user in group.users
} # todo: I think this could be one query, no for loop
return user_set
def _get_scope_filter(req_scope, sub_scope):
def get_orm_class(kind):
class_dict = {
'users': orm.User,
'services': orm.Service,
'tokens': orm.APIToken,
'groups': orm.Group,
}
if kind not in class_dict:
raise ValueError(
"Kind must be one of %s, not %s" % (", ".join(class_dict), kind)
)
return class_dict[kind]
def _get_scope_filter(db, req_scope, sub_scope):
# Rough draft
scope_translator = {
'read:users': 'users',
'users': 'users',
'read:services': 'services',
'read:groups': 'groups',
}
if req_scope not in scope_translator:
raise AttributeError("Scope not found")
raise AttributeError("Scope not found; scope filter not constructed")
kind = scope_translator[req_scope]
scope_filter = None # todo: orm.Class(kind).find
if 'group' in sub_scope and kind == 'user':
scope_filter += _flatten_groups(sub_scope['group'])
Class = get_orm_class(kind)
sub_scope_values = next(iter(sub_scope.values()))
query = db.query(Class).filter(Class.name.in_(sub_scope_values))
scope_filter = [entry.name for entry in query.all()]
if 'group' in sub_scope and kind == 'users':
groups = db.query(orm.Group).filter(orm.Group.name.in_(sub_scope['group']))
scope_filter += _flatten_groups(groups)
return set(scope_filter)
@@ -380,9 +402,8 @@ def check_scope(api_handler, req_scope, scopes, **kwargs):
# Apply filters
sub_scope = scopes[req_scope]
if 'scope_filter' in kwargs:
scope_filter = _get_scope_filter(req_scope, sub_scope)
kwargs['scope_filter'] = scope_filter
return True
scope_filter = _get_scope_filter(api_handler.db, req_scope, sub_scope)
return scope_filter
else:
# Interface change: Now can have multiple filters
for (filter_, filter_value) in kwargs.items():
@@ -449,11 +470,16 @@ def needs_scope(scope):
if resource_name in bound_sig.arguments:
resource_value = bound_sig.arguments[resource_name]
s_kwargs[resource] = resource_value
if 'scope_filter' in bound_sig.arguments:
s_kwargs['scope_filter'] = None
if 'all' in self.scopes and self.current_user:
# todo: What if no user is found? See test_api/test_referer_check
self.scopes |= get_user_scopes(self.current_user.name)
parsed_scopes = parse_scopes(self.scopes)
if check_scope(self, scope, parsed_scopes, **s_kwargs):
scope_filter = check_scope(self, scope, parsed_scopes, **s_kwargs)
if scope_filter:
if isinstance(scope_filter, set):
kwargs['scope_filter'] = scope_filter
return func(self, *args, **kwargs)
else:
# catching attr error occurring for older_requirements test