Replace @gen.coroutine/yield with async/await

This commit is contained in:
Erik Sundell
2020-11-04 04:41:48 +01:00
parent 2a1d341586
commit e1166ec834
10 changed files with 43 additions and 77 deletions

View File

@@ -235,10 +235,9 @@ to Spawner environment:
```python
class MyAuthenticator(Authenticator):
@gen.coroutine
def authenticate(self, handler, data=None):
username = yield identify_user(handler, data)
upstream_token = yield token_for_user(username)
async def authenticate(self, handler, data=None):
username = await identify_user(handler, data)
upstream_token = await token_for_user(username)
return {
'name': username,
'auth_state': {
@@ -246,10 +245,9 @@ class MyAuthenticator(Authenticator):
},
}
@gen.coroutine
def pre_spawn_start(self, user, spawner):
async def pre_spawn_start(self, user, spawner):
"""Pass upstream_token to spawner via environment variable"""
auth_state = yield user.get_auth_state()
auth_state = await user.get_auth_state()
if not auth_state:
# auth_state not enabled
return

View File

@@ -930,7 +930,7 @@ class JupyterHub(Application):
with an :meth:`authenticate` method that:
- is a coroutine (asyncio or tornado)
- is a coroutine (asyncio)
- returns username on success, None on failure
- takes two arguments: (handler, data),
where `handler` is the calling web.RequestHandler,

View File

@@ -23,7 +23,6 @@ from urllib.parse import quote
from urllib.parse import urlencode
import requests
from tornado.gen import coroutine
from tornado.httputil import url_concat
from tornado.log import app_log
from tornado.web import HTTPError
@@ -950,8 +949,7 @@ class HubOAuthCallbackHandler(HubOAuthenticated, RequestHandler):
.. versionadded: 0.8
"""
@coroutine
def get(self):
async def get(self):
error = self.get_argument("error", False)
if error:
msg = self.get_argument("error_description", error)

View File

@@ -1079,7 +1079,7 @@ class Spawner(LoggingConfigurable):
Return ip, port instead of setting on self.user.server directly.
"""
raise NotImplementedError(
"Override in subclass. Must be a Tornado gen.coroutine."
"Override in subclass. Must be a coroutine."
)
async def stop(self, now=False):
@@ -1094,7 +1094,7 @@ class Spawner(LoggingConfigurable):
Must be a coroutine.
"""
raise NotImplementedError(
"Override in subclass. Must be a Tornado gen.coroutine."
"Override in subclass. Must be a coroutine."
)
async def poll(self):
@@ -1122,7 +1122,7 @@ class Spawner(LoggingConfigurable):
"""
raise NotImplementedError(
"Override in subclass. Must be a Tornado gen.coroutine."
"Override in subclass. Must be a coroutine."
)
def add_poll_callback(self, callback, *args, **kwargs):

View File

@@ -36,7 +36,6 @@ from unittest import mock
from pytest import fixture
from pytest import raises
from tornado import gen
from tornado import ioloop
from tornado.httpclient import HTTPError
from tornado.platform.asyncio import AsyncIOMainLoop
@@ -55,16 +54,6 @@ from .utils import ssl_setup
_db = None
def pytest_collection_modifyitems(items):
"""add asyncio marker to all async tests"""
for item in items:
if inspect.iscoroutinefunction(item.obj):
item.add_marker('asyncio')
if hasattr(inspect, 'isasyncgenfunction'):
# double-check that we aren't mixing yield and async def
assert not inspect.isasyncgenfunction(item.obj)
@fixture(scope='module')
def ssl_tmpdir(tmpdir_factory):
return tmpdir_factory.mktemp('ssl')
@@ -244,17 +233,14 @@ def _mockservice(request, app, url=False):
assert name in app._service_map
service = app._service_map[name]
@gen.coroutine
def start():
async def start():
# wait for proxy to be updated before starting the service
yield app.proxy.add_all_services(app._service_map)
await app.proxy.add_all_services(app._service_map)
service.start()
io_loop.run_sync(start)
def cleanup():
import asyncio
asyncio.get_event_loop().run_until_complete(service.stop())
app.services[:] = []
app._service_map.clear()

View File

@@ -37,7 +37,6 @@ from urllib.parse import urlparse
from pamela import PAMError
from tornado import gen
from tornado.concurrent import Future
from tornado.ioloop import IOLoop
from traitlets import Bool
from traitlets import default
@@ -110,19 +109,17 @@ class SlowSpawner(MockSpawner):
delay = 2
_start_future = None
@gen.coroutine
def start(self):
(ip, port) = yield super().start()
async def start(self):
(ip, port) = await super().start()
if self._start_future is not None:
yield self._start_future
await self._start_future
else:
yield gen.sleep(self.delay)
await gen.sleep(self.delay)
return ip, port
@gen.coroutine
def stop(self):
yield gen.sleep(self.delay)
yield super().stop()
async def stop(self):
await gen.sleep(self.delay)
await super().stop()
class NeverSpawner(MockSpawner):
@@ -134,14 +131,12 @@ class NeverSpawner(MockSpawner):
def start(self):
"""Return a Future that will never finish"""
return Future()
return asyncio.Future()
@gen.coroutine
def stop(self):
async def stop(self):
pass
@gen.coroutine
def poll(self):
async def poll(self):
return 0
@@ -215,8 +210,7 @@ class MockPAMAuthenticator(PAMAuthenticator):
# skip the add-system-user bit
return not user.name.startswith('dne')
@gen.coroutine
def authenticate(self, *args, **kwargs):
async def authenticate(self, *args, **kwargs):
with mock.patch.multiple(
'pamela',
authenticate=mock_authenticate,
@@ -224,7 +218,7 @@ class MockPAMAuthenticator(PAMAuthenticator):
close_session=mock_open_session,
check_account=mock_check_account,
):
username = yield super(MockPAMAuthenticator, self).authenticate(
username = await super(MockPAMAuthenticator, self).authenticate(
*args, **kwargs
)
if username is None:
@@ -320,14 +314,13 @@ class MockHub(JupyterHub):
self.db.delete(group)
self.db.commit()
@gen.coroutine
def initialize(self, argv=None):
async def initialize(self, argv=None):
self.pid_file = NamedTemporaryFile(delete=False).name
self.db_file = NamedTemporaryFile()
self.db_url = os.getenv('JUPYTERHUB_TEST_DB_URL') or self.db_file.name
if 'mysql' in self.db_url:
self.db_kwargs['connect_args'] = {'auth_plugin': 'mysql_native_password'}
yield super().initialize([])
await super().initialize([])
# add an initial user
user = self.db.query(orm.User).filter(orm.User.name == 'user').first()
@@ -358,14 +351,13 @@ class MockHub(JupyterHub):
self.cleanup = lambda: None
self.db_file.close()
@gen.coroutine
def login_user(self, name):
async def login_user(self, name):
"""Login a user by name, returning her cookies."""
base_url = public_url(self)
external_ca = None
if self.internal_ssl:
external_ca = self.external_certs['files']['ca']
r = yield async_requests.post(
r = await async_requests.post(
base_url + 'hub/login',
data={'username': name, 'password': name},
allow_redirects=False,
@@ -407,8 +399,7 @@ class StubSingleUserSpawner(MockSpawner):
_thread = None
@gen.coroutine
def start(self):
async def start(self):
ip = self.ip = '127.0.0.1'
port = self.port = random_port()
env = self.get_env()
@@ -435,14 +426,12 @@ class StubSingleUserSpawner(MockSpawner):
assert ready
return (ip, port)
@gen.coroutine
def stop(self):
async def stop(self):
self._app.stop()
self._thread.join(timeout=30)
assert not self._thread.is_alive()
@gen.coroutine
def poll(self):
async def poll(self):
if self._thread is None:
return 0
if self._thread.is_alive():

View File

@@ -1,9 +1,9 @@
"""Tests for the REST API."""
import asyncio
import json
import re
import sys
import uuid
from concurrent.futures import Future
from datetime import datetime
from datetime import timedelta
from unittest import mock
@@ -885,8 +885,8 @@ async def test_spawn_limit(app, no_patience, slow_spawn, request):
# start two pending spawns
names = ['ykka', 'hjarka']
users = [add_user(db, app=app, name=name) for name in names]
users[0].spawner._start_future = Future()
users[1].spawner._start_future = Future()
users[0].spawner._start_future = asyncio.Future()
users[1].spawner._start_future = asyncio.Future()
for name in names:
await api_request(app, 'users', name, 'server', method='post')
assert app.users.count_active_users()['pending'] == 2
@@ -894,7 +894,7 @@ async def test_spawn_limit(app, no_patience, slow_spawn, request):
# ykka and hjarka's spawns are both pending. Essun should fail with 429
name = 'essun'
user = add_user(db, app=app, name=name)
user.spawner._start_future = Future()
user.spawner._start_future = asyncio.Future()
r = await api_request(app, 'users', name, 'server', method='post')
assert r.status_code == 429

View File

@@ -20,8 +20,7 @@ ssl_enabled = True
SSL_ERROR = (SSLError, ConnectionError)
@gen.coroutine
def wait_for_spawner(spawner, timeout=10):
async def wait_for_spawner(spawner, timeout=10):
"""Wait for an http server to show up
polling at shorter intervals for early termination
@@ -32,15 +31,15 @@ def wait_for_spawner(spawner, timeout=10):
return spawner.server.wait_up(timeout=1, http=True)
while time.monotonic() < deadline:
status = yield spawner.poll()
status = await spawner.poll()
assert status is None
try:
yield wait()
await wait()
except TimeoutError:
continue
else:
break
yield wait()
await wait()
async def test_connection_hub_wrong_certs(app):

View File

@@ -222,8 +222,7 @@ async def test_spawn_fails(db):
db.commit()
class BadSpawner(MockSpawner):
@gen.coroutine
def start(self):
async def start(self):
raise RuntimeError("Split the party")
user = User(

View File

@@ -586,8 +586,7 @@ async def test_login_strip(app):
base_url = public_url(app)
called_with = []
@gen.coroutine
def mock_authenticate(handler, data):
async def mock_authenticate(handler, data):
called_with.append(data)
with mock.patch.object(app.authenticator, 'authenticate', mock_authenticate):
@@ -943,8 +942,7 @@ async def test_pre_spawn_start_exc_no_form(app):
exc = "pre_spawn_start error"
# throw exception from pre_spawn_start
@gen.coroutine
def mock_pre_spawn_start(user, spawner):
async def mock_pre_spawn_start(user, spawner):
raise Exception(exc)
with mock.patch.object(app.authenticator, 'pre_spawn_start', mock_pre_spawn_start):
@@ -959,8 +957,7 @@ async def test_pre_spawn_start_exc_options_form(app):
exc = "pre_spawn_start error"
# throw exception from pre_spawn_start
@gen.coroutine
def mock_pre_spawn_start(user, spawner):
async def mock_pre_spawn_start(user, spawner):
raise Exception(exc)
with mock.patch.dict(