add service API tokens

service_tokens supersedes api_tokens,
since they now map to a new services collection,
rather than regular Hub usernames.

Services in the ORM have:

- API tokens
- servers (multiple, can be 0)
- pid (0 if not managed)
This commit is contained in:
Min RK
2016-07-16 16:42:42 -05:00
parent 81350322d7
commit 2a35d1c8a6
4 changed files with 232 additions and 44 deletions

View File

@@ -384,9 +384,28 @@ class JupyterHub(Application):
).tag(config=True) ).tag(config=True)
api_tokens = Dict(Unicode(), api_tokens = Dict(Unicode(),
help="""Dict of token:username to be loaded into the database. help="""PENDING DEPRECATION: consider using service_tokens
Dict of token:username to be loaded into the database.
Allows ahead-of-time generation of API tokens for use by services. Allows ahead-of-time generation of API tokens for use by externally managed services,
which authenticate as JupyterHub users.
Consider using service_tokens for general services that talk to the JupyterHub API.
"""
).tag(config=True)
@observe('api_tokens')
def _deprecate_api_tokens(self, change):
self.log.warn("JupyterHub.api_tokens is pending deprecation."
" Consider using JupyterHub.service_tokens."
" If you have a use case for services that identify as users,"
" let us know: https://github.com/jupyterhub/jupyterhub/issues"
)
service_tokens = Dict(Unicode(),
help="""Dict of token:servicename to be loaded into the database.
Allows ahead-of-time generation of API tokens for use by externally managed services.
""" """
).tag(config=True) ).tag(config=True)
@@ -864,38 +883,51 @@ class JupyterHub(Application):
group.users.append(user) group.users.append(user)
db.commit() db.commit()
def init_api_tokens(self): def _add_tokens(self, token_dict, kind):
"""Load predefined API tokens (for services) into database""" """Add tokens for users or services to the database"""
if kind == 'user':
Class = orm.User
elif kind == 'service':
Class = orm.Service
else:
raise ValueError("kind must be user or service, not %r" % kind)
db = self.db db = self.db
for token, username in self.api_tokens.items(): for token, name in token_dict.items():
username = self.authenticator.normalize_username(username) if kind == 'user':
if not self.authenticator.check_whitelist(username): name = self.authenticator.normalize_username(name)
raise ValueError("Token username %r is not in whitelist" % username) if not self.authenticator.check_whitelist(name):
if not self.authenticator.validate_username(username): raise ValueError("Token name %r is not in whitelist" % name)
raise ValueError("Token username %r is not valid" % username) if not self.authenticator.validate_username(name):
raise ValueError("Token name %r is not valid" % name)
orm_token = orm.APIToken.find(db, token) orm_token = orm.APIToken.find(db, token)
if orm_token is None: if orm_token is None:
user = orm.User.find(db, username) obj = Class.find(db, name)
user_created = False created = False
if user is None: if obj is None:
user_created = True created = True
self.log.debug("Adding user %r to database", username) self.log.debug("Adding %s %r to database", kind, name)
user = orm.User(name=username) obj = Class(name=name)
db.add(user) db.add(obj)
db.commit() db.commit()
self.log.info("Adding API token for %s", username) self.log.info("Adding API token for %s: %s", kind, name)
try: try:
user.new_api_token(token) obj.new_api_token(token)
except Exception: except Exception:
if user_created: if created:
# don't allow bad tokens to create users # don't allow bad tokens to create users
db.delete(user) db.delete(obj)
db.commit() db.commit()
raise raise
else: else:
self.log.debug("Not duplicating token %s", orm_token) self.log.debug("Not duplicating token %s", orm_token)
db.commit() db.commit()
def init_api_tokens(self):
"""Load predefined API tokens (for services) into database"""
self._add_tokens(self.service_tokens, kind='service')
self._add_tokens(self.api_tokens, kind='user')
@gen.coroutine @gen.coroutine
def init_spawners(self): def init_spawners(self):
db = self.db db = self.db

View File

@@ -144,7 +144,7 @@ class BaseHandler(RequestHandler):
if orm_token is None: if orm_token is None:
return None return None
else: else:
return orm_token.user return orm_token.user or orm_token.service
def _user_for_cookie(self, cookie_name, cookie_value=None): def _user_for_cookie(self, cookie_name, cookie_value=None):
"""Get the User for a given cookie, if there is one""" """Get the User for a given cookie, if there is one"""

View File

