mirror of
https://github.com/jupyterhub/jupyterhub.git
synced 2025-10-19 07:53:00 +00:00
Added unit tests and fixed bugs in scope filter
This commit is contained in:
@@ -40,7 +40,7 @@ class GroupListAPIHandler(_GroupAPIHandler):
|
|||||||
"""List groups"""
|
"""List groups"""
|
||||||
groups = self.db.query(orm.Group)
|
groups = self.db.query(orm.Group)
|
||||||
if scope_filter is not None:
|
if scope_filter is not None:
|
||||||
groups.filter(orm.Group.name._in(scope_filter))
|
groups = groups.filter(orm.Group.name.in_(scope_filter))
|
||||||
data = [self.group_model(g) for g in groups]
|
data = [self.group_model(g) for g in groups]
|
||||||
self.write(json.dumps(data))
|
self.write(json.dumps(data))
|
||||||
|
|
||||||
|
@@ -93,7 +93,7 @@ class UserListAPIHandler(APIHandler):
|
|||||||
# no filter, return all users
|
# no filter, return all users
|
||||||
query = self.db.query(orm.User)
|
query = self.db.query(orm.User)
|
||||||
if scope_filter is not None:
|
if scope_filter is not None:
|
||||||
query.filter(orm.User.name.in_(scope_filter))
|
query = query.filter(orm.User.name.in_(scope_filter))
|
||||||
|
|
||||||
data = [
|
data = [
|
||||||
self.user_model(u, include_servers=True, include_state=True)
|
self.user_model(u, include_servers=True, include_state=True)
|
||||||
|
@@ -135,7 +135,7 @@ def add_role(db, role_dict):
|
|||||||
db.commit()
|
db.commit()
|
||||||
|
|
||||||
|
|
||||||
def get_orm_class(kind):
|
def get_orm_class(kind): # Todo: merge and move to orm.py
|
||||||
if kind == 'users':
|
if kind == 'users':
|
||||||
Class = orm.User
|
Class = orm.User
|
||||||
elif kind == 'services':
|
elif kind == 'services':
|
||||||
|
@@ -1,4 +1,5 @@
|
|||||||
"""Test scopes for API handlers"""
|
"""Test scopes for API handlers"""
|
||||||
|
import json
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@@ -214,3 +215,100 @@ async def test_expand_groups(app, user_name, in_group, status_code):
|
|||||||
app, 'users', user_name, headers=auth_header(app.db, user_name)
|
app, 'users', user_name, headers=auth_header(app.db, user_name)
|
||||||
)
|
)
|
||||||
assert r.status_code == status_code
|
assert r.status_code == status_code
|
||||||
|
|
||||||
|
|
||||||
|
async def test_user_filter(app):
|
||||||
|
user_name = 'rollerblade'
|
||||||
|
test_role = {
|
||||||
|
'name': 'test',
|
||||||
|
'description': '',
|
||||||
|
'users': [user_name],
|
||||||
|
'scopes': [
|
||||||
|
'read:users!user=lindsay',
|
||||||
|
'read:users!user=gob',
|
||||||
|
'read:users!user=oscar',
|
||||||
|
],
|
||||||
|
}
|
||||||
|
roles.add_role(app.db, test_role)
|
||||||
|
name_in_scope = {'lindsay', 'oscar', 'gob'}
|
||||||
|
outside_scope = {'maeby', 'marta'}
|
||||||
|
group_name = 'bluth'
|
||||||
|
group = orm.Group.find(app.db, name=group_name)
|
||||||
|
if not group:
|
||||||
|
group = orm.Group(name=group_name)
|
||||||
|
app.db.add(group)
|
||||||
|
for name in name_in_scope | outside_scope:
|
||||||
|
user = add_user(app.db, name=name)
|
||||||
|
if name not in group.users:
|
||||||
|
group.users.append(user)
|
||||||
|
kind = 'users'
|
||||||
|
user = add_user(app.db, name=user_name)
|
||||||
|
roles.update_roles(app.db, user, kind, roles=['test'])
|
||||||
|
roles.remove_obj(app.db, user_name, kind, 'user')
|
||||||
|
app.db.commit()
|
||||||
|
r = await api_request(app, 'users', headers=auth_header(app.db, user_name))
|
||||||
|
assert r.status_code == 200
|
||||||
|
data = json.loads(r.content)
|
||||||
|
result_names = {user['name'] for user in data}
|
||||||
|
assert result_names == name_in_scope
|
||||||
|
|
||||||
|
|
||||||
|
async def test_user_filter_with_group(app): # todo: Move role setup to setup method
|
||||||
|
user_name = 'sally'
|
||||||
|
test_role = {
|
||||||
|
'name': 'test',
|
||||||
|
'description': '',
|
||||||
|
'users': [user_name],
|
||||||
|
'scopes': ['read:users!group=sitwell'],
|
||||||
|
}
|
||||||
|
roles.add_role(app.db, test_role)
|
||||||
|
user = add_user(app.db, name=user_name)
|
||||||
|
name_set = {'sally', 'stan'}
|
||||||
|
group_name = 'sitwell'
|
||||||
|
group = orm.Group.find(app.db, name=group_name)
|
||||||
|
if not group:
|
||||||
|
group = orm.Group(name=group_name)
|
||||||
|
app.db.add(group)
|
||||||
|
for name in name_set:
|
||||||
|
user = add_user(app.db, name=name)
|
||||||
|
if name not in group.users:
|
||||||
|
group.users.append(user)
|
||||||
|
kind = 'users'
|
||||||
|
roles.update_roles(app.db, user, kind, roles=['test'])
|
||||||
|
roles.remove_obj(app.db, user_name, kind, 'user')
|
||||||
|
app.db.commit()
|
||||||
|
r = await api_request(app, 'users', headers=auth_header(app.db, user_name))
|
||||||
|
assert r.status_code == 200
|
||||||
|
data = json.loads(r.content)
|
||||||
|
result_names = {user['name'] for user in data}
|
||||||
|
assert result_names == name_set
|
||||||
|
|
||||||
|
|
||||||
|
async def test_group_scope_filter(app):
|
||||||
|
user_name = 'rollerblade'
|
||||||
|
test_role = {
|
||||||
|
'name': 'test',
|
||||||
|
'description': '',
|
||||||
|
'users': [user_name],
|
||||||
|
'scopes': [
|
||||||
|
'read:groups!group=sitwell',
|
||||||
|
'read:groups!group=bluths',
|
||||||
|
],
|
||||||
|
}
|
||||||
|
roles.add_role(app.db, test_role)
|
||||||
|
user = add_user(app.db, name=user_name)
|
||||||
|
group_set = {'sitwell', 'bluths', 'austero'}
|
||||||
|
for group_name in group_set:
|
||||||
|
group = orm.Group.find(app.db, name=group_name)
|
||||||
|
if not group:
|
||||||
|
group = orm.Group(name=group_name)
|
||||||
|
app.db.add(group)
|
||||||
|
kind = 'users'
|
||||||
|
roles.update_roles(app.db, user, kind, roles=['test'])
|
||||||
|
roles.remove_obj(app.db, user_name, kind, 'user')
|
||||||
|
app.db.commit()
|
||||||
|
r = await api_request(app, 'groups', headers=auth_header(app.db, user_name))
|
||||||
|
assert r.status_code == 200
|
||||||
|
data = json.loads(r.content)
|
||||||
|
result_names = {user['name'] for user in data}
|
||||||
|
assert result_names == {'sitwell', 'bluths'}
|
||||||
|
@@ -30,6 +30,8 @@ from tornado.httpclient import HTTPError
|
|||||||
from tornado.log import app_log
|
from tornado.log import app_log
|
||||||
from tornado.platform.asyncio import to_asyncio_future
|
from tornado.platform.asyncio import to_asyncio_future
|
||||||
|
|
||||||
|
from . import orm # todo: only necessary for scopes, move later
|
||||||
|
|
||||||
|
|
||||||
def random_port():
|
def random_port():
|
||||||
"""Get a single random port."""
|
"""Get a single random port."""
|
||||||
@@ -347,25 +349,45 @@ def check_user_in_expanded_scope(handler, user_name, scope_group_names):
|
|||||||
|
|
||||||
|
|
||||||
def _flatten_groups(groups):
|
def _flatten_groups(groups):
|
||||||
user_set = {}
|
user_set = set()
|
||||||
for group in groups:
|
for group in groups:
|
||||||
user_set |= group.users
|
user_set |= {
|
||||||
|
user.name for user in group.users
|
||||||
|
} # todo: I think this could be one query, no for loop
|
||||||
return user_set
|
return user_set
|
||||||
|
|
||||||
|
|
||||||
def _get_scope_filter(req_scope, sub_scope):
|
def get_orm_class(kind):
|
||||||
|
class_dict = {
|
||||||
|
'users': orm.User,
|
||||||
|
'services': orm.Service,
|
||||||
|
'tokens': orm.APIToken,
|
||||||
|
'groups': orm.Group,
|
||||||
|
}
|
||||||
|
if kind not in class_dict:
|
||||||
|
raise ValueError(
|
||||||
|
"Kind must be one of %s, not %s" % (", ".join(class_dict), kind)
|
||||||
|
)
|
||||||
|
return class_dict[kind]
|
||||||
|
|
||||||
|
|
||||||
|
def _get_scope_filter(db, req_scope, sub_scope):
|
||||||
# Rough draft
|
# Rough draft
|
||||||
scope_translator = {
|
scope_translator = {
|
||||||
'read:users': 'users',
|
'read:users': 'users',
|
||||||
'users': 'users',
|
'read:services': 'services',
|
||||||
'read:groups': 'groups',
|
'read:groups': 'groups',
|
||||||
}
|
}
|
||||||
if req_scope not in scope_translator:
|
if req_scope not in scope_translator:
|
||||||
raise AttributeError("Scope not found")
|
raise AttributeError("Scope not found; scope filter not constructed")
|
||||||
kind = scope_translator[req_scope]
|
kind = scope_translator[req_scope]
|
||||||
scope_filter = None # todo: orm.Class(kind).find
|
Class = get_orm_class(kind)
|
||||||
if 'group' in sub_scope and kind == 'user':
|
sub_scope_values = next(iter(sub_scope.values()))
|
||||||
scope_filter += _flatten_groups(sub_scope['group'])
|
query = db.query(Class).filter(Class.name.in_(sub_scope_values))
|
||||||
|
scope_filter = [entry.name for entry in query.all()]
|
||||||
|
if 'group' in sub_scope and kind == 'users':
|
||||||
|
groups = db.query(orm.Group).filter(orm.Group.name.in_(sub_scope['group']))
|
||||||
|
scope_filter += _flatten_groups(groups)
|
||||||
return set(scope_filter)
|
return set(scope_filter)
|
||||||
|
|
||||||
|
|
||||||
@@ -380,9 +402,8 @@ def check_scope(api_handler, req_scope, scopes, **kwargs):
|
|||||||
# Apply filters
|
# Apply filters
|
||||||
sub_scope = scopes[req_scope]
|
sub_scope = scopes[req_scope]
|
||||||
if 'scope_filter' in kwargs:
|
if 'scope_filter' in kwargs:
|
||||||
scope_filter = _get_scope_filter(req_scope, sub_scope)
|
scope_filter = _get_scope_filter(api_handler.db, req_scope, sub_scope)
|
||||||
kwargs['scope_filter'] = scope_filter
|
return scope_filter
|
||||||
return True
|
|
||||||
else:
|
else:
|
||||||
# Interface change: Now can have multiple filters
|
# Interface change: Now can have multiple filters
|
||||||
for (filter_, filter_value) in kwargs.items():
|
for (filter_, filter_value) in kwargs.items():
|
||||||
@@ -449,11 +470,16 @@ def needs_scope(scope):
|
|||||||
if resource_name in bound_sig.arguments:
|
if resource_name in bound_sig.arguments:
|
||||||
resource_value = bound_sig.arguments[resource_name]
|
resource_value = bound_sig.arguments[resource_name]
|
||||||
s_kwargs[resource] = resource_value
|
s_kwargs[resource] = resource_value
|
||||||
|
if 'scope_filter' in bound_sig.arguments:
|
||||||
|
s_kwargs['scope_filter'] = None
|
||||||
if 'all' in self.scopes and self.current_user:
|
if 'all' in self.scopes and self.current_user:
|
||||||
# todo: What if no user is found? See test_api/test_referer_check
|
# todo: What if no user is found? See test_api/test_referer_check
|
||||||
self.scopes |= get_user_scopes(self.current_user.name)
|
self.scopes |= get_user_scopes(self.current_user.name)
|
||||||
parsed_scopes = parse_scopes(self.scopes)
|
parsed_scopes = parse_scopes(self.scopes)
|
||||||
if check_scope(self, scope, parsed_scopes, **s_kwargs):
|
scope_filter = check_scope(self, scope, parsed_scopes, **s_kwargs)
|
||||||
|
if scope_filter:
|
||||||
|
if isinstance(scope_filter, set):
|
||||||
|
kwargs['scope_filter'] = scope_filter
|
||||||
return func(self, *args, **kwargs)
|
return func(self, *args, **kwargs)
|
||||||
else:
|
else:
|
||||||
# catching attr error occurring for older_requirements test
|
# catching attr error occurring for older_requirements test
|
||||||
|
Reference in New Issue
Block a user