sqlalchemy 2 compatibility

- avoid backref warnings by adding objects to session explicitly before creating any relationships
- remove unnecessary `[]` around scalar query
- use `text()` wrapper on connection.execute
- engine.execute is removed
- update import of declarative_base
- ensure RemovedIn20Warning is available for warnings filters on sqlalchemy < 1.4 (needs editable install to avoid pytest path mismatch)
- explicitly relay password in engine.url to alembic
This commit is contained in:
Min RK
2023-01-18 09:44:35 +01:00
parent 0a84738fe9
commit 2db7c47fbf
12 changed files with 114 additions and 50 deletions

View File

@@ -29,6 +29,7 @@ env:
# UTF-8 content may be interpreted as ascii and causes errors without this. # UTF-8 content may be interpreted as ascii and causes errors without this.
LANG: C.UTF-8 LANG: C.UTF-8
PYTEST_ADDOPTS: "--verbose --color=yes" PYTEST_ADDOPTS: "--verbose --color=yes"
SQLALCHEMY_WARN_20: "1"
permissions: permissions:
contents: read contents: read
@@ -140,7 +141,7 @@ jobs:
- name: Install Python dependencies - name: Install Python dependencies
run: | run: |
pip install --upgrade pip pip install --upgrade pip
pip install ".[test]" pip install -e ".[test]"
if [ "${{ matrix.oldest_dependencies }}" != "" ]; then if [ "${{ matrix.oldest_dependencies }}" != "" ]; then
# take any dependencies in requirements.txt such as tornado>=5.0 # take any dependencies in requirements.txt such as tornado>=5.0
@@ -152,6 +153,7 @@ jobs:
if [ "${{ matrix.main_dependencies }}" != "" ]; then if [ "${{ matrix.main_dependencies }}" != "" ]; then
pip install git+https://github.com/ipython/traitlets#egg=traitlets --force pip install git+https://github.com/ipython/traitlets#egg=traitlets --force
pip install --upgrade --pre sqlalchemy
fi fi
if [ "${{ matrix.legacy_notebook }}" != "" ]; then if [ "${{ matrix.legacy_notebook }}" != "" ]; then
pip uninstall jupyter_server --yes pip uninstall jupyter_server --yes

View File

@@ -94,8 +94,9 @@ class GroupListAPIHandler(_GroupAPIHandler):
# create the group # create the group
self.log.info("Creating new group %s with %i users", name, len(users)) self.log.info("Creating new group %s with %i users", name, len(users))
self.log.debug("Users: %s", usernames) self.log.debug("Users: %s", usernames)
group = orm.Group(name=name, users=users) group = orm.Group(name=name)
self.db.add(group) self.db.add(group)
group.users = users
self.db.commit() self.db.commit()
created.append(group) created.append(group)
self.write(json.dumps([self.group_model(group) for group in created])) self.write(json.dumps([self.group_model(group) for group in created]))
@@ -131,8 +132,9 @@ class GroupAPIHandler(_GroupAPIHandler):
# create the group # create the group
self.log.info("Creating new group %s with %i users", group_name, len(users)) self.log.info("Creating new group %s with %i users", group_name, len(users))
self.log.debug("Users: %s", usernames) self.log.debug("Users: %s", usernames)
group = orm.Group(name=group_name, users=users) group = orm.Group(name=group_name)
self.db.add(group) self.db.add(group)
group.users = users
self.db.commit() self.db.commit()
self.write(json.dumps(self.group_model(group))) self.write(json.dumps(self.group_model(group)))
self.set_status(201) self.set_status(201)

View File

