mirror of
https://github.com/jupyterhub/jupyterhub.git
synced 2025-10-15 14:03:02 +00:00
ensure foreign keys are enabled on sqlite
This commit is contained in:
@@ -20,6 +20,7 @@ from sqlalchemy import (
|
||||
)
|
||||
from sqlalchemy.ext.declarative import declarative_base, declared_attr
|
||||
from sqlalchemy.orm import sessionmaker, relationship
|
||||
from sqlalchemy.interfaces import PoolListener
|
||||
from sqlalchemy.pool import StaticPool
|
||||
from sqlalchemy.sql.expression import bindparam
|
||||
from sqlalchemy import create_engine, Table
|
||||
@@ -464,6 +465,13 @@ class DatabaseSchemaMismatch(Exception):
|
||||
the current version of JupyterHub.
|
||||
"""
|
||||
|
||||
|
||||
class ForeignKeysListener(PoolListener):
|
||||
"""Enable foreign keys on sqlite"""
|
||||
def connect(self, dbapi_con, con_record):
|
||||
dbapi_con.execute('pragma foreign_keys=ON')
|
||||
|
||||
|
||||
def check_db_revision(engine):
|
||||
"""Check the JupyterHub database revision
|
||||
|
||||
@@ -527,10 +535,14 @@ def check_db_revision(engine):
|
||||
head=head,
|
||||
))
|
||||
|
||||
|
||||
def new_session_factory(url="sqlite:///:memory:", reset=False, **kwargs):
|
||||
"""Create a new session at url"""
|
||||
if url.startswith('sqlite'):
|
||||
kwargs.setdefault('connect_args', {'check_same_thread': False})
|
||||
listeners = kwargs.setdefault('listeners', [])
|
||||
listeners.append(ForeignKeysListener())
|
||||
|
||||
elif url.startswith('mysql'):
|
||||
kwargs.setdefault('pool_recycle', 60)
|
||||
|
||||
|
Reference in New Issue
Block a user