mirror of
https://github.com/jupyterhub/jupyterhub.git
synced 2025-10-10 11:33:01 +00:00
allow HubAuth to be async
Switches requests to tornado AsyncHTTPClient instead of requests For backward-compatibility, use opt-in `sync=False` arg for all public methods that _may_ be async When sync=True (default), async functions still used, but blocking via ThreadPool + asyncio run_until_complete
This commit is contained in:
@@ -23,6 +23,7 @@ If you are using OAuth, you will also need to register an oauth callback handler
|
|||||||
A tornado implementation is provided in :class:`HubOAuthCallbackHandler`.
|
A tornado implementation is provided in :class:`HubOAuthCallbackHandler`.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
import asyncio
|
||||||
import base64
|
import base64
|
||||||
import hashlib
|
import hashlib
|
||||||
import json
|
import json
|
||||||
@@ -34,14 +35,26 @@ import string
|
|||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
import warnings
|
import warnings
|
||||||
|
from functools import partial
|
||||||
|
from http import HTTPStatus
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
from urllib.parse import urlencode
|
from urllib.parse import urlencode
|
||||||
|
|
||||||
import requests
|
from tornado.httpclient import AsyncHTTPClient, HTTPRequest
|
||||||
from tornado.httputil import url_concat
|
from tornado.httputil import url_concat
|
||||||
from tornado.log import app_log
|
from tornado.log import app_log
|
||||||
from tornado.web import HTTPError, RequestHandler
|
from tornado.web import HTTPError, RequestHandler
|
||||||
from traitlets import Dict, Instance, Integer, Set, Unicode, default, observe, validate
|
from traitlets import (
|
||||||
|
Any,
|
||||||
|
Dict,
|
||||||
|
Instance,
|
||||||
|
Integer,
|
||||||
|
Set,
|
||||||
|
Unicode,
|
||||||
|
default,
|
||||||
|
observe,
|
||||||
|
validate,
|
||||||
|
)
|
||||||
from traitlets.config import SingletonConfigurable
|
from traitlets.config import SingletonConfigurable
|
||||||
|
|
||||||
from ..scopes import _intersect_expanded_scopes
|
from ..scopes import _intersect_expanded_scopes
|
||||||
@@ -351,7 +364,47 @@ class HubAuth(SingletonConfigurable):
|
|||||||
return {f'access:services!service={service_name}'}
|
return {f'access:services!service={service_name}'}
|
||||||
return set()
|
return set()
|
||||||
|
|
||||||
def _check_hub_authorization(self, url, api_token, cache_key=None, use_cache=True):
|
_pool = Any(help="Thread pool for running async methods in the background")
|
||||||
|
|
||||||
|
@default("_pool")
|
||||||
|
def _new_pool(self):
|
||||||
|
# start a single ThreadPool in the background
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
|
||||||
|
pool = ThreadPoolExecutor(1)
|
||||||
|
# create an event loop in the thread
|
||||||
|
pool.submit(self._setup_asyncio_thread).result()
|
||||||
|
return pool
|
||||||
|
|
||||||
|
def _setup_asyncio_thread(self):
|
||||||
|
"""Create asyncio loop
|
||||||
|
|
||||||
|
To be called from the background thread,
|
||||||
|
so that any thread-local state is setup correctly
|
||||||
|
"""
|
||||||
|
self._thread_loop = asyncio.new_event_loop()
|
||||||
|
|
||||||
|
def _synchronize(self, async_f, *args, **kwargs):
|
||||||
|
"""Call an async method in our background thread"""
|
||||||
|
future = self._pool.submit(
|
||||||
|
lambda: self._thread_loop.run_until_complete(async_f(*args, **kwargs))
|
||||||
|
)
|
||||||
|
return future.result()
|
||||||
|
|
||||||
|
def _call_coroutine(self, sync, async_f, *args, **kwargs):
|
||||||
|
"""Call an async coroutine function, either blocking or returning an awaitable
|
||||||
|
|
||||||
|
if not sync: calls function directly, returning awaitable
|
||||||
|
else: Block on a call in our background thread, return actual result
|
||||||
|
"""
|
||||||
|
if not sync:
|
||||||
|
return async_f(*args, **kwargs)
|
||||||
|
else:
|
||||||
|
return self._synchronize(async_f, *args, **kwargs)
|
||||||
|
|
||||||
|
async def _check_hub_authorization(
|
||||||
|
self, url, api_token, cache_key=None, use_cache=True
|
||||||
|
):
|
||||||
"""Identify a user with the Hub
|
"""Identify a user with the Hub
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -374,7 +427,7 @@ class HubAuth(SingletonConfigurable):
|
|||||||
except KeyError:
|
except KeyError:
|
||||||
app_log.debug("HubAuth cache miss: %s", cache_key)
|
app_log.debug("HubAuth cache miss: %s", cache_key)
|
||||||
|
|
||||||
data = self._api_request(
|
data = await self._api_request(
|
||||||
'GET',
|
'GET',
|
||||||
url,
|
url,
|
||||||
headers={"Authorization": "token " + api_token},
|
headers={"Authorization": "token " + api_token},
|
||||||
@@ -389,18 +442,26 @@ class HubAuth(SingletonConfigurable):
|
|||||||
self.cache[cache_key] = data
|
self.cache[cache_key] = data
|
||||||
return data
|
return data
|
||||||
|
|
||||||
def _api_request(self, method, url, **kwargs):
|
async def _api_request(self, method, url, **kwargs):
|
||||||
"""Make an API request"""
|
"""Make an API request"""
|
||||||
allow_403 = kwargs.pop('allow_403', False)
|
allow_403 = kwargs.pop('allow_403', False)
|
||||||
headers = kwargs.setdefault('headers', {})
|
headers = kwargs.setdefault('headers', {})
|
||||||
headers.setdefault('Authorization', 'token %s' % self.api_token)
|
headers.setdefault('Authorization', f'token {self.api_token}')
|
||||||
if "cert" not in kwargs and self.certfile and self.keyfile:
|
# translate requests args to tornado's
|
||||||
kwargs["cert"] = (self.certfile, self.keyfile)
|
if self.certfile:
|
||||||
if self.client_ca:
|
kwargs["client_cert"] = self.certfile
|
||||||
kwargs["verify"] = self.client_ca
|
if self.keyfile:
|
||||||
|
kwargs["client_key"] = self.keyfile
|
||||||
|
if self.client_ca:
|
||||||
|
kwargs["ca_certs"] = self.client_ca
|
||||||
|
req = HTTPRequest(
|
||||||
|
url,
|
||||||
|
method=method,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
r = requests.request(method, url, **kwargs)
|
r = await AsyncHTTPClient().fetch(req, raise_error=False)
|
||||||
except requests.ConnectionError as e:
|
except Exception as e:
|
||||||
app_log.error("Error connecting to %s: %s", self.api_url, e)
|
app_log.error("Error connecting to %s: %s", self.api_url, e)
|
||||||
msg = "Failed to connect to Hub API at %r." % self.api_url
|
msg = "Failed to connect to Hub API at %r." % self.api_url
|
||||||
msg += (
|
msg += (
|
||||||
@@ -415,35 +476,46 @@ class HubAuth(SingletonConfigurable):
|
|||||||
raise HTTPError(500, msg)
|
raise HTTPError(500, msg)
|
||||||
|
|
||||||
data = None
|
data = None
|
||||||
if r.status_code == 403 and allow_403:
|
try:
|
||||||
|
status = HTTPStatus(r.code)
|
||||||
|
except ValueError:
|
||||||
|
app_log.error(
|
||||||
|
f"Unknown error checking authorization with JupyterHub: {r.code}"
|
||||||
|
)
|
||||||
|
app_log.error(r.body.decode("utf8", "replace"))
|
||||||
|
|
||||||
|
response_text = r.body.decode("utf8", "replace")
|
||||||
|
if status.value == 403 and allow_403:
|
||||||
pass
|
pass
|
||||||
elif r.status_code == 403:
|
elif status.value == 403:
|
||||||
app_log.error(
|
app_log.error(
|
||||||
"I don't have permission to check authorization with JupyterHub, my auth token may have expired: [%i] %s",
|
"I don't have permission to check authorization with JupyterHub, my auth token may have expired: [%i] %s",
|
||||||
r.status_code,
|
status.value,
|
||||||
r.reason,
|
status.description,
|
||||||
)
|
)
|
||||||
app_log.error(r.text)
|
app_log.error(response_text)
|
||||||
raise HTTPError(
|
raise HTTPError(
|
||||||
500, "Permission failure checking authorization, I may need a new token"
|
500, "Permission failure checking authorization, I may need a new token"
|
||||||
)
|
)
|
||||||
elif r.status_code >= 500:
|
elif status.value >= 500:
|
||||||
app_log.error(
|
app_log.error(
|
||||||
"Upstream failure verifying auth token: [%i] %s",
|
"Upstream failure verifying auth token: [%i] %s",
|
||||||
r.status_code,
|
status.value,
|
||||||
r.reason,
|
status.description,
|
||||||
)
|
)
|
||||||
app_log.error(r.text)
|
app_log.error(response_text)
|
||||||
raise HTTPError(502, "Failed to check authorization (upstream problem)")
|
raise HTTPError(502, "Failed to check authorization (upstream problem)")
|
||||||
elif r.status_code >= 400:
|
elif status.value >= 400:
|
||||||
app_log.warning(
|
app_log.warning(
|
||||||
"Failed to check authorization: [%i] %s", r.status_code, r.reason
|
"Failed to check authorization: [%i] %s",
|
||||||
|
status.value,
|
||||||
|
status.description,
|
||||||
)
|
)
|
||||||
app_log.warning(r.text)
|
app_log.warning(response_text)
|
||||||
msg = "Failed to check authorization"
|
msg = "Failed to check authorization"
|
||||||
# pass on error from oauth failure
|
# pass on error from oauth failure
|
||||||
try:
|
try:
|
||||||
response = r.json()
|
response = json.loads(response_text)
|
||||||
# prefer more specific 'error_description', fallback to 'error'
|
# prefer more specific 'error_description', fallback to 'error'
|
||||||
description = response.get(
|
description = response.get(
|
||||||
"error_description", response.get("error", "Unknown error")
|
"error_description", response.get("error", "Unknown error")
|
||||||
@@ -454,7 +526,7 @@ class HubAuth(SingletonConfigurable):
|
|||||||
msg += ": " + description
|
msg += ": " + description
|
||||||
raise HTTPError(500, msg)
|
raise HTTPError(500, msg)
|
||||||
else:
|
else:
|
||||||
data = r.json()
|
data = json.loads(response_text)
|
||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
@@ -464,19 +536,25 @@ class HubAuth(SingletonConfigurable):
|
|||||||
"Identifying users by shared cookie is removed in JupyterHub 2.0. Use OAuth tokens."
|
"Identifying users by shared cookie is removed in JupyterHub 2.0. Use OAuth tokens."
|
||||||
)
|
)
|
||||||
|
|
||||||
def user_for_token(self, token, use_cache=True, session_id=''):
|
def user_for_token(self, token, use_cache=True, session_id='', *, sync=True):
|
||||||
"""Ask the Hub to identify the user for a given token.
|
"""Ask the Hub to identify the user for a given token.
|
||||||
|
|
||||||
|
.. versionadded:: 2.4
|
||||||
|
async support via `sync` argument.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
token (str): the token
|
token (str): the token
|
||||||
use_cache (bool): Specify use_cache=False to skip cached cookie values (default: True)
|
use_cache (bool): Specify use_cache=False to skip cached cookie values (default: True)
|
||||||
|
sync (bool): whether to block for the result or return an awaitable
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
user_model (dict): The user model, if a user is identified, None if authentication fails.
|
user_model (dict): The user model, if a user is identified, None if authentication fails.
|
||||||
|
|
||||||
The 'name' field contains the user's name.
|
The 'name' field contains the user's name.
|
||||||
"""
|
"""
|
||||||
return self._check_hub_authorization(
|
return self._call_coroutine(
|
||||||
|
sync,
|
||||||
|
self._check_hub_authorization,
|
||||||
url=url_path_join(
|
url=url_path_join(
|
||||||
self.api_url,
|
self.api_url,
|
||||||
"user",
|
"user",
|
||||||
@@ -521,7 +599,7 @@ class HubAuth(SingletonConfigurable):
|
|||||||
"""Base class doesn't store tokens in cookies"""
|
"""Base class doesn't store tokens in cookies"""
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _get_user_cookie(self, handler):
|
async def _get_user_cookie(self, handler):
|
||||||
"""Get the user model from a cookie"""
|
"""Get the user model from a cookie"""
|
||||||
# overridden in HubOAuth to store the access token after oauth
|
# overridden in HubOAuth to store the access token after oauth
|
||||||
return None
|
return None
|
||||||
@@ -533,20 +611,26 @@ class HubAuth(SingletonConfigurable):
|
|||||||
"""
|
"""
|
||||||
return handler.get_cookie('jupyterhub-session-id', '')
|
return handler.get_cookie('jupyterhub-session-id', '')
|
||||||
|
|
||||||
def get_user(self, handler):
|
def get_user(self, handler, *, sync=True):
|
||||||
"""Get the Hub user for a given tornado handler.
|
"""Get the Hub user for a given tornado handler.
|
||||||
|
|
||||||
Checks cookie with the Hub to identify the current user.
|
Checks cookie with the Hub to identify the current user.
|
||||||
|
|
||||||
|
.. versionadded:: 2.4
|
||||||
|
async support via `sync` argument.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
handler (tornado.web.RequestHandler): the current request handler
|
handler (tornado.web.RequestHandler): the current request handler
|
||||||
|
sync (bool): whether to block for the result or return an awaitable
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
user_model (dict): The user model, if a user is identified, None if authentication fails.
|
user_model (dict): The user model, if a user is identified, None if authentication fails.
|
||||||
|
|
||||||
The 'name' field contains the user's name.
|
The 'name' field contains the user's name.
|
||||||
"""
|
"""
|
||||||
|
return self._call_coroutine(sync, self._get_user, handler)
|
||||||
|
|
||||||
|
async def _get_user(self, handler):
|
||||||
# only allow this to be called once per handler
|
# only allow this to be called once per handler
|
||||||
# avoids issues if an error is raised,
|
# avoids issues if an error is raised,
|
||||||
# since this may be called again when trying to render the error page
|
# since this may be called again when trying to render the error page
|
||||||
@@ -561,13 +645,15 @@ class HubAuth(SingletonConfigurable):
|
|||||||
# is token-authenticated (CORS-related)
|
# is token-authenticated (CORS-related)
|
||||||
token = self.get_token(handler, in_cookie=False)
|
token = self.get_token(handler, in_cookie=False)
|
||||||
if token:
|
if token:
|
||||||
user_model = self.user_for_token(token, session_id=session_id)
|
user_model = await self.user_for_token(
|
||||||
|
token, session_id=session_id, sync=False
|
||||||
|
)
|
||||||
if user_model:
|
if user_model:
|
||||||
handler._token_authenticated = True
|
handler._token_authenticated = True
|
||||||
|
|
||||||
# no token, check cookie
|
# no token, check cookie
|
||||||
if user_model is None:
|
if user_model is None:
|
||||||
user_model = self._get_user_cookie(handler)
|
user_model = await self._get_user_cookie(handler)
|
||||||
|
|
||||||
# cache result
|
# cache result
|
||||||
handler._cached_hub_user = user_model
|
handler._cached_hub_user = user_model
|
||||||
@@ -627,11 +713,13 @@ class HubOAuth(HubAuth):
|
|||||||
token = token.decode('ascii', 'replace')
|
token = token.decode('ascii', 'replace')
|
||||||
return token
|
return token
|
||||||
|
|
||||||
def _get_user_cookie(self, handler):
|
async def _get_user_cookie(self, handler):
|
||||||
token = self._get_token_cookie(handler)
|
token = self._get_token_cookie(handler)
|
||||||
session_id = self.get_session_id(handler)
|
session_id = self.get_session_id(handler)
|
||||||
if token:
|
if token:
|
||||||
user_model = self.user_for_token(token, session_id=session_id)
|
user_model = await self.user_for_token(
|
||||||
|
token, session_id=session_id, sync=False
|
||||||
|
)
|
||||||
if user_model is None:
|
if user_model is None:
|
||||||
app_log.warning("Token stored in cookie may have expired")
|
app_log.warning("Token stored in cookie may have expired")
|
||||||
handler.clear_cookie(self.cookie_name)
|
handler.clear_cookie(self.cookie_name)
|
||||||
@@ -686,7 +774,7 @@ class HubOAuth(HubAuth):
|
|||||||
def _token_url(self):
|
def _token_url(self):
|
||||||
return url_path_join(self.api_url, 'oauth2/token')
|
return url_path_join(self.api_url, 'oauth2/token')
|
||||||
|
|
||||||
def token_for_code(self, code):
|
def token_for_code(self, code, *, sync=True):
|
||||||
"""Get token for OAuth temporary code
|
"""Get token for OAuth temporary code
|
||||||
|
|
||||||
This is the last step of OAuth login.
|
This is the last step of OAuth login.
|
||||||
@@ -697,6 +785,9 @@ class HubOAuth(HubAuth):
|
|||||||
Returns:
|
Returns:
|
||||||
token (str): JupyterHub API Token
|
token (str): JupyterHub API Token
|
||||||
"""
|
"""
|
||||||
|
return self._call_coroutine(sync, self._token_for_code, code)
|
||||||
|
|
||||||
|
async def _token_for_code(self, code):
|
||||||
# GitHub specifies a POST request yet requires URL parameters
|
# GitHub specifies a POST request yet requires URL parameters
|
||||||
params = dict(
|
params = dict(
|
||||||
client_id=self.oauth_client_id,
|
client_id=self.oauth_client_id,
|
||||||
@@ -706,10 +797,10 @@ class HubOAuth(HubAuth):
|
|||||||
redirect_uri=self.oauth_redirect_uri,
|
redirect_uri=self.oauth_redirect_uri,
|
||||||
)
|
)
|
||||||
|
|
||||||
token_reply = self._api_request(
|
token_reply = await self._api_request(
|
||||||
'POST',
|
'POST',
|
||||||
self.oauth_token_url,
|
self.oauth_token_url,
|
||||||
data=urlencode(params).encode('utf8'),
|
body=urlencode(params).encode('utf8'),
|
||||||
headers={'Content-Type': 'application/x-www-form-urlencoded'},
|
headers={'Content-Type': 'application/x-www-form-urlencoded'},
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1114,10 +1205,12 @@ class HubOAuthCallbackHandler(HubOAuthenticated, RequestHandler):
|
|||||||
app_log.warning("oauth state %r != %r", arg_state, cookie_state)
|
app_log.warning("oauth state %r != %r", arg_state, cookie_state)
|
||||||
raise HTTPError(403, "oauth state does not match. Try logging in again.")
|
raise HTTPError(403, "oauth state does not match. Try logging in again.")
|
||||||
next_url = self.hub_auth.get_next_url(cookie_state)
|
next_url = self.hub_auth.get_next_url(cookie_state)
|
||||||
# TODO: make async (in a Thread?)
|
|
||||||
token = self.hub_auth.token_for_code(code)
|
token = await self.hub_auth.token_for_code(code, sync=False)
|
||||||
session_id = self.hub_auth.get_session_id(self)
|
session_id = self.hub_auth.get_session_id(self)
|
||||||
user_model = self.hub_auth.user_for_token(token, session_id=session_id)
|
user_model = await self.hub_auth.user_for_token(
|
||||||
|
token, session_id=session_id, sync=False
|
||||||
|
)
|
||||||
if user_model is None:
|
if user_model is None:
|
||||||
raise HTTPError(500, "oauth callback failed to identify a user")
|
raise HTTPError(500, "oauth callback failed to identify a user")
|
||||||
app_log.info("Logged-in user %s", user_model)
|
app_log.info("Logged-in user %s", user_model)
|
||||||
|
Reference in New Issue
Block a user