mirror of
https://github.com/jupyterhub/jupyterhub.git
synced 2025-10-18 15:33:02 +00:00
Merge pull request #1097 from minrk/whoami-only
Don't give OAuth access tokens access to the REST API
This commit is contained in:
@@ -18,6 +18,8 @@ class TokenAPIHandler(APIHandler):
|
|||||||
@token_authenticated
|
@token_authenticated
|
||||||
def get(self, token):
|
def get(self, token):
|
||||||
orm_token = orm.APIToken.find(self.db, token)
|
orm_token = orm.APIToken.find(self.db, token)
|
||||||
|
if orm_token is None:
|
||||||
|
orm_token = orm.OAuthAccessToken.find(self.db, token)
|
||||||
if orm_token is None:
|
if orm_token is None:
|
||||||
raise web.HTTPError(404)
|
raise web.HTTPError(404)
|
||||||
if orm_token.user:
|
if orm_token.user:
|
||||||
|
@@ -20,6 +20,11 @@ class SelfAPIHandler(APIHandler):
|
|||||||
@web.authenticated
|
@web.authenticated
|
||||||
def get(self):
|
def get(self):
|
||||||
user = self.get_current_user()
|
user = self.get_current_user()
|
||||||
|
if user is None:
|
||||||
|
# whoami can be accessed via oauth token
|
||||||
|
user = self.get_current_user_oauth_token()
|
||||||
|
if user is None:
|
||||||
|
raise web.HTTPError(403)
|
||||||
self.write(json.dumps(self.user_model(user)))
|
self.write(json.dumps(self.user_model(user)))
|
||||||
|
|
||||||
|
|
||||||
|
@@ -141,13 +141,35 @@ class BaseHandler(RequestHandler):
|
|||||||
def cookie_max_age_days(self):
|
def cookie_max_age_days(self):
|
||||||
return self.settings.get('cookie_max_age_days', None)
|
return self.settings.get('cookie_max_age_days', None)
|
||||||
|
|
||||||
def get_current_user_token(self):
|
def get_auth_token(self):
|
||||||
"""get_current_user from Authorization header token"""
|
"""Get the authorization token from Authorization header"""
|
||||||
auth_header = self.request.headers.get('Authorization', '')
|
auth_header = self.request.headers.get('Authorization', '')
|
||||||
match = auth_header_pat.match(auth_header)
|
match = auth_header_pat.match(auth_header)
|
||||||
if not match:
|
if not match:
|
||||||
return None
|
return None
|
||||||
token = match.group(1)
|
return match.group(1)
|
||||||
|
|
||||||
|
def get_current_user_oauth_token(self):
|
||||||
|
"""Get the current user identified by OAuth access token
|
||||||
|
|
||||||
|
Separate from API token because OAuth access tokens
|
||||||
|
can only be used for identifying users,
|
||||||
|
not using the API.
|
||||||
|
"""
|
||||||
|
token = self.get_auth_token()
|
||||||
|
if token is None:
|
||||||
|
return None
|
||||||
|
orm_token = orm.OAuthAccessToken.find(self.db, token)
|
||||||
|
if orm_token is None:
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
return self._user_from_orm(orm_token.user)
|
||||||
|
|
||||||
|
def get_current_user_token(self):
|
||||||
|
"""get_current_user from Authorization header token"""
|
||||||
|
token = self.get_auth_token()
|
||||||
|
if token is None:
|
||||||
|
return None
|
||||||
orm_token = orm.APIToken.find(self.db, token)
|
orm_token = orm.APIToken.find(self.db, token)
|
||||||
if orm_token is None:
|
if orm_token is None:
|
||||||
return None
|
return None
|
||||||
|
@@ -5,8 +5,8 @@ implements https://python-oauth2.readthedocs.io/en/latest/store.html
|
|||||||
|
|
||||||
import threading
|
import threading
|
||||||
|
|
||||||
from oauth2.datatype import Client, AccessToken, AuthorizationCode
|
from oauth2.datatype import Client, AuthorizationCode
|
||||||
from oauth2.error import AccessTokenNotFound, AuthCodeNotFound, ClientNotFoundError, UserNotAuthenticated
|
from oauth2.error import AuthCodeNotFound, ClientNotFoundError, UserNotAuthenticated
|
||||||
from oauth2.grant import AuthorizationCodeGrant
|
from oauth2.grant import AuthorizationCodeGrant
|
||||||
from oauth2.web import AuthorizationCodeGrantSiteAdapter
|
from oauth2.web import AuthorizationCodeGrantSiteAdapter
|
||||||
import oauth2.store
|
import oauth2.store
|
||||||
@@ -17,7 +17,6 @@ from sqlalchemy.orm import scoped_session
|
|||||||
from tornado.escape import url_escape
|
from tornado.escape import url_escape
|
||||||
|
|
||||||
from .. import orm
|
from .. import orm
|
||||||
from jupyterhub.orm import APIToken
|
|
||||||
from ..utils import url_path_join, hash_token, compare_token
|
from ..utils import url_path_join, hash_token, compare_token
|
||||||
|
|
||||||
|
|
||||||
@@ -66,17 +65,6 @@ class HubDBMixin(object):
|
|||||||
class AccessTokenStore(HubDBMixin, oauth2.store.AccessTokenStore):
|
class AccessTokenStore(HubDBMixin, oauth2.store.AccessTokenStore):
|
||||||
"""OAuth2 AccessTokenStore, storing data in the Hub database"""
|
"""OAuth2 AccessTokenStore, storing data in the Hub database"""
|
||||||
|
|
||||||
def _access_token_from_orm(self, orm_token):
|
|
||||||
"""Transform an ORM AccessToken record into an oauth2 AccessToken instance"""
|
|
||||||
return AccessToken(
|
|
||||||
client_id=orm_token.client_id,
|
|
||||||
grant_type=orm_token.grant_type,
|
|
||||||
expires_at=orm_token.expires_at,
|
|
||||||
refresh_token=orm_token.refresh_token,
|
|
||||||
refresh_expires_at=orm_token.refresh_expires_at,
|
|
||||||
user_id=orm_token.user_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
def save_token(self, access_token):
|
def save_token(self, access_token):
|
||||||
"""
|
"""
|
||||||
Stores an access token in the database.
|
Stores an access token in the database.
|
||||||
@@ -86,17 +74,14 @@ class AccessTokenStore(HubDBMixin, oauth2.store.AccessTokenStore):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
user = self.db.query(orm.User).filter(orm.User.id == access_token.user_id).first()
|
user = self.db.query(orm.User).filter(orm.User.id == access_token.user_id).first()
|
||||||
token = user.new_api_token(access_token.token)
|
|
||||||
orm_api_token = APIToken.find(self.db, token, kind='user')
|
|
||||||
|
|
||||||
orm_access_token = orm.OAuthAccessToken(
|
orm_access_token = orm.OAuthAccessToken(
|
||||||
client_id=access_token.client_id,
|
client_id=access_token.client_id,
|
||||||
grant_type=access_token.grant_type,
|
grant_type=access_token.grant_type,
|
||||||
expires_at=access_token.expires_at,
|
expires_at=access_token.expires_at,
|
||||||
refresh_token=access_token.refresh_token,
|
refresh_token=access_token.refresh_token,
|
||||||
refresh_expires_at=access_token.refresh_expires_at,
|
refresh_expires_at=access_token.refresh_expires_at,
|
||||||
|
token=access_token.token,
|
||||||
user=user,
|
user=user,
|
||||||
api_token=orm_api_token,
|
|
||||||
)
|
)
|
||||||
self.db.add(orm_access_token)
|
self.db.add(orm_access_token)
|
||||||
self.db.commit()
|
self.db.commit()
|
||||||
|
@@ -506,8 +506,65 @@ class Service(Base):
|
|||||||
"""
|
"""
|
||||||
return db.query(cls).filter(cls.name == name).first()
|
return db.query(cls).filter(cls.name == name).first()
|
||||||
|
|
||||||
|
class Hashed(object):
|
||||||
|
"""Mixin for tables with hashed tokens"""
|
||||||
|
prefix_length = 4
|
||||||
|
algorithm = "sha512"
|
||||||
|
rounds = 16384
|
||||||
|
salt_bytes = 8
|
||||||
|
min_length = 8
|
||||||
|
|
||||||
class APIToken(Base):
|
@property
|
||||||
|
def token(self):
|
||||||
|
raise AttributeError("token is write-only")
|
||||||
|
|
||||||
|
@token.setter
|
||||||
|
def token(self, token):
|
||||||
|
"""Store the hashed value and prefix for a token"""
|
||||||
|
self.prefix = token[:self.prefix_length]
|
||||||
|
self.hashed = hash_token(token, rounds=self.rounds, salt=self.salt_bytes, algorithm=self.algorithm)
|
||||||
|
|
||||||
|
def match(self, token):
|
||||||
|
"""Is this my token?"""
|
||||||
|
return compare_token(self.hashed, token)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def check_token(cls, db, token):
|
||||||
|
"""Check if a token is acceptable"""
|
||||||
|
if len(token) < cls.min_length:
|
||||||
|
raise ValueError("Tokens must be at least %i characters, got %r" % (
|
||||||
|
cls.min_length, token)
|
||||||
|
)
|
||||||
|
found = cls.find(db, token)
|
||||||
|
if found:
|
||||||
|
raise ValueError("Collision on token: %s..." % token[:cls.prefix_length])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def find_prefix(cls, db, token):
|
||||||
|
"""Start the query for matching token.
|
||||||
|
|
||||||
|
Returns an SQLAlchemy query already filtered by prefix-matches.
|
||||||
|
"""
|
||||||
|
prefix = token[:cls.prefix_length]
|
||||||
|
# since we can't filter on hashed values, filter on prefix
|
||||||
|
# so we aren't comparing with all tokens
|
||||||
|
return db.query(cls).filter(bindparam('prefix', prefix).startswith(cls.prefix))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def find(cls, db, token):
|
||||||
|
"""Find a token object by value.
|
||||||
|
|
||||||
|
Returns None if not found.
|
||||||
|
|
||||||
|
`kind='user'` only returns API tokens for users
|
||||||
|
`kind='service'` only returns API tokens for services
|
||||||
|
"""
|
||||||
|
prefix_match = cls.find_prefix(db, token)
|
||||||
|
for orm_token in prefix_match:
|
||||||
|
if orm_token.match(token):
|
||||||
|
return orm_token
|
||||||
|
|
||||||
|
class APIToken(Hashed, Base):
|
||||||
"""An API token"""
|
"""An API token"""
|
||||||
__tablename__ = 'api_tokens'
|
__tablename__ = 'api_tokens'
|
||||||
|
|
||||||
@@ -521,21 +578,7 @@ class APIToken(Base):
|
|||||||
|
|
||||||
id = Column(Integer, primary_key=True)
|
id = Column(Integer, primary_key=True)
|
||||||
hashed = Column(Unicode(1023))
|
hashed = Column(Unicode(1023))
|
||||||
prefix = Column(Unicode(1023))
|
prefix = Column(Unicode(16))
|
||||||
prefix_length = 4
|
|
||||||
algorithm = "sha512"
|
|
||||||
rounds = 16384
|
|
||||||
salt_bytes = 8
|
|
||||||
|
|
||||||
@property
|
|
||||||
def token(self):
|
|
||||||
raise AttributeError("token is write-only")
|
|
||||||
|
|
||||||
@token.setter
|
|
||||||
def token(self, token):
|
|
||||||
"""Store the hashed value and prefix for a token"""
|
|
||||||
self.prefix = token[:self.prefix_length]
|
|
||||||
self.hashed = hash_token(token, rounds=self.rounds, salt=self.salt_bytes, algorithm=self.algorithm)
|
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
if self.user is not None:
|
if self.user is not None:
|
||||||
@@ -564,10 +607,7 @@ class APIToken(Base):
|
|||||||
`kind='user'` only returns API tokens for users
|
`kind='user'` only returns API tokens for users
|
||||||
`kind='service'` only returns API tokens for services
|
`kind='service'` only returns API tokens for services
|
||||||
"""
|
"""
|
||||||
prefix = token[:cls.prefix_length]
|
prefix_match = cls.find_prefix(db, token)
|
||||||
# since we can't filter on hashed values, filter on prefix
|
|
||||||
# so we aren't comparing with all tokens
|
|
||||||
prefix_match = db.query(cls).filter(bindparam('prefix', prefix).startswith(cls.prefix))
|
|
||||||
if kind == 'user':
|
if kind == 'user':
|
||||||
prefix_match = prefix_match.filter(cls.user_id != None)
|
prefix_match = prefix_match.filter(cls.user_id != None)
|
||||||
elif kind == 'service':
|
elif kind == 'service':
|
||||||
@@ -578,10 +618,6 @@ class APIToken(Base):
|
|||||||
if orm_token.match(token):
|
if orm_token.match(token):
|
||||||
return orm_token
|
return orm_token
|
||||||
|
|
||||||
def match(self, token):
|
|
||||||
"""Is this my token?"""
|
|
||||||
return compare_token(self.hashed, token)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def new(cls, token=None, user=None, service=None):
|
def new(cls, token=None, user=None, service=None):
|
||||||
"""Generate a new API token for a user or service"""
|
"""Generate a new API token for a user or service"""
|
||||||
@@ -591,12 +627,8 @@ class APIToken(Base):
|
|||||||
if token is None:
|
if token is None:
|
||||||
token = new_token()
|
token = new_token()
|
||||||
else:
|
else:
|
||||||
if len(token) < 8:
|
cls.check_token(db, token)
|
||||||
raise ValueError("Tokens must be at least 8 characters, got %r" % token)
|
orm_token = cls(token=token)
|
||||||
found = APIToken.find(db, token)
|
|
||||||
if found:
|
|
||||||
raise ValueError("Collision on token: %s..." % token[:4])
|
|
||||||
orm_token = APIToken(token=token)
|
|
||||||
if user:
|
if user:
|
||||||
assert user.id is not None
|
assert user.id is not None
|
||||||
orm_token.user_id = user.id
|
orm_token.user_id = user.id
|
||||||
@@ -622,19 +654,29 @@ class GrantType(enum.Enum):
|
|||||||
refresh_token = 'refresh_token'
|
refresh_token = 'refresh_token'
|
||||||
|
|
||||||
|
|
||||||
class OAuthAccessToken(Base):
|
class OAuthAccessToken(Hashed, Base):
|
||||||
__tablename__ = 'oauth_access_tokens'
|
__tablename__ = 'oauth_access_tokens'
|
||||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||||
|
|
||||||
client_id = Column(Unicode(1023))
|
client_id = Column(Unicode(1023))
|
||||||
grant_type = Column(Enum(GrantType), nullable=False)
|
grant_type = Column(Enum(GrantType), nullable=False)
|
||||||
expires_at = Column(Integer)
|
expires_at = Column(Integer)
|
||||||
refresh_token = Column(Unicode(36))
|
refresh_token = Column(Unicode(64))
|
||||||
refresh_expires_at = Column(Integer)
|
refresh_expires_at = Column(Integer)
|
||||||
user_id = Column(Integer, ForeignKey('users.id', ondelete='CASCADE'))
|
user_id = Column(Integer, ForeignKey('users.id', ondelete='CASCADE'))
|
||||||
user = relationship(User)
|
user = relationship(User)
|
||||||
api_token_id = Column(Integer, ForeignKey('api_tokens.id', ondelete='CASCADE'))
|
session = None # for API-equivalence with APIToken
|
||||||
api_token = relationship(APIToken, backref='oauth_token')
|
|
||||||
|
# from Hashed
|
||||||
|
hashed = Column(Unicode(64))
|
||||||
|
prefix = Column(Unicode(16))
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return "<{cls}('{prefix}...', user='{user}'>".format(
|
||||||
|
cls=self.__class__.__name__,
|
||||||
|
user=self.user and self.user.name,
|
||||||
|
prefix=self.prefix,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class OAuthCode(Base):
|
class OAuthCode(Base):
|
||||||
|
@@ -638,6 +638,8 @@ class HubOAuthCallbackHandler(HubOAuthenticated, RequestHandler):
|
|||||||
# TODO: make async (in a Thread?)
|
# TODO: make async (in a Thread?)
|
||||||
token = self.hub_auth.token_for_code(code)
|
token = self.hub_auth.token_for_code(code)
|
||||||
user_model = self.hub_auth.user_for_token(token)
|
user_model = self.hub_auth.user_for_token(token)
|
||||||
|
if user_model is None:
|
||||||
|
raise HTTPError(500, "oauth callback failed to identify a user")
|
||||||
app_log.info("Logged-in user %s", user_model)
|
app_log.info("Logged-in user %s", user_model)
|
||||||
self.hub_auth.set_cookie(self, token)
|
self.hub_auth.set_cookie(self, token)
|
||||||
next_url = self.get_argument('next', '') or self.hub_auth.base_url
|
next_url = self.get_argument('next', '') or self.hub_auth.base_url
|
||||||
|
@@ -5,12 +5,13 @@
|
|||||||
# Distributed under the terms of the Modified BSD License.
|
# Distributed under the terms of the Modified BSD License.
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
from textwrap import dedent
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
from jinja2 import ChoiceLoader, FunctionLoader
|
from jinja2 import ChoiceLoader, FunctionLoader
|
||||||
|
|
||||||
from tornado import ioloop
|
from tornado import ioloop
|
||||||
from textwrap import dedent
|
from tornado.web import HTTPError
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import notebook
|
import notebook
|
||||||
@@ -119,6 +120,8 @@ class OAuthCallbackHandler(HubOAuthCallbackHandler, IPythonHandler):
|
|||||||
# TODO: make async (in a Thread?)
|
# TODO: make async (in a Thread?)
|
||||||
token = self.hub_auth.token_for_code(code)
|
token = self.hub_auth.token_for_code(code)
|
||||||
user_model = self.hub_auth.user_for_token(token)
|
user_model = self.hub_auth.user_for_token(token)
|
||||||
|
if user_model is None:
|
||||||
|
raise HTTPError(500, "oauth callback failed to identify a user")
|
||||||
self.log.info("Logged-in user %s", user_model)
|
self.log.info("Logged-in user %s", user_model)
|
||||||
self.hub_auth.set_cookie(self, token)
|
self.hub_auth.set_cookie(self, token)
|
||||||
next_url = self.get_argument('next', '') or self.base_url
|
next_url = self.get_argument('next', '') or self.base_url
|
||||||
|
Reference in New Issue
Block a user