mirror of
https://github.com/jupyterhub/jupyterhub.git
synced 2025-10-15 14:03:02 +00:00
check database revision on launch
fail with informative error if version mismatches Since we weren't always tagging before, we have to handle no tag being present: - database empty (use latest because we are about to create everything anew) - if 'spawners' is present, assume 0.8.dev - if 'services' is present, assume 0.7.x - else: assume base revision when we started tracking this stuff
This commit is contained in:
@@ -7,6 +7,9 @@ from datetime import datetime
|
||||
import enum
|
||||
import json
|
||||
|
||||
import alembic.config
|
||||
import alembic.command
|
||||
from alembic.script import ScriptDirectory
|
||||
from tornado.log import app_log
|
||||
|
||||
from sqlalchemy.types import TypeDecorator, TEXT, LargeBinary
|
||||
@@ -21,6 +24,7 @@ from sqlalchemy.pool import StaticPool
|
||||
from sqlalchemy.sql.expression import bindparam
|
||||
from sqlalchemy import create_engine, Table
|
||||
|
||||
from .dbutil import _temp_alembic_ini
|
||||
from .utils import (
|
||||
random_port,
|
||||
new_token, hash_token, compare_token,
|
||||
@@ -431,6 +435,73 @@ class OAuthClient(Base):
|
||||
redirect_uri = Column(Unicode(1023))
|
||||
|
||||
|
||||
class DatabaseSchemaMismatch(Exception):
|
||||
"""Exception raised when the database schema version does not match
|
||||
|
||||
the current version of JupyterHub.
|
||||
"""
|
||||
|
||||
def check_db_revision(engine):
|
||||
"""Check the JupyterHub database revision
|
||||
|
||||
After calling this function, an alembic tag is guaranteed to be stored in the db.
|
||||
|
||||
- Checks the alembic tag and raises a ValueError if it's not the current revision
|
||||
- If no tag is stored (Bug in Hub prior to 0.8),
|
||||
guess revision based on db contents and tag the revision.
|
||||
- Empty databases are tagged with the current revision
|
||||
"""
|
||||
# Check database schema version
|
||||
current_table_names = set(engine.table_names())
|
||||
my_table_names = set(Base.metadata.tables.keys())
|
||||
|
||||
with _temp_alembic_ini(engine.url) as ini:
|
||||
cfg = alembic.config.Config(ini)
|
||||
scripts = ScriptDirectory.from_config(cfg)
|
||||
head = scripts.get_heads()[0]
|
||||
base = scripts.get_base()
|
||||
|
||||
if not my_table_names.intersection(current_table_names):
|
||||
# no tables have been created, stamp with current revision
|
||||
app_log.debug("Stamping empty database with alembic revision %s", head)
|
||||
alembic.command.stamp(cfg, head)
|
||||
return
|
||||
|
||||
if 'alembic_version' not in current_table_names:
|
||||
# Has not been tagged or upgraded before.
|
||||
# we didn't start tagging revisions correctly except during `upgrade-db`
|
||||
# until 0.8
|
||||
# This should only occur for databases created prior to JupyterHub 0.8
|
||||
msg_t = "Database schema version not found, guessing that JupyterHub %s created this database."
|
||||
if 'spawners' in current_table_names:
|
||||
# 0.8
|
||||
app_log.warning(msg_t, '0.8.dev')
|
||||
rev = head
|
||||
elif 'services' in current_table_names:
|
||||
# services is present, tag for 0.7
|
||||
app_log.warning(msg_t, '0.7.x')
|
||||
rev = 'af4cbdb2d13c'
|
||||
else:
|
||||
# it's old, mark as first revision
|
||||
app_log.warning(msg_t, '0.6 or earlier')
|
||||
rev = base
|
||||
app_log.debug("Stamping database schema version %s", rev)
|
||||
alembic.command.stamp(cfg, rev)
|
||||
|
||||
# check database schema version
|
||||
# it should always be defined at this point
|
||||
alembic_revision = engine.execute('SELECT version_num FROM alembic_version').first()[0]
|
||||
if alembic_revision == head:
|
||||
app_log.debug("database schema version found: %s", alembic_revision)
|
||||
pass
|
||||
else:
|
||||
raise DatabaseSchemaMismatch("Found database schema version {found} != {head}. "
|
||||
"Backup your database and run `jupyterhub upgrade-db`"
|
||||
" to upgrade to the latest schema.".format(
|
||||
found=alembic_revision,
|
||||
head=head,
|
||||
))
|
||||
|
||||
def new_session_factory(url="sqlite:///:memory:", reset=False, **kwargs):
|
||||
"""Create a new session at url"""
|
||||
if url.startswith('sqlite'):
|
||||
@@ -446,6 +517,9 @@ def new_session_factory(url="sqlite:///:memory:", reset=False, **kwargs):
|
||||
engine = create_engine(url, **kwargs)
|
||||
if reset:
|
||||
Base.metadata.drop_all(engine)
|
||||
|
||||
# check the db revision (will raise, pointing to `upgrade-db` if version doesn't match)
|
||||
check_db_revision(engine)
|
||||
Base.metadata.create_all(engine)
|
||||
|
||||
session_factory = sessionmaker(bind=engine)
|
||||
|
Reference in New Issue
Block a user