add utils.awaitable replacement for gen.maybe_future

gen.maybe_future doesn't accept asyncio coroutines
and asyncio.ensure_future doesn't accept *tornado* coroutines, so do our own thing
This commit is contained in:
Min RK
2018-03-01 17:36:03 +01:00
parent 7b4de150cc
commit b6f634368c
8 changed files with 39 additions and 38 deletions

View File

@@ -8,7 +8,7 @@ import json
from tornado import gen, web from tornado import gen, web
from .. import orm from .. import orm
from ..utils import admin_only from ..utils import admin_only, awaitable
from .base import APIHandler from .base import APIHandler
@@ -76,7 +76,7 @@ class UserListAPIHandler(APIHandler):
user.admin = True user.admin = True
self.db.commit() self.db.commit()
try: try:
await gen.maybe_future(self.authenticator.add_user(user)) await awaitable(self.authenticator.add_user(user))
except Exception as e: except Exception as e:
self.log.error("Failed to create user: %s" % name, exc_info=True) self.log.error("Failed to create user: %s" % name, exc_info=True)
self.users.delete(user) self.users.delete(user)
@@ -125,7 +125,7 @@ class UserAPIHandler(APIHandler):
self.db.commit() self.db.commit()
try: try:
await gen.maybe_future(self.authenticator.add_user(user)) await awaitable(self.authenticator.add_user(user))
except Exception: except Exception:
self.log.error("Failed to create user: %s" % name, exc_info=True) self.log.error("Failed to create user: %s" % name, exc_info=True)
# remove from registry # remove from registry
@@ -149,7 +149,7 @@ class UserAPIHandler(APIHandler):
if user.spawner._stop_pending: if user.spawner._stop_pending:
raise web.HTTPError(400, "%s's server is in the process of stopping, please wait." % name) raise web.HTTPError(400, "%s's server is in the process of stopping, please wait." % name)
await gen.maybe_future(self.authenticator.delete_user(user)) await awaitable(self.authenticator.delete_user(user))
# remove from registry # remove from registry
self.users.delete(user) self.users.delete(user)

View File

