add dbutil.upgrade_if_needed

so it's reusable now that we want to use it in more than one place
This commit is contained in:
Min RK
2017-10-27 14:27:02 +02:00
parent 5356954240
commit f002c67343
3 changed files with 49 additions and 32 deletions

View File

@@ -12,7 +12,6 @@ import logging
from operator import itemgetter from operator import itemgetter
import os import os
import re import re
import shutil
import signal import signal
import sys import sys
from textwrap import dedent from textwrap import dedent
@@ -23,7 +22,6 @@ if sys.version_info[:2] < (3, 3):
from jinja2 import Environment, FileSystemLoader from jinja2 import Environment, FileSystemLoader
from sqlalchemy import create_engine
from sqlalchemy.exc import OperationalError from sqlalchemy.exc import OperationalError
from tornado.httpclient import AsyncHTTPClient from tornado.httpclient import AsyncHTTPClient
@@ -167,39 +165,11 @@ class UpgradeDB(Application):
aliases = common_aliases aliases = common_aliases
classes = [] classes = []
def _backup_db_file(self, db_file):
"""Backup a database file"""
if not os.path.exists(db_file):
return
timestamp = datetime.now().strftime('.%Y-%m-%d-%H%M%S')
backup_db_file = db_file + timestamp
for i in range(1, 10):
if not os.path.exists(backup_db_file):
break
backup_db_file = '{}.{}.{}'.format(db_file, timestamp, i)
if os.path.exists(backup_db_file):
self.exit("backup db file already exists: %s" % backup_db_file)
self.log.info("Backing up %s => %s", db_file, backup_db_file)
shutil.copy(db_file, backup_db_file)
def start(self): def start(self):
hub = JupyterHub(parent=self) hub = JupyterHub(parent=self)
hub.load_config_file(hub.config_file) hub.load_config_file(hub.config_file)
self.log = hub.log self.log = hub.log
if (hub.db_url.startswith('sqlite:///')): dbutil.upgrade_if_needed(hub.db_url, log=self.log)
db_file = hub.db_url.split(':///', 1)[1]
self._backup_db_file(db_file)
self.log.info("Upgrading %s", hub.db_url)
# run check-db-revision first
engine = create_engine(hub.db_url)
try:
orm.check_db_revision(engine)
except orm.DatabaseSchemaMismatch:
# ignore mismatch error because that's what we are here for!
pass
dbutil.upgrade(hub.db_url)
class JupyterHub(Application): class JupyterHub(Application):

View File

@@ -5,11 +5,17 @@
# Based on pgcontents.utils.migrate, used under the Apache license. # Based on pgcontents.utils.migrate, used under the Apache license.
from contextlib import contextmanager from contextlib import contextmanager
from datetime import datetime
import os import os
import shutil
from subprocess import check_call from subprocess import check_call
import sys import sys
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
from sqlalchemy import create_engine
from . import orm
_here = os.path.abspath(os.path.dirname(__file__)) _here = os.path.abspath(os.path.dirname(__file__))
ALEMBIC_INI_TEMPLATE_PATH = os.path.join(_here, 'alembic.ini') ALEMBIC_INI_TEMPLATE_PATH = os.path.join(_here, 'alembic.ini')
@@ -84,6 +90,46 @@ def upgrade(db_url, revision='head'):
) )
def backup_db_file(db_file, log=None):
"""Backup a database file if it exists"""
timestamp = datetime.now().strftime('.%Y-%m-%d-%H%M%S')
backup_db_file = db_file + timestamp
for i in range(1, 10):
if not os.path.exists(backup_db_file):
break
backup_db_file = '{}.{}.{}'.format(db_file, timestamp, i)
#
if os.path.exists(backup_db_file):
raise OSError("backup db file already exists: %s" % backup_db_file)
if log:
log.info("Backing up %s => %s", db_file, backup_db_file)
shutil.copy(db_file, backup_db_file)
def upgrade_if_needed(db_url, backup=True, log=None):
"""Upgrade a database if needed
If the database is sqlite, a backup file will be created with a timestamp.
Other database systems should perform their own backups prior to calling this.
"""
# run check-db-revision first
engine = create_engine(db_url)
try:
orm.check_db_revision(engine)
except orm.DatabaseSchemaMismatch:
# ignore mismatch error because that's what we are here for!
pass
else:
# nothing to do
return
log.info("Upgrading %s", db_url)
# we need to upgrade, backup the database
if backup and db_url.startswith('sqlite:///'):
db_file = db_url.split(':///', 1)[1]
backup_db_file(db_file, log=log)
upgrade(db_url)
def _alembic(*args): def _alembic(*args):
"""Run an alembic command with a temporary alembic.ini""" """Run an alembic command with a temporary alembic.ini"""
with _temp_alembic_ini('sqlite:///jupyterhub.sqlite') as alembic_ini: with _temp_alembic_ini('sqlite:///jupyterhub.sqlite') as alembic_ini:

View File

@@ -24,7 +24,6 @@ from sqlalchemy.pool import StaticPool
from sqlalchemy.sql.expression import bindparam from sqlalchemy.sql.expression import bindparam
from sqlalchemy import create_engine, Table from sqlalchemy import create_engine, Table
from .dbutil import _temp_alembic_ini
from .utils import ( from .utils import (
random_port, random_port,
new_token, hash_token, compare_token, new_token, hash_token, compare_token,
@@ -463,6 +462,8 @@ def check_db_revision(engine):
current_table_names = set(engine.table_names()) current_table_names = set(engine.table_names())
my_table_names = set(Base.metadata.tables.keys()) my_table_names = set(Base.metadata.tables.keys())
from .dbutil import _temp_alembic_ini
with _temp_alembic_ini(engine.url) as ini: with _temp_alembic_ini(engine.url) as ini:
cfg = alembic.config.Config(ini) cfg = alembic.config.Config(ini)
scripts = ScriptDirectory.from_config(cfg) scripts = ScriptDirectory.from_config(cfg)