Merge pull request #1097 from minrk/whoami-only

Don't give OAuth access tokens access to the REST API
This commit is contained in:
Carol Willing
2017-05-02 03:03:58 -07:00
committed by GitHub
7 changed files with 117 additions and 56 deletions

View File

@@ -18,6 +18,8 @@ class TokenAPIHandler(APIHandler):
@token_authenticated @token_authenticated
def get(self, token): def get(self, token):
orm_token = orm.APIToken.find(self.db, token) orm_token = orm.APIToken.find(self.db, token)
if orm_token is None:
orm_token = orm.OAuthAccessToken.find(self.db, token)
if orm_token is None: if orm_token is None:
raise web.HTTPError(404) raise web.HTTPError(404)
if orm_token.user: if orm_token.user:

View File

@@ -20,6 +20,11 @@ class SelfAPIHandler(APIHandler):
@web.authenticated @web.authenticated
def get(self): def get(self):
user = self.get_current_user() user = self.get_current_user()
if user is None:
# whoami can be accessed via oauth token
user = self.get_current_user_oauth_token()
if user is None:
raise web.HTTPError(403)
self.write(json.dumps(self.user_model(user))) self.write(json.dumps(self.user_model(user)))

View File

@@ -141,13 +141,35 @@ class BaseHandler(RequestHandler):
def cookie_max_age_days(self): def cookie_max_age_days(self):
return self.settings.get('cookie_max_age_days', None) return self.settings.get('cookie_max_age_days', None)
def get_current_user_token(self): def get_auth_token(self):
"""get_current_user from Authorization header token""" """Get the authorization token from Authorization header"""
auth_header = self.request.headers.get('Authorization', '') auth_header = self.request.headers.get('Authorization', '')
match = auth_header_pat.match(auth_header) match = auth_header_pat.match(auth_header)
if not match: if not match:
return None return None
token = match.group(1) return match.group(1)
def get_current_user_oauth_token(self):
"""Get the current user identified by OAuth access token
Separate from API token because OAuth access tokens
can only be used for identifying users,
not using the API.
"""
token = self.get_auth_token()
if token is None:
return None
orm_token = orm.OAuthAccessToken.find(self.db, token)
if orm_token is None:
return None
else:
return self._user_from_orm(orm_token.user)
def get_current_user_token(self):
"""get_current_user from Authorization header token"""
token = self.get_auth_token()
if token is None:
return None
orm_token = orm.APIToken.find(self.db, token) orm_token = orm.APIToken.find(self.db, token)
if orm_token is None: if orm_token is None:
return None return None

View File

@@ -5,8 +5,8 @@ implements https://python-oauth2.readthedocs.io/en/latest/store.html
import threading import threading
from oauth2.datatype import Client, AccessToken, AuthorizationCode from oauth2.datatype import Client, AuthorizationCode
from oauth2.error import AccessTokenNotFound, AuthCodeNotFound, ClientNotFoundError, UserNotAuthenticated from oauth2.error import AuthCodeNotFound, ClientNotFoundError, UserNotAuthenticated
from oauth2.grant import AuthorizationCodeGrant from oauth2.grant import AuthorizationCodeGrant
from oauth2.web import AuthorizationCodeGrantSiteAdapter from oauth2.web import AuthorizationCodeGrantSiteAdapter
import oauth2.store import oauth2.store
@@ -17,7 +17,6 @@ from sqlalchemy.orm import scoped_session
from tornado.escape import url_escape from tornado.escape import url_escape
from .. import orm from .. import orm
from jupyterhub.orm import APIToken
from ..utils import url_path_join, hash_token, compare_token from ..utils import url_path_join, hash_token, compare_token
@@ -66,17 +65,6 @@ class HubDBMixin(object):
class AccessTokenStore(HubDBMixin, oauth2.store.AccessTokenStore): class AccessTokenStore(HubDBMixin, oauth2.store.AccessTokenStore):
"""OAuth2 AccessTokenStore, storing data in the Hub database""" """OAuth2 AccessTokenStore, storing data in the Hub database"""
def _access_token_from_orm(self, orm_token):
"""Transform an ORM AccessToken record into an oauth2 AccessToken instance"""
return AccessToken(
client_id=orm_token.client_id,
grant_type=orm_token.grant_type,
expires_at=orm_token.expires_at,
refresh_token=orm_token.refresh_token,
refresh_expires_at=orm_token.refresh_expires_at,
user_id=orm_token.user_id,
)
def save_token(self, access_token): def save_token(self, access_token):
""" """
Stores an access token in the database. Stores an access token in the database.
@@ -86,17 +74,14 @@ class AccessTokenStore(HubDBMixin, oauth2.store.AccessTokenStore):
""" """
user = self.db.query(orm.User).filter(orm.User.id == access_token.user_id).first() user = self.db.query(orm.User).filter(orm.User.id == access_token.user_id).first()
token = user.new_api_token(access_token.token)
orm_api_token = APIToken.find(self.db, token, kind='user')
orm_access_token = orm.OAuthAccessToken( orm_access_token = orm.OAuthAccessToken(
client_id=access_token.client_id, client_id=access_token.client_id,
grant_type=access_token.grant_type, grant_type=access_token.grant_type,
expires_at=access_token.expires_at, expires_at=access_token.expires_at,
refresh_token=access_token.refresh_token, refresh_token=access_token.refresh_token,
refresh_expires_at=access_token.refresh_expires_at, refresh_expires_at=access_token.refresh_expires_at,
token=access_token.token,
user=user, user=user,
api_token=orm_api_token,
) )
self.db.add(orm_access_token) self.db.add(orm_access_token)
self.db.commit() self.db.commit()

