mirror of
https://github.com/jupyterhub/jupyterhub.git
synced 2025-10-10 03:23:04 +00:00
Refactored orm.get_class, improved resource filtereing
This commit is contained in:
@@ -1858,7 +1858,7 @@ class JupyterHub(Application):
|
||||
|
||||
# make sure all users, services and tokens have at least one role (update with default)
|
||||
for bearer in role_bearers:
|
||||
Class = roles.get_orm_class(bearer)
|
||||
Class = orm.get_class(bearer)
|
||||
for obj in db.query(Class):
|
||||
if len(obj.roles) < 1:
|
||||
roles.update_roles(db, obj=obj, kind=bearer)
|
||||
|
@@ -978,3 +978,18 @@ def new_session_factory(
|
||||
# this off gives us a major performance boost
|
||||
session_factory = sessionmaker(bind=engine, expire_on_commit=expire_on_commit)
|
||||
return session_factory
|
||||
|
||||
|
||||
def get_class(resource_name):
|
||||
"""Translates resource string names to ORM classes"""
|
||||
class_dict = {
|
||||
'users': User,
|
||||
'services': Service,
|
||||
'tokens': APIToken,
|
||||
'groups': Group,
|
||||
}
|
||||
if resource_name not in class_dict:
|
||||
raise ValueError(
|
||||
"Kind must be one of %s, not %s" % (", ".join(class_dict), resource_name)
|
||||
)
|
||||
return class_dict[resource_name]
|
||||
|
@@ -135,25 +135,12 @@ def add_role(db, role_dict):
|
||||
db.commit()
|
||||
|
||||
|
||||
def get_orm_class(kind): # Todo: merge and move to orm.py
|
||||
if kind == 'users':
|
||||
Class = orm.User
|
||||
elif kind == 'services':
|
||||
Class = orm.Service
|
||||
elif kind == 'tokens':
|
||||
Class = orm.APIToken
|
||||
else:
|
||||
raise ValueError("kind must be users, services or tokens, not %r" % kind)
|
||||
|
||||
return Class
|
||||
|
||||
|
||||
def existing_only(func):
|
||||
"""Decorator for checking if objects and roles exist"""
|
||||
|
||||
def check_existence(db, objname, kind, rolename):
|
||||
|
||||
Class = get_orm_class(kind)
|
||||
Class = orm.get_class(kind)
|
||||
obj = Class.find(db, objname)
|
||||
role = orm.Role.find(db, rolename)
|
||||
|
||||
@@ -209,7 +196,7 @@ def update_roles(db, obj, kind, roles=None):
|
||||
"""Updates object's roles if specified,
|
||||
assigns default if no roles specified"""
|
||||
|
||||
Class = get_orm_class(kind)
|
||||
Class = orm.get_class(kind)
|
||||
user_role = orm.Role.find(db, 'user')
|
||||
|
||||
if roles:
|
||||
@@ -252,11 +239,8 @@ def update_roles(db, obj, kind, roles=None):
|
||||
|
||||
|
||||
def mock_roles(app, name, kind):
|
||||
|
||||
"""Loads and assigns default roles for mocked objects"""
|
||||
|
||||
Class = get_orm_class(kind)
|
||||
|
||||
Class = orm.get_class(kind)
|
||||
obj = Class.find(app.db, name=name)
|
||||
default_roles = get_default_roles()
|
||||
for role in default_roles:
|
||||
|
@@ -22,7 +22,6 @@ def get_user_scopes(name):
|
||||
users:activity
|
||||
users:servers
|
||||
users:tokens
|
||||
|
||||
"""
|
||||
scope_list = [
|
||||
'users',
|
||||
@@ -51,29 +50,16 @@ def _needs_scope_expansion(filter_, filter_value, sub_scope):
|
||||
|
||||
|
||||
def _check_user_in_expanded_scope(handler, user_name, scope_group_names):
|
||||
"""Check if username is present in set of allowed groups"""
|
||||
user = handler.find_user(user_name)
|
||||
if user is None:
|
||||
raise web.HTTPError(404, 'No such user found')
|
||||
group_names = {group.name for group in user.groups}
|
||||
group_names = {group.name for group in user.groups} # Todo: Replace with SQL query
|
||||
return bool(set(scope_group_names) & group_names)
|
||||
|
||||
|
||||
def get_orm_class(kind):
|
||||
class_dict = {
|
||||
'users': orm.User,
|
||||
'services': orm.Service,
|
||||
'tokens': orm.APIToken,
|
||||
'groups': orm.Group,
|
||||
}
|
||||
if kind not in class_dict:
|
||||
raise ValueError(
|
||||
"Kind must be one of %s, not %s" % (", ".join(class_dict), kind)
|
||||
)
|
||||
return class_dict[kind]
|
||||
|
||||
|
||||
def _get_scope_filter(db, req_scope, sub_scope):
|
||||
# Rough draft
|
||||
"""Produce a filter for `*ListAPIHandlers* so that get method knows which models to return"""
|
||||
scope_translator = {
|
||||
'read:users': 'users',
|
||||
'read:services': 'services',
|
||||
@@ -82,7 +68,7 @@ def _get_scope_filter(db, req_scope, sub_scope):
|
||||
if req_scope not in scope_translator:
|
||||
raise AttributeError("Scope not found; scope filter not constructed")
|
||||
kind = scope_translator[req_scope]
|
||||
Class = get_orm_class(kind)
|
||||
Class = orm.get_class(kind)
|
||||
sub_scope_values = next(iter(sub_scope.values()))
|
||||
query = db.query(Class).filter(Class.name.in_(sub_scope_values))
|
||||
scope_filter = {entry.name for entry in query.all()}
|
||||
@@ -94,6 +80,10 @@ def _get_scope_filter(db, req_scope, sub_scope):
|
||||
|
||||
|
||||
def _check_scope(api_handler, req_scope, scopes, **kwargs):
|
||||
"""Check if scopes satisfy requirements
|
||||
Returns either Scope.ALL for unrestricted access, Scope.NONE for refused access or
|
||||
an iterable with a filter
|
||||
"""
|
||||
# Parse user name and server name together
|
||||
if 'user' in kwargs and 'server' in kwargs:
|
||||
kwargs['server'] = "{}/{}".format(kwargs['user'], kwargs['server'])
|
||||
@@ -178,10 +168,10 @@ def needs_scope(scope):
|
||||
self.scopes |= get_user_scopes(self.current_user.name)
|
||||
parsed_scopes = _parse_scopes(self.scopes)
|
||||
scope_filter = _check_scope(self, scope, parsed_scopes, **s_kwargs)
|
||||
# todo: This checks if True or set of resource names. Not very nice yet
|
||||
# todo: This checks if True/False or set of resource names. Can be improved
|
||||
if isinstance(scope_filter, set):
|
||||
kwargs['scope_filter'] = scope_filter
|
||||
if scope_filter:
|
||||
if isinstance(scope_filter, set):
|
||||
kwargs['scope_filter'] = scope_filter
|
||||
return func(self, *args, **kwargs)
|
||||
else:
|
||||
# catching attr error occurring for older_requirements test
|
||||
|
@@ -25,6 +25,7 @@ from .utils import async_requests
|
||||
from .utils import auth_header
|
||||
from .utils import find_user
|
||||
|
||||
|
||||
# --------------------
|
||||
# Authentication tests
|
||||
# --------------------
|
||||
@@ -166,7 +167,7 @@ TIMESTAMP = normalize_timestamp(datetime.now().isoformat() + 'Z')
|
||||
|
||||
@mark.user
|
||||
@mark.role
|
||||
async def test_get_users(app): # todo: Sync with scope tests
|
||||
async def test_get_users(app):
|
||||
db = app.db
|
||||
r = await api_request(app, 'users', headers=auth_header(db, 'admin'))
|
||||
assert r.status_code == 200
|
||||
@@ -185,7 +186,7 @@ async def test_get_users(app): # todo: Sync with scope tests
|
||||
]
|
||||
r = await api_request(app, 'users', headers=auth_header(db, 'user'))
|
||||
assert r.status_code == 200
|
||||
r_user_model = json.loads(r.text)[0]
|
||||
r_user_model = r.json()[0]
|
||||
assert r_user_model['name'] == user_model['name']
|
||||
|
||||
|
||||
|
@@ -217,7 +217,7 @@ async def test_expand_groups(app, user_name, in_group, status_code):
|
||||
|
||||
|
||||
async def test_user_filter(app):
|
||||
user_name = 'rollerblade'
|
||||
user_name = 'rita'
|
||||
test_role = {
|
||||
'name': 'test',
|
||||
'description': '',
|
||||
@@ -247,8 +247,7 @@ async def test_user_filter(app):
|
||||
app.db.commit()
|
||||
r = await api_request(app, 'users', headers=auth_header(app.db, user_name))
|
||||
assert r.status_code == 200
|
||||
data = json.loads(r.content)
|
||||
result_names = {user['name'] for user in data}
|
||||
result_names = {user['name'] for user in r.json()}
|
||||
assert result_names == name_in_scope
|
||||
|
||||
|
||||
@@ -278,8 +277,7 @@ async def test_user_filter_with_group(app): # todo: Move role setup to setup me
|
||||
app.db.commit()
|
||||
r = await api_request(app, 'users', headers=auth_header(app.db, user_name))
|
||||
assert r.status_code == 200
|
||||
data = json.loads(r.content)
|
||||
result_names = {user['name'] for user in data}
|
||||
result_names = {user['name'] for user in r.json()}
|
||||
assert result_names == name_set
|
||||
|
||||
|
||||
@@ -308,6 +306,5 @@ async def test_group_scope_filter(app):
|
||||
app.db.commit()
|
||||
r = await api_request(app, 'groups', headers=auth_header(app.db, user_name))
|
||||
assert r.status_code == 200
|
||||
data = json.loads(r.content)
|
||||
result_names = {user['name'] for user in data}
|
||||
result_names = {user['name'] for user in r.json()}
|
||||
assert result_names == {'sitwell', 'bluths'}
|
||||
|
Reference in New Issue
Block a user