mirror of
https://github.com/jupyterhub/jupyterhub.git
synced 2025-10-15 05:53:00 +00:00
fix and test deletion cascades
- ensure foreign keys are enabled on sqlite - fix deletion cascades where relationships were causing dissociation instead of deletion
This commit is contained in:
@@ -19,8 +19,8 @@ from sqlalchemy import (
|
|||||||
DateTime, Enum
|
DateTime, Enum
|
||||||
)
|
)
|
||||||
from sqlalchemy.ext.declarative import declarative_base, declared_attr
|
from sqlalchemy.ext.declarative import declarative_base, declared_attr
|
||||||
from sqlalchemy.orm import sessionmaker, relationship
|
|
||||||
from sqlalchemy.interfaces import PoolListener
|
from sqlalchemy.interfaces import PoolListener
|
||||||
|
from sqlalchemy.orm import backref, sessionmaker, relationship
|
||||||
from sqlalchemy.pool import StaticPool
|
from sqlalchemy.pool import StaticPool
|
||||||
from sqlalchemy.sql.expression import bindparam
|
from sqlalchemy.sql.expression import bindparam
|
||||||
from sqlalchemy import create_engine, Table
|
from sqlalchemy import create_engine, Table
|
||||||
@@ -78,8 +78,8 @@ class Server(Base):
|
|||||||
|
|
||||||
# user:group many:many mapping table
|
# user:group many:many mapping table
|
||||||
user_group_map = Table('user_group_map', Base.metadata,
|
user_group_map = Table('user_group_map', Base.metadata,
|
||||||
Column('user_id', ForeignKey('users.id'), primary_key=True),
|
Column('user_id', ForeignKey('users.id', ondelete='CASCADE'), primary_key=True),
|
||||||
Column('group_id', ForeignKey('groups.id'), primary_key=True),
|
Column('group_id', ForeignKey('groups.id', ondelete='CASCADE'), primary_key=True),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -129,7 +129,14 @@ class User(Base):
|
|||||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||||
name = Column(Unicode(255), unique=True)
|
name = Column(Unicode(255), unique=True)
|
||||||
|
|
||||||
_orm_spawners = relationship("Spawner", backref="user")
|
_orm_spawners = relationship(
|
||||||
|
"Spawner",
|
||||||
|
backref="user",
|
||||||
|
cascade="all, delete-orphan",
|
||||||
|
# can't use passive-deletes on this one
|
||||||
|
# because we rely on orm-level delete
|
||||||
|
# for Spawner.server
|
||||||
|
)
|
||||||
@property
|
@property
|
||||||
def orm_spawners(self):
|
def orm_spawners(self):
|
||||||
return {s.name: s for s in self._orm_spawners}
|
return {s.name: s for s in self._orm_spawners}
|
||||||
@@ -138,7 +145,12 @@ class User(Base):
|
|||||||
created = Column(DateTime, default=datetime.utcnow)
|
created = Column(DateTime, default=datetime.utcnow)
|
||||||
last_activity = Column(DateTime, nullable=True)
|
last_activity = Column(DateTime, nullable=True)
|
||||||
|
|
||||||
api_tokens = relationship("APIToken", backref="user")
|
api_tokens = relationship(
|
||||||
|
"APIToken",
|
||||||
|
backref="user",
|
||||||
|
cascade="all, delete-orphan",
|
||||||
|
passive_deletes=True,
|
||||||
|
)
|
||||||
cookie_id = Column(Unicode(255), default=new_token, nullable=False, unique=True)
|
cookie_id = Column(Unicode(255), default=new_token, nullable=False, unique=True)
|
||||||
# User.state is actually Spawner state
|
# User.state is actually Spawner state
|
||||||
# We will need to figure something else out if/when we have multiple spawners per user
|
# We will need to figure something else out if/when we have multiple spawners per user
|
||||||
@@ -179,7 +191,7 @@ class Spawner(Base):
|
|||||||
user_id = Column(Integer, ForeignKey('users.id', ondelete='CASCADE'))
|
user_id = Column(Integer, ForeignKey('users.id', ondelete='CASCADE'))
|
||||||
|
|
||||||
server_id = Column(Integer, ForeignKey('servers.id', ondelete='SET NULL'))
|
server_id = Column(Integer, ForeignKey('servers.id', ondelete='SET NULL'))
|
||||||
server = relationship(Server)
|
server = relationship(Server, cascade="all")
|
||||||
|
|
||||||
state = Column(JSONDict)
|
state = Column(JSONDict)
|
||||||
name = Column(Unicode(255))
|
name = Column(Unicode(255))
|
||||||
@@ -212,11 +224,16 @@ class Service(Base):
|
|||||||
name = Column(Unicode(255), unique=True)
|
name = Column(Unicode(255), unique=True)
|
||||||
admin = Column(Boolean, default=False)
|
admin = Column(Boolean, default=False)
|
||||||
|
|
||||||
api_tokens = relationship("APIToken", backref="service")
|
api_tokens = relationship(
|
||||||
|
"APIToken",
|
||||||
|
backref="service",
|
||||||
|
cascade="all, delete-orphan",
|
||||||
|
passive_deletes=True,
|
||||||
|
)
|
||||||
|
|
||||||
# service-specific interface
|
# service-specific interface
|
||||||
_server_id = Column(Integer, ForeignKey('servers.id', ondelete='SET NULL'))
|
_server_id = Column(Integer, ForeignKey('servers.id', ondelete='SET NULL'))
|
||||||
server = relationship(Server, primaryjoin=_server_id == Server.id)
|
server = relationship(Server, cascade='all')
|
||||||
pid = Column(Integer)
|
pid = Column(Integer)
|
||||||
|
|
||||||
def new_api_token(self, token=None, generated=True, note=''):
|
def new_api_token(self, token=None, generated=True, note=''):
|
||||||
@@ -312,13 +329,8 @@ class APIToken(Hashed, Base):
|
|||||||
"""An API token"""
|
"""An API token"""
|
||||||
__tablename__ = 'api_tokens'
|
__tablename__ = 'api_tokens'
|
||||||
|
|
||||||
@declared_attr
|
user_id = Column(Integer, ForeignKey('users.id', ondelete="CASCADE"), nullable=True)
|
||||||
def user_id(cls):
|
service_id = Column(Integer, ForeignKey('services.id', ondelete="CASCADE"), nullable=True)
|
||||||
return Column(Integer, ForeignKey('users.id', ondelete="CASCADE"), nullable=True)
|
|
||||||
|
|
||||||
@declared_attr
|
|
||||||
def service_id(cls):
|
|
||||||
return Column(Integer, ForeignKey('services.id', ondelete="CASCADE"), nullable=True)
|
|
||||||
|
|
||||||
id = Column(Integer, primary_key=True)
|
id = Column(Integer, primary_key=True)
|
||||||
hashed = Column(Unicode(255), unique=True)
|
hashed = Column(Unicode(255), unique=True)
|
||||||
@@ -419,7 +431,6 @@ class OAuthAccessToken(Hashed, Base):
|
|||||||
refresh_token = Column(Unicode(255))
|
refresh_token = Column(Unicode(255))
|
||||||
refresh_expires_at = Column(Integer)
|
refresh_expires_at = Column(Integer)
|
||||||
user_id = Column(Integer, ForeignKey('users.id', ondelete='CASCADE'))
|
user_id = Column(Integer, ForeignKey('users.id', ondelete='CASCADE'))
|
||||||
user = relationship(User)
|
|
||||||
service = None # for API-equivalence with APIToken
|
service = None # for API-equivalence with APIToken
|
||||||
|
|
||||||
# the browser session id associated with a given token
|
# the browser session id associated with a given token
|
||||||
@@ -433,8 +444,9 @@ class OAuthAccessToken(Hashed, Base):
|
|||||||
last_activity = Column(DateTime, nullable=True)
|
last_activity = Column(DateTime, nullable=True)
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return "<{cls}('{prefix}...', user='{user}'>".format(
|
return "<{cls}('{prefix}...', client_id={client_id!r}, user={user!r}>".format(
|
||||||
cls=self.__class__.__name__,
|
cls=self.__class__.__name__,
|
||||||
|
client_id=self.client_id,
|
||||||
user=self.user and self.user.name,
|
user=self.user and self.user.name,
|
||||||
prefix=self.prefix,
|
prefix=self.prefix,
|
||||||
)
|
)
|
||||||
|
@@ -17,6 +17,11 @@ from .mocking import MockSpawner
|
|||||||
from ..emptyclass import EmptyClass
|
from ..emptyclass import EmptyClass
|
||||||
|
|
||||||
|
|
||||||
|
def assert_not_found(db, ORMType, id):
|
||||||
|
"""Assert that an item with a given id is not found"""
|
||||||
|
assert db.query(ORMType).filter(ORMType.id==id).first() is None
|
||||||
|
|
||||||
|
|
||||||
def test_server(db):
|
def test_server(db):
|
||||||
server = orm.Server()
|
server = orm.Server()
|
||||||
db.add(server)
|
db.add(server)
|
||||||
@@ -116,14 +121,20 @@ def test_service_server(db):
|
|||||||
service = orm.Service(name='has_servers')
|
service = orm.Service(name='has_servers')
|
||||||
db.add(service)
|
db.add(service)
|
||||||
db.commit()
|
db.commit()
|
||||||
|
|
||||||
assert service.server is None
|
assert service.server is None
|
||||||
server = service.server = orm.Server()
|
server = service.server = orm.Server()
|
||||||
assert service
|
assert service
|
||||||
assert server.id is None
|
assert server.id is None
|
||||||
db.commit()
|
db.commit()
|
||||||
assert isinstance(server.id, int)
|
assert isinstance(server.id, int)
|
||||||
|
server_id = server.id
|
||||||
|
|
||||||
|
# deleting service should delete its server
|
||||||
|
db.delete(service)
|
||||||
|
db.commit()
|
||||||
|
assert_not_found(db, orm.Server, server_id)
|
||||||
|
|
||||||
|
|
||||||
def test_token_find(db):
|
def test_token_find(db):
|
||||||
service = db.query(orm.Service).first()
|
service = db.query(orm.Service).first()
|
||||||
@@ -160,7 +171,7 @@ def test_spawn_fails(db):
|
|||||||
orm_user = orm.User(name='aeofel')
|
orm_user = orm.User(name='aeofel')
|
||||||
db.add(orm_user)
|
db.add(orm_user)
|
||||||
db.commit()
|
db.commit()
|
||||||
|
|
||||||
class BadSpawner(MockSpawner):
|
class BadSpawner(MockSpawner):
|
||||||
@gen.coroutine
|
@gen.coroutine
|
||||||
def start(self):
|
def start(self):
|
||||||
@@ -181,7 +192,7 @@ def test_spawn_fails(db):
|
|||||||
def test_groups(db):
|
def test_groups(db):
|
||||||
user = orm.User.find(db, name='aeofel')
|
user = orm.User.find(db, name='aeofel')
|
||||||
db.add(user)
|
db.add(user)
|
||||||
|
|
||||||
group = orm.Group(name='lives')
|
group = orm.Group(name='lives')
|
||||||
db.add(group)
|
db.add(group)
|
||||||
db.commit()
|
db.commit()
|
||||||
@@ -191,6 +202,9 @@ def test_groups(db):
|
|||||||
db.commit()
|
db.commit()
|
||||||
assert group.users == [user]
|
assert group.users == [user]
|
||||||
assert user.groups == [group]
|
assert user.groups == [group]
|
||||||
|
db.delete(user)
|
||||||
|
db.commit()
|
||||||
|
assert group.users == []
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.gen_test
|
@pytest.mark.gen_test
|
||||||
@@ -224,7 +238,7 @@ def test_auth_state(db):
|
|||||||
assert user.encrypted_auth_state is not None
|
assert user.encrypted_auth_state is not None
|
||||||
decrypted_state = yield user.get_auth_state()
|
decrypted_state = yield user.get_auth_state()
|
||||||
assert decrypted_state == state
|
assert decrypted_state == state
|
||||||
|
|
||||||
# can't read auth_state without keys
|
# can't read auth_state without keys
|
||||||
ck.keys = []
|
ck.keys = []
|
||||||
auth_state = yield user.get_auth_state()
|
auth_state = yield user.get_auth_state()
|
||||||
@@ -256,3 +270,97 @@ def test_auth_state(db):
|
|||||||
decrypted_state = yield user.get_auth_state()
|
decrypted_state = yield user.get_auth_state()
|
||||||
assert decrypted_state is None
|
assert decrypted_state is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_spawner_delete_cascade(db):
|
||||||
|
user = orm.User(name='spawner-delete')
|
||||||
|
db.add(user)
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
spawner = orm.Spawner(user=user)
|
||||||
|
db.commit()
|
||||||
|
spawner.server = server = orm.Server()
|
||||||
|
db.commit()
|
||||||
|
db.delete(spawner)
|
||||||
|
server_id = server.id
|
||||||
|
|
||||||
|
# delete the user
|
||||||
|
db.delete(spawner)
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
# verify that server gets deleted
|
||||||
|
assert_not_found(db, orm.Server, server_id)
|
||||||
|
|
||||||
|
|
||||||
|
def test_user_delete_cascade(db):
|
||||||
|
user = orm.User(name='db-delete')
|
||||||
|
oauth_client = orm.OAuthClient(identifier='db-delete-client')
|
||||||
|
db.add(user)
|
||||||
|
db.add(oauth_client)
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
# create a bunch of objects that reference the User
|
||||||
|
# these should all be deleted automatically when the user goes away
|
||||||
|
user.new_api_token()
|
||||||
|
api_token = user.api_tokens[0]
|
||||||
|
spawner = orm.Spawner(user=user)
|
||||||
|
db.commit()
|
||||||
|
spawner.server = server = orm.Server()
|
||||||
|
oauth_code = orm.OAuthCode(client_id=oauth_client.identifier, user_id=user.id)
|
||||||
|
db.add(oauth_code)
|
||||||
|
oauth_token = orm.OAuthAccessToken(
|
||||||
|
client_id=oauth_client.identifier,
|
||||||
|
user_id=user.id,
|
||||||
|
grant_type=orm.GrantType.authorization_code,
|
||||||
|
)
|
||||||
|
db.add(oauth_token)
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
# record all of the ids
|
||||||
|
spawner_id = spawner.id
|
||||||
|
server_id = server.id
|
||||||
|
api_token_id = api_token.id
|
||||||
|
oauth_code_id = oauth_code.id
|
||||||
|
oauth_token_id = oauth_token.id
|
||||||
|
|
||||||
|
# delete the user
|
||||||
|
db.delete(user)
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
# verify that everything gets deleted
|
||||||
|
assert_not_found(db, orm.APIToken, api_token_id)
|
||||||
|
assert_not_found(db, orm.Spawner, spawner_id)
|
||||||
|
assert_not_found(db, orm.Server, server_id)
|
||||||
|
assert_not_found(db, orm.OAuthCode, oauth_code_id)
|
||||||
|
assert_not_found(db, orm.OAuthAccessToken, oauth_token_id)
|
||||||
|
|
||||||
|
|
||||||
|
def test_oauth_client_delete_cascade(db):
|
||||||
|
user = orm.User(name='oauth-delete')
|
||||||
|
oauth_client = orm.OAuthClient(identifier='oauth-delete-client')
|
||||||
|
db.add(user)
|
||||||
|
db.add(oauth_client)
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
# create a bunch of objects that reference the User
|
||||||
|
# these should all be deleted automatically when the user goes away
|
||||||
|
oauth_code = orm.OAuthCode(client_id=oauth_client.identifier, user_id=user.id)
|
||||||
|
db.add(oauth_code)
|
||||||
|
oauth_token = orm.OAuthAccessToken(
|
||||||
|
client_id=oauth_client.identifier,
|
||||||
|
user_id=user.id,
|
||||||
|
grant_type=orm.GrantType.authorization_code,
|
||||||
|
)
|
||||||
|
db.add(oauth_token)
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
# record all of the ids
|
||||||
|
oauth_code_id = oauth_code.id
|
||||||
|
oauth_token_id = oauth_token.id
|
||||||
|
|
||||||
|
# delete the user
|
||||||
|
db.delete(oauth_client)
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
# verify that everything gets deleted
|
||||||
|
assert_not_found(db, orm.OAuthCode, oauth_code_id)
|
||||||
|
assert_not_found(db, orm.OAuthAccessToken, oauth_token_id)
|
||||||
|
Reference in New Issue
Block a user