mirror of
https://github.com/jupyterhub/jupyterhub.git
synced 2025-10-18 07:23:00 +00:00
Replace @gen.coroutine/yield with async/await
This commit is contained in:
@@ -235,10 +235,9 @@ to Spawner environment:
|
||||
|
||||
```python
|
||||
class MyAuthenticator(Authenticator):
|
||||
@gen.coroutine
|
||||
def authenticate(self, handler, data=None):
|
||||
username = yield identify_user(handler, data)
|
||||
upstream_token = yield token_for_user(username)
|
||||
async def authenticate(self, handler, data=None):
|
||||
username = await identify_user(handler, data)
|
||||
upstream_token = await token_for_user(username)
|
||||
return {
|
||||
'name': username,
|
||||
'auth_state': {
|
||||
@@ -246,10 +245,9 @@ class MyAuthenticator(Authenticator):
|
||||
},
|
||||
}
|
||||
|
||||
@gen.coroutine
|
||||
def pre_spawn_start(self, user, spawner):
|
||||
async def pre_spawn_start(self, user, spawner):
|
||||
"""Pass upstream_token to spawner via environment variable"""
|
||||
auth_state = yield user.get_auth_state()
|
||||
auth_state = await user.get_auth_state()
|
||||
if not auth_state:
|
||||
# auth_state not enabled
|
||||
return
|
||||
|
@@ -930,7 +930,7 @@ class JupyterHub(Application):
|
||||
|
||||
with an :meth:`authenticate` method that:
|
||||
|
||||
- is a coroutine (asyncio or tornado)
|
||||
- is a coroutine (asyncio)
|
||||
- returns username on success, None on failure
|
||||
- takes two arguments: (handler, data),
|
||||
where `handler` is the calling web.RequestHandler,
|
||||
|
@@ -23,7 +23,6 @@ from urllib.parse import quote
|
||||
from urllib.parse import urlencode
|
||||
|
||||
import requests
|
||||
from tornado.gen import coroutine
|
||||
from tornado.httputil import url_concat
|
||||
from tornado.log import app_log
|
||||
from tornado.web import HTTPError
|
||||
@@ -950,8 +949,7 @@ class HubOAuthCallbackHandler(HubOAuthenticated, RequestHandler):
|
||||
.. versionadded: 0.8
|
||||
"""
|
||||
|
||||
@coroutine
|
||||
def get(self):
|
||||
async def get(self):
|
||||
error = self.get_argument("error", False)
|
||||
if error:
|
||||
msg = self.get_argument("error_description", error)
|
||||
|
@@ -1079,7 +1079,7 @@ class Spawner(LoggingConfigurable):
|
||||
Return ip, port instead of setting on self.user.server directly.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"Override in subclass. Must be a Tornado gen.coroutine."
|
||||
"Override in subclass. Must be a coroutine."
|
||||
)
|
||||
|
||||
async def stop(self, now=False):
|
||||
@@ -1094,7 +1094,7 @@ class Spawner(LoggingConfigurable):
|
||||
Must be a coroutine.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"Override in subclass. Must be a Tornado gen.coroutine."
|
||||
"Override in subclass. Must be a coroutine."
|
||||
)
|
||||
|
||||
async def poll(self):
|
||||
@@ -1122,7 +1122,7 @@ class Spawner(LoggingConfigurable):
|
||||
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"Override in subclass. Must be a Tornado gen.coroutine."
|
||||
"Override in subclass. Must be a coroutine."
|
||||
)
|
||||
|
||||
def add_poll_callback(self, callback, *args, **kwargs):
|
||||
|
@@ -36,7 +36,6 @@ from unittest import mock
|
||||
|
||||
from pytest import fixture
|
||||
from pytest import raises
|
||||
from tornado import gen
|
||||
from tornado import ioloop
|
||||
from tornado.httpclient import HTTPError
|
||||
from tornado.platform.asyncio import AsyncIOMainLoop
|
||||
@@ -55,16 +54,6 @@ from .utils import ssl_setup
|
||||
_db = None
|
||||
|
||||
|
||||
def pytest_collection_modifyitems(items):
|
||||
"""add asyncio marker to all async tests"""
|
||||
for item in items:
|
||||
if inspect.iscoroutinefunction(item.obj):
|
||||
item.add_marker('asyncio')
|
||||
if hasattr(inspect, 'isasyncgenfunction'):
|
||||
# double-check that we aren't mixing yield and async def
|
||||
assert not inspect.isasyncgenfunction(item.obj)
|
||||
|
||||
|
||||
@fixture(scope='module')
|
||||
def ssl_tmpdir(tmpdir_factory):
|
||||
return tmpdir_factory.mktemp('ssl')
|
||||
@@ -244,17 +233,14 @@ def _mockservice(request, app, url=False):
|
||||
assert name in app._service_map
|
||||
service = app._service_map[name]
|
||||
|
||||
@gen.coroutine
|
||||
def start():
|
||||
async def start():
|
||||
# wait for proxy to be updated before starting the service
|
||||
yield app.proxy.add_all_services(app._service_map)
|
||||
await app.proxy.add_all_services(app._service_map)
|
||||
service.start()
|
||||
|
||||
io_loop.run_sync(start)
|
||||
|
||||
def cleanup():
|
||||
import asyncio
|
||||
|
||||
asyncio.get_event_loop().run_until_complete(service.stop())
|
||||
app.services[:] = []
|
||||
app._service_map.clear()
|
||||
|
@@ -37,7 +37,6 @@ from urllib.parse import urlparse
|
||||
|
||||
from pamela import PAMError
|
||||
from tornado import gen
|
||||
from tornado.concurrent import Future
|
||||
from tornado.ioloop import IOLoop
|
||||
from traitlets import Bool
|
||||
from traitlets import default
|
||||
@@ -110,19 +109,17 @@ class SlowSpawner(MockSpawner):
|
||||
delay = 2
|
||||
_start_future = None
|
||||
|
||||
@gen.coroutine
|
||||
def start(self):
|
||||
(ip, port) = yield super().start()
|
||||
async def start(self):
|
||||
(ip, port) = await super().start()
|
||||
if self._start_future is not None:
|
||||
yield self._start_future
|
||||
await self._start_future
|
||||
else:
|
||||
yield gen.sleep(self.delay)
|
||||
await gen.sleep(self.delay)
|
||||
return ip, port
|
||||
|
||||
@gen.coroutine
|
||||
def stop(self):
|
||||
yield gen.sleep(self.delay)
|
||||
yield super().stop()
|
||||
async def stop(self):
|
||||
await gen.sleep(self.delay)
|
||||
await super().stop()
|
||||
|
||||
|
||||
class NeverSpawner(MockSpawner):
|
||||
@@ -134,14 +131,12 @@ class NeverSpawner(MockSpawner):
|
||||
|
||||
def start(self):
|
||||
"""Return a Future that will never finish"""
|
||||
return Future()
|
||||
return asyncio.Future()
|
||||
|
||||
@gen.coroutine
|
||||
def stop(self):
|
||||
async def stop(self):
|
||||
pass
|
||||
|
||||
@gen.coroutine
|
||||
def poll(self):
|
||||
async def poll(self):
|
||||
return 0
|
||||
|
||||
|
||||
@@ -215,8 +210,7 @@ class MockPAMAuthenticator(PAMAuthenticator):
|
||||
# skip the add-system-user bit
|
||||
return not user.name.startswith('dne')
|
||||
|
||||
@gen.coroutine
|
||||
def authenticate(self, *args, **kwargs):
|
||||
async def authenticate(self, *args, **kwargs):
|
||||
with mock.patch.multiple(
|
||||
'pamela',
|
||||
authenticate=mock_authenticate,
|
||||
@@ -224,7 +218,7 @@ class MockPAMAuthenticator(PAMAuthenticator):
|
||||
close_session=mock_open_session,
|
||||
check_account=mock_check_account,
|
||||
):
|
||||
username = yield super(MockPAMAuthenticator, self).authenticate(
|
||||
username = await super(MockPAMAuthenticator, self).authenticate(
|
||||
*args, **kwargs
|
||||
)
|
||||
if username is None:
|
||||
@@ -320,14 +314,13 @@ class MockHub(JupyterHub):
|
||||
self.db.delete(group)
|
||||
self.db.commit()
|
||||
|
||||
@gen.coroutine
|
||||
def initialize(self, argv=None):
|
||||
async def initialize(self, argv=None):
|
||||
self.pid_file = NamedTemporaryFile(delete=False).name
|
||||
self.db_file = NamedTemporaryFile()
|
||||
self.db_url = os.getenv('JUPYTERHUB_TEST_DB_URL') or self.db_file.name
|
||||
if 'mysql' in self.db_url:
|
||||
self.db_kwargs['connect_args'] = {'auth_plugin': 'mysql_native_password'}
|
||||
yield super().initialize([])
|
||||
await super().initialize([])
|
||||
|
||||
# add an initial user
|
||||
user = self.db.query(orm.User).filter(orm.User.name == 'user').first()
|
||||
@@ -358,14 +351,13 @@ class MockHub(JupyterHub):
|
||||
self.cleanup = lambda: None
|
||||
self.db_file.close()
|
||||
|
||||
@gen.coroutine
|
||||
def login_user(self, name):
|
||||
async def login_user(self, name):
|
||||
"""Login a user by name, returning her cookies."""
|
||||
base_url = public_url(self)
|
||||
external_ca = None
|
||||
if self.internal_ssl:
|
||||
external_ca = self.external_certs['files']['ca']
|
||||
r = yield async_requests.post(
|
||||
r = await async_requests.post(
|
||||
base_url + 'hub/login',
|
||||
data={'username': name, 'password': name},
|
||||
allow_redirects=False,
|
||||
@@ -407,8 +399,7 @@ class StubSingleUserSpawner(MockSpawner):
|
||||
|
||||
_thread = None
|
||||
|
||||
@gen.coroutine
|
||||
def start(self):
|
||||
async def start(self):
|
||||
ip = self.ip = '127.0.0.1'
|
||||
port = self.port = random_port()
|
||||
env = self.get_env()
|
||||
@@ -435,14 +426,12 @@ class StubSingleUserSpawner(MockSpawner):
|
||||
assert ready
|
||||
return (ip, port)
|
||||
|
||||
@gen.coroutine
|
||||
def stop(self):
|
||||
async def stop(self):
|
||||
self._app.stop()
|
||||
self._thread.join(timeout=30)
|
||||
assert not self._thread.is_alive()
|
||||
|
||||
@gen.coroutine
|
||||
def poll(self):
|
||||
async def poll(self):
|
||||
if self._thread is None:
|
||||
return 0
|
||||
if self._thread.is_alive():
|
||||
|
@@ -1,9 +1,9 @@
|
||||
"""Tests for the REST API."""
|
||||
import asyncio
|
||||
import json
|
||||
import re
|
||||
import sys
|
||||
import uuid
|
||||
from concurrent.futures import Future
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from unittest import mock
|
||||
@@ -885,8 +885,8 @@ async def test_spawn_limit(app, no_patience, slow_spawn, request):
|
||||
# start two pending spawns
|
||||
names = ['ykka', 'hjarka']
|
||||
users = [add_user(db, app=app, name=name) for name in names]
|
||||
users[0].spawner._start_future = Future()
|
||||
users[1].spawner._start_future = Future()
|
||||
users[0].spawner._start_future = asyncio.Future()
|
||||
users[1].spawner._start_future = asyncio.Future()
|
||||
for name in names:
|
||||
await api_request(app, 'users', name, 'server', method='post')
|
||||
assert app.users.count_active_users()['pending'] == 2
|
||||
@@ -894,7 +894,7 @@ async def test_spawn_limit(app, no_patience, slow_spawn, request):
|
||||
# ykka and hjarka's spawns are both pending. Essun should fail with 429
|
||||
name = 'essun'
|
||||
user = add_user(db, app=app, name=name)
|
||||
user.spawner._start_future = Future()
|
||||
user.spawner._start_future = asyncio.Future()
|
||||
r = await api_request(app, 'users', name, 'server', method='post')
|
||||
assert r.status_code == 429
|
||||
|
||||
|
@@ -20,8 +20,7 @@ ssl_enabled = True
|
||||
SSL_ERROR = (SSLError, ConnectionError)
|
||||
|
||||
|
||||
@gen.coroutine
|
||||
def wait_for_spawner(spawner, timeout=10):
|
||||
async def wait_for_spawner(spawner, timeout=10):
|
||||
"""Wait for an http server to show up
|
||||
|
||||
polling at shorter intervals for early termination
|
||||
@@ -32,15 +31,15 @@ def wait_for_spawner(spawner, timeout=10):
|
||||
return spawner.server.wait_up(timeout=1, http=True)
|
||||
|
||||
while time.monotonic() < deadline:
|
||||
status = yield spawner.poll()
|
||||
status = await spawner.poll()
|
||||
assert status is None
|
||||
try:
|
||||
yield wait()
|
||||
await wait()
|
||||
except TimeoutError:
|
||||
continue
|
||||
else:
|
||||
break
|
||||
yield wait()
|
||||
await wait()
|
||||
|
||||
|
||||
async def test_connection_hub_wrong_certs(app):
|
||||
|
@@ -222,8 +222,7 @@ async def test_spawn_fails(db):
|
||||
db.commit()
|
||||
|
||||
class BadSpawner(MockSpawner):
|
||||
@gen.coroutine
|
||||
def start(self):
|
||||
async def start(self):
|
||||
raise RuntimeError("Split the party")
|
||||
|
||||
user = User(
|
||||
|
@@ -586,8 +586,7 @@ async def test_login_strip(app):
|
||||
base_url = public_url(app)
|
||||
called_with = []
|
||||
|
||||
@gen.coroutine
|
||||
def mock_authenticate(handler, data):
|
||||
async def mock_authenticate(handler, data):
|
||||
called_with.append(data)
|
||||
|
||||
with mock.patch.object(app.authenticator, 'authenticate', mock_authenticate):
|
||||
@@ -943,8 +942,7 @@ async def test_pre_spawn_start_exc_no_form(app):
|
||||
exc = "pre_spawn_start error"
|
||||
|
||||
# throw exception from pre_spawn_start
|
||||
@gen.coroutine
|
||||
def mock_pre_spawn_start(user, spawner):
|
||||
async def mock_pre_spawn_start(user, spawner):
|
||||
raise Exception(exc)
|
||||
|
||||
with mock.patch.object(app.authenticator, 'pre_spawn_start', mock_pre_spawn_start):
|
||||
@@ -959,8 +957,7 @@ async def test_pre_spawn_start_exc_options_form(app):
|
||||
exc = "pre_spawn_start error"
|
||||
|
||||
# throw exception from pre_spawn_start
|
||||
@gen.coroutine
|
||||
def mock_pre_spawn_start(user, spawner):
|
||||
async def mock_pre_spawn_start(user, spawner):
|
||||
raise Exception(exc)
|
||||
|
||||
with mock.patch.dict(
|
||||
|
Reference in New Issue
Block a user