@@ -1962,9 +1962,9 @@ class JupyterHub(Application):
user = orm.User.find(db, name) user = orm.User.find(db, name)
if user is None: if user is None:
user = orm.User(name=name, admin=True) user = orm.User(name=name, admin=True)
db.add(user)
roles.assign_default_roles(self.db, entity=user) roles.assign_default_roles(self.db, entity=user)
new_users.append(user) new_users.append(user)
db.add(user)
else: else:
user.admin = True user.admin = True
# the admin_users config variable will never be used after this point. # the admin_users config variable will never be used after this point.
@@ -2376,6 +2376,7 @@ class JupyterHub(Application):
if orm_service is None: if orm_service is None:
# not found, create a new one # not found, create a new one
orm_service = orm.Service(name=name) orm_service = orm.Service(name=name)
self.db.add(orm_service)
if spec.get('admin', False): if spec.get('admin', False):
self.log.warning( self.log.warning(
f"Service {name} sets `admin: True`, which is deprecated in JupyterHub 2.0." f"Service {name} sets `admin: True`, which is deprecated in JupyterHub 2.0."
@@ -2384,7 +2385,6 @@ class JupyterHub(Application):
"the Service admin flag will be ignored." "the Service admin flag will be ignored."
) )
roles.update_roles(self.db, entity=orm_service, roles=['admin']) roles.update_roles(self.db, entity=orm_service, roles=['admin'])
self.db.add(orm_service)
orm_service.admin = spec.get('admin', False) orm_service.admin = spec.get('admin', False)
self.db.commit() self.db.commit()
service = Service( service = Service(

View File

@@ -257,16 +257,16 @@ class JupyterHubRequestValidator(RequestValidator):
raise ValueError("No such client: %s" % client_id) raise ValueError("No such client: %s" % client_id)
orm_code = orm.OAuthCode( orm_code = orm.OAuthCode(
client=orm_client,
code=code['code'], code=code['code'],
# oauth has 5 minutes to complete # oauth has 5 minutes to complete
expires_at=int(orm.OAuthCode.now() + 300), expires_at=int(orm.OAuthCode.now() + 300),
scopes=list(request.scopes), scopes=list(request.scopes),
user=request.user.orm_user,
redirect_uri=orm_client.redirect_uri, redirect_uri=orm_client.redirect_uri,
session_id=request.session_id, session_id=request.session_id,
) )
self.db.add(orm_code) self.db.add(orm_code)
orm_code.client = orm_client
orm_code.user = request.user.orm_user
self.db.commit() self.db.commit()
def get_authorization_code_scopes(self, client_id, code, redirect_uri, request): def get_authorization_code_scopes(self, client_id, code, redirect_uri, request):

View File

@@ -8,7 +8,9 @@ from datetime import datetime, timedelta
import alembic.command import alembic.command
import alembic.config import alembic.config
import sqlalchemy
from alembic.script import ScriptDirectory from alembic.script import ScriptDirectory
from packaging.version import parse as parse_version
from sqlalchemy import ( from sqlalchemy import (
Boolean, Boolean,
Column, Column,
@@ -24,8 +26,8 @@ from sqlalchemy import (
inspect, inspect,
or_, or_,
select, select,
text,
) )
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import ( from sqlalchemy.orm import (
Session, Session,
backref, backref,
@@ -34,6 +36,13 @@ from sqlalchemy.orm import (
relationship, relationship,
sessionmaker, sessionmaker,
) )
try:
from sqlalchemy.orm import declarative_base
except ImportError:
# sqlalchemy < 1.4
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.pool import StaticPool from sqlalchemy.pool import StaticPool
from sqlalchemy.types import LargeBinary, Text, TypeDecorator from sqlalchemy.types import LargeBinary, Text, TypeDecorator
from tornado.log import app_log from tornado.log import app_log
@@ -750,6 +759,7 @@ class APIToken(Hashed, Base):
session_id=session_id, session_id=session_id,
scopes=list(scopes), scopes=list(scopes),
) )
db.add(orm_token)
orm_token.token = token orm_token.token = token
if user: if user:
assert user.id is not None assert user.id is not None
@@ -760,7 +770,6 @@ class APIToken(Hashed, Base):
if expires_in is not None: if expires_in is not None:
orm_token.expires_at = cls.now() + timedelta(seconds=expires_in) orm_token.expires_at = cls.now() + timedelta(seconds=expires_in)
db.add(orm_token)
db.commit() db.commit()
return token return token
@@ -902,7 +911,7 @@ def register_ping_connection(engine):
""" """
@event.listens_for(engine, "engine_connect") @event.listens_for(engine, "engine_connect")
def ping_connection(connection, branch): def ping_connection(connection, branch=None):
if branch: if branch:
# "branch" refers to a sub-connection of a connection, # "branch" refers to a sub-connection of a connection,
# we don't want to bother pinging on these. # we don't want to bother pinging on these.
@@ -913,11 +922,17 @@ def register_ping_connection(engine):
save_should_close_with_result = connection.should_close_with_result save_should_close_with_result = connection.should_close_with_result
connection.should_close_with_result = False connection.should_close_with_result = False
if parse_version(sqlalchemy.__version__) < parse_version("1.4"):
one = [1]
else:
one = 1
try: try:
# run a SELECT 1. use a core select() so that # run a SELECT 1. use a core select() so that
# the SELECT of a scalar value without a table is # the SELECT of a scalar value without a table is
# appropriately formatted for the backend # appropriately formatted for the backend
connection.scalar(select([1])) with connection.begin() as transaction:
connection.scalar(select(one))
except exc.DBAPIError as err: except exc.DBAPIError as err:
# catch SQLAlchemy's DBAPIError, which is a wrapper # catch SQLAlchemy's DBAPIError, which is a wrapper
# for the DBAPI's exception. It includes a .connection_invalidated # for the DBAPI's exception. It includes a .connection_invalidated
@@ -932,7 +947,8 @@ def register_ping_connection(engine):
# itself and establish a new connection. The disconnect detection # itself and establish a new connection. The disconnect detection
# here also causes the whole connection pool to be invalidated # here also causes the whole connection pool to be invalidated
# so that all stale connections are discarded. # so that all stale connections are discarded.
connection.scalar(select([1])) with connection.begin() as transaction:
connection.scalar(select(one))
else: else:
raise raise
finally: finally:
@@ -956,7 +972,13 @@ def check_db_revision(engine):
from .dbutil import _temp_alembic_ini from .dbutil import _temp_alembic_ini
with _temp_alembic_ini(engine.url) as ini: if hasattr(engine.url, "render_as_string"):
# sqlalchemy >= 1.4
engine_url = engine.url.render_as_string(hide_password=False)
else:
engine_url = str(engine.url)
with _temp_alembic_ini(engine_url) as ini:
cfg = alembic.config.Config(ini) cfg = alembic.config.Config(ini)
scripts = ScriptDirectory.from_config(cfg) scripts = ScriptDirectory.from_config(cfg)
head = scripts.get_heads()[0] head = scripts.get_heads()[0]
@@ -991,8 +1013,9 @@ def check_db_revision(engine):
# check database schema version # check database schema version
# it should always be defined at this point # it should always be defined at this point
alembic_revision = engine.execute( with engine.begin() as connection:
'SELECT version_num FROM alembic_version' alembic_revision = connection.execute(
text('SELECT version_num FROM alembic_version')
).first()[0] ).first()[0]
if alembic_revision == head: if alembic_revision == head:
app_log.debug("database schema version found: %s", alembic_revision) app_log.debug("database schema version found: %s", alembic_revision)
@@ -1010,11 +1033,14 @@ def mysql_large_prefix_check(engine):
"""Check mysql has innodb_large_prefix set""" """Check mysql has innodb_large_prefix set"""
if not str(engine.url).startswith('mysql'): if not str(engine.url).startswith('mysql'):
return False return False
with engine.begin() as connection:
variables = dict( variables = dict(
engine.execute( connection.execute(
text(
'show variables where variable_name like ' 'show variables where variable_name like '
'"innodb_large_prefix" or ' '"innodb_large_prefix" or '
'variable_name like "innodb_file_format";' 'variable_name like "innodb_file_format";'
)
).fetchall() ).fetchall()
) )
if ( if (

View File

@@ -444,11 +444,12 @@ async def test_get_self(app):
db.add(oauth_client) db.add(oauth_client)
db.commit() db.commit()
oauth_token = orm.APIToken( oauth_token = orm.APIToken(
user=u.orm_user,
oauth_client=oauth_client,
token=token, token=token,
) )
db.add(oauth_token) db.add(oauth_token)
oauth_token.user = u.orm_user
oauth_token.oauth_client = oauth_client
db.commit() db.commit()
r = await api_request( r = await api_request(
app, app,
@@ -2131,13 +2132,13 @@ def test_shutdown(app):
def stop(): def stop():
stop.called = True stop.called = True
loop.call_later(1, real_stop) loop.call_later(2, real_stop)
real_cleanup = app.cleanup real_cleanup = app.cleanup
def cleanup(): def cleanup():
cleanup.called = True cleanup.called = True
return real_cleanup() loop.call_later(1, real_cleanup)
app.cleanup = cleanup app.cleanup = cleanup

View File

@@ -323,7 +323,9 @@ def test_spawner_delete_cascade(db):
db.add(user) db.add(user)
db.commit() db.commit()
spawner = orm.Spawner(user=user) spawner = orm.Spawner()
db.add(spawner)
spawner.user = user
db.commit() db.commit()
spawner.server = server = orm.Server() spawner.server = server = orm.Server()
db.commit() db.commit()
@@ -350,16 +352,19 @@ def test_user_delete_cascade(db):
# these should all be deleted automatically when the user goes away # these should all be deleted automatically when the user goes away
user.new_api_token() user.new_api_token()
api_token = user.api_tokens[0] api_token = user.api_tokens[0]
spawner = orm.Spawner(user=user) spawner = orm.Spawner()
db.add(spawner)
spawner.user = user
db.commit() db.commit()
spawner.server = server = orm.Server() spawner.server = server = orm.Server()
oauth_code = orm.OAuthCode(client=oauth_client, user=user) oauth_code = orm.OAuthCode()
db.add(oauth_code) db.add(oauth_code)
oauth_token = orm.APIToken( oauth_code.client = oauth_client
oauth_client=oauth_client, oauth_code.user = user
user=user, oauth_token = orm.APIToken()
)
db.add(oauth_token) db.add(oauth_token)
oauth_token.oauth_client = oauth_client
oauth_token.user = user
db.commit() db.commit()
# record all of the ids # record all of the ids
@@ -390,13 +395,14 @@ def test_oauth_client_delete_cascade(db):
# create a bunch of objects that reference the User # create a bunch of objects that reference the User
# these should all be deleted automatically when the user goes away # these should all be deleted automatically when the user goes away
oauth_code = orm.OAuthCode(client=oauth_client, user=user) oauth_code = orm.OAuthCode()
db.add(oauth_code) db.add(oauth_code)
oauth_token = orm.APIToken( oauth_code.client = oauth_client
oauth_client=oauth_client, oauth_code.user = user
user=user, oauth_token = orm.APIToken()
)
db.add(oauth_token) db.add(oauth_token)
oauth_token.oauth_client = oauth_client
oauth_token.user = user
db.commit() db.commit()
assert user.api_tokens == [oauth_token] assert user.api_tokens == [oauth_token]
@@ -517,11 +523,11 @@ def test_expiring_oauth_token(app, user):
db.add(client) db.add(client)
orm_token = orm.APIToken( orm_token = orm.APIToken(
token=token, token=token,
oauth_client=client,
user=user,
expires_at=now() + timedelta(seconds=30), expires_at=now() + timedelta(seconds=30),
) )
db.add(orm_token) db.add(orm_token)
orm_token.oauth_client = client
orm_token.user = user
db.commit() db.commit()
found = orm.APIToken.find(db, token) found = orm.APIToken.find(db, token)

View File

@@ -1033,11 +1033,10 @@ async def test_oauth_token_page(app):
user = app.users[orm.User.find(app.db, name)] user = app.users[orm.User.find(app.db, name)]
client = orm.OAuthClient(identifier='token') client = orm.OAuthClient(identifier='token')
app.db.add(client) app.db.add(client)
oauth_token = orm.APIToken( oauth_token = orm.APIToken()
oauth_client=client,
user=user,
)
app.db.add(oauth_token) app.db.add(oauth_token)
oauth_token.oauth_client = client
oauth_token.user = user
app.db.commit() app.db.commit()
r = await get_page('token', app, cookies=cookies) r = await get_page('token', app, cookies=cookies)
r.raise_for_status() r.raise_for_status()

View File

@@ -3,9 +3,11 @@
# Distributed under the terms of the Modified BSD License. # Distributed under the terms of the Modified BSD License.
import json import json
import os import os
import warnings
import pytest import pytest
from pytest import mark from pytest import mark
from sqlalchemy.exc import SADeprecationWarning
from tornado.log import app_log from tornado.log import app_log
from .. import orm, roles from .. import orm, roles
@@ -343,7 +345,13 @@ async def test_creating_roles(app, role, role_def, response_type, response):
# make sure no warnings/info logged when the role exists and its definition hasn't been changed # make sure no warnings/info logged when the role exists and its definition hasn't been changed
elif response_type == 'no-log': elif response_type == 'no-log':
with pytest.warns(response) as record: with pytest.warns(response) as record:
# don't catch already-suppressed sqlalchemy warnings
warnings.simplefilter("ignore", SADeprecationWarning)
roles.create_role(db, role_def) roles.create_role(db, role_def)
for warning in record.list:
# show warnings for debugging
print("Unexpected warning", warning)
assert not record.list assert not record.list
role = orm.Role.find(db, role_def['name']) role = orm.Role.find(db, role_def['name'])
assert role is not None assert role is not None

View File

@@ -6,6 +6,7 @@ from concurrent.futures import ThreadPoolExecutor
import pytest import pytest
import requests import requests
from certipy import Certipy from certipy import Certipy
from sqlalchemy import text
from tornado.httputil import url_concat from tornado.httputil import url_concat
from jupyterhub import metrics, orm from jupyterhub import metrics, orm
@@ -13,6 +14,20 @@ from jupyterhub.objects import Server
from jupyterhub.roles import assign_default_roles, update_roles from jupyterhub.roles import assign_default_roles, update_roles
from jupyterhub.utils import url_path_join as ujoin from jupyterhub.utils import url_path_join as ujoin
try:
from sqlalchemy.exc import RemovedIn20Warning
except ImportError:
class RemovedIn20Warning(DeprecationWarning):
"""
I only exist so I can be used in warnings filters in pytest.ini
I will never be displayed.
sqlalchemy 1.4 introduces RemovedIn20Warning,
but we still test against older sqlalchemy.
"""
class _AsyncRequests: class _AsyncRequests:
"""Wrapper around requests to return a Future from request methods """Wrapper around requests to return a Future from request methods
@@ -85,8 +100,8 @@ def check_db_locks(func):
def _check(_=None): def _check(_=None):
temp_session = app.session_factory() temp_session = app.session_factory()
try: try:
temp_session.execute('CREATE TABLE dummy (foo INT)') temp_session.execute(text('CREATE TABLE dummy (foo INT)'))
temp_session.execute('DROP TABLE dummy') temp_session.execute(text('DROP TABLE dummy'))
finally: finally:
temp_session.close() temp_session.close()

View File

@@ -416,9 +416,10 @@ class User:
yield orm_spawner yield orm_spawner
def _new_orm_spawner(self, server_name): def _new_orm_spawner(self, server_name):
"""Creat the low-level orm Spawner object""" """Create the low-level orm Spawner object"""
orm_spawner = orm.Spawner(user=self.orm_user, name=server_name) orm_spawner = orm.Spawner(name=server_name)
self.db.add(orm_spawner) self.db.add(orm_spawner)
orm_spawner.user = self.orm_user
self.db.commit() self.db.commit()
assert server_name in self.orm_spawners assert server_name in self.orm_spawners
return orm_spawner return orm_spawner

View File

@@ -18,3 +18,7 @@ markers =
slow: mark a test as slow slow: mark a test as slow
role: mark as a test for roles role: mark as a test for roles
selenium: web tests that run with selenium selenium: web tests that run with selenium
filterwarnings =
error:.*:jupyterhub.tests.utils.RemovedIn20Warning
ignore:.*event listener has changed as of version 2.0.*:sqlalchemy.exc.SADeprecationWarning