mirror of
https://github.com/jupyterhub/jupyterhub.git
synced 2025-10-10 19:43:01 +00:00
Refactored orm.get_class, improved resource filtereing
This commit is contained in:
@@ -1858,7 +1858,7 @@ class JupyterHub(Application):
|
|||||||
|
|
||||||
# make sure all users, services and tokens have at least one role (update with default)
|
# make sure all users, services and tokens have at least one role (update with default)
|
||||||
for bearer in role_bearers:
|
for bearer in role_bearers:
|
||||||
Class = roles.get_orm_class(bearer)
|
Class = orm.get_class(bearer)
|
||||||
for obj in db.query(Class):
|
for obj in db.query(Class):
|
||||||
if len(obj.roles) < 1:
|
if len(obj.roles) < 1:
|
||||||
roles.update_roles(db, obj=obj, kind=bearer)
|
roles.update_roles(db, obj=obj, kind=bearer)
|
||||||
|
@@ -978,3 +978,18 @@ def new_session_factory(
|
|||||||
# this off gives us a major performance boost
|
# this off gives us a major performance boost
|
||||||
session_factory = sessionmaker(bind=engine, expire_on_commit=expire_on_commit)
|
session_factory = sessionmaker(bind=engine, expire_on_commit=expire_on_commit)
|
||||||
return session_factory
|
return session_factory
|
||||||
|
|
||||||
|
|
||||||
|
def get_class(resource_name):
|
||||||
|
"""Translates resource string names to ORM classes"""
|
||||||
|
class_dict = {
|
||||||
|
'users': User,
|
||||||
|
'services': Service,
|
||||||
|
'tokens': APIToken,
|
||||||
|
'groups': Group,
|
||||||
|
}
|
||||||
|
if resource_name not in class_dict:
|
||||||
|
raise ValueError(
|
||||||
|
"Kind must be one of %s, not %s" % (", ".join(class_dict), resource_name)
|
||||||
|
)
|
||||||
|
return class_dict[resource_name]
|
||||||
|
@@ -135,25 +135,12 @@ def add_role(db, role_dict):
|
|||||||
db.commit()
|
db.commit()
|
||||||
|
|
||||||
|
|
||||||
def get_orm_class(kind): # Todo: merge and move to orm.py
|
|
||||||
if kind == 'users':
|
|
||||||
Class = orm.User
|
|
||||||
elif kind == 'services':
|
|
||||||
Class = orm.Service
|
|
||||||
elif kind == 'tokens':
|
|
||||||
Class = orm.APIToken
|
|
||||||
else:
|
|
||||||
raise ValueError("kind must be users, services or tokens, not %r" % kind)
|
|
||||||
|
|
||||||
return Class
|
|
||||||
|
|
||||||
|
|
||||||
def existing_only(func):
|
def existing_only(func):
|
||||||
"""Decorator for checking if objects and roles exist"""
|
"""Decorator for checking if objects and roles exist"""
|
||||||
|
|
||||||
def check_existence(db, objname, kind, rolename):
|
def check_existence(db, objname, kind, rolename):
|
||||||
|
|
||||||
Class = get_orm_class(kind)
|
Class = orm.get_class(kind)
|
||||||
obj = Class.find(db, objname)
|
obj = Class.find(db, objname)
|
||||||
role = orm.Role.find(db, rolename)
|
role = orm.Role.find(db, rolename)
|
||||||
|
|
||||||
@@ -209,7 +196,7 @@ def update_roles(db, obj, kind, roles=None):
|
|||||||
"""Updates object's roles if specified,
|
"""Updates object's roles if specified,
|
||||||
assigns default if no roles specified"""
|
assigns default if no roles specified"""
|
||||||
|
|
||||||
Class = get_orm_class(kind)
|
Class = orm.get_class(kind)
|
||||||
user_role = orm.Role.find(db, 'user')
|
user_role = orm.Role.find(db, 'user')
|
||||||
|
|
||||||
if roles:
|
if roles:
|
||||||
@@ -252,11 +239,8 @@ def update_roles(db, obj, kind, roles=None):
|
|||||||
|
|
||||||
|
|
||||||
def mock_roles(app, name, kind):
|
def mock_roles(app, name, kind):
|
||||||
|
|
||||||
"""Loads and assigns default roles for mocked objects"""
|
"""Loads and assigns default roles for mocked objects"""
|
||||||
|
Class = orm.get_class(kind)
|
||||||
Class = get_orm_class(kind)
|
|
||||||
|
|
||||||
obj = Class.find(app.db, name=name)
|
obj = Class.find(app.db, name=name)
|
||||||
default_roles = get_default_roles()
|
default_roles = get_default_roles()
|
||||||
for role in default_roles:
|
for role in default_roles:
|
||||||
|
@@ -22,7 +22,6 @@ def get_user_scopes(name):
|
|||||||
users:activity
|
users:activity
|
||||||
users:servers
|
users:servers
|
||||||
users:tokens
|
users:tokens
|
||||||
|
|
||||||
"""
|
"""
|
||||||
scope_list = [
|
scope_list = [
|
||||||
'users',
|
'users',
|
||||||
@@ -51,29 +50,16 @@ def _needs_scope_expansion(filter_, filter_value, sub_scope):
|
|||||||
|
|
||||||
|
|
||||||
def _check_user_in_expanded_scope(handler, user_name, scope_group_names):
|
def _check_user_in_expanded_scope(handler, user_name, scope_group_names):
|
||||||
|
"""Check if username is present in set of allowed groups"""
|
||||||
user = handler.find_user(user_name)
|
user = handler.find_user(user_name)
|
||||||
if user is None:
|
if user is None:
|
||||||
raise web.HTTPError(404, 'No such user found')
|
raise web.HTTPError(404, 'No such user found')
|
||||||
group_names = {group.name for group in user.groups}
|
group_names = {group.name for group in user.groups} # Todo: Replace with SQL query
|
||||||
return bool(set(scope_group_names) & group_names)
|
return bool(set(scope_group_names) & group_names)
|
||||||
|
|
||||||
|
|
||||||
def get_orm_class(kind):
|
|
||||||
class_dict = {
|
|
||||||
'users': orm.User,
|
|
||||||
'services': orm.Service,
|
|
||||||
'tokens': orm.APIToken,
|
|
||||||
'groups': orm.Group,
|
|
||||||
}
|
|
||||||
if kind not in class_dict:
|
|
||||||
raise ValueError(
|
|
||||||
"Kind must be one of %s, not %s" % (", ".join(class_dict), kind)
|
|
||||||
)
|
|
||||||
return class_dict[kind]
|
|
||||||
|
|
||||||
|
|
||||||
def _get_scope_filter(db, req_scope, sub_scope):
|
def _get_scope_filter(db, req_scope, sub_scope):
|
||||||
# Rough draft
|
"""Produce a filter for `*ListAPIHandlers* so that get method knows which models to return"""
|
||||||
scope_translator = {
|
scope_translator = {
|
||||||
'read:users': 'users',
|
'read:users': 'users',
|
||||||
'read:services': 'services',
|
'read:services': 'services',
|
||||||
@@ -82,7 +68,7 @@ def _get_scope_filter(db, req_scope, sub_scope):
|
|||||||
if req_scope not in scope_translator:
|
if req_scope not in scope_translator:
|
||||||
raise AttributeError("Scope not found; scope filter not constructed")
|
raise AttributeError("Scope not found; scope filter not constructed")
|
||||||
kind = scope_translator[req_scope]
|
kind = scope_translator[req_scope]
|
||||||
Class = get_orm_class(kind)
|
Class = orm.get_class(kind)
|
||||||
sub_scope_values = next(iter(sub_scope.values()))
|
sub_scope_values = next(iter(sub_scope.values()))
|
||||||
query = db.query(Class).filter(Class.name.in_(sub_scope_values))
|
query = db.query(Class).filter(Class.name.in_(sub_scope_values))
|
||||||
scope_filter = {entry.name for entry in query.all()}
|
scope_filter = {entry.name for entry in query.all()}
|
||||||
@@ -94,6 +80,10 @@ def _get_scope_filter(db, req_scope, sub_scope):
|
|||||||
|
|
||||||
|
|
||||||
def _check_scope(api_handler, req_scope, scopes, **kwargs):
|
def _check_scope(api_handler, req_scope, scopes, **kwargs):
|
||||||
|
"""Check if scopes satisfy requirements
|
||||||
|
Returns either Scope.ALL for unrestricted access, Scope.NONE for refused access or
|
||||||
|
an iterable with a filter
|
||||||
|
"""
|
||||||
# Parse user name and server name together
|
# Parse user name and server name together
|
||||||
if 'user' in kwargs and 'server' in kwargs:
|
if 'user' in kwargs and 'server' in kwargs:
|
||||||
kwargs['server'] = "{}/{}".format(kwargs['user'], kwargs['server'])
|
kwargs['server'] = "{}/{}".format(kwargs['user'], kwargs['server'])
|
||||||
@@ -178,10 +168,10 @@ def needs_scope(scope):
|
|||||||
self.scopes |= get_user_scopes(self.current_user.name)
|
self.scopes |= get_user_scopes(self.current_user.name)
|
||||||
parsed_scopes = _parse_scopes(self.scopes)
|
parsed_scopes = _parse_scopes(self.scopes)
|
||||||
scope_filter = _check_scope(self, scope, parsed_scopes, **s_kwargs)
|
scope_filter = _check_scope(self, scope, parsed_scopes, **s_kwargs)
|
||||||
# todo: This checks if True or set of resource names. Not very nice yet
|
# todo: This checks if True/False or set of resource names. Can be improved
|
||||||
if scope_filter:
|
|
||||||
if isinstance(scope_filter, set):
|
if isinstance(scope_filter, set):
|
||||||
kwargs['scope_filter'] = scope_filter
|
kwargs['scope_filter'] = scope_filter
|
||||||
|
if scope_filter:
|
||||||
return func(self, *args, **kwargs)
|
return func(self, *args, **kwargs)
|
||||||
else:
|
else:
|
||||||
# catching attr error occurring for older_requirements test
|
# catching attr error occurring for older_requirements test
|
||||||
|
@@ -25,6 +25,7 @@ from .utils import async_requests
|
|||||||
from .utils import auth_header
|
from .utils import auth_header
|
||||||
from .utils import find_user
|
from .utils import find_user
|
||||||
|
|
||||||
|
|
||||||
# --------------------
|
# --------------------
|
||||||
# Authentication tests
|
# Authentication tests
|
||||||
# --------------------
|
# --------------------
|
||||||
@@ -166,7 +167,7 @@ TIMESTAMP = normalize_timestamp(datetime.now().isoformat() + 'Z')
|
|||||||
|
|
||||||
@mark.user
|
@mark.user
|
||||||
@mark.role
|
@mark.role
|
||||||
async def test_get_users(app): # todo: Sync with scope tests
|
async def test_get_users(app):
|
||||||
db = app.db
|
db = app.db
|
||||||
r = await api_request(app, 'users', headers=auth_header(db, 'admin'))
|
r = await api_request(app, 'users', headers=auth_header(db, 'admin'))
|
||||||
assert r.status_code == 200
|
assert r.status_code == 200
|
||||||
@@ -185,7 +186,7 @@ async def test_get_users(app): # todo: Sync with scope tests
|
|||||||
]
|
]
|
||||||
r = await api_request(app, 'users', headers=auth_header(db, 'user'))
|
r = await api_request(app, 'users', headers=auth_header(db, 'user'))
|
||||||
assert r.status_code == 200
|
assert r.status_code == 200
|
||||||
r_user_model = json.loads(r.text)[0]
|
r_user_model = r.json()[0]
|
||||||
assert r_user_model['name'] == user_model['name']
|
assert r_user_model['name'] == user_model['name']
|
||||||
|
|
||||||
|
|
||||||
|
@@ -217,7 +217,7 @@ async def test_expand_groups(app, user_name, in_group, status_code):
|
|||||||
|
|
||||||
|
|
||||||
async def test_user_filter(app):
|
async def test_user_filter(app):
|
||||||
user_name = 'rollerblade'
|
user_name = 'rita'
|
||||||
test_role = {
|
test_role = {
|
||||||
'name': 'test',
|
'name': 'test',
|
||||||
'description': '',
|
'description': '',
|
||||||
@@ -247,8 +247,7 @@ async def test_user_filter(app):
|
|||||||
app.db.commit()
|
app.db.commit()
|
||||||
r = await api_request(app, 'users', headers=auth_header(app.db, user_name))
|
r = await api_request(app, 'users', headers=auth_header(app.db, user_name))
|
||||||
assert r.status_code == 200
|
assert r.status_code == 200
|
||||||
data = json.loads(r.content)
|
result_names = {user['name'] for user in r.json()}
|
||||||
result_names = {user['name'] for user in data}
|
|
||||||
assert result_names == name_in_scope
|
assert result_names == name_in_scope
|
||||||
|
|
||||||
|
|
||||||
@@ -278,8 +277,7 @@ async def test_user_filter_with_group(app): # todo: Move role setup to setup me
|
|||||||
app.db.commit()
|
app.db.commit()
|
||||||
r = await api_request(app, 'users', headers=auth_header(app.db, user_name))
|
r = await api_request(app, 'users', headers=auth_header(app.db, user_name))
|
||||||
assert r.status_code == 200
|
assert r.status_code == 200
|
||||||
data = json.loads(r.content)
|
result_names = {user['name'] for user in r.json()}
|
||||||
result_names = {user['name'] for user in data}
|
|
||||||
assert result_names == name_set
|
assert result_names == name_set
|
||||||
|
|
||||||
|
|
||||||
@@ -308,6 +306,5 @@ async def test_group_scope_filter(app):
|
|||||||
app.db.commit()
|
app.db.commit()
|
||||||
r = await api_request(app, 'groups', headers=auth_header(app.db, user_name))
|
r = await api_request(app, 'groups', headers=auth_header(app.db, user_name))
|
||||||
assert r.status_code == 200
|
assert r.status_code == 200
|
||||||
data = json.loads(r.content)
|
result_names = {user['name'] for user in r.json()}
|
||||||
result_names = {user['name'] for user in data}
|
|
||||||
assert result_names == {'sitwell', 'bluths'}
|
assert result_names == {'sitwell', 'bluths'}
|
||||||
|
Reference in New Issue
Block a user