Refactored orm.get_class, improved resource filtereing

This commit is contained in:
Omar Richardson
2021-01-05 19:58:39 +01:00
parent e21713c24f
commit 82c837eb89
6 changed files with 37 additions and 50 deletions

View File

@@ -1858,7 +1858,7 @@ class JupyterHub(Application):
# make sure all users, services and tokens have at least one role (update with default)
for bearer in role_bearers:
Class = roles.get_orm_class(bearer)
Class = orm.get_class(bearer)
for obj in db.query(Class):
if len(obj.roles) < 1:
roles.update_roles(db, obj=obj, kind=bearer)

View File

@@ -978,3 +978,18 @@ def new_session_factory(
# this off gives us a major performance boost
session_factory = sessionmaker(bind=engine, expire_on_commit=expire_on_commit)
return session_factory
def get_class(resource_name):
"""Translates resource string names to ORM classes"""
class_dict = {
'users': User,
'services': Service,
'tokens': APIToken,
'groups': Group,
}
if resource_name not in class_dict:
raise ValueError(
"Kind must be one of %s, not %s" % (", ".join(class_dict), resource_name)
)
return class_dict[resource_name]

View File

@@ -135,25 +135,12 @@ def add_role(db, role_dict):
db.commit()
def get_orm_class(kind): # Todo: merge and move to orm.py
if kind == 'users':
Class = orm.User
elif kind == 'services':
Class = orm.Service
elif kind == 'tokens':
Class = orm.APIToken
else:
raise ValueError("kind must be users, services or tokens, not %r" % kind)
return Class
def existing_only(func):
"""Decorator for checking if objects and roles exist"""
def check_existence(db, objname, kind, rolename):
Class = get_orm_class(kind)
Class = orm.get_class(kind)
obj = Class.find(db, objname)
role = orm.Role.find(db, rolename)
@@ -209,7 +196,7 @@ def update_roles(db, obj, kind, roles=None):
"""Updates object's roles if specified,
assigns default if no roles specified"""
Class = get_orm_class(kind)
Class = orm.get_class(kind)
user_role = orm.Role.find(db, 'user')
if roles:
@@ -252,11 +239,8 @@ def update_roles(db, obj, kind, roles=None):
def mock_roles(app, name, kind):
"""Loads and assigns default roles for mocked objects"""
Class = get_orm_class(kind)
Class = orm.get_class(kind)
obj = Class.find(app.db, name=name)
default_roles = get_default_roles()
for role in default_roles:

View File

@@ -22,7 +22,6 @@ def get_user_scopes(name):
users:activity
users:servers
users:tokens
"""
scope_list = [
'users',
@@ -51,29 +50,16 @@ def _needs_scope_expansion(filter_, filter_value, sub_scope):
def _check_user_in_expanded_scope(handler, user_name, scope_group_names):
"""Check if username is present in set of allowed groups"""
user = handler.find_user(user_name)
if user is None:
raise web.HTTPError(404, 'No such user found')
group_names = {group.name for group in user.groups}
group_names = {group.name for group in user.groups} # Todo: Replace with SQL query
return bool(set(scope_group_names) & group_names)
def get_orm_class(kind):
class_dict = {
'users': orm.User,
'services': orm.Service,
'tokens': orm.APIToken,
'groups': orm.Group,
}
if kind not in class_dict:
raise ValueError(
"Kind must be one of %s, not %s" % (", ".join(class_dict), kind)
)
return class_dict[kind]
def _get_scope_filter(db, req_scope, sub_scope):
# Rough draft
"""Produce a filter for `*ListAPIHandlers* so that get method knows which models to return"""
scope_translator = {
'read:users': 'users',
'read:services': 'services',
@@ -82,7 +68,7 @@ def _get_scope_filter(db, req_scope, sub_scope):
if req_scope not in scope_translator:
raise AttributeError("Scope not found; scope filter not constructed")
kind = scope_translator[req_scope]
Class = get_orm_class(kind)
Class = orm.get_class(kind)
sub_scope_values = next(iter(sub_scope.values()))
query = db.query(Class).filter(Class.name.in_(sub_scope_values))
scope_filter = {entry.name for entry in query.all()}
@@ -94,6 +80,10 @@ def _get_scope_filter(db, req_scope, sub_scope):
def _check_scope(api_handler, req_scope, scopes, **kwargs):
"""Check if scopes satisfy requirements
Returns either Scope.ALL for unrestricted access, Scope.NONE for refused access or
an iterable with a filter
"""
# Parse user name and server name together
if 'user' in kwargs and 'server' in kwargs:
kwargs['server'] = "{}/{}".format(kwargs['user'], kwargs['server'])
@@ -178,10 +168,10 @@ def needs_scope(scope):
self.scopes |= get_user_scopes(self.current_user.name)
parsed_scopes = _parse_scopes(self.scopes)
scope_filter = _check_scope(self, scope, parsed_scopes, **s_kwargs)
# todo: This checks if True or set of resource names. Not very nice yet
# todo: This checks if True/False or set of resource names. Can be improved
if isinstance(scope_filter, set):
kwargs['scope_filter'] = scope_filter
if scope_filter:
if isinstance(scope_filter, set):
kwargs['scope_filter'] = scope_filter
return func(self, *args, **kwargs)
else:
# catching attr error occurring for older_requirements test

View File

@@ -25,6 +25,7 @@ from .utils import async_requests
from .utils import auth_header
from .utils import find_user
# --------------------
# Authentication tests
# --------------------
@@ -166,7 +167,7 @@ TIMESTAMP = normalize_timestamp(datetime.now().isoformat() + 'Z')
@mark.user
@mark.role
async def test_get_users(app): # todo: Sync with scope tests
async def test_get_users(app):
db = app.db
r = await api_request(app, 'users', headers=auth_header(db, 'admin'))
assert r.status_code == 200
@@ -185,7 +186,7 @@ async def test_get_users(app): # todo: Sync with scope tests
]
r = await api_request(app, 'users', headers=auth_header(db, 'user'))
assert r.status_code == 200
r_user_model = json.loads(r.text)[0]
r_user_model = r.json()[0]
assert r_user_model['name'] == user_model['name']

View File

@@ -217,7 +217,7 @@ async def test_expand_groups(app, user_name, in_group, status_code):
async def test_user_filter(app):
user_name = 'rollerblade'
user_name = 'rita'
test_role = {
'name': 'test',
'description': '',
@@ -247,8 +247,7 @@ async def test_user_filter(app):
app.db.commit()
r = await api_request(app, 'users', headers=auth_header(app.db, user_name))
assert r.status_code == 200
data = json.loads(r.content)
result_names = {user['name'] for user in data}
result_names = {user['name'] for user in r.json()}
assert result_names == name_in_scope
@@ -278,8 +277,7 @@ async def test_user_filter_with_group(app): # todo: Move role setup to setup me
app.db.commit()
r = await api_request(app, 'users', headers=auth_header(app.db, user_name))
assert r.status_code == 200
data = json.loads(r.content)
result_names = {user['name'] for user in data}
result_names = {user['name'] for user in r.json()}
assert result_names == name_set
@@ -308,6 +306,5 @@ async def test_group_scope_filter(app):
app.db.commit()
r = await api_request(app, 'groups', headers=auth_header(app.db, user_name))
assert r.status_code == 200
data = json.loads(r.content)
result_names = {user['name'] for user in data}
result_names = {user['name'] for user in r.json()}
assert result_names == {'sitwell', 'bluths'}