diff --git a/jupyterhub/alembic/versions/833da8570507_rbac.py b/jupyterhub/alembic/versions/833da8570507_rbac.py new file mode 100644 index 00000000..a060f10c --- /dev/null +++ b/jupyterhub/alembic/versions/833da8570507_rbac.py @@ -0,0 +1,119 @@ +"""rbac + +Revision ID: 833da8570507 +Revises: 4dc2d5a8c53c +Create Date: 2021-02-17 15:03:04.360368 + +""" +# revision identifiers, used by Alembic. +revision = '833da8570507' +down_revision = '4dc2d5a8c53c' +branch_labels = None +depends_on = None + +from alembic import op +import sqlalchemy as sa + + +def upgrade(): + # FIXME: currently drops all api tokens and forces recreation! + # this ensures a consistent database, but requires: + # 1. all servers to be stopped for upgrade (maybe unavoidable anyway) + # 2. any manually issued/stored tokens to be re-issued + + # tokens loaded via configuration will be recreated on launch and unaffected + op.drop_table('api_tokens') + op.drop_table('oauth_access_tokens') + return + # TODO: explore in-place migration. This seems hard! + # 1. add new columns in api tokens + # 2. fill default fields (client_id='jupyterhub') for all api tokens + # 3. copy oauth tokens into api tokens + # 4. give oauth tokens 'identify' scopes + + c = op.get_bind() + naming_convention = { + "fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s", + } + with op.batch_alter_table( + "api_tokens", + naming_convention=naming_convention, + ) as batch_op: + batch_op.add_column( + sa.Column( + 'client_id', + sa.Unicode(255), + # sa.ForeignKey('oauth_clients.identifier', ondelete='CASCADE'), + nullable=True, + ), + ) + # batch_cursor = op.get_bind() + # batch_cursor.execute( + # """ + # UPDATE api_tokens + # SET client_id='jupyterhub' + # WHERE client_id IS NULL + # """ + # ) + batch_op.create_foreign_key( + "fk_api_token_client_id", + # 'api_tokens', + 'oauth_clients', + ['client_id'], + ['identifier'], + ondelete='CASCADE', + ) + + c.execute( + """ + UPDATE api_tokens + SET client_id='jupyterhub' + WHERE client_id IS NULL + """ + ) + + op.add_column( + 'api_tokens', + sa.Column( + 'grant_type', + sa.Enum( + 'authorization_code', + 'implicit', + 'password', + 'client_credentials', + 'refresh_token', + name='granttype', + ), + server_default='authorization_code', + nullable=False, + ), + ) + op.add_column( + 'api_tokens', sa.Column('refresh_token', sa.Unicode(length=255), nullable=True) + ) + op.add_column( + 'api_tokens', sa.Column('session_id', sa.Unicode(length=255), nullable=True) + ) + + # TODO: migrate OAuth tokens into APIToken table + + op.drop_index('ix_oauth_access_tokens_prefix', table_name='oauth_access_tokens') + op.drop_table('oauth_access_tokens') + + +def downgrade(): + # delete OAuth tokens for non-jupyterhub clients + # drop new columns from api tokens + op.drop_constraint(None, 'api_tokens', type_='foreignkey') + op.drop_column('api_tokens', 'session_id') + op.drop_column('api_tokens', 'refresh_token') + op.drop_column('api_tokens', 'grant_type') + op.drop_column('api_tokens', 'client_id') + # FIXME: only drop tokens whose client id is not 'jupyterhub' + # until then, drop all tokens + op.drop_table("api_tokens") + + op.drop_table('api_token_role_map') + op.drop_table('service_role_map') + op.drop_table('user_role_map') + op.drop_table('roles') diff --git a/jupyterhub/apihandlers/auth.py b/jupyterhub/apihandlers/auth.py index 938d88ec..e7f72880 100644 --- a/jupyterhub/apihandlers/auth.py +++ b/jupyterhub/apihandlers/auth.py @@ -29,8 +29,6 @@ class TokenAPIHandler(APIHandler): "/authorizations/token/:token endpoint is deprecated in JupyterHub 2.0. Use /api/user" ) orm_token = orm.APIToken.find(self.db, token) - if orm_token is None: - orm_token = orm.OAuthAccessToken.find(self.db, token) if orm_token is None: raise web.HTTPError(404) diff --git a/jupyterhub/apihandlers/base.py b/jupyterhub/apihandlers/base.py index bbbda5d8..9832afc5 100644 --- a/jupyterhub/apihandlers/base.py +++ b/jupyterhub/apihandlers/base.py @@ -205,23 +205,6 @@ class APIHandler(BaseHandler): def token_model(self, token): """Get the JSON model for an APIToken""" - expires_at = None - if isinstance(token, orm.APIToken): - kind = 'api_token' - roles = [r.name for r in token.roles] - extra = {'note': token.note} - expires_at = token.expires_at - elif isinstance(token, orm.OAuthAccessToken): - kind = 'oauth' - # oauth tokens do not bear roles - roles = [] - extra = {'oauth_client': token.client.description or token.client.client_id} - if token.expires_at: - expires_at = datetime.fromtimestamp(token.expires_at) - else: - raise TypeError( - "token must be an APIToken or OAuthAccessToken, not %s" % type(token) - ) if token.user: owner_key = 'user' @@ -234,13 +217,14 @@ class APIHandler(BaseHandler): model = { owner_key: owner, 'id': token.api_id, - 'kind': kind, - 'roles': [role for role in roles], + 'kind': 'api_token', + 'roles': [r.name for r in token.roles], 'created': isoformat(token.created), 'last_activity': isoformat(token.last_activity), - 'expires_at': isoformat(expires_at), + 'expires_at': isoformat(token.expires_at), + 'note': token.note, + 'oauth_client': token.client.description or token.client.client_id, } - model.update(extra) return model def user_model(self, user): diff --git a/jupyterhub/apihandlers/users.py b/jupyterhub/apihandlers/users.py index 56775e34..c70d35ad 100644 --- a/jupyterhub/apihandlers/users.py +++ b/jupyterhub/apihandlers/users.py @@ -32,9 +32,6 @@ class SelfAPIHandler(APIHandler): async def get(self): user = self.current_user - if user is None: - # whoami can be accessed via oauth token - user = self.get_current_user_oauth_token() if user is None: raise web.HTTPError(403) if isinstance(user, orm.Service): @@ -316,17 +313,7 @@ class UserTokenListAPIHandler(APIHandler): continue api_tokens.append(self.token_model(token)) - oauth_tokens = [] - # OAuth tokens use integer timestamps - now_timestamp = now.timestamp() - for token in sorted(user.oauth_tokens, key=sort_key): - if token.expires_at and token.expires_at < now_timestamp: - # exclude expired tokens - self.db.delete(token) - self.db.commit() - continue - oauth_tokens.append(self.token_model(token)) - self.write(json.dumps({'api_tokens': api_tokens, 'oauth_tokens': oauth_tokens})) + self.write(json.dumps({'api_tokens': api_tokens})) # Todo: Set to @needs_scope('users:tokens') async def post(self, user_name): @@ -410,19 +397,15 @@ class UserTokenAPIHandler(APIHandler): (e.g. wrong owner, invalid key format, etc.) """ not_found = "No such token %s for user %s" % (token_id, user.name) - prefix, id_ = token_id[0], token_id[1:] - if prefix == 'a': - Token = orm.APIToken - elif prefix == 'o': - Token = orm.OAuthAccessToken - else: + prefix, id_ = token_id[:1], token_id[1:] + if prefix != 'a': raise web.HTTPError(404, not_found) try: id_ = int(id_) except ValueError: raise web.HTTPError(404, not_found) - orm_token = self.db.query(Token).filter(Token.id == id_).first() + orm_token = self.db.query(orm.APIToken).filter_by(id=id_).first() if orm_token is None or orm_token.user is not user.orm_user: raise web.HTTPError(404, "Token not found %s", orm_token) return orm_token @@ -444,10 +427,10 @@ class UserTokenAPIHandler(APIHandler): raise web.HTTPError(404, "No such user: %s" % user_name) token = self.find_token_by_id(user, token_id) # deleting an oauth token deletes *all* oauth tokens for that client - if isinstance(token, orm.OAuthAccessToken): - client_id = token.client_id + client_id = token.client_id + if token.client_id != "jupyterhub": tokens = [ - token for token in user.oauth_tokens if token.client_id == client_id + token for token in user.api_tokens if token.client_id == client_id ] else: tokens = [token] diff --git a/jupyterhub/app.py b/jupyterhub/app.py index a0c5fde0..6f3cf45d 100644 --- a/jupyterhub/app.py +++ b/jupyterhub/app.py @@ -2014,12 +2014,13 @@ class JupyterHub(Application): run periodically """ # this should be all the subclasses of Expiring - for cls in (orm.APIToken, orm.OAuthAccessToken, orm.OAuthCode): + for cls in (orm.APIToken, orm.OAuthCode): self.log.debug("Purging expired {name}s".format(name=cls.__name__)) cls.purge_expired(self.db) async def init_api_tokens(self): """Load predefined API tokens (for services) into database""" + await self._add_tokens(self.service_tokens, kind='service') await self._add_tokens(self.api_tokens, kind='user') @@ -2292,13 +2293,30 @@ class JupyterHub(Application): login_url=url_path_join(base_url, 'login'), token_expires_in=self.oauth_token_expires_in, ) + # ensure the default oauth client exists + if ( + not self.db.query(orm.OAuthClient) + .filter_by(identifier="jupyterhub") + .first() + ): + # create the oauth client for jupyterhub itself + # this allows us to distinguish between orphaned tokens + # (failed cascade deletion) and tokens issued by the hub + # it has no client_secret, which means it cannot be used + # to make requests + self.oauth_provider.add_client( + client_id="jupyterhub", + client_secret="", + redirect_uri="", + description="JupyterHub", + ) def cleanup_oauth_clients(self): """Cleanup any OAuth clients that shouldn't be in the database. This should mainly be services that have been removed from configuration or renamed. """ - oauth_client_ids = set() + oauth_client_ids = {"jupyterhub"} for service in self._service_map.values(): if service.oauth_available: oauth_client_ids.add(service.oauth_client_id) diff --git a/jupyterhub/handlers/base.py b/jupyterhub/handlers/base.py index 4841b2a8..adf0c363 100644 --- a/jupyterhub/handlers/base.py +++ b/jupyterhub/handlers/base.py @@ -247,26 +247,6 @@ class BaseHandler(RequestHandler): return None return match.group(1) - def get_current_user_oauth_token(self): - """Get the current user identified by OAuth access token - - Separate from API token because OAuth access tokens - can only be used for identifying users, - not using the API. - """ - token = self.get_auth_token() - if token is None: - return None - orm_token = orm.OAuthAccessToken.find(self.db, token) - if orm_token is None: - return None - - now = datetime.utcnow() - recorded = self._record_activity(orm_token, now) - if self._record_activity(orm_token.user, now) or recorded: - self.db.commit() - return self._user_from_orm(orm_token.user) - def _record_activity(self, obj, timestamp=None): """record activity on an ORM object @@ -373,7 +353,7 @@ class BaseHandler(RequestHandler): # FIXME: scopes should give us better control than this # don't consider API requests originating from a server # to be activity from the user - if not orm_token.note.startswith("Server at "): + if not orm_token.note or not orm_token.note.startswith("Server at "): recorded = self._record_activity(orm_token.user, now) or recorded if recorded: self.db.commit() @@ -501,10 +481,8 @@ class BaseHandler(RequestHandler): # don't clear session tokens if not logged in, # because that could be a malicious logout request! count = 0 - for access_token in ( - self.db.query(orm.OAuthAccessToken) - .filter(orm.OAuthAccessToken.user_id == user.id) - .filter(orm.OAuthAccessToken.session_id == session_id) + for access_token in self.db.query(orm.APIToken).filter_by( + user_id=user.id, session_id=session_id ): self.db.delete(access_token) count += 1 diff --git a/jupyterhub/handlers/pages.py b/jupyterhub/handlers/pages.py index a9422699..356aae71 100644 --- a/jupyterhub/handlers/pages.py +++ b/jupyterhub/handlers/pages.py @@ -552,36 +552,32 @@ class TokenPageHandler(BaseHandler): return (token.last_activity or never, token.created or never) now = datetime.utcnow() - api_tokens = [] - for token in sorted(user.api_tokens, key=sort_key, reverse=True): - if token.expires_at and token.expires_at < now: - self.db.delete(token) - self.db.commit() - continue - api_tokens.append(token) # group oauth client tokens by client id - # AccessTokens have expires_at as an integer timestamp - now_timestamp = now.timestamp() - oauth_tokens = defaultdict(list) - for token in user.oauth_tokens: - if token.expires_at and token.expires_at < now_timestamp: - self.log.warning("Deleting expired token") + all_tokens = defaultdict(list) + for token in sorted(user.api_tokens, key=sort_key, reverse=True): + if token.expires_at and token.expires_at < now: + self.log.warning(f"Deleting expired token {token}") self.db.delete(token) self.db.commit() continue if not token.client_id: # token should have been deleted when client was deleted - self.log.warning("Deleting stale oauth token for %s", user.name) + self.log.warning("Deleting stale oauth token {token}") self.db.delete(token) self.db.commit() continue - oauth_tokens[token.client_id].append(token) + all_tokens[token.client_id].append(token) + # individually list tokens issued by jupyterhub itself + api_tokens = all_tokens.pop("jupyterhub", []) + + # group all other tokens issued under their owners # get the earliest created and latest last_activity # timestamp for a given oauth client oauth_clients = [] - for client_id, tokens in oauth_tokens.items(): + + for client_id, tokens in all_tokens.items(): created = tokens[0].created last_activity = tokens[0].last_activity for token in tokens[1:]: diff --git a/jupyterhub/oauth/provider.py b/jupyterhub/oauth/provider.py index 7dd7b160..ee96dfbe 100644 --- a/jupyterhub/oauth/provider.py +++ b/jupyterhub/oauth/provider.py @@ -2,18 +2,18 @@ implements https://oauthlib.readthedocs.io/en/latest/oauth2/server.html """ +from datetime import timedelta + from oauthlib import uri_validate from oauthlib.oauth2 import RequestValidator from oauthlib.oauth2 import WebApplicationServer from oauthlib.oauth2.rfc6749.grant_types import authorization_code from oauthlib.oauth2.rfc6749.grant_types import base -from tornado.escape import url_escape from tornado.log import app_log from .. import orm from ..utils import compare_token from ..utils import hash_token -from ..utils import url_path_join # patch absolute-uri check # because we want to allow relative uri oauth @@ -60,6 +60,9 @@ class JupyterHubRequestValidator(RequestValidator): ) if oauth_client is None: return False + if not client_secret or not oauth_client.secret: + # disallow authentication with no secret + return False if not compare_token(oauth_client.secret, client_secret): app_log.warning("Client secret mismatch for %s", client_id) return False @@ -339,10 +342,10 @@ class JupyterHubRequestValidator(RequestValidator): .filter_by(identifier=request.client.client_id) .first() ) - orm_access_token = orm.OAuthAccessToken( - client=client, + orm_access_token = orm.APIToken.new( + client_id=client.identifier, grant_type=orm.GrantType.authorization_code, - expires_at=orm.OAuthAccessToken.now() + token['expires_in'], + expires_at=orm.APIToken.now() + timedelta(seconds=token['expires_in']), refresh_token=token['refresh_token'], # TODO: save scopes, # scopes=scopes, @@ -412,6 +415,8 @@ class JupyterHubRequestValidator(RequestValidator): ) if orm_client is None: return False + if not orm_client.secret: + return False request.client = orm_client return True @@ -574,14 +579,16 @@ class JupyterHubOAuthServer(WebApplicationServer): app_log.info(f'Creating oauth client {client_id}') else: app_log.info(f'Updating oauth client {client_id}') - orm_client.secret = hash_token(client_secret) + orm_client.secret = hash_token(client_secret) if client_secret else "" orm_client.redirect_uri = redirect_uri orm_client.description = description self.db.commit() def fetch_by_client_id(self, client_id): """Find a client by its id""" - return self.db.query(orm.OAuthClient).filter_by(identifier=client_id).first() + client = self.db.query(orm.OAuthClient).filter_by(identifier=client_id).first() + if client and client.secret: + return client def make_provider(session_factory, url_prefix, login_url, **oauth_server_kwargs): diff --git a/jupyterhub/orm.py b/jupyterhub/orm.py index 6e6e8693..a00e265a 100644 --- a/jupyterhub/orm.py +++ b/jupyterhub/orm.py @@ -277,9 +277,6 @@ class User(Base): last_activity = Column(DateTime, nullable=True) api_tokens = relationship("APIToken", backref="user", cascade="all, delete-orphan") - oauth_tokens = relationship( - "OAuthAccessToken", backref="user", cascade="all, delete-orphan" - ) oauth_codes = relationship( "OAuthCode", backref="user", cascade="all, delete-orphan" ) @@ -485,7 +482,9 @@ class Hashed(Expiring): @classmethod def check_token(cls, db, token): """Check if a token is acceptable""" + print("checking", cls, token, len(token), cls.min_length) if len(token) < cls.min_length: + print("raising") raise ValueError( "Tokens must be at least %i characters, got %r" % (cls.min_length, token) @@ -530,6 +529,20 @@ class Hashed(Expiring): return orm_token +# ------------------------------------ +# OAuth tables +# ------------------------------------ + + +class GrantType(enum.Enum): + # we only use authorization_code for now + authorization_code = 'authorization_code' + implicit = 'implicit' + password = 'password' + client_credentials = 'client_credentials' + refresh_token = 'refresh_token' + + class APIToken(Hashed, Base): """An API token""" @@ -548,6 +561,15 @@ class APIToken(Hashed, Base): def api_id(self): return 'a%i' % self.id + # added in 2.0 + client_id = Column( + Unicode(255), ForeignKey('oauth_clients.identifier', ondelete='CASCADE') + ) + grant_type = Column(Enum(GrantType), nullable=False) + refresh_token = Column(Unicode(255)) + # the browser session id associated with a given token + session_id = Column(Unicode(255)) + # token metadata for bookkeeping now = datetime.utcnow # for expiry created = Column(DateTime, default=datetime.utcnow) @@ -566,8 +588,12 @@ class APIToken(Hashed, Base): # this shouldn't happen kind = 'owner' name = 'unknown' - return "<{cls}('{pre}...', {kind}='{name}')>".format( - cls=self.__class__.__name__, pre=self.prefix, kind=kind, name=name + return "<{cls}('{pre}...', {kind}='{name}', client_id={client_id!r})>".format( + cls=self.__class__.__name__, + pre=self.prefix, + kind=kind, + name=name, + client_id=self.client_id, ) @classmethod @@ -588,6 +614,14 @@ class APIToken(Hashed, Base): raise ValueError("kind must be 'user', 'service', or None, not %r" % kind) for orm_token in prefix_match: if orm_token.match(token): + if not orm_token.client_id: + app_log.warning( + "Deleting stale oauth token for %s with no client", + orm_token.user and orm_token.user.name, + ) + db.delete(orm_token) + db.commit() + return return orm_token @classmethod @@ -600,6 +634,7 @@ class APIToken(Hashed, Base): note='', generated=True, expires_in=None, + client_id='jupyterhub', ): """Generate a new API token for a user or service""" assert user or service @@ -614,7 +649,12 @@ class APIToken(Hashed, Base): cls.check_token(db, token) # two stages to ensure orm_token.generated has been set # before token setter is called - orm_token = cls(generated=generated, note=note or '') + orm_token = cls( + generated=generated, + note=note or '', + grant_type=GrantType.authorization_code, + client_id=client_id, + ) orm_token.token = token if user: assert user.id is not None @@ -641,76 +681,6 @@ class APIToken(Hashed, Base): return token -# ------------------------------------ -# OAuth tables -# ------------------------------------ - - -class GrantType(enum.Enum): - # we only use authorization_code for now - authorization_code = 'authorization_code' - implicit = 'implicit' - password = 'password' - client_credentials = 'client_credentials' - refresh_token = 'refresh_token' - - -class OAuthAccessToken(Hashed, Base): - __tablename__ = 'oauth_access_tokens' - id = Column(Integer, primary_key=True, autoincrement=True) - - @staticmethod - def now(): - return datetime.utcnow().timestamp() - - @property - def api_id(self): - return 'o%i' % self.id - - client_id = Column( - Unicode(255), ForeignKey('oauth_clients.identifier', ondelete='CASCADE') - ) - grant_type = Column(Enum(GrantType), nullable=False) - expires_at = Column(Integer) - refresh_token = Column(Unicode(255)) - # TODO: drop refresh_expires_at. Refresh tokens shouldn't expire - refresh_expires_at = Column(Integer) - user_id = Column(Integer, ForeignKey('users.id', ondelete='CASCADE')) - service = None # for API-equivalence with APIToken - - # the browser session id associated with a given token - session_id = Column(Unicode(255)) - - # from Hashed - hashed = Column(Unicode(255), unique=True) - prefix = Column(Unicode(16), index=True) - - created = Column(DateTime, default=datetime.utcnow) - last_activity = Column(DateTime, nullable=True) - - def __repr__(self): - return "<{cls}('{prefix}...', client_id={client_id!r}, user={user!r}, expires_in={expires_in}>".format( - cls=self.__class__.__name__, - client_id=self.client_id, - user=self.user and self.user.name, - prefix=self.prefix, - expires_in=self.expires_in, - ) - - @classmethod - def find(cls, db, token): - orm_token = super().find(db, token) - if orm_token and not orm_token.client_id: - app_log.warning( - "Deleting stale oauth token for %s with no client", - orm_token.user and orm_token.user.name, - ) - db.delete(orm_token) - db.commit() - return - return orm_token - - class OAuthCode(Expiring, Base): __tablename__ = 'oauth_codes' @@ -752,7 +722,7 @@ class OAuthClient(Base): return self.identifier access_tokens = relationship( - OAuthAccessToken, backref='client', cascade='all, delete-orphan' + APIToken, backref='client', cascade='all, delete-orphan' ) codes = relationship(OAuthCode, backref='client', cascade='all, delete-orphan') diff --git a/jupyterhub/services/service.py b/jupyterhub/services/service.py index 44fd763c..c72ae382 100644 --- a/jupyterhub/services/service.py +++ b/jupyterhub/services/service.py @@ -51,6 +51,7 @@ from traitlets import Dict from traitlets import HasTraits from traitlets import Instance from traitlets import Unicode +from traitlets import validate from traitlets.config import LoggingConfigurable from .. import orm @@ -284,6 +285,15 @@ class Service(LoggingConfigurable): def _default_client_id(self): return 'service-%s' % self.name + @validate("oauth_client_id") + def _validate_client_id(self, proposal): + if not proposal.value.startswith("service-"): + raise ValueError( + f"service {self.name} has oauth_client_id='{proposal.value}'." + " Service oauth client ids must start with 'service-'" + ) + return proposal.value + oauth_redirect_uri = Unicode( help="""OAuth redirect URI for this service. diff --git a/jupyterhub/tests/populate_db.py b/jupyterhub/tests/populate_db.py index 2b5c6007..4db35cf9 100644 --- a/jupyterhub/tests/populate_db.py +++ b/jupyterhub/tests/populate_db.py @@ -70,7 +70,11 @@ def populate_db(url): code = orm.OAuthCode(client_id=client.identifier) db.add(code) db.commit() - access_token = orm.OAuthAccessToken( + if jupyterhub.version_info < (2, 0): + Token = orm.OAuthAccessToken + else: + Token = orm.APIToken + access_token = Token( client_id=client.identifier, user_id=user.id, grant_type=orm.GrantType.authorization_code, diff --git a/jupyterhub/tests/test_api.py b/jupyterhub/tests/test_api.py index e22ee51c..a8074b81 100644 --- a/jupyterhub/tests/test_api.py +++ b/jupyterhub/tests/test_api.py @@ -273,7 +273,7 @@ async def test_get_self(app): oauth_client = orm.OAuthClient(identifier='eurydice') db.add(oauth_client) db.commit() - oauth_token = orm.OAuthAccessToken( + oauth_token = orm.APIToken( user=u.orm_user, client=oauth_client, token=token, @@ -1423,12 +1423,11 @@ async def test_token_list(app, as_user, for_user, status): if status != 200: return reply = r.json() - assert sorted(reply) == ['api_tokens', 'oauth_tokens'] + assert sorted(reply) == ['api_tokens'] assert len(reply['api_tokens']) == len(for_user_obj.api_tokens) assert all(token['user'] == for_user for token in reply['api_tokens']) - assert all(token['user'] == for_user for token in reply['oauth_tokens']) # validate individual token ids - for token in reply['api_tokens'] + reply['oauth_tokens']: + for token in reply['api_tokens']: r = await api_request( app, 'users', for_user, 'tokens', token['id'], headers=headers ) diff --git a/jupyterhub/tests/test_orm.py b/jupyterhub/tests/test_orm.py index c761a040..8187c481 100644 --- a/jupyterhub/tests/test_orm.py +++ b/jupyterhub/tests/test_orm.py @@ -355,7 +355,7 @@ def test_user_delete_cascade(db): spawner.server = server = orm.Server() oauth_code = orm.OAuthCode(client=oauth_client, user=user) db.add(oauth_code) - oauth_token = orm.OAuthAccessToken( + oauth_token = orm.APIToken( client=oauth_client, user=user, grant_type=orm.GrantType.authorization_code ) db.add(oauth_token) @@ -377,7 +377,7 @@ def test_user_delete_cascade(db): assert_not_found(db, orm.Spawner, spawner_id) assert_not_found(db, orm.Server, server_id) assert_not_found(db, orm.OAuthCode, oauth_code_id) - assert_not_found(db, orm.OAuthAccessToken, oauth_token_id) + assert_not_found(db, orm.APIToken, oauth_token_id) def test_oauth_client_delete_cascade(db): @@ -391,12 +391,12 @@ def test_oauth_client_delete_cascade(db): # these should all be deleted automatically when the user goes away oauth_code = orm.OAuthCode(client=oauth_client, user=user) db.add(oauth_code) - oauth_token = orm.OAuthAccessToken( + oauth_token = orm.APIToken( client=oauth_client, user=user, grant_type=orm.GrantType.authorization_code ) db.add(oauth_token) db.commit() - assert user.oauth_tokens == [oauth_token] + assert user.tokens == [oauth_token] # record all of the ids oauth_code_id = oauth_code.id @@ -408,8 +408,8 @@ def test_oauth_client_delete_cascade(db): # verify that everything gets deleted assert_not_found(db, orm.OAuthCode, oauth_code_id) - assert_not_found(db, orm.OAuthAccessToken, oauth_token_id) - assert user.oauth_tokens == [] + assert_not_found(db, orm.APIToken, oauth_token_id) + assert user.tokens == [] assert user.oauth_codes == [] @@ -510,32 +510,32 @@ def test_expiring_api_token(app, user): def test_expiring_oauth_token(app, user): db = app.db token = "abc123" - now = orm.OAuthAccessToken.now + now = orm.APIToken.now client = orm.OAuthClient(identifier="xxx", secret="yyy") db.add(client) - orm_token = orm.OAuthAccessToken( + orm_token = orm.APIToken( token=token, grant_type=orm.GrantType.authorization_code, client=client, user=user, - expires_at=now() + 30, + expires_at=now() + datetime.timedelta(seconds=30), ) db.add(orm_token) db.commit() - found = orm.OAuthAccessToken.find(db, token) + found = orm.APIToken.find(db, token) assert found is orm_token # purge_expired doesn't delete non-expired - orm.OAuthAccessToken.purge_expired(db) - found = orm.OAuthAccessToken.find(db, token) + orm.APIToken.purge_expired(db) + found = orm.APIToken.find(db, token) assert found is orm_token - with mock.patch.object(orm.OAuthAccessToken, 'now', lambda: now() + 60): - found = orm.OAuthAccessToken.find(db, token) + with mock.patch.object(orm.APIToken, 'now', lambda: now() + 60): + found = orm.APIToken.find(db, token) assert found is None - assert orm_token in db.query(orm.OAuthAccessToken) - orm.OAuthAccessToken.purge_expired(db) - assert orm_token not in db.query(orm.OAuthAccessToken) + assert orm_token in db.query(orm.APIToken) + orm.APIToken.purge_expired(db) + assert orm_token not in db.query(orm.APIToken) def test_expiring_oauth_code(app, user): diff --git a/jupyterhub/tests/test_pages.py b/jupyterhub/tests/test_pages.py index 7140b823..3b4b6a39 100644 --- a/jupyterhub/tests/test_pages.py +++ b/jupyterhub/tests/test_pages.py @@ -869,7 +869,7 @@ async def test_oauth_token_page(app): user = app.users[orm.User.find(app.db, name)] client = orm.OAuthClient(identifier='token') app.db.add(client) - oauth_token = orm.OAuthAccessToken( + oauth_token = orm.APIToken( client=client, user=user, grant_type=orm.GrantType.authorization_code ) app.db.add(oauth_token) diff --git a/jupyterhub/tests/test_services_auth.py b/jupyterhub/tests/test_services_auth.py index b41d50db..540150bb 100644 --- a/jupyterhub/tests/test_services_auth.py +++ b/jupyterhub/tests/test_services_auth.py @@ -444,11 +444,7 @@ async def test_oauth_logout(app, mockservice_url): def auth_tokens(): """Return list of OAuth access tokens for the user""" - return list( - app.db.query(orm.OAuthAccessToken).filter( - orm.OAuthAccessToken.user_id == app_user.id - ) - ) + return list(app.db.query(orm.APIToken).filter_by(user_id=app_user.id)) # ensure we start empty assert auth_tokens() == []