mirror of
https://github.com/jupyterhub/jupyterhub.git
synced 2025-10-15 14:03:02 +00:00
add dbutil.upgrade_if_needed
so it's reusable now that we want to use it in more than one place
This commit is contained in:
@@ -12,7 +12,6 @@ import logging
|
||||
from operator import itemgetter
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import signal
|
||||
import sys
|
||||
from textwrap import dedent
|
||||
@@ -23,7 +22,6 @@ if sys.version_info[:2] < (3, 3):
|
||||
|
||||
from jinja2 import Environment, FileSystemLoader
|
||||
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.exc import OperationalError
|
||||
|
||||
from tornado.httpclient import AsyncHTTPClient
|
||||
@@ -167,39 +165,11 @@ class UpgradeDB(Application):
|
||||
aliases = common_aliases
|
||||
classes = []
|
||||
|
||||
def _backup_db_file(self, db_file):
|
||||
"""Backup a database file"""
|
||||
if not os.path.exists(db_file):
|
||||
return
|
||||
|
||||
timestamp = datetime.now().strftime('.%Y-%m-%d-%H%M%S')
|
||||
backup_db_file = db_file + timestamp
|
||||
for i in range(1, 10):
|
||||
if not os.path.exists(backup_db_file):
|
||||
break
|
||||
backup_db_file = '{}.{}.{}'.format(db_file, timestamp, i)
|
||||
if os.path.exists(backup_db_file):
|
||||
self.exit("backup db file already exists: %s" % backup_db_file)
|
||||
|
||||
self.log.info("Backing up %s => %s", db_file, backup_db_file)
|
||||
shutil.copy(db_file, backup_db_file)
|
||||
|
||||
def start(self):
|
||||
hub = JupyterHub(parent=self)
|
||||
hub.load_config_file(hub.config_file)
|
||||
self.log = hub.log
|
||||
if (hub.db_url.startswith('sqlite:///')):
|
||||
db_file = hub.db_url.split(':///', 1)[1]
|
||||
self._backup_db_file(db_file)
|
||||
self.log.info("Upgrading %s", hub.db_url)
|
||||
# run check-db-revision first
|
||||
engine = create_engine(hub.db_url)
|
||||
try:
|
||||
orm.check_db_revision(engine)
|
||||
except orm.DatabaseSchemaMismatch:
|
||||
# ignore mismatch error because that's what we are here for!
|
||||
pass
|
||||
dbutil.upgrade(hub.db_url)
|
||||
dbutil.upgrade_if_needed(hub.db_url, log=self.log)
|
||||
|
||||
|
||||
class JupyterHub(Application):
|
||||
|
@@ -5,11 +5,17 @@
|
||||
# Based on pgcontents.utils.migrate, used under the Apache license.
|
||||
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime
|
||||
import os
|
||||
import shutil
|
||||
from subprocess import check_call
|
||||
import sys
|
||||
from tempfile import TemporaryDirectory
|
||||
|
||||
from sqlalchemy import create_engine
|
||||
|
||||
from . import orm
|
||||
|
||||
_here = os.path.abspath(os.path.dirname(__file__))
|
||||
|
||||
ALEMBIC_INI_TEMPLATE_PATH = os.path.join(_here, 'alembic.ini')
|
||||
@@ -84,6 +90,46 @@ def upgrade(db_url, revision='head'):
|
||||
)
|
||||
|
||||
|
||||
def backup_db_file(db_file, log=None):
|
||||
"""Backup a database file if it exists"""
|
||||
timestamp = datetime.now().strftime('.%Y-%m-%d-%H%M%S')
|
||||
backup_db_file = db_file + timestamp
|
||||
for i in range(1, 10):
|
||||
if not os.path.exists(backup_db_file):
|
||||
break
|
||||
backup_db_file = '{}.{}.{}'.format(db_file, timestamp, i)
|
||||
#
|
||||
if os.path.exists(backup_db_file):
|
||||
raise OSError("backup db file already exists: %s" % backup_db_file)
|
||||
if log:
|
||||
log.info("Backing up %s => %s", db_file, backup_db_file)
|
||||
shutil.copy(db_file, backup_db_file)
|
||||
|
||||
|
||||
def upgrade_if_needed(db_url, backup=True, log=None):
|
||||
"""Upgrade a database if needed
|
||||
|
||||
If the database is sqlite, a backup file will be created with a timestamp.
|
||||
Other database systems should perform their own backups prior to calling this.
|
||||
"""
|
||||
# run check-db-revision first
|
||||
engine = create_engine(db_url)
|
||||
try:
|
||||
orm.check_db_revision(engine)
|
||||
except orm.DatabaseSchemaMismatch:
|
||||
# ignore mismatch error because that's what we are here for!
|
||||
pass
|
||||
else:
|
||||
# nothing to do
|
||||
return
|
||||
log.info("Upgrading %s", db_url)
|
||||
# we need to upgrade, backup the database
|
||||
if backup and db_url.startswith('sqlite:///'):
|
||||
db_file = db_url.split(':///', 1)[1]
|
||||
backup_db_file(db_file, log=log)
|
||||
upgrade(db_url)
|
||||
|
||||
|
||||
def _alembic(*args):
|
||||
"""Run an alembic command with a temporary alembic.ini"""
|
||||
with _temp_alembic_ini('sqlite:///jupyterhub.sqlite') as alembic_ini:
|
||||
|
@@ -24,7 +24,6 @@ from sqlalchemy.pool import StaticPool
|
||||
from sqlalchemy.sql.expression import bindparam
|
||||
from sqlalchemy import create_engine, Table
|
||||
|
||||
from .dbutil import _temp_alembic_ini
|
||||
from .utils import (
|
||||
random_port,
|
||||
new_token, hash_token, compare_token,
|
||||
@@ -463,6 +462,8 @@ def check_db_revision(engine):
|
||||
current_table_names = set(engine.table_names())
|
||||
my_table_names = set(Base.metadata.tables.keys())
|
||||
|
||||
from .dbutil import _temp_alembic_ini
|
||||
|
||||
with _temp_alembic_ini(engine.url) as ini:
|
||||
cfg = alembic.config.Config(ini)
|
||||
scripts = ScriptDirectory.from_config(cfg)
|
||||
|
Reference in New Issue
Block a user