finish up db rollback checks

- move catch_db_error to utils
- tidy catch/propagate errors in prepare, get_current_user
This commit is contained in:
Min RK
2021-08-10 15:03:41 +02:00
parent 044fb23a70
commit 3bcc542e27
3 changed files with 32 additions and 31 deletions

View File

@@ -5,9 +5,6 @@
import asyncio import asyncio
import atexit import atexit
import binascii import binascii
import functools
import inspect
import json
import logging import logging
import os import os
import re import re
@@ -94,6 +91,7 @@ from .pagination import Pagination
from .proxy import Proxy, ConfigurableHTTPProxy from .proxy import Proxy, ConfigurableHTTPProxy
from .traitlets import URLPrefix, Command, EntryPointType, Callable from .traitlets import URLPrefix, Command, EntryPointType, Callable
from .utils import ( from .utils import (
catch_db_error,
maybe_future, maybe_future,
url_path_join, url_path_join,
print_stacks, print_stacks,
@@ -1561,23 +1559,6 @@ class JupyterHub(Application):
if os.path.exists(path) and not os.access(path, os.W_OK): if os.path.exists(path) and not os.access(path, os.W_OK):
self.log.error("%s cannot edit %s", user, path) self.log.error("%s cannot edit %s", user, path)
def catch_db_error(f):
"""Catch and rollback database errors"""
@functools.wraps(f)
async def catching(self, *args, **kwargs):
try:
r = f(self, *args, **kwargs)
if inspect.isawaitable(r):
r = await r
except SQLAlchemyError:
self.log.exception("Rolling back session due to database error")
self.db.rollback()
else:
return r
return catching
def init_secrets(self): def init_secrets(self):
trait_name = 'cookie_secret' trait_name = 'cookie_secret'
trait = self.traits()[trait_name] trait = self.traits()[trait_name]
@@ -2236,7 +2217,7 @@ class JupyterHub(Application):
await self._add_tokens(self.service_tokens, kind='service') await self._add_tokens(self.service_tokens, kind='service')
await self._add_tokens(self.api_tokens, kind='user') await self._add_tokens(self.api_tokens, kind='user')
self.purge_expired_tokens() await self.purge_expired_tokens()
# purge expired tokens hourly # purge expired tokens hourly
# we don't need to be prompt about this # we don't need to be prompt about this
# because expired tokens cannot be used anyway # because expired tokens cannot be used anyway

View File

@@ -85,12 +85,14 @@ class BaseHandler(RequestHandler):
self.expanded_scopes = set() self.expanded_scopes = set()
try: try:
await self.get_current_user() await self.get_current_user()
except SQLAlchemyError: except Exception as e:
self.log.exception("Rolling back session due to database error") # ensure get_current_user is never called again for this handler,
self.db.rollback() # since it failed
except Exception:
self.log.exception("Failed to get current user")
self._jupyterhub_user = None self._jupyterhub_user = None
self.log.exception("Failed to get current user")
if isinstance(e, SQLAlchemyError):
self.log.error("Rolling back session due to database error")
self.db.rollback()
self._resolve_roles_and_scopes() self._resolve_roles_and_scopes()
return await maybe_future(super().prepare()) return await maybe_future(super().prepare())
@@ -414,12 +416,11 @@ class BaseHandler(RequestHandler):
if user and isinstance(user, User): if user and isinstance(user, User):
user = await self.refresh_auth(user) user = await self.refresh_auth(user)
self._jupyterhub_user = user self._jupyterhub_user = user
except Exception as e: except Exception:
if isinstance(e, SQLAlchemyError):
raise SQLAlchemyError()
# don't let errors here raise more than once # don't let errors here raise more than once
self._jupyterhub_user = None self._jupyterhub_user = None
self.log.exception("Error getting current user") # but still raise, which will get handled in .prepare()
raise
return self._jupyterhub_user return self._jupyterhub_user
def _resolve_roles_and_scopes(self): def _resolve_roles_and_scopes(self):

View File

@@ -4,9 +4,9 @@
import asyncio import asyncio
import concurrent.futures import concurrent.futures
import errno import errno
import functools
import hashlib import hashlib
import inspect import inspect
import os
import random import random
import secrets import secrets
import socket import socket
@@ -22,6 +22,7 @@ from hmac import compare_digest
from operator import itemgetter from operator import itemgetter
from async_generator import aclosing from async_generator import aclosing
from sqlalchemy.exc import SQLAlchemyError
from tornado import ioloop from tornado import ioloop
from tornado import web from tornado import web
from tornado.httpclient import AsyncHTTPClient from tornado.httpclient import AsyncHTTPClient
@@ -635,3 +636,21 @@ def get_accepted_mimetype(accept_header, choices=None):
else: else:
return mime return mime
return None return None
def catch_db_error(f):
"""Catch and rollback database errors"""
@functools.wraps(f)
async def catching(self, *args, **kwargs):
try:
r = f(self, *args, **kwargs)
if inspect.isawaitable(r):
r = await r
except SQLAlchemyError:
self.log.exception("Rolling back session due to database error")
self.db.rollback()
else:
return r
return catching