mirror of
https://github.com/jupyterhub/jupyterhub.git
synced 2025-10-08 10:34:10 +00:00
sqlalchemy 2 compatibility
- avoid backref warnings by adding objects to session explicitly before creating any relationships - remove unnecessary `[]` around scalar query - use `text()` wrapper on connection.execute - engine.execute is removed - update import of declarative_base - ensure RemovedIn20Warning is available for warnings filters on sqlalchemy < 1.4 (needs editable install to avoid pytest path mismatch) - explicitly relay password in engine.url to alembic
This commit is contained in:
@@ -8,7 +8,9 @@ from datetime import datetime, timedelta
|
||||
|
||||
import alembic.command
|
||||
import alembic.config
|
||||
import sqlalchemy
|
||||
from alembic.script import ScriptDirectory
|
||||
from packaging.version import parse as parse_version
|
||||
from sqlalchemy import (
|
||||
Boolean,
|
||||
Column,
|
||||
@@ -24,8 +26,8 @@ from sqlalchemy import (
|
||||
inspect,
|
||||
or_,
|
||||
select,
|
||||
text,
|
||||
)
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import (
|
||||
Session,
|
||||
backref,
|
||||
@@ -34,6 +36,13 @@ from sqlalchemy.orm import (
|
||||
relationship,
|
||||
sessionmaker,
|
||||
)
|
||||
|
||||
try:
|
||||
from sqlalchemy.orm import declarative_base
|
||||
except ImportError:
|
||||
# sqlalchemy < 1.4
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
|
||||
from sqlalchemy.pool import StaticPool
|
||||
from sqlalchemy.types import LargeBinary, Text, TypeDecorator
|
||||
from tornado.log import app_log
|
||||
@@ -750,6 +759,7 @@ class APIToken(Hashed, Base):
|
||||
session_id=session_id,
|
||||
scopes=list(scopes),
|
||||
)
|
||||
db.add(orm_token)
|
||||
orm_token.token = token
|
||||
if user:
|
||||
assert user.id is not None
|
||||
@@ -760,7 +770,6 @@ class APIToken(Hashed, Base):
|
||||
if expires_in is not None:
|
||||
orm_token.expires_at = cls.now() + timedelta(seconds=expires_in)
|
||||
|
||||
db.add(orm_token)
|
||||
db.commit()
|
||||
return token
|
||||
|
||||
@@ -902,7 +911,7 @@ def register_ping_connection(engine):
|
||||
"""
|
||||
|
||||
@event.listens_for(engine, "engine_connect")
|
||||
def ping_connection(connection, branch):
|
||||
def ping_connection(connection, branch=None):
|
||||
if branch:
|
||||
# "branch" refers to a sub-connection of a connection,
|
||||
# we don't want to bother pinging on these.
|
||||
@@ -913,11 +922,17 @@ def register_ping_connection(engine):
|
||||
save_should_close_with_result = connection.should_close_with_result
|
||||
connection.should_close_with_result = False
|
||||
|
||||
if parse_version(sqlalchemy.__version__) < parse_version("1.4"):
|
||||
one = [1]
|
||||
else:
|
||||
one = 1
|
||||
|
||||
try:
|
||||
# run a SELECT 1. use a core select() so that
|
||||
# run a SELECT 1. use a core select() so that
|
||||
# the SELECT of a scalar value without a table is
|
||||
# appropriately formatted for the backend
|
||||
connection.scalar(select([1]))
|
||||
with connection.begin() as transaction:
|
||||
connection.scalar(select(one))
|
||||
except exc.DBAPIError as err:
|
||||
# catch SQLAlchemy's DBAPIError, which is a wrapper
|
||||
# for the DBAPI's exception. It includes a .connection_invalidated
|
||||
@@ -932,7 +947,8 @@ def register_ping_connection(engine):
|
||||
# itself and establish a new connection. The disconnect detection
|
||||
# here also causes the whole connection pool to be invalidated
|
||||
# so that all stale connections are discarded.
|
||||
connection.scalar(select([1]))
|
||||
with connection.begin() as transaction:
|
||||
connection.scalar(select(one))
|
||||
else:
|
||||
raise
|
||||
finally:
|
||||
@@ -956,7 +972,13 @@ def check_db_revision(engine):
|
||||
|
||||
from .dbutil import _temp_alembic_ini
|
||||
|
||||
with _temp_alembic_ini(engine.url) as ini:
|
||||
if hasattr(engine.url, "render_as_string"):
|
||||
# sqlalchemy >= 1.4
|
||||
engine_url = engine.url.render_as_string(hide_password=False)
|
||||
else:
|
||||
engine_url = str(engine.url)
|
||||
|
||||
with _temp_alembic_ini(engine_url) as ini:
|
||||
cfg = alembic.config.Config(ini)
|
||||
scripts = ScriptDirectory.from_config(cfg)
|
||||
head = scripts.get_heads()[0]
|
||||
@@ -991,9 +1013,10 @@ def check_db_revision(engine):
|
||||
|
||||
# check database schema version
|
||||
# it should always be defined at this point
|
||||
alembic_revision = engine.execute(
|
||||
'SELECT version_num FROM alembic_version'
|
||||
).first()[0]
|
||||
with engine.begin() as connection:
|
||||
alembic_revision = connection.execute(
|
||||
text('SELECT version_num FROM alembic_version')
|
||||
).first()[0]
|
||||
if alembic_revision == head:
|
||||
app_log.debug("database schema version found: %s", alembic_revision)
|
||||
else:
|
||||
@@ -1010,13 +1033,16 @@ def mysql_large_prefix_check(engine):
|
||||
"""Check mysql has innodb_large_prefix set"""
|
||||
if not str(engine.url).startswith('mysql'):
|
||||
return False
|
||||
variables = dict(
|
||||
engine.execute(
|
||||
'show variables where variable_name like '
|
||||
'"innodb_large_prefix" or '
|
||||
'variable_name like "innodb_file_format";'
|
||||
).fetchall()
|
||||
)
|
||||
with engine.begin() as connection:
|
||||
variables = dict(
|
||||
connection.execute(
|
||||
text(
|
||||
'show variables where variable_name like '
|
||||
'"innodb_large_prefix" or '
|
||||
'variable_name like "innodb_file_format";'
|
||||
)
|
||||
).fetchall()
|
||||
)
|
||||
if (
|
||||
variables.get('innodb_file_format', 'Barracuda') == 'Barracuda'
|
||||
and variables.get('innodb_large_prefix', 'ON') == 'ON'
|
||||
|
Reference in New Issue
Block a user