Merge pull request #1809 from minrk/no-expire-again

don't expire objects on commit
This commit is contained in:
Carol Willing
2018-04-23 09:26:57 -07:00
committed by GitHub
6 changed files with 210 additions and 48 deletions

View File

@@ -70,11 +70,12 @@ class AccessTokenStore(HubDBMixin, oauth2.store.AccessTokenStore):
""" """
user = self.db.query(orm.User).filter(orm.User.id == access_token.user_id).first() user = self.db.query(orm.User).filter_by(id=access_token.user_id).first()
if user is None: if user is None:
raise ValueError("No user for access token: %s" % access_token.user_id) raise ValueError("No user for access token: %s" % access_token.user_id)
client = self.db.query(orm.OAuthClient).filter_by(identifier=access_token.client_id).first()
orm_access_token = orm.OAuthAccessToken( orm_access_token = orm.OAuthAccessToken(
client_id=access_token.client_id, client=client,
grant_type=access_token.grant_type, grant_type=access_token.grant_type,
expires_at=access_token.expires_at, expires_at=access_token.expires_at,
refresh_token=access_token.refresh_token, refresh_token=access_token.refresh_token,
@@ -101,10 +102,12 @@ class AuthCodeStore(HubDBMixin, oauth2.store.AuthCodeStore):
given code. given code.
""" """
orm_code = self.db\ orm_code = (
.query(orm.OAuthCode)\ self.db
.filter(orm.OAuthCode.code == code)\ .query(orm.OAuthCode)
.filter_by(code=code)
.first() .first()
)
if orm_code is None: if orm_code is None:
raise AuthCodeNotFound() raise AuthCodeNotFound()
else: else:
@@ -118,7 +121,6 @@ class AuthCodeStore(HubDBMixin, oauth2.store.AuthCodeStore):
data={'session_id': orm_code.session_id}, data={'session_id': orm_code.session_id},
) )
def save_code(self, authorization_code): def save_code(self, authorization_code):
""" """
Stores the data belonging to an authorization code token. Stores the data belonging to an authorization code token.
@@ -126,11 +128,29 @@ class AuthCodeStore(HubDBMixin, oauth2.store.AuthCodeStore):
:param authorization_code: An instance of :param authorization_code: An instance of
:class:`oauth2.datatype.AuthorizationCode`. :class:`oauth2.datatype.AuthorizationCode`.
""" """
orm_client = (
self.db
.query(orm.OAuthClient)
.filter_by(identifier=authorization_code.client_id)
.first()
)
if orm_client is None:
raise ValueError("No such client: %s" % authorization_code.client_id)
orm_user = (
self.db
.query(orm.User)
.filter_by(id=authorization_code.user_id)
.first()
)
if orm_user is None:
raise ValueError("No such user: %s" % authorization_code.user_id)
orm_code = orm.OAuthCode( orm_code = orm.OAuthCode(
client_id=authorization_code.client_id, client=orm_client,
code=authorization_code.code, code=authorization_code.code,
expires_at=authorization_code.expires_at, expires_at=authorization_code.expires_at,
user_id=authorization_code.user_id, user=orm_user,
redirect_uri=authorization_code.redirect_uri, redirect_uri=authorization_code.redirect_uri,
session_id=authorization_code.data.get('session_id', ''), session_id=authorization_code.data.get('session_id', ''),
) )
@@ -146,7 +166,7 @@ class AuthCodeStore(HubDBMixin, oauth2.store.AuthCodeStore):
:param code: The authorization code. :param code: The authorization code.
""" """
orm_code = self.db.query(orm.OAuthCode).filter(orm.OAuthCode.code == code).first() orm_code = self.db.query(orm.OAuthCode).filter_by(code=code).first()
if orm_code is not None: if orm_code is not None:
self.db.delete(orm_code) self.db.delete(orm_code)
self.db.commit() self.db.commit()
@@ -185,10 +205,12 @@ class ClientStore(HubDBMixin, oauth2.store.ClientStore):
:raises: :class:`oauth2.error.ClientNotFoundError` if no data could be retrieved for :raises: :class:`oauth2.error.ClientNotFoundError` if no data could be retrieved for
given client_id. given client_id.
""" """
orm_client = self.db\ orm_client = (
.query(orm.OAuthClient)\ self.db
.filter(orm.OAuthClient.identifier == client_id)\ .query(orm.OAuthClient)
.filter_by(identifier=client_id)
.first() .first()
)
if orm_client is None: if orm_client is None:
raise ClientNotFoundError() raise ClientNotFoundError()
return Client(identifier=client_id, return Client(identifier=client_id,
@@ -202,10 +224,12 @@ class ClientStore(HubDBMixin, oauth2.store.ClientStore):
hash its client_secret before putting it in the database. hash its client_secret before putting it in the database.
""" """
# clear existing clients with same ID # clear existing clients with same ID
for client in self.db\ for orm_client in (
self.db
.query(orm.OAuthClient)\ .query(orm.OAuthClient)\
.filter(orm.OAuthClient.identifier == client_id): .filter_by(identifier=client_id)
self.db.delete(client) ):
self.db.delete(orm_client)
self.db.commit() self.db.commit()
orm_client = orm.OAuthClient( orm_client = orm.OAuthClient(

View File

@@ -14,16 +14,19 @@ from tornado.log import app_log
from sqlalchemy.types import TypeDecorator, TEXT, LargeBinary from sqlalchemy.types import TypeDecorator, TEXT, LargeBinary
from sqlalchemy import ( from sqlalchemy import (
inspect, create_engine, event, inspect,
Column, Integer, ForeignKey, Unicode, Boolean, Column, Integer, ForeignKey, Unicode, Boolean,
DateTime, Enum DateTime, Enum, Table,
) )
from sqlalchemy.ext.declarative import declarative_base, declared_attr from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.interfaces import PoolListener from sqlalchemy.interfaces import PoolListener
from sqlalchemy.orm import backref, sessionmaker, relationship from sqlalchemy.orm import (
Session,
interfaces, object_session, relationship, sessionmaker,
)
from sqlalchemy.pool import StaticPool from sqlalchemy.pool import StaticPool
from sqlalchemy.sql.expression import bindparam from sqlalchemy.sql.expression import bindparam
from sqlalchemy import create_engine, Table
from .utils import ( from .utils import (
random_port, random_port,
@@ -88,7 +91,7 @@ class Group(Base):
__tablename__ = 'groups' __tablename__ = 'groups'
id = Column(Integer, primary_key=True, autoincrement=True) id = Column(Integer, primary_key=True, autoincrement=True)
name = Column(Unicode(255), unique=True) name = Column(Unicode(255), unique=True)
users = relationship('User', secondary='user_group_map', back_populates='groups') users = relationship('User', secondary='user_group_map', backref='groups')
def __repr__(self): def __repr__(self):
return "<%s %s (%i users)>" % ( return "<%s %s (%i users)>" % (
@@ -133,9 +136,6 @@ class User(Base):
"Spawner", "Spawner",
backref="user", backref="user",
cascade="all, delete-orphan", cascade="all, delete-orphan",
# can't use passive-deletes on this one
# because we rely on orm-level delete
# for Spawner.server
) )
@property @property
def orm_spawners(self): def orm_spawners(self):
@@ -149,13 +149,16 @@ class User(Base):
"APIToken", "APIToken",
backref="user", backref="user",
cascade="all, delete-orphan", cascade="all, delete-orphan",
passive_deletes=True,
) )
oauth_tokens = relationship( oauth_tokens = relationship(
"OAuthAccessToken", "OAuthAccessToken",
backref="user", backref="user",
cascade="all, delete-orphan", cascade="all, delete-orphan",
passive_deletes=True, )
oauth_codes = relationship(
"OAuthCode",
backref="user",
cascade="all, delete-orphan",
) )
cookie_id = Column(Unicode(255), default=new_token, nullable=False, unique=True) cookie_id = Column(Unicode(255), default=new_token, nullable=False, unique=True)
# User.state is actually Spawner state # User.state is actually Spawner state
@@ -164,8 +167,6 @@ class User(Base):
# Authenticators can store their state here: # Authenticators can store their state here:
# Encryption is handled elsewhere # Encryption is handled elsewhere
encrypted_auth_state = Column(LargeBinary) encrypted_auth_state = Column(LargeBinary)
# group mapping
groups = relationship('Group', secondary='user_group_map', back_populates='users')
def __repr__(self): def __repr__(self):
return "<{cls}({name} {running}/{total} running)>".format( return "<{cls}({name} {running}/{total} running)>".format(
@@ -234,7 +235,6 @@ class Service(Base):
"APIToken", "APIToken",
backref="service", backref="service",
cascade="all, delete-orphan", cascade="all, delete-orphan",
passive_deletes=True,
) )
# service-specific interface # service-specific interface
@@ -408,10 +408,10 @@ class APIToken(Hashed, Base):
orm_token.token = token orm_token.token = token
if user: if user:
assert user.id is not None assert user.id is not None
orm_token.user_id = user.id orm_token.user = user
else: else:
assert service.id is not None assert service.id is not None
orm_token.service_id = service.id orm_token.service = service
db.add(orm_token) db.add(orm_token)
db.commit() db.commit()
return token return token
@@ -498,6 +498,18 @@ class OAuthClient(Base):
secret = Column(Unicode(255)) secret = Column(Unicode(255))
redirect_uri = Column(Unicode(1023)) redirect_uri = Column(Unicode(1023))
access_tokens = relationship(
OAuthAccessToken,
backref='client',
cascade='all, delete-orphan',
)
codes = relationship(
OAuthCode,
backref='client',
cascade='all, delete-orphan',
)
# General database utilities
class DatabaseSchemaMismatch(Exception): class DatabaseSchemaMismatch(Exception):
"""Exception raised when the database schema version does not match """Exception raised when the database schema version does not match
@@ -512,6 +524,39 @@ class ForeignKeysListener(PoolListener):
dbapi_con.execute('pragma foreign_keys=ON') dbapi_con.execute('pragma foreign_keys=ON')
def _expire_relationship(target, relationship_prop):
"""Expire relationship backrefs
used when an object with relationships is deleted
"""
session = object_session(target)
# get peer objects to be expired
peers = getattr(target, relationship_prop.key)
if peers is None:
# no peer to clear
return
# many-to-many and one-to-many have a list of peers
# many-to-one has only one
if relationship_prop.direction is interfaces.MANYTOONE:
peers = [peers]
for obj in peers:
if inspect(obj).persistent:
session.expire(obj, [relationship_prop.back_populates])
@event.listens_for(Session, "persistent_to_deleted")
def _notify_deleted_relationships(session, obj):
"""Expire relationships when an object becomes deleted
Needed for
"""
mapper = inspect(obj).mapper
for prop in mapper.relationships:
if prop.back_populates:
_expire_relationship(obj, prop)
def check_db_revision(engine): def check_db_revision(engine):
"""Check the JupyterHub database revision """Check the JupyterHub database revision
@@ -576,7 +621,10 @@ def check_db_revision(engine):
)) ))
def new_session_factory(url="sqlite:///:memory:", reset=False, **kwargs): def new_session_factory(url="sqlite:///:memory:",
reset=False,
expire_on_commit=False,
**kwargs):
"""Create a new session at url""" """Create a new session at url"""
if url.startswith('sqlite'): if url.startswith('sqlite'):
kwargs.setdefault('connect_args', {'check_same_thread': False}) kwargs.setdefault('connect_args', {'check_same_thread': False})
@@ -599,5 +647,11 @@ def new_session_factory(url="sqlite:///:memory:", reset=False, **kwargs):
check_db_revision(engine) check_db_revision(engine)
Base.metadata.create_all(engine) Base.metadata.create_all(engine)
session_factory = sessionmaker(bind=engine) # We set expire_on_commit=False, since we don't actually need
# SQLAlchemy to expire objects after commiting - we don't expect
# concurrent runs of the hub talking to the same db. Turning
# this off gives us a major performance boost
session_factory = sessionmaker(bind=engine,
expire_on_commit=expire_on_commit,
)
return session_factory return session_factory

