mirror of
https://github.com/jupyterhub/jupyterhub.git
synced 2025-10-15 22:13:00 +00:00
@@ -262,6 +262,7 @@ class JupyterHubApp(Application):
|
||||
help="log all database transactions. This has A LOT of output"
|
||||
)
|
||||
db = Any()
|
||||
session_factory = Any()
|
||||
|
||||
admin_users = Set(config=True,
|
||||
help="""set of usernames of admin users
|
||||
@@ -364,9 +365,13 @@ class JupyterHubApp(Application):
|
||||
"""Create the database connection"""
|
||||
self.log.debug("Connecting to db: %s", self.db_url)
|
||||
try:
|
||||
self.db = orm.new_session(self.db_url, reset=self.reset_db, echo=self.debug_db,
|
||||
self.session_factory = orm.new_session_factory(
|
||||
self.db_url,
|
||||
reset=self.reset_db,
|
||||
echo=self.debug_db,
|
||||
**self.db_kwargs
|
||||
)
|
||||
self.db = self.session_factory()
|
||||
except OperationalError as e:
|
||||
self.log.error("Failed to connect to db: %s", self.db_url)
|
||||
self.log.debug("Database error was:", exc_info=True)
|
||||
|
@@ -58,6 +58,11 @@ class BaseHandler(RequestHandler):
|
||||
def authenticator(self):
|
||||
return self.settings.get('authenticator', None)
|
||||
|
||||
def finish(self, *args, **kwargs):
|
||||
"""Roll back any uncommitted transactions from the handler."""
|
||||
self.db.rollback()
|
||||
super(BaseHandler, self).finish(*args, **kwargs)
|
||||
|
||||
#---------------------------------------------------------------
|
||||
# Login and cookie-related
|
||||
#---------------------------------------------------------------
|
||||
|
@@ -381,17 +381,20 @@ class CookieToken(Token, Base):
|
||||
__tablename__ = 'cookie_tokens'
|
||||
|
||||
|
||||
def new_session(url="sqlite:///:memory:", reset=False, **kwargs):
|
||||
def new_session_factory(url="sqlite:///:memory:", reset=False, **kwargs):
|
||||
"""Create a new session at url"""
|
||||
if url.startswith('sqlite'):
|
||||
kwargs.setdefault('connect_args', {'check_same_thread': False})
|
||||
kwargs.setdefault('poolclass', StaticPool)
|
||||
|
||||
if url.endswith(':memory:'):
|
||||
# If we're using an in-memory database, ensure that only one connection
|
||||
# is ever created.
|
||||
kwargs.setdefault('poolclass', StaticPool)
|
||||
|
||||
engine = create_engine(url, **kwargs)
|
||||
Session = sessionmaker(bind=engine)
|
||||
session = Session()
|
||||
if reset:
|
||||
Base.metadata.drop_all(engine)
|
||||
Base.metadata.create_all(engine)
|
||||
return session
|
||||
|
||||
|
||||
session_factory = sessionmaker(bind=engine)
|
||||
return session_factory
|
||||
|
@@ -22,7 +22,7 @@ def db():
|
||||
"""Get a db session"""
|
||||
global _db
|
||||
if _db is None:
|
||||
_db = orm.new_session('sqlite:///:memory:', echo=True)
|
||||
_db = orm.new_session_factory('sqlite:///:memory:', echo=True)()
|
||||
user = orm.User(
|
||||
name=getuser_unicode(),
|
||||
server=orm.Server(),
|
||||
|
@@ -1,5 +1,6 @@
|
||||
"""mock utilities for testing"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
|
||||
@@ -57,11 +58,18 @@ class MockPAMAuthenticator(PAMAuthenticator):
|
||||
class MockHubApp(JupyterHubApp):
|
||||
"""HubApp with various mock bits"""
|
||||
|
||||
db_path = os.path.join(
|
||||
os.path.dirname(
|
||||
os.path.realpath(__file__),
|
||||
),
|
||||
"test.sqlite",
|
||||
)
|
||||
|
||||
def _ip_default(self):
|
||||
return 'localhost'
|
||||
|
||||
def _db_url_default(self):
|
||||
return 'sqlite:///:memory:'
|
||||
return "sqlite:///" + self.db_path
|
||||
|
||||
def _authenticator_class_default(self):
|
||||
return MockPAMAuthenticator
|
||||
@@ -72,7 +80,12 @@ class MockHubApp(JupyterHubApp):
|
||||
def _admin_users_default(self):
|
||||
return {'admin'}
|
||||
|
||||
def rm_db(self):
|
||||
if os.path.exists(self.db_path):
|
||||
os.remove(self.db_path)
|
||||
|
||||
def start(self, argv=None):
|
||||
self.rm_db()
|
||||
evt = threading.Event()
|
||||
def _start():
|
||||
self.io_loop = IOLoop.current()
|
||||
@@ -91,6 +104,6 @@ class MockHubApp(JupyterHubApp):
|
||||
evt.wait(timeout=5)
|
||||
|
||||
def stop(self):
|
||||
self.rm_db()
|
||||
self.io_loop.add_callback(self.io_loop.stop)
|
||||
self._thread.join()
|
||||
|
||||
|
@@ -7,6 +7,30 @@ import requests
|
||||
from ..utils import url_path_join as ujoin
|
||||
from .. import orm
|
||||
|
||||
|
||||
def check_db_locks(func):
|
||||
"""
|
||||
Decorator for test functions that verifies no locks are held on the
|
||||
application's database upon exit by creating and dropping a dummy table.
|
||||
|
||||
Relies on an instance of JupyterhubApp being the first argument to the
|
||||
decorated function.
|
||||
"""
|
||||
|
||||
def new_func(*args, **kwargs):
|
||||
retval = func(*args, **kwargs)
|
||||
|
||||
app = args[0]
|
||||
temp_session = app.session_factory()
|
||||
temp_session.execute('CREATE TABLE dummy (foo INT)')
|
||||
temp_session.execute('DROP TABLE dummy')
|
||||
temp_session.close()
|
||||
|
||||
return retval
|
||||
|
||||
return new_func
|
||||
|
||||
|
||||
def find_user(db, name):
|
||||
return db.query(orm.User).filter(orm.User.name==name).first()
|
||||
|
||||
@@ -28,6 +52,7 @@ def auth_header(db, name):
|
||||
token = user.api_tokens[0]
|
||||
return {'Authorization': 'token %s' % token.token}
|
||||
|
||||
@check_db_locks
|
||||
def api_request(app, *api_path, **kwargs):
|
||||
"""Make an API request"""
|
||||
base_url = app.hub.server.url
|
||||
@@ -71,7 +96,6 @@ def test_auth_api(app):
|
||||
)
|
||||
assert r.status_code == 403
|
||||
|
||||
|
||||
def test_get_users(app):
|
||||
db = app.db
|
||||
r = api_request(app, 'users')
|
||||
@@ -179,4 +203,3 @@ def test_spawn(app, io_loop):
|
||||
assert 'pid' not in user.state
|
||||
status = io_loop.run_sync(user.spawner.poll)
|
||||
assert status == 0
|
||||
|
Reference in New Issue
Block a user