diff --git a/jupyterhub/orm.py b/jupyterhub/orm.py index 316ec594..65d3b8e1 100644 --- a/jupyterhub/orm.py +++ b/jupyterhub/orm.py @@ -19,8 +19,8 @@ from sqlalchemy import ( DateTime, Enum ) from sqlalchemy.ext.declarative import declarative_base, declared_attr -from sqlalchemy.orm import sessionmaker, relationship from sqlalchemy.interfaces import PoolListener +from sqlalchemy.orm import backref, sessionmaker, relationship from sqlalchemy.pool import StaticPool from sqlalchemy.sql.expression import bindparam from sqlalchemy import create_engine, Table @@ -78,8 +78,8 @@ class Server(Base): # user:group many:many mapping table user_group_map = Table('user_group_map', Base.metadata, - Column('user_id', ForeignKey('users.id'), primary_key=True), - Column('group_id', ForeignKey('groups.id'), primary_key=True), + Column('user_id', ForeignKey('users.id', ondelete='CASCADE'), primary_key=True), + Column('group_id', ForeignKey('groups.id', ondelete='CASCADE'), primary_key=True), ) @@ -129,7 +129,14 @@ class User(Base): id = Column(Integer, primary_key=True, autoincrement=True) name = Column(Unicode(255), unique=True) - _orm_spawners = relationship("Spawner", backref="user") + _orm_spawners = relationship( + "Spawner", + backref="user", + cascade="all, delete-orphan", + # can't use passive-deletes on this one + # because we rely on orm-level delete + # for Spawner.server + ) @property def orm_spawners(self): return {s.name: s for s in self._orm_spawners} @@ -138,7 +145,12 @@ class User(Base): created = Column(DateTime, default=datetime.utcnow) last_activity = Column(DateTime, nullable=True) - api_tokens = relationship("APIToken", backref="user") + api_tokens = relationship( + "APIToken", + backref="user", + cascade="all, delete-orphan", + passive_deletes=True, + ) cookie_id = Column(Unicode(255), default=new_token, nullable=False, unique=True) # User.state is actually Spawner state # We will need to figure something else out if/when we have multiple spawners per user @@ -179,7 +191,7 @@ class Spawner(Base): user_id = Column(Integer, ForeignKey('users.id', ondelete='CASCADE')) server_id = Column(Integer, ForeignKey('servers.id', ondelete='SET NULL')) - server = relationship(Server) + server = relationship(Server, cascade="all") state = Column(JSONDict) name = Column(Unicode(255)) @@ -212,11 +224,16 @@ class Service(Base): name = Column(Unicode(255), unique=True) admin = Column(Boolean, default=False) - api_tokens = relationship("APIToken", backref="service") + api_tokens = relationship( + "APIToken", + backref="service", + cascade="all, delete-orphan", + passive_deletes=True, + ) # service-specific interface _server_id = Column(Integer, ForeignKey('servers.id', ondelete='SET NULL')) - server = relationship(Server, primaryjoin=_server_id == Server.id) + server = relationship(Server, cascade='all') pid = Column(Integer) def new_api_token(self, token=None, generated=True, note=''): @@ -312,13 +329,8 @@ class APIToken(Hashed, Base): """An API token""" __tablename__ = 'api_tokens' - @declared_attr - def user_id(cls): - return Column(Integer, ForeignKey('users.id', ondelete="CASCADE"), nullable=True) - - @declared_attr - def service_id(cls): - return Column(Integer, ForeignKey('services.id', ondelete="CASCADE"), nullable=True) + user_id = Column(Integer, ForeignKey('users.id', ondelete="CASCADE"), nullable=True) + service_id = Column(Integer, ForeignKey('services.id', ondelete="CASCADE"), nullable=True) id = Column(Integer, primary_key=True) hashed = Column(Unicode(255), unique=True) @@ -419,7 +431,6 @@ class OAuthAccessToken(Hashed, Base): refresh_token = Column(Unicode(255)) refresh_expires_at = Column(Integer) user_id = Column(Integer, ForeignKey('users.id', ondelete='CASCADE')) - user = relationship(User) service = None # for API-equivalence with APIToken # the browser session id associated with a given token @@ -433,8 +444,9 @@ class OAuthAccessToken(Hashed, Base): last_activity = Column(DateTime, nullable=True) def __repr__(self): - return "<{cls}('{prefix}...', user='{user}'>".format( + return "<{cls}('{prefix}...', client_id={client_id!r}, user={user!r}>".format( cls=self.__class__.__name__, + client_id=self.client_id, user=self.user and self.user.name, prefix=self.prefix, ) diff --git a/jupyterhub/tests/test_orm.py b/jupyterhub/tests/test_orm.py index 33ffc203..9e54843a 100644 --- a/jupyterhub/tests/test_orm.py +++ b/jupyterhub/tests/test_orm.py @@ -17,6 +17,11 @@ from .mocking import MockSpawner from ..emptyclass import EmptyClass +def assert_not_found(db, ORMType, id): + """Assert that an item with a given id is not found""" + assert db.query(ORMType).filter(ORMType.id==id).first() is None + + def test_server(db): server = orm.Server() db.add(server) @@ -116,14 +121,20 @@ def test_service_server(db): service = orm.Service(name='has_servers') db.add(service) db.commit() - + assert service.server is None server = service.server = orm.Server() assert service assert server.id is None db.commit() assert isinstance(server.id, int) - + server_id = server.id + + # deleting service should delete its server + db.delete(service) + db.commit() + assert_not_found(db, orm.Server, server_id) + def test_token_find(db): service = db.query(orm.Service).first() @@ -160,7 +171,7 @@ def test_spawn_fails(db): orm_user = orm.User(name='aeofel') db.add(orm_user) db.commit() - + class BadSpawner(MockSpawner): @gen.coroutine def start(self): @@ -181,7 +192,7 @@ def test_spawn_fails(db): def test_groups(db): user = orm.User.find(db, name='aeofel') db.add(user) - + group = orm.Group(name='lives') db.add(group) db.commit() @@ -191,6 +202,9 @@ def test_groups(db): db.commit() assert group.users == [user] assert user.groups == [group] + db.delete(user) + db.commit() + assert group.users == [] @pytest.mark.gen_test @@ -224,7 +238,7 @@ def test_auth_state(db): assert user.encrypted_auth_state is not None decrypted_state = yield user.get_auth_state() assert decrypted_state == state - + # can't read auth_state without keys ck.keys = [] auth_state = yield user.get_auth_state() @@ -256,3 +270,97 @@ def test_auth_state(db): decrypted_state = yield user.get_auth_state() assert decrypted_state is None + +def test_spawner_delete_cascade(db): + user = orm.User(name='spawner-delete') + db.add(user) + db.commit() + + spawner = orm.Spawner(user=user) + db.commit() + spawner.server = server = orm.Server() + db.commit() + db.delete(spawner) + server_id = server.id + + # delete the user + db.delete(spawner) + db.commit() + + # verify that server gets deleted + assert_not_found(db, orm.Server, server_id) + + +def test_user_delete_cascade(db): + user = orm.User(name='db-delete') + oauth_client = orm.OAuthClient(identifier='db-delete-client') + db.add(user) + db.add(oauth_client) + db.commit() + + # create a bunch of objects that reference the User + # these should all be deleted automatically when the user goes away + user.new_api_token() + api_token = user.api_tokens[0] + spawner = orm.Spawner(user=user) + db.commit() + spawner.server = server = orm.Server() + oauth_code = orm.OAuthCode(client_id=oauth_client.identifier, user_id=user.id) + db.add(oauth_code) + oauth_token = orm.OAuthAccessToken( + client_id=oauth_client.identifier, + user_id=user.id, + grant_type=orm.GrantType.authorization_code, + ) + db.add(oauth_token) + db.commit() + + # record all of the ids + spawner_id = spawner.id + server_id = server.id + api_token_id = api_token.id + oauth_code_id = oauth_code.id + oauth_token_id = oauth_token.id + + # delete the user + db.delete(user) + db.commit() + + # verify that everything gets deleted + assert_not_found(db, orm.APIToken, api_token_id) + assert_not_found(db, orm.Spawner, spawner_id) + assert_not_found(db, orm.Server, server_id) + assert_not_found(db, orm.OAuthCode, oauth_code_id) + assert_not_found(db, orm.OAuthAccessToken, oauth_token_id) + + +def test_oauth_client_delete_cascade(db): + user = orm.User(name='oauth-delete') + oauth_client = orm.OAuthClient(identifier='oauth-delete-client') + db.add(user) + db.add(oauth_client) + db.commit() + + # create a bunch of objects that reference the User + # these should all be deleted automatically when the user goes away + oauth_code = orm.OAuthCode(client_id=oauth_client.identifier, user_id=user.id) + db.add(oauth_code) + oauth_token = orm.OAuthAccessToken( + client_id=oauth_client.identifier, + user_id=user.id, + grant_type=orm.GrantType.authorization_code, + ) + db.add(oauth_token) + db.commit() + + # record all of the ids + oauth_code_id = oauth_code.id + oauth_token_id = oauth_token.id + + # delete the user + db.delete(oauth_client) + db.commit() + + # verify that everything gets deleted + assert_not_found(db, orm.OAuthCode, oauth_code_id) + assert_not_found(db, orm.OAuthAccessToken, oauth_token_id)