diff --git a/jupyterhub/handlers/base.py b/jupyterhub/handlers/base.py index c2585eca..4e098a1a 100644 --- a/jupyterhub/handlers/base.py +++ b/jupyterhub/handlers/base.py @@ -20,7 +20,7 @@ from .. import __version__ from .. import orm from ..objects import Server from ..spawner import LocalProcessSpawner -from ..utils import url_path_join, DT_SCALE +from ..utils import url_path_join, exponential_backoff # pattern for the authentication token header auth_header_pat = re.compile(r'^(?:token|bearer)\s+([^\s]+)$', flags=re.IGNORECASE) @@ -641,7 +641,8 @@ class UserSpawnHandler(BaseHandler): # record redirect count in query parameter if redirects: self.log.warning("Redirect loop detected on %s", self.request.uri) - yield gen.sleep(min(1 * (DT_SCALE ** redirects), 10)) + # add capped exponential backoff where cap is 10s + yield gen.sleep(min(1 * (2 ** redirects), 10)) # rewrite target url with new `redirects` query value url_parts = urlparse(target) query_parts = parse_qs(url_parts.query) diff --git a/jupyterhub/spawner.py b/jupyterhub/spawner.py index 0abada4b..2ac5c8c8 100644 --- a/jupyterhub/spawner.py +++ b/jupyterhub/spawner.py @@ -26,7 +26,7 @@ from traitlets import ( from .objects import Server from .traitlets import Command, ByteSpecification -from .utils import random_port, url_path_join, DT_MIN, DT_MAX, DT_SCALE +from .utils import random_port, url_path_join, exponential_backoff class Spawner(LoggingConfigurable): @@ -666,21 +666,25 @@ class Spawner(LoggingConfigurable): self.log.exception("Unhandled error in poll callback for %s", self) return status - death_interval = Float(DT_MIN) - + death_interval = Float(0.1) @gen.coroutine def wait_for_death(self, timeout=10): """Wait for the single-user server to die, up to timeout seconds""" - loop = IOLoop.current() - tic = loop.time() - dt = self.death_interval - while dt > 0: + @gen.coroutine + def _wait_for_death(): status = yield self.poll() - if status is not None: - break - else: - yield gen.sleep(dt) - dt = min(dt * DT_SCALE, DT_MAX, timeout - (loop.time() - tic)) + return status is not None + + try: + r = yield exponential_backoff( + _wait_for_death, + 'Process did not die in {timeout} seconds'.format(timeout=timeout), + start_wait=self.death_interval, + timeout=timeout, + ) + return r + except TimeoutError: + return False def _try_setcwd(path): diff --git a/jupyterhub/utils.py b/jupyterhub/utils.py index 38ece304..fd833b0b 100644 --- a/jupyterhub/utils.py +++ b/jupyterhub/utils.py @@ -4,6 +4,7 @@ # Distributed under the terms of the Modified BSD License. from binascii import b2a_hex +import random import errno import hashlib from hmac import compare_digest @@ -48,29 +49,100 @@ def can_connect(ip, port): else: return True -# exponential falloff factors: -# start at 100ms, falloff by 2x -# never longer than 5s -DT_MIN = 0.1 -DT_SCALE = 2 -DT_MAX = 5 +@gen.coroutine +def exponential_backoff( + pass_func, + fail_message, + start_wait=0.2, + scale_factor=2, + max_wait=5, + timeout=10, + timeout_tolerance=0.1, + *args, **kwargs): + """ + Exponentially backoff until `pass_func` is true. + + The `pass_func` function will wait with **exponential backoff** and + **random jitter** for as many needed iterations of the Tornado loop, + until reaching maximum `timeout` or truthiness. If `pass_func` is still + returning false at `timeout`, a `TimeoutError` will be raised. + + The first iteration will begin with a wait time of `start_wait` seconds. + Each subsequent iteration's wait time will scale up by continuously + multiplying itself by `scale_factor`. This continues for each iteration + until `pass_func` returns true or an iteration's wait time has reached + the `max_wait` seconds per iteration. + + `pass_func` may be a future, although that is not entirely recommended. + + Parameters + ---------- + pass_func + function that is to be run + fail_message : str + message for a `TimeoutError` + start_wait : optional + initial wait time for the first iteration in seconds + scale_factor : optional + a multiplier to increase the wait time for each iteration + max_wait : optional + maximum wait time per iteration in seconds + timeout : optional + maximum time of total wait in seconds + timeout_tolerance : optional + a small multiplier used to add jitter to `timeout`'s deadline + *args, **kwargs + passed to `pass_func(*args, **kwargs)` + + Returns + ------- + value of `pass_func(*args, **kwargs)` + + Raises + ------ + TimeoutError + If `pass_func` is still false at the end of the `timeout` period. + + Notes + ----- + See https://www.awsarchitectureblog.com/2015/03/backoff.html + for information about the algorithm and examples. We're using their + full Jitter implementation equivalent. + """ + loop = ioloop.IOLoop.current() + deadline = loop.time() + timeout + # add jitter to the deadline itself to prevent re-align of a bunch of + # timing out calls once the deadline is reached. + if timeout_tolerance: + tol = timeout_tolerance * timeout + deadline = random.uniform(deadline - tol, deadline + tol) + scale = 1 + while True: + ret = yield gen.maybe_future(pass_func(*args, **kwargs)) + # Truthy! + if ret: + return ret + remaining = deadline - loop.time() + if remaining < 0: + # timeout exceeded + break + # add some random jitter to improve performance + # this prevents overloading any single tornado loop iteration with + # too many things + dt = min(max_wait, remaining, random.uniform(0, start_wait * scale)) + scale *= scale_factor + yield gen.sleep(dt) + raise TimeoutError(fail_message) + @gen.coroutine def wait_for_server(ip, port, timeout=10): """Wait for any server to show up at ip:port.""" if ip in {'', '0.0.0.0'}: ip = '127.0.0.1' - loop = ioloop.IOLoop.current() - tic = loop.time() - dt = DT_MIN - while dt > 0: - if can_connect(ip, port): - return - else: - yield gen.sleep(dt) - dt = min(dt * DT_SCALE, DT_MAX, timeout - (loop.time() - tic)) - raise TimeoutError( - "Server at {ip}:{port} didn't respond in {timeout} seconds".format(**locals()) + yield exponential_backoff( + lambda: can_connect(ip, port), + "Server at {ip}:{port} didn't respond in {timeout} seconds".format(ip=ip, port=port, timeout=timeout) ) @@ -80,13 +152,12 @@ def wait_for_http_server(url, timeout=10): Any non-5XX response code will do, even 404. """ - loop = ioloop.IOLoop.current() - tic = loop.time() client = AsyncHTTPClient() - dt = DT_MIN - while dt > 0: + @gen.coroutine + def is_reachable(): try: r = yield client.fetch(url, follow_redirects=False) + return r except HTTPError as e: if e.code >= 500: # failed to respond properly, wait and try again @@ -95,25 +166,21 @@ def wait_for_http_server(url, timeout=10): # but 502 or other proxy error is conceivable app_log.warning( "Server at %s responded with error: %s", url, e.code) - yield gen.sleep(dt) else: app_log.debug("Server at %s responded with %s", url, e.code) return e.response except (OSError, socket.error) as e: if e.errno not in {errno.ECONNABORTED, errno.ECONNREFUSED, errno.ECONNRESET}: app_log.warning("Failed to connect to %s (%s)", url, e) - yield gen.sleep(dt) - else: - return r - dt = min(dt * DT_SCALE, DT_MAX, timeout - (loop.time() - tic)) - - raise TimeoutError( - "Server at {url} didn't respond in {timeout} seconds".format(**locals()) + return False + re = yield exponential_backoff( + is_reachable, + "Server at {url} didn't respond in {timeout} seconds".format(url=url, timeout=timeout) ) + return re # Decorators for authenticated Handlers - def auth_decorator(check_auth): """Make an authentication decorator.