@@ -302,7 +302,7 @@ class User(Base):
""" """
__tablename__ = 'users' __tablename__ = 'users'
id = Column(Integer, primary_key=True, autoincrement=True) id = Column(Integer, primary_key=True, autoincrement=True)
name = Column(Unicode(1023)) name = Column(Unicode(1023), unique=True)
# should we allow multiple servers per user? # should we allow multiple servers per user?
_server_id = Column(Integer, ForeignKey('servers.id', ondelete="SET NULL")) _server_id = Column(Integer, ForeignKey('servers.id', ondelete="SET NULL"))
server = relationship(Server, primaryjoin=_server_id == Server.id) server = relationship(Server, primaryjoin=_server_id == Server.id)
@@ -340,21 +340,7 @@ class User(Base):
If `token` is given, load that token. If `token` is given, load that token.
""" """
assert self.id is not None return APIToken.new(token=token, user=self)
db = inspect(self).session
if token is None:
token = new_token()
else:
if len(token) < 8:
raise ValueError("Tokens must be at least 8 characters, got %r" % token)
found = APIToken.find(db, token)
if found:
raise ValueError("Collision on token: %s..." % token[:4])
orm_token = APIToken(user_id=self.id)
orm_token.token = token
db.add(orm_token)
db.commit()
return token
@classmethod @classmethod
def find(cls, db, name): def find(cls, db, name):
@@ -364,13 +350,73 @@ class User(Base):
""" """
return db.query(cls).filter(cls.name==name).first() return db.query(cls).filter(cls.name==name).first()
# service:server many:many mapping table
service_server_map = Table('service_server_map', Base.metadata,
Column('service_id', ForeignKey('services.id')),
Column('server_id', ForeignKey('servers.id')),
)
class Service(Base):
"""A service run with JupyterHub
A service is similar to a User without a Spawner.
A service can have API tokens for accessing the Hub's API
It has:
- name
- admin
- api tokens
In addition to what it has in common with users, a Service has extra info:
- servers: list of HTTP endpoints for the service
- pid: the process id (if managed)
"""
__tablename__ = 'services'
id = Column(Integer, primary_key=True, autoincrement=True)
# common user interface:
name = Column(Unicode(1023), unique=True)
admin = Column(Boolean, default=False)
api_tokens = relationship("APIToken", backref="service")
# service-specific interface
servers = relationship('Server', secondary='service_server_map')
pid = Column(Integer)
def new_api_token(self, token=None):
"""Create a new API token
If `token` is given, load that token.
"""
return APIToken.new(token=token, service=self)
@classmethod
def find(cls, db, name):
"""Find a service by name.
Returns None if not found.
"""
return db.query(cls).filter(cls.name==name).first()
class APIToken(Base): class APIToken(Base):
"""An API token""" """An API token"""
__tablename__ = 'api_tokens' __tablename__ = 'api_tokens'
# _constraint = ForeignKeyConstraint(['user_id', 'server_id'], ['users.id', 'services.id'])
@declared_attr @declared_attr
def user_id(cls): def user_id(cls):
return Column(Integer, ForeignKey('users.id', ondelete="CASCADE")) return Column(Integer, ForeignKey('users.id', ondelete="CASCADE"), nullable=True)
@declared_attr
def service_id(cls):
return Column(Integer, ForeignKey('services.id', ondelete="CASCADE"), nullable=True)
id = Column(Integer, primary_key=True) id = Column(Integer, primary_key=True)
hashed = Column(Unicode(1023)) hashed = Column(Unicode(1023))
@@ -391,22 +437,42 @@ class APIToken(Base):
self.hashed = hash_token(token, rounds=self.rounds, salt=self.salt_bytes, algorithm=self.algorithm) self.hashed = hash_token(token, rounds=self.rounds, salt=self.salt_bytes, algorithm=self.algorithm)
def __repr__(self): def __repr__(self):
return "<{cls}('{pre}...', user='{u}')>".format( if self.user is not None:
kind = 'user'
name = self.user.name
elif self.service is not None:
kind = 'service'
name = self.service.name
else:
# this shouldn't happen
kind = 'owner'
name = 'unknown'
return "<{cls}('{pre}...', {kind}='{name}')>".format(
cls=self.__class__.__name__, cls=self.__class__.__name__,
pre=self.prefix, pre=self.prefix,
u=self.user.name, kind=kind,
name=name,
) )
@classmethod @classmethod
def find(cls, db, token): def find(cls, db, token, *, kind=None):
"""Find a token object by value. """Find a token object by value.
Returns None if not found. Returns None if not found.
`kind='user'` only returns API tokens for users
`kind='service'` only returns API tokens for services
""" """
prefix = token[:cls.prefix_length] prefix = token[:cls.prefix_length]
# since we can't filter on hashed values, filter on prefix # since we can't filter on hashed values, filter on prefix
# so we aren't comparing with all tokens # so we aren't comparing with all tokens
prefix_match = db.query(cls).filter(bindparam('prefix', prefix).startswith(cls.prefix)) prefix_match = db.query(cls).filter(bindparam('prefix', prefix).startswith(cls.prefix))
if kind == 'user':
prefix_match = prefix_match.filter(cls.user_id != None)
elif kind == 'service':
prefix_match = prefix_match.filter(cls.service_id != None)
elif kind is not None:
raise ValueError("kind must be 'user', 'service', or None, not %r" % kind)
for orm_token in prefix_match: for orm_token in prefix_match:
if orm_token.match(token): if orm_token.match(token):
return orm_token return orm_token
@@ -415,6 +481,31 @@ class APIToken(Base):
"""Is this my token?""" """Is this my token?"""
return compare_token(self.hashed, token) return compare_token(self.hashed, token)
@classmethod
def new(cls, token=None, user=None, service=None):
"""Generate a new API token for a user or service"""
assert user or service
assert not (user and service)
db = inspect(user or service).session
if token is None:
token = new_token()
else:
if len(token) < 8:
raise ValueError("Tokens must be at least 8 characters, got %r" % token)
found = APIToken.find(db, token)
if found:
raise ValueError("Collision on token: %s..." % token[:4])
orm_token = APIToken(token=token)
if user:
assert user.id is not None
orm_token.user_id = user.id
else:
assert service.id is not None
orm_token.service_id = service.id
db.add(orm_token)
db.commit()
return token
def new_session_factory(url="sqlite:///:memory:", reset=False, **kwargs): def new_session_factory(url="sqlite:///:memory:", reset=False, **kwargs):
"""Create a new session at url""" """Create a new session at url"""

