fix and test deletion cascades

- ensure foreign keys are enabled on sqlite
- fix deletion cascades where relationships were causing dissociation instead of deletion
This commit is contained in:
Min RK
2018-04-13 20:02:24 +02:00
parent 33ba9fb5cf
commit 078bd8c627
2 changed files with 142 additions and 22 deletions

View File

@@ -19,8 +19,8 @@ from sqlalchemy import (
DateTime, Enum DateTime, Enum
) )
from sqlalchemy.ext.declarative import declarative_base, declared_attr from sqlalchemy.ext.declarative import declarative_base, declared_attr
from sqlalchemy.orm import sessionmaker, relationship
from sqlalchemy.interfaces import PoolListener from sqlalchemy.interfaces import PoolListener
from sqlalchemy.orm import backref, sessionmaker, relationship
from sqlalchemy.pool import StaticPool from sqlalchemy.pool import StaticPool
from sqlalchemy.sql.expression import bindparam from sqlalchemy.sql.expression import bindparam
from sqlalchemy import create_engine, Table from sqlalchemy import create_engine, Table
@@ -78,8 +78,8 @@ class Server(Base):
# user:group many:many mapping table # user:group many:many mapping table
user_group_map = Table('user_group_map', Base.metadata, user_group_map = Table('user_group_map', Base.metadata,
Column('user_id', ForeignKey('users.id'), primary_key=True), Column('user_id', ForeignKey('users.id', ondelete='CASCADE'), primary_key=True),
Column('group_id', ForeignKey('groups.id'), primary_key=True), Column('group_id', ForeignKey('groups.id', ondelete='CASCADE'), primary_key=True),
) )
@@ -129,7 +129,14 @@ class User(Base):
id = Column(Integer, primary_key=True, autoincrement=True) id = Column(Integer, primary_key=True, autoincrement=True)
name = Column(Unicode(255), unique=True) name = Column(Unicode(255), unique=True)
_orm_spawners = relationship("Spawner", backref="user") _orm_spawners = relationship(
"Spawner",
backref="user",
cascade="all, delete-orphan",
# can't use passive-deletes on this one
# because we rely on orm-level delete
# for Spawner.server
)
@property @property
def orm_spawners(self): def orm_spawners(self):
return {s.name: s for s in self._orm_spawners} return {s.name: s for s in self._orm_spawners}
@@ -138,7 +145,12 @@ class User(Base):
created = Column(DateTime, default=datetime.utcnow) created = Column(DateTime, default=datetime.utcnow)
last_activity = Column(DateTime, nullable=True) last_activity = Column(DateTime, nullable=True)
api_tokens = relationship("APIToken", backref="user") api_tokens = relationship(
"APIToken",
backref="user",
cascade="all, delete-orphan",
passive_deletes=True,
)
cookie_id = Column(Unicode(255), default=new_token, nullable=False, unique=True) cookie_id = Column(Unicode(255), default=new_token, nullable=False, unique=True)
# User.state is actually Spawner state # User.state is actually Spawner state
# We will need to figure something else out if/when we have multiple spawners per user # We will need to figure something else out if/when we have multiple spawners per user
@@ -179,7 +191,7 @@ class Spawner(Base):
user_id = Column(Integer, ForeignKey('users.id', ondelete='CASCADE')) user_id = Column(Integer, ForeignKey('users.id', ondelete='CASCADE'))
server_id = Column(Integer, ForeignKey('servers.id', ondelete='SET NULL')) server_id = Column(Integer, ForeignKey('servers.id', ondelete='SET NULL'))
server = relationship(Server) server = relationship(Server, cascade="all")
state = Column(JSONDict) state = Column(JSONDict)
name = Column(Unicode(255)) name = Column(Unicode(255))
@@ -212,11 +224,16 @@ class Service(Base):
name = Column(Unicode(255), unique=True) name = Column(Unicode(255), unique=True)
admin = Column(Boolean, default=False) admin = Column(Boolean, default=False)
api_tokens = relationship("APIToken", backref="service") api_tokens = relationship(
"APIToken",
backref="service",
cascade="all, delete-orphan",
passive_deletes=True,
)
# service-specific interface # service-specific interface
_server_id = Column(Integer, ForeignKey('servers.id', ondelete='SET NULL')) _server_id = Column(Integer, ForeignKey('servers.id', ondelete='SET NULL'))
server = relationship(Server, primaryjoin=_server_id == Server.id) server = relationship(Server, cascade='all')
pid = Column(Integer) pid = Column(Integer)
def new_api_token(self, token=None, generated=True, note=''): def new_api_token(self, token=None, generated=True, note=''):
@@ -312,13 +329,8 @@ class APIToken(Hashed, Base):
"""An API token""" """An API token"""
__tablename__ = 'api_tokens' __tablename__ = 'api_tokens'
@declared_attr user_id = Column(Integer, ForeignKey('users.id', ondelete="CASCADE"), nullable=True)
def user_id(cls): service_id = Column(Integer, ForeignKey('services.id', ondelete="CASCADE"), nullable=True)
return Column(Integer, ForeignKey('users.id', ondelete="CASCADE"), nullable=True)
@declared_attr
def service_id(cls):
return Column(Integer, ForeignKey('services.id', ondelete="CASCADE"), nullable=True)
id = Column(Integer, primary_key=True) id = Column(Integer, primary_key=True)
hashed = Column(Unicode(255), unique=True) hashed = Column(Unicode(255), unique=True)
@@ -419,7 +431,6 @@ class OAuthAccessToken(Hashed, Base):
refresh_token = Column(Unicode(255)) refresh_token = Column(Unicode(255))
refresh_expires_at = Column(Integer) refresh_expires_at = Column(Integer)
user_id = Column(Integer, ForeignKey('users.id', ondelete='CASCADE')) user_id = Column(Integer, ForeignKey('users.id', ondelete='CASCADE'))
user = relationship(User)
service = None # for API-equivalence with APIToken service = None # for API-equivalence with APIToken
# the browser session id associated with a given token # the browser session id associated with a given token
@@ -433,8 +444,9 @@ class OAuthAccessToken(Hashed, Base):
last_activity = Column(DateTime, nullable=True) last_activity = Column(DateTime, nullable=True)
def __repr__(self): def __repr__(self):
return "<{cls}('{prefix}...', user='{user}'>".format( return "<{cls}('{prefix}...', client_id={client_id!r}, user={user!r}>".format(
cls=self.__class__.__name__, cls=self.__class__.__name__,
client_id=self.client_id,
user=self.user and self.user.name, user=self.user and self.user.name,
prefix=self.prefix, prefix=self.prefix,
) )

