mirror of
https://github.com/jupyterhub/jupyterhub.git
synced 2025-10-10 19:43:01 +00:00
use relationships everywhere
in order to use sqlalchemy's expire_on_commit=False optimization, we need to make sure that objects are kept up to date. This means we cannot rely on ForeignKey ondelete/onupdate behavior, we must use sqlalchemy's local relationship cascades The main key here is that we must use relationships to set foreign-key relations, e.g. APIToken.user = user instead of APIToken.user_id = user.id. It also means that we cannot use passive_deletes, which allows sqlalchemy to defer to the database's more efficient ON DELETE behavior. This makes deletions more expensive in particular, but should improve db performance overall.
This commit is contained in:
@@ -70,11 +70,12 @@ class AccessTokenStore(HubDBMixin, oauth2.store.AccessTokenStore):
|
||||
|
||||
"""
|
||||
|
||||
user = self.db.query(orm.User).filter(orm.User.id == access_token.user_id).first()
|
||||
user = self.db.query(orm.User).filter_by(id=access_token.user_id).first()
|
||||
if user is None:
|
||||
raise ValueError("No user for access token: %s" % access_token.user_id)
|
||||
client = self.db.query(orm.OAuthClient).filter_by(identifier=access_token.client_id).first()
|
||||
orm_access_token = orm.OAuthAccessToken(
|
||||
client_id=access_token.client_id,
|
||||
client=client,
|
||||
grant_type=access_token.grant_type,
|
||||
expires_at=access_token.expires_at,
|
||||
refresh_token=access_token.refresh_token,
|
||||
@@ -127,10 +128,10 @@ class AuthCodeStore(HubDBMixin, oauth2.store.AuthCodeStore):
|
||||
:class:`oauth2.datatype.AuthorizationCode`.
|
||||
"""
|
||||
orm_code = orm.OAuthCode(
|
||||
client_id=authorization_code.client_id,
|
||||
client=authorization_code.client,
|
||||
code=authorization_code.code,
|
||||
expires_at=authorization_code.expires_at,
|
||||
user_id=authorization_code.user_id,
|
||||
user=authorization_code.user,
|
||||
redirect_uri=authorization_code.redirect_uri,
|
||||
session_id=authorization_code.data.get('session_id', ''),
|
||||
)
|
||||
|
@@ -14,16 +14,19 @@ from tornado.log import app_log
|
||||
|
||||
from sqlalchemy.types import TypeDecorator, TEXT, LargeBinary
|
||||
from sqlalchemy import (
|
||||
inspect,
|
||||
create_engine, event, inspect,
|
||||
Column, Integer, ForeignKey, Unicode, Boolean,
|
||||
DateTime, Enum
|
||||
DateTime, Enum, Table,
|
||||
)
|
||||
from sqlalchemy.ext.declarative import declarative_base, declared_attr
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.interfaces import PoolListener
|
||||
from sqlalchemy.orm import backref, sessionmaker, relationship
|
||||
from sqlalchemy.orm import (
|
||||
Session,
|
||||
interfaces, object_session, relationship, sessionmaker,
|
||||
)
|
||||
|
||||
from sqlalchemy.pool import StaticPool
|
||||
from sqlalchemy.sql.expression import bindparam
|
||||
from sqlalchemy import create_engine, Table
|
||||
|
||||
from .utils import (
|
||||
random_port,
|
||||
@@ -88,7 +91,7 @@ class Group(Base):
|
||||
__tablename__ = 'groups'
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
name = Column(Unicode(255), unique=True)
|
||||
users = relationship('User', secondary='user_group_map', back_populates='groups')
|
||||
users = relationship('User', secondary='user_group_map', backref='groups')
|
||||
|
||||
def __repr__(self):
|
||||
return "<%s %s (%i users)>" % (
|
||||
@@ -133,9 +136,6 @@ class User(Base):
|
||||
"Spawner",
|
||||
backref="user",
|
||||
cascade="all, delete-orphan",
|
||||
# can't use passive-deletes on this one
|
||||
# because we rely on orm-level delete
|
||||
# for Spawner.server
|
||||
)
|
||||
@property
|
||||
def orm_spawners(self):
|
||||
@@ -149,13 +149,16 @@ class User(Base):
|
||||
"APIToken",
|
||||
backref="user",
|
||||
cascade="all, delete-orphan",
|
||||
passive_deletes=True,
|
||||
)
|
||||
oauth_tokens = relationship(
|
||||
"OAuthAccessToken",
|
||||
backref="user",
|
||||
cascade="all, delete-orphan",
|
||||
passive_deletes=True,
|
||||
)
|
||||
oauth_codes = relationship(
|
||||
"OAuthCode",
|
||||
backref="user",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
cookie_id = Column(Unicode(255), default=new_token, nullable=False, unique=True)
|
||||
# User.state is actually Spawner state
|
||||
@@ -164,8 +167,6 @@ class User(Base):
|
||||
# Authenticators can store their state here:
|
||||
# Encryption is handled elsewhere
|
||||
encrypted_auth_state = Column(LargeBinary)
|
||||
# group mapping
|
||||
groups = relationship('Group', secondary='user_group_map', back_populates='users')
|
||||
|
||||
def __repr__(self):
|
||||
return "<{cls}({name} {running}/{total} running)>".format(
|
||||
@@ -234,7 +235,6 @@ class Service(Base):
|
||||
"APIToken",
|
||||
backref="service",
|
||||
cascade="all, delete-orphan",
|
||||
passive_deletes=True,
|
||||
)
|
||||
|
||||
# service-specific interface
|
||||
@@ -408,10 +408,10 @@ class APIToken(Hashed, Base):
|
||||
orm_token.token = token
|
||||
if user:
|
||||
assert user.id is not None
|
||||
orm_token.user_id = user.id
|
||||
orm_token.user = user
|
||||
else:
|
||||
assert service.id is not None
|
||||
orm_token.service_id = service.id
|
||||
orm_token.service = service
|
||||
db.add(orm_token)
|
||||
db.commit()
|
||||
return token
|
||||
@@ -498,6 +498,18 @@ class OAuthClient(Base):
|
||||
secret = Column(Unicode(255))
|
||||
redirect_uri = Column(Unicode(1023))
|
||||
|
||||
access_tokens = relationship(
|
||||
OAuthAccessToken,
|
||||
backref='client',
|
||||
cascade='all, delete-orphan',
|
||||
)
|
||||
codes = relationship(
|
||||
OAuthCode,
|
||||
backref='client',
|
||||
cascade='all, delete-orphan',
|
||||
)
|
||||
|
||||
# General database utilities
|
||||
|
||||
class DatabaseSchemaMismatch(Exception):
|
||||
"""Exception raised when the database schema version does not match
|
||||
@@ -512,6 +524,39 @@ class ForeignKeysListener(PoolListener):
|
||||
dbapi_con.execute('pragma foreign_keys=ON')
|
||||
|
||||
|
||||
def _expire_relationship(target, relationship_prop):
|
||||
"""Expire relationship backrefs
|
||||
|
||||
used when an object with relationships is deleted
|
||||
"""
|
||||
|
||||
session = object_session(target)
|
||||
# get peer objects to be expired
|
||||
peers = getattr(target, relationship_prop.key)
|
||||
if peers is None:
|
||||
# no peer to clear
|
||||
return
|
||||
# many-to-many and one-to-many have a list of peers
|
||||
# many-to-one has only one
|
||||
if relationship_prop.direction is interfaces.MANYTOONE:
|
||||
peers = [peers]
|
||||
for obj in peers:
|
||||
if inspect(obj).persistent:
|
||||
session.expire(obj, [relationship_prop.back_populates])
|
||||
|
||||
|
||||
@event.listens_for(Session, "persistent_to_deleted")
|
||||
def _notify_deleted_relationships(session, obj):
|
||||
"""Expire relationships when an object becomes deleted
|
||||
|
||||
Needed for
|
||||
"""
|
||||
mapper = inspect(obj).mapper
|
||||
for prop in mapper.relationships:
|
||||
if prop.back_populates:
|
||||
_expire_relationship(obj, prop)
|
||||
|
||||
|
||||
def check_db_revision(engine):
|
||||
"""Check the JupyterHub database revision
|
||||
|
||||
|
@@ -280,7 +280,7 @@ def test_get_self(app):
|
||||
db.commit()
|
||||
oauth_token = orm.OAuthAccessToken(
|
||||
user=u.orm_user,
|
||||
client_id=oauth_client.identifier,
|
||||
client=oauth_client,
|
||||
token=token,
|
||||
grant_type=orm.GrantType.authorization_code,
|
||||
)
|
||||
|
@@ -289,6 +289,7 @@ def test_spawner_delete_cascade(db):
|
||||
|
||||
# verify that server gets deleted
|
||||
assert_not_found(db, orm.Server, server_id)
|
||||
assert user.orm_spawners == {}
|
||||
|
||||
|
||||
def test_user_delete_cascade(db):
|
||||
@@ -305,11 +306,11 @@ def test_user_delete_cascade(db):
|
||||
spawner = orm.Spawner(user=user)
|
||||
db.commit()
|
||||
spawner.server = server = orm.Server()
|
||||
oauth_code = orm.OAuthCode(client_id=oauth_client.identifier, user_id=user.id)
|
||||
oauth_code = orm.OAuthCode(client=oauth_client, user=user)
|
||||
db.add(oauth_code)
|
||||
oauth_token = orm.OAuthAccessToken(
|
||||
client_id=oauth_client.identifier,
|
||||
user_id=user.id,
|
||||
client=oauth_client,
|
||||
user=user,
|
||||
grant_type=orm.GrantType.authorization_code,
|
||||
)
|
||||
db.add(oauth_token)
|
||||
@@ -343,15 +344,16 @@ def test_oauth_client_delete_cascade(db):
|
||||
|
||||
# create a bunch of objects that reference the User
|
||||
# these should all be deleted automatically when the user goes away
|
||||
oauth_code = orm.OAuthCode(client_id=oauth_client.identifier, user_id=user.id)
|
||||
oauth_code = orm.OAuthCode(client=oauth_client, user=user)
|
||||
db.add(oauth_code)
|
||||
oauth_token = orm.OAuthAccessToken(
|
||||
client_id=oauth_client.identifier,
|
||||
user_id=user.id,
|
||||
client=oauth_client,
|
||||
user=user,
|
||||
grant_type=orm.GrantType.authorization_code,
|
||||
)
|
||||
db.add(oauth_token)
|
||||
db.commit()
|
||||
assert user.oauth_tokens == [oauth_token]
|
||||
|
||||
# record all of the ids
|
||||
oauth_code_id = oauth_code.id
|
||||
@@ -364,3 +366,80 @@ def test_oauth_client_delete_cascade(db):
|
||||
# verify that everything gets deleted
|
||||
assert_not_found(db, orm.OAuthCode, oauth_code_id)
|
||||
assert_not_found(db, orm.OAuthAccessToken, oauth_token_id)
|
||||
assert user.oauth_tokens == []
|
||||
assert user.oauth_codes == []
|
||||
|
||||
|
||||
def test_delete_token_cascade(db):
|
||||
user = orm.User(name='mobs')
|
||||
db.add(user)
|
||||
db.commit()
|
||||
user.new_api_token()
|
||||
api_token = user.api_tokens[0]
|
||||
db.delete(api_token)
|
||||
db.commit()
|
||||
assert user.api_tokens == []
|
||||
|
||||
|
||||
def test_group_delete_cascade(db):
|
||||
user1 = orm.User(name='user1')
|
||||
user2 = orm.User(name='user2')
|
||||
group1 = orm.Group(name='group1')
|
||||
group2 = orm.Group(name='group2')
|
||||
db.add(user1)
|
||||
db.add(user2)
|
||||
db.add(group1)
|
||||
db.add(group2)
|
||||
db.commit()
|
||||
# add user to group via user.groups works
|
||||
user1.groups.append(group1)
|
||||
db.commit()
|
||||
assert user1 in group1.users
|
||||
|
||||
# add user to group via groups.users works
|
||||
group1.users.append(user2)
|
||||
db.commit()
|
||||
assert user1 in group1.users
|
||||
assert user2 in group1.users
|
||||
assert group1 in user1.groups
|
||||
assert group1 in user2.groups
|
||||
|
||||
# fill out the connections (no new concept)
|
||||
group2.users.append(user1)
|
||||
group2.users.append(user2)
|
||||
db.commit()
|
||||
assert user1 in group1.users
|
||||
assert user2 in group1.users
|
||||
assert user1 in group2.users
|
||||
assert user2 in group2.users
|
||||
assert group1 in user1.groups
|
||||
assert group1 in user2.groups
|
||||
assert group2 in user1.groups
|
||||
assert group2 in user2.groups
|
||||
|
||||
# now start deleting
|
||||
# 1. remove group via user.groups
|
||||
user1.groups.remove(group2)
|
||||
db.commit()
|
||||
assert user1 not in group2.users
|
||||
assert group2 not in user1.groups
|
||||
|
||||
# 2. remove user via group.users
|
||||
group1.users.remove(user2)
|
||||
db.commit()
|
||||
assert user2 not in group1.users
|
||||
assert group1 not in user2.groups
|
||||
|
||||
# 3. delete group object
|
||||
db.delete(group2)
|
||||
db.commit()
|
||||
assert group2 not in user1.groups
|
||||
assert group2 not in user2.groups
|
||||
|
||||
# 4. delete user object
|
||||
db.delete(user1)
|
||||
db.commit()
|
||||
assert user1 not in group1.users
|
||||
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user