diff --git a/jupyterhub/app.py b/jupyterhub/app.py index c7571453..b8465a44 100644 --- a/jupyterhub/app.py +++ b/jupyterhub/app.py @@ -1858,7 +1858,7 @@ class JupyterHub(Application): # make sure all users, services and tokens have at least one role (update with default) for bearer in role_bearers: - Class = roles.get_orm_class(bearer) + Class = orm.get_class(bearer) for obj in db.query(Class): if len(obj.roles) < 1: roles.update_roles(db, obj=obj, kind=bearer) diff --git a/jupyterhub/orm.py b/jupyterhub/orm.py index 90ca63f7..828badce 100644 --- a/jupyterhub/orm.py +++ b/jupyterhub/orm.py @@ -978,3 +978,18 @@ def new_session_factory( # this off gives us a major performance boost session_factory = sessionmaker(bind=engine, expire_on_commit=expire_on_commit) return session_factory + + +def get_class(resource_name): + """Translates resource string names to ORM classes""" + class_dict = { + 'users': User, + 'services': Service, + 'tokens': APIToken, + 'groups': Group, + } + if resource_name not in class_dict: + raise ValueError( + "Kind must be one of %s, not %s" % (", ".join(class_dict), resource_name) + ) + return class_dict[resource_name] diff --git a/jupyterhub/roles.py b/jupyterhub/roles.py index bb2eab44..c46f7e19 100644 --- a/jupyterhub/roles.py +++ b/jupyterhub/roles.py @@ -135,25 +135,12 @@ def add_role(db, role_dict): db.commit() -def get_orm_class(kind): # Todo: merge and move to orm.py - if kind == 'users': - Class = orm.User - elif kind == 'services': - Class = orm.Service - elif kind == 'tokens': - Class = orm.APIToken - else: - raise ValueError("kind must be users, services or tokens, not %r" % kind) - - return Class - - def existing_only(func): """Decorator for checking if objects and roles exist""" def check_existence(db, objname, kind, rolename): - Class = get_orm_class(kind) + Class = orm.get_class(kind) obj = Class.find(db, objname) role = orm.Role.find(db, rolename) @@ -209,7 +196,7 @@ def update_roles(db, obj, kind, roles=None): """Updates object's roles if specified, assigns default if no roles specified""" - Class = get_orm_class(kind) + Class = orm.get_class(kind) user_role = orm.Role.find(db, 'user') if roles: @@ -252,11 +239,8 @@ def update_roles(db, obj, kind, roles=None): def mock_roles(app, name, kind): - """Loads and assigns default roles for mocked objects""" - - Class = get_orm_class(kind) - + Class = orm.get_class(kind) obj = Class.find(app.db, name=name) default_roles = get_default_roles() for role in default_roles: diff --git a/jupyterhub/scopes.py b/jupyterhub/scopes.py index ac2e2ae8..a67d3ce4 100644 --- a/jupyterhub/scopes.py +++ b/jupyterhub/scopes.py @@ -22,7 +22,6 @@ def get_user_scopes(name): users:activity users:servers users:tokens - """ scope_list = [ 'users', @@ -51,29 +50,16 @@ def _needs_scope_expansion(filter_, filter_value, sub_scope): def _check_user_in_expanded_scope(handler, user_name, scope_group_names): + """Check if username is present in set of allowed groups""" user = handler.find_user(user_name) if user is None: raise web.HTTPError(404, 'No such user found') - group_names = {group.name for group in user.groups} + group_names = {group.name for group in user.groups} # Todo: Replace with SQL query return bool(set(scope_group_names) & group_names) -def get_orm_class(kind): - class_dict = { - 'users': orm.User, - 'services': orm.Service, - 'tokens': orm.APIToken, - 'groups': orm.Group, - } - if kind not in class_dict: - raise ValueError( - "Kind must be one of %s, not %s" % (", ".join(class_dict), kind) - ) - return class_dict[kind] - - def _get_scope_filter(db, req_scope, sub_scope): - # Rough draft + """Produce a filter for `*ListAPIHandlers* so that get method knows which models to return""" scope_translator = { 'read:users': 'users', 'read:services': 'services', @@ -82,7 +68,7 @@ def _get_scope_filter(db, req_scope, sub_scope): if req_scope not in scope_translator: raise AttributeError("Scope not found; scope filter not constructed") kind = scope_translator[req_scope] - Class = get_orm_class(kind) + Class = orm.get_class(kind) sub_scope_values = next(iter(sub_scope.values())) query = db.query(Class).filter(Class.name.in_(sub_scope_values)) scope_filter = {entry.name for entry in query.all()} @@ -94,6 +80,10 @@ def _get_scope_filter(db, req_scope, sub_scope): def _check_scope(api_handler, req_scope, scopes, **kwargs): + """Check if scopes satisfy requirements + Returns either Scope.ALL for unrestricted access, Scope.NONE for refused access or + an iterable with a filter + """ # Parse user name and server name together if 'user' in kwargs and 'server' in kwargs: kwargs['server'] = "{}/{}".format(kwargs['user'], kwargs['server']) @@ -178,10 +168,10 @@ def needs_scope(scope): self.scopes |= get_user_scopes(self.current_user.name) parsed_scopes = _parse_scopes(self.scopes) scope_filter = _check_scope(self, scope, parsed_scopes, **s_kwargs) - # todo: This checks if True or set of resource names. Not very nice yet + # todo: This checks if True/False or set of resource names. Can be improved + if isinstance(scope_filter, set): + kwargs['scope_filter'] = scope_filter if scope_filter: - if isinstance(scope_filter, set): - kwargs['scope_filter'] = scope_filter return func(self, *args, **kwargs) else: # catching attr error occurring for older_requirements test diff --git a/jupyterhub/tests/test_api.py b/jupyterhub/tests/test_api.py index ff7b2389..a6883149 100644 --- a/jupyterhub/tests/test_api.py +++ b/jupyterhub/tests/test_api.py @@ -25,6 +25,7 @@ from .utils import async_requests from .utils import auth_header from .utils import find_user + # -------------------- # Authentication tests # -------------------- @@ -166,7 +167,7 @@ TIMESTAMP = normalize_timestamp(datetime.now().isoformat() + 'Z') @mark.user @mark.role -async def test_get_users(app): # todo: Sync with scope tests +async def test_get_users(app): db = app.db r = await api_request(app, 'users', headers=auth_header(db, 'admin')) assert r.status_code == 200 @@ -185,7 +186,7 @@ async def test_get_users(app): # todo: Sync with scope tests ] r = await api_request(app, 'users', headers=auth_header(db, 'user')) assert r.status_code == 200 - r_user_model = json.loads(r.text)[0] + r_user_model = r.json()[0] assert r_user_model['name'] == user_model['name'] diff --git a/jupyterhub/tests/test_scopes.py b/jupyterhub/tests/test_scopes.py index 9b66c47a..49a794bb 100644 --- a/jupyterhub/tests/test_scopes.py +++ b/jupyterhub/tests/test_scopes.py @@ -217,7 +217,7 @@ async def test_expand_groups(app, user_name, in_group, status_code): async def test_user_filter(app): - user_name = 'rollerblade' + user_name = 'rita' test_role = { 'name': 'test', 'description': '', @@ -247,8 +247,7 @@ async def test_user_filter(app): app.db.commit() r = await api_request(app, 'users', headers=auth_header(app.db, user_name)) assert r.status_code == 200 - data = json.loads(r.content) - result_names = {user['name'] for user in data} + result_names = {user['name'] for user in r.json()} assert result_names == name_in_scope @@ -278,8 +277,7 @@ async def test_user_filter_with_group(app): # todo: Move role setup to setup me app.db.commit() r = await api_request(app, 'users', headers=auth_header(app.db, user_name)) assert r.status_code == 200 - data = json.loads(r.content) - result_names = {user['name'] for user in data} + result_names = {user['name'] for user in r.json()} assert result_names == name_set @@ -308,6 +306,5 @@ async def test_group_scope_filter(app): app.db.commit() r = await api_request(app, 'groups', headers=auth_header(app.db, user_name)) assert r.status_code == 200 - data = json.loads(r.content) - result_names = {user['name'] for user in data} + result_names = {user['name'] for user in r.json()} assert result_names == {'sitwell', 'bluths'}