View File

@@ -222,12 +222,21 @@ class MockHub(JupyterHub):
def load_config_file(self, *args, **kwargs): def load_config_file(self, *args, **kwargs):
pass pass
def init_tornado_application(self): def init_tornado_application(self):
"""Instantiate the tornado Application object""" """Instantiate the tornado Application object"""
super().init_tornado_application() super().init_tornado_application()
# reconnect tornado_settings so that mocks can update the real thing # reconnect tornado_settings so that mocks can update the real thing
self.tornado_settings = self.users.settings = self.tornado_application.settings self.tornado_settings = self.users.settings = self.tornado_application.settings
def init_services(self):
# explicitly expire services before reinitializing
# this only happens in tests because re-initialize
# does not occur in a real instance
for service in self.db.query(orm.Service):
self.db.expire(service)
return super().init_services()
@gen.coroutine @gen.coroutine
def initialize(self, argv=None): def initialize(self, argv=None):
self.pid_file = NamedTemporaryFile(delete=False).name self.pid_file = NamedTemporaryFile(delete=False).name

View File

@@ -280,7 +280,7 @@ def test_get_self(app):
db.commit() db.commit()
oauth_token = orm.OAuthAccessToken( oauth_token = orm.OAuthAccessToken(
user=u.orm_user, user=u.orm_user,
client_id=oauth_client.identifier, client=oauth_client,
token=token, token=token,
grant_type=orm.GrantType.authorization_code, grant_type=orm.GrantType.authorization_code,
) )

