mirror of
https://github.com/jupyterhub/jupyterhub.git
synced 2025-10-11 03:52:59 +00:00
update to roles utils
This commit is contained in:
@@ -88,7 +88,7 @@ class UserListAPIHandler(APIHandler):
|
|||||||
user = self.user_from_username(name)
|
user = self.user_from_username(name)
|
||||||
if admin:
|
if admin:
|
||||||
user.admin = True
|
user.admin = True
|
||||||
roles.DefaultRoles.add_default_role(self.db, user)
|
roles.update_roles(self.db, user)
|
||||||
self.db.commit()
|
self.db.commit()
|
||||||
try:
|
try:
|
||||||
await maybe_future(self.authenticator.add_user(user))
|
await maybe_future(self.authenticator.add_user(user))
|
||||||
@@ -151,7 +151,7 @@ class UserAPIHandler(APIHandler):
|
|||||||
self._check_user_model(data)
|
self._check_user_model(data)
|
||||||
if 'admin' in data:
|
if 'admin' in data:
|
||||||
user.admin = data['admin']
|
user.admin = data['admin']
|
||||||
roles.DefaultRoles.add_default_role(self.db, user)
|
roles.update_roles(self.db, user)
|
||||||
self.db.commit()
|
self.db.commit()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -208,9 +208,9 @@ class UserAPIHandler(APIHandler):
|
|||||||
if key == 'auth_state':
|
if key == 'auth_state':
|
||||||
await user.save_auth_state(value)
|
await user.save_auth_state(value)
|
||||||
else:
|
else:
|
||||||
if key == 'admin' and value != user.admin:
|
|
||||||
roles.DefaultRoles.change_admin(self.db, user=user, admin=value)
|
|
||||||
setattr(user, key, value)
|
setattr(user, key, value)
|
||||||
|
if key == 'admin':
|
||||||
|
roles.update_roles(self.db, user=user)
|
||||||
self.db.commit()
|
self.db.commit()
|
||||||
user_ = self.user_model(user)
|
user_ = self.user_model(user)
|
||||||
user_['auth_state'] = await user.get_auth_state()
|
user_['auth_state'] = await user.get_auth_state()
|
||||||
|
@@ -1715,7 +1715,6 @@ class JupyterHub(Application):
|
|||||||
|
|
||||||
for name in admin_users:
|
for name in admin_users:
|
||||||
# ensure anyone specified as admin in config is admin in db
|
# ensure anyone specified as admin in config is admin in db
|
||||||
# and gets admin role
|
|
||||||
user = orm.User.find(db, name)
|
user = orm.User.find(db, name)
|
||||||
if user is None:
|
if user is None:
|
||||||
user = orm.User(name=name, admin=True)
|
user = orm.User(name=name, admin=True)
|
||||||
@@ -1825,11 +1824,15 @@ class JupyterHub(Application):
|
|||||||
"""Load default and predefined roles into the database"""
|
"""Load default and predefined roles into the database"""
|
||||||
db = self.db
|
db = self.db
|
||||||
# load default roles
|
# load default roles
|
||||||
roles.DefaultRoles.load_to_database(db)
|
default_roles = roles.get_default_roles()
|
||||||
|
for role in default_roles:
|
||||||
|
roles.add_role(db, role)
|
||||||
|
|
||||||
# load predefined roles from config file
|
# load predefined roles from config file
|
||||||
for predef_role in self.load_roles:
|
for predef_role in self.load_roles:
|
||||||
role = roles.add_predef_role(db, predef_role)
|
roles.add_role(db, predef_role)
|
||||||
|
role = orm.Role.find(db, predef_role['name'])
|
||||||
|
|
||||||
# handle users
|
# handle users
|
||||||
for username in predef_role['users']:
|
for username in predef_role['users']:
|
||||||
username = self.authenticator.normalize_username(username)
|
username = self.authenticator.normalize_username(username)
|
||||||
@@ -1847,9 +1850,11 @@ class JupyterHub(Application):
|
|||||||
db.add(user)
|
db.add(user)
|
||||||
roles.add_user(db, user=user, role=role)
|
roles.add_user(db, user=user, role=role)
|
||||||
|
|
||||||
# make sure every existing user has a default user or admin role
|
# make sure all users have at least one role (update with default)
|
||||||
for user in db.query(orm.User):
|
for user in db.query(orm.User):
|
||||||
roles.DefaultRoles.add_default_role(db, user)
|
if len(user.roles) < 1:
|
||||||
|
roles.update_roles(db, user)
|
||||||
|
|
||||||
db.commit()
|
db.commit()
|
||||||
|
|
||||||
async def _add_tokens(self, token_dict, kind):
|
async def _add_tokens(self, token_dict, kind):
|
||||||
|
@@ -4,83 +4,58 @@
|
|||||||
from .orm import Role
|
from .orm import Role
|
||||||
|
|
||||||
|
|
||||||
# define default roles
|
def get_default_roles():
|
||||||
class DefaultRoles:
|
|
||||||
|
|
||||||
user = Role(name='user', description='Everything the user can do', scopes=['all'])
|
"""Returns a list of default roles dictionaries"""
|
||||||
admin = Role(
|
|
||||||
name='admin',
|
|
||||||
description='Admin privileges (currently can do everything)',
|
|
||||||
scopes=[
|
|
||||||
'all',
|
|
||||||
'users',
|
|
||||||
'users:tokens',
|
|
||||||
'admin:users',
|
|
||||||
'admin:users:servers',
|
|
||||||
'groups',
|
|
||||||
'admin:groups',
|
|
||||||
'read:services',
|
|
||||||
'proxy',
|
|
||||||
'shutdown',
|
|
||||||
],
|
|
||||||
)
|
|
||||||
server = Role(
|
|
||||||
name='server',
|
|
||||||
description='Post activity only',
|
|
||||||
scopes=['users:activity!user=username'],
|
|
||||||
)
|
|
||||||
roles = (user, admin, server)
|
|
||||||
|
|
||||||
def __init__(cls, roles=roles):
|
default_roles = [
|
||||||
cls.roles = roles
|
{
|
||||||
|
'name': 'user',
|
||||||
|
'description': 'Everything the user can do',
|
||||||
|
'scopes': ['all'],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'name': 'admin',
|
||||||
|
'description': 'Admin privileges (currently can do everything)',
|
||||||
|
'scopes': [
|
||||||
|
'all',
|
||||||
|
'users',
|
||||||
|
'users:tokens',
|
||||||
|
'admin:users',
|
||||||
|
'admin:users:servers',
|
||||||
|
'groups',
|
||||||
|
'admin:groups',
|
||||||
|
'read:services',
|
||||||
|
'proxy',
|
||||||
|
'shutdown',
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'name': 'server',
|
||||||
|
'description': 'Post activity only',
|
||||||
|
'scopes': ['users:activity!user=username'],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
return default_roles
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_user_role(cls, db):
|
|
||||||
return Role.find(db, name=cls.user.name)
|
|
||||||
|
|
||||||
@classmethod
|
def add_role(db, role_dict):
|
||||||
def get_admin_role(cls, db):
|
|
||||||
return Role.find(db, name=cls.admin.name)
|
|
||||||
|
|
||||||
@classmethod
|
"""Adds a new role to database or modifies an existing one"""
|
||||||
def get_server_role(cls, db):
|
|
||||||
return Role.find(db, name=cls.server.name)
|
|
||||||
|
|
||||||
@classmethod
|
role = Role.find(db, role_dict['name'])
|
||||||
def load_to_database(cls, db):
|
|
||||||
for role in cls.roles:
|
|
||||||
db_role = Role.find(db, name=role.name)
|
|
||||||
if db_role is None:
|
|
||||||
new_role = Role(
|
|
||||||
name=role.name, description=role.description, scopes=role.scopes,
|
|
||||||
)
|
|
||||||
db.add(new_role)
|
|
||||||
db.commit()
|
|
||||||
|
|
||||||
@classmethod
|
if role is None:
|
||||||
def add_default_role(cls, db, user):
|
role = Role(
|
||||||
role = None
|
name=role_dict['name'],
|
||||||
if user.admin and cls.admin not in user.roles:
|
description=role_dict['description'],
|
||||||
role = cls.get_admin_role(db)
|
scopes=role_dict['scopes'],
|
||||||
if not user.admin and cls.user not in user.roles:
|
)
|
||||||
role = cls.get_user_role(db)
|
db.add(role)
|
||||||
if role is not None:
|
else:
|
||||||
add_user(db, user, role)
|
role.description = role_dict['description']
|
||||||
db.commit()
|
role.scopes = role_dict['scopes']
|
||||||
|
db.commit()
|
||||||
@classmethod
|
|
||||||
def change_admin(cls, db, user, admin):
|
|
||||||
user_role = cls.get_user_role(db)
|
|
||||||
admin_role = cls.get_admin_role(db)
|
|
||||||
if admin:
|
|
||||||
if user_role in user.roles:
|
|
||||||
remove_user(db, user, user_role)
|
|
||||||
add_user(db, user, admin_role)
|
|
||||||
else:
|
|
||||||
if admin_role in user.roles:
|
|
||||||
remove_user(db, user, admin_role)
|
|
||||||
add_user(db, user, user_role)
|
|
||||||
db.commit()
|
|
||||||
|
|
||||||
|
|
||||||
def add_user(db, user, role):
|
def add_user(db, user, role):
|
||||||
@@ -95,33 +70,21 @@ def remove_user(db, user, role):
|
|||||||
db.commit()
|
db.commit()
|
||||||
|
|
||||||
|
|
||||||
def add_predef_role(db, predef_role):
|
def update_roles(db, user):
|
||||||
"""
|
|
||||||
Returns either the role to write into db or updated role if already in db
|
"""Updates roles if user has no role with default or when user admin status is changed"""
|
||||||
"""
|
|
||||||
role = Role.find(db, predef_role['name'])
|
user_role = Role.find(db, 'user')
|
||||||
# if a new role, add to db, if existing, rewrite its attributes apart from the name
|
admin_role = Role.find(db, 'admin')
|
||||||
if role is None:
|
|
||||||
role = Role(
|
if user.admin:
|
||||||
name=predef_role['name'],
|
if user_role in user.roles:
|
||||||
description=predef_role['description'],
|
remove_user(db, user, user_role)
|
||||||
scopes=predef_role['scopes'],
|
add_user(db, user, admin_role)
|
||||||
)
|
|
||||||
db.add(role)
|
|
||||||
db.commit()
|
|
||||||
else:
|
else:
|
||||||
# check if it's not one of the default roles
|
if admin_role in user.roles:
|
||||||
if not any(d.name == predef_role['name'] for d in DefaultRoles.roles):
|
remove_user(db, user, admin_role)
|
||||||
# if description and scopes specified, rewrite the old ones
|
# only add user role if the user has no other roles
|
||||||
if 'description' in predef_role.keys():
|
if len(user.roles) < 1:
|
||||||
role.description = predef_role['description']
|
add_user(db, user, user_role)
|
||||||
if 'scopes' in predef_role.keys():
|
db.commit()
|
||||||
role.scopes = predef_role['scopes']
|
|
||||||
# FIXME - for now deletes old users and writes new ones
|
|
||||||
role.users = []
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
"The role %r is a default role that cannot be overwritten, use a different role name"
|
|
||||||
% predef_role['name']
|
|
||||||
)
|
|
||||||
return role
|
|
||||||
|
@@ -336,7 +336,8 @@ class MockHub(JupyterHub):
|
|||||||
user = self.db.query(orm.User).filter(orm.User.name == 'user').first()
|
user = self.db.query(orm.User).filter(orm.User.name == 'user').first()
|
||||||
if user is None:
|
if user is None:
|
||||||
user = orm.User(name='user')
|
user = orm.User(name='user')
|
||||||
roles.DefaultRoles.add_default_role(self.db, user=user)
|
user_role = orm.Role.find(self.db, 'user')
|
||||||
|
roles.add_user(self.db, user=user, role=user_role)
|
||||||
self.db.add(user)
|
self.db.add(user)
|
||||||
self.db.commit()
|
self.db.commit()
|
||||||
|
|
||||||
|
@@ -233,8 +233,8 @@ async def test_add_user(app):
|
|||||||
assert user.name == name
|
assert user.name == name
|
||||||
assert not user.admin
|
assert not user.admin
|
||||||
# assert newuser has default 'user' role
|
# assert newuser has default 'user' role
|
||||||
assert roles.DefaultRoles.get_user_role(db=db) in user.roles
|
assert orm.Role.find(db, 'user') in user.roles
|
||||||
assert roles.DefaultRoles.get_admin_role(db=db) not in user.roles
|
assert orm.Role.find(db, 'admin') not in user.roles
|
||||||
|
|
||||||
|
|
||||||
@mark.user
|
@mark.user
|
||||||
@@ -291,8 +291,8 @@ async def test_add_multi_user(app):
|
|||||||
assert user.name == name
|
assert user.name == name
|
||||||
assert not user.admin
|
assert not user.admin
|
||||||
# assert default 'user' role added
|
# assert default 'user' role added
|
||||||
assert roles.DefaultRoles.get_user_role(db=db) in user.roles
|
assert orm.Role.find(db, 'user') in user.roles
|
||||||
assert roles.DefaultRoles.get_admin_role(db=db) not in user.roles
|
assert orm.Role.find(db, 'admin') not in user.roles
|
||||||
|
|
||||||
# try to create the same users again
|
# try to create the same users again
|
||||||
r = await api_request(
|
r = await api_request(
|
||||||
@@ -333,8 +333,8 @@ async def test_add_multi_user_admin(app):
|
|||||||
assert user is not None
|
assert user is not None
|
||||||
assert user.name == name
|
assert user.name == name
|
||||||
assert user.admin
|
assert user.admin
|
||||||
assert roles.DefaultRoles.get_user_role(db=db) not in user.roles
|
assert orm.Role.find(db, 'user') not in user.roles
|
||||||
assert roles.DefaultRoles.get_admin_role(db=db) in user.roles
|
assert orm.Role.find(db, 'admin') in user.roles
|
||||||
|
|
||||||
|
|
||||||
@mark.user
|
@mark.user
|
||||||
@@ -369,13 +369,12 @@ async def test_add_admin(app):
|
|||||||
)
|
)
|
||||||
assert r.status_code == 201
|
assert r.status_code == 201
|
||||||
user = find_user(db, name)
|
user = find_user(db, name)
|
||||||
user_role = orm.Role.find(db, 'user')
|
|
||||||
assert user is not None
|
assert user is not None
|
||||||
assert user.name == name
|
assert user.name == name
|
||||||
assert user.admin
|
assert user.admin
|
||||||
# assert newadmin has default 'admin' role
|
# assert newadmin has default 'admin' role
|
||||||
assert roles.DefaultRoles.get_user_role(db=db) not in user.roles
|
assert orm.Role.find(db, 'user') not in user.roles
|
||||||
assert roles.DefaultRoles.get_admin_role(db=db) in user.roles
|
assert orm.Role.find(db, 'admin') in user.roles
|
||||||
|
|
||||||
|
|
||||||
@mark.user
|
@mark.user
|
||||||
@@ -397,8 +396,8 @@ async def test_make_admin(app):
|
|||||||
assert user is not None
|
assert user is not None
|
||||||
assert user.name == name
|
assert user.name == name
|
||||||
assert not user.admin
|
assert not user.admin
|
||||||
assert roles.DefaultRoles.get_user_role(db=db) in user.roles
|
assert orm.Role.find(db, 'user') in user.roles
|
||||||
assert roles.DefaultRoles.get_admin_role(db=db) not in user.roles
|
assert orm.Role.find(db, 'admin') not in user.roles
|
||||||
|
|
||||||
r = await api_request(
|
r = await api_request(
|
||||||
app, 'users', name, method='patch', data=json.dumps({'admin': True})
|
app, 'users', name, method='patch', data=json.dumps({'admin': True})
|
||||||
@@ -409,8 +408,8 @@ async def test_make_admin(app):
|
|||||||
assert user is not None
|
assert user is not None
|
||||||
assert user.name == name
|
assert user.name == name
|
||||||
assert user.admin
|
assert user.admin
|
||||||
assert roles.DefaultRoles.get_user_role(db=db) not in user.roles
|
assert orm.Role.find(db, 'user') not in user.roles
|
||||||
assert roles.DefaultRoles.get_admin_role(db=db) in user.roles
|
assert orm.Role.find(db, 'admin') in user.roles
|
||||||
|
|
||||||
|
|
||||||
@mark.user
|
@mark.user
|
||||||
|
@@ -3,7 +3,7 @@
|
|||||||
from pytest import mark
|
from pytest import mark
|
||||||
|
|
||||||
from .. import orm
|
from .. import orm
|
||||||
from ..roles import DefaultRoles
|
from .. import roles
|
||||||
from .mocking import MockHub
|
from .mocking import MockHub
|
||||||
|
|
||||||
|
|
||||||
@@ -95,37 +95,41 @@ def test_role_delete_cascade(db):
|
|||||||
@mark.role
|
@mark.role
|
||||||
async def test_load_roles(tmpdir, request):
|
async def test_load_roles(tmpdir, request):
|
||||||
"""Test loading default and predefined roles in app.py"""
|
"""Test loading default and predefined roles in app.py"""
|
||||||
to_load = [
|
roles_to_load = [
|
||||||
{
|
{
|
||||||
'name': 'teacher',
|
'name': 'teacher',
|
||||||
'description': 'Access users information, servers and groups without create/delete privileges',
|
'description': 'Access users information, servers and groups without create/delete privileges',
|
||||||
'scopes': ['users', 'groups'],
|
'scopes': ['users', 'groups'],
|
||||||
'users': ['cyclops', 'gandalf'],
|
'users': ['cyclops', 'gandalf'],
|
||||||
}
|
},
|
||||||
|
{
|
||||||
|
'name': 'user',
|
||||||
|
'description': 'Only read access',
|
||||||
|
'scopes': ['read:all'],
|
||||||
|
'users': ['test_user'],
|
||||||
|
},
|
||||||
]
|
]
|
||||||
kwargs = {'load_roles': to_load}
|
kwargs = {'load_roles': roles_to_load}
|
||||||
ssl_enabled = getattr(request.module, "ssl_enabled", False)
|
ssl_enabled = getattr(request.module, "ssl_enabled", False)
|
||||||
if ssl_enabled:
|
if ssl_enabled:
|
||||||
kwargs['internal_certs_location'] = str(tmpdir)
|
kwargs['internal_certs_location'] = str(tmpdir)
|
||||||
# keep the users and groups from test_load_groups
|
hub = MockHub(**kwargs)
|
||||||
hub = MockHub(test_clean_db=False, **kwargs)
|
|
||||||
hub.init_db()
|
hub.init_db()
|
||||||
|
db = hub.db
|
||||||
await hub.init_users()
|
await hub.init_users()
|
||||||
await hub.init_roles()
|
await hub.init_roles()
|
||||||
db = hub.db
|
# test if the 'user' role has been overwritten
|
||||||
# test default roles loaded to database
|
user_role = orm.Role.find(db, 'user')
|
||||||
assert DefaultRoles.get_user_role(db) is not None
|
assert user_role is not None
|
||||||
assert DefaultRoles.get_admin_role(db) is not None
|
assert user_role.scopes == ['read:all']
|
||||||
assert DefaultRoles.get_server_role(db) is not None
|
# test other default roles loaded to database
|
||||||
# test if every existing user has a correct default role
|
assert orm.Role.find(db, 'user') is not None
|
||||||
|
assert orm.Role.find(db, 'admin') is not None
|
||||||
|
assert orm.Role.find(db, 'server') is not None
|
||||||
|
# test if every existing user has a role (and no duplicates)
|
||||||
for user in db.query(orm.User):
|
for user in db.query(orm.User):
|
||||||
|
assert len(user.roles) > 0
|
||||||
assert len(user.roles) == len(set(user.roles))
|
assert len(user.roles) == len(set(user.roles))
|
||||||
if user.admin:
|
|
||||||
assert DefaultRoles.get_admin_role(db) in user.roles
|
|
||||||
assert DefaultRoles.get_user_role(db) not in user.roles
|
|
||||||
else:
|
|
||||||
assert DefaultRoles.get_user_role(db) in user.roles
|
|
||||||
assert DefaultRoles.get_admin_role(db) not in user.roles
|
|
||||||
# test if predefined roles loaded and assigned
|
# test if predefined roles loaded and assigned
|
||||||
teacher_role = orm.Role.find(db, name='teacher')
|
teacher_role = orm.Role.find(db, name='teacher')
|
||||||
assert teacher_role is not None
|
assert teacher_role is not None
|
||||||
|
Reference in New Issue
Block a user