diff --git a/jupyterhub/app.py b/jupyterhub/app.py index 0ce8cfca..19b39e78 100644 --- a/jupyterhub/app.py +++ b/jupyterhub/app.py @@ -262,6 +262,7 @@ class JupyterHubApp(Application): help="log all database transactions. This has A LOT of output" ) db = Any() + session_factory = Any() admin_users = Set(config=True, help="""set of usernames of admin users @@ -364,9 +365,13 @@ class JupyterHubApp(Application): """Create the database connection""" self.log.debug("Connecting to db: %s", self.db_url) try: - self.db = orm.new_session(self.db_url, reset=self.reset_db, echo=self.debug_db, + self.session_factory = orm.new_session_factory( + self.db_url, + reset=self.reset_db, + echo=self.debug_db, **self.db_kwargs ) + self.db = self.session_factory() except OperationalError as e: self.log.error("Failed to connect to db: %s", self.db_url) self.log.debug("Database error was:", exc_info=True) diff --git a/jupyterhub/handlers/base.py b/jupyterhub/handlers/base.py index 669388c2..90104286 100644 --- a/jupyterhub/handlers/base.py +++ b/jupyterhub/handlers/base.py @@ -58,6 +58,11 @@ class BaseHandler(RequestHandler): def authenticator(self): return self.settings.get('authenticator', None) + def finish(self, *args, **kwargs): + """Roll back any uncommitted transactions from the handler.""" + self.db.rollback() + super(BaseHandler, self).finish(*args, **kwargs) + #--------------------------------------------------------------- # Login and cookie-related #--------------------------------------------------------------- diff --git a/jupyterhub/orm.py b/jupyterhub/orm.py index 5abbbfd8..8726b5b4 100644 --- a/jupyterhub/orm.py +++ b/jupyterhub/orm.py @@ -381,17 +381,20 @@ class CookieToken(Token, Base): __tablename__ = 'cookie_tokens' -def new_session(url="sqlite:///:memory:", reset=False, **kwargs): +def new_session_factory(url="sqlite:///:memory:", reset=False, **kwargs): """Create a new session at url""" if url.startswith('sqlite'): kwargs.setdefault('connect_args', {'check_same_thread': False}) - kwargs.setdefault('poolclass', StaticPool) + + if url.endswith(':memory:'): + # If we're using an in-memory database, ensure that only one connection + # is ever created. + kwargs.setdefault('poolclass', StaticPool) + engine = create_engine(url, **kwargs) - Session = sessionmaker(bind=engine) - session = Session() if reset: Base.metadata.drop_all(engine) Base.metadata.create_all(engine) - return session - + session_factory = sessionmaker(bind=engine) + return session_factory diff --git a/jupyterhub/tests/conftest.py b/jupyterhub/tests/conftest.py index fe569048..948b36f8 100644 --- a/jupyterhub/tests/conftest.py +++ b/jupyterhub/tests/conftest.py @@ -22,7 +22,7 @@ def db(): """Get a db session""" global _db if _db is None: - _db = orm.new_session('sqlite:///:memory:', echo=True) + _db = orm.new_session_factory('sqlite:///:memory:', echo=True)() user = orm.User( name=getuser_unicode(), server=orm.Server(), diff --git a/jupyterhub/tests/mocking.py b/jupyterhub/tests/mocking.py index 9f771261..21b42483 100644 --- a/jupyterhub/tests/mocking.py +++ b/jupyterhub/tests/mocking.py @@ -1,6 +1,8 @@ """mock utilities for testing""" +import os import sys +from tempfile import NamedTemporaryFile import threading try: @@ -56,13 +58,12 @@ class MockPAMAuthenticator(PAMAuthenticator): class MockHubApp(JupyterHubApp): """HubApp with various mock bits""" - + + db_file = None + def _ip_default(self): return 'localhost' - def _db_url_default(self): - return 'sqlite:///:memory:' - def _authenticator_class_default(self): return MockPAMAuthenticator @@ -71,8 +72,10 @@ class MockHubApp(JupyterHubApp): def _admin_users_default(self): return {'admin'} - + def start(self, argv=None): + self.db_file = NamedTemporaryFile() + self.db_url = 'sqlite:///' + self.db_file.name evt = threading.Event() def _start(): self.io_loop = IOLoop.current() @@ -91,6 +94,6 @@ class MockHubApp(JupyterHubApp): evt.wait(timeout=5) def stop(self): + self.db_file.close() self.io_loop.add_callback(self.io_loop.stop) self._thread.join() - diff --git a/jupyterhub/tests/test_api.py b/jupyterhub/tests/test_api.py index ce914c6f..c31a2e5d 100644 --- a/jupyterhub/tests/test_api.py +++ b/jupyterhub/tests/test_api.py @@ -7,6 +7,30 @@ import requests from ..utils import url_path_join as ujoin from .. import orm + +def check_db_locks(func): + """ + Decorator for test functions that verifies no locks are held on the + application's database upon exit by creating and dropping a dummy table. + + Relies on an instance of JupyterhubApp being the first argument to the + decorated function. + """ + + def new_func(*args, **kwargs): + retval = func(*args, **kwargs) + + app = args[0] + temp_session = app.session_factory() + temp_session.execute('CREATE TABLE dummy (foo INT)') + temp_session.execute('DROP TABLE dummy') + temp_session.close() + + return retval + + return new_func + + def find_user(db, name): return db.query(orm.User).filter(orm.User.name==name).first() @@ -28,6 +52,7 @@ def auth_header(db, name): token = user.api_tokens[0] return {'Authorization': 'token %s' % token.token} +@check_db_locks def api_request(app, *api_path, **kwargs): """Make an API request""" base_url = app.hub.server.url @@ -70,7 +95,6 @@ def test_auth_api(app): headers={'Authorization': 'token: %s' % cookie_token.token}, ) assert r.status_code == 403 - def test_get_users(app): db = app.db @@ -179,4 +203,3 @@ def test_spawn(app, io_loop): assert 'pid' not in user.state status = io_loop.run_sync(user.spawner.poll) assert status == 0 - \ No newline at end of file