View File

@@ -17,6 +17,11 @@ from .mocking import MockSpawner
from ..emptyclass import EmptyClass from ..emptyclass import EmptyClass
def assert_not_found(db, ORMType, id):
"""Assert that an item with a given id is not found"""
assert db.query(ORMType).filter(ORMType.id==id).first() is None
def test_server(db): def test_server(db):
server = orm.Server() server = orm.Server()
db.add(server) db.add(server)
@@ -123,6 +128,12 @@ def test_service_server(db):
assert server.id is None assert server.id is None
db.commit() db.commit()
assert isinstance(server.id, int) assert isinstance(server.id, int)
server_id = server.id
# deleting service should delete its server
db.delete(service)
db.commit()
assert_not_found(db, orm.Server, server_id)
def test_token_find(db): def test_token_find(db):
@@ -191,6 +202,9 @@ def test_groups(db):
db.commit() db.commit()
assert group.users == [user] assert group.users == [user]
assert user.groups == [group] assert user.groups == [group]
db.delete(user)
db.commit()
assert group.users == []
@pytest.mark.gen_test @pytest.mark.gen_test
@@ -256,3 +270,97 @@ def test_auth_state(db):
decrypted_state = yield user.get_auth_state() decrypted_state = yield user.get_auth_state()
assert decrypted_state is None assert decrypted_state is None
def test_spawner_delete_cascade(db):
user = orm.User(name='spawner-delete')
db.add(user)
db.commit()
spawner = orm.Spawner(user=user)
db.commit()
spawner.server = server = orm.Server()
db.commit()
db.delete(spawner)
server_id = server.id
# delete the user
db.delete(spawner)
db.commit()
# verify that server gets deleted
assert_not_found(db, orm.Server, server_id)
def test_user_delete_cascade(db):
user = orm.User(name='db-delete')
oauth_client = orm.OAuthClient(identifier='db-delete-client')
db.add(user)
db.add(oauth_client)
db.commit()
# create a bunch of objects that reference the User
# these should all be deleted automatically when the user goes away
user.new_api_token()
api_token = user.api_tokens[0]
spawner = orm.Spawner(user=user)
db.commit()
spawner.server = server = orm.Server()
oauth_code = orm.OAuthCode(client_id=oauth_client.identifier, user_id=user.id)
db.add(oauth_code)
oauth_token = orm.OAuthAccessToken(
client_id=oauth_client.identifier,
user_id=user.id,
grant_type=orm.GrantType.authorization_code,
)
db.add(oauth_token)
db.commit()
# record all of the ids
spawner_id = spawner.id
server_id = server.id
api_token_id = api_token.id
oauth_code_id = oauth_code.id
oauth_token_id = oauth_token.id
# delete the user
db.delete(user)
db.commit()
# verify that everything gets deleted
assert_not_found(db, orm.APIToken, api_token_id)
assert_not_found(db, orm.Spawner, spawner_id)
assert_not_found(db, orm.Server, server_id)
assert_not_found(db, orm.OAuthCode, oauth_code_id)
assert_not_found(db, orm.OAuthAccessToken, oauth_token_id)
def test_oauth_client_delete_cascade(db):
user = orm.User(name='oauth-delete')
oauth_client = orm.OAuthClient(identifier='oauth-delete-client')
db.add(user)
db.add(oauth_client)
db.commit()
# create a bunch of objects that reference the User
# these should all be deleted automatically when the user goes away
oauth_code = orm.OAuthCode(client_id=oauth_client.identifier, user_id=user.id)
db.add(oauth_code)
oauth_token = orm.OAuthAccessToken(
client_id=oauth_client.identifier,
user_id=user.id,
grant_type=orm.GrantType.authorization_code,
)
db.add(oauth_token)
db.commit()
# record all of the ids
oauth_code_id = oauth_code.id
oauth_token_id = oauth_token.id
# delete the user
db.delete(oauth_client)
db.commit()
# verify that everything gets deleted
assert_not_found(db, orm.OAuthCode, oauth_code_id)
assert_not_found(db, orm.OAuthAccessToken, oauth_token_id)