Update with expand group test

This commit is contained in:
Omar Richardson
2020-11-19 09:57:50 +01:00
parent 54cb31b3a9
commit 71d99e1180
3 changed files with 31 additions and 11 deletions

View File

@@ -27,7 +27,6 @@ from .utils import api_request
from .utils import async_requests from .utils import async_requests
from .utils import auth_header from .utils import auth_header
from .utils import find_user from .utils import find_user
from .utils import get_scopes
# -------------------- # --------------------

View File

@@ -3,10 +3,13 @@ import pytest
from pytest import mark from pytest import mark
from tornado import web from tornado import web
from .. import orm
from ..utils import check_scope from ..utils import check_scope
from ..utils import needs_scope from ..utils import needs_scope
from ..utils import parse_scopes from ..utils import parse_scopes
from ..utils import Scope from ..utils import Scope
from .utils import api_request
from .utils import auth_header
def test_scope_constructor(): def test_scope_constructor():
@@ -74,7 +77,7 @@ def test_scope_parse_server_name():
) )
class MockAPI: class MockAPIHandler:
def __init__(self): def __init__(self):
self.scopes = ['users'] self.scopes = ['users']
@@ -152,7 +155,7 @@ class MockAPI:
], ],
) )
def test_scope_method_access(scopes, method, arguments, is_allowed): def test_scope_method_access(scopes, method, arguments, is_allowed):
obj = MockAPI() obj = MockAPIHandler()
obj.scopes = scopes obj.scopes = scopes
api_call = getattr(obj, method) api_call = getattr(obj, method)
if is_allowed: if is_allowed:
@@ -160,3 +163,19 @@ def test_scope_method_access(scopes, method, arguments, is_allowed):
else: else:
with pytest.raises(web.HTTPError): with pytest.raises(web.HTTPError):
api_call(*arguments) api_call(*arguments)
async def test_expand_groups(app):
db = app.db
user = orm.User(name='gob')
group = orm.Group(name='bluth')
db.add(group)
db.add(user)
group.users.append(user)
db.commit()
scopes = ['read:users!user=micheal', 'read:users!group=bluth', 'read:groups']
app.tornado_settings['mock_scopes'] = scopes
r = await api_request(app, 'users', 'micheal', headers=auth_header(db, 'micheal'))
assert r.status_code == 200
r = await api_request(app, 'users', 'gob', headers=auth_header(db, 'user'))
assert r.status_code == 200

View File

@@ -307,7 +307,6 @@ def needs_scope_expansion(filter_, filter_value, sub_scope):
""" """
Check if there is a requirements to expand the `group` scope to individual `user` scopes. Check if there is a requirements to expand the `group` scope to individual `user` scopes.
Assumptions: Assumptions:
req_scopes in scopes
filter_ != Scope.ALL filter_ != Scope.ALL
This can be made arbitrarily intelligent but that would make it more complex This can be made arbitrarily intelligent but that would make it more complex
@@ -334,7 +333,7 @@ def check_user_in_expanded_scope(handler, user_name, scope_group_names):
if user is None: if user is None:
raise web.HTTPError(404, 'No such user found') raise web.HTTPError(404, 'No such user found')
group_names = {group.name for group in user.groups} group_names = {group.name for group in user.groups}
return bool(scope_group_names & group_names) return bool(set(scope_group_names) & group_names)
def check_scope(api_handler, req_scope, scopes, **kwargs): def check_scope(api_handler, req_scope, scopes, **kwargs):
@@ -354,12 +353,15 @@ def check_scope(api_handler, req_scope, scopes, **kwargs):
filter_, filter_value = list(kwargs.items())[0] filter_, filter_value = list(kwargs.items())[0]
sub_scope = scopes[req_scope] sub_scope = scopes[req_scope]
if filter_ not in sub_scope: if filter_ not in sub_scope:
if needs_scope_expansion(filter_, filter_value, sub_scope): valid_scope = False
group_names = sub_scope['groups'] else:
return check_user_in_expanded_scope(api_handler, filter_value, group_names) valid_scope = filter_value in sub_scope[filter_]
else: if not valid_scope and needs_scope_expansion(filter_, filter_value, sub_scope):
return False group_names = sub_scope['group']
return filter_value in sub_scope[filter_] valid_scope |= check_user_in_expanded_scope(
api_handler, filter_value, group_names
)
return valid_scope
def parse_scopes(scope_list): def parse_scopes(scope_list):