From c29a5ca4cec9029f47cd01c993c33b8a7cd439f6 Mon Sep 17 00:00:00 2001 From: Min RK Date: Tue, 10 Aug 2021 15:03:41 +0200 Subject: [PATCH] finish up db rollback checks - move catch_db_error to utils - tidy catch/propagate errors in prepare, get_current_user (cherry picked from commit 3bcc542e27fb8590d8e4ee2141e24f7e853163d1) Conflicts: jupyterhub/handlers/base.py NOTE(mriedem): The conflict is due to e6845a68f not being in 1.5.0. --- jupyterhub/app.py | 23 ++--------------------- jupyterhub/handlers/base.py | 19 ++++++++++--------- jupyterhub/utils.py | 21 ++++++++++++++++++++- 3 files changed, 32 insertions(+), 31 deletions(-) diff --git a/jupyterhub/app.py b/jupyterhub/app.py index d8a61055..e03bab34 100644 --- a/jupyterhub/app.py +++ b/jupyterhub/app.py @@ -5,9 +5,6 @@ import asyncio import atexit import binascii -import functools -import inspect -import json import logging import os import re @@ -92,6 +89,7 @@ from .pagination import Pagination from .proxy import Proxy, ConfigurableHTTPProxy from .traitlets import URLPrefix, Command, EntryPointType, Callable from .utils import ( + catch_db_error, maybe_future, url_path_join, print_stacks, @@ -1538,23 +1536,6 @@ class JupyterHub(Application): if os.path.exists(path) and not os.access(path, os.W_OK): self.log.error("%s cannot edit %s", user, path) - def catch_db_error(f): - """Catch and rollback database errors""" - - @functools.wraps(f) - async def catching(self, *args, **kwargs): - try: - r = f(self, *args, **kwargs) - if inspect.isawaitable(r): - r = await r - except SQLAlchemyError: - self.log.exception("Rolling back session due to database error") - self.db.rollback() - else: - return r - - return catching - def init_secrets(self): trait_name = 'cookie_secret' trait = self.traits()[trait_name] @@ -2035,7 +2016,7 @@ class JupyterHub(Application): await self._add_tokens(self.service_tokens, kind='service') await self._add_tokens(self.api_tokens, kind='user') - self.purge_expired_tokens() + await self.purge_expired_tokens() # purge expired tokens hourly # we don't need to be prompt about this # because expired tokens cannot be used anyway diff --git a/jupyterhub/handlers/base.py b/jupyterhub/handlers/base.py index 673c7290..25405cf3 100644 --- a/jupyterhub/handlers/base.py +++ b/jupyterhub/handlers/base.py @@ -81,12 +81,14 @@ class BaseHandler(RequestHandler): """ try: await self.get_current_user() - except SQLAlchemyError: - self.log.exception("Rolling back session due to database error") - self.db.rollback() - except Exception: - self.log.exception("Failed to get current user") + except Exception as e: + # ensure get_current_user is never called again for this handler, + # since it failed self._jupyterhub_user = None + self.log.exception("Failed to get current user") + if isinstance(e, SQLAlchemyError): + self.log.error("Rolling back session due to database error") + self.db.rollback() return await maybe_future(super().prepare()) @@ -426,12 +428,11 @@ class BaseHandler(RequestHandler): if user and isinstance(user, User): user = await self.refresh_auth(user) self._jupyterhub_user = user - except Exception as e: - if isinstance(e, SQLAlchemyError): - raise SQLAlchemyError() + except Exception: # don't let errors here raise more than once self._jupyterhub_user = None - self.log.exception("Error getting current user") + # but still raise, which will get handled in .prepare() + raise return self._jupyterhub_user @property diff --git a/jupyterhub/utils.py b/jupyterhub/utils.py index 3ce9d2b7..4b1893dd 100644 --- a/jupyterhub/utils.py +++ b/jupyterhub/utils.py @@ -4,9 +4,9 @@ import asyncio import concurrent.futures import errno +import functools import hashlib import inspect -import os import random import secrets import socket @@ -22,6 +22,7 @@ from hmac import compare_digest from operator import itemgetter from async_generator import aclosing +from sqlalchemy.exc import SQLAlchemyError from tornado import ioloop from tornado import web from tornado.httpclient import AsyncHTTPClient @@ -642,3 +643,21 @@ def get_accepted_mimetype(accept_header, choices=None): else: return mime return None + + +def catch_db_error(f): + """Catch and rollback database errors""" + + @functools.wraps(f) + async def catching(self, *args, **kwargs): + try: + r = f(self, *args, **kwargs) + if inspect.isawaitable(r): + r = await r + except SQLAlchemyError: + self.log.exception("Rolling back session due to database error") + self.db.rollback() + else: + return r + + return catching