explicitly support async oauth_client_allowed_scopes

This commit is contained in:
Min RK
2022-08-02 13:33:43 +02:00
parent a35a2ec8b7
commit 6a470b44e7
2 changed files with 8 additions and 2 deletions

View File

@@ -329,10 +329,13 @@ class Spawner(LoggingConfigurable):
Default is an empty list, meaning minimal permissions to identify users, Default is an empty list, meaning minimal permissions to identify users,
no actions can be taken on their behalf. no actions can be taken on their behalf.
If callable, will be called with the Spawner as a single argument.
Callables may be async.
""", """,
).tag(config=True) ).tag(config=True)
def _get_oauth_client_allowed_scopes(self): async def _get_oauth_client_allowed_scopes(self):
"""Private method: get oauth allowed scopes """Private method: get oauth allowed scopes
Handle: Handle:
@@ -351,6 +354,8 @@ class Spawner(LoggingConfigurable):
allowed_scopes = self.oauth_client_allowed_scopes allowed_scopes = self.oauth_client_allowed_scopes
if callable(allowed_scopes): if callable(allowed_scopes):
allowed_scopes = allowed_scopes(self) allowed_scopes = allowed_scopes(self)
if inspect.isawaitable(allowed_scopes):
allowed_scopes = await allowed_scopes
scopes.extend(allowed_scopes) scopes.extend(allowed_scopes)
if self.oauth_roles: if self.oauth_roles:

View File

@@ -666,11 +666,12 @@ class User:
client_id = spawner.oauth_client_id client_id = spawner.oauth_client_id
oauth_provider = self.settings.get('oauth_provider') oauth_provider = self.settings.get('oauth_provider')
if oauth_provider: if oauth_provider:
allowed_scopes = await spawner._get_oauth_client_allowed_scopes()
oauth_client = oauth_provider.add_client( oauth_client = oauth_provider.add_client(
client_id, client_id,
api_token, api_token,
url_path_join(self.url, url_escape_path(server_name), 'oauth_callback'), url_path_join(self.url, url_escape_path(server_name), 'oauth_callback'),
allowed_scopes=spawner._get_oauth_client_allowed_scopes(), allowed_scopes=allowed_scopes,
description="Server at %s" description="Server at %s"
% (url_path_join(self.base_url, server_name) + '/'), % (url_path_join(self.base_url, server_name) + '/'),
) )