mirror of
https://github.com/jupyterhub/jupyterhub.git
synced 2025-10-16 06:22:59 +00:00
add dbutil.upgrade_if_needed
so it's reusable now that we want to use it in more than one place
This commit is contained in:
@@ -12,7 +12,6 @@ import logging
|
|||||||
from operator import itemgetter
|
from operator import itemgetter
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import shutil
|
|
||||||
import signal
|
import signal
|
||||||
import sys
|
import sys
|
||||||
from textwrap import dedent
|
from textwrap import dedent
|
||||||
@@ -23,7 +22,6 @@ if sys.version_info[:2] < (3, 3):
|
|||||||
|
|
||||||
from jinja2 import Environment, FileSystemLoader
|
from jinja2 import Environment, FileSystemLoader
|
||||||
|
|
||||||
from sqlalchemy import create_engine
|
|
||||||
from sqlalchemy.exc import OperationalError
|
from sqlalchemy.exc import OperationalError
|
||||||
|
|
||||||
from tornado.httpclient import AsyncHTTPClient
|
from tornado.httpclient import AsyncHTTPClient
|
||||||
@@ -167,39 +165,11 @@ class UpgradeDB(Application):
|
|||||||
aliases = common_aliases
|
aliases = common_aliases
|
||||||
classes = []
|
classes = []
|
||||||
|
|
||||||
def _backup_db_file(self, db_file):
|
|
||||||
"""Backup a database file"""
|
|
||||||
if not os.path.exists(db_file):
|
|
||||||
return
|
|
||||||
|
|
||||||
timestamp = datetime.now().strftime('.%Y-%m-%d-%H%M%S')
|
|
||||||
backup_db_file = db_file + timestamp
|
|
||||||
for i in range(1, 10):
|
|
||||||
if not os.path.exists(backup_db_file):
|
|
||||||
break
|
|
||||||
backup_db_file = '{}.{}.{}'.format(db_file, timestamp, i)
|
|
||||||
if os.path.exists(backup_db_file):
|
|
||||||
self.exit("backup db file already exists: %s" % backup_db_file)
|
|
||||||
|
|
||||||
self.log.info("Backing up %s => %s", db_file, backup_db_file)
|
|
||||||
shutil.copy(db_file, backup_db_file)
|
|
||||||
|
|
||||||
def start(self):
|
def start(self):
|
||||||
hub = JupyterHub(parent=self)
|
hub = JupyterHub(parent=self)
|
||||||
hub.load_config_file(hub.config_file)
|
hub.load_config_file(hub.config_file)
|
||||||
self.log = hub.log
|
self.log = hub.log
|
||||||
if (hub.db_url.startswith('sqlite:///')):
|
dbutil.upgrade_if_needed(hub.db_url, log=self.log)
|
||||||
db_file = hub.db_url.split(':///', 1)[1]
|
|
||||||
self._backup_db_file(db_file)
|
|
||||||
self.log.info("Upgrading %s", hub.db_url)
|
|
||||||
# run check-db-revision first
|
|
||||||
engine = create_engine(hub.db_url)
|
|
||||||
try:
|
|
||||||
orm.check_db_revision(engine)
|
|
||||||
except orm.DatabaseSchemaMismatch:
|
|
||||||
# ignore mismatch error because that's what we are here for!
|
|
||||||
pass
|
|
||||||
dbutil.upgrade(hub.db_url)
|
|
||||||
|
|
||||||
|
|
||||||
class JupyterHub(Application):
|
class JupyterHub(Application):
|
||||||
|
@@ -5,11 +5,17 @@
|
|||||||
# Based on pgcontents.utils.migrate, used under the Apache license.
|
# Based on pgcontents.utils.migrate, used under the Apache license.
|
||||||
|
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
from datetime import datetime
|
||||||
import os
|
import os
|
||||||
|
import shutil
|
||||||
from subprocess import check_call
|
from subprocess import check_call
|
||||||
import sys
|
import sys
|
||||||
from tempfile import TemporaryDirectory
|
from tempfile import TemporaryDirectory
|
||||||
|
|
||||||
|
from sqlalchemy import create_engine
|
||||||
|
|
||||||
|
from . import orm
|
||||||
|
|
||||||
_here = os.path.abspath(os.path.dirname(__file__))
|
_here = os.path.abspath(os.path.dirname(__file__))
|
||||||
|
|
||||||
ALEMBIC_INI_TEMPLATE_PATH = os.path.join(_here, 'alembic.ini')
|
ALEMBIC_INI_TEMPLATE_PATH = os.path.join(_here, 'alembic.ini')
|
||||||
@@ -84,6 +90,46 @@ def upgrade(db_url, revision='head'):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def backup_db_file(db_file, log=None):
|
||||||
|
"""Backup a database file if it exists"""
|
||||||
|
timestamp = datetime.now().strftime('.%Y-%m-%d-%H%M%S')
|
||||||
|
backup_db_file = db_file + timestamp
|
||||||
|
for i in range(1, 10):
|
||||||
|
if not os.path.exists(backup_db_file):
|
||||||
|
break
|
||||||
|
backup_db_file = '{}.{}.{}'.format(db_file, timestamp, i)
|
||||||
|
#
|
||||||
|
if os.path.exists(backup_db_file):
|
||||||
|
raise OSError("backup db file already exists: %s" % backup_db_file)
|
||||||
|
if log:
|
||||||
|
log.info("Backing up %s => %s", db_file, backup_db_file)
|
||||||
|
shutil.copy(db_file, backup_db_file)
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade_if_needed(db_url, backup=True, log=None):
|
||||||
|
"""Upgrade a database if needed
|
||||||
|
|
||||||
|
If the database is sqlite, a backup file will be created with a timestamp.
|
||||||
|
Other database systems should perform their own backups prior to calling this.
|
||||||
|
"""
|
||||||
|
# run check-db-revision first
|
||||||
|
engine = create_engine(db_url)
|
||||||
|
try:
|
||||||
|
orm.check_db_revision(engine)
|
||||||
|
except orm.DatabaseSchemaMismatch:
|
||||||
|
# ignore mismatch error because that's what we are here for!
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
# nothing to do
|
||||||
|
return
|
||||||
|
log.info("Upgrading %s", db_url)
|
||||||
|
# we need to upgrade, backup the database
|
||||||
|
if backup and db_url.startswith('sqlite:///'):
|
||||||
|
db_file = db_url.split(':///', 1)[1]
|
||||||
|
backup_db_file(db_file, log=log)
|
||||||
|
upgrade(db_url)
|
||||||
|
|
||||||
|
|
||||||
def _alembic(*args):
|
def _alembic(*args):
|
||||||
"""Run an alembic command with a temporary alembic.ini"""
|
"""Run an alembic command with a temporary alembic.ini"""
|
||||||
with _temp_alembic_ini('sqlite:///jupyterhub.sqlite') as alembic_ini:
|
with _temp_alembic_ini('sqlite:///jupyterhub.sqlite') as alembic_ini:
|
||||||
|
@@ -24,7 +24,6 @@ from sqlalchemy.pool import StaticPool
|
|||||||
from sqlalchemy.sql.expression import bindparam
|
from sqlalchemy.sql.expression import bindparam
|
||||||
from sqlalchemy import create_engine, Table
|
from sqlalchemy import create_engine, Table
|
||||||
|
|
||||||
from .dbutil import _temp_alembic_ini
|
|
||||||
from .utils import (
|
from .utils import (
|
||||||
random_port,
|
random_port,
|
||||||
new_token, hash_token, compare_token,
|
new_token, hash_token, compare_token,
|
||||||
@@ -463,6 +462,8 @@ def check_db_revision(engine):
|
|||||||
current_table_names = set(engine.table_names())
|
current_table_names = set(engine.table_names())
|
||||||
my_table_names = set(Base.metadata.tables.keys())
|
my_table_names = set(Base.metadata.tables.keys())
|
||||||
|
|
||||||
|
from .dbutil import _temp_alembic_ini
|
||||||
|
|
||||||
with _temp_alembic_ini(engine.url) as ini:
|
with _temp_alembic_ini(engine.url) as ini:
|
||||||
cfg = alembic.config.Config(ini)
|
cfg = alembic.config.Config(ini)
|
||||||
scripts = ScriptDirectory.from_config(cfg)
|
scripts = ScriptDirectory.from_config(cfg)
|
||||||
|
Reference in New Issue
Block a user