diff --git a/.travis.yml b/.travis.yml index 0dae176e..249f7f30 100644 --- a/.travis.yml +++ b/.travis.yml @@ -10,7 +10,7 @@ before_install: - npm install -g configurable-http-proxy - git clone --quiet --depth 1 https://github.com/minrk/travis-wheels travis-wheels install: - - pip install --pre -f travis-wheels/wheelhouse -r dev-requirements.txt . + - pip install -v --pre -f travis-wheels/wheelhouse -r dev-requirements.txt . script: - travis_retry py.test --cov jupyterhub jupyterhub/tests -v after_success: diff --git a/dev-requirements.txt b/dev-requirements.txt index 12bd80c1..a118ce73 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -1,4 +1,5 @@ -r requirements.txt +mock codecov pytest-cov pytest>=2.8 diff --git a/jupyterhub/alembic/versions/af4cbdb2d13c_services.py b/jupyterhub/alembic/versions/af4cbdb2d13c_services.py new file mode 100644 index 00000000..ebf2f851 --- /dev/null +++ b/jupyterhub/alembic/versions/af4cbdb2d13c_services.py @@ -0,0 +1,25 @@ +"""services + +Revision ID: af4cbdb2d13c +Revises: eeb276e51423 +Create Date: 2016-07-28 16:16:38.245348 + +""" + +# revision identifiers, used by Alembic. +revision = 'af4cbdb2d13c' +down_revision = 'eeb276e51423' +branch_labels = None +depends_on = None + +from alembic import op +import sqlalchemy as sa + + +def upgrade(): + op.add_column('api_tokens', sa.Column('service_id', sa.Integer)) + + +def downgrade(): + # sqlite cannot downgrade because of limited ALTER TABLE support (no DROP COLUMN) + op.drop_column('api_tokens', 'service_id') diff --git a/jupyterhub/app.py b/jupyterhub/app.py index 79c5c3ec..26151771 100644 --- a/jupyterhub/app.py +++ b/jupyterhub/app.py @@ -384,9 +384,28 @@ class JupyterHub(Application): ).tag(config=True) api_tokens = Dict(Unicode(), - help="""Dict of token:username to be loaded into the database. + help="""PENDING DEPRECATION: consider using service_tokens + + Dict of token:username to be loaded into the database. - Allows ahead-of-time generation of API tokens for use by services. + Allows ahead-of-time generation of API tokens for use by externally managed services, + which authenticate as JupyterHub users. + + Consider using service_tokens for general services that talk to the JupyterHub API. + """ + ).tag(config=True) + @observe('api_tokens') + def _deprecate_api_tokens(self, change): + self.log.warn("JupyterHub.api_tokens is pending deprecation." + " Consider using JupyterHub.service_tokens." + " If you have a use case for services that identify as users," + " let us know: https://github.com/jupyterhub/jupyterhub/issues" + ) + + service_tokens = Dict(Unicode(), + help="""Dict of token:servicename to be loaded into the database. + + Allows ahead-of-time generation of API tokens for use by externally managed services. """ ).tag(config=True) @@ -864,38 +883,51 @@ class JupyterHub(Application): group.users.append(user) db.commit() - def init_api_tokens(self): - """Load predefined API tokens (for services) into database""" + def _add_tokens(self, token_dict, kind): + """Add tokens for users or services to the database""" + if kind == 'user': + Class = orm.User + elif kind == 'service': + Class = orm.Service + else: + raise ValueError("kind must be user or service, not %r" % kind) + db = self.db - for token, username in self.api_tokens.items(): - username = self.authenticator.normalize_username(username) - if not self.authenticator.check_whitelist(username): - raise ValueError("Token username %r is not in whitelist" % username) - if not self.authenticator.validate_username(username): - raise ValueError("Token username %r is not valid" % username) + for token, name in token_dict.items(): + if kind == 'user': + name = self.authenticator.normalize_username(name) + if not self.authenticator.check_whitelist(name): + raise ValueError("Token name %r is not in whitelist" % name) + if not self.authenticator.validate_username(name): + raise ValueError("Token name %r is not valid" % name) orm_token = orm.APIToken.find(db, token) if orm_token is None: - user = orm.User.find(db, username) - user_created = False - if user is None: - user_created = True - self.log.debug("Adding user %r to database", username) - user = orm.User(name=username) - db.add(user) + obj = Class.find(db, name) + created = False + if obj is None: + created = True + self.log.debug("Adding %s %r to database", kind, name) + obj = Class(name=name) + db.add(obj) db.commit() - self.log.info("Adding API token for %s", username) + self.log.info("Adding API token for %s: %s", kind, name) try: - user.new_api_token(token) + obj.new_api_token(token) except Exception: - if user_created: + if created: # don't allow bad tokens to create users - db.delete(user) + db.delete(obj) db.commit() raise else: self.log.debug("Not duplicating token %s", orm_token) db.commit() + def init_api_tokens(self): + """Load predefined API tokens (for services) into database""" + self._add_tokens(self.service_tokens, kind='service') + self._add_tokens(self.api_tokens, kind='user') + @gen.coroutine def init_spawners(self): db = self.db diff --git a/jupyterhub/handlers/base.py b/jupyterhub/handlers/base.py index 092049d7..07651787 100644 --- a/jupyterhub/handlers/base.py +++ b/jupyterhub/handlers/base.py @@ -144,7 +144,7 @@ class BaseHandler(RequestHandler): if orm_token is None: return None else: - return orm_token.user + return orm_token.user or orm_token.service def _user_for_cookie(self, cookie_name, cookie_value=None): """Get the User for a given cookie, if there is one""" diff --git a/jupyterhub/orm.py b/jupyterhub/orm.py index 7bc2a35c..f7241b9a 100644 --- a/jupyterhub/orm.py +++ b/jupyterhub/orm.py @@ -302,7 +302,7 @@ class User(Base): """ __tablename__ = 'users' id = Column(Integer, primary_key=True, autoincrement=True) - name = Column(Unicode(1023)) + name = Column(Unicode(1023), unique=True) # should we allow multiple servers per user? _server_id = Column(Integer, ForeignKey('servers.id', ondelete="SET NULL")) server = relationship(Server, primaryjoin=_server_id == Server.id) @@ -340,21 +340,7 @@ class User(Base): If `token` is given, load that token. """ - assert self.id is not None - db = inspect(self).session - if token is None: - token = new_token() - else: - if len(token) < 8: - raise ValueError("Tokens must be at least 8 characters, got %r" % token) - found = APIToken.find(db, token) - if found: - raise ValueError("Collision on token: %s..." % token[:4]) - orm_token = APIToken(user_id=self.id) - orm_token.token = token - db.add(orm_token) - db.commit() - return token + return APIToken.new(token=token, user=self) @classmethod def find(cls, db, name): @@ -364,13 +350,73 @@ class User(Base): """ return db.query(cls).filter(cls.name==name).first() + +# service:server many:many mapping table +service_server_map = Table('service_server_map', Base.metadata, + Column('service_id', ForeignKey('services.id')), + Column('server_id', ForeignKey('servers.id')), +) + + +class Service(Base): + """A service run with JupyterHub + + A service is similar to a User without a Spawner. + A service can have API tokens for accessing the Hub's API + + It has: + + - name + - admin + - api tokens + + In addition to what it has in common with users, a Service has extra info: + + - servers: list of HTTP endpoints for the service + - pid: the process id (if managed) + + """ + __tablename__ = 'services' + id = Column(Integer, primary_key=True, autoincrement=True) + + # common user interface: + name = Column(Unicode(1023), unique=True) + admin = Column(Boolean, default=False) + + api_tokens = relationship("APIToken", backref="service") + + # service-specific interface + servers = relationship('Server', secondary='service_server_map') + pid = Column(Integer) + + def new_api_token(self, token=None): + """Create a new API token + + If `token` is given, load that token. + """ + return APIToken.new(token=token, service=self) + + @classmethod + def find(cls, db, name): + """Find a service by name. + + Returns None if not found. + """ + return db.query(cls).filter(cls.name==name).first() + + class APIToken(Base): """An API token""" __tablename__ = 'api_tokens' - + + # _constraint = ForeignKeyConstraint(['user_id', 'server_id'], ['users.id', 'services.id']) @declared_attr def user_id(cls): - return Column(Integer, ForeignKey('users.id', ondelete="CASCADE")) + return Column(Integer, ForeignKey('users.id', ondelete="CASCADE"), nullable=True) + + @declared_attr + def service_id(cls): + return Column(Integer, ForeignKey('services.id', ondelete="CASCADE"), nullable=True) id = Column(Integer, primary_key=True) hashed = Column(Unicode(1023)) @@ -391,22 +437,42 @@ class APIToken(Base): self.hashed = hash_token(token, rounds=self.rounds, salt=self.salt_bytes, algorithm=self.algorithm) def __repr__(self): - return "<{cls}('{pre}...', user='{u}')>".format( + if self.user is not None: + kind = 'user' + name = self.user.name + elif self.service is not None: + kind = 'service' + name = self.service.name + else: + # this shouldn't happen + kind = 'owner' + name = 'unknown' + return "<{cls}('{pre}...', {kind}='{name}')>".format( cls=self.__class__.__name__, pre=self.prefix, - u=self.user.name, + kind=kind, + name=name, ) @classmethod - def find(cls, db, token): + def find(cls, db, token, *, kind=None): """Find a token object by value. Returns None if not found. + + `kind='user'` only returns API tokens for users + `kind='service'` only returns API tokens for services """ prefix = token[:cls.prefix_length] # since we can't filter on hashed values, filter on prefix # so we aren't comparing with all tokens prefix_match = db.query(cls).filter(bindparam('prefix', prefix).startswith(cls.prefix)) + if kind == 'user': + prefix_match = prefix_match.filter(cls.user_id != None) + elif kind == 'service': + prefix_match = prefix_match.filter(cls.service_id != None) + elif kind is not None: + raise ValueError("kind must be 'user', 'service', or None, not %r" % kind) for orm_token in prefix_match: if orm_token.match(token): return orm_token @@ -415,6 +481,31 @@ class APIToken(Base): """Is this my token?""" return compare_token(self.hashed, token) + @classmethod + def new(cls, token=None, user=None, service=None): + """Generate a new API token for a user or service""" + assert user or service + assert not (user and service) + db = inspect(user or service).session + if token is None: + token = new_token() + else: + if len(token) < 8: + raise ValueError("Tokens must be at least 8 characters, got %r" % token) + found = APIToken.find(db, token) + if found: + raise ValueError("Collision on token: %s..." % token[:4]) + orm_token = APIToken(token=token) + if user: + assert user.id is not None + orm_token.user_id = user.id + else: + assert service.id is not None + orm_token.service_id = service.id + db.add(orm_token) + db.commit() + return token + def new_session_factory(url="sqlite:///:memory:", reset=False, **kwargs): """Create a new session at url""" diff --git a/jupyterhub/tests/test_api.py b/jupyterhub/tests/test_api.py index 8b08fabf..493f46d1 100644 --- a/jupyterhub/tests/test_api.py +++ b/jupyterhub/tests/test_api.py @@ -42,8 +42,13 @@ def find_user(db, name): return db.query(orm.User).filter(orm.User.name==name).first() def add_user(db, app=None, **kwargs): - orm_user = orm.User(**kwargs) - db.add(orm_user) + orm_user = find_user(db, name=kwargs.get('name')) + if orm_user is None: + orm_user = orm.User(**kwargs) + db.add(orm_user) + else: + for attr, value in kwargs.items(): + setattr(orm_user, attr, value) db.commit() if app: user = app.users[orm_user.id] = User(orm_user, app.tornado_settings) diff --git a/jupyterhub/tests/test_app.py b/jupyterhub/tests/test_app.py index 75648473..d7b20d48 100644 --- a/jupyterhub/tests/test_app.py +++ b/jupyterhub/tests/test_app.py @@ -55,8 +55,7 @@ def test_generate_config(): assert 'Spawner.cmd' in cfg_text assert 'Authenticator.whitelist' in cfg_text - -def test_init_tokens(): +def test_init_tokens(io_loop): with TemporaryDirectory() as td: db_file = os.path.join(td, 'jupyterhub.sqlite') tokens = { @@ -64,8 +63,8 @@ def test_init_tokens(): 'also-super-secret': 'gordon', 'boagasdfasdf': 'chell', } - app = MockHub(db_file=db_file, api_tokens=tokens) - app.initialize([]) + app = MockHub(db_url=db_file, api_tokens=tokens) + io_loop.run_sync(lambda : app.initialize([])) db = app.db for token, username in tokens.items(): api_token = orm.APIToken.find(db, token) @@ -74,8 +73,8 @@ def test_init_tokens(): assert user.name == username # simulate second startup, reloading same tokens: - app = MockHub(db_file=db_file, api_tokens=tokens) - app.initialize([]) + app = MockHub(db_url=db_file, api_tokens=tokens) + io_loop.run_sync(lambda : app.initialize([])) db = app.db for token, username in tokens.items(): api_token = orm.APIToken.find(db, token) @@ -85,9 +84,9 @@ def test_init_tokens(): # don't allow failed token insertion to create users: tokens['short'] = 'gman' - app = MockHub(db_file=db_file, api_tokens=tokens) - # with pytest.raises(ValueError): - app.initialize([]) + app = MockHub(db_url=db_file, api_tokens=tokens) + with pytest.raises(ValueError): + io_loop.run_sync(lambda : app.initialize([])) assert orm.User.find(app.db, 'gman') is None diff --git a/jupyterhub/tests/test_db.py b/jupyterhub/tests/test_db.py index 4d844d04..d2955e06 100644 --- a/jupyterhub/tests/test_db.py +++ b/jupyterhub/tests/test_db.py @@ -23,7 +23,7 @@ def test_upgrade(tmpdir): print(db_url) upgrade(db_url) -def test_upgrade_entrypoint(tmpdir): +def test_upgrade_entrypoint(tmpdir, io_loop): generate_old_db(str(tmpdir)) tmpdir.chdir() tokenapp = NewToken() @@ -32,7 +32,7 @@ def test_upgrade_entrypoint(tmpdir): tokenapp.start() upgradeapp = UpgradeDB() - upgradeapp.initialize([]) + io_loop.run_sync(lambda : upgradeapp.initialize([])) upgradeapp.start() # run tokenapp again, it should work diff --git a/jupyterhub/tests/test_orm.py b/jupyterhub/tests/test_orm.py index d4843492..6a369ad5 100644 --- a/jupyterhub/tests/test_orm.py +++ b/jupyterhub/tests/test_orm.py @@ -90,6 +90,8 @@ def test_tokens(db): assert len(user.api_tokens) == 2 found = orm.APIToken.find(db, token=token) assert found.match(token) + assert found.user is user + assert found.service is None found = orm.APIToken.find(db, 'something else') assert found is None @@ -104,6 +106,69 @@ def test_tokens(db): assert len(user.api_tokens) == 3 +def test_service_tokens(db): + service = orm.Service(name='secret') + db.add(service) + db.commit() + token = service.new_api_token() + assert any(t.match(token) for t in service.api_tokens) + service.new_api_token() + assert len(service.api_tokens) == 2 + found = orm.APIToken.find(db, token=token) + assert found.match(token) + assert found.user is None + assert found.service is service + service2 = orm.Service(name='secret') + db.add(service) + db.commit() + assert service2.id != service.id + + +def test_service_servers(db): + service = orm.Service(name='has_servers') + db.add(service) + db.commit() + + assert service.servers == [] + servers = service.servers = [ + orm.Server(), + orm.Server(), + ] + assert [ s.id for s in servers ] == [ None, None ] + db.commit() + assert [ type(s.id) for s in servers ] == [ int, int ] + + +def test_token_find(db): + service = db.query(orm.Service).first() + user = db.query(orm.User).first() + service_token = service.new_api_token() + user_token = user.new_api_token() + with pytest.raises(ValueError): + orm.APIToken.find(db, 'irrelevant', kind='richard') + # no kind, find anything + found = orm.APIToken.find(db, token=user_token) + assert found + assert found.match(user_token) + found = orm.APIToken.find(db, token=service_token) + assert found + assert found.match(service_token) + + # kind=user, only find user tokens + found = orm.APIToken.find(db, token=user_token, kind='user') + assert found + assert found.match(user_token) + found = orm.APIToken.find(db, token=service_token, kind='user') + assert found is None + + # kind=service, only find service tokens + found = orm.APIToken.find(db, token=service_token, kind='service') + assert found + assert found.match(service_token) + found = orm.APIToken.find(db, token=user_token, kind='service') + assert found is None + + def test_spawn_fails(db, io_loop): orm_user = orm.User(name='aeofel') db.add(orm_user) @@ -126,7 +191,7 @@ def test_spawn_fails(db, io_loop): def test_groups(db): - user = orm.User(name='aeofel') + user = orm.User.find(db, name='aeofel') db.add(user) group = orm.Group(name='lives')