diff --git a/jupyterhub/app.py b/jupyterhub/app.py index d8a61055..e03bab34 100644 --- a/jupyterhub/app.py +++ b/jupyterhub/app.py @@ -5,9 +5,6 @@ import asyncio import atexit import binascii -import functools -import inspect -import json import logging import os import re @@ -92,6 +89,7 @@ from .pagination import Pagination from .proxy import Proxy, ConfigurableHTTPProxy from .traitlets import URLPrefix, Command, EntryPointType, Callable from .utils import ( + catch_db_error, maybe_future, url_path_join, print_stacks, @@ -1538,23 +1536,6 @@ class JupyterHub(Application): if os.path.exists(path) and not os.access(path, os.W_OK): self.log.error("%s cannot edit %s", user, path) - def catch_db_error(f): - """Catch and rollback database errors""" - - @functools.wraps(f) - async def catching(self, *args, **kwargs): - try: - r = f(self, *args, **kwargs) - if inspect.isawaitable(r): - r = await r - except SQLAlchemyError: - self.log.exception("Rolling back session due to database error") - self.db.rollback() - else: - return r - - return catching - def init_secrets(self): trait_name = 'cookie_secret' trait = self.traits()[trait_name] @@ -2035,7 +2016,7 @@ class JupyterHub(Application): await self._add_tokens(self.service_tokens, kind='service') await self._add_tokens(self.api_tokens, kind='user') - self.purge_expired_tokens() + await self.purge_expired_tokens() # purge expired tokens hourly # we don't need to be prompt about this # because expired tokens cannot be used anyway diff --git a/jupyterhub/handlers/base.py b/jupyterhub/handlers/base.py index 673c7290..25405cf3 100644 --- a/jupyterhub/handlers/base.py +++ b/jupyterhub/handlers/base.py @@ -81,12 +81,14 @@ class BaseHandler(RequestHandler): """ try: await self.get_current_user() - except SQLAlchemyError: - self.log.exception("Rolling back session due to database error") - self.db.rollback() - except Exception: - self.log.exception("Failed to get current user") + except Exception as e: + # ensure get_current_user is never called again for this handler, + # since it failed self._jupyterhub_user = None + self.log.exception("Failed to get current user") + if isinstance(e, SQLAlchemyError): + self.log.error("Rolling back session due to database error") + self.db.rollback() return await maybe_future(super().prepare()) @@ -426,12 +428,11 @@ class BaseHandler(RequestHandler): if user and isinstance(user, User): user = await self.refresh_auth(user) self._jupyterhub_user = user - except Exception as e: - if isinstance(e, SQLAlchemyError): - raise SQLAlchemyError() + except Exception: # don't let errors here raise more than once self._jupyterhub_user = None - self.log.exception("Error getting current user") + # but still raise, which will get handled in .prepare() + raise return self._jupyterhub_user @property diff --git a/jupyterhub/utils.py b/jupyterhub/utils.py index 3ce9d2b7..4b1893dd 100644 --- a/jupyterhub/utils.py +++ b/jupyterhub/utils.py @@ -4,9 +4,9 @@ import asyncio import concurrent.futures import errno +import functools import hashlib import inspect -import os import random import secrets import socket @@ -22,6 +22,7 @@ from hmac import compare_digest from operator import itemgetter from async_generator import aclosing +from sqlalchemy.exc import SQLAlchemyError from tornado import ioloop from tornado import web from tornado.httpclient import AsyncHTTPClient @@ -642,3 +643,21 @@ def get_accepted_mimetype(accept_header, choices=None): else: return mime return None + + +def catch_db_error(f): + """Catch and rollback database errors""" + + @functools.wraps(f) + async def catching(self, *args, **kwargs): + try: + r = f(self, *args, **kwargs) + if inspect.isawaitable(r): + r = await r + except SQLAlchemyError: + self.log.exception("Rolling back session due to database error") + self.db.rollback() + else: + return r + + return catching