diff --git a/jupyterhub/app.py b/jupyterhub/app.py index b34ef2ac..3a2ce80c 100644 --- a/jupyterhub/app.py +++ b/jupyterhub/app.py @@ -12,7 +12,6 @@ import logging from operator import itemgetter import os import re -import shutil import signal import sys from textwrap import dedent @@ -23,7 +22,6 @@ if sys.version_info[:2] < (3, 3): from jinja2 import Environment, FileSystemLoader -from sqlalchemy import create_engine from sqlalchemy.exc import OperationalError from tornado.httpclient import AsyncHTTPClient @@ -167,39 +165,11 @@ class UpgradeDB(Application): aliases = common_aliases classes = [] - def _backup_db_file(self, db_file): - """Backup a database file""" - if not os.path.exists(db_file): - return - - timestamp = datetime.now().strftime('.%Y-%m-%d-%H%M%S') - backup_db_file = db_file + timestamp - for i in range(1, 10): - if not os.path.exists(backup_db_file): - break - backup_db_file = '{}.{}.{}'.format(db_file, timestamp, i) - if os.path.exists(backup_db_file): - self.exit("backup db file already exists: %s" % backup_db_file) - - self.log.info("Backing up %s => %s", db_file, backup_db_file) - shutil.copy(db_file, backup_db_file) - def start(self): hub = JupyterHub(parent=self) hub.load_config_file(hub.config_file) self.log = hub.log - if (hub.db_url.startswith('sqlite:///')): - db_file = hub.db_url.split(':///', 1)[1] - self._backup_db_file(db_file) - self.log.info("Upgrading %s", hub.db_url) - # run check-db-revision first - engine = create_engine(hub.db_url) - try: - orm.check_db_revision(engine) - except orm.DatabaseSchemaMismatch: - # ignore mismatch error because that's what we are here for! - pass - dbutil.upgrade(hub.db_url) + dbutil.upgrade_if_needed(hub.db_url, log=self.log) class JupyterHub(Application): diff --git a/jupyterhub/dbutil.py b/jupyterhub/dbutil.py index 59887ddb..ee608524 100644 --- a/jupyterhub/dbutil.py +++ b/jupyterhub/dbutil.py @@ -5,11 +5,17 @@ # Based on pgcontents.utils.migrate, used under the Apache license. from contextlib import contextmanager +from datetime import datetime import os +import shutil from subprocess import check_call import sys from tempfile import TemporaryDirectory +from sqlalchemy import create_engine + +from . import orm + _here = os.path.abspath(os.path.dirname(__file__)) ALEMBIC_INI_TEMPLATE_PATH = os.path.join(_here, 'alembic.ini') @@ -84,6 +90,46 @@ def upgrade(db_url, revision='head'): ) +def backup_db_file(db_file, log=None): + """Backup a database file if it exists""" + timestamp = datetime.now().strftime('.%Y-%m-%d-%H%M%S') + backup_db_file = db_file + timestamp + for i in range(1, 10): + if not os.path.exists(backup_db_file): + break + backup_db_file = '{}.{}.{}'.format(db_file, timestamp, i) + # + if os.path.exists(backup_db_file): + raise OSError("backup db file already exists: %s" % backup_db_file) + if log: + log.info("Backing up %s => %s", db_file, backup_db_file) + shutil.copy(db_file, backup_db_file) + + +def upgrade_if_needed(db_url, backup=True, log=None): + """Upgrade a database if needed + + If the database is sqlite, a backup file will be created with a timestamp. + Other database systems should perform their own backups prior to calling this. + """ + # run check-db-revision first + engine = create_engine(db_url) + try: + orm.check_db_revision(engine) + except orm.DatabaseSchemaMismatch: + # ignore mismatch error because that's what we are here for! + pass + else: + # nothing to do + return + log.info("Upgrading %s", db_url) + # we need to upgrade, backup the database + if backup and db_url.startswith('sqlite:///'): + db_file = db_url.split(':///', 1)[1] + backup_db_file(db_file, log=log) + upgrade(db_url) + + def _alembic(*args): """Run an alembic command with a temporary alembic.ini""" with _temp_alembic_ini('sqlite:///jupyterhub.sqlite') as alembic_ini: diff --git a/jupyterhub/orm.py b/jupyterhub/orm.py index 11aa0cc3..8dcf1eae 100644 --- a/jupyterhub/orm.py +++ b/jupyterhub/orm.py @@ -24,7 +24,6 @@ from sqlalchemy.pool import StaticPool from sqlalchemy.sql.expression import bindparam from sqlalchemy import create_engine, Table -from .dbutil import _temp_alembic_ini from .utils import ( random_port, new_token, hash_token, compare_token, @@ -463,6 +462,8 @@ def check_db_revision(engine): current_table_names = set(engine.table_names()) my_table_names = set(Base.metadata.tables.keys()) + from .dbutil import _temp_alembic_ini + with _temp_alembic_ini(engine.url) as ini: cfg = alembic.config.Config(ini) scripts = ScriptDirectory.from_config(cfg)