mirror of
https://github.com/jupyterhub/jupyterhub.git
synced 2025-10-18 07:23:00 +00:00
Merge pull request #3475 from minrk/async-check-db-locks
handle async functions in check_db_locks
This commit is contained in:
@@ -3,6 +3,7 @@
|
|||||||
# Distributed under the terms of the Modified BSD License.
|
# Distributed under the terms of the Modified BSD License.
|
||||||
import enum
|
import enum
|
||||||
import json
|
import json
|
||||||
|
import warnings
|
||||||
from base64 import decodebytes
|
from base64 import decodebytes
|
||||||
from base64 import encodebytes
|
from base64 import encodebytes
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
@@ -674,18 +675,29 @@ class APIToken(Hashed, Base):
|
|||||||
orm_token.service = service
|
orm_token.service = service
|
||||||
if expires_in is not None:
|
if expires_in is not None:
|
||||||
orm_token.expires_at = cls.now() + timedelta(seconds=expires_in)
|
orm_token.expires_at = cls.now() + timedelta(seconds=expires_in)
|
||||||
|
|
||||||
db.add(orm_token)
|
db.add(orm_token)
|
||||||
# load default roles if they haven't been initiated
|
|
||||||
# correct to have this here? otherwise some tests fail
|
|
||||||
token_role = Role.find(db, 'token')
|
token_role = Role.find(db, 'token')
|
||||||
if not token_role:
|
if not token_role:
|
||||||
|
# FIXME: remove this.
|
||||||
|
# Creating a token before the db has roles defined should raise an error.
|
||||||
|
# PR #3460 should let us fix it by ensuring default roles are defined
|
||||||
|
|
||||||
|
warnings.warn(
|
||||||
|
"Token created before default roles!", RuntimeWarning, stacklevel=2
|
||||||
|
)
|
||||||
default_roles = get_default_roles()
|
default_roles = get_default_roles()
|
||||||
for role in default_roles:
|
for role in default_roles:
|
||||||
create_role(db, role)
|
create_role(db, role)
|
||||||
if roles is not None:
|
try:
|
||||||
update_roles(db, entity=orm_token, roles=roles)
|
if roles is not None:
|
||||||
else:
|
update_roles(db, entity=orm_token, roles=roles)
|
||||||
assign_default_roles(db, entity=orm_token)
|
else:
|
||||||
|
assign_default_roles(db, entity=orm_token)
|
||||||
|
except Exception:
|
||||||
|
db.delete(orm_token)
|
||||||
|
db.commit()
|
||||||
|
raise
|
||||||
|
|
||||||
db.commit()
|
db.commit()
|
||||||
return token
|
return token
|
||||||
|
@@ -435,11 +435,11 @@ def assign_default_roles(db, entity):
|
|||||||
"""Assigns the default roles to an entity:
|
"""Assigns the default roles to an entity:
|
||||||
users and services get 'user' role, or admin role if they have admin flag
|
users and services get 'user' role, or admin role if they have admin flag
|
||||||
Tokens get 'token' role"""
|
Tokens get 'token' role"""
|
||||||
default_token_role = orm.Role.find(db, 'token')
|
|
||||||
if isinstance(entity, orm.Group):
|
if isinstance(entity, orm.Group):
|
||||||
pass
|
pass
|
||||||
elif isinstance(entity, orm.APIToken):
|
elif isinstance(entity, orm.APIToken):
|
||||||
app_log.debug('Assigning default roles to tokens')
|
app_log.debug('Assigning default roles to tokens')
|
||||||
|
default_token_role = orm.Role.find(db, 'token')
|
||||||
if not entity.roles and (entity.user or entity.service) is not None:
|
if not entity.roles and (entity.user or entity.service) is not None:
|
||||||
default_token_role.tokens.append(entity)
|
default_token_role.tokens.append(entity)
|
||||||
app_log.info('Added role %s to token %s', default_token_role.name, entity)
|
app_log.info('Added role %s to token %s', default_token_role.name, entity)
|
||||||
|
@@ -1,4 +1,5 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
|
import inspect
|
||||||
import os
|
import os
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
|
||||||
@@ -80,14 +81,26 @@ def check_db_locks(func):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def new_func(app, *args, **kwargs):
|
def new_func(app, *args, **kwargs):
|
||||||
retval = func(app, *args, **kwargs)
|
maybe_future = func(app, *args, **kwargs)
|
||||||
|
|
||||||
temp_session = app.session_factory()
|
def _check(_=None):
|
||||||
temp_session.execute('CREATE TABLE dummy (foo INT)')
|
temp_session = app.session_factory()
|
||||||
temp_session.execute('DROP TABLE dummy')
|
try:
|
||||||
temp_session.close()
|
temp_session.execute('CREATE TABLE dummy (foo INT)')
|
||||||
|
temp_session.execute('DROP TABLE dummy')
|
||||||
|
finally:
|
||||||
|
temp_session.close()
|
||||||
|
|
||||||
return retval
|
async def await_then_check():
|
||||||
|
result = await maybe_future
|
||||||
|
_check()
|
||||||
|
return result
|
||||||
|
|
||||||
|
if inspect.isawaitable(maybe_future):
|
||||||
|
return await_then_check()
|
||||||
|
else:
|
||||||
|
_check()
|
||||||
|
return maybe_future
|
||||||
|
|
||||||
return new_func
|
return new_func
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user