diff --git a/jupyterhub/oauth/provider.py b/jupyterhub/oauth/provider.py index bb395752..c1369863 100644 --- a/jupyterhub/oauth/provider.py +++ b/jupyterhub/oauth/provider.py @@ -345,18 +345,19 @@ class JupyterHubRequestValidator(RequestValidator): # FIXME: pick a role # this will be empty for now roles = list(self.db.query(orm.Role).filter_by(name='identify')) - orm_access_token = orm.APIToken.new( + # FIXME: support refresh tokens + # These should be in a new table + token.pop("refresh_token", None) + + # APIToken.new commits the token to the db + orm.APIToken.new( client_id=client.identifier, - grant_type=orm.GrantType.authorization_code, - expires_at=orm.APIToken.now() + timedelta(seconds=token['expires_in']), - refresh_token=token['refresh_token'], + expires_in=token['expires_in'], roles=roles, token=token['access_token'], session_id=request.session_id, user=request.user, ) - self.db.add(orm_access_token) - self.db.commit() return client.redirect_uri def validate_bearer_token(self, token, scopes, request): diff --git a/jupyterhub/orm.py b/jupyterhub/orm.py index a00e265a..c35a7255 100644 --- a/jupyterhub/orm.py +++ b/jupyterhub/orm.py @@ -548,9 +548,15 @@ class APIToken(Hashed, Base): __tablename__ = 'api_tokens' - user_id = Column(Integer, ForeignKey('users.id', ondelete="CASCADE"), nullable=True) + user_id = Column( + Integer, + ForeignKey('users.id', ondelete="CASCADE"), + nullable=True, + ) service_id = Column( - Integer, ForeignKey('services.id', ondelete="CASCADE"), nullable=True + Integer, + ForeignKey('services.id', ondelete="CASCADE"), + nullable=True, ) id = Column(Integer, primary_key=True) @@ -563,12 +569,23 @@ class APIToken(Hashed, Base): # added in 2.0 client_id = Column( - Unicode(255), ForeignKey('oauth_clients.identifier', ondelete='CASCADE') + Unicode(255), + ForeignKey( + 'oauth_clients.identifier', + ondelete='CASCADE', + ), ) - grant_type = Column(Enum(GrantType), nullable=False) - refresh_token = Column(Unicode(255)) - # the browser session id associated with a given token - session_id = Column(Unicode(255)) + # FIXME: refresh_tokens not implemented + # should be a relation to another token table + # refresh_token = Column( + # Integer, + # ForeignKey('refresh_tokens.id', ondelete="CASCADE"), + # nullable=True, + # ) + + # the browser session id associated with a given token, + # if issued during oauth to be stored in a cookie + session_id = Column(Unicode(255), nullable=True) # token metadata for bookkeeping now = datetime.utcnow # for expiry @@ -633,8 +650,10 @@ class APIToken(Hashed, Base): roles=None, note='', generated=True, + session_id=None, expires_in=None, client_id='jupyterhub', + return_orm=False, ): """Generate a new API token for a user or service""" assert user or service @@ -652,8 +671,8 @@ class APIToken(Hashed, Base): orm_token = cls( generated=generated, note=note or '', - grant_type=GrantType.authorization_code, client_id=client_id, + session_id=session_id, ) orm_token.token = token if user: diff --git a/jupyterhub/tests/populate_db.py b/jupyterhub/tests/populate_db.py index ba95104a..4504a13c 100644 --- a/jupyterhub/tests/populate_db.py +++ b/jupyterhub/tests/populate_db.py @@ -6,6 +6,7 @@ used in test_db.py """ import os from datetime import datetime +from functools import partial import jupyterhub from jupyterhub import orm @@ -69,13 +70,15 @@ def populate_db(url): db.add(code) db.commit() if jupyterhub.version_info < (2, 0): - Token = orm.OAuthAccessToken + Token = partial( + orm.OAuthAccessToken, + grant_type=orm.GrantType.authorization_code, + ) else: Token = orm.APIToken access_token = Token( client_id=client.identifier, user_id=user.id, - grant_type=orm.GrantType.authorization_code, ) db.add(access_token) db.commit() diff --git a/jupyterhub/tests/test_api.py b/jupyterhub/tests/test_api.py index a8074b81..2881a69c 100644 --- a/jupyterhub/tests/test_api.py +++ b/jupyterhub/tests/test_api.py @@ -277,7 +277,6 @@ async def test_get_self(app): user=u.orm_user, client=oauth_client, token=token, - grant_type=orm.GrantType.authorization_code, ) db.add(oauth_token) db.commit() diff --git a/jupyterhub/tests/test_orm.py b/jupyterhub/tests/test_orm.py index 8187c481..093c29be 100644 --- a/jupyterhub/tests/test_orm.py +++ b/jupyterhub/tests/test_orm.py @@ -356,7 +356,8 @@ def test_user_delete_cascade(db): oauth_code = orm.OAuthCode(client=oauth_client, user=user) db.add(oauth_code) oauth_token = orm.APIToken( - client=oauth_client, user=user, grant_type=orm.GrantType.authorization_code + client=oauth_client, + user=user, ) db.add(oauth_token) db.commit() @@ -392,11 +393,12 @@ def test_oauth_client_delete_cascade(db): oauth_code = orm.OAuthCode(client=oauth_client, user=user) db.add(oauth_code) oauth_token = orm.APIToken( - client=oauth_client, user=user, grant_type=orm.GrantType.authorization_code + client=oauth_client, + user=user, ) db.add(oauth_token) db.commit() - assert user.tokens == [oauth_token] + assert user.api_tokens == [oauth_token] # record all of the ids oauth_code_id = oauth_code.id @@ -409,7 +411,7 @@ def test_oauth_client_delete_cascade(db): # verify that everything gets deleted assert_not_found(db, orm.OAuthCode, oauth_code_id) assert_not_found(db, orm.APIToken, oauth_token_id) - assert user.tokens == [] + assert user.api_tokens == [] assert user.oauth_codes == [] @@ -515,10 +517,9 @@ def test_expiring_oauth_token(app, user): db.add(client) orm_token = orm.APIToken( token=token, - grant_type=orm.GrantType.authorization_code, client=client, user=user, - expires_at=now() + datetime.timedelta(seconds=30), + expires_at=now() + timedelta(seconds=30), ) db.add(orm_token) db.commit() @@ -530,7 +531,7 @@ def test_expiring_oauth_token(app, user): found = orm.APIToken.find(db, token) assert found is orm_token - with mock.patch.object(orm.APIToken, 'now', lambda: now() + 60): + with mock.patch.object(orm.APIToken, 'now', lambda: now() + timedelta(seconds=60)): found = orm.APIToken.find(db, token) assert found is None assert orm_token in db.query(orm.APIToken) diff --git a/jupyterhub/tests/test_pages.py b/jupyterhub/tests/test_pages.py index 3b4b6a39..1bba9926 100644 --- a/jupyterhub/tests/test_pages.py +++ b/jupyterhub/tests/test_pages.py @@ -870,7 +870,8 @@ async def test_oauth_token_page(app): client = orm.OAuthClient(identifier='token') app.db.add(client) oauth_token = orm.APIToken( - client=client, user=user, grant_type=orm.GrantType.authorization_code + client=client, + user=user, ) app.db.add(oauth_token) app.db.commit()