Merge pull request #604 from minrk/service-token

add Services to db
This commit is contained in:
Min RK
2016-07-29 10:32:15 +02:00
committed by GitHub
10 changed files with 276 additions and 58 deletions

View File

@@ -10,7 +10,7 @@ before_install:
- npm install -g configurable-http-proxy
- git clone --quiet --depth 1 https://github.com/minrk/travis-wheels travis-wheels
install:
- pip install --pre -f travis-wheels/wheelhouse -r dev-requirements.txt .
- pip install -v --pre -f travis-wheels/wheelhouse -r dev-requirements.txt .
script:
- travis_retry py.test --cov jupyterhub jupyterhub/tests -v
after_success:

View File

@@ -1,4 +1,5 @@
-r requirements.txt
mock
codecov
pytest-cov
pytest>=2.8

View File

@@ -0,0 +1,25 @@
"""services
Revision ID: af4cbdb2d13c
Revises: eeb276e51423
Create Date: 2016-07-28 16:16:38.245348
"""
# revision identifiers, used by Alembic.
revision = 'af4cbdb2d13c'
down_revision = 'eeb276e51423'
branch_labels = None
depends_on = None
from alembic import op
import sqlalchemy as sa
def upgrade():
op.add_column('api_tokens', sa.Column('service_id', sa.Integer))
def downgrade():
# sqlite cannot downgrade because of limited ALTER TABLE support (no DROP COLUMN)
op.drop_column('api_tokens', 'service_id')

View File

