diff --git a/jupyterhub/oauth/store.py b/jupyterhub/oauth/store.py index f41b2b39..3cc98463 100644 --- a/jupyterhub/oauth/store.py +++ b/jupyterhub/oauth/store.py @@ -70,11 +70,12 @@ class AccessTokenStore(HubDBMixin, oauth2.store.AccessTokenStore): """ - user = self.db.query(orm.User).filter(orm.User.id == access_token.user_id).first() + user = self.db.query(orm.User).filter_by(id=access_token.user_id).first() if user is None: raise ValueError("No user for access token: %s" % access_token.user_id) + client = self.db.query(orm.OAuthClient).filter_by(identifier=access_token.client_id).first() orm_access_token = orm.OAuthAccessToken( - client_id=access_token.client_id, + client=client, grant_type=access_token.grant_type, expires_at=access_token.expires_at, refresh_token=access_token.refresh_token, @@ -127,10 +128,10 @@ class AuthCodeStore(HubDBMixin, oauth2.store.AuthCodeStore): :class:`oauth2.datatype.AuthorizationCode`. """ orm_code = orm.OAuthCode( - client_id=authorization_code.client_id, + client=authorization_code.client, code=authorization_code.code, expires_at=authorization_code.expires_at, - user_id=authorization_code.user_id, + user=authorization_code.user, redirect_uri=authorization_code.redirect_uri, session_id=authorization_code.data.get('session_id', ''), ) diff --git a/jupyterhub/orm.py b/jupyterhub/orm.py index 7b4715a1..38ba6522 100644 --- a/jupyterhub/orm.py +++ b/jupyterhub/orm.py @@ -14,16 +14,19 @@ from tornado.log import app_log from sqlalchemy.types import TypeDecorator, TEXT, LargeBinary from sqlalchemy import ( - inspect, + create_engine, event, inspect, Column, Integer, ForeignKey, Unicode, Boolean, - DateTime, Enum + DateTime, Enum, Table, ) -from sqlalchemy.ext.declarative import declarative_base, declared_attr +from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.interfaces import PoolListener -from sqlalchemy.orm import backref, sessionmaker, relationship +from sqlalchemy.orm import ( + Session, + interfaces, object_session, relationship, sessionmaker, +) + from sqlalchemy.pool import StaticPool from sqlalchemy.sql.expression import bindparam -from sqlalchemy import create_engine, Table from .utils import ( random_port, @@ -88,7 +91,7 @@ class Group(Base): __tablename__ = 'groups' id = Column(Integer, primary_key=True, autoincrement=True) name = Column(Unicode(255), unique=True) - users = relationship('User', secondary='user_group_map', back_populates='groups') + users = relationship('User', secondary='user_group_map', backref='groups') def __repr__(self): return "<%s %s (%i users)>" % ( @@ -133,9 +136,6 @@ class User(Base): "Spawner", backref="user", cascade="all, delete-orphan", - # can't use passive-deletes on this one - # because we rely on orm-level delete - # for Spawner.server ) @property def orm_spawners(self): @@ -149,13 +149,16 @@ class User(Base): "APIToken", backref="user", cascade="all, delete-orphan", - passive_deletes=True, ) oauth_tokens = relationship( "OAuthAccessToken", backref="user", cascade="all, delete-orphan", - passive_deletes=True, + ) + oauth_codes = relationship( + "OAuthCode", + backref="user", + cascade="all, delete-orphan", ) cookie_id = Column(Unicode(255), default=new_token, nullable=False, unique=True) # User.state is actually Spawner state @@ -164,8 +167,6 @@ class User(Base): # Authenticators can store their state here: # Encryption is handled elsewhere encrypted_auth_state = Column(LargeBinary) - # group mapping - groups = relationship('Group', secondary='user_group_map', back_populates='users') def __repr__(self): return "<{cls}({name} {running}/{total} running)>".format( @@ -234,7 +235,6 @@ class Service(Base): "APIToken", backref="service", cascade="all, delete-orphan", - passive_deletes=True, ) # service-specific interface @@ -408,10 +408,10 @@ class APIToken(Hashed, Base): orm_token.token = token if user: assert user.id is not None - orm_token.user_id = user.id + orm_token.user = user else: assert service.id is not None - orm_token.service_id = service.id + orm_token.service = service db.add(orm_token) db.commit() return token @@ -498,6 +498,18 @@ class OAuthClient(Base): secret = Column(Unicode(255)) redirect_uri = Column(Unicode(1023)) + access_tokens = relationship( + OAuthAccessToken, + backref='client', + cascade='all, delete-orphan', + ) + codes = relationship( + OAuthCode, + backref='client', + cascade='all, delete-orphan', + ) + +# General database utilities class DatabaseSchemaMismatch(Exception): """Exception raised when the database schema version does not match @@ -512,6 +524,39 @@ class ForeignKeysListener(PoolListener): dbapi_con.execute('pragma foreign_keys=ON') +def _expire_relationship(target, relationship_prop): + """Expire relationship backrefs + + used when an object with relationships is deleted + """ + + session = object_session(target) + # get peer objects to be expired + peers = getattr(target, relationship_prop.key) + if peers is None: + # no peer to clear + return + # many-to-many and one-to-many have a list of peers + # many-to-one has only one + if relationship_prop.direction is interfaces.MANYTOONE: + peers = [peers] + for obj in peers: + if inspect(obj).persistent: + session.expire(obj, [relationship_prop.back_populates]) + + +@event.listens_for(Session, "persistent_to_deleted") +def _notify_deleted_relationships(session, obj): + """Expire relationships when an object becomes deleted + + Needed for + """ + mapper = inspect(obj).mapper + for prop in mapper.relationships: + if prop.back_populates: + _expire_relationship(obj, prop) + + def check_db_revision(engine): """Check the JupyterHub database revision diff --git a/jupyterhub/tests/test_api.py b/jupyterhub/tests/test_api.py index 7f8cfa97..d74d5067 100644 --- a/jupyterhub/tests/test_api.py +++ b/jupyterhub/tests/test_api.py @@ -280,7 +280,7 @@ def test_get_self(app): db.commit() oauth_token = orm.OAuthAccessToken( user=u.orm_user, - client_id=oauth_client.identifier, + client=oauth_client, token=token, grant_type=orm.GrantType.authorization_code, ) diff --git a/jupyterhub/tests/test_orm.py b/jupyterhub/tests/test_orm.py index 9e54843a..7edbf3e3 100644 --- a/jupyterhub/tests/test_orm.py +++ b/jupyterhub/tests/test_orm.py @@ -289,6 +289,7 @@ def test_spawner_delete_cascade(db): # verify that server gets deleted assert_not_found(db, orm.Server, server_id) + assert user.orm_spawners == {} def test_user_delete_cascade(db): @@ -305,11 +306,11 @@ def test_user_delete_cascade(db): spawner = orm.Spawner(user=user) db.commit() spawner.server = server = orm.Server() - oauth_code = orm.OAuthCode(client_id=oauth_client.identifier, user_id=user.id) + oauth_code = orm.OAuthCode(client=oauth_client, user=user) db.add(oauth_code) oauth_token = orm.OAuthAccessToken( - client_id=oauth_client.identifier, - user_id=user.id, + client=oauth_client, + user=user, grant_type=orm.GrantType.authorization_code, ) db.add(oauth_token) @@ -343,15 +344,16 @@ def test_oauth_client_delete_cascade(db): # create a bunch of objects that reference the User # these should all be deleted automatically when the user goes away - oauth_code = orm.OAuthCode(client_id=oauth_client.identifier, user_id=user.id) + oauth_code = orm.OAuthCode(client=oauth_client, user=user) db.add(oauth_code) oauth_token = orm.OAuthAccessToken( - client_id=oauth_client.identifier, - user_id=user.id, + client=oauth_client, + user=user, grant_type=orm.GrantType.authorization_code, ) db.add(oauth_token) db.commit() + assert user.oauth_tokens == [oauth_token] # record all of the ids oauth_code_id = oauth_code.id @@ -364,3 +366,80 @@ def test_oauth_client_delete_cascade(db): # verify that everything gets deleted assert_not_found(db, orm.OAuthCode, oauth_code_id) assert_not_found(db, orm.OAuthAccessToken, oauth_token_id) + assert user.oauth_tokens == [] + assert user.oauth_codes == [] + + +def test_delete_token_cascade(db): + user = orm.User(name='mobs') + db.add(user) + db.commit() + user.new_api_token() + api_token = user.api_tokens[0] + db.delete(api_token) + db.commit() + assert user.api_tokens == [] + + +def test_group_delete_cascade(db): + user1 = orm.User(name='user1') + user2 = orm.User(name='user2') + group1 = orm.Group(name='group1') + group2 = orm.Group(name='group2') + db.add(user1) + db.add(user2) + db.add(group1) + db.add(group2) + db.commit() + # add user to group via user.groups works + user1.groups.append(group1) + db.commit() + assert user1 in group1.users + + # add user to group via groups.users works + group1.users.append(user2) + db.commit() + assert user1 in group1.users + assert user2 in group1.users + assert group1 in user1.groups + assert group1 in user2.groups + + # fill out the connections (no new concept) + group2.users.append(user1) + group2.users.append(user2) + db.commit() + assert user1 in group1.users + assert user2 in group1.users + assert user1 in group2.users + assert user2 in group2.users + assert group1 in user1.groups + assert group1 in user2.groups + assert group2 in user1.groups + assert group2 in user2.groups + + # now start deleting + # 1. remove group via user.groups + user1.groups.remove(group2) + db.commit() + assert user1 not in group2.users + assert group2 not in user1.groups + + # 2. remove user via group.users + group1.users.remove(user2) + db.commit() + assert user2 not in group1.users + assert group1 not in user2.groups + + # 3. delete group object + db.delete(group2) + db.commit() + assert group2 not in user1.groups + assert group2 not in user2.groups + + # 4. delete user object + db.delete(user1) + db.commit() + assert user1 not in group1.users + + +