mirror of
https://github.com/jupyterhub/jupyterhub.git
synced 2025-10-09 19:13:03 +00:00
allow HubAuth to be async
Switches requests to tornado AsyncHTTPClient instead of requests For backward-compatibility, use opt-in `sync=False` arg for all public methods that _may_ be async When sync=True (default), async functions still used, but blocking via ThreadPool + asyncio run_until_complete
This commit is contained in:
@@ -23,6 +23,7 @@ If you are using OAuth, you will also need to register an oauth callback handler
|
||||
A tornado implementation is provided in :class:`HubOAuthCallbackHandler`.
|
||||
|
||||
"""
|
||||
import asyncio
|
||||
import base64
|
||||
import hashlib
|
||||
import json
|
||||
@@ -34,14 +35,26 @@ import string
|
||||
import time
|
||||
import uuid
|
||||
import warnings
|
||||
from functools import partial
|
||||
from http import HTTPStatus
|
||||
from unittest import mock
|
||||
from urllib.parse import urlencode
|
||||
|
||||
import requests
|
||||
from tornado.httpclient import AsyncHTTPClient, HTTPRequest
|
||||
from tornado.httputil import url_concat
|
||||
from tornado.log import app_log
|
||||
from tornado.web import HTTPError, RequestHandler
|
||||
from traitlets import Dict, Instance, Integer, Set, Unicode, default, observe, validate
|
||||
from traitlets import (
|
||||
Any,
|
||||
Dict,
|
||||
Instance,
|
||||
Integer,
|
||||
Set,
|
||||
Unicode,
|
||||
default,
|
||||
observe,
|
||||
validate,
|
||||
)
|
||||
from traitlets.config import SingletonConfigurable
|
||||
|
||||
from ..scopes import _intersect_expanded_scopes
|
||||
@@ -351,7 +364,47 @@ class HubAuth(SingletonConfigurable):
|
||||
return {f'access:services!service={service_name}'}
|
||||
return set()
|
||||
|
||||
def _check_hub_authorization(self, url, api_token, cache_key=None, use_cache=True):
|
||||
_pool = Any(help="Thread pool for running async methods in the background")
|
||||
|
||||
@default("_pool")
|
||||
def _new_pool(self):
|
||||
# start a single ThreadPool in the background
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
pool = ThreadPoolExecutor(1)
|
||||
# create an event loop in the thread
|
||||
pool.submit(self._setup_asyncio_thread).result()
|
||||
return pool
|
||||
|
||||
def _setup_asyncio_thread(self):
|
||||
"""Create asyncio loop
|
||||
|
||||
To be called from the background thread,
|
||||
so that any thread-local state is setup correctly
|
||||
"""
|
||||
self._thread_loop = asyncio.new_event_loop()
|
||||
|
||||
def _synchronize(self, async_f, *args, **kwargs):
|
||||
"""Call an async method in our background thread"""
|
||||
future = self._pool.submit(
|
||||
lambda: self._thread_loop.run_until_complete(async_f(*args, **kwargs))
|
||||
)
|
||||
return future.result()
|
||||
|
||||
def _call_coroutine(self, sync, async_f, *args, **kwargs):
|
||||
"""Call an async coroutine function, either blocking or returning an awaitable
|
||||
|
||||
if not sync: calls function directly, returning awaitable
|
||||
else: Block on a call in our background thread, return actual result
|
||||
"""
|
||||
if not sync:
|
||||
return async_f(*args, **kwargs)
|
||||
else:
|
||||
return self._synchronize(async_f, *args, **kwargs)
|
||||
|
||||
async def _check_hub_authorization(
|
||||
self, url, api_token, cache_key=None, use_cache=True
|
||||
):
|
||||
"""Identify a user with the Hub
|
||||
|
||||
Args:
|
||||
@@ -374,7 +427,7 @@ class HubAuth(SingletonConfigurable):
|
||||
except KeyError:
|
||||
app_log.debug("HubAuth cache miss: %s", cache_key)
|
||||
|
||||
data = self._api_request(
|
||||
data = await self._api_request(
|
||||
'GET',
|
||||
url,
|
||||
headers={"Authorization": "token " + api_token},
|
||||
@@ -389,18 +442,26 @@ class HubAuth(SingletonConfigurable):
|
||||
self.cache[cache_key] = data
|
||||
return data
|
||||
|
||||
def _api_request(self, method, url, **kwargs):
|
||||
async def _api_request(self, method, url, **kwargs):
|
||||
"""Make an API request"""
|
||||
allow_403 = kwargs.pop('allow_403', False)
|
||||
headers = kwargs.setdefault('headers', {})
|
||||
headers.setdefault('Authorization', 'token %s' % self.api_token)
|
||||
if "cert" not in kwargs and self.certfile and self.keyfile:
|
||||
kwargs["cert"] = (self.certfile, self.keyfile)
|
||||
if self.client_ca:
|
||||
kwargs["verify"] = self.client_ca
|
||||
headers.setdefault('Authorization', f'token {self.api_token}')
|
||||
# translate requests args to tornado's
|
||||
if self.certfile:
|
||||
kwargs["client_cert"] = self.certfile
|
||||
if self.keyfile:
|
||||
kwargs["client_key"] = self.keyfile
|
||||
if self.client_ca:
|
||||
kwargs["ca_certs"] = self.client_ca
|
||||
req = HTTPRequest(
|
||||
url,
|
||||
method=method,
|
||||
**kwargs,
|
||||
)
|
||||
try:
|
||||
r = requests.request(method, url, **kwargs)
|
||||
except requests.ConnectionError as e:
|
||||
r = await AsyncHTTPClient().fetch(req, raise_error=False)
|
||||
except Exception as e:
|
||||
app_log.error("Error connecting to %s: %s", self.api_url, e)
|
||||
msg = "Failed to connect to Hub API at %r." % self.api_url
|
||||
msg += (
|
||||
@@ -415,35 +476,46 @@ class HubAuth(SingletonConfigurable):
|
||||
raise HTTPError(500, msg)
|
||||
|
||||
data = None
|
||||
if r.status_code == 403 and allow_403:
|
||||
try:
|
||||
status = HTTPStatus(r.code)
|
||||
except ValueError:
|
||||
app_log.error(
|
||||
f"Unknown error checking authorization with JupyterHub: {r.code}"
|
||||
)
|
||||
app_log.error(r.body.decode("utf8", "replace"))
|
||||
|
||||
response_text = r.body.decode("utf8", "replace")
|
||||
if status.value == 403 and allow_403:
|
||||
pass
|
||||
elif r.status_code == 403:
|
||||
elif status.value == 403:
|
||||
app_log.error(
|
||||
"I don't have permission to check authorization with JupyterHub, my auth token may have expired: [%i] %s",
|
||||
r.status_code,
|
||||
r.reason,
|
||||
status.value,
|
||||
status.description,
|
||||
)
|
||||
app_log.error(r.text)
|
||||
app_log.error(response_text)
|
||||
raise HTTPError(
|
||||
500, "Permission failure checking authorization, I may need a new token"
|
||||
)
|
||||
elif r.status_code >= 500:
|
||||
elif status.value >= 500:
|
||||
app_log.error(
|
||||
"Upstream failure verifying auth token: [%i] %s",
|
||||
r.status_code,
|
||||
r.reason,
|
||||
status.value,
|
||||
status.description,
|
||||
)
|
||||
app_log.error(r.text)
|
||||
app_log.error(response_text)
|
||||
raise HTTPError(502, "Failed to check authorization (upstream problem)")
|
||||
elif r.status_code >= 400:
|
||||
elif status.value >= 400:
|
||||
app_log.warning(
|
||||
"Failed to check authorization: [%i] %s", r.status_code, r.reason
|
||||
"Failed to check authorization: [%i] %s",
|
||||
status.value,
|
||||
status.description,
|
||||
)
|
||||
app_log.warning(r.text)
|
||||
app_log.warning(response_text)
|
||||
msg = "Failed to check authorization"
|
||||
# pass on error from oauth failure
|
||||
try:
|
||||
response = r.json()
|
||||
response = json.loads(response_text)
|
||||
# prefer more specific 'error_description', fallback to 'error'
|
||||
description = response.get(
|
||||
"error_description", response.get("error", "Unknown error")
|
||||
@@ -454,7 +526,7 @@ class HubAuth(SingletonConfigurable):
|
||||
msg += ": " + description
|
||||
raise HTTPError(500, msg)
|
||||
else:
|
||||
data = r.json()
|
||||
data = json.loads(response_text)
|
||||
|
||||
return data
|
||||
|
||||
@@ -464,19 +536,25 @@ class HubAuth(SingletonConfigurable):
|
||||
"Identifying users by shared cookie is removed in JupyterHub 2.0. Use OAuth tokens."
|
||||
)
|
||||
|
||||
def user_for_token(self, token, use_cache=True, session_id=''):
|
||||
def user_for_token(self, token, use_cache=True, session_id='', *, sync=True):
|
||||
"""Ask the Hub to identify the user for a given token.
|
||||
|
||||
.. versionadded:: 2.4
|
||||
async support via `sync` argument.
|
||||
|
||||
Args:
|
||||
token (str): the token
|
||||
use_cache (bool): Specify use_cache=False to skip cached cookie values (default: True)
|
||||
sync (bool): whether to block for the result or return an awaitable
|
||||
|
||||
Returns:
|
||||
user_model (dict): The user model, if a user is identified, None if authentication fails.
|
||||
|
||||
The 'name' field contains the user's name.
|
||||
"""
|
||||
return self._check_hub_authorization(
|
||||
return self._call_coroutine(
|
||||
sync,
|
||||
self._check_hub_authorization,
|
||||
url=url_path_join(
|
||||
self.api_url,
|
||||
"user",
|
||||
@@ -521,7 +599,7 @@ class HubAuth(SingletonConfigurable):
|
||||
"""Base class doesn't store tokens in cookies"""
|
||||
return None
|
||||
|
||||
def _get_user_cookie(self, handler):
|
||||
async def _get_user_cookie(self, handler):
|
||||
"""Get the user model from a cookie"""
|
||||
# overridden in HubOAuth to store the access token after oauth
|
||||
return None
|
||||
@@ -533,20 +611,26 @@ class HubAuth(SingletonConfigurable):
|
||||
"""
|
||||
return handler.get_cookie('jupyterhub-session-id', '')
|
||||
|
||||
def get_user(self, handler):
|
||||
def get_user(self, handler, *, sync=True):
|
||||
"""Get the Hub user for a given tornado handler.
|
||||
|
||||
Checks cookie with the Hub to identify the current user.
|
||||
|
||||
.. versionadded:: 2.4
|
||||
async support via `sync` argument.
|
||||
|
||||
Args:
|
||||
handler (tornado.web.RequestHandler): the current request handler
|
||||
sync (bool): whether to block for the result or return an awaitable
|
||||
|
||||
Returns:
|
||||
user_model (dict): The user model, if a user is identified, None if authentication fails.
|
||||
|
||||
The 'name' field contains the user's name.
|
||||
"""
|
||||
return self._call_coroutine(sync, self._get_user, handler)
|
||||
|
||||
async def _get_user(self, handler):
|
||||
# only allow this to be called once per handler
|
||||
# avoids issues if an error is raised,
|
||||
# since this may be called again when trying to render the error page
|
||||
@@ -561,13 +645,15 @@ class HubAuth(SingletonConfigurable):
|
||||
# is token-authenticated (CORS-related)
|
||||
token = self.get_token(handler, in_cookie=False)
|
||||
if token:
|
||||
user_model = self.user_for_token(token, session_id=session_id)
|
||||
user_model = await self.user_for_token(
|
||||
token, session_id=session_id, sync=False
|
||||
)
|
||||
if user_model:
|
||||
handler._token_authenticated = True
|
||||
|
||||
# no token, check cookie
|
||||
if user_model is None:
|
||||
user_model = self._get_user_cookie(handler)
|
||||
user_model = await self._get_user_cookie(handler)
|
||||
|
||||
# cache result
|
||||
handler._cached_hub_user = user_model
|
||||
@@ -627,11 +713,13 @@ class HubOAuth(HubAuth):
|
||||
token = token.decode('ascii', 'replace')
|
||||
return token
|
||||
|
||||
def _get_user_cookie(self, handler):
|
||||
async def _get_user_cookie(self, handler):
|
||||
token = self._get_token_cookie(handler)
|
||||
session_id = self.get_session_id(handler)
|
||||
if token:
|
||||
user_model = self.user_for_token(token, session_id=session_id)
|
||||
user_model = await self.user_for_token(
|
||||
token, session_id=session_id, sync=False
|
||||
)
|
||||
if user_model is None:
|
||||
app_log.warning("Token stored in cookie may have expired")
|
||||
handler.clear_cookie(self.cookie_name)
|
||||
@@ -686,7 +774,7 @@ class HubOAuth(HubAuth):
|
||||
def _token_url(self):
|
||||
return url_path_join(self.api_url, 'oauth2/token')
|
||||
|
||||
def token_for_code(self, code):
|
||||
def token_for_code(self, code, *, sync=True):
|
||||
"""Get token for OAuth temporary code
|
||||
|
||||
This is the last step of OAuth login.
|
||||
@@ -697,6 +785,9 @@ class HubOAuth(HubAuth):
|
||||
Returns:
|
||||
token (str): JupyterHub API Token
|
||||
"""
|
||||
return self._call_coroutine(sync, self._token_for_code, code)
|
||||
|
||||
async def _token_for_code(self, code):
|
||||
# GitHub specifies a POST request yet requires URL parameters
|
||||
params = dict(
|
||||
client_id=self.oauth_client_id,
|
||||
@@ -706,10 +797,10 @@ class HubOAuth(HubAuth):
|
||||
redirect_uri=self.oauth_redirect_uri,
|
||||
)
|
||||
|
||||
token_reply = self._api_request(
|
||||
token_reply = await self._api_request(
|
||||
'POST',
|
||||
self.oauth_token_url,
|
||||
data=urlencode(params).encode('utf8'),
|
||||
body=urlencode(params).encode('utf8'),
|
||||
headers={'Content-Type': 'application/x-www-form-urlencoded'},
|
||||
)
|
||||
|
||||
@@ -1114,10 +1205,12 @@ class HubOAuthCallbackHandler(HubOAuthenticated, RequestHandler):
|
||||
app_log.warning("oauth state %r != %r", arg_state, cookie_state)
|
||||
raise HTTPError(403, "oauth state does not match. Try logging in again.")
|
||||
next_url = self.hub_auth.get_next_url(cookie_state)
|
||||
# TODO: make async (in a Thread?)
|
||||
token = self.hub_auth.token_for_code(code)
|
||||
|
||||
token = await self.hub_auth.token_for_code(code, sync=False)
|
||||
session_id = self.hub_auth.get_session_id(self)
|
||||
user_model = self.hub_auth.user_for_token(token, session_id=session_id)
|
||||
user_model = await self.hub_auth.user_for_token(
|
||||
token, session_id=session_id, sync=False
|
||||
)
|
||||
if user_model is None:
|
||||
raise HTTPError(500, "oauth callback failed to identify a user")
|
||||
app_log.info("Logged-in user %s", user_model)
|
||||
|
Reference in New Issue
Block a user