diff --git a/jupyterhub/orm.py b/jupyterhub/orm.py index e78fcc02..b16926ac 100644 --- a/jupyterhub/orm.py +++ b/jupyterhub/orm.py @@ -3,6 +3,7 @@ # Distributed under the terms of the Modified BSD License. import enum import json +import warnings from base64 import decodebytes from base64 import encodebytes from datetime import datetime @@ -674,18 +675,29 @@ class APIToken(Hashed, Base): orm_token.service = service if expires_in is not None: orm_token.expires_at = cls.now() + timedelta(seconds=expires_in) + db.add(orm_token) - # load default roles if they haven't been initiated - # correct to have this here? otherwise some tests fail token_role = Role.find(db, 'token') if not token_role: + # FIXME: remove this. + # Creating a token before the db has roles defined should raise an error. + # PR #3460 should let us fix it by ensuring default roles are defined + + warnings.warn( + "Token created before default roles!", RuntimeWarning, stacklevel=2 + ) default_roles = get_default_roles() for role in default_roles: create_role(db, role) - if roles is not None: - update_roles(db, entity=orm_token, roles=roles) - else: - assign_default_roles(db, entity=orm_token) + try: + if roles is not None: + update_roles(db, entity=orm_token, roles=roles) + else: + assign_default_roles(db, entity=orm_token) + except Exception: + db.delete(orm_token) + db.commit() + raise db.commit() return token diff --git a/jupyterhub/roles.py b/jupyterhub/roles.py index d6ade30c..dbd2e75a 100644 --- a/jupyterhub/roles.py +++ b/jupyterhub/roles.py @@ -435,11 +435,11 @@ def assign_default_roles(db, entity): """Assigns the default roles to an entity: users and services get 'user' role, or admin role if they have admin flag Tokens get 'token' role""" - default_token_role = orm.Role.find(db, 'token') if isinstance(entity, orm.Group): pass elif isinstance(entity, orm.APIToken): app_log.debug('Assigning default roles to tokens') + default_token_role = orm.Role.find(db, 'token') if not entity.roles and (entity.user or entity.service) is not None: default_token_role.tokens.append(entity) app_log.info('Added role %s to token %s', default_token_role.name, entity) diff --git a/jupyterhub/tests/utils.py b/jupyterhub/tests/utils.py index 73ddcf51..d300c84a 100644 --- a/jupyterhub/tests/utils.py +++ b/jupyterhub/tests/utils.py @@ -1,4 +1,5 @@ import asyncio +import inspect import os from concurrent.futures import ThreadPoolExecutor @@ -80,14 +81,26 @@ def check_db_locks(func): """ def new_func(app, *args, **kwargs): - retval = func(app, *args, **kwargs) + maybe_future = func(app, *args, **kwargs) - temp_session = app.session_factory() - temp_session.execute('CREATE TABLE dummy (foo INT)') - temp_session.execute('DROP TABLE dummy') - temp_session.close() + def _check(_=None): + temp_session = app.session_factory() + try: + temp_session.execute('CREATE TABLE dummy (foo INT)') + temp_session.execute('DROP TABLE dummy') + finally: + temp_session.close() - return retval + async def await_then_check(): + result = await maybe_future + _check() + return result + + if inspect.isawaitable(maybe_future): + return await_then_check() + else: + _check() + return maybe_future return new_func