fix and test deletion cascades

- ensure foreign keys are enabled on sqlite
- fix deletion cascades where relationships were causing dissociation instead of deletion
This commit is contained in:
Min RK
2018-04-13 20:02:24 +02:00
parent 33ba9fb5cf
commit 078bd8c627
2 changed files with 142 additions and 22 deletions

View File

@@ -19,8 +19,8 @@ from sqlalchemy import (
DateTime, Enum
)
from sqlalchemy.ext.declarative import declarative_base, declared_attr
from sqlalchemy.orm import sessionmaker, relationship
from sqlalchemy.interfaces import PoolListener
from sqlalchemy.orm import backref, sessionmaker, relationship
from sqlalchemy.pool import StaticPool
from sqlalchemy.sql.expression import bindparam
from sqlalchemy import create_engine, Table
@@ -78,8 +78,8 @@ class Server(Base):
# user:group many:many mapping table
user_group_map = Table('user_group_map', Base.metadata,
Column('user_id', ForeignKey('users.id'), primary_key=True),
Column('group_id', ForeignKey('groups.id'), primary_key=True),
Column('user_id', ForeignKey('users.id', ondelete='CASCADE'), primary_key=True),
Column('group_id', ForeignKey('groups.id', ondelete='CASCADE'), primary_key=True),
)
@@ -129,7 +129,14 @@ class User(Base):
id = Column(Integer, primary_key=True, autoincrement=True)
name = Column(Unicode(255), unique=True)
_orm_spawners = relationship("Spawner", backref="user")
_orm_spawners = relationship(
"Spawner",
backref="user",
cascade="all, delete-orphan",
# can't use passive-deletes on this one
# because we rely on orm-level delete
# for Spawner.server
)
@property
def orm_spawners(self):
return {s.name: s for s in self._orm_spawners}
@@ -138,7 +145,12 @@ class User(Base):
created = Column(DateTime, default=datetime.utcnow)
last_activity = Column(DateTime, nullable=True)
api_tokens = relationship("APIToken", backref="user")
api_tokens = relationship(
"APIToken",
backref="user",
cascade="all, delete-orphan",
passive_deletes=True,
)
cookie_id = Column(Unicode(255), default=new_token, nullable=False, unique=True)
# User.state is actually Spawner state
# We will need to figure something else out if/when we have multiple spawners per user
@@ -179,7 +191,7 @@ class Spawner(Base):
user_id = Column(Integer, ForeignKey('users.id', ondelete='CASCADE'))
server_id = Column(Integer, ForeignKey('servers.id', ondelete='SET NULL'))
server = relationship(Server)
server = relationship(Server, cascade="all")
state = Column(JSONDict)
name = Column(Unicode(255))
@@ -212,11 +224,16 @@ class Service(Base):
name = Column(Unicode(255), unique=True)
admin = Column(Boolean, default=False)
api_tokens = relationship("APIToken", backref="service")
api_tokens = relationship(
"APIToken",
backref="service",
cascade="all, delete-orphan",
passive_deletes=True,
)
# service-specific interface
_server_id = Column(Integer, ForeignKey('servers.id', ondelete='SET NULL'))
server = relationship(Server, primaryjoin=_server_id == Server.id)
server = relationship(Server, cascade='all')
pid = Column(Integer)
def new_api_token(self, token=None, generated=True, note=''):
@@ -312,13 +329,8 @@ class APIToken(Hashed, Base):
"""An API token"""
__tablename__ = 'api_tokens'
@declared_attr
def user_id(cls):
return Column(Integer, ForeignKey('users.id', ondelete="CASCADE"), nullable=True)
@declared_attr
def service_id(cls):
return Column(Integer, ForeignKey('services.id', ondelete="CASCADE"), nullable=True)
user_id = Column(Integer, ForeignKey('users.id', ondelete="CASCADE"), nullable=True)
service_id = Column(Integer, ForeignKey('services.id', ondelete="CASCADE"), nullable=True)
id = Column(Integer, primary_key=True)
hashed = Column(Unicode(255), unique=True)
@@ -419,7 +431,6 @@ class OAuthAccessToken(Hashed, Base):
refresh_token = Column(Unicode(255))
refresh_expires_at = Column(Integer)
user_id = Column(Integer, ForeignKey('users.id', ondelete='CASCADE'))
user = relationship(User)
service = None # for API-equivalence with APIToken
# the browser session id associated with a given token
@@ -433,8 +444,9 @@ class OAuthAccessToken(Hashed, Base):
last_activity = Column(DateTime, nullable=True)
def __repr__(self):
return "<{cls}('{prefix}...', user='{user}'>".format(
return "<{cls}('{prefix}...', client_id={client_id!r}, user={user!r}>".format(
cls=self.__class__.__name__,
client_id=self.client_id,
user=self.user and self.user.name,
prefix=self.prefix,
)

View File

@@ -17,6 +17,11 @@ from .mocking import MockSpawner
from ..emptyclass import EmptyClass
def assert_not_found(db, ORMType, id):
"""Assert that an item with a given id is not found"""
assert db.query(ORMType).filter(ORMType.id==id).first() is None
def test_server(db):
server = orm.Server()
db.add(server)
@@ -116,14 +121,20 @@ def test_service_server(db):
service = orm.Service(name='has_servers')
db.add(service)
db.commit()
assert service.server is None
server = service.server = orm.Server()
assert service
assert server.id is None
db.commit()
assert isinstance(server.id, int)
server_id = server.id
# deleting service should delete its server
db.delete(service)
db.commit()
assert_not_found(db, orm.Server, server_id)
def test_token_find(db):
service = db.query(orm.Service).first()
@@ -160,7 +171,7 @@ def test_spawn_fails(db):
orm_user = orm.User(name='aeofel')
db.add(orm_user)
db.commit()
class BadSpawner(MockSpawner):
@gen.coroutine
def start(self):
@@ -181,7 +192,7 @@ def test_spawn_fails(db):
def test_groups(db):
user = orm.User.find(db, name='aeofel')
db.add(user)
group = orm.Group(name='lives')
db.add(group)
db.commit()
@@ -191,6 +202,9 @@ def test_groups(db):
db.commit()
assert group.users == [user]
assert user.groups == [group]
db.delete(user)
db.commit()
assert group.users == []
@pytest.mark.gen_test
@@ -224,7 +238,7 @@ def test_auth_state(db):
assert user.encrypted_auth_state is not None
decrypted_state = yield user.get_auth_state()
assert decrypted_state == state
# can't read auth_state without keys
ck.keys = []
auth_state = yield user.get_auth_state()
@@ -256,3 +270,97 @@ def test_auth_state(db):
decrypted_state = yield user.get_auth_state()
assert decrypted_state is None
def test_spawner_delete_cascade(db):
user = orm.User(name='spawner-delete')
db.add(user)
db.commit()
spawner = orm.Spawner(user=user)
db.commit()
spawner.server = server = orm.Server()
db.commit()
db.delete(spawner)
server_id = server.id
# delete the user
db.delete(spawner)
db.commit()
# verify that server gets deleted
assert_not_found(db, orm.Server, server_id)
def test_user_delete_cascade(db):
user = orm.User(name='db-delete')
oauth_client = orm.OAuthClient(identifier='db-delete-client')
db.add(user)
db.add(oauth_client)
db.commit()
# create a bunch of objects that reference the User
# these should all be deleted automatically when the user goes away
user.new_api_token()
api_token = user.api_tokens[0]
spawner = orm.Spawner(user=user)
db.commit()
spawner.server = server = orm.Server()
oauth_code = orm.OAuthCode(client_id=oauth_client.identifier, user_id=user.id)
db.add(oauth_code)
oauth_token = orm.OAuthAccessToken(
client_id=oauth_client.identifier,
user_id=user.id,
grant_type=orm.GrantType.authorization_code,
)
db.add(oauth_token)
db.commit()
# record all of the ids
spawner_id = spawner.id
server_id = server.id
api_token_id = api_token.id
oauth_code_id = oauth_code.id
oauth_token_id = oauth_token.id
# delete the user
db.delete(user)
db.commit()
# verify that everything gets deleted
assert_not_found(db, orm.APIToken, api_token_id)
assert_not_found(db, orm.Spawner, spawner_id)
assert_not_found(db, orm.Server, server_id)
assert_not_found(db, orm.OAuthCode, oauth_code_id)
assert_not_found(db, orm.OAuthAccessToken, oauth_token_id)
def test_oauth_client_delete_cascade(db):
user = orm.User(name='oauth-delete')
oauth_client = orm.OAuthClient(identifier='oauth-delete-client')
db.add(user)
db.add(oauth_client)
db.commit()
# create a bunch of objects that reference the User
# these should all be deleted automatically when the user goes away
oauth_code = orm.OAuthCode(client_id=oauth_client.identifier, user_id=user.id)
db.add(oauth_code)
oauth_token = orm.OAuthAccessToken(
client_id=oauth_client.identifier,
user_id=user.id,
grant_type=orm.GrantType.authorization_code,
)
db.add(oauth_token)
db.commit()
# record all of the ids
oauth_code_id = oauth_code.id
oauth_token_id = oauth_token.id
# delete the user
db.delete(oauth_client)
db.commit()
# verify that everything gets deleted
assert_not_found(db, orm.OAuthCode, oauth_code_id)
assert_not_found(db, orm.OAuthAccessToken, oauth_token_id)