use relationships everywhere

in order to use sqlalchemy's expire_on_commit=False optimization,
we need to make sure that objects are kept up to date.

This means we cannot rely on ForeignKey ondelete/onupdate behavior,
we must use sqlalchemy's local relationship cascades

The main key here is that we must use relationships to set foreign-key relations,
e.g. APIToken.user = user instead of APIToken.user_id = user.id.

It also means that we cannot use passive_deletes,
which allows sqlalchemy to defer to the database's more efficient ON DELETE behavior.

This makes deletions more expensive in particular,
but should improve db performance overall.
This commit is contained in:
Min RK
2018-04-17 10:54:14 +02:00
parent 15e4b1ad8b
commit b1840e8be7
4 changed files with 152 additions and 27 deletions

View File

@@ -70,11 +70,12 @@ class AccessTokenStore(HubDBMixin, oauth2.store.AccessTokenStore):
"""
user = self.db.query(orm.User).filter(orm.User.id == access_token.user_id).first()
user = self.db.query(orm.User).filter_by(id=access_token.user_id).first()
if user is None:
raise ValueError("No user for access token: %s" % access_token.user_id)
client = self.db.query(orm.OAuthClient).filter_by(identifier=access_token.client_id).first()
orm_access_token = orm.OAuthAccessToken(
client_id=access_token.client_id,
client=client,
grant_type=access_token.grant_type,
expires_at=access_token.expires_at,
refresh_token=access_token.refresh_token,
@@ -127,10 +128,10 @@ class AuthCodeStore(HubDBMixin, oauth2.store.AuthCodeStore):
:class:`oauth2.datatype.AuthorizationCode`.
"""
orm_code = orm.OAuthCode(
client_id=authorization_code.client_id,
client=authorization_code.client,
code=authorization_code.code,
expires_at=authorization_code.expires_at,
user_id=authorization_code.user_id,
user=authorization_code.user,
redirect_uri=authorization_code.redirect_uri,
session_id=authorization_code.data.get('session_id', ''),
)

View File

@@ -14,16 +14,19 @@ from tornado.log import app_log
from sqlalchemy.types import TypeDecorator, TEXT, LargeBinary
from sqlalchemy import (
inspect,
create_engine, event, inspect,
Column, Integer, ForeignKey, Unicode, Boolean,
DateTime, Enum
DateTime, Enum, Table,
)
from sqlalchemy.ext.declarative import declarative_base, declared_attr
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.interfaces import PoolListener
from sqlalchemy.orm import backref, sessionmaker, relationship
from sqlalchemy.orm import (
Session,
interfaces, object_session, relationship, sessionmaker,
)
from sqlalchemy.pool import StaticPool
from sqlalchemy.sql.expression import bindparam
from sqlalchemy import create_engine, Table
from .utils import (
random_port,
@@ -88,7 +91,7 @@ class Group(Base):
__tablename__ = 'groups'
id = Column(Integer, primary_key=True, autoincrement=True)
name = Column(Unicode(255), unique=True)
users = relationship('User', secondary='user_group_map', back_populates='groups')
users = relationship('User', secondary='user_group_map', backref='groups')
def __repr__(self):
return "<%s %s (%i users)>" % (
@@ -133,9 +136,6 @@ class User(Base):
"Spawner",
backref="user",
cascade="all, delete-orphan",
# can't use passive-deletes on this one
# because we rely on orm-level delete
# for Spawner.server
)
@property
def orm_spawners(self):
@@ -149,13 +149,16 @@ class User(Base):
"APIToken",
backref="user",
cascade="all, delete-orphan",
passive_deletes=True,
)
oauth_tokens = relationship(
"OAuthAccessToken",
backref="user",
cascade="all, delete-orphan",
passive_deletes=True,
)
oauth_codes = relationship(
"OAuthCode",
backref="user",
cascade="all, delete-orphan",
)
cookie_id = Column(Unicode(255), default=new_token, nullable=False, unique=True)
# User.state is actually Spawner state
@@ -164,8 +167,6 @@ class User(Base):
# Authenticators can store their state here:
# Encryption is handled elsewhere
encrypted_auth_state = Column(LargeBinary)
# group mapping
groups = relationship('Group', secondary='user_group_map', back_populates='users')
def __repr__(self):
return "<{cls}({name} {running}/{total} running)>".format(
@@ -234,7 +235,6 @@ class Service(Base):
"APIToken",
backref="service",
cascade="all, delete-orphan",
passive_deletes=True,
)
# service-specific interface
@@ -408,10 +408,10 @@ class APIToken(Hashed, Base):
orm_token.token = token
if user:
assert user.id is not None
orm_token.user_id = user.id
orm_token.user = user
else:
assert service.id is not None
orm_token.service_id = service.id
orm_token.service = service
db.add(orm_token)
db.commit()
return token
@@ -498,6 +498,18 @@ class OAuthClient(Base):
secret = Column(Unicode(255))
redirect_uri = Column(Unicode(1023))
access_tokens = relationship(
OAuthAccessToken,
backref='client',
cascade='all, delete-orphan',
)
codes = relationship(
OAuthCode,
backref='client',
cascade='all, delete-orphan',
)
# General database utilities
class DatabaseSchemaMismatch(Exception):
"""Exception raised when the database schema version does not match
@@ -512,6 +524,39 @@ class ForeignKeysListener(PoolListener):
dbapi_con.execute('pragma foreign_keys=ON')
def _expire_relationship(target, relationship_prop):
"""Expire relationship backrefs
used when an object with relationships is deleted
"""
session = object_session(target)
# get peer objects to be expired
peers = getattr(target, relationship_prop.key)
if peers is None:
# no peer to clear
return
# many-to-many and one-to-many have a list of peers
# many-to-one has only one
if relationship_prop.direction is interfaces.MANYTOONE:
peers = [peers]
for obj in peers:
if inspect(obj).persistent:
session.expire(obj, [relationship_prop.back_populates])
@event.listens_for(Session, "persistent_to_deleted")
def _notify_deleted_relationships(session, obj):
"""Expire relationships when an object becomes deleted
Needed for
"""
mapper = inspect(obj).mapper
for prop in mapper.relationships:
if prop.back_populates:
_expire_relationship(obj, prop)
def check_db_revision(engine):
"""Check the JupyterHub database revision