@@ -55,6 +55,7 @@ from .log import CoroutineLogFormatter, log_request
from .proxy import Proxy, ConfigurableHTTPProxy from .proxy import Proxy, ConfigurableHTTPProxy
from .traitlets import URLPrefix, Command from .traitlets import URLPrefix, Command
from .utils import ( from .utils import (
awaitable,
url_path_join, url_path_join,
ISO8601_ms, ISO8601_s, ISO8601_ms, ISO8601_s,
print_stacks, print_ps_info, print_stacks, print_ps_info,
@@ -1046,7 +1047,7 @@ class JupyterHub(Application):
# and persist across sessions. # and persist across sessions.
for user in db.query(orm.User): for user in db.query(orm.User):
try: try:
await gen.maybe_future(self.authenticator.add_user(user)) await awaitable(self.authenticator.add_user(user))
except Exception: except Exception:
self.log.exception("Error adding user %s already in db", user.name) self.log.exception("Error adding user %s already in db", user.name)
if self.authenticator.delete_invalid_users: if self.authenticator.delete_invalid_users:
@@ -1077,7 +1078,7 @@ class JupyterHub(Application):
db.add(group) db.add(group)
for username in usernames: for username in usernames:
username = self.authenticator.normalize_username(username) username = self.authenticator.normalize_username(username)
if not (await gen.maybe_future(self.authenticator.check_whitelist(username))): if not (await awaitable(self.authenticator.check_whitelist(username))):
raise ValueError("Username %r is not in whitelist" % username) raise ValueError("Username %r is not in whitelist" % username)
user = orm.User.find(db, name=username) user = orm.User.find(db, name=username)
if user is None: if user is None:
@@ -1101,7 +1102,7 @@ class JupyterHub(Application):
for token, name in token_dict.items(): for token, name in token_dict.items():
if kind == 'user': if kind == 'user':
name = self.authenticator.normalize_username(name) name = self.authenticator.normalize_username(name)
if not (await gen.maybe_future(self.authenticator.check_whitelist(name))): if not (await awaitable(self.authenticator.check_whitelist(name))):
raise ValueError("Token name %r is not in whitelist" % name) raise ValueError("Token name %r is not in whitelist" % name)
if not self.authenticator.validate_username(name): if not self.authenticator.validate_username(name):
raise ValueError("Token name %r is not valid" % name) raise ValueError("Token name %r is not valid" % name)
@@ -1491,7 +1492,7 @@ class JupyterHub(Application):
# clean up proxy while single-user servers are shutting down # clean up proxy while single-user servers are shutting down
if self.cleanup_proxy: if self.cleanup_proxy:
if self.proxy.should_start: if self.proxy.should_start:
await gen.maybe_future(self.proxy.stop()) await awaitable(self.proxy.stop())
else: else:
self.log.info("I didn't start the proxy, I can't clean it up") self.log.info("I didn't start the proxy, I can't clean it up")
else: else:

View File

@@ -23,7 +23,7 @@ from traitlets.config import LoggingConfigurable
from traitlets import Bool, Set, Unicode, Dict, Any, default, observe from traitlets import Bool, Set, Unicode, Dict, Any, default, observe
from .handlers.login import LoginHandler from .handlers.login import LoginHandler
from .utils import url_path_join from .utils import awaitable, url_path_join
from .traitlets import Command from .traitlets import Command
@@ -244,7 +244,7 @@ class Authenticator(LoggingConfigurable):
self.log.warning("Disallowing invalid username %r.", username) self.log.warning("Disallowing invalid username %r.", username)
return return
whitelist_pass = await gen.maybe_future(self.check_whitelist(username)) whitelist_pass = await awaitable(self.check_whitelist(username))
if whitelist_pass: if whitelist_pass:
return authenticated return authenticated
else: else:
@@ -481,14 +481,14 @@ class LocalAuthenticator(Authenticator):
If self.create_system_users, the user will attempt to be created if it doesn't exist. If self.create_system_users, the user will attempt to be created if it doesn't exist.
""" """
user_exists = await gen.maybe_future(self.system_user_exists(user)) user_exists = await awaitable(self.system_user_exists(user))
if not user_exists: if not user_exists:
if self.create_system_users: if self.create_system_users:
await gen.maybe_future(self.add_system_user(user)) await awaitable(self.add_system_user(user))
else: else:
raise KeyError("User %s does not exist." % user.name) raise KeyError("User %s does not exist." % user.name)
await gen.maybe_future(super().add_user(user)) await awaitable(super().add_user(user))
@staticmethod @staticmethod
def system_user_exists(user): def system_user_exists(user):

View File

@@ -19,6 +19,7 @@ except ImportError:
class InvalidToken(Exception): class InvalidToken(Exception):
pass pass
from .utils import awaitable
KEY_ENV = 'JUPYTERHUB_CRYPT_KEY' KEY_ENV = 'JUPYTERHUB_CRYPT_KEY'
@@ -132,7 +133,7 @@ class CryptKeeper(SingletonConfigurable):
def encrypt(self, data): def encrypt(self, data):
"""Encrypt an object with cryptography""" """Encrypt an object with cryptography"""
self.check_available() self.check_available()
return self.executor.submit(self._encrypt, data) return awaitable(self.executor.submit(self._encrypt, data))
def _decrypt(self, encrypted): def _decrypt(self, encrypted):
decrypted = self.fernet.decrypt(encrypted) decrypted = self.fernet.decrypt(encrypted)
@@ -141,7 +142,7 @@ class CryptKeeper(SingletonConfigurable):
def decrypt(self, encrypted): def decrypt(self, encrypted):
"""Decrypt an object with cryptography""" """Decrypt an object with cryptography"""
self.check_available() self.check_available()
return self.executor.submit(self._decrypt, encrypted) return awaitable(self.executor.submit(self._decrypt, encrypted))
def encrypt(data): def encrypt(data):
@@ -158,4 +159,3 @@ def decrypt(data):
Returns a Future whose result will be the decrypted, deserialized data. Returns a Future whose result will be the decrypted, deserialized data.
""" """
return CryptKeeper.instance().decrypt(data) return CryptKeeper.instance().decrypt(data)

View File

@@ -23,7 +23,7 @@ from .. import __version__
from .. import orm from .. import orm
from ..objects import Server from ..objects import Server
from ..spawner import LocalProcessSpawner from ..spawner import LocalProcessSpawner
from ..utils import url_path_join from ..utils import awaitable, url_path_join
from ..metrics import ( from ..metrics import (
SERVER_SPAWN_DURATION_SECONDS, ServerSpawnStatus, SERVER_SPAWN_DURATION_SECONDS, ServerSpawnStatus,
PROXY_ADD_DURATION_SECONDS, ProxyAddStatus PROXY_ADD_DURATION_SECONDS, ProxyAddStatus
@@ -387,7 +387,7 @@ class BaseHandler(RequestHandler):
self.set_hub_cookie(user) self.set_hub_cookie(user)
def authenticate(self, data): def authenticate(self, data):
return gen.maybe_future(self.authenticator.get_authenticated_user(self, data)) return awaitable(self.authenticator.get_authenticated_user(self, data))
def get_next_url(self, user=None): def get_next_url(self, user=None):
"""Get the next_url for login redirect """Get the next_url for login redirect
@@ -421,7 +421,7 @@ class BaseHandler(RequestHandler):
new_user = username not in self.users new_user = username not in self.users
user = self.user_from_username(username) user = self.user_from_username(username)
if new_user: if new_user:
await gen.maybe_future(self.authenticator.add_user(user)) await awaitable(self.authenticator.add_user(user))
# Only set `admin` if the authenticator returned an explicit value. # Only set `admin` if the authenticator returned an explicit value.
if admin is not None and admin != user.admin: if admin is not None and admin != user.admin:
user.admin = admin user.admin = admin
@@ -577,7 +577,7 @@ class BaseHandler(RequestHandler):
# hook up spawner._spawn_future so that other requests can await # hook up spawner._spawn_future so that other requests can await
# this result # this result
finish_spawn_future = spawner._spawn_future = finish_user_spawn() finish_spawn_future = spawner._spawn_future = awaitable(finish_user_spawn())
def _clear_spawn_future(f): def _clear_spawn_future(f):
# clear spawner._spawn_future when it's done # clear spawner._spawn_future when it's done
# keep an exception around, though, to prevent repeated implicit spawns # keep an exception around, though, to prevent repeated implicit spawns

View File

@@ -28,7 +28,7 @@ from traitlets import (
from .objects import Server from .objects import Server
from .traitlets import Command, ByteSpecification, Callable from .traitlets import Command, ByteSpecification, Callable
from .utils import random_port, url_path_join, exponential_backoff from .utils import awaitable, random_port, url_path_join, exponential_backoff
class Spawner(LoggingConfigurable): class Spawner(LoggingConfigurable):
@@ -269,7 +269,7 @@ class Spawner(LoggingConfigurable):
Introduced. Introduced.
""" """
if callable(self.options_form): if callable(self.options_form):
options_form = await gen.maybe_future(self.options_form(self)) options_form = await awaitable(self.options_form(self))
else: else:
options_form = self.options_form options_form = self.options_form
@@ -783,7 +783,7 @@ class Spawner(LoggingConfigurable):
for callback in callbacks: for callback in callbacks:
try: try:
await gen.maybe_future(callback()) await awaitable(callback())
except Exception: except Exception:
self.log.exception("Unhandled error in poll callback for %s", self) self.log.exception("Unhandled error in poll callback for %s", self)
return status return status

View File

@@ -12,7 +12,7 @@ from tornado import gen
from tornado.log import app_log from tornado.log import app_log
from traitlets import HasTraits, Any, Dict, default from traitlets import HasTraits, Any, Dict, default
from .utils import url_path_join from .utils import awaitable, url_path_join
from . import orm from . import orm
from ._version import _check_version, __version__ from ._version import _check_version, __version__
@@ -378,13 +378,13 @@ class User:
# trigger pre-spawn hook on authenticator # trigger pre-spawn hook on authenticator
authenticator = self.authenticator authenticator = self.authenticator
if (authenticator): if (authenticator):
await gen.maybe_future(authenticator.pre_spawn_start(self, spawner)) await awaitable(authenticator.pre_spawn_start(self, spawner))
spawner._start_pending = True spawner._start_pending = True
# wait for spawner.start to return # wait for spawner.start to return
try: try:
# run optional preparation work to bootstrap the notebook # run optional preparation work to bootstrap the notebook
await gen.maybe_future(spawner.run_pre_spawn_hook()) await awaitable(spawner.run_pre_spawn_hook())
f = spawner.start() f = spawner.start()
# commit any changes in spawner.start (always commit db changes before yield) # commit any changes in spawner.start (always commit db changes before yield)
db.commit() db.commit()
@@ -533,7 +533,7 @@ class User:
auth = spawner.authenticator auth = spawner.authenticator
try: try:
if auth: if auth:
await gen.maybe_future( await awaitable(
auth.post_spawn_stop(self, spawner) auth.post_spawn_stop(self, spawner)
) )
except Exception: except Exception:

View File

@@ -20,7 +20,7 @@ import uuid
import warnings import warnings
from tornado import gen, ioloop, web from tornado import gen, ioloop, web
from tornado.concurrent import to_asyncio_future from tornado.platform.asyncio import to_asyncio_future
from tornado.httpclient import AsyncHTTPClient, HTTPError from tornado.httpclient import AsyncHTTPClient, HTTPError
from tornado.log import app_log from tornado.log import app_log
@@ -123,7 +123,7 @@ async def exponential_backoff(
deadline = random.uniform(deadline - tol, deadline + tol) deadline = random.uniform(deadline - tol, deadline + tol)
scale = 1 scale = 1
while True: while True:
ret = await gen.maybe_future(pass_func(*args, **kwargs)) ret = await awaitable(pass_func(*args, **kwargs))
# Truthy! # Truthy!
if ret: if ret:
return ret return ret
@@ -428,8 +428,8 @@ def awaitable(obj):
- asyncio Future (works both ways) - asyncio Future (works both ways)
""" """
if inspect.isawaitable(obj): if inspect.isawaitable(obj):
# return obj that's already awaitable # already awaitable, use ensure_future
return obj return asyncio.ensure_future(obj)
elif isinstance(obj, concurrent.futures.Future): elif isinstance(obj, concurrent.futures.Future):
return asyncio.wrap_future(obj) return asyncio.wrap_future(obj)
else: else: