Files
jupyterhub/jupyterhub/utils.py
2016-03-29 09:22:32 -07:00

208 lines
5.8 KiB
Python

"""Miscellaneous utilities"""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
from binascii import b2a_hex
import errno
import hashlib
from hmac import compare_digest
import os
import socket
from threading import Thread
import uuid
import warnings
from tornado import web, gen, ioloop
from tornado.httpclient import AsyncHTTPClient, HTTPError
from tornado.log import app_log
def random_port():
"""get a single random port"""
sock = socket.socket()
sock.bind(('', 0))
port = sock.getsockname()[1]
sock.close()
return port
# ISO8601 for strptime with/without milliseconds
ISO8601_ms = '%Y-%m-%dT%H:%M:%S.%fZ'
ISO8601_s = '%Y-%m-%dT%H:%M:%SZ'
def can_connect(ip, port):
"""Check if we can connect to an ip:port
return True if we can connect, False otherwise.
"""
try:
socket.create_connection((ip, port))
except socket.error as e:
if e.errno not in {errno.ECONNREFUSED, errno.ETIMEDOUT}:
app_log.error("Unexpected error connecting to %s:%i %s",
ip, port, e
)
return False
else:
return True
@gen.coroutine
def wait_for_server(ip, port, timeout=10):
"""wait for any server to show up at ip:port"""
loop = ioloop.IOLoop.current()
tic = loop.time()
while loop.time() - tic < timeout:
if can_connect(ip, port):
return
else:
yield gen.sleep(0.1)
raise TimeoutError("Server at {ip}:{port} didn't respond in {timeout} seconds".format(
**locals()
))
@gen.coroutine
def wait_for_http_server(url, timeout=10):
"""Wait for an HTTP Server to respond at url
Any non-5XX response code will do, even 404.
"""
loop = ioloop.IOLoop.current()
tic = loop.time()
client = AsyncHTTPClient()
while loop.time() - tic < timeout:
try:
r = yield client.fetch(url, follow_redirects=False)
except HTTPError as e:
if e.code >= 500:
# failed to respond properly, wait and try again
if e.code != 599:
# we expect 599 for no connection,
# but 502 or other proxy error is conceivable
app_log.warning("Server at %s responded with error: %s", url, e.code)
yield gen.sleep(0.1)
else:
app_log.debug("Server at %s responded with %s", url, e.code)
return
except (OSError, socket.error) as e:
if e.errno not in {errno.ECONNABORTED, errno.ECONNREFUSED, errno.ECONNRESET}:
app_log.warning("Failed to connect to %s (%s)", url, e)
yield gen.sleep(0.1)
else:
return
raise TimeoutError("Server at {url} didn't respond in {timeout} seconds".format(
**locals()
))
# Decorators for authenticated Handlers
def auth_decorator(check_auth):
"""Make an authentication decorator
I heard you like decorators, so I put a decorator
in your decorator, so you can decorate while you decorate.
"""
def decorator(method):
def decorated(self, *args, **kwargs):
check_auth(self)
return method(self, *args, **kwargs)
decorated.__name__ = method.__name__
decorated.__doc__ = method.__doc__
return decorated
decorator.__name__ = check_auth.__name__
decorator.__doc__ = check_auth.__doc__
return decorator
@auth_decorator
def token_authenticated(self):
"""decorator for a method authenticated only by the Authorization token header
(no cookies)
"""
if self.get_current_user_token() is None:
raise web.HTTPError(403)
@auth_decorator
def authenticated_403(self):
"""like web.authenticated, but raise 403 instead of redirect to login"""
if self.get_current_user() is None:
raise web.HTTPError(403)
@auth_decorator
def admin_only(self):
"""decorator for restricting access to admin users"""
user = self.get_current_user()
if user is None or not user.admin:
raise web.HTTPError(403)
# Token utilities
def new_token(*args, **kwargs):
"""generator for new random tokens
For now, just UUIDs.
"""
return uuid.uuid4().hex
def hash_token(token, salt=8, rounds=16384, algorithm='sha512'):
"""hash a token, and return it as `algorithm:salt:hash`
If `salt` is an integer, a random salt of that many bytes will be used.
"""
h = hashlib.new(algorithm)
if isinstance(salt, int):
salt = b2a_hex(os.urandom(salt))
if isinstance(salt, bytes):
bsalt = salt
salt = salt.decode('utf8')
else:
bsalt = salt.encode('utf8')
btoken = token.encode('utf8', 'replace')
h.update(bsalt)
for i in range(rounds):
h.update(btoken)
digest = h.hexdigest()
return "{algorithm}:{rounds}:{salt}:{digest}".format(**locals())
def compare_token(compare, token):
"""compare a token with a hashed token
uses the same algorithm and salt of the hashed token for comparison
"""
algorithm, srounds, salt, _ = compare.split(':')
hashed = hash_token(token, salt=salt, rounds=int(srounds), algorithm=algorithm).encode('utf8')
compare = compare.encode('utf8')
if compare_digest(compare, hashed):
return True
return False
def url_path_join(*pieces):
"""Join components of url into a relative url
Use to prevent double slash when joining subpath. This will leave the
initial and final / in place
Copied from notebook.utils.url_path_join
"""
initial = pieces[0].startswith('/')
final = pieces[-1].endswith('/')
stripped = [ s.strip('/') for s in pieces ]
result = '/'.join(s for s in stripped if s)
if initial:
result = '/' + result
if final:
result = result + '/'
if result == '//':
result = '/'
return result