diff --git a/jupyterhub/orm.py b/jupyterhub/orm.py index 658f67e2..b90cd4fb 100644 --- a/jupyterhub/orm.py +++ b/jupyterhub/orm.py @@ -14,7 +14,7 @@ from tornado.log import app_log from sqlalchemy.types import TypeDecorator, TEXT, LargeBinary from sqlalchemy import ( - create_engine, event, inspect, or_, + create_engine, event, exc, inspect, or_, select, Column, Integer, ForeignKey, Unicode, Boolean, DateTime, Enum, Table, ) @@ -575,7 +575,7 @@ def _expire_relationship(target, relationship_prop): def _notify_deleted_relationships(session, obj): """Expire relationships when an object becomes deleted - Needed for + Needed to keep relationships up to date. """ mapper = inspect(obj).mapper for prop in mapper.relationships: @@ -583,6 +583,52 @@ def _notify_deleted_relationships(session, obj): _expire_relationship(obj, prop) +def register_ping_connection(engine): + """Check connections before using them. + + Avoids database errors when using stale connections. + + From SQLAlchemy docs on pessimistic disconnect handling: + + https://docs.sqlalchemy.org/en/rel_1_1/core/pooling.html#disconnect-handling-pessimistic + """ + @event.listens_for(engine, "engine_connect") + def ping_connection(connection, branch): + if branch: + # "branch" refers to a sub-connection of a connection, + # we don't want to bother pinging on these. + return + + # turn off "close with result". This flag is only used with + # "connectionless" execution, otherwise will be False in any case + save_should_close_with_result = connection.should_close_with_result + connection.should_close_with_result = False + + try: + # run a SELECT 1. use a core select() so that + # the SELECT of a scalar value without a table is + # appropriately formatted for the backend + connection.scalar(select([1])) + except exc.DBAPIError as err: + # catch SQLAlchemy's DBAPIError, which is a wrapper + # for the DBAPI's exception. It includes a .connection_invalidated + # attribute which specifies if this connection is a "disconnect" + # condition, which is based on inspection of the original exception + # by the dialect in use. + if err.connection_invalidated: + app_log.error("Database connection error, attempting to reconnect: %s", err) + # run the same SELECT again - the connection will re-validate + # itself and establish a new connection. The disconnect detection + # here also causes the whole connection pool to be invalidated + # so that all stale connections are discarded. + connection.scalar(select([1])) + else: + raise + finally: + # restore "close with result" + connection.should_close_with_result = save_should_close_with_result + + def check_db_revision(engine): """Check the JupyterHub database revision @@ -661,10 +707,12 @@ def mysql_large_prefix_check(engine): else: return False + def add_row_format(base): for t in base.metadata.tables.values(): t.dialect_kwargs['mysql_ROW_FORMAT'] = 'DYNAMIC' + def new_session_factory(url="sqlite:///:memory:", reset=False, expire_on_commit=False, @@ -684,6 +732,9 @@ def new_session_factory(url="sqlite:///:memory:", kwargs.setdefault('poolclass', StaticPool) engine = create_engine(url, **kwargs) + # enable pessimistic disconnect handling + register_ping_connection(engine) + if reset: Base.metadata.drop_all(engine)