View File

@@ -90,6 +90,8 @@ def test_tokens(db):
assert len(user.api_tokens) == 2 assert len(user.api_tokens) == 2
found = orm.APIToken.find(db, token=token) found = orm.APIToken.find(db, token=token)
assert found.match(token) assert found.match(token)
assert found.user is user
assert found.service is None
found = orm.APIToken.find(db, 'something else') found = orm.APIToken.find(db, 'something else')
assert found is None assert found is None
@@ -104,6 +106,69 @@ def test_tokens(db):
assert len(user.api_tokens) == 3 assert len(user.api_tokens) == 3
def test_service_tokens(db):
service = orm.Service(name='secret')
db.add(service)
db.commit()
token = service.new_api_token()
assert any(t.match(token) for t in service.api_tokens)
service.new_api_token()
assert len(service.api_tokens) == 2
found = orm.APIToken.find(db, token=token)
assert found.match(token)
assert found.user is None
assert found.service is service
service2 = orm.Service(name='secret')
db.add(service)
db.commit()
assert service2.id != service.id
def test_service_servers(db):
service = orm.Service(name='has_servers')
db.add(service)
db.commit()
assert service.servers == []
servers = service.servers = [
orm.Server(),
orm.Server(),
]
assert [ s.id for s in servers ] == [ None, None ]
db.commit()
assert [ type(s.id) for s in servers ] == [ int, int ]
def test_token_find(db):
service = db.query(orm.Service).first()
user = db.query(orm.User).first()
service_token = service.new_api_token()
user_token = user.new_api_token()
with pytest.raises(ValueError):
orm.APIToken.find(db, 'irrelevant', kind='richard')
# no kind, find anything
found = orm.APIToken.find(db, token=user_token)
assert found
assert found.match(user_token)
found = orm.APIToken.find(db, token=service_token)
assert found
assert found.match(service_token)
# kind=user, only find user tokens
found = orm.APIToken.find(db, token=user_token, kind='user')
assert found
assert found.match(user_token)
found = orm.APIToken.find(db, token=service_token, kind='user')
assert found is None
# kind=service, only find service tokens
found = orm.APIToken.find(db, token=service_token, kind='service')
assert found
assert found.match(service_token)
found = orm.APIToken.find(db, token=user_token, kind='service')
assert found is None
def test_spawn_fails(db, io_loop): def test_spawn_fails(db, io_loop):
orm_user = orm.User(name='aeofel') orm_user = orm.User(name='aeofel')
db.add(orm_user) db.add(orm_user)
@@ -126,7 +191,7 @@ def test_spawn_fails(db, io_loop):
def test_groups(db): def test_groups(db):
user = orm.User(name='aeofel') user = orm.User.find(db, name='aeofel')
db.add(user) db.add(user)
group = orm.Group(name='lives') group = orm.Group(name='lives')