mirror of
https://github.com/jupyterhub/jupyterhub.git
synced 2025-10-18 15:33:02 +00:00
Merge pull request #1097 from minrk/whoami-only
Don't give OAuth access tokens access to the REST API
This commit is contained in:
@@ -18,6 +18,8 @@ class TokenAPIHandler(APIHandler):
|
||||
@token_authenticated
|
||||
def get(self, token):
|
||||
orm_token = orm.APIToken.find(self.db, token)
|
||||
if orm_token is None:
|
||||
orm_token = orm.OAuthAccessToken.find(self.db, token)
|
||||
if orm_token is None:
|
||||
raise web.HTTPError(404)
|
||||
if orm_token.user:
|
||||
|
@@ -20,6 +20,11 @@ class SelfAPIHandler(APIHandler):
|
||||
@web.authenticated
|
||||
def get(self):
|
||||
user = self.get_current_user()
|
||||
if user is None:
|
||||
# whoami can be accessed via oauth token
|
||||
user = self.get_current_user_oauth_token()
|
||||
if user is None:
|
||||
raise web.HTTPError(403)
|
||||
self.write(json.dumps(self.user_model(user)))
|
||||
|
||||
|
||||
|
@@ -141,13 +141,35 @@ class BaseHandler(RequestHandler):
|
||||
def cookie_max_age_days(self):
|
||||
return self.settings.get('cookie_max_age_days', None)
|
||||
|
||||
def get_current_user_token(self):
|
||||
"""get_current_user from Authorization header token"""
|
||||
def get_auth_token(self):
|
||||
"""Get the authorization token from Authorization header"""
|
||||
auth_header = self.request.headers.get('Authorization', '')
|
||||
match = auth_header_pat.match(auth_header)
|
||||
if not match:
|
||||
return None
|
||||
token = match.group(1)
|
||||
return match.group(1)
|
||||
|
||||
def get_current_user_oauth_token(self):
|
||||
"""Get the current user identified by OAuth access token
|
||||
|
||||
Separate from API token because OAuth access tokens
|
||||
can only be used for identifying users,
|
||||
not using the API.
|
||||
"""
|
||||
token = self.get_auth_token()
|
||||
if token is None:
|
||||
return None
|
||||
orm_token = orm.OAuthAccessToken.find(self.db, token)
|
||||
if orm_token is None:
|
||||
return None
|
||||
else:
|
||||
return self._user_from_orm(orm_token.user)
|
||||
|
||||
def get_current_user_token(self):
|
||||
"""get_current_user from Authorization header token"""
|
||||
token = self.get_auth_token()
|
||||
if token is None:
|
||||
return None
|
||||
orm_token = orm.APIToken.find(self.db, token)
|
||||
if orm_token is None:
|
||||
return None
|
||||
|
@@ -5,8 +5,8 @@ implements https://python-oauth2.readthedocs.io/en/latest/store.html
|
||||
|
||||
import threading
|
||||
|
||||
from oauth2.datatype import Client, AccessToken, AuthorizationCode
|
||||
from oauth2.error import AccessTokenNotFound, AuthCodeNotFound, ClientNotFoundError, UserNotAuthenticated
|
||||
from oauth2.datatype import Client, AuthorizationCode
|
||||
from oauth2.error import AuthCodeNotFound, ClientNotFoundError, UserNotAuthenticated
|
||||
from oauth2.grant import AuthorizationCodeGrant
|
||||
from oauth2.web import AuthorizationCodeGrantSiteAdapter
|
||||
import oauth2.store
|
||||
@@ -17,7 +17,6 @@ from sqlalchemy.orm import scoped_session
|
||||
from tornado.escape import url_escape
|
||||
|
||||
from .. import orm
|
||||
from jupyterhub.orm import APIToken
|
||||
from ..utils import url_path_join, hash_token, compare_token
|
||||
|
||||
|
||||
@@ -66,17 +65,6 @@ class HubDBMixin(object):
|
||||
class AccessTokenStore(HubDBMixin, oauth2.store.AccessTokenStore):
|
||||
"""OAuth2 AccessTokenStore, storing data in the Hub database"""
|
||||
|
||||
def _access_token_from_orm(self, orm_token):
|
||||
"""Transform an ORM AccessToken record into an oauth2 AccessToken instance"""
|
||||
return AccessToken(
|
||||
client_id=orm_token.client_id,
|
||||
grant_type=orm_token.grant_type,
|
||||
expires_at=orm_token.expires_at,
|
||||
refresh_token=orm_token.refresh_token,
|
||||
refresh_expires_at=orm_token.refresh_expires_at,
|
||||
user_id=orm_token.user_id,
|
||||
)
|
||||
|
||||
def save_token(self, access_token):
|
||||
"""
|
||||
Stores an access token in the database.
|
||||
@@ -86,17 +74,14 @@ class AccessTokenStore(HubDBMixin, oauth2.store.AccessTokenStore):
|
||||
"""
|
||||
|
||||
user = self.db.query(orm.User).filter(orm.User.id == access_token.user_id).first()
|
||||
token = user.new_api_token(access_token.token)
|
||||
orm_api_token = APIToken.find(self.db, token, kind='user')
|
||||
|
||||
orm_access_token = orm.OAuthAccessToken(
|
||||
client_id=access_token.client_id,
|
||||
grant_type=access_token.grant_type,
|
||||
expires_at=access_token.expires_at,
|
||||
refresh_token=access_token.refresh_token,
|
||||
refresh_expires_at=access_token.refresh_expires_at,
|
||||
token=access_token.token,
|
||||
user=user,
|
||||
api_token=orm_api_token,
|
||||
)
|
||||
self.db.add(orm_access_token)
|
||||
self.db.commit()
|
||||
|
@@ -506,8 +506,65 @@ class Service(Base):
|
||||
"""
|
||||
return db.query(cls).filter(cls.name == name).first()
|
||||
|
||||
class Hashed(object):
|
||||
"""Mixin for tables with hashed tokens"""
|
||||
prefix_length = 4
|
||||
algorithm = "sha512"
|
||||
rounds = 16384
|
||||
salt_bytes = 8
|
||||
min_length = 8
|
||||
|
||||
class APIToken(Base):
|
||||
@property
|
||||
def token(self):
|
||||
raise AttributeError("token is write-only")
|
||||
|
||||
@token.setter
|
||||
def token(self, token):
|
||||
"""Store the hashed value and prefix for a token"""
|
||||
self.prefix = token[:self.prefix_length]
|
||||
self.hashed = hash_token(token, rounds=self.rounds, salt=self.salt_bytes, algorithm=self.algorithm)
|
||||
|
||||
def match(self, token):
|
||||
"""Is this my token?"""
|
||||
return compare_token(self.hashed, token)
|
||||
|
||||
@classmethod
|
||||
def check_token(cls, db, token):
|
||||
"""Check if a token is acceptable"""
|
||||
if len(token) < cls.min_length:
|
||||
raise ValueError("Tokens must be at least %i characters, got %r" % (
|
||||
cls.min_length, token)
|
||||
)
|
||||
found = cls.find(db, token)
|
||||
if found:
|
||||
raise ValueError("Collision on token: %s..." % token[:cls.prefix_length])
|
||||
|
||||
@classmethod
|
||||
def find_prefix(cls, db, token):
|
||||
"""Start the query for matching token.
|
||||
|
||||
Returns an SQLAlchemy query already filtered by prefix-matches.
|
||||
"""
|
||||
prefix = token[:cls.prefix_length]
|
||||
# since we can't filter on hashed values, filter on prefix
|
||||
# so we aren't comparing with all tokens
|
||||
return db.query(cls).filter(bindparam('prefix', prefix).startswith(cls.prefix))
|
||||
|
||||
@classmethod
|
||||
def find(cls, db, token):
|
||||
"""Find a token object by value.
|
||||
|
||||
Returns None if not found.
|
||||
|
||||
`kind='user'` only returns API tokens for users
|
||||
`kind='service'` only returns API tokens for services
|
||||
"""
|
||||
prefix_match = cls.find_prefix(db, token)
|
||||
for orm_token in prefix_match:
|
||||
if orm_token.match(token):
|
||||
return orm_token
|
||||
|
||||
class APIToken(Hashed, Base):
|
||||
"""An API token"""
|
||||
__tablename__ = 'api_tokens'
|
||||
|
||||
@@ -521,21 +578,7 @@ class APIToken(Base):
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
hashed = Column(Unicode(1023))
|
||||
prefix = Column(Unicode(1023))
|
||||
prefix_length = 4
|
||||
algorithm = "sha512"
|
||||
rounds = 16384
|
||||
salt_bytes = 8
|
||||
|
||||
@property
|
||||
def token(self):
|
||||
raise AttributeError("token is write-only")
|
||||
|
||||
@token.setter
|
||||
def token(self, token):
|
||||
"""Store the hashed value and prefix for a token"""
|
||||
self.prefix = token[:self.prefix_length]
|
||||
self.hashed = hash_token(token, rounds=self.rounds, salt=self.salt_bytes, algorithm=self.algorithm)
|
||||
prefix = Column(Unicode(16))
|
||||
|
||||
def __repr__(self):
|
||||
if self.user is not None:
|
||||
@@ -564,10 +607,7 @@ class APIToken(Base):
|
||||
`kind='user'` only returns API tokens for users
|
||||
`kind='service'` only returns API tokens for services
|
||||
"""
|
||||
prefix = token[:cls.prefix_length]
|
||||
# since we can't filter on hashed values, filter on prefix
|
||||
# so we aren't comparing with all tokens
|
||||
prefix_match = db.query(cls).filter(bindparam('prefix', prefix).startswith(cls.prefix))
|
||||
prefix_match = cls.find_prefix(db, token)
|
||||
if kind == 'user':
|
||||
prefix_match = prefix_match.filter(cls.user_id != None)
|
||||
elif kind == 'service':
|
||||
@@ -578,10 +618,6 @@ class APIToken(Base):
|
||||
if orm_token.match(token):
|
||||
return orm_token
|
||||
|
||||
def match(self, token):
|
||||
"""Is this my token?"""
|
||||
return compare_token(self.hashed, token)
|
||||
|
||||
@classmethod
|
||||
def new(cls, token=None, user=None, service=None):
|
||||
"""Generate a new API token for a user or service"""
|
||||
@@ -591,12 +627,8 @@ class APIToken(Base):
|
||||
if token is None:
|
||||
token = new_token()
|
||||
else:
|
||||
if len(token) < 8:
|
||||
raise ValueError("Tokens must be at least 8 characters, got %r" % token)
|
||||
found = APIToken.find(db, token)
|
||||
if found:
|
||||
raise ValueError("Collision on token: %s..." % token[:4])
|
||||
orm_token = APIToken(token=token)
|
||||
cls.check_token(db, token)
|
||||
orm_token = cls(token=token)
|
||||
if user:
|
||||
assert user.id is not None
|
||||
orm_token.user_id = user.id
|
||||
@@ -622,19 +654,29 @@ class GrantType(enum.Enum):
|
||||
refresh_token = 'refresh_token'
|
||||
|
||||
|
||||
class OAuthAccessToken(Base):
|
||||
class OAuthAccessToken(Hashed, Base):
|
||||
__tablename__ = 'oauth_access_tokens'
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
|
||||
client_id = Column(Unicode(1023))
|
||||
grant_type = Column(Enum(GrantType), nullable=False)
|
||||
expires_at = Column(Integer)
|
||||
refresh_token = Column(Unicode(36))
|
||||
refresh_token = Column(Unicode(64))
|
||||
refresh_expires_at = Column(Integer)
|
||||
user_id = Column(Integer, ForeignKey('users.id', ondelete='CASCADE'))
|
||||
user = relationship(User)
|
||||
api_token_id = Column(Integer, ForeignKey('api_tokens.id', ondelete='CASCADE'))
|
||||
api_token = relationship(APIToken, backref='oauth_token')
|
||||
session = None # for API-equivalence with APIToken
|
||||
|
||||
# from Hashed
|
||||
hashed = Column(Unicode(64))
|
||||
prefix = Column(Unicode(16))
|
||||
|
||||
def __repr__(self):
|
||||
return "<{cls}('{prefix}...', user='{user}'>".format(
|
||||
cls=self.__class__.__name__,
|
||||
user=self.user and self.user.name,
|
||||
prefix=self.prefix,
|
||||
)
|
||||
|
||||
|
||||
class OAuthCode(Base):
|
||||
|
@@ -638,6 +638,8 @@ class HubOAuthCallbackHandler(HubOAuthenticated, RequestHandler):
|
||||
# TODO: make async (in a Thread?)
|
||||
token = self.hub_auth.token_for_code(code)
|
||||
user_model = self.hub_auth.user_for_token(token)
|
||||
if user_model is None:
|
||||
raise HTTPError(500, "oauth callback failed to identify a user")
|
||||
app_log.info("Logged-in user %s", user_model)
|
||||
self.hub_auth.set_cookie(self, token)
|
||||
next_url = self.get_argument('next', '') or self.hub_auth.base_url
|
||||
|
@@ -5,12 +5,13 @@
|
||||
# Distributed under the terms of the Modified BSD License.
|
||||
|
||||
import os
|
||||
from textwrap import dedent
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from jinja2 import ChoiceLoader, FunctionLoader
|
||||
|
||||
from tornado import ioloop
|
||||
from textwrap import dedent
|
||||
from tornado.web import HTTPError
|
||||
|
||||
try:
|
||||
import notebook
|
||||
@@ -119,6 +120,8 @@ class OAuthCallbackHandler(HubOAuthCallbackHandler, IPythonHandler):
|
||||
# TODO: make async (in a Thread?)
|
||||
token = self.hub_auth.token_for_code(code)
|
||||
user_model = self.hub_auth.user_for_token(token)
|
||||
if user_model is None:
|
||||
raise HTTPError(500, "oauth callback failed to identify a user")
|
||||
self.log.info("Logged-in user %s", user_model)
|
||||
self.hub_auth.set_cookie(self, token)
|
||||
next_url = self.get_argument('next', '') or self.base_url
|
||||
|
Reference in New Issue
Block a user