mirror of
https://github.com/jupyterhub/jupyterhub.git
synced 2025-10-12 20:43:02 +00:00
add utils.awaitable replacement for gen.maybe_future
gen.maybe_future doesn't accept asyncio coroutines and asyncio.ensure_future doesn't accept *tornado* coroutines, so do our own thing
This commit is contained in:
@@ -8,7 +8,7 @@ import json
|
||||
from tornado import gen, web
|
||||
|
||||
from .. import orm
|
||||
from ..utils import admin_only
|
||||
from ..utils import admin_only, awaitable
|
||||
from .base import APIHandler
|
||||
|
||||
|
||||
@@ -76,7 +76,7 @@ class UserListAPIHandler(APIHandler):
|
||||
user.admin = True
|
||||
self.db.commit()
|
||||
try:
|
||||
await gen.maybe_future(self.authenticator.add_user(user))
|
||||
await awaitable(self.authenticator.add_user(user))
|
||||
except Exception as e:
|
||||
self.log.error("Failed to create user: %s" % name, exc_info=True)
|
||||
self.users.delete(user)
|
||||
@@ -125,7 +125,7 @@ class UserAPIHandler(APIHandler):
|
||||
self.db.commit()
|
||||
|
||||
try:
|
||||
await gen.maybe_future(self.authenticator.add_user(user))
|
||||
await awaitable(self.authenticator.add_user(user))
|
||||
except Exception:
|
||||
self.log.error("Failed to create user: %s" % name, exc_info=True)
|
||||
# remove from registry
|
||||
@@ -149,7 +149,7 @@ class UserAPIHandler(APIHandler):
|
||||
if user.spawner._stop_pending:
|
||||
raise web.HTTPError(400, "%s's server is in the process of stopping, please wait." % name)
|
||||
|
||||
await gen.maybe_future(self.authenticator.delete_user(user))
|
||||
await awaitable(self.authenticator.delete_user(user))
|
||||
# remove from registry
|
||||
self.users.delete(user)
|
||||
|
||||
|
@@ -55,6 +55,7 @@ from .log import CoroutineLogFormatter, log_request
|
||||
from .proxy import Proxy, ConfigurableHTTPProxy
|
||||
from .traitlets import URLPrefix, Command
|
||||
from .utils import (
|
||||
awaitable,
|
||||
url_path_join,
|
||||
ISO8601_ms, ISO8601_s,
|
||||
print_stacks, print_ps_info,
|
||||
@@ -1046,7 +1047,7 @@ class JupyterHub(Application):
|
||||
# and persist across sessions.
|
||||
for user in db.query(orm.User):
|
||||
try:
|
||||
await gen.maybe_future(self.authenticator.add_user(user))
|
||||
await awaitable(self.authenticator.add_user(user))
|
||||
except Exception:
|
||||
self.log.exception("Error adding user %s already in db", user.name)
|
||||
if self.authenticator.delete_invalid_users:
|
||||
@@ -1077,7 +1078,7 @@ class JupyterHub(Application):
|
||||
db.add(group)
|
||||
for username in usernames:
|
||||
username = self.authenticator.normalize_username(username)
|
||||
if not (await gen.maybe_future(self.authenticator.check_whitelist(username))):
|
||||
if not (await awaitable(self.authenticator.check_whitelist(username))):
|
||||
raise ValueError("Username %r is not in whitelist" % username)
|
||||
user = orm.User.find(db, name=username)
|
||||
if user is None:
|
||||
@@ -1101,7 +1102,7 @@ class JupyterHub(Application):
|
||||
for token, name in token_dict.items():
|
||||
if kind == 'user':
|
||||
name = self.authenticator.normalize_username(name)
|
||||
if not (await gen.maybe_future(self.authenticator.check_whitelist(name))):
|
||||
if not (await awaitable(self.authenticator.check_whitelist(name))):
|
||||
raise ValueError("Token name %r is not in whitelist" % name)
|
||||
if not self.authenticator.validate_username(name):
|
||||
raise ValueError("Token name %r is not valid" % name)
|
||||
@@ -1491,7 +1492,7 @@ class JupyterHub(Application):
|
||||
# clean up proxy while single-user servers are shutting down
|
||||
if self.cleanup_proxy:
|
||||
if self.proxy.should_start:
|
||||
await gen.maybe_future(self.proxy.stop())
|
||||
await awaitable(self.proxy.stop())
|
||||
else:
|
||||
self.log.info("I didn't start the proxy, I can't clean it up")
|
||||
else:
|
||||
|
@@ -23,7 +23,7 @@ from traitlets.config import LoggingConfigurable
|
||||
from traitlets import Bool, Set, Unicode, Dict, Any, default, observe
|
||||
|
||||
from .handlers.login import LoginHandler
|
||||
from .utils import url_path_join
|
||||
from .utils import awaitable, url_path_join
|
||||
from .traitlets import Command
|
||||
|
||||
|
||||
@@ -244,7 +244,7 @@ class Authenticator(LoggingConfigurable):
|
||||
self.log.warning("Disallowing invalid username %r.", username)
|
||||
return
|
||||
|
||||
whitelist_pass = await gen.maybe_future(self.check_whitelist(username))
|
||||
whitelist_pass = await awaitable(self.check_whitelist(username))
|
||||
if whitelist_pass:
|
||||
return authenticated
|
||||
else:
|
||||
@@ -481,14 +481,14 @@ class LocalAuthenticator(Authenticator):
|
||||
|
||||
If self.create_system_users, the user will attempt to be created if it doesn't exist.
|
||||
"""
|
||||
user_exists = await gen.maybe_future(self.system_user_exists(user))
|
||||
user_exists = await awaitable(self.system_user_exists(user))
|
||||
if not user_exists:
|
||||
if self.create_system_users:
|
||||
await gen.maybe_future(self.add_system_user(user))
|
||||
await awaitable(self.add_system_user(user))
|
||||
else:
|
||||
raise KeyError("User %s does not exist." % user.name)
|
||||
|
||||
await gen.maybe_future(super().add_user(user))
|
||||
await awaitable(super().add_user(user))
|
||||
|
||||
@staticmethod
|
||||
def system_user_exists(user):
|
||||
|
@@ -19,6 +19,7 @@ except ImportError:
|
||||
class InvalidToken(Exception):
|
||||
pass
|
||||
|
||||
from .utils import awaitable
|
||||
|
||||
KEY_ENV = 'JUPYTERHUB_CRYPT_KEY'
|
||||
|
||||
@@ -132,7 +133,7 @@ class CryptKeeper(SingletonConfigurable):
|
||||
def encrypt(self, data):
|
||||
"""Encrypt an object with cryptography"""
|
||||
self.check_available()
|
||||
return self.executor.submit(self._encrypt, data)
|
||||
return awaitable(self.executor.submit(self._encrypt, data))
|
||||
|
||||
def _decrypt(self, encrypted):
|
||||
decrypted = self.fernet.decrypt(encrypted)
|
||||
@@ -141,7 +142,7 @@ class CryptKeeper(SingletonConfigurable):
|
||||
def decrypt(self, encrypted):
|
||||
"""Decrypt an object with cryptography"""
|
||||
self.check_available()
|
||||
return self.executor.submit(self._decrypt, encrypted)
|
||||
return awaitable(self.executor.submit(self._decrypt, encrypted))
|
||||
|
||||
|
||||
def encrypt(data):
|
||||
@@ -158,4 +159,3 @@ def decrypt(data):
|
||||
Returns a Future whose result will be the decrypted, deserialized data.
|
||||
"""
|
||||
return CryptKeeper.instance().decrypt(data)
|
||||
|
@@ -23,7 +23,7 @@ from .. import __version__
|
||||
from .. import orm
|
||||
from ..objects import Server
|
||||
from ..spawner import LocalProcessSpawner
|
||||
from ..utils import url_path_join
|
||||
from ..utils import awaitable, url_path_join
|
||||
from ..metrics import (
|
||||
SERVER_SPAWN_DURATION_SECONDS, ServerSpawnStatus,
|
||||
PROXY_ADD_DURATION_SECONDS, ProxyAddStatus
|
||||
@@ -387,7 +387,7 @@ class BaseHandler(RequestHandler):
|
||||
self.set_hub_cookie(user)
|
||||
|
||||
def authenticate(self, data):
|
||||
return gen.maybe_future(self.authenticator.get_authenticated_user(self, data))
|
||||
return awaitable(self.authenticator.get_authenticated_user(self, data))
|
||||
|
||||
def get_next_url(self, user=None):
|
||||
"""Get the next_url for login redirect
|
||||
@@ -421,7 +421,7 @@ class BaseHandler(RequestHandler):
|
||||
new_user = username not in self.users
|
||||
user = self.user_from_username(username)
|
||||
if new_user:
|
||||
await gen.maybe_future(self.authenticator.add_user(user))
|
||||
await awaitable(self.authenticator.add_user(user))
|
||||
# Only set `admin` if the authenticator returned an explicit value.
|
||||
if admin is not None and admin != user.admin:
|
||||
user.admin = admin
|
||||
@@ -577,7 +577,7 @@ class BaseHandler(RequestHandler):
|
||||
|
||||
# hook up spawner._spawn_future so that other requests can await
|
||||
# this result
|
||||
finish_spawn_future = spawner._spawn_future = finish_user_spawn()
|
||||
finish_spawn_future = spawner._spawn_future = awaitable(finish_user_spawn())
|
||||
def _clear_spawn_future(f):
|
||||
# clear spawner._spawn_future when it's done
|
||||
# keep an exception around, though, to prevent repeated implicit spawns
|
||||
|
@@ -28,7 +28,7 @@ from traitlets import (
|
||||
|
||||
from .objects import Server
|
||||
from .traitlets import Command, ByteSpecification, Callable
|
||||
from .utils import random_port, url_path_join, exponential_backoff
|
||||
from .utils import awaitable, random_port, url_path_join, exponential_backoff
|
||||
|
||||
|
||||
class Spawner(LoggingConfigurable):
|
||||
@@ -269,7 +269,7 @@ class Spawner(LoggingConfigurable):
|
||||
Introduced.
|
||||
"""
|
||||
if callable(self.options_form):
|
||||
options_form = await gen.maybe_future(self.options_form(self))
|
||||
options_form = await awaitable(self.options_form(self))
|
||||
else:
|
||||
options_form = self.options_form
|
||||
|
||||
@@ -783,7 +783,7 @@ class Spawner(LoggingConfigurable):
|
||||
|
||||
for callback in callbacks:
|
||||
try:
|
||||
await gen.maybe_future(callback())
|
||||
await awaitable(callback())
|
||||
except Exception:
|
||||
self.log.exception("Unhandled error in poll callback for %s", self)
|
||||
return status
|
||||
|
@@ -12,7 +12,7 @@ from tornado import gen
|
||||
from tornado.log import app_log
|
||||
from traitlets import HasTraits, Any, Dict, default
|
||||
|
||||
from .utils import url_path_join
|
||||
from .utils import awaitable, url_path_join
|
||||
|
||||
from . import orm
|
||||
from ._version import _check_version, __version__
|
||||
@@ -378,13 +378,13 @@ class User:
|
||||
# trigger pre-spawn hook on authenticator
|
||||
authenticator = self.authenticator
|
||||
if (authenticator):
|
||||
await gen.maybe_future(authenticator.pre_spawn_start(self, spawner))
|
||||
await awaitable(authenticator.pre_spawn_start(self, spawner))
|
||||
|
||||
spawner._start_pending = True
|
||||
# wait for spawner.start to return
|
||||
try:
|
||||
# run optional preparation work to bootstrap the notebook
|
||||
await gen.maybe_future(spawner.run_pre_spawn_hook())
|
||||
await awaitable(spawner.run_pre_spawn_hook())
|
||||
f = spawner.start()
|
||||
# commit any changes in spawner.start (always commit db changes before yield)
|
||||
db.commit()
|
||||
@@ -533,7 +533,7 @@ class User:
|
||||
auth = spawner.authenticator
|
||||
try:
|
||||
if auth:
|
||||
await gen.maybe_future(
|
||||
await awaitable(
|
||||
auth.post_spawn_stop(self, spawner)
|
||||
)
|
||||
except Exception:
|
||||
|
@@ -20,7 +20,7 @@ import uuid
|
||||
import warnings
|
||||
|
||||
from tornado import gen, ioloop, web
|
||||
from tornado.concurrent import to_asyncio_future
|
||||
from tornado.platform.asyncio import to_asyncio_future
|
||||
from tornado.httpclient import AsyncHTTPClient, HTTPError
|
||||
from tornado.log import app_log
|
||||
|
||||
@@ -123,7 +123,7 @@ async def exponential_backoff(
|
||||
deadline = random.uniform(deadline - tol, deadline + tol)
|
||||
scale = 1
|
||||
while True:
|
||||
ret = await gen.maybe_future(pass_func(*args, **kwargs))
|
||||
ret = await awaitable(pass_func(*args, **kwargs))
|
||||
# Truthy!
|
||||
if ret:
|
||||
return ret
|
||||
@@ -428,8 +428,8 @@ def awaitable(obj):
|
||||
- asyncio Future (works both ways)
|
||||
"""
|
||||
if inspect.isawaitable(obj):
|
||||
# return obj that's already awaitable
|
||||
return obj
|
||||
# already awaitable, use ensure_future
|
||||
return asyncio.ensure_future(obj)
|
||||
elif isinstance(obj, concurrent.futures.Future):
|
||||
return asyncio.wrap_future(obj)
|
||||
else:
|
||||
|
Reference in New Issue
Block a user