Merge pull request #3475 from minrk/async-check-db-locks

handle async functions in check_db_locks
This commit is contained in:
Min RK
2021-05-21 15:36:20 +02:00
committed by GitHub
3 changed files with 38 additions and 13 deletions

View File

@@ -3,6 +3,7 @@
# Distributed under the terms of the Modified BSD License.
import enum
import json
import warnings
from base64 import decodebytes
from base64 import encodebytes
from datetime import datetime
@@ -674,18 +675,29 @@ class APIToken(Hashed, Base):
orm_token.service = service
if expires_in is not None:
orm_token.expires_at = cls.now() + timedelta(seconds=expires_in)
db.add(orm_token)
# load default roles if they haven't been initiated
# correct to have this here? otherwise some tests fail
token_role = Role.find(db, 'token')
if not token_role:
# FIXME: remove this.
# Creating a token before the db has roles defined should raise an error.
# PR #3460 should let us fix it by ensuring default roles are defined
warnings.warn(
"Token created before default roles!", RuntimeWarning, stacklevel=2
)
default_roles = get_default_roles()
for role in default_roles:
create_role(db, role)
if roles is not None:
update_roles(db, entity=orm_token, roles=roles)
else:
assign_default_roles(db, entity=orm_token)
try:
if roles is not None:
update_roles(db, entity=orm_token, roles=roles)
else:
assign_default_roles(db, entity=orm_token)
except Exception:
db.delete(orm_token)
db.commit()
raise
db.commit()
return token

View File

@@ -435,11 +435,11 @@ def assign_default_roles(db, entity):
"""Assigns the default roles to an entity:
users and services get 'user' role, or admin role if they have admin flag
Tokens get 'token' role"""
default_token_role = orm.Role.find(db, 'token')
if isinstance(entity, orm.Group):
pass
elif isinstance(entity, orm.APIToken):
app_log.debug('Assigning default roles to tokens')
default_token_role = orm.Role.find(db, 'token')
if not entity.roles and (entity.user or entity.service) is not None:
default_token_role.tokens.append(entity)
app_log.info('Added role %s to token %s', default_token_role.name, entity)

View File

@@ -1,4 +1,5 @@
import asyncio
import inspect
import os
from concurrent.futures import ThreadPoolExecutor
@@ -80,14 +81,26 @@ def check_db_locks(func):
"""
def new_func(app, *args, **kwargs):
retval = func(app, *args, **kwargs)
maybe_future = func(app, *args, **kwargs)
temp_session = app.session_factory()
temp_session.execute('CREATE TABLE dummy (foo INT)')
temp_session.execute('DROP TABLE dummy')
temp_session.close()
def _check(_=None):
temp_session = app.session_factory()
try:
temp_session.execute('CREATE TABLE dummy (foo INT)')
temp_session.execute('DROP TABLE dummy')
finally:
temp_session.close()
return retval
async def await_then_check():
result = await maybe_future
_check()
return result
if inspect.isawaitable(maybe_future):
return await_then_check()
else:
_check()
return maybe_future
return new_func