diff --git a/jupyterhub/orm.py b/jupyterhub/orm.py index dc98c4b8..316ec594 100644 --- a/jupyterhub/orm.py +++ b/jupyterhub/orm.py @@ -20,6 +20,7 @@ from sqlalchemy import ( ) from sqlalchemy.ext.declarative import declarative_base, declared_attr from sqlalchemy.orm import sessionmaker, relationship +from sqlalchemy.interfaces import PoolListener from sqlalchemy.pool import StaticPool from sqlalchemy.sql.expression import bindparam from sqlalchemy import create_engine, Table @@ -464,6 +465,13 @@ class DatabaseSchemaMismatch(Exception): the current version of JupyterHub. """ + +class ForeignKeysListener(PoolListener): + """Enable foreign keys on sqlite""" + def connect(self, dbapi_con, con_record): + dbapi_con.execute('pragma foreign_keys=ON') + + def check_db_revision(engine): """Check the JupyterHub database revision @@ -527,10 +535,14 @@ def check_db_revision(engine): head=head, )) + def new_session_factory(url="sqlite:///:memory:", reset=False, **kwargs): """Create a new session at url""" if url.startswith('sqlite'): kwargs.setdefault('connect_args', {'check_same_thread': False}) + listeners = kwargs.setdefault('listeners', []) + listeners.append(ForeignKeysListener()) + elif url.startswith('mysql'): kwargs.setdefault('pool_recycle', 60)