diff --git a/jupyterhub/apihandlers/proxy.py b/jupyterhub/apihandlers/proxy.py index 54d75500..a2b705f3 100644 --- a/jupyterhub/apihandlers/proxy.py +++ b/jupyterhub/apihandlers/proxy.py @@ -28,7 +28,7 @@ class ProxyAPIHandler(APIHandler): @gen.coroutine def post(self): """POST checks the proxy to ensure""" - yield self.proxy.check_routes(self.users) + yield self.proxy.check_routes(self.users, self.services) @admin_only @@ -59,7 +59,7 @@ class ProxyAPIHandler(APIHandler): self.proxy.auth_token = model['auth_token'] self.db.commit() self.log.info("Updated proxy at %s", server.bind_url) - yield self.proxy.check_routes(self.users) + yield self.proxy.check_routes(self.users, self.services) diff --git a/jupyterhub/app.py b/jupyterhub/app.py index 41886ab0..37dd80a8 100644 --- a/jupyterhub/app.py +++ b/jupyterhub/app.py @@ -976,8 +976,12 @@ class JupyterHub(Application): proxy=self.proxy, hub=self.hub, base_url=self.base_url, db=self.db, orm=orm_service, parent=self, - hub_api_url=self.hub.api_url, - **spec) + hub_api_url=self.hub.api_url) + traits = service.traits(input=True) + for key, value in spec.items(): + if key not in traits: + raise AttributeError("No such service field: %s" % key) + setattr(service, key, value) self._service_map[name] = service if service.managed: if not service.api_token: @@ -986,6 +990,14 @@ class JupyterHub(Application): else: # ensure provided token is registered self.service_tokens[service.api_token] = service.name + else: + self.service_tokens[service.api_token] = service.name + + # delete services from db not in service config: + for service in self.db.query(orm.Service): + if service.name not in self._service_map: + self.db.delete(service) + self.db.commit() @gen.coroutine def init_spawners(self): @@ -1155,6 +1167,7 @@ class JupyterHub(Application): yield self.start_proxy() self.log.info("Setting up routes on new proxy") yield self.proxy.add_all_users(self.users) + yield self.proxy.add_all_services(self.services) self.log.info("New proxy back up, and good to go") def init_tornado_settings(self): @@ -1213,6 +1226,7 @@ class JupyterHub(Application): self.tornado_settings = settings # constructing users requires access to tornado_settings self.tornado_settings['users'] = self.users + self.tornado_settings['services'] = self._service_map def init_tornado_application(self): """Instantiate the tornado Application object""" @@ -1354,7 +1368,7 @@ class JupyterHub(Application): self.statsd.gauge('users.active', active_users_count) self.db.commit() - yield self.proxy.check_routes(self.users, routes) + yield self.proxy.check_routes(self.users, self._service_map, routes) @gen.coroutine def start(self): @@ -1396,6 +1410,7 @@ class JupyterHub(Application): self.exit(1) loop.add_callback(self.proxy.add_all_users, self.users) + loop.add_callback(self.proxy.add_all_services, self._service_map) if self.proxy_process: # only check / restart the proxy if we started it in the first place. diff --git a/jupyterhub/handlers/base.py b/jupyterhub/handlers/base.py index 07651787..f03f817b 100644 --- a/jupyterhub/handlers/base.py +++ b/jupyterhub/handlers/base.py @@ -68,6 +68,9 @@ class BaseHandler(RequestHandler): return self.settings.setdefault('users', {}) @property + def services(self): + return self.settings.setdefault('services', {}) + @property def hub(self): return self.settings['hub'] @@ -236,6 +239,10 @@ class BaseHandler(RequestHandler): **kwargs ) + def set_service_cookie(self, user): + """set the login cookie for services""" + self._set_user_cookie(user, self.service_server) + def set_server_cookie(self, user): """set the login cookie for the single-user server""" self._set_user_cookie(user, user.server) @@ -254,6 +261,10 @@ class BaseHandler(RequestHandler): if user.server: self.set_server_cookie(user) + # set single cookie for services + if self.db.query(orm.Service).first(): + self.set_service_cookie(user) + # create and set a new cookie token for the hub if not self.get_current_user_cookie(): self.set_hub_cookie(user) diff --git a/jupyterhub/orm.py b/jupyterhub/orm.py index f7241b9a..5e581e19 100644 --- a/jupyterhub/orm.py +++ b/jupyterhub/orm.py @@ -152,6 +152,35 @@ class Proxy(Base): return client.fetch(req) + @gen.coroutine + def add_service(self, service, client=None): + """Add a service's server to the proxy table.""" + if not service.server: + raise RuntimeError( + "Service %s does not have an http endpoint to add to the proxy.", service.name) + + self.log.info("Adding service %s to proxy %s => %s", + service.name, service.proxy_path, service.server.host, + ) + + yield self.api_request(service.proxy_path, + method='POST', + body=dict( + target=service.server.host, + service=service.name, + ), + client=client, + ) + + @gen.coroutine + def delete_service(self, service, client=None): + """Remove a service's server from the proxy table.""" + self.log.info("Removing service %s from proxy", service.name) + yield self.api_request(service.proxy_path, + method='DELETE', + client=client, + ) + @gen.coroutine def add_user(self, user, client=None): """Add a user's server to the proxy table.""" @@ -174,7 +203,7 @@ class Proxy(Base): @gen.coroutine def delete_user(self, user, client=None): - """Remove a user's server to the proxy table.""" + """Remove a user's server from the proxy table.""" self.log.info("Removing user %s from proxy", user.name) yield self.api_request(user.proxy_path, method='DELETE', @@ -182,10 +211,20 @@ class Proxy(Base): ) @gen.coroutine - def get_routes(self, client=None): - """Fetch the proxy's routes""" - resp = yield self.api_request('', client=client) - return json.loads(resp.body.decode('utf8', 'replace')) + def add_all_services(self, service_dict): + """Update the proxy table from the database. + + Used when loading up a new proxy. + """ + db = inspect(self).session + futures = [] + for orm_service in db.query(Service): + service = service_dict[orm_service.name] + if service.server: + futures.append(self.add_service(service)) + # wait after submitting them all + for f in futures: + yield f @gen.coroutine def add_all_users(self, user_dict): @@ -204,12 +243,18 @@ class Proxy(Base): yield f @gen.coroutine - def check_routes(self, user_dict, routes=None): + def get_routes(self, client=None): + """Fetch the proxy's routes""" + resp = yield self.api_request('', client=client) + return json.loads(resp.body.decode('utf8', 'replace')) + + @gen.coroutine + def check_routes(self, user_dict, service_dict, routes=None): """Check that all users are properly routed on the proxy""" if not routes: routes = yield self.get_routes() - have_routes = { r['user'] for r in routes.values() if 'user' in r } + user_routes = { r['user'] for r in routes.values() if 'user' in r } futures = [] db = inspect(self).session for orm_user in db.query(User).filter(User.server != None): @@ -222,9 +267,22 @@ class Proxy(Base): # catch filter bug, either in sqlalchemy or my understanding of its behavior self.log.error("User %s has no server, but wasn't filtered out.", user) continue - if user.name not in have_routes: + if user.name not in user_routes: self.log.warning("Adding missing route for %s (%s)", user.name, user.server) futures.append(self.add_user(user)) + + # check service routes + service_routes = { r['service'] for r in routes.values() if 'service' in r } + for orm_service in db.query(Service).filter(Service.server != None): + service = service_dict[orm_service.name] + if service.server is None: + # This should never be True, but seems to be on rare occasion. + # catch filter bug, either in sqlalchemy or my understanding of its behavior + self.log.error("Service %s has no server, but wasn't filtered out.", service) + continue + if service.name not in service_routes: + self.log.warning("Adding missing route for %s (%s)", service.name, service.server) + futures.append(self.add_service(service)) for f in futures: yield f @@ -351,13 +409,6 @@ class User(Base): return db.query(cls).filter(cls.name==name).first() -# service:server many:many mapping table -service_server_map = Table('service_server_map', Base.metadata, - Column('service_id', ForeignKey('services.id')), - Column('server_id', ForeignKey('servers.id')), -) - - class Service(Base): """A service run with JupyterHub @@ -369,10 +420,10 @@ class Service(Base): - name - admin - api tokens + - server (if proxied http endpoint) In addition to what it has in common with users, a Service has extra info: - - servers: list of HTTP endpoints for the service - pid: the process id (if managed) """ @@ -386,7 +437,8 @@ class Service(Base): api_tokens = relationship("APIToken", backref="service") # service-specific interface - servers = relationship('Server', secondary='service_server_map') + _server_id = Column(Integer, ForeignKey('servers.id')) + server = relationship(Server, primaryjoin=_server_id == Server.id) pid = Column(Integer) def new_api_token(self, token=None): diff --git a/jupyterhub/services/service.py b/jupyterhub/services/service.py index b4cbe0fa..63b6c17b 100644 --- a/jupyterhub/services/service.py +++ b/jupyterhub/services/service.py @@ -50,7 +50,7 @@ from tornado import gen from traitlets import ( HasTraits, Any, Bool, Dict, Unicode, Instance, - observe, + default, observe, ) from traitlets.config import LoggingConfigurable @@ -74,6 +74,7 @@ class _ServiceSpawner(LocalProcessSpawner): Removes notebook-specific-ness from LocalProcessSpawner. """ cwd = Unicode() + cmd = Command(minlen=0) def make_preexec_fn(self, name): if not name or name == getuser(): @@ -81,7 +82,6 @@ class _ServiceSpawner(LocalProcessSpawner): return return super().make_preexec_fn(name) - @gen.coroutine def start(self): """Start the process""" env = self.get_env() @@ -92,7 +92,7 @@ class _ServiceSpawner(LocalProcessSpawner): self.proc = Popen(self.cmd, env=env, preexec_fn=self.make_preexec_fn(self.user.name), start_new_session=True, # don't forward signals - cwd=self.cwd, + cwd=self.cwd or None, ) except PermissionError: # use which to get abspath @@ -137,47 +137,46 @@ class Service(LoggingConfigurable): If the service has an http endpoint, it """ - ) + ).tag(input=True) admin = Bool(False, help="Does the service need admin-access to the Hub API?" - ) + ).tag(input=True) url = Unicode( help="""URL of the service. Only specify if the service runs an HTTP(s) endpoint that. If managed, will be passed as JUPYTERHUB_SERVICE_URL env. """ - ) + ).tag(input=True) @observe('url') def _url_changed(self, change): url = change['new'] if not url: self.orm.server = None else: - if self.orm.server is None: - parsed = urlparse(url) - if parsed.port is not None: - port = parsed.port - elif parsed.scheme == 'http': - port = 80 - elif parsed.scheme == 'https': - port = 443 - server = self.orm.server = orm.Server( - proto=parsed.scheme, - ip=parsed.host, - port=port, - cookie_name='jupyterhub-services', - base_url=self.proxy_path, - ) - self.db.add(server) - self.db.commit() + parsed = urlparse(url) + if parsed.port is not None: + port = parsed.port + elif parsed.scheme == 'http': + port = 80 + elif parsed.scheme == 'https': + port = 443 + server = self.orm.server = orm.Server( + proto=parsed.scheme, + ip=parsed.hostname, + port=port, + cookie_name='jupyterhub-services', + base_url=self.proxy_path, + ) + self.db.add(server) + self.db.commit() api_token = Unicode( help="""The API token to use for the service. If unspecified, an API token will be generated for managed services. """ - ) + ).tag(input=True) # Managed service API: @property @@ -185,28 +184,37 @@ class Service(LoggingConfigurable): """Am I managed by the Hub?""" return bool(self.command) - command = Command( + command = Command(minlen=0, help="Command to spawn this service, if managed." - ) + ).tag(input=True) cwd = Unicode( help="""The working directory in which to run the service.""" - ) + ).tag(input=True) environment = Dict( help="""Environment variables to pass to the service. Only used if the Hub is spawning the service. """ - ) + ).tag(input=True) user = Unicode(getuser(), help="""The user to become when launching the service. If unspecified, run the service as the same user as the Hub. """ - ) + ).tag(input=True) # handles on globals: proxy = Any() hub = Any() base_url = Unicode() + db = Any() + orm = Any() + @default('orm') + def _orm_default(self): + return self.db.query(orm.Service).filter(orm.Service.name==self.name).first() + + @property + def server(self): + return self.orm.server @property def proxy_path(self): @@ -219,7 +227,6 @@ class Service(LoggingConfigurable): managed=' managed' if self.managed else '', ) - @gen.coroutine def start(self): """Start a managed service""" if not self.managed: @@ -233,6 +240,7 @@ class Service(LoggingConfigurable): env['JUPYTERHUB_API_URL'] = self.hub_api_url env['JUPYTERHUB_BASE_URL'] = self.base_url env['JUPYTERHUB_SERVICE_PATH'] = self.proxy_path + env['JUPYTERHUB_SERVICE_URL'] = self.url self.spawner = _ServiceSpawner( cmd=self.command, @@ -245,7 +253,7 @@ class Service(LoggingConfigurable): server=self.orm.server, ), ) - yield self.spawner.start() + self.spawner.start() self.proc = self.spawner.proc self.spawner.add_poll_callback(self._proc_stopped) self.spawner.start_polling() @@ -253,7 +261,6 @@ class Service(LoggingConfigurable): def _proc_stopped(self): """Called when the service process unexpectedly exits""" self.log.error("Service %s exited with status %i", self.name, self.proc.returncode) - self.proc = None self.start() def stop(self): diff --git a/jupyterhub/tests/mockservice.py b/jupyterhub/tests/mockservice.py new file mode 100644 index 00000000..e989c1e5 --- /dev/null +++ b/jupyterhub/tests/mockservice.py @@ -0,0 +1,59 @@ +"""Mock service for testing + +basic HTTP Server that echos URLs back, +and allow retrieval of sys.argv. +""" + +import argparse +import json +import os +import sys +from urllib.parse import urlparse + +import requests +from tornado import web, httpserver, ioloop + + +class EchoHandler(web.RequestHandler): + def get(self): + self.write(self.request.path) + + +class EnvHandler(web.RequestHandler): + def get(self): + self.set_header('Content-Type', 'application/json') + self.write(json.dumps(dict(os.environ))) + + +class APIHandler(web.RequestHandler): + def get(self, path): + api_token = os.environ['JUPYTERHUB_API_TOKEN'] + api_url = os.environ['JUPYTERHUB_API_URL'] + r = requests.get(api_url + path, headers={ + 'Authorization': 'token %s' % api_token + }) + r.raise_for_status() + self.set_header('Content-Type', 'application/json') + self.write(r.text) + + +def main(): + if os.environ['JUPYTERHUB_SERVICE_URL']: + url = urlparse(os.environ['JUPYTERHUB_SERVICE_URL']) + app = web.Application([ + (r'.*/env', EnvHandler), + (r'.*/api/(.*)', APIHandler), + (r'.*', EchoHandler), + ]) + + server = httpserver.HTTPServer(app) + server.listen(url.port, url.hostname) + try: + ioloop.IOLoop.instance().start() + except KeyboardInterrupt: + print('\nInterrupted') + +if __name__ == '__main__': + from tornado.options import parse_command_line + parse_command_line() + main() diff --git a/jupyterhub/tests/test_orm.py b/jupyterhub/tests/test_orm.py index 6a369ad5..182160e4 100644 --- a/jupyterhub/tests/test_orm.py +++ b/jupyterhub/tests/test_orm.py @@ -124,19 +124,17 @@ def test_service_tokens(db): assert service2.id != service.id -def test_service_servers(db): +def test_service_server(db): service = orm.Service(name='has_servers') db.add(service) db.commit() - assert service.servers == [] - servers = service.servers = [ - orm.Server(), - orm.Server(), - ] - assert [ s.id for s in servers ] == [ None, None ] + assert service.server is None + server = service.server = orm.Server() + assert service + assert server.id is None db.commit() - assert [ type(s.id) for s in servers ] == [ int, int ] + assert isinstance(server.id, int) def test_token_find(db): diff --git a/jupyterhub/tests/test_proxy.py b/jupyterhub/tests/test_proxy.py index 07e389a8..4ce1c1ef 100644 --- a/jupyterhub/tests/test_proxy.py +++ b/jupyterhub/tests/test_proxy.py @@ -137,11 +137,11 @@ def test_check_routes(app, io_loop): zoe = app.users[zoe] before = sorted(io_loop.run_sync(app.proxy.get_routes)) assert unquote(zoe.proxy_path) in before - io_loop.run_sync(lambda : app.proxy.check_routes(app.users)) + io_loop.run_sync(lambda : app.proxy.check_routes(app.users, app._service_map)) io_loop.run_sync(lambda : proxy.delete_user(zoe)) during = sorted(io_loop.run_sync(app.proxy.get_routes)) assert unquote(zoe.proxy_path) not in during - io_loop.run_sync(lambda : app.proxy.check_routes(app.users)) + io_loop.run_sync(lambda : app.proxy.check_routes(app.users, app._service_map)) after = sorted(io_loop.run_sync(app.proxy.get_routes)) assert unquote(zoe.proxy_path) in after assert before == after diff --git a/jupyterhub/tests/test_services.py b/jupyterhub/tests/test_services.py new file mode 100644 index 00000000..84f67b20 --- /dev/null +++ b/jupyterhub/tests/test_services.py @@ -0,0 +1,137 @@ +"""Tests for services""" + +from binascii import hexlify +from contextlib import contextmanager +import os +from subprocess import Popen, TimeoutExpired +import sys +from threading import Event +import time +try: + from unittest import mock +except ImportError: + import mock +from urllib.parse import unquote + +import pytest +from tornado import gen +from tornado.ioloop import IOLoop + + +import jupyterhub.services.service +from .test_pages import get_page +from ..utils import url_path_join, wait_for_http_server + +here = os.path.dirname(os.path.abspath(__file__)) +mockservice_py = os.path.join(here, 'mockservice.py') +mockservice_cmd = [sys.executable, mockservice_py] + +from ..utils import random_port + +@contextmanager +def external_service(app, name='mockservice'): + env = { + 'JUPYTERHUB_API_TOKEN': hexlify(os.urandom(5)), + 'JUPYTERHUB_SERVICE_NAME': name, + 'JUPYTERHUB_API_URL': url_path_join(app.hub.server.url, 'api/'), + 'JUPYTERHUB_SERVICE_URL': 'http://127.0.0.1:%i' % random_port(), + } + p = Popen(mockservice_cmd, env=env) + IOLoop().run_sync(lambda : wait_for_http_server(env['JUPYTERHUB_SERVICE_URL'])) + try: + yield env + finally: + p.terminate() + + +# mock services for testing. +# Shorter intervals, etc. +class MockServiceSpawner(jupyterhub.services.service._ServiceSpawner): + poll_interval = 1 + +@pytest.yield_fixture +def mockservice(request, app): + name = 'mock-service' + with mock.patch.object(jupyterhub.services.service, '_ServiceSpawner', MockServiceSpawner): + app.services = [{ + 'name': name, + 'command': mockservice_cmd, + 'url': 'http://127.0.0.1:%i' % random_port(), + 'admin': True, + }] + app.init_services() + app.io_loop.add_callback(app.proxy.add_all_services, app._service_map) + assert name in app._service_map + service = app._service_map[name] + app.io_loop.add_callback(service.start) + request.addfinalizer(service.stop) + for i in range(20): + if not getattr(service, 'proc', False): + time.sleep(0.2) + # ensure process finishes starting + with pytest.raises(TimeoutExpired): + service.proc.wait(1) + yield service + + +def test_managed_service(app, mockservice): + service = mockservice + proc = service.proc + first_pid = proc.pid + assert proc.poll() is None + # shut it down: + proc.terminate() + proc.wait(10) + assert proc.poll() is not None + # ensure Hub notices and brings it back up: + for i in range(20): + if service.proc is not proc: + break + else: + time.sleep(0.2) + + assert service.proc.pid != first_pid + assert service.proc.poll() is None + + +def test_proxy_service(app, mockservice, io_loop): + name = mockservice.name + routes = io_loop.run_sync(app.proxy.get_routes) + assert unquote(mockservice.proxy_path) in routes + io_loop.run_sync(mockservice.server.wait_up) + path = '/services/{}/foo'.format(name) + r = get_page(path, app, hub=False, allow_redirects=False) + r.raise_for_status() + assert r.status_code == 200 + assert r.text.endswith(path) + + +@pytest.mark.now +def test_external_service(app, io_loop): + name = 'external' + with external_service(app, name=name) as env: + app.services = [{ + 'name': name, + 'admin': True, + 'url': env['JUPYTERHUB_SERVICE_URL'], + 'api_token': env['JUPYTERHUB_API_TOKEN'], + }] + app.init_services() + app.init_api_tokens() + evt = Event() + @gen.coroutine + def add_services(): + yield app.proxy.add_all_services(app._service_map) + evt.set() + app.io_loop.add_callback(add_services) + assert evt.wait(10) + path = '/services/{}/api/users'.format(name) + r = get_page(path, app, hub=False, allow_redirects=False) + print(r.headers, r.status_code) + r.raise_for_status() + assert r.status_code == 200 + resp = r.json() + assert isinstance(resp, list) + assert len(resp) >= 1 + assert isinstance(resp[0], dict) + assert 'name' in resp[0]