Replace @gen.coroutine/yield with async/await

This commit is contained in:
Erik Sundell
2020-11-04 04:41:48 +01:00
parent 2a1d341586
commit e1166ec834
10 changed files with 43 additions and 77 deletions

View File

@@ -235,10 +235,9 @@ to Spawner environment:
```python ```python
class MyAuthenticator(Authenticator): class MyAuthenticator(Authenticator):
@gen.coroutine async def authenticate(self, handler, data=None):
def authenticate(self, handler, data=None): username = await identify_user(handler, data)
username = yield identify_user(handler, data) upstream_token = await token_for_user(username)
upstream_token = yield token_for_user(username)
return { return {
'name': username, 'name': username,
'auth_state': { 'auth_state': {
@@ -246,10 +245,9 @@ class MyAuthenticator(Authenticator):
}, },
} }
@gen.coroutine async def pre_spawn_start(self, user, spawner):
def pre_spawn_start(self, user, spawner):
"""Pass upstream_token to spawner via environment variable""" """Pass upstream_token to spawner via environment variable"""
auth_state = yield user.get_auth_state() auth_state = await user.get_auth_state()
if not auth_state: if not auth_state:
# auth_state not enabled # auth_state not enabled
return return

View File

@@ -930,7 +930,7 @@ class JupyterHub(Application):
with an :meth:`authenticate` method that: with an :meth:`authenticate` method that:
- is a coroutine (asyncio or tornado) - is a coroutine (asyncio)
- returns username on success, None on failure - returns username on success, None on failure
- takes two arguments: (handler, data), - takes two arguments: (handler, data),
where `handler` is the calling web.RequestHandler, where `handler` is the calling web.RequestHandler,

View File

@@ -23,7 +23,6 @@ from urllib.parse import quote
from urllib.parse import urlencode from urllib.parse import urlencode
import requests import requests
from tornado.gen import coroutine
from tornado.httputil import url_concat from tornado.httputil import url_concat
from tornado.log import app_log from tornado.log import app_log
from tornado.web import HTTPError from tornado.web import HTTPError
@@ -950,8 +949,7 @@ class HubOAuthCallbackHandler(HubOAuthenticated, RequestHandler):
.. versionadded: 0.8 .. versionadded: 0.8
""" """
@coroutine async def get(self):
def get(self):
error = self.get_argument("error", False) error = self.get_argument("error", False)
if error: if error:
msg = self.get_argument("error_description", error) msg = self.get_argument("error_description", error)

View File

@@ -1079,7 +1079,7 @@ class Spawner(LoggingConfigurable):
Return ip, port instead of setting on self.user.server directly. Return ip, port instead of setting on self.user.server directly.
""" """
raise NotImplementedError( raise NotImplementedError(
"Override in subclass. Must be a Tornado gen.coroutine." "Override in subclass. Must be a coroutine."
) )
async def stop(self, now=False): async def stop(self, now=False):
@@ -1094,7 +1094,7 @@ class Spawner(LoggingConfigurable):
Must be a coroutine. Must be a coroutine.
""" """
raise NotImplementedError( raise NotImplementedError(
"Override in subclass. Must be a Tornado gen.coroutine." "Override in subclass. Must be a coroutine."
) )
async def poll(self): async def poll(self):
@@ -1122,7 +1122,7 @@ class Spawner(LoggingConfigurable):
""" """
raise NotImplementedError( raise NotImplementedError(
"Override in subclass. Must be a Tornado gen.coroutine." "Override in subclass. Must be a coroutine."
) )
def add_poll_callback(self, callback, *args, **kwargs): def add_poll_callback(self, callback, *args, **kwargs):

View File

@@ -36,7 +36,6 @@ from unittest import mock
from pytest import fixture from pytest import fixture
from pytest import raises from pytest import raises
from tornado import gen
from tornado import ioloop from tornado import ioloop
from tornado.httpclient import HTTPError from tornado.httpclient import HTTPError
from tornado.platform.asyncio import AsyncIOMainLoop from tornado.platform.asyncio import AsyncIOMainLoop
@@ -55,16 +54,6 @@ from .utils import ssl_setup
_db = None _db = None
def pytest_collection_modifyitems(items):
"""add asyncio marker to all async tests"""
for item in items:
if inspect.iscoroutinefunction(item.obj):
item.add_marker('asyncio')
if hasattr(inspect, 'isasyncgenfunction'):
# double-check that we aren't mixing yield and async def
assert not inspect.isasyncgenfunction(item.obj)
@fixture(scope='module') @fixture(scope='module')
def ssl_tmpdir(tmpdir_factory): def ssl_tmpdir(tmpdir_factory):
return tmpdir_factory.mktemp('ssl') return tmpdir_factory.mktemp('ssl')
@@ -244,17 +233,14 @@ def _mockservice(request, app, url=False):
assert name in app._service_map assert name in app._service_map
service = app._service_map[name] service = app._service_map[name]
@gen.coroutine async def start():
def start():
# wait for proxy to be updated before starting the service # wait for proxy to be updated before starting the service
yield app.proxy.add_all_services(app._service_map) await app.proxy.add_all_services(app._service_map)
service.start() service.start()
io_loop.run_sync(start) io_loop.run_sync(start)
def cleanup(): def cleanup():
import asyncio
asyncio.get_event_loop().run_until_complete(service.stop()) asyncio.get_event_loop().run_until_complete(service.stop())
app.services[:] = [] app.services[:] = []
app._service_map.clear() app._service_map.clear()

