mirror of
https://github.com/jupyterhub/jupyterhub.git
synced 2025-10-16 14:33:00 +00:00
932 lines
28 KiB
Python
932 lines
28 KiB
Python
"""Miscellaneous utilities"""
|
|
|
|
# Copyright (c) Jupyter Development Team.
|
|
# Distributed under the terms of the Modified BSD License.
|
|
import asyncio
|
|
import concurrent.futures
|
|
import errno
|
|
import functools
|
|
import hashlib
|
|
import inspect
|
|
import random
|
|
import re
|
|
import secrets
|
|
import socket
|
|
import ssl
|
|
import string
|
|
import sys
|
|
import threading
|
|
import time
|
|
import uuid
|
|
import warnings
|
|
from binascii import b2a_hex
|
|
from datetime import datetime, timezone
|
|
from functools import lru_cache
|
|
from hmac import compare_digest
|
|
from operator import itemgetter
|
|
from urllib.parse import quote
|
|
|
|
import idna
|
|
from async_generator import aclosing
|
|
from sqlalchemy.exc import SQLAlchemyError
|
|
from tornado import gen, ioloop, web
|
|
from tornado.httpclient import AsyncHTTPClient, HTTPError
|
|
from tornado.log import app_log
|
|
|
|
|
|
# Deprecated aliases: no longer needed now that we require 3.7
|
|
def asyncio_all_tasks(loop=None):
|
|
warnings.warn(
|
|
"jupyterhub.utils.asyncio_all_tasks is deprecated in JupyterHub 2.4."
|
|
" Use asyncio.all_tasks().",
|
|
DeprecationWarning,
|
|
stacklevel=2,
|
|
)
|
|
return asyncio.all_tasks(loop=loop)
|
|
|
|
|
|
def asyncio_current_task(loop=None):
|
|
warnings.warn(
|
|
"jupyterhub.utils.asyncio_current_task is deprecated in JupyterHub 2.4."
|
|
" Use asyncio.current_task().",
|
|
DeprecationWarning,
|
|
stacklevel=2,
|
|
)
|
|
return asyncio.current_task(loop=loop)
|
|
|
|
|
|
def random_port():
|
|
"""Get a single random port."""
|
|
sock = socket.socket()
|
|
sock.bind(('', 0))
|
|
port = sock.getsockname()[1]
|
|
sock.close()
|
|
return port
|
|
|
|
|
|
# ISO8601 for strptime with/without milliseconds
|
|
ISO8601_ms = '%Y-%m-%dT%H:%M:%S.%fZ'
|
|
ISO8601_s = '%Y-%m-%dT%H:%M:%SZ'
|
|
|
|
|
|
def isoformat(dt):
|
|
"""Render a datetime object as an ISO 8601 UTC timestamp
|
|
|
|
Naive datetime objects are assumed to be UTC
|
|
"""
|
|
# allow null timestamps to remain None without
|
|
# having to check if isoformat should be called
|
|
if dt is None:
|
|
return None
|
|
if dt.tzinfo:
|
|
dt = dt.astimezone(timezone.utc).replace(tzinfo=None)
|
|
return dt.isoformat() + 'Z'
|
|
|
|
|
|
def can_connect(ip, port):
|
|
"""Check if we can connect to an ip:port.
|
|
|
|
Return True if we can connect, False otherwise.
|
|
"""
|
|
if ip in {'', '0.0.0.0', '::'}:
|
|
ip = '127.0.0.1'
|
|
try:
|
|
socket.create_connection((ip, port)).close()
|
|
except OSError as e:
|
|
if e.errno not in {errno.ECONNREFUSED, errno.ETIMEDOUT}:
|
|
app_log.error("Unexpected error connecting to %s:%i %s", ip, port, e)
|
|
else:
|
|
app_log.debug("Server at %s:%i not ready: %s", ip, port, e)
|
|
return False
|
|
else:
|
|
return True
|
|
|
|
|
|
def make_ssl_context(
|
|
keyfile,
|
|
certfile,
|
|
cafile=None,
|
|
verify=None,
|
|
check_hostname=None,
|
|
purpose=ssl.Purpose.SERVER_AUTH,
|
|
):
|
|
"""Setup context for starting an https server or making requests over ssl.
|
|
|
|
Used for verifying internal ssl connections.
|
|
Certificates are always verified in both directions.
|
|
Hostnames are checked for client sockets.
|
|
|
|
Client sockets are created with `purpose=ssl.Purpose.SERVER_AUTH` (default),
|
|
Server sockets are created with `purpose=ssl.Purpose.CLIENT_AUTH`.
|
|
"""
|
|
if not keyfile or not certfile:
|
|
return None
|
|
if verify is not None:
|
|
purpose = ssl.Purpose.SERVER_AUTH if verify else ssl.Purpose.CLIENT_AUTH
|
|
warnings.warn(
|
|
f"make_ssl_context(verify={verify}) is deprecated in jupyterhub 2.4."
|
|
f" Use make_ssl_context(purpose={purpose!s}).",
|
|
DeprecationWarning,
|
|
stacklevel=2,
|
|
)
|
|
if check_hostname is not None:
|
|
purpose = ssl.Purpose.SERVER_AUTH if check_hostname else ssl.Purpose.CLIENT_AUTH
|
|
warnings.warn(
|
|
f"make_ssl_context(check_hostname={check_hostname}) is deprecated in jupyterhub 2.4."
|
|
f" Use make_ssl_context(purpose={purpose!s}).",
|
|
DeprecationWarning,
|
|
stacklevel=2,
|
|
)
|
|
|
|
ssl_context = ssl.create_default_context(purpose, cafile=cafile)
|
|
# always verify
|
|
ssl_context.verify_mode = ssl.CERT_REQUIRED
|
|
|
|
if purpose == ssl.Purpose.SERVER_AUTH:
|
|
# SERVER_AUTH is authenticating servers (i.e. for a client)
|
|
ssl_context.check_hostname = True
|
|
ssl_context.load_default_certs()
|
|
|
|
ssl_context.load_cert_chain(certfile, keyfile)
|
|
ssl_context.check_hostname = check_hostname
|
|
return ssl_context
|
|
|
|
|
|
# AnyTimeoutError catches TimeoutErrors coming from asyncio, tornado, stdlib
|
|
AnyTimeoutError = (gen.TimeoutError, asyncio.TimeoutError, TimeoutError)
|
|
|
|
|
|
async def exponential_backoff(
|
|
pass_func,
|
|
fail_message,
|
|
start_wait=0.2,
|
|
scale_factor=2,
|
|
max_wait=5,
|
|
timeout=10,
|
|
timeout_tolerance=0.1,
|
|
*args,
|
|
**kwargs,
|
|
):
|
|
"""
|
|
Exponentially backoff until `pass_func` is true.
|
|
|
|
The `pass_func` function will wait with **exponential backoff** and
|
|
**random jitter** for as many needed iterations of the Tornado loop,
|
|
until reaching maximum `timeout` or truthiness. If `pass_func` is still
|
|
returning false at `timeout`, a `TimeoutError` will be raised.
|
|
|
|
The first iteration will begin with a wait time of `start_wait` seconds.
|
|
Each subsequent iteration's wait time will scale up by continuously
|
|
multiplying itself by `scale_factor`. This continues for each iteration
|
|
until `pass_func` returns true or an iteration's wait time has reached
|
|
the `max_wait` seconds per iteration.
|
|
|
|
`pass_func` may be a future, although that is not entirely recommended.
|
|
|
|
Parameters
|
|
----------
|
|
pass_func
|
|
function that is to be run
|
|
fail_message : str
|
|
message for a `TimeoutError`
|
|
start_wait : optional
|
|
initial wait time for the first iteration in seconds
|
|
scale_factor : optional
|
|
a multiplier to increase the wait time for each iteration
|
|
max_wait : optional
|
|
maximum wait time per iteration in seconds
|
|
timeout : optional
|
|
maximum time of total wait in seconds
|
|
timeout_tolerance : optional
|
|
a small multiplier used to add jitter to `timeout`'s deadline
|
|
*args, **kwargs
|
|
passed to `pass_func(*args, **kwargs)`
|
|
|
|
Returns
|
|
-------
|
|
value of `pass_func(*args, **kwargs)`
|
|
|
|
Raises
|
|
------
|
|
TimeoutError
|
|
If `pass_func` is still false at the end of the `timeout` period.
|
|
|
|
Notes
|
|
-----
|
|
See https://www.awsarchitectureblog.com/2015/03/backoff.html
|
|
for information about the algorithm and examples. We're using their
|
|
full Jitter implementation equivalent.
|
|
"""
|
|
loop = ioloop.IOLoop.current()
|
|
deadline = loop.time() + timeout
|
|
# add jitter to the deadline itself to prevent re-align of a bunch of
|
|
# timing out calls once the deadline is reached.
|
|
if timeout_tolerance:
|
|
tol = timeout_tolerance * timeout
|
|
deadline = random.uniform(deadline - tol, deadline + tol)
|
|
scale = 1
|
|
while True:
|
|
ret = await maybe_future(pass_func(*args, **kwargs))
|
|
# Truthy!
|
|
if ret:
|
|
return ret
|
|
remaining = deadline - loop.time()
|
|
if remaining < 0:
|
|
# timeout exceeded
|
|
break
|
|
# add some random jitter to improve performance
|
|
# this prevents overloading any single tornado loop iteration with
|
|
# too many things
|
|
limit = min(max_wait, start_wait * scale)
|
|
if limit < max_wait:
|
|
scale *= scale_factor
|
|
dt = min(remaining, random.uniform(0, limit))
|
|
await asyncio.sleep(dt)
|
|
raise asyncio.TimeoutError(fail_message)
|
|
|
|
|
|
async def wait_for_server(ip, port, timeout=10):
|
|
"""Wait for any server to show up at ip:port."""
|
|
if ip in {'', '0.0.0.0', '::'}:
|
|
ip = '127.0.0.1'
|
|
app_log.debug("Waiting %ss for server at %s:%s", timeout, ip, port)
|
|
tic = time.perf_counter()
|
|
await exponential_backoff(
|
|
lambda: can_connect(ip, port),
|
|
"Server at {ip}:{port} didn't respond in {timeout} seconds".format(
|
|
ip=ip, port=port, timeout=timeout
|
|
),
|
|
timeout=timeout,
|
|
)
|
|
toc = time.perf_counter()
|
|
app_log.debug("Server at %s:%s responded in %.2fs", ip, port, toc - tic)
|
|
|
|
|
|
async def wait_for_http_server(url, timeout=10, ssl_context=None):
|
|
"""Wait for an HTTP Server to respond at url.
|
|
|
|
Any non-5XX response code will do, even 404.
|
|
"""
|
|
client = AsyncHTTPClient()
|
|
if ssl_context:
|
|
client.ssl_options = ssl_context
|
|
|
|
app_log.debug("Waiting %ss for server at %s", timeout, url)
|
|
tic = time.perf_counter()
|
|
|
|
async def is_reachable():
|
|
try:
|
|
r = await client.fetch(url, follow_redirects=False)
|
|
return r
|
|
except HTTPError as e:
|
|
if e.code >= 500:
|
|
# failed to respond properly, wait and try again
|
|
if e.code != 599:
|
|
# we expect 599 for no connection,
|
|
# but 502 or other proxy error is conceivable
|
|
app_log.warning(
|
|
"Server at %s responded with error: %s", url, e.code
|
|
)
|
|
else:
|
|
app_log.debug("Server at %s responded with %s", url, e.code)
|
|
return e.response
|
|
except OSError as e:
|
|
if e.errno not in {
|
|
errno.ECONNABORTED,
|
|
errno.ECONNREFUSED,
|
|
errno.ECONNRESET,
|
|
}:
|
|
app_log.warning("Failed to connect to %s (%s)", url, e)
|
|
return False
|
|
|
|
re = await exponential_backoff(
|
|
is_reachable,
|
|
"Server at {url} didn't respond in {timeout} seconds".format(
|
|
url=url, timeout=timeout
|
|
),
|
|
timeout=timeout,
|
|
)
|
|
toc = time.perf_counter()
|
|
app_log.debug("Server at %s responded in %.2fs", url, toc - tic)
|
|
return re
|
|
|
|
|
|
# Decorators for authenticated Handlers
|
|
def auth_decorator(check_auth):
|
|
"""Make an authentication decorator.
|
|
|
|
I heard you like decorators, so I put a decorator
|
|
in your decorator, so you can decorate while you decorate.
|
|
"""
|
|
|
|
def decorator(method):
|
|
def decorated(self, *args, **kwargs):
|
|
check_auth(self, **kwargs)
|
|
return method(self, *args, **kwargs)
|
|
|
|
# Perhaps replace with functools.wrap
|
|
decorated.__name__ = method.__name__
|
|
decorated.__doc__ = method.__doc__
|
|
return decorated
|
|
|
|
decorator.__name__ = check_auth.__name__
|
|
decorator.__doc__ = check_auth.__doc__
|
|
return decorator
|
|
|
|
|
|
@auth_decorator
|
|
def token_authenticated(self):
|
|
"""Decorator for method authenticated only by Authorization token header
|
|
|
|
(no cookies)
|
|
"""
|
|
if self.get_current_user_token() is None:
|
|
raise web.HTTPError(403)
|
|
|
|
|
|
@auth_decorator
|
|
def authenticated_403(self):
|
|
"""Decorator for method to raise 403 error instead of redirect to login
|
|
|
|
Like tornado.web.authenticated, this decorator raises a 403 error
|
|
instead of redirecting to login.
|
|
"""
|
|
if self.current_user is None:
|
|
raise web.HTTPError(403)
|
|
|
|
|
|
def admin_only(f):
|
|
"""Deprecated!"""
|
|
# write it this way to trigger deprecation warning at decoration time,
|
|
# not on the method call
|
|
warnings.warn(
|
|
"""@jupyterhub.utils.admin_only is deprecated in JupyterHub 2.0.
|
|
|
|
Use the new `@jupyterhub.scopes.needs_scope` decorator to resolve permissions,
|
|
or check against `self.current_user.parsed_scopes`.
|
|
""",
|
|
DeprecationWarning,
|
|
stacklevel=2,
|
|
)
|
|
|
|
# the original decorator
|
|
@auth_decorator
|
|
def admin_only(self):
|
|
"""Decorator for restricting access to admin users"""
|
|
user = self.current_user
|
|
if user is None or not user.admin:
|
|
raise web.HTTPError(403)
|
|
|
|
return admin_only(f)
|
|
|
|
|
|
@auth_decorator
|
|
def metrics_authentication(self):
|
|
"""Decorator for restricting access to metrics"""
|
|
if not self.authenticate_prometheus:
|
|
return
|
|
scope = 'read:metrics'
|
|
if scope not in self.parsed_scopes:
|
|
raise web.HTTPError(403, f"Access to metrics requires scope '{scope}'")
|
|
|
|
|
|
# Token utilities
|
|
|
|
|
|
def new_token(*args, **kwargs):
|
|
"""Generator for new random tokens
|
|
|
|
For now, just UUIDs.
|
|
"""
|
|
return uuid.uuid4().hex
|
|
|
|
|
|
def hash_token(token, salt=8, rounds=16384, algorithm='sha512'):
|
|
"""Hash a token, and return it as `algorithm:salt:hash`.
|
|
|
|
If `salt` is an integer, a random salt of that many bytes will be used.
|
|
"""
|
|
h = hashlib.new(algorithm)
|
|
if isinstance(salt, int):
|
|
salt = b2a_hex(secrets.token_bytes(salt))
|
|
if isinstance(salt, bytes):
|
|
bsalt = salt
|
|
salt = salt.decode('utf8')
|
|
else:
|
|
bsalt = salt.encode('utf8')
|
|
btoken = token.encode('utf8', 'replace')
|
|
h.update(bsalt)
|
|
for i in range(rounds):
|
|
h.update(btoken)
|
|
digest = h.hexdigest()
|
|
|
|
return f"{algorithm}:{rounds}:{salt}:{digest}"
|
|
|
|
|
|
def compare_token(compare, token):
|
|
"""Compare a token with a hashed token.
|
|
|
|
Uses the same algorithm and salt of the hashed token for comparison.
|
|
"""
|
|
algorithm, srounds, salt, _ = compare.split(':')
|
|
hashed = hash_token(
|
|
token, salt=salt, rounds=int(srounds), algorithm=algorithm
|
|
).encode('utf8')
|
|
compare = compare.encode('utf8')
|
|
if compare_digest(compare, hashed):
|
|
return True
|
|
return False
|
|
|
|
|
|
def url_escape_path(value):
|
|
"""Escape a value to be used in URLs, cookies, etc."""
|
|
return quote(value, safe='@~')
|
|
|
|
|
|
def url_path_join(*pieces):
|
|
"""Join components of url into a relative url.
|
|
|
|
Use to prevent double slash when joining subpath. This will leave the
|
|
initial and final / in place.
|
|
|
|
Copied from `notebook.utils.url_path_join`.
|
|
"""
|
|
initial = pieces[0].startswith('/')
|
|
final = pieces[-1].endswith('/')
|
|
stripped = [s.strip('/') for s in pieces]
|
|
result = '/'.join(s for s in stripped if s)
|
|
|
|
if initial:
|
|
result = '/' + result
|
|
if final:
|
|
result = result + '/'
|
|
if result == '//':
|
|
result = '/'
|
|
|
|
return result
|
|
|
|
|
|
def print_ps_info(file=sys.stderr):
|
|
"""Print process summary info from psutil
|
|
|
|
warns if psutil is unavailable
|
|
"""
|
|
try:
|
|
import psutil
|
|
except ImportError:
|
|
# nothing to print
|
|
warnings.warn(
|
|
"psutil unavailable. Install psutil to get CPU and memory stats",
|
|
stacklevel=2,
|
|
)
|
|
return
|
|
p = psutil.Process()
|
|
# format CPU percentage
|
|
cpu = p.cpu_percent(0.1)
|
|
if cpu >= 10:
|
|
cpu_s = "%i" % cpu
|
|
else:
|
|
cpu_s = "%.1f" % cpu
|
|
|
|
# format memory (only resident set)
|
|
rss = p.memory_info().rss
|
|
if rss >= 1e9:
|
|
mem_s = '%.1fG' % (rss / 1e9)
|
|
elif rss >= 1e7:
|
|
mem_s = '%.0fM' % (rss / 1e6)
|
|
elif rss >= 1e6:
|
|
mem_s = '%.1fM' % (rss / 1e6)
|
|
else:
|
|
mem_s = '%.0fk' % (rss / 1e3)
|
|
|
|
# left-justify and shrink-to-fit columns
|
|
cpulen = max(len(cpu_s), 4)
|
|
memlen = max(len(mem_s), 3)
|
|
fd_s = str(p.num_fds())
|
|
fdlen = max(len(fd_s), 3)
|
|
threadlen = len('threads')
|
|
|
|
print(
|
|
"%s %s %s %s"
|
|
% ('%CPU'.ljust(cpulen), 'MEM'.ljust(memlen), 'FDs'.ljust(fdlen), 'threads'),
|
|
file=file,
|
|
)
|
|
|
|
print(
|
|
"%s %s %s %s"
|
|
% (
|
|
cpu_s.ljust(cpulen),
|
|
mem_s.ljust(memlen),
|
|
fd_s.ljust(fdlen),
|
|
str(p.num_threads()).ljust(7),
|
|
),
|
|
file=file,
|
|
)
|
|
|
|
# trailing blank line
|
|
print('', file=file)
|
|
|
|
|
|
def print_stacks(file=sys.stderr):
|
|
"""Print current status of the process
|
|
|
|
For debugging purposes.
|
|
Used as part of SIGINFO handler.
|
|
|
|
- Shows active thread count
|
|
- Shows current stack for all threads
|
|
|
|
Parameters:
|
|
|
|
file: file to write output to (default: stderr)
|
|
|
|
"""
|
|
# local imports because these will not be used often,
|
|
# no need to add them to startup
|
|
import traceback
|
|
|
|
from .log import coroutine_frames
|
|
|
|
print("Active threads: %i" % threading.active_count(), file=file)
|
|
for thread in threading.enumerate():
|
|
print("Thread %s:" % thread.name, end='', file=file)
|
|
frame = sys._current_frames()[thread.ident]
|
|
stack = traceback.extract_stack(frame)
|
|
if thread is threading.current_thread():
|
|
# truncate last two frames of the current thread
|
|
# which are this function and its caller
|
|
stack = stack[:-2]
|
|
stack = coroutine_frames(stack)
|
|
if stack:
|
|
last_frame = stack[-1]
|
|
if (
|
|
last_frame[0].endswith('threading.py')
|
|
and last_frame[-1] == 'waiter.acquire()'
|
|
) or (
|
|
last_frame[0].endswith('thread.py')
|
|
and last_frame[-1].endswith('work_queue.get(block=True)')
|
|
):
|
|
# thread is waiting on a condition
|
|
# call it idle rather than showing the uninteresting stack
|
|
# most threadpools will be in this state
|
|
print(' idle', file=file)
|
|
continue
|
|
|
|
print(''.join(['\n'] + traceback.format_list(stack)), file=file)
|
|
|
|
# also show asyncio tasks, if any
|
|
# this will increase over time as we transition from tornado
|
|
# coroutines to native `async def`
|
|
tasks = asyncio_all_tasks()
|
|
if tasks:
|
|
print("AsyncIO tasks: %i" % len(tasks))
|
|
for task in tasks:
|
|
task.print_stack(file=file)
|
|
|
|
|
|
def maybe_future(obj):
|
|
"""Return an asyncio Future
|
|
|
|
Use instead of gen.maybe_future
|
|
|
|
For our compatibility, this must accept:
|
|
|
|
- asyncio coroutine (gen.maybe_future doesn't work in tornado < 5)
|
|
- tornado coroutine (asyncio.ensure_future doesn't work)
|
|
- scalar (asyncio.ensure_future doesn't work)
|
|
- concurrent.futures.Future (asyncio.ensure_future doesn't work)
|
|
- tornado Future (works both ways)
|
|
- asyncio Future (works both ways)
|
|
"""
|
|
if inspect.isawaitable(obj):
|
|
# already awaitable, use ensure_future
|
|
return asyncio.ensure_future(obj)
|
|
elif isinstance(obj, concurrent.futures.Future):
|
|
return asyncio.wrap_future(obj)
|
|
else:
|
|
# could also check for tornado.concurrent.Future
|
|
# but with tornado >= 5.1 tornado.Future is asyncio.Future
|
|
f = asyncio.Future()
|
|
f.set_result(obj)
|
|
return f
|
|
|
|
|
|
async def iterate_until(deadline_future, generator):
|
|
"""An async generator that yields items from a generator
|
|
until a deadline future resolves
|
|
|
|
This could *almost* be implemented as a context manager
|
|
like asyncio_timeout with a Future for the cutoff.
|
|
|
|
However, we want one distinction: continue yielding items
|
|
after the future is complete, as long as the are already finished.
|
|
|
|
Usage::
|
|
|
|
async for item in iterate_until(some_future, some_async_generator()):
|
|
print(item)
|
|
|
|
"""
|
|
async with aclosing(generator.__aiter__()) as aiter:
|
|
while True:
|
|
item_future = asyncio.ensure_future(aiter.__anext__())
|
|
await asyncio.wait(
|
|
[item_future, deadline_future], return_when=asyncio.FIRST_COMPLETED
|
|
)
|
|
if item_future.done():
|
|
try:
|
|
yield item_future.result()
|
|
except (StopAsyncIteration, asyncio.CancelledError):
|
|
break
|
|
elif deadline_future.done():
|
|
# deadline is done *and* next item is not ready
|
|
# cancel item future to avoid warnings about
|
|
# unawaited tasks
|
|
if not item_future.cancelled():
|
|
item_future.cancel()
|
|
# resolve cancellation to avoid garbage collection issues
|
|
try:
|
|
await item_future
|
|
except asyncio.CancelledError:
|
|
pass
|
|
break
|
|
else:
|
|
# neither is done, this shouldn't happen
|
|
continue
|
|
|
|
|
|
def utcnow(*, with_tz=True):
|
|
"""Return utcnow
|
|
|
|
with_tz (default): returns tz-aware datetime in UTC
|
|
|
|
if with_tz=False, returns UTC timestamp without tzinfo
|
|
(used for most internal timestamp storage because databases often don't preserve tz info)
|
|
"""
|
|
now = datetime.now(timezone.utc)
|
|
if not with_tz:
|
|
now = now.replace(tzinfo=None)
|
|
return now
|
|
|
|
|
|
def _parse_accept_header(accept):
|
|
"""
|
|
Parse the Accept header
|
|
|
|
Return a list with 2-tuples of
|
|
[(str(media_type), float(q_value)),] ordered by q values (descending).
|
|
|
|
Default `q` for values that are not specified is 1.0
|
|
"""
|
|
result = []
|
|
if not accept:
|
|
return result
|
|
for media_range in accept.split(","):
|
|
media_type, *parts = media_range.split(";")
|
|
media_type = media_type.strip()
|
|
if not media_type:
|
|
continue
|
|
|
|
q = 1.0
|
|
for part in parts:
|
|
(key, _, value) = part.partition("=")
|
|
key = key.strip()
|
|
if key == "q":
|
|
try:
|
|
q = float(value)
|
|
except ValueError:
|
|
pass
|
|
break
|
|
result.append((media_type, q))
|
|
result.sort(key=itemgetter(1), reverse=True)
|
|
return result
|
|
|
|
|
|
def get_accepted_mimetype(accept_header, choices=None):
|
|
"""Return the preferred mimetype from an Accept header
|
|
|
|
If `choices` is given, return the first match,
|
|
otherwise return the first accepted item
|
|
|
|
Return `None` if choices is given and no match is found,
|
|
or nothing is specified.
|
|
"""
|
|
for mime, q in _parse_accept_header(accept_header):
|
|
if choices:
|
|
if mime in choices:
|
|
return mime
|
|
else:
|
|
continue
|
|
else:
|
|
return mime
|
|
return None
|
|
|
|
|
|
def catch_db_error(f):
|
|
"""Catch and rollback database errors"""
|
|
|
|
@functools.wraps(f)
|
|
async def catching(self, *args, **kwargs):
|
|
try:
|
|
r = f(self, *args, **kwargs)
|
|
if inspect.isawaitable(r):
|
|
r = await r
|
|
except SQLAlchemyError:
|
|
self.log.exception("Rolling back session due to database error")
|
|
self.db.rollback()
|
|
else:
|
|
return r
|
|
|
|
return catching
|
|
|
|
|
|
def get_browser_protocol(request):
|
|
"""Get the _protocol_ seen by the browser
|
|
|
|
Like tornado's _apply_xheaders,
|
|
but in the case of multiple proxy hops,
|
|
use the outermost value (what the browser likely sees)
|
|
instead of the innermost value,
|
|
which is the most trustworthy.
|
|
|
|
We care about what the browser sees,
|
|
not where the request actually came from,
|
|
so trusting possible spoofs is the right thing to do.
|
|
"""
|
|
headers = request.headers
|
|
# first choice: Forwarded header
|
|
forwarded_header = headers.get("Forwarded")
|
|
if forwarded_header:
|
|
first_forwarded = forwarded_header.split(",", 1)[0].strip()
|
|
fields = {}
|
|
forwarded_dict = {}
|
|
for field in first_forwarded.split(";"):
|
|
key, _, value = field.partition("=")
|
|
fields[key.strip().lower()] = value.strip()
|
|
if "proto" in fields and fields["proto"].lower() in {"http", "https"}:
|
|
return fields["proto"].lower()
|
|
else:
|
|
app_log.warning(
|
|
f"Forwarded header present without protocol: {forwarded_header}"
|
|
)
|
|
|
|
# second choice: X-Scheme or X-Forwarded-Proto
|
|
proto_header = headers.get("X-Scheme", headers.get("X-Forwarded-Proto", None))
|
|
if proto_header:
|
|
proto_header = proto_header.split(",")[0].strip().lower()
|
|
if proto_header in {"http", "https"}:
|
|
return proto_header
|
|
|
|
# no forwarded headers
|
|
return request.protocol
|
|
|
|
|
|
# set of chars that are safe in dns labels
|
|
# (allow '.' because we don't mind multiple levels of subdomains)
|
|
_dns_safe = set(string.ascii_letters + string.digits + '-.')
|
|
# don't escape % because it's the escape char and we handle it separately
|
|
_dns_needs_replace = _dns_safe | {"%"}
|
|
|
|
|
|
@lru_cache()
|
|
def _dns_quote(name):
|
|
"""Escape a name for use in a dns label
|
|
|
|
this is _NOT_ fully domain-safe, but works often enough for realistic usernames.
|
|
Fully safe would be full IDNA encoding,
|
|
PLUS escaping non-IDNA-legal ascii,
|
|
PLUS some encoding of boundary conditions
|
|
"""
|
|
# escape name for subdomain label
|
|
label = quote(name, safe="").lower()
|
|
# some characters are not handled by quote,
|
|
# because they are legal in URLs but not domains,
|
|
# specifically _ and ~ (starting in 3.7).
|
|
# Escape these in the same way (%{hex_codepoint}).
|
|
unique_chars = set(label)
|
|
for c in unique_chars:
|
|
if c not in _dns_needs_replace:
|
|
label = label.replace(c, f"%{ord(c):x}")
|
|
|
|
# underscore is our escape char -
|
|
# it's not officially legal in hostnames,
|
|
# but is valid in _domain_ names (?),
|
|
# and seems to always work in practice.
|
|
label = label.replace("%", "_")
|
|
return label
|
|
|
|
|
|
def subdomain_hook_legacy(name, domain, kind):
|
|
"""Legacy (default) hook for subdomains
|
|
|
|
Users are at '$user.$host' where $user is _mostly_ DNS-safe.
|
|
Services are all simultaneously on 'services.$host`.
|
|
"""
|
|
if kind == "user":
|
|
# backward-compatibility
|
|
return f"{_dns_quote(name)}.{domain}"
|
|
elif kind == "service":
|
|
return f"services.{domain}"
|
|
else:
|
|
raise ValueError(f"kind must be 'service' or 'user', not {kind!r}")
|
|
|
|
|
|
# strict dns-safe characters (excludes '-')
|
|
_strict_dns_safe = set(string.ascii_lowercase) | set(string.digits)
|
|
|
|
|
|
def _trim_and_hash(name):
|
|
"""Always-safe fallback for a DNS label
|
|
|
|
Produces a valid and unique DNS label for any string
|
|
|
|
- prefix with 'u-' to avoid collisions and first-character rules
|
|
- Selects the first N characters that are safe ('x' if none are safe)
|
|
- suffix with truncated hash of true name
|
|
- length is guaranteed to be < 32 characters
|
|
leaving room for additional components to build a DNS label.
|
|
Will currently be between 12-19 characters:
|
|
4 (prefix, delimiters) + 7 (hash) + 1-8 (name stub)
|
|
"""
|
|
name_hash = hashlib.sha256(name.encode('utf8')).hexdigest()[:7]
|
|
|
|
safe_chars = [c for c in name.lower() if c in _strict_dns_safe]
|
|
name_stub = ''.join(safe_chars[:8])
|
|
# We MUST NOT put the `--` in the 3rd and 4th position (RFC 5891)
|
|
# which is reserved for IDNs
|
|
# It would be if name_stub were empty, so put 'x' here
|
|
# (value doesn't matter, as uniqueness is in the hash - the stub is more of a hint, anyway)
|
|
if not name_stub:
|
|
name_stub = "x"
|
|
return f"u-{name_stub}--{name_hash}"
|
|
|
|
|
|
# A host name (label) can start or end with a letter or a number
|
|
# this pattern doesn't need to handle the boundary conditions,
|
|
# which are handled more simply with starts/endswith
|
|
_dns_re = re.compile(r'^[a-z0-9-]{1,63}$', flags=re.IGNORECASE)
|
|
|
|
|
|
def _is_dns_safe(label, max_length=63):
|
|
# A host name (label) MUST NOT consist of all numeric values
|
|
if label.isnumeric():
|
|
return False
|
|
# A host name (label) can be up to 63 characters
|
|
if not 0 < len(label) <= max_length:
|
|
return False
|
|
# A host name (label) MUST NOT start or end with a '-' (dash)
|
|
if label.startswith('-') or label.endswith('-'):
|
|
return False
|
|
return bool(_dns_re.match(label))
|
|
|
|
|
|
def _strict_dns_safe_encode(name, max_length=63):
|
|
"""Will encode a username to a guaranteed-safe DNS label
|
|
|
|
- if it contains '--' at all, jump to the end and take the hash route to avoid collisions with escaped
|
|
- if safe, use it
|
|
- if not, use IDNA encoding
|
|
- if a safe encoding cannot be produced, use stripped safe characters + '--{hash}`
|
|
- allow specifying a max_length, to give room for additional components,
|
|
if used as only a _part_ of a DNS label.
|
|
"""
|
|
# short-circuit: avoid accepting already-encoded results
|
|
# which all include '--'
|
|
if '--' in name:
|
|
return _trim_and_hash(name)
|
|
|
|
# if name is already safe (and can't collide with an escaped result) use it
|
|
if _is_dns_safe(name, max_length=max_length):
|
|
return name
|
|
|
|
# next: use IDNA encoding, if applicable
|
|
try:
|
|
idna_name = idna.encode(name).decode("ascii")
|
|
except ValueError:
|
|
idna_name = None
|
|
|
|
if idna_name and idna_name != name and _is_dns_safe(idna_name):
|
|
return idna_name
|
|
|
|
# fallback, always works: trim to safe characters and hash
|
|
return _trim_and_hash(name)
|
|
|
|
|
|
def subdomain_hook_idna(name, domain, kind):
|
|
"""New, reliable subdomain hook
|
|
|
|
More reliable than previous, should always produce valid domains
|
|
|
|
- uses IDNA encoding for simple unicode names
|
|
- separate domain for each service
|
|
- uses stripped name and hash, where above schemes fail to produce a valid domain
|
|
"""
|
|
safe_name = _strict_dns_safe_encode(name)
|
|
if kind == 'user':
|
|
# 'user' namespace is special-cased as the default
|
|
# for aesthetics and backward-compatibility for names that don't need escaping
|
|
suffix = ""
|
|
else:
|
|
suffix = f"--{kind}"
|
|
return f"{safe_name}{suffix}.{domain}"
|