allow HubAuth to be async

Switches requests to tornado AsyncHTTPClient instead of requests

For backward-compatibility, use opt-in `sync=False` arg for all public methods that _may_ be async

When sync=True (default), async functions still used, but blocking via ThreadPool + asyncio run_until_complete
This commit is contained in:
Min RK
2022-05-03 12:58:39 +02:00
parent 8a44748324
commit b9c83cf7ab

View File

@@ -23,6 +23,7 @@ If you are using OAuth, you will also need to register an oauth callback handler
A tornado implementation is provided in :class:`HubOAuthCallbackHandler`.
"""
import asyncio
import base64
import hashlib
import json
@@ -34,14 +35,26 @@ import string
import time
import uuid
import warnings
from functools import partial
from http import HTTPStatus
from unittest import mock
from urllib.parse import urlencode
import requests
from tornado.httpclient import AsyncHTTPClient, HTTPRequest
from tornado.httputil import url_concat
from tornado.log import app_log
from tornado.web import HTTPError, RequestHandler
from traitlets import Dict, Instance, Integer, Set, Unicode, default, observe, validate
from traitlets import (
Any,
Dict,
Instance,
Integer,
Set,
Unicode,
default,
observe,
validate,
)
from traitlets.config import SingletonConfigurable
from ..scopes import _intersect_expanded_scopes
@@ -351,7 +364,47 @@ class HubAuth(SingletonConfigurable):
return {f'access:services!service={service_name}'}
return set()
def _check_hub_authorization(self, url, api_token, cache_key=None, use_cache=True):
_pool = Any(help="Thread pool for running async methods in the background")
@default("_pool")
def _new_pool(self):
# start a single ThreadPool in the background
from concurrent.futures import ThreadPoolExecutor
pool = ThreadPoolExecutor(1)
# create an event loop in the thread
pool.submit(self._setup_asyncio_thread).result()
return pool
def _setup_asyncio_thread(self):
"""Create asyncio loop
To be called from the background thread,
so that any thread-local state is setup correctly
"""
self._thread_loop = asyncio.new_event_loop()
def _synchronize(self, async_f, *args, **kwargs):
"""Call an async method in our background thread"""
future = self._pool.submit(
lambda: self._thread_loop.run_until_complete(async_f(*args, **kwargs))
)
return future.result()
def _call_coroutine(self, sync, async_f, *args, **kwargs):
"""Call an async coroutine function, either blocking or returning an awaitable
if not sync: calls function directly, returning awaitable
else: Block on a call in our background thread, return actual result
"""
if not sync:
return async_f(*args, **kwargs)
else:
return self._synchronize(async_f, *args, **kwargs)
async def _check_hub_authorization(
self, url, api_token, cache_key=None, use_cache=True
):
"""Identify a user with the Hub
Args:
@@ -374,7 +427,7 @@ class HubAuth(SingletonConfigurable):
except KeyError:
app_log.debug("HubAuth cache miss: %s", cache_key)
data = self._api_request(
data = await self._api_request(
'GET',
url,
headers={"Authorization": "token " + api_token},
@@ -389,18 +442,26 @@ class HubAuth(SingletonConfigurable):
self.cache[cache_key] = data
return data
def _api_request(self, method, url, **kwargs):
async def _api_request(self, method, url, **kwargs):
"""Make an API request"""
allow_403 = kwargs.pop('allow_403', False)
headers = kwargs.setdefault('headers', {})
headers.setdefault('Authorization', 'token %s' % self.api_token)
if "cert" not in kwargs and self.certfile and self.keyfile:
kwargs["cert"] = (self.certfile, self.keyfile)
if self.client_ca:
kwargs["verify"] = self.client_ca
headers.setdefault('Authorization', f'token {self.api_token}')
# translate requests args to tornado's
if self.certfile:
kwargs["client_cert"] = self.certfile
if self.keyfile:
kwargs["client_key"] = self.keyfile
if self.client_ca:
kwargs["ca_certs"] = self.client_ca
req = HTTPRequest(
url,
method=method,
**kwargs,
)
try:
r = requests.request(method, url, **kwargs)
except requests.ConnectionError as e:
r = await AsyncHTTPClient().fetch(req, raise_error=False)
except Exception as e:
app_log.error("Error connecting to %s: %s", self.api_url, e)
msg = "Failed to connect to Hub API at %r." % self.api_url
msg += (
@@ -415,35 +476,46 @@ class HubAuth(SingletonConfigurable):
raise HTTPError(500, msg)
data = None
if r.status_code == 403 and allow_403:
try:
status = HTTPStatus(r.code)
except ValueError:
app_log.error(
f"Unknown error checking authorization with JupyterHub: {r.code}"
)
app_log.error(r.body.decode("utf8", "replace"))
response_text = r.body.decode("utf8", "replace")
if status.value == 403 and allow_403:
pass
elif r.status_code == 403:
elif status.value == 403:
app_log.error(
"I don't have permission to check authorization with JupyterHub, my auth token may have expired: [%i] %s",
r.status_code,
r.reason,
status.value,
status.description,
)
app_log.error(r.text)
app_log.error(response_text)
raise HTTPError(
500, "Permission failure checking authorization, I may need a new token"
)
elif r.status_code >= 500:
elif status.value >= 500:
app_log.error(
"Upstream failure verifying auth token: [%i] %s",
r.status_code,
r.reason,
status.value,
status.description,
)
app_log.error(r.text)
app_log.error(response_text)
raise HTTPError(502, "Failed to check authorization (upstream problem)")
elif r.status_code >= 400:
elif status.value >= 400:
app_log.warning(
"Failed to check authorization: [%i] %s", r.status_code, r.reason
"Failed to check authorization: [%i] %s",
status.value,
status.description,
)
app_log.warning(r.text)
app_log.warning(response_text)
msg = "Failed to check authorization"
# pass on error from oauth failure
try:
response = r.json()
response = json.loads(response_text)
# prefer more specific 'error_description', fallback to 'error'
description = response.get(
"error_description", response.get("error", "Unknown error")
@@ -454,7 +526,7 @@ class HubAuth(SingletonConfigurable):
msg += ": " + description
raise HTTPError(500, msg)
else:
data = r.json()
data = json.loads(response_text)
return data
@@ -464,19 +536,25 @@ class HubAuth(SingletonConfigurable):
"Identifying users by shared cookie is removed in JupyterHub 2.0. Use OAuth tokens."
)
def user_for_token(self, token, use_cache=True, session_id=''):
def user_for_token(self, token, use_cache=True, session_id='', *, sync=True):
"""Ask the Hub to identify the user for a given token.
.. versionadded:: 2.4
async support via `sync` argument.
Args:
token (str): the token
use_cache (bool): Specify use_cache=False to skip cached cookie values (default: True)
sync (bool): whether to block for the result or return an awaitable
Returns:
user_model (dict): The user model, if a user is identified, None if authentication fails.
The 'name' field contains the user's name.
"""
return self._check_hub_authorization(
return self._call_coroutine(
sync,
self._check_hub_authorization,
url=url_path_join(
self.api_url,
"user",
@@ -521,7 +599,7 @@ class HubAuth(SingletonConfigurable):
"""Base class doesn't store tokens in cookies"""
return None
def _get_user_cookie(self, handler):
async def _get_user_cookie(self, handler):
"""Get the user model from a cookie"""
# overridden in HubOAuth to store the access token after oauth
return None
@@ -533,20 +611,26 @@ class HubAuth(SingletonConfigurable):
"""
return handler.get_cookie('jupyterhub-session-id', '')
def get_user(self, handler):
def get_user(self, handler, *, sync=True):
"""Get the Hub user for a given tornado handler.
Checks cookie with the Hub to identify the current user.
.. versionadded:: 2.4
async support via `sync` argument.
Args:
handler (tornado.web.RequestHandler): the current request handler
sync (bool): whether to block for the result or return an awaitable
Returns:
user_model (dict): The user model, if a user is identified, None if authentication fails.
The 'name' field contains the user's name.
"""
return self._call_coroutine(sync, self._get_user, handler)
async def _get_user(self, handler):
# only allow this to be called once per handler
# avoids issues if an error is raised,
# since this may be called again when trying to render the error page
@@ -561,13 +645,15 @@ class HubAuth(SingletonConfigurable):
# is token-authenticated (CORS-related)
token = self.get_token(handler, in_cookie=False)
if token:
user_model = self.user_for_token(token, session_id=session_id)
user_model = await self.user_for_token(
token, session_id=session_id, sync=False
)
if user_model:
handler._token_authenticated = True
# no token, check cookie
if user_model is None:
user_model = self._get_user_cookie(handler)
user_model = await self._get_user_cookie(handler)
# cache result
handler._cached_hub_user = user_model
@@ -627,11 +713,13 @@ class HubOAuth(HubAuth):
token = token.decode('ascii', 'replace')
return token
def _get_user_cookie(self, handler):
async def _get_user_cookie(self, handler):
token = self._get_token_cookie(handler)
session_id = self.get_session_id(handler)
if token:
user_model = self.user_for_token(token, session_id=session_id)
user_model = await self.user_for_token(
token, session_id=session_id, sync=False
)
if user_model is None:
app_log.warning("Token stored in cookie may have expired")
handler.clear_cookie(self.cookie_name)
@@ -686,7 +774,7 @@ class HubOAuth(HubAuth):
def _token_url(self):
return url_path_join(self.api_url, 'oauth2/token')
def token_for_code(self, code):
def token_for_code(self, code, *, sync=True):
"""Get token for OAuth temporary code
This is the last step of OAuth login.
@@ -697,6 +785,9 @@ class HubOAuth(HubAuth):
Returns:
token (str): JupyterHub API Token
"""
return self._call_coroutine(sync, self._token_for_code, code)
async def _token_for_code(self, code):
# GitHub specifies a POST request yet requires URL parameters
params = dict(
client_id=self.oauth_client_id,
@@ -706,10 +797,10 @@ class HubOAuth(HubAuth):
redirect_uri=self.oauth_redirect_uri,
)
token_reply = self._api_request(
token_reply = await self._api_request(
'POST',
self.oauth_token_url,
data=urlencode(params).encode('utf8'),
body=urlencode(params).encode('utf8'),
headers={'Content-Type': 'application/x-www-form-urlencoded'},
)
@@ -1114,10 +1205,12 @@ class HubOAuthCallbackHandler(HubOAuthenticated, RequestHandler):
app_log.warning("oauth state %r != %r", arg_state, cookie_state)
raise HTTPError(403, "oauth state does not match. Try logging in again.")
next_url = self.hub_auth.get_next_url(cookie_state)
# TODO: make async (in a Thread?)
token = self.hub_auth.token_for_code(code)
token = await self.hub_auth.token_for_code(code, sync=False)
session_id = self.hub_auth.get_session_id(self)
user_model = self.hub_auth.user_for_token(token, session_id=session_id)
user_model = await self.hub_auth.user_for_token(
token, session_id=session_id, sync=False
)
if user_model is None:
raise HTTPError(500, "oauth callback failed to identify a user")
app_log.info("Logged-in user %s", user_model)