DEV: Close transactions at the end of HTTP Requests.

Fixes #84
This commit is contained in:
Scott Sanderson
2014-10-29 16:00:41 -04:00
parent db5cf9cf99
commit 8cfbe9b38e
6 changed files with 62 additions and 13 deletions

View File

@@ -262,6 +262,7 @@ class JupyterHubApp(Application):
help="log all database transactions. This has A LOT of output" help="log all database transactions. This has A LOT of output"
) )
db = Any() db = Any()
session_factory = Any()
admin_users = Set(config=True, admin_users = Set(config=True,
help="""set of usernames of admin users help="""set of usernames of admin users
@@ -364,9 +365,13 @@ class JupyterHubApp(Application):
"""Create the database connection""" """Create the database connection"""
self.log.debug("Connecting to db: %s", self.db_url) self.log.debug("Connecting to db: %s", self.db_url)
try: try:
self.db = orm.new_session(self.db_url, reset=self.reset_db, echo=self.debug_db, self.session_factory = orm.new_session_factory(
self.db_url,
reset=self.reset_db,
echo=self.debug_db,
**self.db_kwargs **self.db_kwargs
) )
self.db = self.session_factory()
except OperationalError as e: except OperationalError as e:
self.log.error("Failed to connect to db: %s", self.db_url) self.log.error("Failed to connect to db: %s", self.db_url)
self.log.debug("Database error was:", exc_info=True) self.log.debug("Database error was:", exc_info=True)

View File

@@ -58,6 +58,11 @@ class BaseHandler(RequestHandler):
def authenticator(self): def authenticator(self):
return self.settings.get('authenticator', None) return self.settings.get('authenticator', None)
def finish(self, *args, **kwargs):
"""Roll back any uncommitted transactions from the handler."""
self.db.rollback()
super(BaseHandler, self).finish(*args, **kwargs)
#--------------------------------------------------------------- #---------------------------------------------------------------
# Login and cookie-related # Login and cookie-related
#--------------------------------------------------------------- #---------------------------------------------------------------

View File

@@ -381,17 +381,20 @@ class CookieToken(Token, Base):
__tablename__ = 'cookie_tokens' __tablename__ = 'cookie_tokens'
def new_session(url="sqlite:///:memory:", reset=False, **kwargs): def new_session_factory(url="sqlite:///:memory:", reset=False, **kwargs):
"""Create a new session at url""" """Create a new session at url"""
if url.startswith('sqlite'): if url.startswith('sqlite'):
kwargs.setdefault('connect_args', {'check_same_thread': False}) kwargs.setdefault('connect_args', {'check_same_thread': False})
kwargs.setdefault('poolclass', StaticPool)
if url.endswith(':memory:'):
# If we're using an in-memory database, ensure that only one connection
# is ever created.
kwargs.setdefault('poolclass', StaticPool)
engine = create_engine(url, **kwargs) engine = create_engine(url, **kwargs)
Session = sessionmaker(bind=engine)
session = Session()
if reset: if reset:
Base.metadata.drop_all(engine) Base.metadata.drop_all(engine)
Base.metadata.create_all(engine) Base.metadata.create_all(engine)
return session
session_factory = sessionmaker(bind=engine)
return session_factory

View File

@@ -22,7 +22,7 @@ def db():
"""Get a db session""" """Get a db session"""
global _db global _db
if _db is None: if _db is None:
_db = orm.new_session('sqlite:///:memory:', echo=True) _db = orm.new_session_factory('sqlite:///:memory:', echo=True)()
user = orm.User( user = orm.User(
name=getuser_unicode(), name=getuser_unicode(),
server=orm.Server(), server=orm.Server(),

View File

@@ -1,5 +1,6 @@
"""mock utilities for testing""" """mock utilities for testing"""
import os
import sys import sys
import threading import threading
@@ -56,12 +57,19 @@ class MockPAMAuthenticator(PAMAuthenticator):
class MockHubApp(JupyterHubApp): class MockHubApp(JupyterHubApp):
"""HubApp with various mock bits""" """HubApp with various mock bits"""
db_path = os.path.join(
os.path.dirname(
os.path.realpath(__file__),
),
"test.sqlite",
)
def _ip_default(self): def _ip_default(self):
return 'localhost' return 'localhost'
def _db_url_default(self): def _db_url_default(self):
return 'sqlite:///:memory:' return "sqlite:///" + self.db_path
def _authenticator_class_default(self): def _authenticator_class_default(self):
return MockPAMAuthenticator return MockPAMAuthenticator
@@ -71,8 +79,13 @@ class MockHubApp(JupyterHubApp):
def _admin_users_default(self): def _admin_users_default(self):
return {'admin'} return {'admin'}
def rm_db(self):
if os.path.exists(self.db_path):
os.remove(self.db_path)
def start(self, argv=None): def start(self, argv=None):
self.rm_db()
evt = threading.Event() evt = threading.Event()
def _start(): def _start():
self.io_loop = IOLoop.current() self.io_loop = IOLoop.current()
@@ -91,6 +104,6 @@ class MockHubApp(JupyterHubApp):
evt.wait(timeout=5) evt.wait(timeout=5)
def stop(self): def stop(self):
self.rm_db()
self.io_loop.add_callback(self.io_loop.stop) self.io_loop.add_callback(self.io_loop.stop)
self._thread.join() self._thread.join()

View File

@@ -7,6 +7,30 @@ import requests
from ..utils import url_path_join as ujoin from ..utils import url_path_join as ujoin
from .. import orm from .. import orm
def check_db_locks(func):
"""
Decorator for test functions that verifies no locks are held on the
application's database upon exit by creating and dropping a dummy table.
Relies on an instance of JupyterhubApp being the first argument to the
decorated function.
"""
def new_func(*args, **kwargs):
retval = func(*args, **kwargs)
app = args[0]
temp_session = app.session_factory()
temp_session.execute('CREATE TABLE dummy (foo INT)')
temp_session.execute('DROP TABLE dummy')
temp_session.close()
return retval
return new_func
def find_user(db, name): def find_user(db, name):
return db.query(orm.User).filter(orm.User.name==name).first() return db.query(orm.User).filter(orm.User.name==name).first()
@@ -28,6 +52,7 @@ def auth_header(db, name):
token = user.api_tokens[0] token = user.api_tokens[0]
return {'Authorization': 'token %s' % token.token} return {'Authorization': 'token %s' % token.token}
@check_db_locks
def api_request(app, *api_path, **kwargs): def api_request(app, *api_path, **kwargs):
"""Make an API request""" """Make an API request"""
base_url = app.hub.server.url base_url = app.hub.server.url
@@ -70,7 +95,6 @@ def test_auth_api(app):
headers={'Authorization': 'token: %s' % cookie_token.token}, headers={'Authorization': 'token: %s' % cookie_token.token},
) )
assert r.status_code == 403 assert r.status_code == 403
def test_get_users(app): def test_get_users(app):
db = app.db db = app.db
@@ -179,4 +203,3 @@ def test_spawn(app, io_loop):
assert 'pid' not in user.state assert 'pid' not in user.state
status = io_loop.run_sync(user.spawner.poll) status = io_loop.run_sync(user.spawner.poll)
assert status == 0 assert status == 0