add dbutil.upgrade_if_needed

so it's reusable now that we want to use it in more than one place
This commit is contained in:
Min RK
2017-10-27 14:27:02 +02:00
parent 5356954240
commit f002c67343
3 changed files with 49 additions and 32 deletions

View File

@@ -12,7 +12,6 @@ import logging
from operator import itemgetter
import os
import re
import shutil
import signal
import sys
from textwrap import dedent
@@ -23,7 +22,6 @@ if sys.version_info[:2] < (3, 3):
from jinja2 import Environment, FileSystemLoader
from sqlalchemy import create_engine
from sqlalchemy.exc import OperationalError
from tornado.httpclient import AsyncHTTPClient
@@ -167,39 +165,11 @@ class UpgradeDB(Application):
aliases = common_aliases
classes = []
def _backup_db_file(self, db_file):
"""Backup a database file"""
if not os.path.exists(db_file):
return
timestamp = datetime.now().strftime('.%Y-%m-%d-%H%M%S')
backup_db_file = db_file + timestamp
for i in range(1, 10):
if not os.path.exists(backup_db_file):
break
backup_db_file = '{}.{}.{}'.format(db_file, timestamp, i)
if os.path.exists(backup_db_file):
self.exit("backup db file already exists: %s" % backup_db_file)
self.log.info("Backing up %s => %s", db_file, backup_db_file)
shutil.copy(db_file, backup_db_file)
def start(self):
hub = JupyterHub(parent=self)
hub.load_config_file(hub.config_file)
self.log = hub.log
if (hub.db_url.startswith('sqlite:///')):
db_file = hub.db_url.split(':///', 1)[1]
self._backup_db_file(db_file)
self.log.info("Upgrading %s", hub.db_url)
# run check-db-revision first
engine = create_engine(hub.db_url)
try:
orm.check_db_revision(engine)
except orm.DatabaseSchemaMismatch:
# ignore mismatch error because that's what we are here for!
pass
dbutil.upgrade(hub.db_url)
dbutil.upgrade_if_needed(hub.db_url, log=self.log)
class JupyterHub(Application):

View File

@@ -5,11 +5,17 @@
# Based on pgcontents.utils.migrate, used under the Apache license.
from contextlib import contextmanager
from datetime import datetime
import os
import shutil
from subprocess import check_call
import sys
from tempfile import TemporaryDirectory
from sqlalchemy import create_engine
from . import orm
_here = os.path.abspath(os.path.dirname(__file__))
ALEMBIC_INI_TEMPLATE_PATH = os.path.join(_here, 'alembic.ini')
@@ -84,6 +90,46 @@ def upgrade(db_url, revision='head'):
)
def backup_db_file(db_file, log=None):
"""Backup a database file if it exists"""
timestamp = datetime.now().strftime('.%Y-%m-%d-%H%M%S')
backup_db_file = db_file + timestamp
for i in range(1, 10):
if not os.path.exists(backup_db_file):
break
backup_db_file = '{}.{}.{}'.format(db_file, timestamp, i)
#
if os.path.exists(backup_db_file):
raise OSError("backup db file already exists: %s" % backup_db_file)
if log:
log.info("Backing up %s => %s", db_file, backup_db_file)
shutil.copy(db_file, backup_db_file)
def upgrade_if_needed(db_url, backup=True, log=None):
"""Upgrade a database if needed
If the database is sqlite, a backup file will be created with a timestamp.
Other database systems should perform their own backups prior to calling this.
"""
# run check-db-revision first
engine = create_engine(db_url)
try:
orm.check_db_revision(engine)
except orm.DatabaseSchemaMismatch:
# ignore mismatch error because that's what we are here for!
pass
else:
# nothing to do
return
log.info("Upgrading %s", db_url)
# we need to upgrade, backup the database
if backup and db_url.startswith('sqlite:///'):
db_file = db_url.split(':///', 1)[1]
backup_db_file(db_file, log=log)
upgrade(db_url)
def _alembic(*args):
"""Run an alembic command with a temporary alembic.ini"""
with _temp_alembic_ini('sqlite:///jupyterhub.sqlite') as alembic_ini:

View File

@@ -24,7 +24,6 @@ from sqlalchemy.pool import StaticPool
from sqlalchemy.sql.expression import bindparam
from sqlalchemy import create_engine, Table
from .dbutil import _temp_alembic_ini
from .utils import (
random_port,
new_token, hash_token, compare_token,
@@ -463,6 +462,8 @@ def check_db_revision(engine):
current_table_names = set(engine.table_names())
my_table_names = set(Base.metadata.tables.keys())
from .dbutil import _temp_alembic_ini
with _temp_alembic_ini(engine.url) as ini:
cfg = alembic.config.Config(ini)
scripts = ScriptDirectory.from_config(cfg)