allow HubAuth to be async

Switches requests to tornado AsyncHTTPClient instead of requests

For backward-compatibility, use opt-in `sync=False` arg for all public methods that _may_ be async

When sync=True (default), async functions still used, but blocking via ThreadPool + asyncio run_until_complete
This commit is contained in:
Min RK
2022-05-03 12:58:39 +02:00
parent 8a44748324
commit b9c83cf7ab

View File

@@ -23,6 +23,7 @@ If you are using OAuth, you will also need to register an oauth callback handler
A tornado implementation is provided in :class:`HubOAuthCallbackHandler`. A tornado implementation is provided in :class:`HubOAuthCallbackHandler`.
""" """
import asyncio
import base64 import base64
import hashlib import hashlib
import json import json
@@ -34,14 +35,26 @@ import string
import time import time
import uuid import uuid
import warnings import warnings
from functools import partial
from http import HTTPStatus
from unittest import mock from unittest import mock
from urllib.parse import urlencode from urllib.parse import urlencode
import requests from tornado.httpclient import AsyncHTTPClient, HTTPRequest
from tornado.httputil import url_concat from tornado.httputil import url_concat
from tornado.log import app_log from tornado.log import app_log
from tornado.web import HTTPError, RequestHandler from tornado.web import HTTPError, RequestHandler
from traitlets import Dict, Instance, Integer, Set, Unicode, default, observe, validate from traitlets import (
Any,
Dict,
Instance,
Integer,
Set,
Unicode,
default,
observe,
validate,
)
from traitlets.config import SingletonConfigurable from traitlets.config import SingletonConfigurable
from ..scopes import _intersect_expanded_scopes from ..scopes import _intersect_expanded_scopes
@@ -351,7 +364,47 @@ class HubAuth(SingletonConfigurable):
return {f'access:services!service={service_name}'} return {f'access:services!service={service_name}'}
return set() return set()
def _check_hub_authorization(self, url, api_token, cache_key=None, use_cache=True): _pool = Any(help="Thread pool for running async methods in the background")
@default("_pool")
def _new_pool(self):
# start a single ThreadPool in the background
from concurrent.futures import ThreadPoolExecutor
pool = ThreadPoolExecutor(1)
# create an event loop in the thread
pool.submit(self._setup_asyncio_thread).result()
return pool
def _setup_asyncio_thread(self):
"""Create asyncio loop
To be called from the background thread,
so that any thread-local state is setup correctly
"""
self._thread_loop = asyncio.new_event_loop()
def _synchronize(self, async_f, *args, **kwargs):
"""Call an async method in our background thread"""
future = self._pool.submit(
lambda: self._thread_loop.run_until_complete(async_f(*args, **kwargs))
)
return future.result()
def _call_coroutine(self, sync, async_f, *args, **kwargs):
"""Call an async coroutine function, either blocking or returning an awaitable
if not sync: calls function directly, returning awaitable
else: Block on a call in our background thread, return actual result
"""
if not sync:
return async_f(*args, **kwargs)
else:
return self._synchronize(async_f, *args, **kwargs)
async def _check_hub_authorization(
self, url, api_token, cache_key=None, use_cache=True
):
"""Identify a user with the Hub """Identify a user with the Hub
Args: Args:
@@ -374,7 +427,7 @@ class HubAuth(SingletonConfigurable):
except KeyError: except KeyError:
app_log.debug("HubAuth cache miss: %s", cache_key) app_log.debug("HubAuth cache miss: %s", cache_key)
data = self._api_request( data = await self._api_request(
'GET', 'GET',
url, url,
headers={"Authorization": "token " + api_token}, headers={"Authorization": "token " + api_token},
@@ -389,18 +442,26 @@ class HubAuth(SingletonConfigurable):
self.cache[cache_key] = data self.cache[cache_key] = data
return data return data
def _api_request(self, method, url, **kwargs): async def _api_request(self, method, url, **kwargs):
"""Make an API request""" """Make an API request"""
allow_403 = kwargs.pop('allow_403', False) allow_403 = kwargs.pop('allow_403', False)
headers = kwargs.setdefault('headers', {}) headers = kwargs.setdefault('headers', {})
headers.setdefault('Authorization', 'token %s' % self.api_token) headers.setdefault('Authorization', f'token {self.api_token}')
if "cert" not in kwargs and self.certfile and self.keyfile: # translate requests args to tornado's
kwargs["cert"] = (self.certfile, self.keyfile) if self.certfile:
if self.client_ca: kwargs["client_cert"] = self.certfile
kwargs["verify"] = self.client_ca if self.keyfile:
kwargs["client_key"] = self.keyfile
if self.client_ca:
kwargs["ca_certs"] = self.client_ca
req = HTTPRequest(
url,
method=method,
**kwargs,
)
try: try:
r = requests.request(method, url, **kwargs) r = await AsyncHTTPClient().fetch(req, raise_error=False)
except requests.ConnectionError as e: except Exception as e:
app_log.error("Error connecting to %s: %s", self.api_url, e) app_log.error("Error connecting to %s: %s", self.api_url, e)
msg = "Failed to connect to Hub API at %r." % self.api_url msg = "Failed to connect to Hub API at %r." % self.api_url
msg += ( msg += (
@@ -415,35 +476,46 @@ class HubAuth(SingletonConfigurable):
raise HTTPError(500, msg) raise HTTPError(500, msg)
data = None data = None
if r.status_code == 403 and allow_403: try:
status = HTTPStatus(r.code)
except ValueError:
app_log.error(
f"Unknown error checking authorization with JupyterHub: {r.code}"
)
app_log.error(r.body.decode("utf8", "replace"))
response_text = r.body.decode("utf8", "replace")
if status.value == 403 and allow_403:
pass pass
elif r.status_code == 403: elif status.value == 403:
app_log.error( app_log.error(
"I don't have permission to check authorization with JupyterHub, my auth token may have expired: [%i] %s", "I don't have permission to check authorization with JupyterHub, my auth token may have expired: [%i] %s",
r.status_code, status.value,
r.reason, status.description,
) )
app_log.error(r.text) app_log.error(response_text)
raise HTTPError( raise HTTPError(
500, "Permission failure checking authorization, I may need a new token" 500, "Permission failure checking authorization, I may need a new token"
) )
elif r.status_code >= 500: elif status.value >= 500:
app_log.error( app_log.error(
"Upstream failure verifying auth token: [%i] %s", "Upstream failure verifying auth token: [%i] %s",
r.status_code, status.value,
r.reason, status.description,
) )
app_log.error(r.text) app_log.error(response_text)
raise HTTPError(502, "Failed to check authorization (upstream problem)") raise HTTPError(502, "Failed to check authorization (upstream problem)")
elif r.status_code >= 400: elif status.value >= 400:
app_log.warning( app_log.warning(
"Failed to check authorization: [%i] %s", r.status_code, r.reason "Failed to check authorization: [%i] %s",
status.value,
status.description,
) )
app_log.warning(r.text) app_log.warning(response_text)
msg = "Failed to check authorization" msg = "Failed to check authorization"
# pass on error from oauth failure # pass on error from oauth failure
try: try:
response = r.json() response = json.loads(response_text)
# prefer more specific 'error_description', fallback to 'error' # prefer more specific 'error_description', fallback to 'error'
description = response.get( description = response.get(
"error_description", response.get("error", "Unknown error") "error_description", response.get("error", "Unknown error")
@@ -454,7 +526,7 @@ class HubAuth(SingletonConfigurable):
msg += ": " + description msg += ": " + description
raise HTTPError(500, msg) raise HTTPError(500, msg)
else: else:
data = r.json() data = json.loads(response_text)
return data return data
@@ -464,19 +536,25 @@ class HubAuth(SingletonConfigurable):
"Identifying users by shared cookie is removed in JupyterHub 2.0. Use OAuth tokens." "Identifying users by shared cookie is removed in JupyterHub 2.0. Use OAuth tokens."
) )
def user_for_token(self, token, use_cache=True, session_id=''): def user_for_token(self, token, use_cache=True, session_id='', *, sync=True):
"""Ask the Hub to identify the user for a given token. """Ask the Hub to identify the user for a given token.
.. versionadded:: 2.4
async support via `sync` argument.
Args: Args:
token (str): the token token (str): the token
use_cache (bool): Specify use_cache=False to skip cached cookie values (default: True) use_cache (bool): Specify use_cache=False to skip cached cookie values (default: True)
sync (bool): whether to block for the result or return an awaitable
Returns: Returns:
user_model (dict): The user model, if a user is identified, None if authentication fails. user_model (dict): The user model, if a user is identified, None if authentication fails.
The 'name' field contains the user's name. The 'name' field contains the user's name.
""" """
return self._check_hub_authorization( return self._call_coroutine(
sync,
self._check_hub_authorization,
url=url_path_join( url=url_path_join(
self.api_url, self.api_url,
"user", "user",
@@ -521,7 +599,7 @@ class HubAuth(SingletonConfigurable):
"""Base class doesn't store tokens in cookies""" """Base class doesn't store tokens in cookies"""
return None return None
def _get_user_cookie(self, handler): async def _get_user_cookie(self, handler):
"""Get the user model from a cookie""" """Get the user model from a cookie"""
# overridden in HubOAuth to store the access token after oauth # overridden in HubOAuth to store the access token after oauth
return None return None
@@ -533,20 +611,26 @@ class HubAuth(SingletonConfigurable):
""" """
return handler.get_cookie('jupyterhub-session-id', '') return handler.get_cookie('jupyterhub-session-id', '')
def get_user(self, handler): def get_user(self, handler, *, sync=True):
"""Get the Hub user for a given tornado handler. """Get the Hub user for a given tornado handler.
Checks cookie with the Hub to identify the current user. Checks cookie with the Hub to identify the current user.
.. versionadded:: 2.4
async support via `sync` argument.
Args: Args:
handler (tornado.web.RequestHandler): the current request handler handler (tornado.web.RequestHandler): the current request handler
sync (bool): whether to block for the result or return an awaitable
Returns: Returns:
user_model (dict): The user model, if a user is identified, None if authentication fails. user_model (dict): The user model, if a user is identified, None if authentication fails.
The 'name' field contains the user's name. The 'name' field contains the user's name.
""" """
return self._call_coroutine(sync, self._get_user, handler)
async def _get_user(self, handler):
# only allow this to be called once per handler # only allow this to be called once per handler
# avoids issues if an error is raised, # avoids issues if an error is raised,
# since this may be called again when trying to render the error page # since this may be called again when trying to render the error page
@@ -561,13 +645,15 @@ class HubAuth(SingletonConfigurable):
# is token-authenticated (CORS-related) # is token-authenticated (CORS-related)
token = self.get_token(handler, in_cookie=False) token = self.get_token(handler, in_cookie=False)
if token: if token:
user_model = self.user_for_token(token, session_id=session_id) user_model = await self.user_for_token(
token, session_id=session_id, sync=False
)
if user_model: if user_model:
handler._token_authenticated = True handler._token_authenticated = True
# no token, check cookie # no token, check cookie
if user_model is None: if user_model is None:
user_model = self._get_user_cookie(handler) user_model = await self._get_user_cookie(handler)
# cache result # cache result
handler._cached_hub_user = user_model handler._cached_hub_user = user_model
@@ -627,11 +713,13 @@ class HubOAuth(HubAuth):
token = token.decode('ascii', 'replace') token = token.decode('ascii', 'replace')
return token return token
def _get_user_cookie(self, handler): async def _get_user_cookie(self, handler):
token = self._get_token_cookie(handler) token = self._get_token_cookie(handler)
session_id = self.get_session_id(handler) session_id = self.get_session_id(handler)
if token: if token:
user_model = self.user_for_token(token, session_id=session_id) user_model = await self.user_for_token(
token, session_id=session_id, sync=False
)
if user_model is None: if user_model is None:
app_log.warning("Token stored in cookie may have expired") app_log.warning("Token stored in cookie may have expired")
handler.clear_cookie(self.cookie_name) handler.clear_cookie(self.cookie_name)
@@ -686,7 +774,7 @@ class HubOAuth(HubAuth):
def _token_url(self): def _token_url(self):
return url_path_join(self.api_url, 'oauth2/token') return url_path_join(self.api_url, 'oauth2/token')
def token_for_code(self, code): def token_for_code(self, code, *, sync=True):
"""Get token for OAuth temporary code """Get token for OAuth temporary code
This is the last step of OAuth login. This is the last step of OAuth login.
@@ -697,6 +785,9 @@ class HubOAuth(HubAuth):
Returns: Returns:
token (str): JupyterHub API Token token (str): JupyterHub API Token
""" """
return self._call_coroutine(sync, self._token_for_code, code)
async def _token_for_code(self, code):
# GitHub specifies a POST request yet requires URL parameters # GitHub specifies a POST request yet requires URL parameters
params = dict( params = dict(
client_id=self.oauth_client_id, client_id=self.oauth_client_id,
@@ -706,10 +797,10 @@ class HubOAuth(HubAuth):
redirect_uri=self.oauth_redirect_uri, redirect_uri=self.oauth_redirect_uri,
) )
token_reply = self._api_request( token_reply = await self._api_request(
'POST', 'POST',
self.oauth_token_url, self.oauth_token_url,
data=urlencode(params).encode('utf8'), body=urlencode(params).encode('utf8'),
headers={'Content-Type': 'application/x-www-form-urlencoded'}, headers={'Content-Type': 'application/x-www-form-urlencoded'},
) )
@@ -1114,10 +1205,12 @@ class HubOAuthCallbackHandler(HubOAuthenticated, RequestHandler):
app_log.warning("oauth state %r != %r", arg_state, cookie_state) app_log.warning("oauth state %r != %r", arg_state, cookie_state)
raise HTTPError(403, "oauth state does not match. Try logging in again.") raise HTTPError(403, "oauth state does not match. Try logging in again.")
next_url = self.hub_auth.get_next_url(cookie_state) next_url = self.hub_auth.get_next_url(cookie_state)
# TODO: make async (in a Thread?)
token = self.hub_auth.token_for_code(code) token = await self.hub_auth.token_for_code(code, sync=False)
session_id = self.hub_auth.get_session_id(self) session_id = self.hub_auth.get_session_id(self)
user_model = self.hub_auth.user_for_token(token, session_id=session_id) user_model = await self.hub_auth.user_for_token(
token, session_id=session_id, sync=False
)
if user_model is None: if user_model is None:
raise HTTPError(500, "oauth callback failed to identify a user") raise HTTPError(500, "oauth callback failed to identify a user")
app_log.info("Logged-in user %s", user_model) app_log.info("Logged-in user %s", user_model)