DEV: Close transactions at the end of HTTP Requests.

Fixes #84
This commit is contained in:
Scott Sanderson
2014-10-29 16:00:41 -04:00
parent db5cf9cf99
commit 8cfbe9b38e
6 changed files with 62 additions and 13 deletions

View File

@@ -262,6 +262,7 @@ class JupyterHubApp(Application):
help="log all database transactions. This has A LOT of output"
)
db = Any()
session_factory = Any()
admin_users = Set(config=True,
help="""set of usernames of admin users
@@ -364,9 +365,13 @@ class JupyterHubApp(Application):
"""Create the database connection"""
self.log.debug("Connecting to db: %s", self.db_url)
try:
self.db = orm.new_session(self.db_url, reset=self.reset_db, echo=self.debug_db,
self.session_factory = orm.new_session_factory(
self.db_url,
reset=self.reset_db,
echo=self.debug_db,
**self.db_kwargs
)
self.db = self.session_factory()
except OperationalError as e:
self.log.error("Failed to connect to db: %s", self.db_url)
self.log.debug("Database error was:", exc_info=True)

View File

@@ -58,6 +58,11 @@ class BaseHandler(RequestHandler):
def authenticator(self):
return self.settings.get('authenticator', None)
def finish(self, *args, **kwargs):
"""Roll back any uncommitted transactions from the handler."""
self.db.rollback()
super(BaseHandler, self).finish(*args, **kwargs)
#---------------------------------------------------------------
# Login and cookie-related
#---------------------------------------------------------------

View File

@@ -381,17 +381,20 @@ class CookieToken(Token, Base):
__tablename__ = 'cookie_tokens'
def new_session(url="sqlite:///:memory:", reset=False, **kwargs):
def new_session_factory(url="sqlite:///:memory:", reset=False, **kwargs):
"""Create a new session at url"""
if url.startswith('sqlite'):
kwargs.setdefault('connect_args', {'check_same_thread': False})
if url.endswith(':memory:'):
# If we're using an in-memory database, ensure that only one connection
# is ever created.
kwargs.setdefault('poolclass', StaticPool)
engine = create_engine(url, **kwargs)
Session = sessionmaker(bind=engine)
session = Session()
if reset:
Base.metadata.drop_all(engine)
Base.metadata.create_all(engine)
return session
session_factory = sessionmaker(bind=engine)
return session_factory

View File

@@ -22,7 +22,7 @@ def db():
"""Get a db session"""
global _db
if _db is None:
_db = orm.new_session('sqlite:///:memory:', echo=True)
_db = orm.new_session_factory('sqlite:///:memory:', echo=True)()
user = orm.User(
name=getuser_unicode(),
server=orm.Server(),

View File

@@ -1,5 +1,6 @@
"""mock utilities for testing"""
import os
import sys
import threading
@@ -57,11 +58,18 @@ class MockPAMAuthenticator(PAMAuthenticator):
class MockHubApp(JupyterHubApp):
"""HubApp with various mock bits"""
db_path = os.path.join(
os.path.dirname(
os.path.realpath(__file__),
),
"test.sqlite",
)
def _ip_default(self):
return 'localhost'
def _db_url_default(self):
return 'sqlite:///:memory:'
return "sqlite:///" + self.db_path
def _authenticator_class_default(self):
return MockPAMAuthenticator
@@ -72,7 +80,12 @@ class MockHubApp(JupyterHubApp):
def _admin_users_default(self):
return {'admin'}
def rm_db(self):
if os.path.exists(self.db_path):
os.remove(self.db_path)
def start(self, argv=None):
self.rm_db()
evt = threading.Event()
def _start():
self.io_loop = IOLoop.current()
@@ -91,6 +104,6 @@ class MockHubApp(JupyterHubApp):
evt.wait(timeout=5)
def stop(self):
self.rm_db()
self.io_loop.add_callback(self.io_loop.stop)
self._thread.join()

View File

@@ -7,6 +7,30 @@ import requests
from ..utils import url_path_join as ujoin
from .. import orm
def check_db_locks(func):
"""
Decorator for test functions that verifies no locks are held on the
application's database upon exit by creating and dropping a dummy table.
Relies on an instance of JupyterhubApp being the first argument to the
decorated function.
"""
def new_func(*args, **kwargs):
retval = func(*args, **kwargs)
app = args[0]
temp_session = app.session_factory()
temp_session.execute('CREATE TABLE dummy (foo INT)')
temp_session.execute('DROP TABLE dummy')
temp_session.close()
return retval
return new_func
def find_user(db, name):
return db.query(orm.User).filter(orm.User.name==name).first()
@@ -28,6 +52,7 @@ def auth_header(db, name):
token = user.api_tokens[0]
return {'Authorization': 'token %s' % token.token}
@check_db_locks
def api_request(app, *api_path, **kwargs):
"""Make an API request"""
base_url = app.hub.server.url
@@ -71,7 +96,6 @@ def test_auth_api(app):
)
assert r.status_code == 403
def test_get_users(app):
db = app.db
r = api_request(app, 'users')
@@ -179,4 +203,3 @@ def test_spawn(app, io_loop):
assert 'pid' not in user.state
status = io_loop.run_sync(user.spawner.poll)
assert status == 0