diff --git a/ci/init-db.sh b/ci/init-db.sh index b510f549..32a73d29 100755 --- a/ci/init-db.sh +++ b/ci/init-db.sh @@ -20,7 +20,7 @@ fi # Configure a set of databases in the database server for upgrade tests set -x -for SUFFIX in '' _upgrade_072 _upgrade_081 _upgrade_094; do +for SUFFIX in '' _upgrade_100 _upgrade_122 _upgrade_130; do $SQL_CLIENT "DROP DATABASE jupyterhub${SUFFIX};" 2>/dev/null || true $SQL_CLIENT "CREATE DATABASE jupyterhub${SUFFIX} ${EXTRA_CREATE_DATABASE_ARGS:-};" done diff --git a/jupyterhub/_version.py b/jupyterhub/_version.py index 591bafc5..ed389f1e 100644 --- a/jupyterhub/_version.py +++ b/jupyterhub/_version.py @@ -3,8 +3,8 @@ # Distributed under the terms of the Modified BSD License. version_info = ( - 1, - 4, + 2, + 0, 0, "", # release (b1, rc1, or "" for final or dev) "dev", # dev or nothing for beta/rc/stable releases diff --git a/jupyterhub/alembic/versions/833da8570507_rbac.py b/jupyterhub/alembic/versions/833da8570507_rbac.py new file mode 100644 index 00000000..b76fc707 --- /dev/null +++ b/jupyterhub/alembic/versions/833da8570507_rbac.py @@ -0,0 +1,49 @@ +"""rbac + +Revision ID: 833da8570507 +Revises: 4dc2d5a8c53c +Create Date: 2021-02-17 15:03:04.360368 + +""" +# revision identifiers, used by Alembic. +revision = '833da8570507' +down_revision = '4dc2d5a8c53c' +branch_labels = None +depends_on = None + +from alembic import op +import sqlalchemy as sa + + +def upgrade(): + # FIXME, maybe: currently drops all api tokens and forces recreation! + # this ensures a consistent database, but requires: + # 1. all servers to be stopped for upgrade (maybe unavoidable anyway) + # 2. any manually issued/stored tokens to be re-issued + + # tokens loaded via configuration will be recreated on launch and unaffected + op.drop_table('api_tokens') + op.drop_table('oauth_access_tokens') + return + # TODO: explore in-place migration. This seems hard! + # 1. add new columns in api tokens + # 2. fill default fields (client_id='jupyterhub') for all api tokens + # 3. copy oauth tokens into api tokens + # 4. give oauth tokens 'identify' scopes + + +def downgrade(): + # delete OAuth tokens for non-jupyterhub clients + # drop new columns from api tokens + op.drop_constraint(None, 'api_tokens', type_='foreignkey') + op.drop_column('api_tokens', 'session_id') + op.drop_column('api_tokens', 'client_id') + + # FIXME: only drop tokens whose client id is not 'jupyterhub' + # until then, drop all tokens + op.drop_table("api_tokens") + + op.drop_table('api_token_role_map') + op.drop_table('service_role_map') + op.drop_table('user_role_map') + op.drop_table('roles') diff --git a/jupyterhub/apihandlers/auth.py b/jupyterhub/apihandlers/auth.py index 938d88ec..e7f72880 100644 --- a/jupyterhub/apihandlers/auth.py +++ b/jupyterhub/apihandlers/auth.py @@ -29,8 +29,6 @@ class TokenAPIHandler(APIHandler): "/authorizations/token/:token endpoint is deprecated in JupyterHub 2.0. Use /api/user" ) orm_token = orm.APIToken.find(self.db, token) - if orm_token is None: - orm_token = orm.OAuthAccessToken.find(self.db, token) if orm_token is None: raise web.HTTPError(404) diff --git a/jupyterhub/apihandlers/base.py b/jupyterhub/apihandlers/base.py index bbbda5d8..9832afc5 100644 --- a/jupyterhub/apihandlers/base.py +++ b/jupyterhub/apihandlers/base.py @@ -205,23 +205,6 @@ class APIHandler(BaseHandler): def token_model(self, token): """Get the JSON model for an APIToken""" - expires_at = None - if isinstance(token, orm.APIToken): - kind = 'api_token' - roles = [r.name for r in token.roles] - extra = {'note': token.note} - expires_at = token.expires_at - elif isinstance(token, orm.OAuthAccessToken): - kind = 'oauth' - # oauth tokens do not bear roles - roles = [] - extra = {'oauth_client': token.client.description or token.client.client_id} - if token.expires_at: - expires_at = datetime.fromtimestamp(token.expires_at) - else: - raise TypeError( - "token must be an APIToken or OAuthAccessToken, not %s" % type(token) - ) if token.user: owner_key = 'user' @@ -234,13 +217,14 @@ class APIHandler(BaseHandler): model = { owner_key: owner, 'id': token.api_id, - 'kind': kind, - 'roles': [role for role in roles], + 'kind': 'api_token', + 'roles': [r.name for r in token.roles], 'created': isoformat(token.created), 'last_activity': isoformat(token.last_activity), - 'expires_at': isoformat(expires_at), + 'expires_at': isoformat(token.expires_at), + 'note': token.note, + 'oauth_client': token.client.description or token.client.client_id, } - model.update(extra) return model def user_model(self, user): diff --git a/jupyterhub/apihandlers/users.py b/jupyterhub/apihandlers/users.py index 56775e34..492a3a29 100644 --- a/jupyterhub/apihandlers/users.py +++ b/jupyterhub/apihandlers/users.py @@ -14,6 +14,7 @@ from tornado import web from tornado.iostream import StreamClosedError from .. import orm +from .. import scopes from ..roles import assign_default_roles from ..scopes import needs_scope from ..user import User @@ -32,14 +33,16 @@ class SelfAPIHandler(APIHandler): async def get(self): user = self.current_user - if user is None: - # whoami can be accessed via oauth token - user = self.get_current_user_oauth_token() if user is None: raise web.HTTPError(403) if isinstance(user, orm.Service): + # ensure we have the minimal 'identify' scopes for the token owner + self.raw_scopes.update(scopes.identify_scopes(user)) + self.parsed_scopes = scopes.parse_scopes(self.raw_scopes) model = self.service_model(user) else: + self.raw_scopes.update(scopes.identify_scopes(user.orm_user)) + self.parsed_scopes = scopes.parse_scopes(self.raw_scopes) model = self.user_model(user) self.write(json.dumps(model)) @@ -316,17 +319,7 @@ class UserTokenListAPIHandler(APIHandler): continue api_tokens.append(self.token_model(token)) - oauth_tokens = [] - # OAuth tokens use integer timestamps - now_timestamp = now.timestamp() - for token in sorted(user.oauth_tokens, key=sort_key): - if token.expires_at and token.expires_at < now_timestamp: - # exclude expired tokens - self.db.delete(token) - self.db.commit() - continue - oauth_tokens.append(self.token_model(token)) - self.write(json.dumps({'api_tokens': api_tokens, 'oauth_tokens': oauth_tokens})) + self.write(json.dumps({'api_tokens': api_tokens})) # Todo: Set to @needs_scope('users:tokens') async def post(self, user_name): @@ -410,19 +403,15 @@ class UserTokenAPIHandler(APIHandler): (e.g. wrong owner, invalid key format, etc.) """ not_found = "No such token %s for user %s" % (token_id, user.name) - prefix, id_ = token_id[0], token_id[1:] - if prefix == 'a': - Token = orm.APIToken - elif prefix == 'o': - Token = orm.OAuthAccessToken - else: + prefix, id_ = token_id[:1], token_id[1:] + if prefix != 'a': raise web.HTTPError(404, not_found) try: id_ = int(id_) except ValueError: raise web.HTTPError(404, not_found) - orm_token = self.db.query(Token).filter(Token.id == id_).first() + orm_token = self.db.query(orm.APIToken).filter_by(id=id_).first() if orm_token is None or orm_token.user is not user.orm_user: raise web.HTTPError(404, "Token not found %s", orm_token) return orm_token @@ -444,10 +433,10 @@ class UserTokenAPIHandler(APIHandler): raise web.HTTPError(404, "No such user: %s" % user_name) token = self.find_token_by_id(user, token_id) # deleting an oauth token deletes *all* oauth tokens for that client - if isinstance(token, orm.OAuthAccessToken): - client_id = token.client_id + client_id = token.client_id + if token.client_id != "jupyterhub": tokens = [ - token for token in user.oauth_tokens if token.client_id == client_id + token for token in user.api_tokens if token.client_id == client_id ] else: tokens = [token] diff --git a/jupyterhub/app.py b/jupyterhub/app.py index 23036575..dd191206 100644 --- a/jupyterhub/app.py +++ b/jupyterhub/app.py @@ -1692,6 +1692,26 @@ class JupyterHub(Application): except orm.DatabaseSchemaMismatch as e: self.exit(e) + # ensure the default oauth client exists + if ( + not self.db.query(orm.OAuthClient) + .filter_by(identifier="jupyterhub") + .one_or_none() + ): + # create the oauth client for jupyterhub itself + # this allows us to distinguish between orphaned tokens + # (failed cascade deletion) and tokens issued by the hub + # it has no client_secret, which means it cannot be used + # to make requests + client = orm.OAuthClient( + identifier="jupyterhub", + secret="", + redirect_uri="", + description="JupyterHub", + ) + self.db.add(client) + self.db.commit() + def init_hub(self): """Load the Hub URL config""" hub_args = dict( @@ -2014,12 +2034,13 @@ class JupyterHub(Application): run periodically """ # this should be all the subclasses of Expiring - for cls in (orm.APIToken, orm.OAuthAccessToken, orm.OAuthCode): + for cls in (orm.APIToken, orm.OAuthCode): self.log.debug("Purging expired {name}s".format(name=cls.__name__)) cls.purge_expired(self.db) async def init_api_tokens(self): """Load predefined API tokens (for services) into database""" + await self._add_tokens(self.service_tokens, kind='service') await self._add_tokens(self.api_tokens, kind='user') @@ -2298,7 +2319,7 @@ class JupyterHub(Application): This should mainly be services that have been removed from configuration or renamed. """ - oauth_client_ids = set() + oauth_client_ids = {"jupyterhub"} for service in self._service_map.values(): if service.oauth_available: oauth_client_ids.add(service.oauth_client_id) diff --git a/jupyterhub/handlers/base.py b/jupyterhub/handlers/base.py index 4841b2a8..3c49577c 100644 --- a/jupyterhub/handlers/base.py +++ b/jupyterhub/handlers/base.py @@ -247,26 +247,6 @@ class BaseHandler(RequestHandler): return None return match.group(1) - def get_current_user_oauth_token(self): - """Get the current user identified by OAuth access token - - Separate from API token because OAuth access tokens - can only be used for identifying users, - not using the API. - """ - token = self.get_auth_token() - if token is None: - return None - orm_token = orm.OAuthAccessToken.find(self.db, token) - if orm_token is None: - return None - - now = datetime.utcnow() - recorded = self._record_activity(orm_token, now) - if self._record_activity(orm_token.user, now) or recorded: - self.db.commit() - return self._user_from_orm(orm_token.user) - def _record_activity(self, obj, timestamp=None): """record activity on an ORM object @@ -373,7 +353,7 @@ class BaseHandler(RequestHandler): # FIXME: scopes should give us better control than this # don't consider API requests originating from a server # to be activity from the user - if not orm_token.note.startswith("Server at "): + if not orm_token.note or not orm_token.note.startswith("Server at "): recorded = self._record_activity(orm_token.user, now) or recorded if recorded: self.db.commit() @@ -439,17 +419,10 @@ class BaseHandler(RequestHandler): def _resolve_scopes(self): self.raw_scopes = set() app_log.debug("Loading and parsing scopes") - if not self.current_user: - # check for oauth tokens as long as #3380 not merged - user_from_oauth = self.get_current_user_oauth_token() - if user_from_oauth is not None: - self.raw_scopes = {f'read:users!user={user_from_oauth.name}'} - else: - app_log.debug("No user found, no scopes loaded") - else: - api_token = self.get_token() - if api_token: - self.raw_scopes = scopes.get_scopes_for(api_token) + if self.current_user: + orm_token = self.get_token() + if orm_token: + self.raw_scopes = scopes.get_scopes_for(orm_token) else: self.raw_scopes = scopes.get_scopes_for(self.current_user) self.parsed_scopes = scopes.parse_scopes(self.raw_scopes) @@ -501,10 +474,8 @@ class BaseHandler(RequestHandler): # don't clear session tokens if not logged in, # because that could be a malicious logout request! count = 0 - for access_token in ( - self.db.query(orm.OAuthAccessToken) - .filter(orm.OAuthAccessToken.user_id == user.id) - .filter(orm.OAuthAccessToken.session_id == session_id) + for access_token in self.db.query(orm.APIToken).filter_by( + user_id=user.id, session_id=session_id ): self.db.delete(access_token) count += 1 diff --git a/jupyterhub/handlers/pages.py b/jupyterhub/handlers/pages.py index a9422699..356aae71 100644 --- a/jupyterhub/handlers/pages.py +++ b/jupyterhub/handlers/pages.py @@ -552,36 +552,32 @@ class TokenPageHandler(BaseHandler): return (token.last_activity or never, token.created or never) now = datetime.utcnow() - api_tokens = [] - for token in sorted(user.api_tokens, key=sort_key, reverse=True): - if token.expires_at and token.expires_at < now: - self.db.delete(token) - self.db.commit() - continue - api_tokens.append(token) # group oauth client tokens by client id - # AccessTokens have expires_at as an integer timestamp - now_timestamp = now.timestamp() - oauth_tokens = defaultdict(list) - for token in user.oauth_tokens: - if token.expires_at and token.expires_at < now_timestamp: - self.log.warning("Deleting expired token") + all_tokens = defaultdict(list) + for token in sorted(user.api_tokens, key=sort_key, reverse=True): + if token.expires_at and token.expires_at < now: + self.log.warning(f"Deleting expired token {token}") self.db.delete(token) self.db.commit() continue if not token.client_id: # token should have been deleted when client was deleted - self.log.warning("Deleting stale oauth token for %s", user.name) + self.log.warning("Deleting stale oauth token {token}") self.db.delete(token) self.db.commit() continue - oauth_tokens[token.client_id].append(token) + all_tokens[token.client_id].append(token) + # individually list tokens issued by jupyterhub itself + api_tokens = all_tokens.pop("jupyterhub", []) + + # group all other tokens issued under their owners # get the earliest created and latest last_activity # timestamp for a given oauth client oauth_clients = [] - for client_id, tokens in oauth_tokens.items(): + + for client_id, tokens in all_tokens.items(): created = tokens[0].created last_activity = tokens[0].last_activity for token in tokens[1:]: diff --git a/jupyterhub/oauth/provider.py b/jupyterhub/oauth/provider.py index 7dd7b160..c1369863 100644 --- a/jupyterhub/oauth/provider.py +++ b/jupyterhub/oauth/provider.py @@ -2,18 +2,18 @@ implements https://oauthlib.readthedocs.io/en/latest/oauth2/server.html """ +from datetime import timedelta + from oauthlib import uri_validate from oauthlib.oauth2 import RequestValidator from oauthlib.oauth2 import WebApplicationServer from oauthlib.oauth2.rfc6749.grant_types import authorization_code from oauthlib.oauth2.rfc6749.grant_types import base -from tornado.escape import url_escape from tornado.log import app_log from .. import orm from ..utils import compare_token from ..utils import hash_token -from ..utils import url_path_join # patch absolute-uri check # because we want to allow relative uri oauth @@ -60,6 +60,9 @@ class JupyterHubRequestValidator(RequestValidator): ) if oauth_client is None: return False + if not client_secret or not oauth_client.secret: + # disallow authentication with no secret + return False if not compare_token(oauth_client.secret, client_secret): app_log.warning("Client secret mismatch for %s", client_id) return False @@ -339,19 +342,22 @@ class JupyterHubRequestValidator(RequestValidator): .filter_by(identifier=request.client.client_id) .first() ) - orm_access_token = orm.OAuthAccessToken( - client=client, - grant_type=orm.GrantType.authorization_code, - expires_at=orm.OAuthAccessToken.now() + token['expires_in'], - refresh_token=token['refresh_token'], - # TODO: save scopes, - # scopes=scopes, + # FIXME: pick a role + # this will be empty for now + roles = list(self.db.query(orm.Role).filter_by(name='identify')) + # FIXME: support refresh tokens + # These should be in a new table + token.pop("refresh_token", None) + + # APIToken.new commits the token to the db + orm.APIToken.new( + client_id=client.identifier, + expires_in=token['expires_in'], + roles=roles, token=token['access_token'], session_id=request.session_id, user=request.user, ) - self.db.add(orm_access_token) - self.db.commit() return client.redirect_uri def validate_bearer_token(self, token, scopes, request): @@ -412,6 +418,8 @@ class JupyterHubRequestValidator(RequestValidator): ) if orm_client is None: return False + if not orm_client.secret: + return False request.client = orm_client return True @@ -574,14 +582,16 @@ class JupyterHubOAuthServer(WebApplicationServer): app_log.info(f'Creating oauth client {client_id}') else: app_log.info(f'Updating oauth client {client_id}') - orm_client.secret = hash_token(client_secret) + orm_client.secret = hash_token(client_secret) if client_secret else "" orm_client.redirect_uri = redirect_uri orm_client.description = description self.db.commit() def fetch_by_client_id(self, client_id): """Find a client by its id""" - return self.db.query(orm.OAuthClient).filter_by(identifier=client_id).first() + client = self.db.query(orm.OAuthClient).filter_by(identifier=client_id).first() + if client and client.secret: + return client def make_provider(session_factory, url_prefix, login_url, **oauth_server_kwargs): diff --git a/jupyterhub/orm.py b/jupyterhub/orm.py index 6e6e8693..c35a7255 100644 --- a/jupyterhub/orm.py +++ b/jupyterhub/orm.py @@ -277,9 +277,6 @@ class User(Base): last_activity = Column(DateTime, nullable=True) api_tokens = relationship("APIToken", backref="user", cascade="all, delete-orphan") - oauth_tokens = relationship( - "OAuthAccessToken", backref="user", cascade="all, delete-orphan" - ) oauth_codes = relationship( "OAuthCode", backref="user", cascade="all, delete-orphan" ) @@ -485,7 +482,9 @@ class Hashed(Expiring): @classmethod def check_token(cls, db, token): """Check if a token is acceptable""" + print("checking", cls, token, len(token), cls.min_length) if len(token) < cls.min_length: + print("raising") raise ValueError( "Tokens must be at least %i characters, got %r" % (cls.min_length, token) @@ -530,14 +529,34 @@ class Hashed(Expiring): return orm_token +# ------------------------------------ +# OAuth tables +# ------------------------------------ + + +class GrantType(enum.Enum): + # we only use authorization_code for now + authorization_code = 'authorization_code' + implicit = 'implicit' + password = 'password' + client_credentials = 'client_credentials' + refresh_token = 'refresh_token' + + class APIToken(Hashed, Base): """An API token""" __tablename__ = 'api_tokens' - user_id = Column(Integer, ForeignKey('users.id', ondelete="CASCADE"), nullable=True) + user_id = Column( + Integer, + ForeignKey('users.id', ondelete="CASCADE"), + nullable=True, + ) service_id = Column( - Integer, ForeignKey('services.id', ondelete="CASCADE"), nullable=True + Integer, + ForeignKey('services.id', ondelete="CASCADE"), + nullable=True, ) id = Column(Integer, primary_key=True) @@ -548,6 +567,26 @@ class APIToken(Hashed, Base): def api_id(self): return 'a%i' % self.id + # added in 2.0 + client_id = Column( + Unicode(255), + ForeignKey( + 'oauth_clients.identifier', + ondelete='CASCADE', + ), + ) + # FIXME: refresh_tokens not implemented + # should be a relation to another token table + # refresh_token = Column( + # Integer, + # ForeignKey('refresh_tokens.id', ondelete="CASCADE"), + # nullable=True, + # ) + + # the browser session id associated with a given token, + # if issued during oauth to be stored in a cookie + session_id = Column(Unicode(255), nullable=True) + # token metadata for bookkeeping now = datetime.utcnow # for expiry created = Column(DateTime, default=datetime.utcnow) @@ -566,8 +605,12 @@ class APIToken(Hashed, Base): # this shouldn't happen kind = 'owner' name = 'unknown' - return "<{cls}('{pre}...', {kind}='{name}')>".format( - cls=self.__class__.__name__, pre=self.prefix, kind=kind, name=name + return "<{cls}('{pre}...', {kind}='{name}', client_id={client_id!r})>".format( + cls=self.__class__.__name__, + pre=self.prefix, + kind=kind, + name=name, + client_id=self.client_id, ) @classmethod @@ -588,6 +631,14 @@ class APIToken(Hashed, Base): raise ValueError("kind must be 'user', 'service', or None, not %r" % kind) for orm_token in prefix_match: if orm_token.match(token): + if not orm_token.client_id: + app_log.warning( + "Deleting stale oauth token for %s with no client", + orm_token.user and orm_token.user.name, + ) + db.delete(orm_token) + db.commit() + return return orm_token @classmethod @@ -599,7 +650,10 @@ class APIToken(Hashed, Base): roles=None, note='', generated=True, + session_id=None, expires_in=None, + client_id='jupyterhub', + return_orm=False, ): """Generate a new API token for a user or service""" assert user or service @@ -614,7 +668,12 @@ class APIToken(Hashed, Base): cls.check_token(db, token) # two stages to ensure orm_token.generated has been set # before token setter is called - orm_token = cls(generated=generated, note=note or '') + orm_token = cls( + generated=generated, + note=note or '', + client_id=client_id, + session_id=session_id, + ) orm_token.token = token if user: assert user.id is not None @@ -641,76 +700,6 @@ class APIToken(Hashed, Base): return token -# ------------------------------------ -# OAuth tables -# ------------------------------------ - - -class GrantType(enum.Enum): - # we only use authorization_code for now - authorization_code = 'authorization_code' - implicit = 'implicit' - password = 'password' - client_credentials = 'client_credentials' - refresh_token = 'refresh_token' - - -class OAuthAccessToken(Hashed, Base): - __tablename__ = 'oauth_access_tokens' - id = Column(Integer, primary_key=True, autoincrement=True) - - @staticmethod - def now(): - return datetime.utcnow().timestamp() - - @property - def api_id(self): - return 'o%i' % self.id - - client_id = Column( - Unicode(255), ForeignKey('oauth_clients.identifier', ondelete='CASCADE') - ) - grant_type = Column(Enum(GrantType), nullable=False) - expires_at = Column(Integer) - refresh_token = Column(Unicode(255)) - # TODO: drop refresh_expires_at. Refresh tokens shouldn't expire - refresh_expires_at = Column(Integer) - user_id = Column(Integer, ForeignKey('users.id', ondelete='CASCADE')) - service = None # for API-equivalence with APIToken - - # the browser session id associated with a given token - session_id = Column(Unicode(255)) - - # from Hashed - hashed = Column(Unicode(255), unique=True) - prefix = Column(Unicode(16), index=True) - - created = Column(DateTime, default=datetime.utcnow) - last_activity = Column(DateTime, nullable=True) - - def __repr__(self): - return "<{cls}('{prefix}...', client_id={client_id!r}, user={user!r}, expires_in={expires_in}>".format( - cls=self.__class__.__name__, - client_id=self.client_id, - user=self.user and self.user.name, - prefix=self.prefix, - expires_in=self.expires_in, - ) - - @classmethod - def find(cls, db, token): - orm_token = super().find(db, token) - if orm_token and not orm_token.client_id: - app_log.warning( - "Deleting stale oauth token for %s with no client", - orm_token.user and orm_token.user.name, - ) - db.delete(orm_token) - db.commit() - return - return orm_token - - class OAuthCode(Expiring, Base): __tablename__ = 'oauth_codes' @@ -752,7 +741,7 @@ class OAuthClient(Base): return self.identifier access_tokens = relationship( - OAuthAccessToken, backref='client', cascade='all, delete-orphan' + APIToken, backref='client', cascade='all, delete-orphan' ) codes = relationship(OAuthCode, backref='client', cascade='all, delete-orphan') diff --git a/jupyterhub/services/service.py b/jupyterhub/services/service.py index 44fd763c..c72ae382 100644 --- a/jupyterhub/services/service.py +++ b/jupyterhub/services/service.py @@ -51,6 +51,7 @@ from traitlets import Dict from traitlets import HasTraits from traitlets import Instance from traitlets import Unicode +from traitlets import validate from traitlets.config import LoggingConfigurable from .. import orm @@ -284,6 +285,15 @@ class Service(LoggingConfigurable): def _default_client_id(self): return 'service-%s' % self.name + @validate("oauth_client_id") + def _validate_client_id(self, proposal): + if not proposal.value.startswith("service-"): + raise ValueError( + f"service {self.name} has oauth_client_id='{proposal.value}'." + " Service oauth client ids must start with 'service-'" + ) + return proposal.value + oauth_redirect_uri = Unicode( help="""OAuth redirect URI for this service. diff --git a/jupyterhub/tests/conftest.py b/jupyterhub/tests/conftest.py index eb203029..f439bfb0 100644 --- a/jupyterhub/tests/conftest.py +++ b/jupyterhub/tests/conftest.py @@ -125,7 +125,11 @@ def db(): """Get a db session""" global _db if _db is None: - _db = orm.new_session_factory('sqlite:///:memory:')() + # make sure some initial db contents are filled out + # specifically, the 'default' jupyterhub oauth client + app = MockHub(db_url='sqlite:///:memory:') + app.init_db() + _db = app.db user = orm.User(name=getuser()) _db.add(user) _db.commit() @@ -164,9 +168,14 @@ def cleanup_after(request, io_loop): allows cleanup of servers between tests without having to launch a whole new app """ + try: yield finally: + if _db is not None: + # cleanup after failed transactions + _db.rollback() + if not MockHub.initialized(): return app = MockHub.instance() diff --git a/jupyterhub/tests/populate_db.py b/jupyterhub/tests/populate_db.py index 2b5c6007..4504a13c 100644 --- a/jupyterhub/tests/populate_db.py +++ b/jupyterhub/tests/populate_db.py @@ -6,6 +6,7 @@ used in test_db.py """ import os from datetime import datetime +from functools import partial import jupyterhub from jupyterhub import orm @@ -62,32 +63,35 @@ def populate_db(url): db.commit() # create some oauth objects - if jupyterhub.version_info >= (0, 8): - # create oauth client - client = orm.OAuthClient(identifier='oauth-client') - db.add(client) - db.commit() - code = orm.OAuthCode(client_id=client.identifier) - db.add(code) - db.commit() - access_token = orm.OAuthAccessToken( - client_id=client.identifier, - user_id=user.id, + client = orm.OAuthClient(identifier='oauth-client') + db.add(client) + db.commit() + code = orm.OAuthCode(client_id=client.identifier) + db.add(code) + db.commit() + if jupyterhub.version_info < (2, 0): + Token = partial( + orm.OAuthAccessToken, grant_type=orm.GrantType.authorization_code, ) - db.add(access_token) - db.commit() + else: + Token = orm.APIToken + access_token = Token( + client_id=client.identifier, + user_id=user.id, + ) + db.add(access_token) + db.commit() # set some timestamps added in 0.9 - if jupyterhub.version_info >= (0, 9): - assert user.created - assert admin.created - # set last_activity - user.last_activity = datetime.utcnow() - spawner = user.orm_spawners[''] - spawner.started = datetime.utcnow() - spawner.last_activity = datetime.utcnow() - db.commit() + assert user.created + assert admin.created + # set last_activity + user.last_activity = datetime.utcnow() + spawner = user.orm_spawners[''] + spawner.started = datetime.utcnow() + spawner.last_activity = datetime.utcnow() + db.commit() if __name__ == '__main__': diff --git a/jupyterhub/tests/test_api.py b/jupyterhub/tests/test_api.py index e22ee51c..2881a69c 100644 --- a/jupyterhub/tests/test_api.py +++ b/jupyterhub/tests/test_api.py @@ -273,11 +273,10 @@ async def test_get_self(app): oauth_client = orm.OAuthClient(identifier='eurydice') db.add(oauth_client) db.commit() - oauth_token = orm.OAuthAccessToken( + oauth_token = orm.APIToken( user=u.orm_user, client=oauth_client, token=token, - grant_type=orm.GrantType.authorization_code, ) db.add(oauth_token) db.commit() @@ -1423,12 +1422,11 @@ async def test_token_list(app, as_user, for_user, status): if status != 200: return reply = r.json() - assert sorted(reply) == ['api_tokens', 'oauth_tokens'] + assert sorted(reply) == ['api_tokens'] assert len(reply['api_tokens']) == len(for_user_obj.api_tokens) assert all(token['user'] == for_user for token in reply['api_tokens']) - assert all(token['user'] == for_user for token in reply['oauth_tokens']) # validate individual token ids - for token in reply['api_tokens'] + reply['oauth_tokens']: + for token in reply['api_tokens']: r = await api_request( app, 'users', for_user, 'tokens', token['id'], headers=headers ) diff --git a/jupyterhub/tests/test_db.py b/jupyterhub/tests/test_db.py index beb63099..77231f97 100644 --- a/jupyterhub/tests/test_db.py +++ b/jupyterhub/tests/test_db.py @@ -36,7 +36,7 @@ def generate_old_db(env_dir, hub_version, db_url): check_call([env_py, populate_db, db_url]) -@pytest.mark.parametrize('hub_version', ['0.7.2', '0.8.1', '0.9.4']) +@pytest.mark.parametrize('hub_version', ['1.0.0', "1.2.2", "1.3.0"]) async def test_upgrade(tmpdir, hub_version): db_url = os.getenv('JUPYTERHUB_TEST_DB_URL') if db_url: diff --git a/jupyterhub/tests/test_orm.py b/jupyterhub/tests/test_orm.py index c761a040..093c29be 100644 --- a/jupyterhub/tests/test_orm.py +++ b/jupyterhub/tests/test_orm.py @@ -355,8 +355,9 @@ def test_user_delete_cascade(db): spawner.server = server = orm.Server() oauth_code = orm.OAuthCode(client=oauth_client, user=user) db.add(oauth_code) - oauth_token = orm.OAuthAccessToken( - client=oauth_client, user=user, grant_type=orm.GrantType.authorization_code + oauth_token = orm.APIToken( + client=oauth_client, + user=user, ) db.add(oauth_token) db.commit() @@ -377,7 +378,7 @@ def test_user_delete_cascade(db): assert_not_found(db, orm.Spawner, spawner_id) assert_not_found(db, orm.Server, server_id) assert_not_found(db, orm.OAuthCode, oauth_code_id) - assert_not_found(db, orm.OAuthAccessToken, oauth_token_id) + assert_not_found(db, orm.APIToken, oauth_token_id) def test_oauth_client_delete_cascade(db): @@ -391,12 +392,13 @@ def test_oauth_client_delete_cascade(db): # these should all be deleted automatically when the user goes away oauth_code = orm.OAuthCode(client=oauth_client, user=user) db.add(oauth_code) - oauth_token = orm.OAuthAccessToken( - client=oauth_client, user=user, grant_type=orm.GrantType.authorization_code + oauth_token = orm.APIToken( + client=oauth_client, + user=user, ) db.add(oauth_token) db.commit() - assert user.oauth_tokens == [oauth_token] + assert user.api_tokens == [oauth_token] # record all of the ids oauth_code_id = oauth_code.id @@ -408,8 +410,8 @@ def test_oauth_client_delete_cascade(db): # verify that everything gets deleted assert_not_found(db, orm.OAuthCode, oauth_code_id) - assert_not_found(db, orm.OAuthAccessToken, oauth_token_id) - assert user.oauth_tokens == [] + assert_not_found(db, orm.APIToken, oauth_token_id) + assert user.api_tokens == [] assert user.oauth_codes == [] @@ -510,32 +512,31 @@ def test_expiring_api_token(app, user): def test_expiring_oauth_token(app, user): db = app.db token = "abc123" - now = orm.OAuthAccessToken.now + now = orm.APIToken.now client = orm.OAuthClient(identifier="xxx", secret="yyy") db.add(client) - orm_token = orm.OAuthAccessToken( + orm_token = orm.APIToken( token=token, - grant_type=orm.GrantType.authorization_code, client=client, user=user, - expires_at=now() + 30, + expires_at=now() + timedelta(seconds=30), ) db.add(orm_token) db.commit() - found = orm.OAuthAccessToken.find(db, token) + found = orm.APIToken.find(db, token) assert found is orm_token # purge_expired doesn't delete non-expired - orm.OAuthAccessToken.purge_expired(db) - found = orm.OAuthAccessToken.find(db, token) + orm.APIToken.purge_expired(db) + found = orm.APIToken.find(db, token) assert found is orm_token - with mock.patch.object(orm.OAuthAccessToken, 'now', lambda: now() + 60): - found = orm.OAuthAccessToken.find(db, token) + with mock.patch.object(orm.APIToken, 'now', lambda: now() + timedelta(seconds=60)): + found = orm.APIToken.find(db, token) assert found is None - assert orm_token in db.query(orm.OAuthAccessToken) - orm.OAuthAccessToken.purge_expired(db) - assert orm_token not in db.query(orm.OAuthAccessToken) + assert orm_token in db.query(orm.APIToken) + orm.APIToken.purge_expired(db) + assert orm_token not in db.query(orm.APIToken) def test_expiring_oauth_code(app, user): diff --git a/jupyterhub/tests/test_pages.py b/jupyterhub/tests/test_pages.py index 7140b823..1bba9926 100644 --- a/jupyterhub/tests/test_pages.py +++ b/jupyterhub/tests/test_pages.py @@ -869,8 +869,9 @@ async def test_oauth_token_page(app): user = app.users[orm.User.find(app.db, name)] client = orm.OAuthClient(identifier='token') app.db.add(client) - oauth_token = orm.OAuthAccessToken( - client=client, user=user, grant_type=orm.GrantType.authorization_code + oauth_token = orm.APIToken( + client=client, + user=user, ) app.db.add(oauth_token) app.db.commit() diff --git a/jupyterhub/tests/test_services_auth.py b/jupyterhub/tests/test_services_auth.py index b41d50db..540150bb 100644 --- a/jupyterhub/tests/test_services_auth.py +++ b/jupyterhub/tests/test_services_auth.py @@ -444,11 +444,7 @@ async def test_oauth_logout(app, mockservice_url): def auth_tokens(): """Return list of OAuth access tokens for the user""" - return list( - app.db.query(orm.OAuthAccessToken).filter( - orm.OAuthAccessToken.user_id == app_user.id - ) - ) + return list(app.db.query(orm.APIToken).filter_by(user_id=app_user.id)) # ensure we start empty assert auth_tokens() == []