@@ -384,9 +384,28 @@ class JupyterHub(Application):
).tag(config=True)
api_tokens = Dict(Unicode(),
help="""Dict of token:username to be loaded into the database.
help="""PENDING DEPRECATION: consider using service_tokens
Allows ahead-of-time generation of API tokens for use by services.
Dict of token:username to be loaded into the database.
Allows ahead-of-time generation of API tokens for use by externally managed services,
which authenticate as JupyterHub users.
Consider using service_tokens for general services that talk to the JupyterHub API.
"""
).tag(config=True)
@observe('api_tokens')
def _deprecate_api_tokens(self, change):
self.log.warn("JupyterHub.api_tokens is pending deprecation."
" Consider using JupyterHub.service_tokens."
" If you have a use case for services that identify as users,"
" let us know: https://github.com/jupyterhub/jupyterhub/issues"
)
service_tokens = Dict(Unicode(),
help="""Dict of token:servicename to be loaded into the database.
Allows ahead-of-time generation of API tokens for use by externally managed services.
"""
).tag(config=True)
@@ -864,38 +883,51 @@ class JupyterHub(Application):
group.users.append(user)
db.commit()
def init_api_tokens(self):
"""Load predefined API tokens (for services) into database"""
def _add_tokens(self, token_dict, kind):
"""Add tokens for users or services to the database"""
if kind == 'user':
Class = orm.User
elif kind == 'service':
Class = orm.Service
else:
raise ValueError("kind must be user or service, not %r" % kind)
db = self.db
for token, username in self.api_tokens.items():
username = self.authenticator.normalize_username(username)
if not self.authenticator.check_whitelist(username):
raise ValueError("Token username %r is not in whitelist" % username)
if not self.authenticator.validate_username(username):
raise ValueError("Token username %r is not valid" % username)
for token, name in token_dict.items():
if kind == 'user':
name = self.authenticator.normalize_username(name)
if not self.authenticator.check_whitelist(name):
raise ValueError("Token name %r is not in whitelist" % name)
if not self.authenticator.validate_username(name):
raise ValueError("Token name %r is not valid" % name)
orm_token = orm.APIToken.find(db, token)
if orm_token is None:
user = orm.User.find(db, username)
user_created = False
if user is None:
user_created = True
self.log.debug("Adding user %r to database", username)
user = orm.User(name=username)
db.add(user)
obj = Class.find(db, name)
created = False
if obj is None:
created = True
self.log.debug("Adding %s %r to database", kind, name)
obj = Class(name=name)
db.add(obj)
db.commit()
self.log.info("Adding API token for %s", username)
self.log.info("Adding API token for %s: %s", kind, name)
try:
user.new_api_token(token)
obj.new_api_token(token)
except Exception:
if user_created:
if created:
# don't allow bad tokens to create users
db.delete(user)
db.delete(obj)
db.commit()
raise
else:
self.log.debug("Not duplicating token %s", orm_token)
db.commit()
def init_api_tokens(self):
"""Load predefined API tokens (for services) into database"""
self._add_tokens(self.service_tokens, kind='service')
self._add_tokens(self.api_tokens, kind='user')
@gen.coroutine
def init_spawners(self):
db = self.db

View File

@@ -144,7 +144,7 @@ class BaseHandler(RequestHandler):
if orm_token is None:
return None
else:
return orm_token.user
return orm_token.user or orm_token.service
def _user_for_cookie(self, cookie_name, cookie_value=None):
"""Get the User for a given cookie, if there is one"""

View File

@@ -302,7 +302,7 @@ class User(Base):
"""
__tablename__ = 'users'
id = Column(Integer, primary_key=True, autoincrement=True)
name = Column(Unicode(1023))
name = Column(Unicode(1023), unique=True)
# should we allow multiple servers per user?
_server_id = Column(Integer, ForeignKey('servers.id', ondelete="SET NULL"))
server = relationship(Server, primaryjoin=_server_id == Server.id)
@@ -340,21 +340,7 @@ class User(Base):
If `token` is given, load that token.
"""
assert self.id is not None
db = inspect(self).session
if token is None:
token = new_token()
else:
if len(token) < 8:
raise ValueError("Tokens must be at least 8 characters, got %r" % token)
found = APIToken.find(db, token)
if found:
raise ValueError("Collision on token: %s..." % token[:4])
orm_token = APIToken(user_id=self.id)
orm_token.token = token
db.add(orm_token)
db.commit()
return token
return APIToken.new(token=token, user=self)
@classmethod
def find(cls, db, name):
@@ -364,13 +350,73 @@ class User(Base):
"""
return db.query(cls).filter(cls.name==name).first()
# service:server many:many mapping table
service_server_map = Table('service_server_map', Base.metadata,
Column('service_id', ForeignKey('services.id')),
Column('server_id', ForeignKey('servers.id')),
)
class Service(Base):
"""A service run with JupyterHub
A service is similar to a User without a Spawner.
A service can have API tokens for accessing the Hub's API
It has:
- name
- admin
- api tokens
In addition to what it has in common with users, a Service has extra info:
- servers: list of HTTP endpoints for the service
- pid: the process id (if managed)
"""
__tablename__ = 'services'
id = Column(Integer, primary_key=True, autoincrement=True)
# common user interface:
name = Column(Unicode(1023), unique=True)
admin = Column(Boolean, default=False)
api_tokens = relationship("APIToken", backref="service")
# service-specific interface
servers = relationship('Server', secondary='service_server_map')
pid = Column(Integer)
def new_api_token(self, token=None):
"""Create a new API token
If `token` is given, load that token.
"""
return APIToken.new(token=token, service=self)
@classmethod
def find(cls, db, name):
"""Find a service by name.
Returns None if not found.
"""
return db.query(cls).filter(cls.name==name).first()
class APIToken(Base):
"""An API token"""
__tablename__ = 'api_tokens'
# _constraint = ForeignKeyConstraint(['user_id', 'server_id'], ['users.id', 'services.id'])
@declared_attr
def user_id(cls):
return Column(Integer, ForeignKey('users.id', ondelete="CASCADE"))
return Column(Integer, ForeignKey('users.id', ondelete="CASCADE"), nullable=True)
@declared_attr
def service_id(cls):
return Column(Integer, ForeignKey('services.id', ondelete="CASCADE"), nullable=True)
id = Column(Integer, primary_key=True)
hashed = Column(Unicode(1023))
@@ -391,22 +437,42 @@ class APIToken(Base):
self.hashed = hash_token(token, rounds=self.rounds, salt=self.salt_bytes, algorithm=self.algorithm)
def __repr__(self):
return "<{cls}('{pre}...', user='{u}')>".format(
if self.user is not None:
kind = 'user'
name = self.user.name
elif self.service is not None:
kind = 'service'
name = self.service.name
else:
# this shouldn't happen
kind = 'owner'
name = 'unknown'
return "<{cls}('{pre}...', {kind}='{name}')>".format(
cls=self.__class__.__name__,
pre=self.prefix,
u=self.user.name,
kind=kind,
name=name,
)
@classmethod
def find(cls, db, token):
def find(cls, db, token, *, kind=None):
"""Find a token object by value.
Returns None if not found.
`kind='user'` only returns API tokens for users
`kind='service'` only returns API tokens for services
"""
prefix = token[:cls.prefix_length]
# since we can't filter on hashed values, filter on prefix
# so we aren't comparing with all tokens
prefix_match = db.query(cls).filter(bindparam('prefix', prefix).startswith(cls.prefix))
if kind == 'user':
prefix_match = prefix_match.filter(cls.user_id != None)
elif kind == 'service':
prefix_match = prefix_match.filter(cls.service_id != None)
elif kind is not None:
raise ValueError("kind must be 'user', 'service', or None, not %r" % kind)
for orm_token in prefix_match:
if orm_token.match(token):
return orm_token
@@ -415,6 +481,31 @@ class APIToken(Base):
"""Is this my token?"""
return compare_token(self.hashed, token)
@classmethod
def new(cls, token=None, user=None, service=None):
"""Generate a new API token for a user or service"""
assert user or service
assert not (user and service)
db = inspect(user or service).session
if token is None:
token = new_token()
else:
if len(token) < 8:
raise ValueError("Tokens must be at least 8 characters, got %r" % token)
found = APIToken.find(db, token)
if found:
raise ValueError("Collision on token: %s..." % token[:4])
orm_token = APIToken(token=token)
if user:
assert user.id is not None
orm_token.user_id = user.id
else:
assert service.id is not None
orm_token.service_id = service.id
db.add(orm_token)
db.commit()
return token
def new_session_factory(url="sqlite:///:memory:", reset=False, **kwargs):
"""Create a new session at url"""

View File

@@ -42,8 +42,13 @@ def find_user(db, name):
return db.query(orm.User).filter(orm.User.name==name).first()
def add_user(db, app=None, **kwargs):
orm_user = find_user(db, name=kwargs.get('name'))
if orm_user is None:
orm_user = orm.User(**kwargs)
db.add(orm_user)
else:
for attr, value in kwargs.items():
setattr(orm_user, attr, value)
db.commit()
if app:
user = app.users[orm_user.id] = User(orm_user, app.tornado_settings)

View File

@@ -55,8 +55,7 @@ def test_generate_config():
assert 'Spawner.cmd' in cfg_text
assert 'Authenticator.whitelist' in cfg_text
def test_init_tokens():
def test_init_tokens(io_loop):
with TemporaryDirectory() as td:
db_file = os.path.join(td, 'jupyterhub.sqlite')
tokens = {
@@ -64,8 +63,8 @@ def test_init_tokens():
'also-super-secret': 'gordon',
'boagasdfasdf': 'chell',
}
app = MockHub(db_file=db_file, api_tokens=tokens)
app.initialize([])
app = MockHub(db_url=db_file, api_tokens=tokens)
io_loop.run_sync(lambda : app.initialize([]))
db = app.db
for token, username in tokens.items():
api_token = orm.APIToken.find(db, token)
@@ -74,8 +73,8 @@ def test_init_tokens():
assert user.name == username
# simulate second startup, reloading same tokens:
app = MockHub(db_file=db_file, api_tokens=tokens)
app.initialize([])
app = MockHub(db_url=db_file, api_tokens=tokens)
io_loop.run_sync(lambda : app.initialize([]))
db = app.db
for token, username in tokens.items():
api_token = orm.APIToken.find(db, token)
@@ -85,9 +84,9 @@ def test_init_tokens():
# don't allow failed token insertion to create users:
tokens['short'] = 'gman'
app = MockHub(db_file=db_file, api_tokens=tokens)
# with pytest.raises(ValueError):
app.initialize([])
app = MockHub(db_url=db_file, api_tokens=tokens)
with pytest.raises(ValueError):
io_loop.run_sync(lambda : app.initialize([]))
assert orm.User.find(app.db, 'gman') is None

View File

@@ -23,7 +23,7 @@ def test_upgrade(tmpdir):
print(db_url)
upgrade(db_url)
def test_upgrade_entrypoint(tmpdir):
def test_upgrade_entrypoint(tmpdir, io_loop):
generate_old_db(str(tmpdir))
tmpdir.chdir()
tokenapp = NewToken()
@@ -32,7 +32,7 @@ def test_upgrade_entrypoint(tmpdir):
tokenapp.start()
upgradeapp = UpgradeDB()
upgradeapp.initialize([])
io_loop.run_sync(lambda : upgradeapp.initialize([]))
upgradeapp.start()
# run tokenapp again, it should work

View File

@@ -90,6 +90,8 @@ def test_tokens(db):
assert len(user.api_tokens) == 2
found = orm.APIToken.find(db, token=token)
assert found.match(token)
assert found.user is user
assert found.service is None
found = orm.APIToken.find(db, 'something else')
assert found is None
@@ -104,6 +106,69 @@ def test_tokens(db):
assert len(user.api_tokens) == 3
def test_service_tokens(db):
service = orm.Service(name='secret')
db.add(service)
db.commit()
token = service.new_api_token()
assert any(t.match(token) for t in service.api_tokens)
service.new_api_token()
assert len(service.api_tokens) == 2
found = orm.APIToken.find(db, token=token)
assert found.match(token)
assert found.user is None
assert found.service is service
service2 = orm.Service(name='secret')
db.add(service)
db.commit()
assert service2.id != service.id
def test_service_servers(db):
service = orm.Service(name='has_servers')
db.add(service)
db.commit()
assert service.servers == []
servers = service.servers = [
orm.Server(),
orm.Server(),
]
assert [ s.id for s in servers ] == [ None, None ]
db.commit()
assert [ type(s.id) for s in servers ] == [ int, int ]
def test_token_find(db):
service = db.query(orm.Service).first()
user = db.query(orm.User).first()
service_token = service.new_api_token()
user_token = user.new_api_token()
with pytest.raises(ValueError):
orm.APIToken.find(db, 'irrelevant', kind='richard')
# no kind, find anything
found = orm.APIToken.find(db, token=user_token)
assert found
assert found.match(user_token)
found = orm.APIToken.find(db, token=service_token)
assert found
assert found.match(service_token)
# kind=user, only find user tokens
found = orm.APIToken.find(db, token=user_token, kind='user')
assert found
assert found.match(user_token)
found = orm.APIToken.find(db, token=service_token, kind='user')
assert found is None
# kind=service, only find service tokens
found = orm.APIToken.find(db, token=service_token, kind='service')
assert found
assert found.match(service_token)
found = orm.APIToken.find(db, token=user_token, kind='service')
assert found is None
def test_spawn_fails(db, io_loop):
orm_user = orm.User(name='aeofel')
db.add(orm_user)
@@ -126,7 +191,7 @@ def test_spawn_fails(db, io_loop):
def test_groups(db):
user = orm.User(name='aeofel')
user = orm.User.find(db, name='aeofel')
db.add(user)
group = orm.Group(name='lives')