mirror of
https://github.com/jupyterhub/jupyterhub.git
synced 2025-10-15 14:03:02 +00:00
@@ -262,6 +262,7 @@ class JupyterHubApp(Application):
|
|||||||
help="log all database transactions. This has A LOT of output"
|
help="log all database transactions. This has A LOT of output"
|
||||||
)
|
)
|
||||||
db = Any()
|
db = Any()
|
||||||
|
session_factory = Any()
|
||||||
|
|
||||||
admin_users = Set(config=True,
|
admin_users = Set(config=True,
|
||||||
help="""set of usernames of admin users
|
help="""set of usernames of admin users
|
||||||
@@ -364,9 +365,13 @@ class JupyterHubApp(Application):
|
|||||||
"""Create the database connection"""
|
"""Create the database connection"""
|
||||||
self.log.debug("Connecting to db: %s", self.db_url)
|
self.log.debug("Connecting to db: %s", self.db_url)
|
||||||
try:
|
try:
|
||||||
self.db = orm.new_session(self.db_url, reset=self.reset_db, echo=self.debug_db,
|
self.session_factory = orm.new_session_factory(
|
||||||
|
self.db_url,
|
||||||
|
reset=self.reset_db,
|
||||||
|
echo=self.debug_db,
|
||||||
**self.db_kwargs
|
**self.db_kwargs
|
||||||
)
|
)
|
||||||
|
self.db = self.session_factory()
|
||||||
except OperationalError as e:
|
except OperationalError as e:
|
||||||
self.log.error("Failed to connect to db: %s", self.db_url)
|
self.log.error("Failed to connect to db: %s", self.db_url)
|
||||||
self.log.debug("Database error was:", exc_info=True)
|
self.log.debug("Database error was:", exc_info=True)
|
||||||
|
@@ -58,6 +58,11 @@ class BaseHandler(RequestHandler):
|
|||||||
def authenticator(self):
|
def authenticator(self):
|
||||||
return self.settings.get('authenticator', None)
|
return self.settings.get('authenticator', None)
|
||||||
|
|
||||||
|
def finish(self, *args, **kwargs):
|
||||||
|
"""Roll back any uncommitted transactions from the handler."""
|
||||||
|
self.db.rollback()
|
||||||
|
super(BaseHandler, self).finish(*args, **kwargs)
|
||||||
|
|
||||||
#---------------------------------------------------------------
|
#---------------------------------------------------------------
|
||||||
# Login and cookie-related
|
# Login and cookie-related
|
||||||
#---------------------------------------------------------------
|
#---------------------------------------------------------------
|
||||||
|
@@ -381,17 +381,20 @@ class CookieToken(Token, Base):
|
|||||||
__tablename__ = 'cookie_tokens'
|
__tablename__ = 'cookie_tokens'
|
||||||
|
|
||||||
|
|
||||||
def new_session(url="sqlite:///:memory:", reset=False, **kwargs):
|
def new_session_factory(url="sqlite:///:memory:", reset=False, **kwargs):
|
||||||
"""Create a new session at url"""
|
"""Create a new session at url"""
|
||||||
if url.startswith('sqlite'):
|
if url.startswith('sqlite'):
|
||||||
kwargs.setdefault('connect_args', {'check_same_thread': False})
|
kwargs.setdefault('connect_args', {'check_same_thread': False})
|
||||||
kwargs.setdefault('poolclass', StaticPool)
|
|
||||||
|
if url.endswith(':memory:'):
|
||||||
|
# If we're using an in-memory database, ensure that only one connection
|
||||||
|
# is ever created.
|
||||||
|
kwargs.setdefault('poolclass', StaticPool)
|
||||||
|
|
||||||
engine = create_engine(url, **kwargs)
|
engine = create_engine(url, **kwargs)
|
||||||
Session = sessionmaker(bind=engine)
|
|
||||||
session = Session()
|
|
||||||
if reset:
|
if reset:
|
||||||
Base.metadata.drop_all(engine)
|
Base.metadata.drop_all(engine)
|
||||||
Base.metadata.create_all(engine)
|
Base.metadata.create_all(engine)
|
||||||
return session
|
|
||||||
|
|
||||||
|
|
||||||
|
session_factory = sessionmaker(bind=engine)
|
||||||
|
return session_factory
|
||||||
|
@@ -22,7 +22,7 @@ def db():
|
|||||||
"""Get a db session"""
|
"""Get a db session"""
|
||||||
global _db
|
global _db
|
||||||
if _db is None:
|
if _db is None:
|
||||||
_db = orm.new_session('sqlite:///:memory:', echo=True)
|
_db = orm.new_session_factory('sqlite:///:memory:', echo=True)()
|
||||||
user = orm.User(
|
user = orm.User(
|
||||||
name=getuser_unicode(),
|
name=getuser_unicode(),
|
||||||
server=orm.Server(),
|
server=orm.Server(),
|
||||||
|
@@ -1,5 +1,6 @@
|
|||||||
"""mock utilities for testing"""
|
"""mock utilities for testing"""
|
||||||
|
|
||||||
|
import os
|
||||||
import sys
|
import sys
|
||||||
import threading
|
import threading
|
||||||
|
|
||||||
@@ -56,12 +57,19 @@ class MockPAMAuthenticator(PAMAuthenticator):
|
|||||||
|
|
||||||
class MockHubApp(JupyterHubApp):
|
class MockHubApp(JupyterHubApp):
|
||||||
"""HubApp with various mock bits"""
|
"""HubApp with various mock bits"""
|
||||||
|
|
||||||
|
db_path = os.path.join(
|
||||||
|
os.path.dirname(
|
||||||
|
os.path.realpath(__file__),
|
||||||
|
),
|
||||||
|
"test.sqlite",
|
||||||
|
)
|
||||||
|
|
||||||
def _ip_default(self):
|
def _ip_default(self):
|
||||||
return 'localhost'
|
return 'localhost'
|
||||||
|
|
||||||
def _db_url_default(self):
|
def _db_url_default(self):
|
||||||
return 'sqlite:///:memory:'
|
return "sqlite:///" + self.db_path
|
||||||
|
|
||||||
def _authenticator_class_default(self):
|
def _authenticator_class_default(self):
|
||||||
return MockPAMAuthenticator
|
return MockPAMAuthenticator
|
||||||
@@ -71,8 +79,13 @@ class MockHubApp(JupyterHubApp):
|
|||||||
|
|
||||||
def _admin_users_default(self):
|
def _admin_users_default(self):
|
||||||
return {'admin'}
|
return {'admin'}
|
||||||
|
|
||||||
|
def rm_db(self):
|
||||||
|
if os.path.exists(self.db_path):
|
||||||
|
os.remove(self.db_path)
|
||||||
|
|
||||||
def start(self, argv=None):
|
def start(self, argv=None):
|
||||||
|
self.rm_db()
|
||||||
evt = threading.Event()
|
evt = threading.Event()
|
||||||
def _start():
|
def _start():
|
||||||
self.io_loop = IOLoop.current()
|
self.io_loop = IOLoop.current()
|
||||||
@@ -91,6 +104,6 @@ class MockHubApp(JupyterHubApp):
|
|||||||
evt.wait(timeout=5)
|
evt.wait(timeout=5)
|
||||||
|
|
||||||
def stop(self):
|
def stop(self):
|
||||||
|
self.rm_db()
|
||||||
self.io_loop.add_callback(self.io_loop.stop)
|
self.io_loop.add_callback(self.io_loop.stop)
|
||||||
self._thread.join()
|
self._thread.join()
|
||||||
|
|
||||||
|
@@ -7,6 +7,30 @@ import requests
|
|||||||
from ..utils import url_path_join as ujoin
|
from ..utils import url_path_join as ujoin
|
||||||
from .. import orm
|
from .. import orm
|
||||||
|
|
||||||
|
|
||||||
|
def check_db_locks(func):
|
||||||
|
"""
|
||||||
|
Decorator for test functions that verifies no locks are held on the
|
||||||
|
application's database upon exit by creating and dropping a dummy table.
|
||||||
|
|
||||||
|
Relies on an instance of JupyterhubApp being the first argument to the
|
||||||
|
decorated function.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def new_func(*args, **kwargs):
|
||||||
|
retval = func(*args, **kwargs)
|
||||||
|
|
||||||
|
app = args[0]
|
||||||
|
temp_session = app.session_factory()
|
||||||
|
temp_session.execute('CREATE TABLE dummy (foo INT)')
|
||||||
|
temp_session.execute('DROP TABLE dummy')
|
||||||
|
temp_session.close()
|
||||||
|
|
||||||
|
return retval
|
||||||
|
|
||||||
|
return new_func
|
||||||
|
|
||||||
|
|
||||||
def find_user(db, name):
|
def find_user(db, name):
|
||||||
return db.query(orm.User).filter(orm.User.name==name).first()
|
return db.query(orm.User).filter(orm.User.name==name).first()
|
||||||
|
|
||||||
@@ -28,6 +52,7 @@ def auth_header(db, name):
|
|||||||
token = user.api_tokens[0]
|
token = user.api_tokens[0]
|
||||||
return {'Authorization': 'token %s' % token.token}
|
return {'Authorization': 'token %s' % token.token}
|
||||||
|
|
||||||
|
@check_db_locks
|
||||||
def api_request(app, *api_path, **kwargs):
|
def api_request(app, *api_path, **kwargs):
|
||||||
"""Make an API request"""
|
"""Make an API request"""
|
||||||
base_url = app.hub.server.url
|
base_url = app.hub.server.url
|
||||||
@@ -70,7 +95,6 @@ def test_auth_api(app):
|
|||||||
headers={'Authorization': 'token: %s' % cookie_token.token},
|
headers={'Authorization': 'token: %s' % cookie_token.token},
|
||||||
)
|
)
|
||||||
assert r.status_code == 403
|
assert r.status_code == 403
|
||||||
|
|
||||||
|
|
||||||
def test_get_users(app):
|
def test_get_users(app):
|
||||||
db = app.db
|
db = app.db
|
||||||
@@ -179,4 +203,3 @@ def test_spawn(app, io_loop):
|
|||||||
assert 'pid' not in user.state
|
assert 'pid' not in user.state
|
||||||
status = io_loop.run_sync(user.spawner.poll)
|
status = io_loop.run_sync(user.spawner.poll)
|
||||||
assert status == 0
|
assert status == 0
|
||||||
|
|
Reference in New Issue
Block a user