mirror of
https://github.com/jupyterhub/jupyterhub.git
synced 2025-10-18 15:33:02 +00:00
Replace @gen.coroutine/yield with async/await
This commit is contained in:
@@ -235,10 +235,9 @@ to Spawner environment:
|
|||||||
|
|
||||||
```python
|
```python
|
||||||
class MyAuthenticator(Authenticator):
|
class MyAuthenticator(Authenticator):
|
||||||
@gen.coroutine
|
async def authenticate(self, handler, data=None):
|
||||||
def authenticate(self, handler, data=None):
|
username = await identify_user(handler, data)
|
||||||
username = yield identify_user(handler, data)
|
upstream_token = await token_for_user(username)
|
||||||
upstream_token = yield token_for_user(username)
|
|
||||||
return {
|
return {
|
||||||
'name': username,
|
'name': username,
|
||||||
'auth_state': {
|
'auth_state': {
|
||||||
@@ -246,10 +245,9 @@ class MyAuthenticator(Authenticator):
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@gen.coroutine
|
async def pre_spawn_start(self, user, spawner):
|
||||||
def pre_spawn_start(self, user, spawner):
|
|
||||||
"""Pass upstream_token to spawner via environment variable"""
|
"""Pass upstream_token to spawner via environment variable"""
|
||||||
auth_state = yield user.get_auth_state()
|
auth_state = await user.get_auth_state()
|
||||||
if not auth_state:
|
if not auth_state:
|
||||||
# auth_state not enabled
|
# auth_state not enabled
|
||||||
return
|
return
|
||||||
|
@@ -930,7 +930,7 @@ class JupyterHub(Application):
|
|||||||
|
|
||||||
with an :meth:`authenticate` method that:
|
with an :meth:`authenticate` method that:
|
||||||
|
|
||||||
- is a coroutine (asyncio or tornado)
|
- is a coroutine (asyncio)
|
||||||
- returns username on success, None on failure
|
- returns username on success, None on failure
|
||||||
- takes two arguments: (handler, data),
|
- takes two arguments: (handler, data),
|
||||||
where `handler` is the calling web.RequestHandler,
|
where `handler` is the calling web.RequestHandler,
|
||||||
|
@@ -23,7 +23,6 @@ from urllib.parse import quote
|
|||||||
from urllib.parse import urlencode
|
from urllib.parse import urlencode
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from tornado.gen import coroutine
|
|
||||||
from tornado.httputil import url_concat
|
from tornado.httputil import url_concat
|
||||||
from tornado.log import app_log
|
from tornado.log import app_log
|
||||||
from tornado.web import HTTPError
|
from tornado.web import HTTPError
|
||||||
@@ -950,8 +949,7 @@ class HubOAuthCallbackHandler(HubOAuthenticated, RequestHandler):
|
|||||||
.. versionadded: 0.8
|
.. versionadded: 0.8
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@coroutine
|
async def get(self):
|
||||||
def get(self):
|
|
||||||
error = self.get_argument("error", False)
|
error = self.get_argument("error", False)
|
||||||
if error:
|
if error:
|
||||||
msg = self.get_argument("error_description", error)
|
msg = self.get_argument("error_description", error)
|
||||||
|
@@ -1079,7 +1079,7 @@ class Spawner(LoggingConfigurable):
|
|||||||
Return ip, port instead of setting on self.user.server directly.
|
Return ip, port instead of setting on self.user.server directly.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"Override in subclass. Must be a Tornado gen.coroutine."
|
"Override in subclass. Must be a coroutine."
|
||||||
)
|
)
|
||||||
|
|
||||||
async def stop(self, now=False):
|
async def stop(self, now=False):
|
||||||
@@ -1094,7 +1094,7 @@ class Spawner(LoggingConfigurable):
|
|||||||
Must be a coroutine.
|
Must be a coroutine.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"Override in subclass. Must be a Tornado gen.coroutine."
|
"Override in subclass. Must be a coroutine."
|
||||||
)
|
)
|
||||||
|
|
||||||
async def poll(self):
|
async def poll(self):
|
||||||
@@ -1122,7 +1122,7 @@ class Spawner(LoggingConfigurable):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"Override in subclass. Must be a Tornado gen.coroutine."
|
"Override in subclass. Must be a coroutine."
|
||||||
)
|
)
|
||||||
|
|
||||||
def add_poll_callback(self, callback, *args, **kwargs):
|
def add_poll_callback(self, callback, *args, **kwargs):
|
||||||
|
@@ -36,7 +36,6 @@ from unittest import mock
|
|||||||
|
|
||||||
from pytest import fixture
|
from pytest import fixture
|
||||||
from pytest import raises
|
from pytest import raises
|
||||||
from tornado import gen
|
|
||||||
from tornado import ioloop
|
from tornado import ioloop
|
||||||
from tornado.httpclient import HTTPError
|
from tornado.httpclient import HTTPError
|
||||||
from tornado.platform.asyncio import AsyncIOMainLoop
|
from tornado.platform.asyncio import AsyncIOMainLoop
|
||||||
@@ -55,16 +54,6 @@ from .utils import ssl_setup
|
|||||||
_db = None
|
_db = None
|
||||||
|
|
||||||
|
|
||||||
def pytest_collection_modifyitems(items):
|
|
||||||
"""add asyncio marker to all async tests"""
|
|
||||||
for item in items:
|
|
||||||
if inspect.iscoroutinefunction(item.obj):
|
|
||||||
item.add_marker('asyncio')
|
|
||||||
if hasattr(inspect, 'isasyncgenfunction'):
|
|
||||||
# double-check that we aren't mixing yield and async def
|
|
||||||
assert not inspect.isasyncgenfunction(item.obj)
|
|
||||||
|
|
||||||
|
|
||||||
@fixture(scope='module')
|
@fixture(scope='module')
|
||||||
def ssl_tmpdir(tmpdir_factory):
|
def ssl_tmpdir(tmpdir_factory):
|
||||||
return tmpdir_factory.mktemp('ssl')
|
return tmpdir_factory.mktemp('ssl')
|
||||||
@@ -244,17 +233,14 @@ def _mockservice(request, app, url=False):
|
|||||||
assert name in app._service_map
|
assert name in app._service_map
|
||||||
service = app._service_map[name]
|
service = app._service_map[name]
|
||||||
|
|
||||||
@gen.coroutine
|
async def start():
|
||||||
def start():
|
|
||||||
# wait for proxy to be updated before starting the service
|
# wait for proxy to be updated before starting the service
|
||||||
yield app.proxy.add_all_services(app._service_map)
|
await app.proxy.add_all_services(app._service_map)
|
||||||
service.start()
|
service.start()
|
||||||
|
|
||||||
io_loop.run_sync(start)
|
io_loop.run_sync(start)
|
||||||
|
|
||||||
def cleanup():
|
def cleanup():
|
||||||
import asyncio
|
|
||||||
|
|
||||||
asyncio.get_event_loop().run_until_complete(service.stop())
|
asyncio.get_event_loop().run_until_complete(service.stop())
|
||||||
app.services[:] = []
|
app.services[:] = []
|
||||||
app._service_map.clear()
|
app._service_map.clear()
|
||||||
|
@@ -37,7 +37,6 @@ from urllib.parse import urlparse
|
|||||||
|
|
||||||
from pamela import PAMError
|
from pamela import PAMError
|
||||||
from tornado import gen
|
from tornado import gen
|
||||||
from tornado.concurrent import Future
|
|
||||||
from tornado.ioloop import IOLoop
|
from tornado.ioloop import IOLoop
|
||||||
from traitlets import Bool
|
from traitlets import Bool
|
||||||
from traitlets import default
|
from traitlets import default
|
||||||
@@ -110,19 +109,17 @@ class SlowSpawner(MockSpawner):
|
|||||||
delay = 2
|
delay = 2
|
||||||
_start_future = None
|
_start_future = None
|
||||||
|
|
||||||
@gen.coroutine
|
async def start(self):
|
||||||
def start(self):
|
(ip, port) = await super().start()
|
||||||
(ip, port) = yield super().start()
|
|
||||||
if self._start_future is not None:
|
if self._start_future is not None:
|
||||||
yield self._start_future
|
await self._start_future
|
||||||
else:
|
else:
|
||||||
yield gen.sleep(self.delay)
|
await gen.sleep(self.delay)
|
||||||
return ip, port
|
return ip, port
|
||||||
|
|
||||||
@gen.coroutine
|
async def stop(self):
|
||||||
def stop(self):
|
await gen.sleep(self.delay)
|
||||||
yield gen.sleep(self.delay)
|
await super().stop()
|
||||||
yield super().stop()
|
|
||||||
|
|
||||||
|
|
||||||
class NeverSpawner(MockSpawner):
|
class NeverSpawner(MockSpawner):
|
||||||
@@ -134,14 +131,12 @@ class NeverSpawner(MockSpawner):
|
|||||||
|
|
||||||
def start(self):
|
def start(self):
|
||||||
"""Return a Future that will never finish"""
|
"""Return a Future that will never finish"""
|
||||||
return Future()
|
return asyncio.Future()
|
||||||
|
|
||||||
@gen.coroutine
|
async def stop(self):
|
||||||
def stop(self):
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@gen.coroutine
|
async def poll(self):
|
||||||
def poll(self):
|
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
|
||||||
@@ -215,8 +210,7 @@ class MockPAMAuthenticator(PAMAuthenticator):
|
|||||||
# skip the add-system-user bit
|
# skip the add-system-user bit
|
||||||
return not user.name.startswith('dne')
|
return not user.name.startswith('dne')
|
||||||
|
|
||||||
@gen.coroutine
|
async def authenticate(self, *args, **kwargs):
|
||||||
def authenticate(self, *args, **kwargs):
|
|
||||||
with mock.patch.multiple(
|
with mock.patch.multiple(
|
||||||
'pamela',
|
'pamela',
|
||||||
authenticate=mock_authenticate,
|
authenticate=mock_authenticate,
|
||||||
@@ -224,7 +218,7 @@ class MockPAMAuthenticator(PAMAuthenticator):
|
|||||||
close_session=mock_open_session,
|
close_session=mock_open_session,
|
||||||
check_account=mock_check_account,
|
check_account=mock_check_account,
|
||||||
):
|
):
|
||||||
username = yield super(MockPAMAuthenticator, self).authenticate(
|
username = await super(MockPAMAuthenticator, self).authenticate(
|
||||||
*args, **kwargs
|
*args, **kwargs
|
||||||
)
|
)
|
||||||
if username is None:
|
if username is None:
|
||||||
@@ -320,14 +314,13 @@ class MockHub(JupyterHub):
|
|||||||
self.db.delete(group)
|
self.db.delete(group)
|
||||||
self.db.commit()
|
self.db.commit()
|
||||||
|
|
||||||
@gen.coroutine
|
async def initialize(self, argv=None):
|
||||||
def initialize(self, argv=None):
|
|
||||||
self.pid_file = NamedTemporaryFile(delete=False).name
|
self.pid_file = NamedTemporaryFile(delete=False).name
|
||||||
self.db_file = NamedTemporaryFile()
|
self.db_file = NamedTemporaryFile()
|
||||||
self.db_url = os.getenv('JUPYTERHUB_TEST_DB_URL') or self.db_file.name
|
self.db_url = os.getenv('JUPYTERHUB_TEST_DB_URL') or self.db_file.name
|
||||||
if 'mysql' in self.db_url:
|
if 'mysql' in self.db_url:
|
||||||
self.db_kwargs['connect_args'] = {'auth_plugin': 'mysql_native_password'}
|
self.db_kwargs['connect_args'] = {'auth_plugin': 'mysql_native_password'}
|
||||||
yield super().initialize([])
|
await super().initialize([])
|
||||||
|
|
||||||
# add an initial user
|
# add an initial user
|
||||||
user = self.db.query(orm.User).filter(orm.User.name == 'user').first()
|
user = self.db.query(orm.User).filter(orm.User.name == 'user').first()
|
||||||
@@ -358,14 +351,13 @@ class MockHub(JupyterHub):
|
|||||||
self.cleanup = lambda: None
|
self.cleanup = lambda: None
|
||||||
self.db_file.close()
|
self.db_file.close()
|
||||||
|
|
||||||
@gen.coroutine
|
async def login_user(self, name):
|
||||||
def login_user(self, name):
|
|
||||||
"""Login a user by name, returning her cookies."""
|
"""Login a user by name, returning her cookies."""
|
||||||
base_url = public_url(self)
|
base_url = public_url(self)
|
||||||
external_ca = None
|
external_ca = None
|
||||||
if self.internal_ssl:
|
if self.internal_ssl:
|
||||||
external_ca = self.external_certs['files']['ca']
|
external_ca = self.external_certs['files']['ca']
|
||||||
r = yield async_requests.post(
|
r = await async_requests.post(
|
||||||
base_url + 'hub/login',
|
base_url + 'hub/login',
|
||||||
data={'username': name, 'password': name},
|
data={'username': name, 'password': name},
|
||||||
allow_redirects=False,
|
allow_redirects=False,
|
||||||
@@ -407,8 +399,7 @@ class StubSingleUserSpawner(MockSpawner):
|
|||||||
|
|
||||||
_thread = None
|
_thread = None
|
||||||
|
|
||||||
@gen.coroutine
|
async def start(self):
|
||||||
def start(self):
|
|
||||||
ip = self.ip = '127.0.0.1'
|
ip = self.ip = '127.0.0.1'
|
||||||
port = self.port = random_port()
|
port = self.port = random_port()
|
||||||
env = self.get_env()
|
env = self.get_env()
|
||||||
@@ -435,14 +426,12 @@ class StubSingleUserSpawner(MockSpawner):
|
|||||||
assert ready
|
assert ready
|
||||||
return (ip, port)
|
return (ip, port)
|
||||||
|
|
||||||
@gen.coroutine
|
async def stop(self):
|
||||||
def stop(self):
|
|
||||||
self._app.stop()
|
self._app.stop()
|
||||||
self._thread.join(timeout=30)
|
self._thread.join(timeout=30)
|
||||||
assert not self._thread.is_alive()
|
assert not self._thread.is_alive()
|
||||||
|
|
||||||
@gen.coroutine
|
async def poll(self):
|
||||||
def poll(self):
|
|
||||||
if self._thread is None:
|
if self._thread is None:
|
||||||
return 0
|
return 0
|
||||||
if self._thread.is_alive():
|
if self._thread.is_alive():
|
||||||
|
@@ -1,9 +1,9 @@
|
|||||||
"""Tests for the REST API."""
|
"""Tests for the REST API."""
|
||||||
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
import uuid
|
import uuid
|
||||||
from concurrent.futures import Future
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
@@ -885,8 +885,8 @@ async def test_spawn_limit(app, no_patience, slow_spawn, request):
|
|||||||
# start two pending spawns
|
# start two pending spawns
|
||||||
names = ['ykka', 'hjarka']
|
names = ['ykka', 'hjarka']
|
||||||
users = [add_user(db, app=app, name=name) for name in names]
|
users = [add_user(db, app=app, name=name) for name in names]
|
||||||
users[0].spawner._start_future = Future()
|
users[0].spawner._start_future = asyncio.Future()
|
||||||
users[1].spawner._start_future = Future()
|
users[1].spawner._start_future = asyncio.Future()
|
||||||
for name in names:
|
for name in names:
|
||||||
await api_request(app, 'users', name, 'server', method='post')
|
await api_request(app, 'users', name, 'server', method='post')
|
||||||
assert app.users.count_active_users()['pending'] == 2
|
assert app.users.count_active_users()['pending'] == 2
|
||||||
@@ -894,7 +894,7 @@ async def test_spawn_limit(app, no_patience, slow_spawn, request):
|
|||||||
# ykka and hjarka's spawns are both pending. Essun should fail with 429
|
# ykka and hjarka's spawns are both pending. Essun should fail with 429
|
||||||
name = 'essun'
|
name = 'essun'
|
||||||
user = add_user(db, app=app, name=name)
|
user = add_user(db, app=app, name=name)
|
||||||
user.spawner._start_future = Future()
|
user.spawner._start_future = asyncio.Future()
|
||||||
r = await api_request(app, 'users', name, 'server', method='post')
|
r = await api_request(app, 'users', name, 'server', method='post')
|
||||||
assert r.status_code == 429
|
assert r.status_code == 429
|
||||||
|
|
||||||
|
@@ -20,8 +20,7 @@ ssl_enabled = True
|
|||||||
SSL_ERROR = (SSLError, ConnectionError)
|
SSL_ERROR = (SSLError, ConnectionError)
|
||||||
|
|
||||||
|
|
||||||
@gen.coroutine
|
async def wait_for_spawner(spawner, timeout=10):
|
||||||
def wait_for_spawner(spawner, timeout=10):
|
|
||||||
"""Wait for an http server to show up
|
"""Wait for an http server to show up
|
||||||
|
|
||||||
polling at shorter intervals for early termination
|
polling at shorter intervals for early termination
|
||||||
@@ -32,15 +31,15 @@ def wait_for_spawner(spawner, timeout=10):
|
|||||||
return spawner.server.wait_up(timeout=1, http=True)
|
return spawner.server.wait_up(timeout=1, http=True)
|
||||||
|
|
||||||
while time.monotonic() < deadline:
|
while time.monotonic() < deadline:
|
||||||
status = yield spawner.poll()
|
status = await spawner.poll()
|
||||||
assert status is None
|
assert status is None
|
||||||
try:
|
try:
|
||||||
yield wait()
|
await wait()
|
||||||
except TimeoutError:
|
except TimeoutError:
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
break
|
break
|
||||||
yield wait()
|
await wait()
|
||||||
|
|
||||||
|
|
||||||
async def test_connection_hub_wrong_certs(app):
|
async def test_connection_hub_wrong_certs(app):
|
||||||
|
@@ -222,8 +222,7 @@ async def test_spawn_fails(db):
|
|||||||
db.commit()
|
db.commit()
|
||||||
|
|
||||||
class BadSpawner(MockSpawner):
|
class BadSpawner(MockSpawner):
|
||||||
@gen.coroutine
|
async def start(self):
|
||||||
def start(self):
|
|
||||||
raise RuntimeError("Split the party")
|
raise RuntimeError("Split the party")
|
||||||
|
|
||||||
user = User(
|
user = User(
|
||||||
|
@@ -586,8 +586,7 @@ async def test_login_strip(app):
|
|||||||
base_url = public_url(app)
|
base_url = public_url(app)
|
||||||
called_with = []
|
called_with = []
|
||||||
|
|
||||||
@gen.coroutine
|
async def mock_authenticate(handler, data):
|
||||||
def mock_authenticate(handler, data):
|
|
||||||
called_with.append(data)
|
called_with.append(data)
|
||||||
|
|
||||||
with mock.patch.object(app.authenticator, 'authenticate', mock_authenticate):
|
with mock.patch.object(app.authenticator, 'authenticate', mock_authenticate):
|
||||||
@@ -943,8 +942,7 @@ async def test_pre_spawn_start_exc_no_form(app):
|
|||||||
exc = "pre_spawn_start error"
|
exc = "pre_spawn_start error"
|
||||||
|
|
||||||
# throw exception from pre_spawn_start
|
# throw exception from pre_spawn_start
|
||||||
@gen.coroutine
|
async def mock_pre_spawn_start(user, spawner):
|
||||||
def mock_pre_spawn_start(user, spawner):
|
|
||||||
raise Exception(exc)
|
raise Exception(exc)
|
||||||
|
|
||||||
with mock.patch.object(app.authenticator, 'pre_spawn_start', mock_pre_spawn_start):
|
with mock.patch.object(app.authenticator, 'pre_spawn_start', mock_pre_spawn_start):
|
||||||
@@ -959,8 +957,7 @@ async def test_pre_spawn_start_exc_options_form(app):
|
|||||||
exc = "pre_spawn_start error"
|
exc = "pre_spawn_start error"
|
||||||
|
|
||||||
# throw exception from pre_spawn_start
|
# throw exception from pre_spawn_start
|
||||||
@gen.coroutine
|
async def mock_pre_spawn_start(user, spawner):
|
||||||
def mock_pre_spawn_start(user, spawner):
|
|
||||||
raise Exception(exc)
|
raise Exception(exc)
|
||||||
|
|
||||||
with mock.patch.dict(
|
with mock.patch.dict(
|
||||||
|
Reference in New Issue
Block a user