diff --git a/jupyterhub/oauth/store.py b/jupyterhub/oauth/store.py index f41b2b39..97ca345f 100644 --- a/jupyterhub/oauth/store.py +++ b/jupyterhub/oauth/store.py @@ -70,11 +70,12 @@ class AccessTokenStore(HubDBMixin, oauth2.store.AccessTokenStore): """ - user = self.db.query(orm.User).filter(orm.User.id == access_token.user_id).first() + user = self.db.query(orm.User).filter_by(id=access_token.user_id).first() if user is None: raise ValueError("No user for access token: %s" % access_token.user_id) + client = self.db.query(orm.OAuthClient).filter_by(identifier=access_token.client_id).first() orm_access_token = orm.OAuthAccessToken( - client_id=access_token.client_id, + client=client, grant_type=access_token.grant_type, expires_at=access_token.expires_at, refresh_token=access_token.refresh_token, @@ -101,10 +102,12 @@ class AuthCodeStore(HubDBMixin, oauth2.store.AuthCodeStore): given code. """ - orm_code = self.db\ - .query(orm.OAuthCode)\ - .filter(orm.OAuthCode.code == code)\ + orm_code = ( + self.db + .query(orm.OAuthCode) + .filter_by(code=code) .first() + ) if orm_code is None: raise AuthCodeNotFound() else: @@ -118,7 +121,6 @@ class AuthCodeStore(HubDBMixin, oauth2.store.AuthCodeStore): data={'session_id': orm_code.session_id}, ) - def save_code(self, authorization_code): """ Stores the data belonging to an authorization code token. @@ -126,11 +128,29 @@ class AuthCodeStore(HubDBMixin, oauth2.store.AuthCodeStore): :param authorization_code: An instance of :class:`oauth2.datatype.AuthorizationCode`. """ + orm_client = ( + self.db + .query(orm.OAuthClient) + .filter_by(identifier=authorization_code.client_id) + .first() + ) + if orm_client is None: + raise ValueError("No such client: %s" % authorization_code.client_id) + + orm_user = ( + self.db + .query(orm.User) + .filter_by(id=authorization_code.user_id) + .first() + ) + if orm_user is None: + raise ValueError("No such user: %s" % authorization_code.user_id) + orm_code = orm.OAuthCode( - client_id=authorization_code.client_id, + client=orm_client, code=authorization_code.code, expires_at=authorization_code.expires_at, - user_id=authorization_code.user_id, + user=orm_user, redirect_uri=authorization_code.redirect_uri, session_id=authorization_code.data.get('session_id', ''), ) @@ -146,7 +166,7 @@ class AuthCodeStore(HubDBMixin, oauth2.store.AuthCodeStore): :param code: The authorization code. """ - orm_code = self.db.query(orm.OAuthCode).filter(orm.OAuthCode.code == code).first() + orm_code = self.db.query(orm.OAuthCode).filter_by(code=code).first() if orm_code is not None: self.db.delete(orm_code) self.db.commit() @@ -166,7 +186,7 @@ class HashComparable: """ def __init__(self, hashed_token): self.hashed_token = hashed_token - + def __repr__(self): return "<{} '{}'>".format(self.__class__.__name__, self.hashed_token) @@ -185,10 +205,12 @@ class ClientStore(HubDBMixin, oauth2.store.ClientStore): :raises: :class:`oauth2.error.ClientNotFoundError` if no data could be retrieved for given client_id. """ - orm_client = self.db\ - .query(orm.OAuthClient)\ - .filter(orm.OAuthClient.identifier == client_id)\ + orm_client = ( + self.db + .query(orm.OAuthClient) + .filter_by(identifier=client_id) .first() + ) if orm_client is None: raise ClientNotFoundError() return Client(identifier=client_id, @@ -202,10 +224,12 @@ class ClientStore(HubDBMixin, oauth2.store.ClientStore): hash its client_secret before putting it in the database. """ # clear existing clients with same ID - for client in self.db\ - .query(orm.OAuthClient)\ - .filter(orm.OAuthClient.identifier == client_id): - self.db.delete(client) + for orm_client in ( + self.db + .query(orm.OAuthClient)\ + .filter_by(identifier=client_id) + ): + self.db.delete(orm_client) self.db.commit() orm_client = orm.OAuthClient( diff --git a/jupyterhub/orm.py b/jupyterhub/orm.py index 5d147a2c..262ccb54 100644 --- a/jupyterhub/orm.py +++ b/jupyterhub/orm.py @@ -14,16 +14,19 @@ from tornado.log import app_log from sqlalchemy.types import TypeDecorator, TEXT, LargeBinary from sqlalchemy import ( - inspect, + create_engine, event, inspect, Column, Integer, ForeignKey, Unicode, Boolean, - DateTime, Enum + DateTime, Enum, Table, ) -from sqlalchemy.ext.declarative import declarative_base, declared_attr +from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.interfaces import PoolListener -from sqlalchemy.orm import backref, sessionmaker, relationship +from sqlalchemy.orm import ( + Session, + interfaces, object_session, relationship, sessionmaker, +) + from sqlalchemy.pool import StaticPool from sqlalchemy.sql.expression import bindparam -from sqlalchemy import create_engine, Table from .utils import ( random_port, @@ -88,7 +91,7 @@ class Group(Base): __tablename__ = 'groups' id = Column(Integer, primary_key=True, autoincrement=True) name = Column(Unicode(255), unique=True) - users = relationship('User', secondary='user_group_map', back_populates='groups') + users = relationship('User', secondary='user_group_map', backref='groups') def __repr__(self): return "<%s %s (%i users)>" % ( @@ -133,9 +136,6 @@ class User(Base): "Spawner", backref="user", cascade="all, delete-orphan", - # can't use passive-deletes on this one - # because we rely on orm-level delete - # for Spawner.server ) @property def orm_spawners(self): @@ -149,13 +149,16 @@ class User(Base): "APIToken", backref="user", cascade="all, delete-orphan", - passive_deletes=True, ) oauth_tokens = relationship( "OAuthAccessToken", backref="user", cascade="all, delete-orphan", - passive_deletes=True, + ) + oauth_codes = relationship( + "OAuthCode", + backref="user", + cascade="all, delete-orphan", ) cookie_id = Column(Unicode(255), default=new_token, nullable=False, unique=True) # User.state is actually Spawner state @@ -164,8 +167,6 @@ class User(Base): # Authenticators can store their state here: # Encryption is handled elsewhere encrypted_auth_state = Column(LargeBinary) - # group mapping - groups = relationship('Group', secondary='user_group_map', back_populates='users') def __repr__(self): return "<{cls}({name} {running}/{total} running)>".format( @@ -234,7 +235,6 @@ class Service(Base): "APIToken", backref="service", cascade="all, delete-orphan", - passive_deletes=True, ) # service-specific interface @@ -408,10 +408,10 @@ class APIToken(Hashed, Base): orm_token.token = token if user: assert user.id is not None - orm_token.user_id = user.id + orm_token.user = user else: assert service.id is not None - orm_token.service_id = service.id + orm_token.service = service db.add(orm_token) db.commit() return token @@ -498,6 +498,18 @@ class OAuthClient(Base): secret = Column(Unicode(255)) redirect_uri = Column(Unicode(1023)) + access_tokens = relationship( + OAuthAccessToken, + backref='client', + cascade='all, delete-orphan', + ) + codes = relationship( + OAuthCode, + backref='client', + cascade='all, delete-orphan', + ) + +# General database utilities class DatabaseSchemaMismatch(Exception): """Exception raised when the database schema version does not match @@ -512,6 +524,39 @@ class ForeignKeysListener(PoolListener): dbapi_con.execute('pragma foreign_keys=ON') +def _expire_relationship(target, relationship_prop): + """Expire relationship backrefs + + used when an object with relationships is deleted + """ + + session = object_session(target) + # get peer objects to be expired + peers = getattr(target, relationship_prop.key) + if peers is None: + # no peer to clear + return + # many-to-many and one-to-many have a list of peers + # many-to-one has only one + if relationship_prop.direction is interfaces.MANYTOONE: + peers = [peers] + for obj in peers: + if inspect(obj).persistent: + session.expire(obj, [relationship_prop.back_populates]) + + +@event.listens_for(Session, "persistent_to_deleted") +def _notify_deleted_relationships(session, obj): + """Expire relationships when an object becomes deleted + + Needed for + """ + mapper = inspect(obj).mapper + for prop in mapper.relationships: + if prop.back_populates: + _expire_relationship(obj, prop) + + def check_db_revision(engine): """Check the JupyterHub database revision @@ -576,7 +621,10 @@ def check_db_revision(engine): )) -def new_session_factory(url="sqlite:///:memory:", reset=False, **kwargs): +def new_session_factory(url="sqlite:///:memory:", + reset=False, + expire_on_commit=False, + **kwargs): """Create a new session at url""" if url.startswith('sqlite'): kwargs.setdefault('connect_args', {'check_same_thread': False}) @@ -599,5 +647,11 @@ def new_session_factory(url="sqlite:///:memory:", reset=False, **kwargs): check_db_revision(engine) Base.metadata.create_all(engine) - session_factory = sessionmaker(bind=engine) + # We set expire_on_commit=False, since we don't actually need + # SQLAlchemy to expire objects after commiting - we don't expect + # concurrent runs of the hub talking to the same db. Turning + # this off gives us a major performance boost + session_factory = sessionmaker(bind=engine, + expire_on_commit=expire_on_commit, + ) return session_factory diff --git a/jupyterhub/tests/mocking.py b/jupyterhub/tests/mocking.py index cf49e849..4ea10c63 100644 --- a/jupyterhub/tests/mocking.py +++ b/jupyterhub/tests/mocking.py @@ -222,12 +222,21 @@ class MockHub(JupyterHub): def load_config_file(self, *args, **kwargs): pass + def init_tornado_application(self): """Instantiate the tornado Application object""" super().init_tornado_application() # reconnect tornado_settings so that mocks can update the real thing self.tornado_settings = self.users.settings = self.tornado_application.settings + def init_services(self): + # explicitly expire services before reinitializing + # this only happens in tests because re-initialize + # does not occur in a real instance + for service in self.db.query(orm.Service): + self.db.expire(service) + return super().init_services() + @gen.coroutine def initialize(self, argv=None): self.pid_file = NamedTemporaryFile(delete=False).name diff --git a/jupyterhub/tests/test_api.py b/jupyterhub/tests/test_api.py index 7f8cfa97..d74d5067 100644 --- a/jupyterhub/tests/test_api.py +++ b/jupyterhub/tests/test_api.py @@ -280,7 +280,7 @@ def test_get_self(app): db.commit() oauth_token = orm.OAuthAccessToken( user=u.orm_user, - client_id=oauth_client.identifier, + client=oauth_client, token=token, grant_type=orm.GrantType.authorization_code, ) diff --git a/jupyterhub/tests/test_orm.py b/jupyterhub/tests/test_orm.py index 9e54843a..7edbf3e3 100644 --- a/jupyterhub/tests/test_orm.py +++ b/jupyterhub/tests/test_orm.py @@ -289,6 +289,7 @@ def test_spawner_delete_cascade(db): # verify that server gets deleted assert_not_found(db, orm.Server, server_id) + assert user.orm_spawners == {} def test_user_delete_cascade(db): @@ -305,11 +306,11 @@ def test_user_delete_cascade(db): spawner = orm.Spawner(user=user) db.commit() spawner.server = server = orm.Server() - oauth_code = orm.OAuthCode(client_id=oauth_client.identifier, user_id=user.id) + oauth_code = orm.OAuthCode(client=oauth_client, user=user) db.add(oauth_code) oauth_token = orm.OAuthAccessToken( - client_id=oauth_client.identifier, - user_id=user.id, + client=oauth_client, + user=user, grant_type=orm.GrantType.authorization_code, ) db.add(oauth_token) @@ -343,15 +344,16 @@ def test_oauth_client_delete_cascade(db): # create a bunch of objects that reference the User # these should all be deleted automatically when the user goes away - oauth_code = orm.OAuthCode(client_id=oauth_client.identifier, user_id=user.id) + oauth_code = orm.OAuthCode(client=oauth_client, user=user) db.add(oauth_code) oauth_token = orm.OAuthAccessToken( - client_id=oauth_client.identifier, - user_id=user.id, + client=oauth_client, + user=user, grant_type=orm.GrantType.authorization_code, ) db.add(oauth_token) db.commit() + assert user.oauth_tokens == [oauth_token] # record all of the ids oauth_code_id = oauth_code.id @@ -364,3 +366,80 @@ def test_oauth_client_delete_cascade(db): # verify that everything gets deleted assert_not_found(db, orm.OAuthCode, oauth_code_id) assert_not_found(db, orm.OAuthAccessToken, oauth_token_id) + assert user.oauth_tokens == [] + assert user.oauth_codes == [] + + +def test_delete_token_cascade(db): + user = orm.User(name='mobs') + db.add(user) + db.commit() + user.new_api_token() + api_token = user.api_tokens[0] + db.delete(api_token) + db.commit() + assert user.api_tokens == [] + + +def test_group_delete_cascade(db): + user1 = orm.User(name='user1') + user2 = orm.User(name='user2') + group1 = orm.Group(name='group1') + group2 = orm.Group(name='group2') + db.add(user1) + db.add(user2) + db.add(group1) + db.add(group2) + db.commit() + # add user to group via user.groups works + user1.groups.append(group1) + db.commit() + assert user1 in group1.users + + # add user to group via groups.users works + group1.users.append(user2) + db.commit() + assert user1 in group1.users + assert user2 in group1.users + assert group1 in user1.groups + assert group1 in user2.groups + + # fill out the connections (no new concept) + group2.users.append(user1) + group2.users.append(user2) + db.commit() + assert user1 in group1.users + assert user2 in group1.users + assert user1 in group2.users + assert user2 in group2.users + assert group1 in user1.groups + assert group1 in user2.groups + assert group2 in user1.groups + assert group2 in user2.groups + + # now start deleting + # 1. remove group via user.groups + user1.groups.remove(group2) + db.commit() + assert user1 not in group2.users + assert group2 not in user1.groups + + # 2. remove user via group.users + group1.users.remove(user2) + db.commit() + assert user2 not in group1.users + assert group1 not in user2.groups + + # 3. delete group object + db.delete(group2) + db.commit() + assert group2 not in user1.groups + assert group2 not in user2.groups + + # 4. delete user object + db.delete(user1) + db.commit() + assert user1 not in group1.users + + + diff --git a/jupyterhub/tests/test_spawner.py b/jupyterhub/tests/test_spawner.py index d115aae5..6eb061b3 100644 --- a/jupyterhub/tests/test_spawner.py +++ b/jupyterhub/tests/test_spawner.py @@ -370,14 +370,10 @@ def test_spawner_delete_server(app): assert spawner.server is not None assert spawner.orm_spawner.server is not None - # trigger delete via db - db.delete(spawner.orm_spawner.server) - db.commit() - assert spawner.orm_spawner.server is None - - # setting server = None also triggers delete + # setting server = None triggers delete spawner.server = None db.commit() + assert spawner.orm_spawner.server is None # verify that the server was actually deleted from the db assert db.query(orm.Server).filter(orm.Server.id == server_id).first() is None # verify that both ORM and top-level references are None