View File

@@ -280,7 +280,7 @@ def test_get_self(app):
db.commit()
oauth_token = orm.OAuthAccessToken(
user=u.orm_user,
client_id=oauth_client.identifier,
client=oauth_client,
token=token,
grant_type=orm.GrantType.authorization_code,
)

View File

@@ -289,6 +289,7 @@ def test_spawner_delete_cascade(db):
# verify that server gets deleted
assert_not_found(db, orm.Server, server_id)
assert user.orm_spawners == {}
def test_user_delete_cascade(db):
@@ -305,11 +306,11 @@ def test_user_delete_cascade(db):
spawner = orm.Spawner(user=user)
db.commit()
spawner.server = server = orm.Server()
oauth_code = orm.OAuthCode(client_id=oauth_client.identifier, user_id=user.id)
oauth_code = orm.OAuthCode(client=oauth_client, user=user)
db.add(oauth_code)
oauth_token = orm.OAuthAccessToken(
client_id=oauth_client.identifier,
user_id=user.id,
client=oauth_client,
user=user,
grant_type=orm.GrantType.authorization_code,
)
db.add(oauth_token)
@@ -343,15 +344,16 @@ def test_oauth_client_delete_cascade(db):
# create a bunch of objects that reference the User
# these should all be deleted automatically when the user goes away
oauth_code = orm.OAuthCode(client_id=oauth_client.identifier, user_id=user.id)
oauth_code = orm.OAuthCode(client=oauth_client, user=user)
db.add(oauth_code)
oauth_token = orm.OAuthAccessToken(
client_id=oauth_client.identifier,
user_id=user.id,
client=oauth_client,
user=user,
grant_type=orm.GrantType.authorization_code,
)
db.add(oauth_token)
db.commit()
assert user.oauth_tokens == [oauth_token]
# record all of the ids
oauth_code_id = oauth_code.id
@@ -364,3 +366,80 @@ def test_oauth_client_delete_cascade(db):
# verify that everything gets deleted
assert_not_found(db, orm.OAuthCode, oauth_code_id)
assert_not_found(db, orm.OAuthAccessToken, oauth_token_id)
assert user.oauth_tokens == []
assert user.oauth_codes == []
def test_delete_token_cascade(db):
user = orm.User(name='mobs')
db.add(user)
db.commit()
user.new_api_token()
api_token = user.api_tokens[0]
db.delete(api_token)
db.commit()
assert user.api_tokens == []
def test_group_delete_cascade(db):
user1 = orm.User(name='user1')
user2 = orm.User(name='user2')
group1 = orm.Group(name='group1')
group2 = orm.Group(name='group2')
db.add(user1)
db.add(user2)
db.add(group1)
db.add(group2)
db.commit()
# add user to group via user.groups works
user1.groups.append(group1)
db.commit()
assert user1 in group1.users
# add user to group via groups.users works
group1.users.append(user2)
db.commit()
assert user1 in group1.users
assert user2 in group1.users
assert group1 in user1.groups
assert group1 in user2.groups
# fill out the connections (no new concept)
group2.users.append(user1)
group2.users.append(user2)
db.commit()
assert user1 in group1.users
assert user2 in group1.users
assert user1 in group2.users
assert user2 in group2.users
assert group1 in user1.groups
assert group1 in user2.groups
assert group2 in user1.groups
assert group2 in user2.groups
# now start deleting
# 1. remove group via user.groups
user1.groups.remove(group2)
db.commit()
assert user1 not in group2.users
assert group2 not in user1.groups
# 2. remove user via group.users
group1.users.remove(user2)
db.commit()
assert user2 not in group1.users
assert group1 not in user2.groups
# 3. delete group object
db.delete(group2)
db.commit()
assert group2 not in user1.groups
assert group2 not in user2.groups
# 4. delete user object
db.delete(user1)
db.commit()
assert user1 not in group1.users