View File

@@ -506,8 +506,65 @@ class Service(Base):
""" """
return db.query(cls).filter(cls.name == name).first() return db.query(cls).filter(cls.name == name).first()
class Hashed(object):
"""Mixin for tables with hashed tokens"""
prefix_length = 4
algorithm = "sha512"
rounds = 16384
salt_bytes = 8
min_length = 8
class APIToken(Base): @property
def token(self):
raise AttributeError("token is write-only")
@token.setter
def token(self, token):
"""Store the hashed value and prefix for a token"""
self.prefix = token[:self.prefix_length]
self.hashed = hash_token(token, rounds=self.rounds, salt=self.salt_bytes, algorithm=self.algorithm)
def match(self, token):
"""Is this my token?"""
return compare_token(self.hashed, token)
@classmethod
def check_token(cls, db, token):
"""Check if a token is acceptable"""
if len(token) < cls.min_length:
raise ValueError("Tokens must be at least %i characters, got %r" % (
cls.min_length, token)
)
found = cls.find(db, token)
if found:
raise ValueError("Collision on token: %s..." % token[:cls.prefix_length])
@classmethod
def find_prefix(cls, db, token):
"""Start the query for matching token.
Returns an SQLAlchemy query already filtered by prefix-matches.
"""
prefix = token[:cls.prefix_length]
# since we can't filter on hashed values, filter on prefix
# so we aren't comparing with all tokens
return db.query(cls).filter(bindparam('prefix', prefix).startswith(cls.prefix))
@classmethod
def find(cls, db, token):
"""Find a token object by value.
Returns None if not found.
`kind='user'` only returns API tokens for users
`kind='service'` only returns API tokens for services
"""
prefix_match = cls.find_prefix(db, token)
for orm_token in prefix_match:
if orm_token.match(token):
return orm_token
class APIToken(Hashed, Base):
"""An API token""" """An API token"""
__tablename__ = 'api_tokens' __tablename__ = 'api_tokens'
@@ -521,21 +578,7 @@ class APIToken(Base):
id = Column(Integer, primary_key=True) id = Column(Integer, primary_key=True)
hashed = Column(Unicode(1023)) hashed = Column(Unicode(1023))
prefix = Column(Unicode(1023)) prefix = Column(Unicode(16))
prefix_length = 4
algorithm = "sha512"
rounds = 16384
salt_bytes = 8
@property
def token(self):
raise AttributeError("token is write-only")
@token.setter
def token(self, token):
"""Store the hashed value and prefix for a token"""
self.prefix = token[:self.prefix_length]
self.hashed = hash_token(token, rounds=self.rounds, salt=self.salt_bytes, algorithm=self.algorithm)
def __repr__(self): def __repr__(self):
if self.user is not None: if self.user is not None:
@@ -564,10 +607,7 @@ class APIToken(Base):
`kind='user'` only returns API tokens for users `kind='user'` only returns API tokens for users
`kind='service'` only returns API tokens for services `kind='service'` only returns API tokens for services
""" """
prefix = token[:cls.prefix_length] prefix_match = cls.find_prefix(db, token)
# since we can't filter on hashed values, filter on prefix
# so we aren't comparing with all tokens
prefix_match = db.query(cls).filter(bindparam('prefix', prefix).startswith(cls.prefix))
if kind == 'user': if kind == 'user':
prefix_match = prefix_match.filter(cls.user_id != None) prefix_match = prefix_match.filter(cls.user_id != None)
elif kind == 'service': elif kind == 'service':
@@ -578,10 +618,6 @@ class APIToken(Base):
if orm_token.match(token): if orm_token.match(token):
return orm_token return orm_token
def match(self, token):
"""Is this my token?"""
return compare_token(self.hashed, token)
@classmethod @classmethod
def new(cls, token=None, user=None, service=None): def new(cls, token=None, user=None, service=None):
"""Generate a new API token for a user or service""" """Generate a new API token for a user or service"""
@@ -591,12 +627,8 @@ class APIToken(Base):
if token is None: if token is None:
token = new_token() token = new_token()
else: else:
if len(token) < 8: cls.check_token(db, token)
raise ValueError("Tokens must be at least 8 characters, got %r" % token) orm_token = cls(token=token)
found = APIToken.find(db, token)
if found:
raise ValueError("Collision on token: %s..." % token[:4])
orm_token = APIToken(token=token)
if user: if user:
assert user.id is not None assert user.id is not None
orm_token.user_id = user.id orm_token.user_id = user.id
@@ -622,19 +654,29 @@ class GrantType(enum.Enum):
refresh_token = 'refresh_token' refresh_token = 'refresh_token'
class OAuthAccessToken(Base): class OAuthAccessToken(Hashed, Base):
__tablename__ = 'oauth_access_tokens' __tablename__ = 'oauth_access_tokens'
id = Column(Integer, primary_key=True, autoincrement=True) id = Column(Integer, primary_key=True, autoincrement=True)
client_id = Column(Unicode(1023)) client_id = Column(Unicode(1023))
grant_type = Column(Enum(GrantType), nullable=False) grant_type = Column(Enum(GrantType), nullable=False)
expires_at = Column(Integer) expires_at = Column(Integer)
refresh_token = Column(Unicode(36)) refresh_token = Column(Unicode(64))
refresh_expires_at = Column(Integer) refresh_expires_at = Column(Integer)
user_id = Column(Integer, ForeignKey('users.id', ondelete='CASCADE')) user_id = Column(Integer, ForeignKey('users.id', ondelete='CASCADE'))
user = relationship(User) user = relationship(User)
api_token_id = Column(Integer, ForeignKey('api_tokens.id', ondelete='CASCADE')) session = None # for API-equivalence with APIToken
api_token = relationship(APIToken, backref='oauth_token')
# from Hashed
hashed = Column(Unicode(64))
prefix = Column(Unicode(16))
def __repr__(self):
return "<{cls}('{prefix}...', user='{user}'>".format(
cls=self.__class__.__name__,
user=self.user and self.user.name,
prefix=self.prefix,
)
class OAuthCode(Base): class OAuthCode(Base):

