use thread local db sessions

to avoid segfaults on Travis
This commit is contained in:
Min RK
2015-07-09 10:09:46 -05:00
parent c8487c2117
commit 48fe642c44
3 changed files with 51 additions and 24 deletions

View File

@@ -12,6 +12,6 @@ install:
- pip install -f travis-wheels/wheelhouse -r dev-requirements.txt . - pip install -f travis-wheels/wheelhouse -r dev-requirements.txt .
- pip install -f travis-wheels/wheelhouse ipython[notebook] - pip install -f travis-wheels/wheelhouse ipython[notebook]
script: script:
- py.test --cov jupyterhub jupyterhub/tests - py.test --cov jupyterhub jupyterhub/tests -v
after_success: after_success:
- coveralls - coveralls

View File

@@ -11,6 +11,7 @@ import os
import signal import signal
import socket import socket
import sys import sys
import threading
from datetime import datetime from datetime import datetime
from distutils.version import LooseVersion as V from distutils.version import LooseVersion as V
from getpass import getuser from getpass import getuser
@@ -540,9 +541,18 @@ class JupyterHub(Application):
# store the loaded trait value # store the loaded trait value
self.cookie_secret = secret self.cookie_secret = secret
_db_local = None
@property
def db(self):
if not hasattr(self._db_local, 'db'):
print("Making new connection", self)
self._db_local.db = scoped_session(self.session_factory)()
return self._db_local.db
def init_db(self): def init_db(self):
"""Create the database connection""" """Create the database connection"""
self.log.debug("Connecting to db: %s", self.db_url) self.log.debug("Connecting to db: %s", self.db_url)
self._db_local = threading.local()
try: try:
self.session_factory = orm.new_session_factory( self.session_factory = orm.new_session_factory(
self.db_url, self.db_url,
@@ -550,7 +560,8 @@ class JupyterHub(Application):
echo=self.debug_db, echo=self.debug_db,
**self.db_kwargs **self.db_kwargs
) )
self.db = scoped_session(self.session_factory)() # trigger constructing thread local db property
_ = self.db
except OperationalError as e: except OperationalError as e:
self.log.error("Failed to connect to db: %s", self.db_url) self.log.error("Failed to connect to db: %s", self.db_url)
self.log.debug("Database error was:", exc_info=True) self.log.debug("Database error was:", exc_info=True)

View File

@@ -3,6 +3,7 @@
import json import json
import time import time
from datetime import timedelta from datetime import timedelta
from queue import Queue
import requests import requests
@@ -254,6 +255,18 @@ def test_make_admin(app):
assert user.name == name assert user.name == name
assert user.admin assert user.admin
def get_app_user(app, name):
"""Get the User object from the main thread
Needed for access to the Spawner.
No ORM methods should be called on the result.
"""
q = Queue()
def get_user():
user = find_user(app.db, name)
q.put(user)
app.io_loop.add_callback(get_user)
return q.get(timeout=2)
def test_spawn(app, io_loop): def test_spawn(app, io_loop):
db = app.db db = app.db
@@ -262,9 +275,10 @@ def test_spawn(app, io_loop):
r = api_request(app, 'users', name, 'server', method='post') r = api_request(app, 'users', name, 'server', method='post')
assert r.status_code == 201 assert r.status_code == 201
assert 'pid' in user.state assert 'pid' in user.state
assert user.spawner is not None app_user = get_app_user(app, name)
assert not user.spawn_pending assert app_user.spawner is not None
status = io_loop.run_sync(user.spawner.poll) assert not app_user.spawn_pending
status = io_loop.run_sync(app_user.spawner.poll)
assert status is None assert status is None
assert user.server.base_url == '/user/%s' % name assert user.server.base_url == '/user/%s' % name
@@ -282,7 +296,7 @@ def test_spawn(app, io_loop):
assert r.status_code == 204 assert r.status_code == 204
assert 'pid' not in user.state assert 'pid' not in user.state
status = io_loop.run_sync(user.spawner.poll) status = io_loop.run_sync(app_user.spawner.poll)
assert status == 0 assert status == 0
def test_slow_spawn(app, io_loop): def test_slow_spawn(app, io_loop):
@@ -296,41 +310,42 @@ def test_slow_spawn(app, io_loop):
r = api_request(app, 'users', name, 'server', method='post') r = api_request(app, 'users', name, 'server', method='post')
r.raise_for_status() r.raise_for_status()
assert r.status_code == 202 assert r.status_code == 202
assert user.spawner is not None app_user = get_app_user(app, name)
assert user.spawn_pending assert app_user.spawner is not None
assert not user.stop_pending assert app_user.spawn_pending
assert not app_user.stop_pending
dt = timedelta(seconds=0.1) dt = timedelta(seconds=0.1)
@gen.coroutine @gen.coroutine
def wait_spawn(): def wait_spawn():
while user.spawn_pending: while app_user.spawn_pending:
yield gen.Task(io_loop.add_timeout, dt) yield gen.Task(io_loop.add_timeout, dt)
io_loop.run_sync(wait_spawn) io_loop.run_sync(wait_spawn)
assert not user.spawn_pending assert not app_user.spawn_pending
status = io_loop.run_sync(user.spawner.poll) status = io_loop.run_sync(app_user.spawner.poll)
assert status is None assert status is None
@gen.coroutine @gen.coroutine
def wait_stop(): def wait_stop():
while user.stop_pending: while app_user.stop_pending:
yield gen.Task(io_loop.add_timeout, dt) yield gen.Task(io_loop.add_timeout, dt)
r = api_request(app, 'users', name, 'server', method='delete') r = api_request(app, 'users', name, 'server', method='delete')
r.raise_for_status() r.raise_for_status()
assert r.status_code == 202 assert r.status_code == 202
assert user.spawner is not None assert app_user.spawner is not None
assert user.stop_pending assert app_user.stop_pending
r = api_request(app, 'users', name, 'server', method='delete') r = api_request(app, 'users', name, 'server', method='delete')
r.raise_for_status() r.raise_for_status()
assert r.status_code == 202 assert r.status_code == 202
assert user.spawner is not None assert app_user.spawner is not None
assert user.stop_pending assert app_user.stop_pending
io_loop.run_sync(wait_stop) io_loop.run_sync(wait_stop)
assert not user.stop_pending assert not app_user.stop_pending
assert user.spawner is not None assert app_user.spawner is not None
r = api_request(app, 'users', name, 'server', method='delete') r = api_request(app, 'users', name, 'server', method='delete')
assert r.status_code == 400 assert r.status_code == 400
@@ -343,18 +358,19 @@ def test_never_spawn(app, io_loop):
name = 'badger' name = 'badger'
user = add_user(db, name=name) user = add_user(db, name=name)
r = api_request(app, 'users', name, 'server', method='post') r = api_request(app, 'users', name, 'server', method='post')
assert user.spawner is not None app_user = get_app_user(app, name)
assert user.spawn_pending assert app_user.spawner is not None
assert app_user.spawn_pending
dt = timedelta(seconds=0.1) dt = timedelta(seconds=0.1)
@gen.coroutine @gen.coroutine
def wait_pending(): def wait_pending():
while user.spawn_pending: while app_user.spawn_pending:
yield gen.Task(io_loop.add_timeout, dt) yield gen.Task(io_loop.add_timeout, dt)
io_loop.run_sync(wait_pending) io_loop.run_sync(wait_pending)
assert not user.spawn_pending assert not app_user.spawn_pending
status = io_loop.run_sync(user.spawner.poll) status = io_loop.run_sync(app_user.spawner.poll)
assert status is not None assert status is not None