add missing session_id to newly merged API tokens

and remove grant_type which is not a property of the tokens themselves
This commit is contained in:
Min RK
2021-04-12 13:01:15 +02:00
parent e504fa4bf5
commit ad9ebdd60f
6 changed files with 49 additions and 25 deletions

View File

@@ -345,18 +345,19 @@ class JupyterHubRequestValidator(RequestValidator):
# FIXME: pick a role
# this will be empty for now
roles = list(self.db.query(orm.Role).filter_by(name='identify'))
orm_access_token = orm.APIToken.new(
# FIXME: support refresh tokens
# These should be in a new table
token.pop("refresh_token", None)
# APIToken.new commits the token to the db
orm.APIToken.new(
client_id=client.identifier,
grant_type=orm.GrantType.authorization_code,
expires_at=orm.APIToken.now() + timedelta(seconds=token['expires_in']),
refresh_token=token['refresh_token'],
expires_in=token['expires_in'],
roles=roles,
token=token['access_token'],
session_id=request.session_id,
user=request.user,
)
self.db.add(orm_access_token)
self.db.commit()
return client.redirect_uri
def validate_bearer_token(self, token, scopes, request):

View File

@@ -548,9 +548,15 @@ class APIToken(Hashed, Base):
__tablename__ = 'api_tokens'
user_id = Column(Integer, ForeignKey('users.id', ondelete="CASCADE"), nullable=True)
user_id = Column(
Integer,
ForeignKey('users.id', ondelete="CASCADE"),
nullable=True,
)
service_id = Column(
Integer, ForeignKey('services.id', ondelete="CASCADE"), nullable=True
Integer,
ForeignKey('services.id', ondelete="CASCADE"),
nullable=True,
)
id = Column(Integer, primary_key=True)
@@ -563,12 +569,23 @@ class APIToken(Hashed, Base):
# added in 2.0
client_id = Column(
Unicode(255), ForeignKey('oauth_clients.identifier', ondelete='CASCADE')
Unicode(255),
ForeignKey(
'oauth_clients.identifier',
ondelete='CASCADE',
),
)
grant_type = Column(Enum(GrantType), nullable=False)
refresh_token = Column(Unicode(255))
# the browser session id associated with a given token
session_id = Column(Unicode(255))
# FIXME: refresh_tokens not implemented
# should be a relation to another token table
# refresh_token = Column(
# Integer,
# ForeignKey('refresh_tokens.id', ondelete="CASCADE"),
# nullable=True,
# )
# the browser session id associated with a given token,
# if issued during oauth to be stored in a cookie
session_id = Column(Unicode(255), nullable=True)
# token metadata for bookkeeping
now = datetime.utcnow # for expiry
@@ -633,8 +650,10 @@ class APIToken(Hashed, Base):
roles=None,
note='',
generated=True,
session_id=None,
expires_in=None,
client_id='jupyterhub',
return_orm=False,
):
"""Generate a new API token for a user or service"""
assert user or service
@@ -652,8 +671,8 @@ class APIToken(Hashed, Base):
orm_token = cls(
generated=generated,
note=note or '',
grant_type=GrantType.authorization_code,
client_id=client_id,
session_id=session_id,
)
orm_token.token = token
if user:

View File

@@ -6,6 +6,7 @@ used in test_db.py
"""
import os
from datetime import datetime
from functools import partial
import jupyterhub
from jupyterhub import orm
@@ -69,13 +70,15 @@ def populate_db(url):
db.add(code)
db.commit()
if jupyterhub.version_info < (2, 0):
Token = orm.OAuthAccessToken
Token = partial(
orm.OAuthAccessToken,
grant_type=orm.GrantType.authorization_code,
)
else:
Token = orm.APIToken
access_token = Token(
client_id=client.identifier,
user_id=user.id,
grant_type=orm.GrantType.authorization_code,
)
db.add(access_token)
db.commit()

View File

@@ -277,7 +277,6 @@ async def test_get_self(app):
user=u.orm_user,
client=oauth_client,
token=token,
grant_type=orm.GrantType.authorization_code,
)
db.add(oauth_token)
db.commit()

View File

@@ -356,7 +356,8 @@ def test_user_delete_cascade(db):
oauth_code = orm.OAuthCode(client=oauth_client, user=user)
db.add(oauth_code)
oauth_token = orm.APIToken(
client=oauth_client, user=user, grant_type=orm.GrantType.authorization_code
client=oauth_client,
user=user,
)
db.add(oauth_token)
db.commit()
@@ -392,11 +393,12 @@ def test_oauth_client_delete_cascade(db):
oauth_code = orm.OAuthCode(client=oauth_client, user=user)
db.add(oauth_code)
oauth_token = orm.APIToken(
client=oauth_client, user=user, grant_type=orm.GrantType.authorization_code
client=oauth_client,
user=user,
)
db.add(oauth_token)
db.commit()
assert user.tokens == [oauth_token]
assert user.api_tokens == [oauth_token]
# record all of the ids
oauth_code_id = oauth_code.id
@@ -409,7 +411,7 @@ def test_oauth_client_delete_cascade(db):
# verify that everything gets deleted
assert_not_found(db, orm.OAuthCode, oauth_code_id)
assert_not_found(db, orm.APIToken, oauth_token_id)
assert user.tokens == []
assert user.api_tokens == []
assert user.oauth_codes == []
@@ -515,10 +517,9 @@ def test_expiring_oauth_token(app, user):
db.add(client)
orm_token = orm.APIToken(
token=token,
grant_type=orm.GrantType.authorization_code,
client=client,
user=user,
expires_at=now() + datetime.timedelta(seconds=30),
expires_at=now() + timedelta(seconds=30),
)
db.add(orm_token)
db.commit()
@@ -530,7 +531,7 @@ def test_expiring_oauth_token(app, user):
found = orm.APIToken.find(db, token)
assert found is orm_token
with mock.patch.object(orm.APIToken, 'now', lambda: now() + 60):
with mock.patch.object(orm.APIToken, 'now', lambda: now() + timedelta(seconds=60)):
found = orm.APIToken.find(db, token)
assert found is None
assert orm_token in db.query(orm.APIToken)

View File

@@ -870,7 +870,8 @@ async def test_oauth_token_page(app):
client = orm.OAuthClient(identifier='token')
app.db.add(client)
oauth_token = orm.APIToken(
client=client, user=user, grant_type=orm.GrantType.authorization_code
client=client,
user=user,
)
app.db.add(oauth_token)
app.db.commit()