diff --git a/jupyterhub/apihandlers/users.py b/jupyterhub/apihandlers/users.py index fb16f944..78af7c97 100644 --- a/jupyterhub/apihandlers/users.py +++ b/jupyterhub/apihandlers/users.py @@ -20,6 +20,11 @@ class SelfAPIHandler(APIHandler): @web.authenticated def get(self): user = self.get_current_user() + if user is None: + # whoami can be accessed via oauth token + user = self.get_current_user_oauth_token() + if user is None: + raise web.HTTPError(403) self.write(json.dumps(self.user_model(user))) diff --git a/jupyterhub/handlers/base.py b/jupyterhub/handlers/base.py index 18755363..97c3255a 100644 --- a/jupyterhub/handlers/base.py +++ b/jupyterhub/handlers/base.py @@ -141,13 +141,35 @@ class BaseHandler(RequestHandler): def cookie_max_age_days(self): return self.settings.get('cookie_max_age_days', None) - def get_current_user_token(self): - """get_current_user from Authorization header token""" + def get_auth_token(self): + """Get the authorization token from Authorization header""" auth_header = self.request.headers.get('Authorization', '') match = auth_header_pat.match(auth_header) if not match: return None - token = match.group(1) + return match.group(1) + + def get_current_user_oauth_token(self): + """Get the current user identified by OAuth access token + + Separate from API token because OAuth access tokens + can only be used for identifying users, + not using the API. + """ + token = self.get_auth_token() + if token is None: + return None + orm_token = orm.OAuthAccessToken.find(self.db, token) + if orm_token is None: + return None + else: + return self._user_from_orm(orm_token.user) + + def get_current_user_token(self): + """get_current_user from Authorization header token""" + token = self.get_auth_token() + if token is None: + return None orm_token = orm.APIToken.find(self.db, token) if orm_token is None: return None diff --git a/jupyterhub/oauth/store.py b/jupyterhub/oauth/store.py index 941bd6de..e864320f 100644 --- a/jupyterhub/oauth/store.py +++ b/jupyterhub/oauth/store.py @@ -6,7 +6,7 @@ implements https://python-oauth2.readthedocs.io/en/latest/store.html import threading from oauth2.datatype import Client, AccessToken, AuthorizationCode -from oauth2.error import AccessTokenNotFound, AuthCodeNotFound, ClientNotFoundError, UserNotAuthenticated +from oauth2.error import AuthCodeNotFound, ClientNotFoundError, UserNotAuthenticated from oauth2.grant import AuthorizationCodeGrant from oauth2.web import AuthorizationCodeGrantSiteAdapter import oauth2.store @@ -17,8 +17,7 @@ from sqlalchemy.orm import scoped_session from tornado.escape import url_escape from .. import orm -from jupyterhub.orm import APIToken -from ..utils import url_path_join, hash_token, compare_token +from ..utils import url_path_join, hash_token, compare_token, new_token class JupyterHubSiteAdapter(AuthorizationCodeGrantSiteAdapter): @@ -66,17 +65,6 @@ class HubDBMixin(object): class AccessTokenStore(HubDBMixin, oauth2.store.AccessTokenStore): """OAuth2 AccessTokenStore, storing data in the Hub database""" - def _access_token_from_orm(self, orm_token): - """Transform an ORM AccessToken record into an oauth2 AccessToken instance""" - return AccessToken( - client_id=orm_token.client_id, - grant_type=orm_token.grant_type, - expires_at=orm_token.expires_at, - refresh_token=orm_token.refresh_token, - refresh_expires_at=orm_token.refresh_expires_at, - user_id=orm_token.user_id, - ) - def save_token(self, access_token): """ Stores an access token in the database. @@ -86,17 +74,14 @@ class AccessTokenStore(HubDBMixin, oauth2.store.AccessTokenStore): """ user = self.db.query(orm.User).filter(orm.User.id == access_token.user_id).first() - token = user.new_api_token(access_token.token) - orm_api_token = APIToken.find(self.db, token, kind='user') - orm_access_token = orm.OAuthAccessToken( client_id=access_token.client_id, grant_type=access_token.grant_type, expires_at=access_token.expires_at, refresh_token=access_token.refresh_token, refresh_expires_at=access_token.refresh_expires_at, + token=access_token.token, user=user, - api_token=orm_api_token, ) self.db.add(orm_access_token) self.db.commit() diff --git a/jupyterhub/orm.py b/jupyterhub/orm.py index 2ddf1f57..d6a0ada4 100644 --- a/jupyterhub/orm.py +++ b/jupyterhub/orm.py @@ -506,8 +506,65 @@ class Service(Base): """ return db.query(cls).filter(cls.name == name).first() +class Hashed(object): + """Mixin for tables with hashed tokens""" + prefix_length = 4 + algorithm = "sha512" + rounds = 16384 + salt_bytes = 8 + min_length = 8 -class APIToken(Base): + @property + def token(self): + raise AttributeError("token is write-only") + + @token.setter + def token(self, token): + """Store the hashed value and prefix for a token""" + self.prefix = token[:self.prefix_length] + self.hashed = hash_token(token, rounds=self.rounds, salt=self.salt_bytes, algorithm=self.algorithm) + + def match(self, token): + """Is this my token?""" + return compare_token(self.hashed, token) + + @classmethod + def check_token(cls, db, token): + """Check if a token is acceptable""" + if len(token) < cls.min_length: + raise ValueError("Tokens must be at least %i characters, got %r" % ( + cls.min_length, token) + ) + found = cls.find(db, token) + if found: + raise ValueError("Collision on token: %s..." % token[:cls.prefix_length]) + + @classmethod + def find_prefix(cls, db, token): + """Start the query for matching token. + + Returns an SQLAlchemy query already filtered by prefix-matches. + """ + prefix = token[:cls.prefix_length] + # since we can't filter on hashed values, filter on prefix + # so we aren't comparing with all tokens + return db.query(cls).filter(bindparam('prefix', prefix).startswith(cls.prefix)) + + @classmethod + def find(cls, db, token): + """Find a token object by value. + + Returns None if not found. + + `kind='user'` only returns API tokens for users + `kind='service'` only returns API tokens for services + """ + prefix_match = cls.find_prefix(db, token) + for orm_token in prefix_match: + if orm_token.match(token): + return orm_token + +class APIToken(Hashed, Base): """An API token""" __tablename__ = 'api_tokens' @@ -521,21 +578,7 @@ class APIToken(Base): id = Column(Integer, primary_key=True) hashed = Column(Unicode(1023)) - prefix = Column(Unicode(1023)) - prefix_length = 4 - algorithm = "sha512" - rounds = 16384 - salt_bytes = 8 - - @property - def token(self): - raise AttributeError("token is write-only") - - @token.setter - def token(self, token): - """Store the hashed value and prefix for a token""" - self.prefix = token[:self.prefix_length] - self.hashed = hash_token(token, rounds=self.rounds, salt=self.salt_bytes, algorithm=self.algorithm) + prefix = Column(Unicode(16)) def __repr__(self): if self.user is not None: @@ -564,10 +607,7 @@ class APIToken(Base): `kind='user'` only returns API tokens for users `kind='service'` only returns API tokens for services """ - prefix = token[:cls.prefix_length] - # since we can't filter on hashed values, filter on prefix - # so we aren't comparing with all tokens - prefix_match = db.query(cls).filter(bindparam('prefix', prefix).startswith(cls.prefix)) + prefix_match = cls.find_prefix(db, token) if kind == 'user': prefix_match = prefix_match.filter(cls.user_id != None) elif kind == 'service': @@ -578,10 +618,6 @@ class APIToken(Base): if orm_token.match(token): return orm_token - def match(self, token): - """Is this my token?""" - return compare_token(self.hashed, token) - @classmethod def new(cls, token=None, user=None, service=None): """Generate a new API token for a user or service""" @@ -591,12 +627,8 @@ class APIToken(Base): if token is None: token = new_token() else: - if len(token) < 8: - raise ValueError("Tokens must be at least 8 characters, got %r" % token) - found = APIToken.find(db, token) - if found: - raise ValueError("Collision on token: %s..." % token[:4]) - orm_token = APIToken(token=token) + cls.check_token(db, token) + orm_token = cls(token=token) if user: assert user.id is not None orm_token.user_id = user.id @@ -622,19 +654,29 @@ class GrantType(enum.Enum): refresh_token = 'refresh_token' -class OAuthAccessToken(Base): +class OAuthAccessToken(Hashed, Base): __tablename__ = 'oauth_access_tokens' id = Column(Integer, primary_key=True, autoincrement=True) client_id = Column(Unicode(1023)) grant_type = Column(Enum(GrantType), nullable=False) expires_at = Column(Integer) - refresh_token = Column(Unicode(36)) + refresh_token = Column(Unicode(64)) refresh_expires_at = Column(Integer) user_id = Column(Integer, ForeignKey('users.id', ondelete='CASCADE')) user = relationship(User) - api_token_id = Column(Integer, ForeignKey('api_tokens.id', ondelete='CASCADE')) - api_token = relationship(APIToken, backref='oauth_token') + session = None # for API-equivalence with APIToken + + # from Hashed + hashed = Column(Unicode(64)) + prefix = Column(Unicode(16)) + + def __repr__(self): + return "<{cls}('{prefix}...', user='{user}'>".format( + cls=self.__class__.__name__, + user=self.user and self.user.name, + prefix=self.prefix, + ) class OAuthCode(Base):