diff --git a/jupyterhub/app.py b/jupyterhub/app.py index 2c975234..cdc8e13f 100644 --- a/jupyterhub/app.py +++ b/jupyterhub/app.py @@ -352,6 +352,13 @@ class JupyterHub(Application): cookie_secret_file = Unicode('jupyterhub_cookie_secret', help="""File in which to store the cookie secret.""" ).tag(config=True) + + api_tokens = Dict(Unicode(), + help="""Dict of token:username to be loaded into the database. + + Allows ahead-of-time generation of API tokens for use by services. + """ + ).tag(config=True) authenticator_class = Type(PAMAuthenticator, Authenticator, help="""Class for authenticating users. @@ -794,6 +801,28 @@ class JupyterHub(Application): # From this point on, any user changes should be done simultaneously # to the whitelist set and user db, unless the whitelist is empty (all users allowed). + def init_api_tokens(self): + """Load predefined API tokens (for services) into database""" + db = self.db + for token, username in self.api_tokens.items(): + username = self.authenticator.normalize_username(username) + if not self.authenticator.check_whitelist(username): + raise ValueError("Token username %r is not in whitelist" % username) + if not self.authenticator.validate_username(username): + raise ValueError("Token username %r is not valid" % username) + orm_token = orm.APIToken.find(db, token) + if orm_token is None: + user = orm.User.find(db, username) + if user is None: + self.log.debug("Adding user %r to database", username) + user = orm.User(name=username) + db.add(user) + db.commit() + self.log.info("Adding API token for %s", username) + user.new_api_token(token) + else: + self.log.debug("Not duplicating token %s", orm_token) + db.commit() @gen.coroutine def init_spawners(self): @@ -1055,6 +1084,7 @@ class JupyterHub(Application): self.init_hub() self.init_proxy() yield self.init_users() + self.init_api_tokens() self.init_tornado_settings() yield self.init_spawners() self.init_handlers() diff --git a/jupyterhub/orm.py b/jupyterhub/orm.py index a0e0cfb0..1e7b7390 100644 --- a/jupyterhub/orm.py +++ b/jupyterhub/orm.py @@ -303,11 +303,19 @@ class User(Base): name=self.name, ) - def new_api_token(self): - """Create a new API token""" + def new_api_token(self, token=None): + """Create a new API token + + If `token` is given, load that token. + """ assert self.id is not None db = inspect(self).session - token = new_token() + if token is None: + token = new_token() + else: + found = APIToken.find(db, token) + if found: + raise ValueError("Collision on token: %s..." % token[:4]) orm_token = APIToken(user_id=self.id) orm_token.token = token db.add(orm_token) diff --git a/jupyterhub/tests/test_app.py b/jupyterhub/tests/test_app.py index 280c2e4b..832b91ac 100644 --- a/jupyterhub/tests/test_app.py +++ b/jupyterhub/tests/test_app.py @@ -6,6 +6,7 @@ import sys from subprocess import check_output, Popen, PIPE from tempfile import NamedTemporaryFile, TemporaryDirectory from .mocking import MockHub +from .. import orm def test_help_all(): out = check_output([sys.executable, '-m', 'jupyterhub', '--help-all']).decode('utf8', 'replace') @@ -48,3 +49,30 @@ def test_generate_config(): assert cfg_file in out assert 'Spawner.cmd' in cfg_text assert 'Authenticator.whitelist' in cfg_text + +def test_init_tokens(): + with TemporaryDirectory() as td: + db_file = os.path.join(td, 'jupyterhub.sqlite') + tokens = { + 'super-secret-token': 'alyx', + 'also-super-secret': 'gordon', + 'boagasdfasdf': 'chell', + } + app = MockHub(db_file=db_file, api_tokens=tokens) + app.initialize([]) + db = app.db + for token, username in tokens.items(): + api_token = orm.APIToken.find(db, token) + assert api_token is not None + user = api_token.user + assert user.name == username + + # simulate second startup, reloading same tokens: + app = MockHub(db_file=db_file, api_tokens=tokens) + app.initialize([]) + db = app.db + for token, username in tokens.items(): + api_token = orm.APIToken.find(db, token) + assert api_token is not None + user = api_token.user + assert user.name == username diff --git a/jupyterhub/tests/test_orm.py b/jupyterhub/tests/test_orm.py index 983c4d12..1243f622 100644 --- a/jupyterhub/tests/test_orm.py +++ b/jupyterhub/tests/test_orm.py @@ -93,6 +93,16 @@ def test_tokens(db): found = orm.APIToken.find(db, 'something else') assert found is None + secret = 'super-secret-preload-token' + token = user.new_api_token(secret) + assert token == secret + assert len(user.api_tokens) == 3 + + # raise ValueError on collision + with pytest.raises(ValueError): + user.new_api_token(token) + assert len(user.api_tokens) == 3 + def test_spawn_fails(db, io_loop): orm_user = orm.User(name='aeofel')