diff --git a/jupyterhub/app.py b/jupyterhub/app.py index 86d904ef..6da806ad 100644 --- a/jupyterhub/app.py +++ b/jupyterhub/app.py @@ -111,7 +111,6 @@ from .objects import Hub, Server # For faking stats from .emptyclass import EmptyClass - common_aliases = { 'log-level': 'Application.log_level', 'f': 'JupyterHub.config_file', @@ -119,7 +118,6 @@ common_aliases = { 'db': 'JupyterHub.db_url', } - aliases = { 'base-url': 'JupyterHub.base_url', 'y': 'JupyterHub.answer_yes', @@ -2129,6 +2127,7 @@ class JupyterHub(Application): name = spec['name'] # get/create orm orm_service = orm.Service.find(self.db, name=name) + allowed_roles = spec.get('allowed_roles', []) if orm_service is None: # not found, create a new one orm_service = orm.Service(name=name) @@ -2193,6 +2192,7 @@ class JupyterHub(Application): client_id=service.oauth_client_id, client_secret=service.api_token, redirect_uri=service.oauth_redirect_uri, + allowed_roles=allowed_roles, description="JupyterHub service %s" % service.name, ) diff --git a/jupyterhub/oauth/provider.py b/jupyterhub/oauth/provider.py index 17b3eead..052616a7 100644 --- a/jupyterhub/oauth/provider.py +++ b/jupyterhub/oauth/provider.py @@ -586,7 +586,9 @@ class JupyterHubOAuthServer(WebApplicationServer): self.db = db super().__init__(validator, *args, **kwargs) - def add_client(self, client_id, client_secret, redirect_uri, description=''): + def add_client( + self, client_id, client_secret, redirect_uri, allowed_roles, description='' + ): """Add a client hash its client_secret before putting it in the database. @@ -610,6 +612,7 @@ class JupyterHubOAuthServer(WebApplicationServer): orm_client.secret = hash_token(client_secret) if client_secret else "" orm_client.redirect_uri = redirect_uri orm_client.description = description + orm_client.allowed_roles = allowed_roles self.db.commit() def fetch_by_client_id(self, client_id): diff --git a/jupyterhub/services/service.py b/jupyterhub/services/service.py index c72ae382..380cddf8 100644 --- a/jupyterhub/services/service.py +++ b/jupyterhub/services/service.py @@ -50,6 +50,7 @@ from traitlets import default from traitlets import Dict from traitlets import HasTraits from traitlets import Instance +from traitlets import List from traitlets import Unicode from traitlets import validate from traitlets.config import LoggingConfigurable @@ -189,6 +190,13 @@ class Service(LoggingConfigurable): """ ).tag(input=True) + allowed_roles = List( + help="""OAuth allowed roles. + + List of roles that are passed to generated tokens if the service act as an OAuth client + on behalf of users""" + ).tag(input=True) + api_token = Unicode( help="""The API token to use for the service. diff --git a/jupyterhub/tests/test_roles.py b/jupyterhub/tests/test_roles.py index f538ff21..450640f1 100644 --- a/jupyterhub/tests/test_roles.py +++ b/jupyterhub/tests/test_roles.py @@ -13,6 +13,7 @@ from .. import roles from ..scopes import get_scopes_for from ..utils import maybe_future from .mocking import MockHub +from .test_scopes import create_temp_role from .utils import add_user from .utils import api_request @@ -898,3 +899,19 @@ async def test_valid_names(name, valid): else: with pytest.raises(ValueError): roles._validate_role_name(name) + + +async def test_oauth_allowed_roles(app, create_temp_role): + allowed_roles = ['oracle', 'goose'] + service = { + 'name': 'oas1', + 'api_token': 'some-token', + 'allowed_roles': ['oracle', 'goose'], + } + for role in allowed_roles: + create_temp_role('read:users', role_name=role) + app.services.append(service) + app.init_services() + app_service = app.services[0] + assert app_service['name'] == 'oas1' + assert set(app_service['allowed_roles']) == set(allowed_roles)