Added unit tests and fixed bugs in scope filter

This commit is contained in:
Omar Richardson
2021-01-04 22:44:23 +01:00
parent f4ba57b1d7
commit 82bebfaff2
5 changed files with 139 additions and 15 deletions

View File

@@ -30,6 +30,8 @@ from tornado.httpclient import HTTPError
from tornado.log import app_log
from tornado.platform.asyncio import to_asyncio_future
from . import orm # todo: only necessary for scopes, move later
def random_port():
"""Get a single random port."""
@@ -347,25 +349,45 @@ def check_user_in_expanded_scope(handler, user_name, scope_group_names):
def _flatten_groups(groups):
user_set = {}
user_set = set()
for group in groups:
user_set |= group.users
user_set |= {
user.name for user in group.users
} # todo: I think this could be one query, no for loop
return user_set
def _get_scope_filter(req_scope, sub_scope):
def get_orm_class(kind):
class_dict = {
'users': orm.User,
'services': orm.Service,
'tokens': orm.APIToken,
'groups': orm.Group,
}
if kind not in class_dict:
raise ValueError(
"Kind must be one of %s, not %s" % (", ".join(class_dict), kind)
)
return class_dict[kind]
def _get_scope_filter(db, req_scope, sub_scope):
# Rough draft
scope_translator = {
'read:users': 'users',
'users': 'users',
'read:services': 'services',
'read:groups': 'groups',
}
if req_scope not in scope_translator:
raise AttributeError("Scope not found")
raise AttributeError("Scope not found; scope filter not constructed")
kind = scope_translator[req_scope]
scope_filter = None # todo: orm.Class(kind).find
if 'group' in sub_scope and kind == 'user':
scope_filter += _flatten_groups(sub_scope['group'])
Class = get_orm_class(kind)
sub_scope_values = next(iter(sub_scope.values()))
query = db.query(Class).filter(Class.name.in_(sub_scope_values))
scope_filter = [entry.name for entry in query.all()]
if 'group' in sub_scope and kind == 'users':
groups = db.query(orm.Group).filter(orm.Group.name.in_(sub_scope['group']))
scope_filter += _flatten_groups(groups)
return set(scope_filter)
@@ -380,9 +402,8 @@ def check_scope(api_handler, req_scope, scopes, **kwargs):
# Apply filters
sub_scope = scopes[req_scope]
if 'scope_filter' in kwargs:
scope_filter = _get_scope_filter(req_scope, sub_scope)
kwargs['scope_filter'] = scope_filter
return True
scope_filter = _get_scope_filter(api_handler.db, req_scope, sub_scope)
return scope_filter
else:
# Interface change: Now can have multiple filters
for (filter_, filter_value) in kwargs.items():
@@ -449,11 +470,16 @@ def needs_scope(scope):
if resource_name in bound_sig.arguments:
resource_value = bound_sig.arguments[resource_name]
s_kwargs[resource] = resource_value
if 'scope_filter' in bound_sig.arguments:
s_kwargs['scope_filter'] = None
if 'all' in self.scopes and self.current_user:
# todo: What if no user is found? See test_api/test_referer_check
self.scopes |= get_user_scopes(self.current_user.name)
parsed_scopes = parse_scopes(self.scopes)
if check_scope(self, scope, parsed_scopes, **s_kwargs):
scope_filter = check_scope(self, scope, parsed_scopes, **s_kwargs)
if scope_filter:
if isinstance(scope_filter, set):
kwargs['scope_filter'] = scope_filter
return func(self, *args, **kwargs)
else:
# catching attr error occurring for older_requirements test