From 15e4b1ad8b04618f0a8bd491ed6eb8b332f3de48 Mon Sep 17 00:00:00 2001 From: yuvipanda Date: Thu, 27 Jul 2017 23:47:40 -0700 Subject: [PATCH 1/6] Don't expire objects on commit --- jupyterhub/orm.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/jupyterhub/orm.py b/jupyterhub/orm.py index 5d147a2c..7b4715a1 100644 --- a/jupyterhub/orm.py +++ b/jupyterhub/orm.py @@ -599,5 +599,9 @@ def new_session_factory(url="sqlite:///:memory:", reset=False, **kwargs): check_db_revision(engine) Base.metadata.create_all(engine) - session_factory = sessionmaker(bind=engine) + # We set expire_on_commit=False, since we don't actually need + # SQLAlchemy to expire objects after commiting - we don't expect + # concurrent runs of the hub talking to the same db. Turning + # this off gives us a major performance boost + session_factory = sessionmaker(bind=engine, expire_on_commit=False) return session_factory From b1840e8be7390b35d9f3a518df5291e64a745e1b Mon Sep 17 00:00:00 2001 From: Min RK Date: Tue, 17 Apr 2018 10:54:14 +0200 Subject: [PATCH 2/6] use relationships everywhere in order to use sqlalchemy's expire_on_commit=False optimization, we need to make sure that objects are kept up to date. This means we cannot rely on ForeignKey ondelete/onupdate behavior, we must use sqlalchemy's local relationship cascades The main key here is that we must use relationships to set foreign-key relations, e.g. APIToken.user = user instead of APIToken.user_id = user.id. It also means that we cannot use passive_deletes, which allows sqlalchemy to defer to the database's more efficient ON DELETE behavior. This makes deletions more expensive in particular, but should improve db performance overall. --- jupyterhub/oauth/store.py | 9 ++-- jupyterhub/orm.py | 77 +++++++++++++++++++++++------- jupyterhub/tests/test_api.py | 2 +- jupyterhub/tests/test_orm.py | 91 +++++++++++++++++++++++++++++++++--- 4 files changed, 152 insertions(+), 27 deletions(-) diff --git a/jupyterhub/oauth/store.py b/jupyterhub/oauth/store.py index f41b2b39..3cc98463 100644 --- a/jupyterhub/oauth/store.py +++ b/jupyterhub/oauth/store.py @@ -70,11 +70,12 @@ class AccessTokenStore(HubDBMixin, oauth2.store.AccessTokenStore): """ - user = self.db.query(orm.User).filter(orm.User.id == access_token.user_id).first() + user = self.db.query(orm.User).filter_by(id=access_token.user_id).first() if user is None: raise ValueError("No user for access token: %s" % access_token.user_id) + client = self.db.query(orm.OAuthClient).filter_by(identifier=access_token.client_id).first() orm_access_token = orm.OAuthAccessToken( - client_id=access_token.client_id, + client=client, grant_type=access_token.grant_type, expires_at=access_token.expires_at, refresh_token=access_token.refresh_token, @@ -127,10 +128,10 @@ class AuthCodeStore(HubDBMixin, oauth2.store.AuthCodeStore): :class:`oauth2.datatype.AuthorizationCode`. """ orm_code = orm.OAuthCode( - client_id=authorization_code.client_id, + client=authorization_code.client, code=authorization_code.code, expires_at=authorization_code.expires_at, - user_id=authorization_code.user_id, + user=authorization_code.user, redirect_uri=authorization_code.redirect_uri, session_id=authorization_code.data.get('session_id', ''), ) diff --git a/jupyterhub/orm.py b/jupyterhub/orm.py index 7b4715a1..38ba6522 100644 --- a/jupyterhub/orm.py +++ b/jupyterhub/orm.py @@ -14,16 +14,19 @@ from tornado.log import app_log from sqlalchemy.types import TypeDecorator, TEXT, LargeBinary from sqlalchemy import ( - inspect, + create_engine, event, inspect, Column, Integer, ForeignKey, Unicode, Boolean, - DateTime, Enum + DateTime, Enum, Table, ) -from sqlalchemy.ext.declarative import declarative_base, declared_attr +from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.interfaces import PoolListener -from sqlalchemy.orm import backref, sessionmaker, relationship +from sqlalchemy.orm import ( + Session, + interfaces, object_session, relationship, sessionmaker, +) + from sqlalchemy.pool import StaticPool from sqlalchemy.sql.expression import bindparam -from sqlalchemy import create_engine, Table from .utils import ( random_port, @@ -88,7 +91,7 @@ class Group(Base): __tablename__ = 'groups' id = Column(Integer, primary_key=True, autoincrement=True) name = Column(Unicode(255), unique=True) - users = relationship('User', secondary='user_group_map', back_populates='groups') + users = relationship('User', secondary='user_group_map', backref='groups') def __repr__(self): return "<%s %s (%i users)>" % ( @@ -133,9 +136,6 @@ class User(Base): "Spawner", backref="user", cascade="all, delete-orphan", - # can't use passive-deletes on this one - # because we rely on orm-level delete - # for Spawner.server ) @property def orm_spawners(self): @@ -149,13 +149,16 @@ class User(Base): "APIToken", backref="user", cascade="all, delete-orphan", - passive_deletes=True, ) oauth_tokens = relationship( "OAuthAccessToken", backref="user", cascade="all, delete-orphan", - passive_deletes=True, + ) + oauth_codes = relationship( + "OAuthCode", + backref="user", + cascade="all, delete-orphan", ) cookie_id = Column(Unicode(255), default=new_token, nullable=False, unique=True) # User.state is actually Spawner state @@ -164,8 +167,6 @@ class User(Base): # Authenticators can store their state here: # Encryption is handled elsewhere encrypted_auth_state = Column(LargeBinary) - # group mapping - groups = relationship('Group', secondary='user_group_map', back_populates='users') def __repr__(self): return "<{cls}({name} {running}/{total} running)>".format( @@ -234,7 +235,6 @@ class Service(Base): "APIToken", backref="service", cascade="all, delete-orphan", - passive_deletes=True, ) # service-specific interface @@ -408,10 +408,10 @@ class APIToken(Hashed, Base): orm_token.token = token if user: assert user.id is not None - orm_token.user_id = user.id + orm_token.user = user else: assert service.id is not None - orm_token.service_id = service.id + orm_token.service = service db.add(orm_token) db.commit() return token @@ -498,6 +498,18 @@ class OAuthClient(Base): secret = Column(Unicode(255)) redirect_uri = Column(Unicode(1023)) + access_tokens = relationship( + OAuthAccessToken, + backref='client', + cascade='all, delete-orphan', + ) + codes = relationship( + OAuthCode, + backref='client', + cascade='all, delete-orphan', + ) + +# General database utilities class DatabaseSchemaMismatch(Exception): """Exception raised when the database schema version does not match @@ -512,6 +524,39 @@ class ForeignKeysListener(PoolListener): dbapi_con.execute('pragma foreign_keys=ON') +def _expire_relationship(target, relationship_prop): + """Expire relationship backrefs + + used when an object with relationships is deleted + """ + + session = object_session(target) + # get peer objects to be expired + peers = getattr(target, relationship_prop.key) + if peers is None: + # no peer to clear + return + # many-to-many and one-to-many have a list of peers + # many-to-one has only one + if relationship_prop.direction is interfaces.MANYTOONE: + peers = [peers] + for obj in peers: + if inspect(obj).persistent: + session.expire(obj, [relationship_prop.back_populates]) + + +@event.listens_for(Session, "persistent_to_deleted") +def _notify_deleted_relationships(session, obj): + """Expire relationships when an object becomes deleted + + Needed for + """ + mapper = inspect(obj).mapper + for prop in mapper.relationships: + if prop.back_populates: + _expire_relationship(obj, prop) + + def check_db_revision(engine): """Check the JupyterHub database revision diff --git a/jupyterhub/tests/test_api.py b/jupyterhub/tests/test_api.py index 7f8cfa97..d74d5067 100644 --- a/jupyterhub/tests/test_api.py +++ b/jupyterhub/tests/test_api.py @@ -280,7 +280,7 @@ def test_get_self(app): db.commit() oauth_token = orm.OAuthAccessToken( user=u.orm_user, - client_id=oauth_client.identifier, + client=oauth_client, token=token, grant_type=orm.GrantType.authorization_code, ) diff --git a/jupyterhub/tests/test_orm.py b/jupyterhub/tests/test_orm.py index 9e54843a..7edbf3e3 100644 --- a/jupyterhub/tests/test_orm.py +++ b/jupyterhub/tests/test_orm.py @@ -289,6 +289,7 @@ def test_spawner_delete_cascade(db): # verify that server gets deleted assert_not_found(db, orm.Server, server_id) + assert user.orm_spawners == {} def test_user_delete_cascade(db): @@ -305,11 +306,11 @@ def test_user_delete_cascade(db): spawner = orm.Spawner(user=user) db.commit() spawner.server = server = orm.Server() - oauth_code = orm.OAuthCode(client_id=oauth_client.identifier, user_id=user.id) + oauth_code = orm.OAuthCode(client=oauth_client, user=user) db.add(oauth_code) oauth_token = orm.OAuthAccessToken( - client_id=oauth_client.identifier, - user_id=user.id, + client=oauth_client, + user=user, grant_type=orm.GrantType.authorization_code, ) db.add(oauth_token) @@ -343,15 +344,16 @@ def test_oauth_client_delete_cascade(db): # create a bunch of objects that reference the User # these should all be deleted automatically when the user goes away - oauth_code = orm.OAuthCode(client_id=oauth_client.identifier, user_id=user.id) + oauth_code = orm.OAuthCode(client=oauth_client, user=user) db.add(oauth_code) oauth_token = orm.OAuthAccessToken( - client_id=oauth_client.identifier, - user_id=user.id, + client=oauth_client, + user=user, grant_type=orm.GrantType.authorization_code, ) db.add(oauth_token) db.commit() + assert user.oauth_tokens == [oauth_token] # record all of the ids oauth_code_id = oauth_code.id @@ -364,3 +366,80 @@ def test_oauth_client_delete_cascade(db): # verify that everything gets deleted assert_not_found(db, orm.OAuthCode, oauth_code_id) assert_not_found(db, orm.OAuthAccessToken, oauth_token_id) + assert user.oauth_tokens == [] + assert user.oauth_codes == [] + + +def test_delete_token_cascade(db): + user = orm.User(name='mobs') + db.add(user) + db.commit() + user.new_api_token() + api_token = user.api_tokens[0] + db.delete(api_token) + db.commit() + assert user.api_tokens == [] + + +def test_group_delete_cascade(db): + user1 = orm.User(name='user1') + user2 = orm.User(name='user2') + group1 = orm.Group(name='group1') + group2 = orm.Group(name='group2') + db.add(user1) + db.add(user2) + db.add(group1) + db.add(group2) + db.commit() + # add user to group via user.groups works + user1.groups.append(group1) + db.commit() + assert user1 in group1.users + + # add user to group via groups.users works + group1.users.append(user2) + db.commit() + assert user1 in group1.users + assert user2 in group1.users + assert group1 in user1.groups + assert group1 in user2.groups + + # fill out the connections (no new concept) + group2.users.append(user1) + group2.users.append(user2) + db.commit() + assert user1 in group1.users + assert user2 in group1.users + assert user1 in group2.users + assert user2 in group2.users + assert group1 in user1.groups + assert group1 in user2.groups + assert group2 in user1.groups + assert group2 in user2.groups + + # now start deleting + # 1. remove group via user.groups + user1.groups.remove(group2) + db.commit() + assert user1 not in group2.users + assert group2 not in user1.groups + + # 2. remove user via group.users + group1.users.remove(user2) + db.commit() + assert user2 not in group1.users + assert group1 not in user2.groups + + # 3. delete group object + db.delete(group2) + db.commit() + assert group2 not in user1.groups + assert group2 not in user2.groups + + # 4. delete user object + db.delete(user1) + db.commit() + assert user1 not in group1.users + + + From e6c7b280579e45314f4f20e1cea7d9382fe0ff37 Mon Sep 17 00:00:00 2001 From: Min RK Date: Tue, 17 Apr 2018 10:54:35 +0200 Subject: [PATCH 3/6] expire before re-running init_services seems to be required, not sure why --- jupyterhub/tests/mocking.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/jupyterhub/tests/mocking.py b/jupyterhub/tests/mocking.py index cf49e849..4ea10c63 100644 --- a/jupyterhub/tests/mocking.py +++ b/jupyterhub/tests/mocking.py @@ -222,12 +222,21 @@ class MockHub(JupyterHub): def load_config_file(self, *args, **kwargs): pass + def init_tornado_application(self): """Instantiate the tornado Application object""" super().init_tornado_application() # reconnect tornado_settings so that mocks can update the real thing self.tornado_settings = self.users.settings = self.tornado_application.settings + def init_services(self): + # explicitly expire services before reinitializing + # this only happens in tests because re-initialize + # does not occur in a real instance + for service in self.db.query(orm.Service): + self.db.expire(service) + return super().init_services() + @gen.coroutine def initialize(self, argv=None): self.pid_file = NamedTemporaryFile(delete=False).name From e6c2afc4db96abd565e2e2bb674741fe1244b330 Mon Sep 17 00:00:00 2001 From: Min RK Date: Fri, 20 Apr 2018 15:59:29 +0200 Subject: [PATCH 4/6] fix oauth lookup use of relationships have to lookup orm client/user by id client/user attributes don't exist on oauth objects, which aren't orm objects --- jupyterhub/oauth/store.py | 53 ++++++++++++++++++++++++++++----------- 1 file changed, 38 insertions(+), 15 deletions(-) diff --git a/jupyterhub/oauth/store.py b/jupyterhub/oauth/store.py index 3cc98463..97ca345f 100644 --- a/jupyterhub/oauth/store.py +++ b/jupyterhub/oauth/store.py @@ -102,10 +102,12 @@ class AuthCodeStore(HubDBMixin, oauth2.store.AuthCodeStore): given code. """ - orm_code = self.db\ - .query(orm.OAuthCode)\ - .filter(orm.OAuthCode.code == code)\ + orm_code = ( + self.db + .query(orm.OAuthCode) + .filter_by(code=code) .first() + ) if orm_code is None: raise AuthCodeNotFound() else: @@ -119,7 +121,6 @@ class AuthCodeStore(HubDBMixin, oauth2.store.AuthCodeStore): data={'session_id': orm_code.session_id}, ) - def save_code(self, authorization_code): """ Stores the data belonging to an authorization code token. @@ -127,11 +128,29 @@ class AuthCodeStore(HubDBMixin, oauth2.store.AuthCodeStore): :param authorization_code: An instance of :class:`oauth2.datatype.AuthorizationCode`. """ + orm_client = ( + self.db + .query(orm.OAuthClient) + .filter_by(identifier=authorization_code.client_id) + .first() + ) + if orm_client is None: + raise ValueError("No such client: %s" % authorization_code.client_id) + + orm_user = ( + self.db + .query(orm.User) + .filter_by(id=authorization_code.user_id) + .first() + ) + if orm_user is None: + raise ValueError("No such user: %s" % authorization_code.user_id) + orm_code = orm.OAuthCode( - client=authorization_code.client, + client=orm_client, code=authorization_code.code, expires_at=authorization_code.expires_at, - user=authorization_code.user, + user=orm_user, redirect_uri=authorization_code.redirect_uri, session_id=authorization_code.data.get('session_id', ''), ) @@ -147,7 +166,7 @@ class AuthCodeStore(HubDBMixin, oauth2.store.AuthCodeStore): :param code: The authorization code. """ - orm_code = self.db.query(orm.OAuthCode).filter(orm.OAuthCode.code == code).first() + orm_code = self.db.query(orm.OAuthCode).filter_by(code=code).first() if orm_code is not None: self.db.delete(orm_code) self.db.commit() @@ -167,7 +186,7 @@ class HashComparable: """ def __init__(self, hashed_token): self.hashed_token = hashed_token - + def __repr__(self): return "<{} '{}'>".format(self.__class__.__name__, self.hashed_token) @@ -186,10 +205,12 @@ class ClientStore(HubDBMixin, oauth2.store.ClientStore): :raises: :class:`oauth2.error.ClientNotFoundError` if no data could be retrieved for given client_id. """ - orm_client = self.db\ - .query(orm.OAuthClient)\ - .filter(orm.OAuthClient.identifier == client_id)\ + orm_client = ( + self.db + .query(orm.OAuthClient) + .filter_by(identifier=client_id) .first() + ) if orm_client is None: raise ClientNotFoundError() return Client(identifier=client_id, @@ -203,10 +224,12 @@ class ClientStore(HubDBMixin, oauth2.store.ClientStore): hash its client_secret before putting it in the database. """ # clear existing clients with same ID - for client in self.db\ - .query(orm.OAuthClient)\ - .filter(orm.OAuthClient.identifier == client_id): - self.db.delete(client) + for orm_client in ( + self.db + .query(orm.OAuthClient)\ + .filter_by(identifier=client_id) + ): + self.db.delete(orm_client) self.db.commit() orm_client = orm.OAuthClient( From a021f910c8841e15d2926f07ee14311a0cfed8df Mon Sep 17 00:00:00 2001 From: Min RK Date: Fri, 20 Apr 2018 16:02:29 +0200 Subject: [PATCH 5/6] expose expire_on_commit option conservative deployments may set c.JupyterHub.db_kwargs['expire_on_commit'] = True as an escape if the optimization is causing problems. --- jupyterhub/orm.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/jupyterhub/orm.py b/jupyterhub/orm.py index 38ba6522..262ccb54 100644 --- a/jupyterhub/orm.py +++ b/jupyterhub/orm.py @@ -621,7 +621,10 @@ def check_db_revision(engine): )) -def new_session_factory(url="sqlite:///:memory:", reset=False, **kwargs): +def new_session_factory(url="sqlite:///:memory:", + reset=False, + expire_on_commit=False, + **kwargs): """Create a new session at url""" if url.startswith('sqlite'): kwargs.setdefault('connect_args', {'check_same_thread': False}) @@ -648,5 +651,7 @@ def new_session_factory(url="sqlite:///:memory:", reset=False, **kwargs): # SQLAlchemy to expire objects after commiting - we don't expect # concurrent runs of the hub talking to the same db. Turning # this off gives us a major performance boost - session_factory = sessionmaker(bind=engine, expire_on_commit=False) + session_factory = sessionmaker(bind=engine, + expire_on_commit=expire_on_commit, + ) return session_factory From 453e1198086ac7a4d3bfd4c7e28e10463f88fea2 Mon Sep 17 00:00:00 2001 From: Min RK Date: Fri, 20 Apr 2018 16:33:14 +0200 Subject: [PATCH 6/6] don't bypass spawner.server to delete server this shouldn't happen, it's just breaking things --- jupyterhub/tests/test_spawner.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/jupyterhub/tests/test_spawner.py b/jupyterhub/tests/test_spawner.py index d115aae5..6eb061b3 100644 --- a/jupyterhub/tests/test_spawner.py +++ b/jupyterhub/tests/test_spawner.py @@ -370,14 +370,10 @@ def test_spawner_delete_server(app): assert spawner.server is not None assert spawner.orm_spawner.server is not None - # trigger delete via db - db.delete(spawner.orm_spawner.server) - db.commit() - assert spawner.orm_spawner.server is None - - # setting server = None also triggers delete + # setting server = None triggers delete spawner.server = None db.commit() + assert spawner.orm_spawner.server is None # verify that the server was actually deleted from the db assert db.query(orm.Server).filter(orm.Server.id == server_id).first() is None # verify that both ORM and top-level references are None