mirror of
https://github.com/jupyterhub/jupyterhub.git
synced 2025-10-18 15:33:02 +00:00
Merge pull request #1809 from minrk/no-expire-again
don't expire objects on commit
This commit is contained in:
@@ -70,11 +70,12 @@ class AccessTokenStore(HubDBMixin, oauth2.store.AccessTokenStore):
|
||||
|
||||
"""
|
||||
|
||||
user = self.db.query(orm.User).filter(orm.User.id == access_token.user_id).first()
|
||||
user = self.db.query(orm.User).filter_by(id=access_token.user_id).first()
|
||||
if user is None:
|
||||
raise ValueError("No user for access token: %s" % access_token.user_id)
|
||||
client = self.db.query(orm.OAuthClient).filter_by(identifier=access_token.client_id).first()
|
||||
orm_access_token = orm.OAuthAccessToken(
|
||||
client_id=access_token.client_id,
|
||||
client=client,
|
||||
grant_type=access_token.grant_type,
|
||||
expires_at=access_token.expires_at,
|
||||
refresh_token=access_token.refresh_token,
|
||||
@@ -101,10 +102,12 @@ class AuthCodeStore(HubDBMixin, oauth2.store.AuthCodeStore):
|
||||
given code.
|
||||
|
||||
"""
|
||||
orm_code = self.db\
|
||||
.query(orm.OAuthCode)\
|
||||
.filter(orm.OAuthCode.code == code)\
|
||||
orm_code = (
|
||||
self.db
|
||||
.query(orm.OAuthCode)
|
||||
.filter_by(code=code)
|
||||
.first()
|
||||
)
|
||||
if orm_code is None:
|
||||
raise AuthCodeNotFound()
|
||||
else:
|
||||
@@ -118,7 +121,6 @@ class AuthCodeStore(HubDBMixin, oauth2.store.AuthCodeStore):
|
||||
data={'session_id': orm_code.session_id},
|
||||
)
|
||||
|
||||
|
||||
def save_code(self, authorization_code):
|
||||
"""
|
||||
Stores the data belonging to an authorization code token.
|
||||
@@ -126,11 +128,29 @@ class AuthCodeStore(HubDBMixin, oauth2.store.AuthCodeStore):
|
||||
:param authorization_code: An instance of
|
||||
:class:`oauth2.datatype.AuthorizationCode`.
|
||||
"""
|
||||
orm_client = (
|
||||
self.db
|
||||
.query(orm.OAuthClient)
|
||||
.filter_by(identifier=authorization_code.client_id)
|
||||
.first()
|
||||
)
|
||||
if orm_client is None:
|
||||
raise ValueError("No such client: %s" % authorization_code.client_id)
|
||||
|
||||
orm_user = (
|
||||
self.db
|
||||
.query(orm.User)
|
||||
.filter_by(id=authorization_code.user_id)
|
||||
.first()
|
||||
)
|
||||
if orm_user is None:
|
||||
raise ValueError("No such user: %s" % authorization_code.user_id)
|
||||
|
||||
orm_code = orm.OAuthCode(
|
||||
client_id=authorization_code.client_id,
|
||||
client=orm_client,
|
||||
code=authorization_code.code,
|
||||
expires_at=authorization_code.expires_at,
|
||||
user_id=authorization_code.user_id,
|
||||
user=orm_user,
|
||||
redirect_uri=authorization_code.redirect_uri,
|
||||
session_id=authorization_code.data.get('session_id', ''),
|
||||
)
|
||||
@@ -146,7 +166,7 @@ class AuthCodeStore(HubDBMixin, oauth2.store.AuthCodeStore):
|
||||
|
||||
:param code: The authorization code.
|
||||
"""
|
||||
orm_code = self.db.query(orm.OAuthCode).filter(orm.OAuthCode.code == code).first()
|
||||
orm_code = self.db.query(orm.OAuthCode).filter_by(code=code).first()
|
||||
if orm_code is not None:
|
||||
self.db.delete(orm_code)
|
||||
self.db.commit()
|
||||
@@ -166,7 +186,7 @@ class HashComparable:
|
||||
"""
|
||||
def __init__(self, hashed_token):
|
||||
self.hashed_token = hashed_token
|
||||
|
||||
|
||||
def __repr__(self):
|
||||
return "<{} '{}'>".format(self.__class__.__name__, self.hashed_token)
|
||||
|
||||
@@ -185,10 +205,12 @@ class ClientStore(HubDBMixin, oauth2.store.ClientStore):
|
||||
:raises: :class:`oauth2.error.ClientNotFoundError` if no data could be retrieved for
|
||||
given client_id.
|
||||
"""
|
||||
orm_client = self.db\
|
||||
.query(orm.OAuthClient)\
|
||||
.filter(orm.OAuthClient.identifier == client_id)\
|
||||
orm_client = (
|
||||
self.db
|
||||
.query(orm.OAuthClient)
|
||||
.filter_by(identifier=client_id)
|
||||
.first()
|
||||
)
|
||||
if orm_client is None:
|
||||
raise ClientNotFoundError()
|
||||
return Client(identifier=client_id,
|
||||
@@ -202,10 +224,12 @@ class ClientStore(HubDBMixin, oauth2.store.ClientStore):
|
||||
hash its client_secret before putting it in the database.
|
||||
"""
|
||||
# clear existing clients with same ID
|
||||
for client in self.db\
|
||||
.query(orm.OAuthClient)\
|
||||
.filter(orm.OAuthClient.identifier == client_id):
|
||||
self.db.delete(client)
|
||||
for orm_client in (
|
||||
self.db
|
||||
.query(orm.OAuthClient)\
|
||||
.filter_by(identifier=client_id)
|
||||
):
|
||||
self.db.delete(orm_client)
|
||||
self.db.commit()
|
||||
|
||||
orm_client = orm.OAuthClient(
|
||||
|
@@ -14,16 +14,19 @@ from tornado.log import app_log
|
||||
|
||||
from sqlalchemy.types import TypeDecorator, TEXT, LargeBinary
|
||||
from sqlalchemy import (
|
||||
inspect,
|
||||
create_engine, event, inspect,
|
||||
Column, Integer, ForeignKey, Unicode, Boolean,
|
||||
DateTime, Enum
|
||||
DateTime, Enum, Table,
|
||||
)
|
||||
from sqlalchemy.ext.declarative import declarative_base, declared_attr
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.interfaces import PoolListener
|
||||
from sqlalchemy.orm import backref, sessionmaker, relationship
|
||||
from sqlalchemy.orm import (
|
||||
Session,
|
||||
interfaces, object_session, relationship, sessionmaker,
|
||||
)
|
||||
|
||||
from sqlalchemy.pool import StaticPool
|
||||
from sqlalchemy.sql.expression import bindparam
|
||||
from sqlalchemy import create_engine, Table
|
||||
|
||||
from .utils import (
|
||||
random_port,
|
||||
@@ -88,7 +91,7 @@ class Group(Base):
|
||||
__tablename__ = 'groups'
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
name = Column(Unicode(255), unique=True)
|
||||
users = relationship('User', secondary='user_group_map', back_populates='groups')
|
||||
users = relationship('User', secondary='user_group_map', backref='groups')
|
||||
|
||||
def __repr__(self):
|
||||
return "<%s %s (%i users)>" % (
|
||||
@@ -133,9 +136,6 @@ class User(Base):
|
||||
"Spawner",
|
||||
backref="user",
|
||||
cascade="all, delete-orphan",
|
||||
# can't use passive-deletes on this one
|
||||
# because we rely on orm-level delete
|
||||
# for Spawner.server
|
||||
)
|
||||
@property
|
||||
def orm_spawners(self):
|
||||
@@ -149,13 +149,16 @@ class User(Base):
|
||||
"APIToken",
|
||||
backref="user",
|
||||
cascade="all, delete-orphan",
|
||||
passive_deletes=True,
|
||||
)
|
||||
oauth_tokens = relationship(
|
||||
"OAuthAccessToken",
|
||||
backref="user",
|
||||
cascade="all, delete-orphan",
|
||||
passive_deletes=True,
|
||||
)
|
||||
oauth_codes = relationship(
|
||||
"OAuthCode",
|
||||
backref="user",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
cookie_id = Column(Unicode(255), default=new_token, nullable=False, unique=True)
|
||||
# User.state is actually Spawner state
|
||||
@@ -164,8 +167,6 @@ class User(Base):
|
||||
# Authenticators can store their state here:
|
||||
# Encryption is handled elsewhere
|
||||
encrypted_auth_state = Column(LargeBinary)
|
||||
# group mapping
|
||||
groups = relationship('Group', secondary='user_group_map', back_populates='users')
|
||||
|
||||
def __repr__(self):
|
||||
return "<{cls}({name} {running}/{total} running)>".format(
|
||||
@@ -234,7 +235,6 @@ class Service(Base):
|
||||
"APIToken",
|
||||
backref="service",
|
||||
cascade="all, delete-orphan",
|
||||
passive_deletes=True,
|
||||
)
|
||||
|
||||
# service-specific interface
|
||||
@@ -408,10 +408,10 @@ class APIToken(Hashed, Base):
|
||||
orm_token.token = token
|
||||
if user:
|
||||
assert user.id is not None
|
||||
orm_token.user_id = user.id
|
||||
orm_token.user = user
|
||||
else:
|
||||
assert service.id is not None
|
||||
orm_token.service_id = service.id
|
||||
orm_token.service = service
|
||||
db.add(orm_token)
|
||||
db.commit()
|
||||
return token
|
||||
@@ -498,6 +498,18 @@ class OAuthClient(Base):
|
||||
secret = Column(Unicode(255))
|
||||
redirect_uri = Column(Unicode(1023))
|
||||
|
||||
access_tokens = relationship(
|
||||
OAuthAccessToken,
|
||||
backref='client',
|
||||
cascade='all, delete-orphan',
|
||||
)
|
||||
codes = relationship(
|
||||
OAuthCode,
|
||||
backref='client',
|
||||
cascade='all, delete-orphan',
|
||||
)
|
||||
|
||||
# General database utilities
|
||||
|
||||
class DatabaseSchemaMismatch(Exception):
|
||||
"""Exception raised when the database schema version does not match
|
||||
@@ -512,6 +524,39 @@ class ForeignKeysListener(PoolListener):
|
||||
dbapi_con.execute('pragma foreign_keys=ON')
|
||||
|
||||
|
||||
def _expire_relationship(target, relationship_prop):
|
||||
"""Expire relationship backrefs
|
||||
|
||||
used when an object with relationships is deleted
|
||||
"""
|
||||
|
||||
session = object_session(target)
|
||||
# get peer objects to be expired
|
||||
peers = getattr(target, relationship_prop.key)
|
||||
if peers is None:
|
||||
# no peer to clear
|
||||
return
|
||||
# many-to-many and one-to-many have a list of peers
|
||||
# many-to-one has only one
|
||||
if relationship_prop.direction is interfaces.MANYTOONE:
|
||||
peers = [peers]
|
||||
for obj in peers:
|
||||
if inspect(obj).persistent:
|
||||
session.expire(obj, [relationship_prop.back_populates])
|
||||
|
||||
|
||||
@event.listens_for(Session, "persistent_to_deleted")
|
||||
def _notify_deleted_relationships(session, obj):
|
||||
"""Expire relationships when an object becomes deleted
|
||||
|
||||
Needed for
|
||||
"""
|
||||
mapper = inspect(obj).mapper
|
||||
for prop in mapper.relationships:
|
||||
if prop.back_populates:
|
||||
_expire_relationship(obj, prop)
|
||||
|
||||
|
||||
def check_db_revision(engine):
|
||||
"""Check the JupyterHub database revision
|
||||
|
||||
@@ -576,7 +621,10 @@ def check_db_revision(engine):
|
||||
))
|
||||
|
||||
|
||||
def new_session_factory(url="sqlite:///:memory:", reset=False, **kwargs):
|
||||
def new_session_factory(url="sqlite:///:memory:",
|
||||
reset=False,
|
||||
expire_on_commit=False,
|
||||
**kwargs):
|
||||
"""Create a new session at url"""
|
||||
if url.startswith('sqlite'):
|
||||
kwargs.setdefault('connect_args', {'check_same_thread': False})
|
||||
@@ -599,5 +647,11 @@ def new_session_factory(url="sqlite:///:memory:", reset=False, **kwargs):
|
||||
check_db_revision(engine)
|
||||
Base.metadata.create_all(engine)
|
||||
|
||||
session_factory = sessionmaker(bind=engine)
|
||||
# We set expire_on_commit=False, since we don't actually need
|
||||
# SQLAlchemy to expire objects after commiting - we don't expect
|
||||
# concurrent runs of the hub talking to the same db. Turning
|
||||
# this off gives us a major performance boost
|
||||
session_factory = sessionmaker(bind=engine,
|
||||
expire_on_commit=expire_on_commit,
|
||||
)
|
||||
return session_factory
|
||||
|
@@ -222,12 +222,21 @@ class MockHub(JupyterHub):
|
||||
|
||||
def load_config_file(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def init_tornado_application(self):
|
||||
"""Instantiate the tornado Application object"""
|
||||
super().init_tornado_application()
|
||||
# reconnect tornado_settings so that mocks can update the real thing
|
||||
self.tornado_settings = self.users.settings = self.tornado_application.settings
|
||||
|
||||
def init_services(self):
|
||||
# explicitly expire services before reinitializing
|
||||
# this only happens in tests because re-initialize
|
||||
# does not occur in a real instance
|
||||
for service in self.db.query(orm.Service):
|
||||
self.db.expire(service)
|
||||
return super().init_services()
|
||||
|
||||
@gen.coroutine
|
||||
def initialize(self, argv=None):
|
||||
self.pid_file = NamedTemporaryFile(delete=False).name
|
||||
|
@@ -280,7 +280,7 @@ def test_get_self(app):
|
||||
db.commit()
|
||||
oauth_token = orm.OAuthAccessToken(
|
||||
user=u.orm_user,
|
||||
client_id=oauth_client.identifier,
|
||||
client=oauth_client,
|
||||
token=token,
|
||||
grant_type=orm.GrantType.authorization_code,
|
||||
)
|
||||
|
@@ -289,6 +289,7 @@ def test_spawner_delete_cascade(db):
|
||||
|
||||
# verify that server gets deleted
|
||||
assert_not_found(db, orm.Server, server_id)
|
||||
assert user.orm_spawners == {}
|
||||
|
||||
|
||||
def test_user_delete_cascade(db):
|
||||
@@ -305,11 +306,11 @@ def test_user_delete_cascade(db):
|
||||
spawner = orm.Spawner(user=user)
|
||||
db.commit()
|
||||
spawner.server = server = orm.Server()
|
||||
oauth_code = orm.OAuthCode(client_id=oauth_client.identifier, user_id=user.id)
|
||||
oauth_code = orm.OAuthCode(client=oauth_client, user=user)
|
||||
db.add(oauth_code)
|
||||
oauth_token = orm.OAuthAccessToken(
|
||||
client_id=oauth_client.identifier,
|
||||
user_id=user.id,
|
||||
client=oauth_client,
|
||||
user=user,
|
||||
grant_type=orm.GrantType.authorization_code,
|
||||
)
|
||||
db.add(oauth_token)
|
||||
@@ -343,15 +344,16 @@ def test_oauth_client_delete_cascade(db):
|
||||
|
||||
# create a bunch of objects that reference the User
|
||||
# these should all be deleted automatically when the user goes away
|
||||
oauth_code = orm.OAuthCode(client_id=oauth_client.identifier, user_id=user.id)
|
||||
oauth_code = orm.OAuthCode(client=oauth_client, user=user)
|
||||
db.add(oauth_code)
|
||||
oauth_token = orm.OAuthAccessToken(
|
||||
client_id=oauth_client.identifier,
|
||||
user_id=user.id,
|
||||
client=oauth_client,
|
||||
user=user,
|
||||
grant_type=orm.GrantType.authorization_code,
|
||||
)
|
||||
db.add(oauth_token)
|
||||
db.commit()
|
||||
assert user.oauth_tokens == [oauth_token]
|
||||
|
||||
# record all of the ids
|
||||
oauth_code_id = oauth_code.id
|
||||
@@ -364,3 +366,80 @@ def test_oauth_client_delete_cascade(db):
|
||||
# verify that everything gets deleted
|
||||
assert_not_found(db, orm.OAuthCode, oauth_code_id)
|
||||
assert_not_found(db, orm.OAuthAccessToken, oauth_token_id)
|
||||
assert user.oauth_tokens == []
|
||||
assert user.oauth_codes == []
|
||||
|
||||
|
||||
def test_delete_token_cascade(db):
|
||||
user = orm.User(name='mobs')
|
||||
db.add(user)
|
||||
db.commit()
|
||||
user.new_api_token()
|
||||
api_token = user.api_tokens[0]
|
||||
db.delete(api_token)
|
||||
db.commit()
|
||||
assert user.api_tokens == []
|
||||
|
||||
|
||||
def test_group_delete_cascade(db):
|
||||
user1 = orm.User(name='user1')
|
||||
user2 = orm.User(name='user2')
|
||||
group1 = orm.Group(name='group1')
|
||||
group2 = orm.Group(name='group2')
|
||||
db.add(user1)
|
||||
db.add(user2)
|
||||
db.add(group1)
|
||||
db.add(group2)
|
||||
db.commit()
|
||||
# add user to group via user.groups works
|
||||
user1.groups.append(group1)
|
||||
db.commit()
|
||||
assert user1 in group1.users
|
||||
|
||||
# add user to group via groups.users works
|
||||
group1.users.append(user2)
|
||||
db.commit()
|
||||
assert user1 in group1.users
|
||||
assert user2 in group1.users
|
||||
assert group1 in user1.groups
|
||||
assert group1 in user2.groups
|
||||
|
||||
# fill out the connections (no new concept)
|
||||
group2.users.append(user1)
|
||||
group2.users.append(user2)
|
||||
db.commit()
|
||||
assert user1 in group1.users
|
||||
assert user2 in group1.users
|
||||
assert user1 in group2.users
|
||||
assert user2 in group2.users
|
||||
assert group1 in user1.groups
|
||||
assert group1 in user2.groups
|
||||
assert group2 in user1.groups
|
||||
assert group2 in user2.groups
|
||||
|
||||
# now start deleting
|
||||
# 1. remove group via user.groups
|
||||
user1.groups.remove(group2)
|
||||
db.commit()
|
||||
assert user1 not in group2.users
|
||||
assert group2 not in user1.groups
|
||||
|
||||
# 2. remove user via group.users
|
||||
group1.users.remove(user2)
|
||||
db.commit()
|
||||
assert user2 not in group1.users
|
||||
assert group1 not in user2.groups
|
||||
|
||||
# 3. delete group object
|
||||
db.delete(group2)
|
||||
db.commit()
|
||||
assert group2 not in user1.groups
|
||||
assert group2 not in user2.groups
|
||||
|
||||
# 4. delete user object
|
||||
db.delete(user1)
|
||||
db.commit()
|
||||
assert user1 not in group1.users
|
||||
|
||||
|
||||
|
||||
|
@@ -370,14 +370,10 @@ def test_spawner_delete_server(app):
|
||||
assert spawner.server is not None
|
||||
assert spawner.orm_spawner.server is not None
|
||||
|
||||
# trigger delete via db
|
||||
db.delete(spawner.orm_spawner.server)
|
||||
db.commit()
|
||||
assert spawner.orm_spawner.server is None
|
||||
|
||||
# setting server = None also triggers delete
|
||||
# setting server = None triggers delete
|
||||
spawner.server = None
|
||||
db.commit()
|
||||
assert spawner.orm_spawner.server is None
|
||||
# verify that the server was actually deleted from the db
|
||||
assert db.query(orm.Server).filter(orm.Server.id == server_id).first() is None
|
||||
# verify that both ORM and top-level references are None
|
||||
|
Reference in New Issue
Block a user