mirror of
https://github.com/jupyterhub/jupyterhub.git
synced 2025-10-16 14:33:00 +00:00
237 lines
7.4 KiB
Python
237 lines
7.4 KiB
Python
import asyncio
|
|
import inspect
|
|
import os
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
|
|
import pytest
|
|
import requests
|
|
from certipy import Certipy
|
|
from sqlalchemy import text
|
|
from tornado.httputil import url_concat
|
|
|
|
from jupyterhub import metrics, orm
|
|
from jupyterhub.objects import Server
|
|
from jupyterhub.roles import assign_default_roles, update_roles
|
|
from jupyterhub.utils import url_path_join as ujoin
|
|
|
|
|
|
class _AsyncRequests:
|
|
"""Wrapper around requests to return a Future from request methods
|
|
|
|
A single thread is allocated to avoid blocking the IOLoop thread.
|
|
"""
|
|
|
|
def __init__(self):
|
|
self.executor = ThreadPoolExecutor(1)
|
|
real_submit = self.executor.submit
|
|
self.executor.submit = lambda *args, **kwargs: asyncio.wrap_future(
|
|
real_submit(*args, **kwargs)
|
|
)
|
|
|
|
def __getattr__(self, name):
|
|
requests_method = getattr(requests, name)
|
|
return lambda *args, **kwargs: self.executor.submit(
|
|
requests_method, *args, **kwargs
|
|
)
|
|
|
|
|
|
# async_requests.get = requests.get returning a Future, etc.
|
|
async_requests = _AsyncRequests()
|
|
|
|
|
|
class AsyncSession(requests.Session):
|
|
"""requests.Session object that runs in the background thread"""
|
|
|
|
def __init__(self, **kwargs):
|
|
super().__init__(**kwargs)
|
|
# session requests are for cookie authentication
|
|
# and should look like regular page views,
|
|
# so set Sec-Fetch-Mode: navigate
|
|
self.headers.setdefault("Sec-Fetch-Mode", "navigate")
|
|
|
|
def request(self, *args, **kwargs):
|
|
return async_requests.executor.submit(super().request, *args, **kwargs)
|
|
|
|
|
|
def ssl_setup(cert_dir, authority_name):
|
|
# Set up the external certs with the same authority as the internal
|
|
# one so that certificate trust works regardless of chosen endpoint.
|
|
certipy = Certipy(store_dir=cert_dir)
|
|
alt_names = ["DNS:localhost", "IP:127.0.0.1"]
|
|
internal_authority = certipy.create_ca(authority_name, overwrite=True)
|
|
external_certs = certipy.create_signed_pair(
|
|
"external", authority_name, overwrite=True, alt_names=alt_names
|
|
)
|
|
return external_certs
|
|
|
|
|
|
"""Skip tests that don't work under internal-ssl when testing under internal-ssl"""
|
|
skip_if_ssl = pytest.mark.skipif(
|
|
os.environ.get('SSL_ENABLED', False), reason="Does not use internal SSL"
|
|
)
|
|
|
|
|
|
def check_db_locks(func):
|
|
"""Decorator that verifies no locks are held on database upon exit.
|
|
|
|
This decorator for test functions verifies no locks are held on the
|
|
application's database upon exit by creating and dropping a dummy table.
|
|
|
|
The decorator relies on an instance of JupyterHubApp being the first
|
|
argument to the decorated function.
|
|
|
|
Examples
|
|
--------
|
|
@check_db_locks
|
|
def api_request(app, *api_path, **kwargs):
|
|
|
|
"""
|
|
|
|
def new_func(app, *args, **kwargs):
|
|
maybe_future = func(app, *args, **kwargs)
|
|
|
|
def _check(_=None):
|
|
temp_session = app.session_factory()
|
|
try:
|
|
temp_session.execute(text('CREATE TABLE dummy (foo INT)'))
|
|
temp_session.execute(text('DROP TABLE dummy'))
|
|
finally:
|
|
temp_session.close()
|
|
|
|
async def await_then_check():
|
|
result = await maybe_future
|
|
_check()
|
|
return result
|
|
|
|
if inspect.isawaitable(maybe_future):
|
|
return await_then_check()
|
|
else:
|
|
_check()
|
|
return maybe_future
|
|
|
|
return new_func
|
|
|
|
|
|
def find_user(db, name, app=None):
|
|
"""Find user in database."""
|
|
orm_user = db.query(orm.User).filter(orm.User.name == name).first()
|
|
if app is None:
|
|
return orm_user
|
|
else:
|
|
return app.users[orm_user.id]
|
|
|
|
|
|
def add_user(db, app=None, **kwargs):
|
|
"""Add a user to the database."""
|
|
orm_user = find_user(db, name=kwargs.get('name'))
|
|
if orm_user is None:
|
|
orm_user = orm.User(**kwargs)
|
|
db.add(orm_user)
|
|
metrics.TOTAL_USERS.inc()
|
|
else:
|
|
for attr, value in kwargs.items():
|
|
setattr(orm_user, attr, value)
|
|
db.commit()
|
|
requested_roles = kwargs.get('roles')
|
|
if requested_roles:
|
|
update_roles(db, entity=orm_user, roles=requested_roles)
|
|
else:
|
|
assign_default_roles(db, entity=orm_user)
|
|
if app:
|
|
return app.users[orm_user.id]
|
|
else:
|
|
return orm_user
|
|
|
|
|
|
def auth_header(db, name):
|
|
"""Return header with user's API authorization token."""
|
|
user = find_user(db, name)
|
|
if user is None:
|
|
raise KeyError(f"No such user: {name}")
|
|
token = user.new_api_token()
|
|
return {'Authorization': f'token {token}'}
|
|
|
|
|
|
@check_db_locks
|
|
async def api_request(
|
|
app, *api_path, method='get', noauth=False, bypass_proxy=False, **kwargs
|
|
):
|
|
"""Make an API request"""
|
|
if bypass_proxy:
|
|
# make a direct request to the hub,
|
|
# skipping the proxy
|
|
base_url = app.hub.url
|
|
else:
|
|
base_url = public_url(app, path='hub')
|
|
headers = kwargs.setdefault('headers', {})
|
|
headers.setdefault("Sec-Fetch-Mode", "cors")
|
|
if 'Authorization' not in headers and not noauth and 'cookies' not in kwargs:
|
|
# make a copy to avoid modifying arg in-place
|
|
kwargs['headers'] = h = {}
|
|
h.update(headers)
|
|
h.update(auth_header(app.db, kwargs.pop('name', 'admin')))
|
|
|
|
url = ujoin(base_url, 'api', *api_path)
|
|
|
|
if 'cookies' in kwargs:
|
|
# for cookie-authenticated requests,
|
|
# add _xsrf to url params
|
|
if "_xsrf" in kwargs['cookies'] and not noauth:
|
|
url = url_concat(url, {"_xsrf": kwargs['cookies']['_xsrf']})
|
|
|
|
f = getattr(async_requests, method)
|
|
if app.internal_ssl:
|
|
kwargs['cert'] = (app.internal_ssl_cert, app.internal_ssl_key)
|
|
kwargs["verify"] = app.internal_ssl_ca
|
|
resp = await f(url, **kwargs)
|
|
assert "frame-ancestors 'none'" in resp.headers['Content-Security-Policy']
|
|
assert (
|
|
ujoin(app.hub.base_url, "security/csp-report")
|
|
in resp.headers['Content-Security-Policy']
|
|
)
|
|
assert 'http' not in resp.headers['Content-Security-Policy']
|
|
if not kwargs.get('stream', False) and resp.content:
|
|
assert resp.headers.get('content-type') == 'application/json'
|
|
return resp
|
|
|
|
|
|
def get_page(path, app, hub=True, **kw):
|
|
if "://" in path:
|
|
raise ValueError(
|
|
f"Not a hub page path: {path!r}. Did you mean async_requests.get?"
|
|
)
|
|
if hub:
|
|
prefix = app.hub.base_url
|
|
else:
|
|
prefix = app.base_url
|
|
base_url = ujoin(public_host(app), prefix)
|
|
# Sec-Fetch-Mode=navigate to look like a regular page view
|
|
headers = kw.setdefault("headers", {})
|
|
headers.setdefault("Sec-Fetch-Mode", "navigate")
|
|
return async_requests.get(ujoin(base_url, path), **kw)
|
|
|
|
|
|
def public_host(app):
|
|
"""Return the public *host* (no URL prefix) of the given JupyterHub instance."""
|
|
if app.subdomain_host:
|
|
return app.subdomain_host
|
|
else:
|
|
return Server.from_url(app.proxy.public_url).host
|
|
|
|
|
|
def public_url(app, user_or_service=None, path=''):
|
|
"""Return the full, public base URL (including prefix) of the given JupyterHub instance."""
|
|
if user_or_service:
|
|
if app.subdomain_host:
|
|
host = user_or_service.host
|
|
else:
|
|
host = public_host(app)
|
|
prefix = user_or_service.prefix
|
|
else:
|
|
host = public_host(app)
|
|
prefix = Server.from_url(app.proxy.public_url).base_url
|
|
if path:
|
|
return host + ujoin(prefix, path)
|
|
else:
|
|
return host + prefix
|