View File

@@ -638,6 +638,8 @@ class HubOAuthCallbackHandler(HubOAuthenticated, RequestHandler):
# TODO: make async (in a Thread?) # TODO: make async (in a Thread?)
token = self.hub_auth.token_for_code(code) token = self.hub_auth.token_for_code(code)
user_model = self.hub_auth.user_for_token(token) user_model = self.hub_auth.user_for_token(token)
if user_model is None:
raise HTTPError(500, "oauth callback failed to identify a user")
app_log.info("Logged-in user %s", user_model) app_log.info("Logged-in user %s", user_model)
self.hub_auth.set_cookie(self, token) self.hub_auth.set_cookie(self, token)
next_url = self.get_argument('next', '') or self.hub_auth.base_url next_url = self.get_argument('next', '') or self.hub_auth.base_url

View File

@@ -5,12 +5,13 @@
# Distributed under the terms of the Modified BSD License. # Distributed under the terms of the Modified BSD License.
import os import os
from textwrap import dedent
from urllib.parse import urlparse from urllib.parse import urlparse
from jinja2 import ChoiceLoader, FunctionLoader from jinja2 import ChoiceLoader, FunctionLoader
from tornado import ioloop from tornado import ioloop
from textwrap import dedent from tornado.web import HTTPError
try: try:
import notebook import notebook
@@ -119,6 +120,8 @@ class OAuthCallbackHandler(HubOAuthCallbackHandler, IPythonHandler):
# TODO: make async (in a Thread?) # TODO: make async (in a Thread?)
token = self.hub_auth.token_for_code(code) token = self.hub_auth.token_for_code(code)
user_model = self.hub_auth.user_for_token(token) user_model = self.hub_auth.user_for_token(token)
if user_model is None:
raise HTTPError(500, "oauth callback failed to identify a user")
self.log.info("Logged-in user %s", user_model) self.log.info("Logged-in user %s", user_model)
self.hub_auth.set_cookie(self, token) self.hub_auth.set_cookie(self, token)
next_url = self.get_argument('next', '') or self.base_url next_url = self.get_argument('next', '') or self.base_url