diff --git a/.travis.yml b/.travis.yml index 7cd48d22..af58c0f6 100644 --- a/.travis.yml +++ b/.travis.yml @@ -12,6 +12,6 @@ install: - pip install -f travis-wheels/wheelhouse -r dev-requirements.txt . - pip install -f travis-wheels/wheelhouse ipython[notebook] script: - - py.test --cov jupyterhub jupyterhub/tests + - py.test --cov jupyterhub jupyterhub/tests -v after_success: - coveralls diff --git a/jupyterhub/app.py b/jupyterhub/app.py index ef519507..0717a92f 100644 --- a/jupyterhub/app.py +++ b/jupyterhub/app.py @@ -11,6 +11,7 @@ import os import signal import socket import sys +import threading from datetime import datetime from distutils.version import LooseVersion as V from getpass import getuser @@ -540,9 +541,18 @@ class JupyterHub(Application): # store the loaded trait value self.cookie_secret = secret + _db_local = None + @property + def db(self): + if not hasattr(self._db_local, 'db'): + print("Making new connection", self) + self._db_local.db = scoped_session(self.session_factory)() + return self._db_local.db + def init_db(self): """Create the database connection""" self.log.debug("Connecting to db: %s", self.db_url) + self._db_local = threading.local() try: self.session_factory = orm.new_session_factory( self.db_url, @@ -550,7 +560,8 @@ class JupyterHub(Application): echo=self.debug_db, **self.db_kwargs ) - self.db = scoped_session(self.session_factory)() + # trigger constructing thread local db property + _ = self.db except OperationalError as e: self.log.error("Failed to connect to db: %s", self.db_url) self.log.debug("Database error was:", exc_info=True) diff --git a/jupyterhub/tests/test_api.py b/jupyterhub/tests/test_api.py index 4388f3c9..ab3211bb 100644 --- a/jupyterhub/tests/test_api.py +++ b/jupyterhub/tests/test_api.py @@ -3,6 +3,7 @@ import json import time from datetime import timedelta +from queue import Queue import requests @@ -254,6 +255,18 @@ def test_make_admin(app): assert user.name == name assert user.admin +def get_app_user(app, name): + """Get the User object from the main thread + + Needed for access to the Spawner. + No ORM methods should be called on the result. + """ + q = Queue() + def get_user(): + user = find_user(app.db, name) + q.put(user) + app.io_loop.add_callback(get_user) + return q.get(timeout=2) def test_spawn(app, io_loop): db = app.db @@ -262,9 +275,10 @@ def test_spawn(app, io_loop): r = api_request(app, 'users', name, 'server', method='post') assert r.status_code == 201 assert 'pid' in user.state - assert user.spawner is not None - assert not user.spawn_pending - status = io_loop.run_sync(user.spawner.poll) + app_user = get_app_user(app, name) + assert app_user.spawner is not None + assert not app_user.spawn_pending + status = io_loop.run_sync(app_user.spawner.poll) assert status is None assert user.server.base_url == '/user/%s' % name @@ -282,7 +296,7 @@ def test_spawn(app, io_loop): assert r.status_code == 204 assert 'pid' not in user.state - status = io_loop.run_sync(user.spawner.poll) + status = io_loop.run_sync(app_user.spawner.poll) assert status == 0 def test_slow_spawn(app, io_loop): @@ -296,41 +310,42 @@ def test_slow_spawn(app, io_loop): r = api_request(app, 'users', name, 'server', method='post') r.raise_for_status() assert r.status_code == 202 - assert user.spawner is not None - assert user.spawn_pending - assert not user.stop_pending + app_user = get_app_user(app, name) + assert app_user.spawner is not None + assert app_user.spawn_pending + assert not app_user.stop_pending dt = timedelta(seconds=0.1) @gen.coroutine def wait_spawn(): - while user.spawn_pending: + while app_user.spawn_pending: yield gen.Task(io_loop.add_timeout, dt) io_loop.run_sync(wait_spawn) - assert not user.spawn_pending - status = io_loop.run_sync(user.spawner.poll) + assert not app_user.spawn_pending + status = io_loop.run_sync(app_user.spawner.poll) assert status is None @gen.coroutine def wait_stop(): - while user.stop_pending: + while app_user.stop_pending: yield gen.Task(io_loop.add_timeout, dt) r = api_request(app, 'users', name, 'server', method='delete') r.raise_for_status() assert r.status_code == 202 - assert user.spawner is not None - assert user.stop_pending + assert app_user.spawner is not None + assert app_user.stop_pending r = api_request(app, 'users', name, 'server', method='delete') r.raise_for_status() assert r.status_code == 202 - assert user.spawner is not None - assert user.stop_pending + assert app_user.spawner is not None + assert app_user.stop_pending io_loop.run_sync(wait_stop) - assert not user.stop_pending - assert user.spawner is not None + assert not app_user.stop_pending + assert app_user.spawner is not None r = api_request(app, 'users', name, 'server', method='delete') assert r.status_code == 400 @@ -343,18 +358,19 @@ def test_never_spawn(app, io_loop): name = 'badger' user = add_user(db, name=name) r = api_request(app, 'users', name, 'server', method='post') - assert user.spawner is not None - assert user.spawn_pending + app_user = get_app_user(app, name) + assert app_user.spawner is not None + assert app_user.spawn_pending dt = timedelta(seconds=0.1) @gen.coroutine def wait_pending(): - while user.spawn_pending: + while app_user.spawn_pending: yield gen.Task(io_loop.add_timeout, dt) io_loop.run_sync(wait_pending) - assert not user.spawn_pending - status = io_loop.run_sync(user.spawner.poll) + assert not app_user.spawn_pending + status = io_loop.run_sync(app_user.spawner.poll) assert status is not None