View File

@@ -37,7 +37,6 @@ from urllib.parse import urlparse
from pamela import PAMError from pamela import PAMError
from tornado import gen from tornado import gen
from tornado.concurrent import Future
from tornado.ioloop import IOLoop from tornado.ioloop import IOLoop
from traitlets import Bool from traitlets import Bool
from traitlets import default from traitlets import default
@@ -110,19 +109,17 @@ class SlowSpawner(MockSpawner):
delay = 2 delay = 2
_start_future = None _start_future = None
@gen.coroutine async def start(self):
def start(self): (ip, port) = await super().start()
(ip, port) = yield super().start()
if self._start_future is not None: if self._start_future is not None:
yield self._start_future await self._start_future
else: else:
yield gen.sleep(self.delay) await gen.sleep(self.delay)
return ip, port return ip, port
@gen.coroutine async def stop(self):
def stop(self): await gen.sleep(self.delay)
yield gen.sleep(self.delay) await super().stop()
yield super().stop()
class NeverSpawner(MockSpawner): class NeverSpawner(MockSpawner):
@@ -134,14 +131,12 @@ class NeverSpawner(MockSpawner):
def start(self): def start(self):
"""Return a Future that will never finish""" """Return a Future that will never finish"""
return Future() return asyncio.Future()
@gen.coroutine async def stop(self):
def stop(self):
pass pass
@gen.coroutine async def poll(self):
def poll(self):
return 0 return 0
@@ -215,8 +210,7 @@ class MockPAMAuthenticator(PAMAuthenticator):
# skip the add-system-user bit # skip the add-system-user bit
return not user.name.startswith('dne') return not user.name.startswith('dne')
@gen.coroutine async def authenticate(self, *args, **kwargs):
def authenticate(self, *args, **kwargs):
with mock.patch.multiple( with mock.patch.multiple(
'pamela', 'pamela',
authenticate=mock_authenticate, authenticate=mock_authenticate,
@@ -224,7 +218,7 @@ class MockPAMAuthenticator(PAMAuthenticator):
close_session=mock_open_session, close_session=mock_open_session,
check_account=mock_check_account, check_account=mock_check_account,
): ):
username = yield super(MockPAMAuthenticator, self).authenticate( username = await super(MockPAMAuthenticator, self).authenticate(
*args, **kwargs *args, **kwargs
) )
if username is None: if username is None:
@@ -320,14 +314,13 @@ class MockHub(JupyterHub):
self.db.delete(group) self.db.delete(group)
self.db.commit() self.db.commit()
@gen.coroutine async def initialize(self, argv=None):
def initialize(self, argv=None):
self.pid_file = NamedTemporaryFile(delete=False).name self.pid_file = NamedTemporaryFile(delete=False).name
self.db_file = NamedTemporaryFile() self.db_file = NamedTemporaryFile()
self.db_url = os.getenv('JUPYTERHUB_TEST_DB_URL') or self.db_file.name self.db_url = os.getenv('JUPYTERHUB_TEST_DB_URL') or self.db_file.name
if 'mysql' in self.db_url: if 'mysql' in self.db_url:
self.db_kwargs['connect_args'] = {'auth_plugin': 'mysql_native_password'} self.db_kwargs['connect_args'] = {'auth_plugin': 'mysql_native_password'}
yield super().initialize([]) await super().initialize([])
# add an initial user # add an initial user
user = self.db.query(orm.User).filter(orm.User.name == 'user').first() user = self.db.query(orm.User).filter(orm.User.name == 'user').first()
@@ -358,14 +351,13 @@ class MockHub(JupyterHub):
self.cleanup = lambda: None self.cleanup = lambda: None
self.db_file.close() self.db_file.close()
@gen.coroutine async def login_user(self, name):
def login_user(self, name):
"""Login a user by name, returning her cookies.""" """Login a user by name, returning her cookies."""
base_url = public_url(self) base_url = public_url(self)
external_ca = None external_ca = None
if self.internal_ssl: if self.internal_ssl:
external_ca = self.external_certs['files']['ca'] external_ca = self.external_certs['files']['ca']
r = yield async_requests.post( r = await async_requests.post(
base_url + 'hub/login', base_url + 'hub/login',
data={'username': name, 'password': name}, data={'username': name, 'password': name},
allow_redirects=False, allow_redirects=False,
@@ -407,8 +399,7 @@ class StubSingleUserSpawner(MockSpawner):
_thread = None _thread = None
@gen.coroutine async def start(self):
def start(self):
ip = self.ip = '127.0.0.1' ip = self.ip = '127.0.0.1'
port = self.port = random_port() port = self.port = random_port()
env = self.get_env() env = self.get_env()
@@ -435,14 +426,12 @@ class StubSingleUserSpawner(MockSpawner):
assert ready assert ready
return (ip, port) return (ip, port)
@gen.coroutine async def stop(self):
def stop(self):
self._app.stop() self._app.stop()
self._thread.join(timeout=30) self._thread.join(timeout=30)
assert not self._thread.is_alive() assert not self._thread.is_alive()
@gen.coroutine async def poll(self):
def poll(self):
if self._thread is None: if self._thread is None:
return 0 return 0
if self._thread.is_alive(): if self._thread.is_alive():

View File

@@ -1,9 +1,9 @@
"""Tests for the REST API.""" """Tests for the REST API."""
import asyncio
import json import json
import re import re
import sys import sys
import uuid import uuid
from concurrent.futures import Future
from datetime import datetime from datetime import datetime
from datetime import timedelta from datetime import timedelta
from unittest import mock from unittest import mock
@@ -885,8 +885,8 @@ async def test_spawn_limit(app, no_patience, slow_spawn, request):
# start two pending spawns # start two pending spawns
names = ['ykka', 'hjarka'] names = ['ykka', 'hjarka']
users = [add_user(db, app=app, name=name) for name in names] users = [add_user(db, app=app, name=name) for name in names]
users[0].spawner._start_future = Future() users[0].spawner._start_future = asyncio.Future()
users[1].spawner._start_future = Future() users[1].spawner._start_future = asyncio.Future()
for name in names: for name in names:
await api_request(app, 'users', name, 'server', method='post') await api_request(app, 'users', name, 'server', method='post')
assert app.users.count_active_users()['pending'] == 2 assert app.users.count_active_users()['pending'] == 2
@@ -894,7 +894,7 @@ async def test_spawn_limit(app, no_patience, slow_spawn, request):
# ykka and hjarka's spawns are both pending. Essun should fail with 429 # ykka and hjarka's spawns are both pending. Essun should fail with 429
name = 'essun' name = 'essun'
user = add_user(db, app=app, name=name) user = add_user(db, app=app, name=name)
user.spawner._start_future = Future() user.spawner._start_future = asyncio.Future()
r = await api_request(app, 'users', name, 'server', method='post') r = await api_request(app, 'users', name, 'server', method='post')
assert r.status_code == 429 assert r.status_code == 429

View File

@@ -20,8 +20,7 @@ ssl_enabled = True
SSL_ERROR = (SSLError, ConnectionError) SSL_ERROR = (SSLError, ConnectionError)
@gen.coroutine async def wait_for_spawner(spawner, timeout=10):
def wait_for_spawner(spawner, timeout=10):
"""Wait for an http server to show up """Wait for an http server to show up
polling at shorter intervals for early termination polling at shorter intervals for early termination
@@ -32,15 +31,15 @@ def wait_for_spawner(spawner, timeout=10):
return spawner.server.wait_up(timeout=1, http=True) return spawner.server.wait_up(timeout=1, http=True)
while time.monotonic() < deadline: while time.monotonic() < deadline:
status = yield spawner.poll() status = await spawner.poll()
assert status is None assert status is None
try: try:
yield wait() await wait()
except TimeoutError: except TimeoutError:
continue continue
else: else:
break break
yield wait() await wait()
async def test_connection_hub_wrong_certs(app): async def test_connection_hub_wrong_certs(app):

View File

@@ -222,8 +222,7 @@ async def test_spawn_fails(db):
db.commit() db.commit()
class BadSpawner(MockSpawner): class BadSpawner(MockSpawner):
@gen.coroutine async def start(self):
def start(self):
raise RuntimeError("Split the party") raise RuntimeError("Split the party")
user = User( user = User(

View File

@@ -586,8 +586,7 @@ async def test_login_strip(app):
base_url = public_url(app) base_url = public_url(app)
called_with = [] called_with = []
@gen.coroutine async def mock_authenticate(handler, data):
def mock_authenticate(handler, data):
called_with.append(data) called_with.append(data)
with mock.patch.object(app.authenticator, 'authenticate', mock_authenticate): with mock.patch.object(app.authenticator, 'authenticate', mock_authenticate):
@@ -943,8 +942,7 @@ async def test_pre_spawn_start_exc_no_form(app):
exc = "pre_spawn_start error" exc = "pre_spawn_start error"
# throw exception from pre_spawn_start # throw exception from pre_spawn_start
@gen.coroutine async def mock_pre_spawn_start(user, spawner):
def mock_pre_spawn_start(user, spawner):
raise Exception(exc) raise Exception(exc)
with mock.patch.object(app.authenticator, 'pre_spawn_start', mock_pre_spawn_start): with mock.patch.object(app.authenticator, 'pre_spawn_start', mock_pre_spawn_start):
@@ -959,8 +957,7 @@ async def test_pre_spawn_start_exc_options_form(app):
exc = "pre_spawn_start error" exc = "pre_spawn_start error"
# throw exception from pre_spawn_start # throw exception from pre_spawn_start
@gen.coroutine async def mock_pre_spawn_start(user, spawner):
def mock_pre_spawn_start(user, spawner):
raise Exception(exc) raise Exception(exc)
with mock.patch.dict( with mock.patch.dict(