View File

@@ -289,6 +289,7 @@ def test_spawner_delete_cascade(db):
# verify that server gets deleted # verify that server gets deleted
assert_not_found(db, orm.Server, server_id) assert_not_found(db, orm.Server, server_id)
assert user.orm_spawners == {}
def test_user_delete_cascade(db): def test_user_delete_cascade(db):
@@ -305,11 +306,11 @@ def test_user_delete_cascade(db):
spawner = orm.Spawner(user=user) spawner = orm.Spawner(user=user)
db.commit() db.commit()
spawner.server = server = orm.Server() spawner.server = server = orm.Server()
oauth_code = orm.OAuthCode(client_id=oauth_client.identifier, user_id=user.id) oauth_code = orm.OAuthCode(client=oauth_client, user=user)
db.add(oauth_code) db.add(oauth_code)
oauth_token = orm.OAuthAccessToken( oauth_token = orm.OAuthAccessToken(
client_id=oauth_client.identifier, client=oauth_client,
user_id=user.id, user=user,
grant_type=orm.GrantType.authorization_code, grant_type=orm.GrantType.authorization_code,
) )
db.add(oauth_token) db.add(oauth_token)
@@ -343,15 +344,16 @@ def test_oauth_client_delete_cascade(db):
# create a bunch of objects that reference the User # create a bunch of objects that reference the User
# these should all be deleted automatically when the user goes away # these should all be deleted automatically when the user goes away
oauth_code = orm.OAuthCode(client_id=oauth_client.identifier, user_id=user.id) oauth_code = orm.OAuthCode(client=oauth_client, user=user)
db.add(oauth_code) db.add(oauth_code)
oauth_token = orm.OAuthAccessToken( oauth_token = orm.OAuthAccessToken(
client_id=oauth_client.identifier, client=oauth_client,
user_id=user.id, user=user,
grant_type=orm.GrantType.authorization_code, grant_type=orm.GrantType.authorization_code,
) )
db.add(oauth_token) db.add(oauth_token)
db.commit() db.commit()
assert user.oauth_tokens == [oauth_token]
# record all of the ids # record all of the ids
oauth_code_id = oauth_code.id oauth_code_id = oauth_code.id
@@ -364,3 +366,80 @@ def test_oauth_client_delete_cascade(db):
# verify that everything gets deleted # verify that everything gets deleted
assert_not_found(db, orm.OAuthCode, oauth_code_id) assert_not_found(db, orm.OAuthCode, oauth_code_id)
assert_not_found(db, orm.OAuthAccessToken, oauth_token_id) assert_not_found(db, orm.OAuthAccessToken, oauth_token_id)
assert user.oauth_tokens == []
assert user.oauth_codes == []
def test_delete_token_cascade(db):
user = orm.User(name='mobs')
db.add(user)
db.commit()
user.new_api_token()
api_token = user.api_tokens[0]
db.delete(api_token)
db.commit()
assert user.api_tokens == []
def test_group_delete_cascade(db):
user1 = orm.User(name='user1')
user2 = orm.User(name='user2')
group1 = orm.Group(name='group1')
group2 = orm.Group(name='group2')
db.add(user1)
db.add(user2)
db.add(group1)
db.add(group2)
db.commit()
# add user to group via user.groups works
user1.groups.append(group1)
db.commit()
assert user1 in group1.users
# add user to group via groups.users works
group1.users.append(user2)
db.commit()
assert user1 in group1.users
assert user2 in group1.users
assert group1 in user1.groups
assert group1 in user2.groups
# fill out the connections (no new concept)
group2.users.append(user1)
group2.users.append(user2)
db.commit()
assert user1 in group1.users
assert user2 in group1.users
assert user1 in group2.users
assert user2 in group2.users
assert group1 in user1.groups
assert group1 in user2.groups
assert group2 in user1.groups
assert group2 in user2.groups
# now start deleting
# 1. remove group via user.groups
user1.groups.remove(group2)
db.commit()
assert user1 not in group2.users
assert group2 not in user1.groups
# 2. remove user via group.users
group1.users.remove(user2)
db.commit()
assert user2 not in group1.users
assert group1 not in user2.groups
# 3. delete group object
db.delete(group2)
db.commit()
assert group2 not in user1.groups
assert group2 not in user2.groups
# 4. delete user object
db.delete(user1)
db.commit()
assert user1 not in group1.users

View File

@@ -370,14 +370,10 @@ def test_spawner_delete_server(app):
assert spawner.server is not None assert spawner.server is not None
assert spawner.orm_spawner.server is not None assert spawner.orm_spawner.server is not None
# trigger delete via db # setting server = None triggers delete
db.delete(spawner.orm_spawner.server)
db.commit()
assert spawner.orm_spawner.server is None
# setting server = None also triggers delete
spawner.server = None spawner.server = None
db.commit() db.commit()
assert spawner.orm_spawner.server is None
# verify that the server was actually deleted from the db # verify that the server was actually deleted from the db
assert db.query(orm.Server).filter(orm.Server.id == server_id).first() is None assert db.query(orm.Server).filter(orm.Server.id == server_id).first() is None
# verify that both ORM and top-level references are None # verify that both ORM and top-level references are None