SSL setup for testing

Setup general ssl request, not just to api

Basic tests comprised of non-ssl test copies

Create the context only when request is http

Refactor ssl key, cert, ca names

Configure the AsyncHTTPClient at app start

Change tests to import existing ones with ssl on

Override __new__ in MockHub to turn on SSL
This commit is contained in:
Thomas Mendoza
2018-06-07 16:01:24 -07:00
parent 5c39325104
commit 373c3f82dd
16 changed files with 153 additions and 233 deletions

View File

@@ -1148,9 +1148,9 @@ class JupyterHub(Application):
hub_args = dict( hub_args = dict(
base_url=self.hub_prefix, base_url=self.hub_prefix,
public_host=self.subdomain_host, public_host=self.subdomain_host,
ssl_cert_file=self.internal_ssl_cert, certfile=self.internal_ssl_cert,
ssl_key_file=self.internal_ssl_key, keyfile=self.internal_ssl_key,
ssl_ca_file=self.internal_ssl_ca, cafile=self.internal_ssl_ca,
) )
if self.hub_bind_url: if self.hub_bind_url:
# ensure hub_prefix is set on bind_url # ensure hub_prefix is set on bind_url
@@ -1439,9 +1439,9 @@ class JupyterHub(Application):
port=port, port=port,
cookie_name='jupyterhub-services', cookie_name='jupyterhub-services',
base_url=service.prefix, base_url=service.prefix,
ssl_cert_file=self.internal_ssl_cert, certfile=self.internal_ssl_cert,
ssl_key_file=self.internal_ssl_key, keyfile=self.internal_ssl_key,
ssl_ca_file=self.internal_ssl_ca, cafile=self.internal_ssl_ca,
) )
self.db.add(server) self.db.add(server)
@@ -1732,7 +1732,10 @@ class JupyterHub(Application):
extra_names = [socket.getfqdn()] + self.trusted_alt_names extra_names = [socket.getfqdn()] + self.trusted_alt_names
extra_names = ','.join(["DNS:{}".format(name) for name in extra_names]) extra_names = ','.join(["DNS:{}".format(name) for name in extra_names])
alt_names = alt_names.format(extra_names=extra_names).encode() alt_names = alt_names.format(extra_names=extra_names).encode()
internal_key_pair = cert_store.create_signed_pair("localhost", self.internal_authority_name, alt_names=alt_names) internal_key_pair = cert_store.create_signed_pair(
"localhost",
self.internal_authority_name,
alt_names=alt_names)
# Join CA files # Join CA files
with open(internal_key_pair.ca_file) as internal_ca, \ with open(internal_key_pair.ca_file) as internal_ca, \
@@ -1744,6 +1747,14 @@ class JupyterHub(Application):
self.internal_ssl_key = internal_key_pair.key_file self.internal_ssl_key = internal_key_pair.key_file
self.internal_ssl_cert = internal_key_pair.cert_file self.internal_ssl_cert = internal_key_pair.cert_file
self.internal_ssl_ca = joint_ca_file self.internal_ssl_ca = joint_ca_file
# Configure the AsyncHTTPClient
ssl_context = make_ssl_context(
self.internal_ssl_key,
self.internal_ssl_cert,
cafile=self.internal_ssl_ca,
)
AsyncHTTPClient.configure(None, defaults={"ssl_options" : ssl_context})
self.write_pid_file() self.write_pid_file()
def _log_cls(name, cls): def _log_cls(name, cls):

View File

@@ -35,9 +35,9 @@ class Server(HasTraits):
cookie_name = Unicode('') cookie_name = Unicode('')
connect_url = Unicode('') connect_url = Unicode('')
bind_url = Unicode('') bind_url = Unicode('')
ssl_cert_file = Unicode() certfile = Unicode()
ssl_key_file = Unicode() keyfile = Unicode()
ssl_ca_file = Unicode() cafile = Unicode()
@default('bind_url') @default('bind_url')
def bind_url_default(self): def bind_url_default(self):
@@ -126,9 +126,9 @@ class Server(HasTraits):
self.port = obj.port self.port = obj.port
self.base_url = obj.base_url self.base_url = obj.base_url
self.cookie_name = obj.cookie_name self.cookie_name = obj.cookie_name
self.ssl_cert_file = obj.ssl_cert_file self.certfile = obj.certfile
self.ssl_key_file = obj.ssl_key_file self.keyfile = obj.keyfile
self.ssl_ca_file = obj.ssl_ca_file self.cafile = obj.cafile
# setter to pass through to the database # setter to pass through to the database
@observe('ip', 'proto', 'port', 'base_url', 'cookie_name') @observe('ip', 'proto', 'port', 'base_url', 'cookie_name')
@@ -166,9 +166,12 @@ class Server(HasTraits):
def wait_up(self, timeout=10, http=False, ssl_context=None): def wait_up(self, timeout=10, http=False, ssl_context=None):
"""Wait for this server to come up""" """Wait for this server to come up"""
ssl_context = ssl_context or make_ssl_context(self.ssl_key_file, self.ssl_cert_file, cafile=self.ssl_ca_file)
if http: if http:
return wait_for_http_server(self.url, timeout=timeout, ssl_context=ssl_context) ssl_context = ssl_context or make_ssl_context(
self.keyfile, self.certfile, cafile=self.cafile)
return wait_for_http_server(
self.url, timeout=timeout, ssl_context=ssl_context)
else: else:
return wait_for_server(self._connect_ip, self._connect_port, timeout=timeout) return wait_for_server(self._connect_ip, self._connect_port, timeout=timeout)

View File

@@ -77,9 +77,9 @@ class Server(Base):
port = Column(Integer, default=random_port) port = Column(Integer, default=random_port)
base_url = Column(Unicode(255), default='/') base_url = Column(Unicode(255), default='/')
cookie_name = Column(Unicode(255), default='cookie') cookie_name = Column(Unicode(255), default='cookie')
ssl_cert_file = Column(Unicode(4096), default='') certfile = Column(Unicode(4096), default='')
ssl_key_file = Column(Unicode(4096), default='') keyfile = Column(Unicode(4096), default='')
ssl_ca_file = Column(Unicode(4096), default='') cafile = Column(Unicode(4096), default='')
def __repr__(self): def __repr__(self):
return "<Server(%s:%s)>" % (self.ip, self.port) return "<Server(%s:%s)>" % (self.ip, self.port)

View File

@@ -391,15 +391,6 @@ class ConfigurableHTTPProxy(Proxy):
c.ConfigurableHTTPProxy.should_start = False c.ConfigurableHTTPProxy.should_start = False
""" """
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
ssl_context = make_ssl_context(
self.app.internal_ssl_key,
self.app.internal_ssl_cert,
cafile=self.app.internal_ssl_ca,
)
AsyncHTTPClient.configure(None, defaults={"ssl_options" : ssl_context})
proxy_process = Any() proxy_process = Any()
client = Instance(AsyncHTTPClient, ()) client = Instance(AsyncHTTPClient, ())

View File

@@ -680,18 +680,23 @@ class Spawner(LoggingConfigurable):
internal_authority = self.internal_authority_name internal_authority = self.internal_authority_name
notebook_authority = self.internal_notebook_authority_name notebook_authority = self.internal_notebook_authority_name
internal_key_pair = cert_store.get(internal_authority) internal_key_pair = cert_store.get(internal_authority)
notebook_key_pair = cert_store.create_signed_pair(self.user.name, notebook_authority, alt_names=b"DNS:localhost,IP:127.0.0.1") notebook_key_pair = cert_store.create_signed_pair(
self.user.name,
notebook_authority,
alt_names=b"DNS:localhost,IP:127.0.0.1")
return { return {
"key_file": notebook_key_pair.key_file, "keyfile": notebook_key_pair.key_file,
"cert_file": notebook_key_pair.cert_file, "certfile": notebook_key_pair.cert_file,
"ca_file": internal_key_pair.ca_file, "cafile": internal_key_pair.ca_file,
} }
def move_certs(self, key_pair): def move_certs(self, key_pair):
"""Takes dict of cert/ca file paths and moves, sets up proper ownership for them.""" """Takes dict of cert/ca file paths and moves, sets up proper ownership
key = key_pair['key_file'] for them.
cert = key_pair['cert_file'] """
ca = key_pair['ca_file'] key = key_pair['keyfile']
cert = key_pair['certfile']
ca = key_pair['cafile']
try: try:
user = pwd.getpwnam(self.user.name) user = pwd.getpwnam(self.user.name)
@@ -705,9 +710,9 @@ class Spawner(LoggingConfigurable):
os.makedirs(out_dir, 0o700, exist_ok=True) os.makedirs(out_dir, 0o700, exist_ok=True)
# Move certs to users dir # Move certs to users dir
shutil.move(key_pair['key_file'], out_dir) shutil.move(key_pair['keyfile'], out_dir)
shutil.move(key_pair['cert_file'], out_dir) shutil.move(key_pair['certfile'], out_dir)
shutil.copy(key_pair['ca_file'], out_dir) shutil.copy(key_pair['cafile'], out_dir)
path_tmpl = "{out}/{name}.{ext}" path_tmpl = "{out}/{name}.{ext}"
key = path_tmpl.format(out=out_dir, name=self.user.name, ext="key") key = path_tmpl.format(out=out_dir, name=self.user.name, ext="key")

View File

@@ -42,6 +42,7 @@ from ..utils import random_port
from . import mocking from . import mocking
from .mocking import MockHub from .mocking import MockHub
from .utils import ssl_setup
from .test_services import mockservice_cmd from .test_services import mockservice_cmd
import jupyterhub.services.service import jupyterhub.services.service
@@ -50,10 +51,25 @@ import jupyterhub.services.service
_db = None _db = None
@fixture(scope='session')
def ssl_tmpdir(tmpdir_factory):
return tmpdir_factory.mktemp('ssl')
@fixture(scope='module') @fixture(scope='module')
def app(request, io_loop): def app(request, io_loop, ssl_tmpdir):
"""Mock a jupyterhub app for testing""" """Mock a jupyterhub app for testing"""
mocked_app = MockHub.instance(log_level=logging.DEBUG) mocked_app = MockHub.instance(log_level=logging.DEBUG)
ssl_enabled = getattr(request.module, "ssl_enabled", False)
if ssl_enabled:
internal_authority_name = 'hub'
external_certs = ssl_setup(str(ssl_tmpdir), internal_authority_name)
mocked_app = MockHub.instance(
log_level=logging.DEBUG,
internal_ssl=True,
internal_authority_name=internal_authority_name,
internal_certs_location=str(ssl_tmpdir))
@gen.coroutine @gen.coroutine
def make_app(): def make_app():
@@ -116,16 +132,6 @@ def io_loop(request):
request.addfinalizer(_close) request.addfinalizer(_close)
return io_loop return io_loop
@fixture(scope='module')
def app(request, io_loop):
"""Mock a jupyterhub app for testing"""
ssl_enabled = getattr(request.module, "ssl_enabled", False)
mocked_app = MockHub.instance(log_level=logging.DEBUG, internal_ssl=ssl_enabled)
@gen.coroutine
def make_app():
yield mocked_app.initialize([])
yield mocked_app.start()
io_loop.run_sync(make_app)
@fixture(autouse=True) @fixture(autouse=True)
def cleanup_after(request, io_loop): def cleanup_after(request, io_loop):

View File

@@ -48,7 +48,7 @@ from ..objects import Server
from ..spawner import LocalProcessSpawner from ..spawner import LocalProcessSpawner
from ..singleuser import SingleUserNotebookApp from ..singleuser import SingleUserNotebookApp
from ..utils import random_port, url_path_join from ..utils import random_port, url_path_join
from .utils import async_requests from .utils import async_requests, ssl_setup
from pamela import PAMError from pamela import PAMError
@@ -216,6 +216,20 @@ class MockHub(JupyterHub):
last_activity_interval = 2 last_activity_interval = 2
log_datefmt = '%M:%S' log_datefmt = '%M:%S'
def __new__(cls, *args, **kwargs):
try:
# Turn on internalSSL if the options exist
internal_authority_name = 'hub'
cert_location = kwargs['internal_certs_location']
external_certs = ssl_setup(cert_location, internal_authority_name)
kwargs['internal_ssl'] = True
kwargs['internal_authority_name'] = internal_authority_name
kwargs['ssl_cert'] = external_certs.cert_file
kwargs['ssl_key'] = external_certs.key_file
except KeyError:
pass
return super().__new__(cls, *args, **kwargs)
@default('subdomain_host') @default('subdomain_host')
def _subdomain_host_default(self): def _subdomain_host_default(self):
return os.environ.get('JUPYTERHUB_TEST_SUBDOMAIN_HOST', '') return os.environ.get('JUPYTERHUB_TEST_SUBDOMAIN_HOST', '')

View File

@@ -14,9 +14,11 @@ Handlers and their purpose include:
import argparse import argparse
import json import json
import sys import sys
import os
from tornado import web, httpserver, ioloop from tornado import web, httpserver, ioloop
from .mockservice import EnvHandler from .mockservice import EnvHandler
from ..utils import make_ssl_context
class EchoHandler(web.RequestHandler): class EchoHandler(web.RequestHandler):
def get(self): def get(self):
@@ -34,7 +36,19 @@ def main(args):
(r'.*', EchoHandler), (r'.*', EchoHandler),
]) ])
server = httpserver.HTTPServer(app) ssl_context = None
if args.keyfile and args.certfile and args.client_ca:
key = args.keyfile.strip('"')
cert = args.certfile.strip('"')
ca = args.client_ca.strip('"')
ssl_context = make_ssl_context(
key,
cert,
cafile = ca,
check_hostname = False)
server = httpserver.HTTPServer(app, ssl_options=ssl_context)
server.listen(args.port) server.listen(args.port)
try: try:
ioloop.IOLoop.instance().start() ioloop.IOLoop.instance().start()
@@ -44,5 +58,8 @@ def main(args):
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--port', type=int) parser.add_argument('--port', type=int)
parser.add_argument('--keyfile', type=str)
parser.add_argument('--certfile', type=str)
parser.add_argument('--client-ca', type=str)
args, extra = parser.parse_known_args() args, extra = parser.parse_known_args()
main(args) main(args)

View File

@@ -72,7 +72,7 @@ def test_init_tokens():
'also-super-secret': 'gordon', 'also-super-secret': 'gordon',
'boagasdfasdf': 'chell', 'boagasdfasdf': 'chell',
} }
app = MockHub(db_url=db_file, api_tokens=tokens) app = MockHub(db_url=db_file, api_tokens=tokens, internal_certs_location=td)
yield app.initialize([]) yield app.initialize([])
db = app.db db = app.db
for token, username in tokens.items(): for token, username in tokens.items():
@@ -82,7 +82,7 @@ def test_init_tokens():
assert user.name == username assert user.name == username
# simulate second startup, reloading same tokens: # simulate second startup, reloading same tokens:
app = MockHub(db_url=db_file, api_tokens=tokens) app = MockHub(db_url=db_file, api_tokens=tokens, internal_certs_location=td)
yield app.initialize([]) yield app.initialize([])
db = app.db db = app.db
for token, username in tokens.items(): for token, username in tokens.items():
@@ -93,7 +93,7 @@ def test_init_tokens():
# don't allow failed token insertion to create users: # don't allow failed token insertion to create users:
tokens['short'] = 'gman' tokens['short'] = 'gman'
app = MockHub(db_url=db_file, api_tokens=tokens) app = MockHub(db_url=db_file, api_tokens=tokens, internal_certs_location=td)
with pytest.raises(ValueError): with pytest.raises(ValueError):
yield app.initialize([]) yield app.initialize([])
assert orm.User.find(app.db, 'gman') is None assert orm.User.find(app.db, 'gman') is None
@@ -101,7 +101,7 @@ def test_init_tokens():
def test_write_cookie_secret(tmpdir): def test_write_cookie_secret(tmpdir):
secret_path = str(tmpdir.join('cookie_secret')) secret_path = str(tmpdir.join('cookie_secret'))
hub = MockHub(cookie_secret_file=secret_path) hub = MockHub(cookie_secret_file=secret_path, internal_certs_location=str(tmpdir))
hub.init_secrets() hub.init_secrets()
assert os.path.exists(secret_path) assert os.path.exists(secret_path)
assert os.stat(secret_path).st_mode & 0o600 assert os.stat(secret_path).st_mode & 0o600
@@ -113,7 +113,7 @@ def test_cookie_secret_permissions(tmpdir):
secret_path = str(secret_file) secret_path = str(secret_file)
secret = os.urandom(COOKIE_SECRET_BYTES) secret = os.urandom(COOKIE_SECRET_BYTES)
secret_file.write(binascii.b2a_hex(secret)) secret_file.write(binascii.b2a_hex(secret))
hub = MockHub(cookie_secret_file=secret_path) hub = MockHub(cookie_secret_file=secret_path, internal_certs_location=str(tmpdir))
# raise with public secret file # raise with public secret file
os.chmod(secret_path, 0o664) os.chmod(secret_path, 0o664)
@@ -131,13 +131,13 @@ def test_cookie_secret_content(tmpdir):
secret_file.write('not base 64: uñiço∂e') secret_file.write('not base 64: uñiço∂e')
secret_path = str(secret_file) secret_path = str(secret_file)
os.chmod(secret_path, 0o660) os.chmod(secret_path, 0o660)
hub = MockHub(cookie_secret_file=secret_path) hub = MockHub(cookie_secret_file=secret_path, internal_certs_location=str(tmpdir))
with pytest.raises(SystemExit): with pytest.raises(SystemExit):
hub.init_secrets() hub.init_secrets()
def test_cookie_secret_env(tmpdir): def test_cookie_secret_env(tmpdir):
hub = MockHub(cookie_secret_file=str(tmpdir.join('cookie_secret'))) hub = MockHub(cookie_secret_file=str(tmpdir.join('cookie_secret')), internal_certs_location=str(tmpdir))
with patch.dict(os.environ, {'JPY_COOKIE_SECRET': 'not hex'}): with patch.dict(os.environ, {'JPY_COOKIE_SECRET': 'not hex'}):
with pytest.raises(ValueError): with pytest.raises(ValueError):
@@ -150,12 +150,12 @@ def test_cookie_secret_env(tmpdir):
@pytest.mark.gen_test @pytest.mark.gen_test
def test_load_groups(): def test_load_groups(tmpdir):
to_load = { to_load = {
'blue': ['cyclops', 'rogue', 'wolverine'], 'blue': ['cyclops', 'rogue', 'wolverine'],
'gold': ['storm', 'jean-grey', 'colossus'], 'gold': ['storm', 'jean-grey', 'colossus'],
} }
hub = MockHub(load_groups=to_load) hub = MockHub(load_groups=to_load, internal_certs_location=str(tmpdir))
hub.init_db() hub.init_db()
yield hub.init_users() yield hub.init_users()
yield hub.init_groups() yield hub.init_groups()
@@ -178,7 +178,7 @@ def test_resume_spawners(tmpdir, request):
request.addfinalizer(p.stop) request.addfinalizer(p.stop)
@gen.coroutine @gen.coroutine
def new_hub(): def new_hub():
app = MockHub() app = MockHub(internal_certs_location=str(tmpdir))
app.config.ConfigurableHTTPProxy.should_start = False app.config.ConfigurableHTTPProxy.should_start = False
app.config.ConfigurableHTTPProxy.auth_token = 'unused' app.config.ConfigurableHTTPProxy.auth_token = 'unused'
yield app.initialize([]) yield app.initialize([])

View File

@@ -1,167 +0,0 @@
"""Tests for the SSL enabled REST API."""
from concurrent.futures import Future
import json
import time
import sys
from unittest import mock
from urllib.parse import urlparse, quote
import pytest
from pytest import mark
import requests
from tornado import gen
import jupyterhub
from .. import orm
from ..user import User
from ..utils import url_path_join as ujoin
from . import mocking
from .mocking import public_host, public_url
from .utils import async_requests
ssl_enabled = True
def check_db_locks(func):
"""Decorator that verifies no locks are held on database upon exit.
This decorator for test functions verifies no locks are held on the
application's database upon exit by creating and dropping a dummy table.
The decorator relies on an instance of JupyterHubApp being the first
argument to the decorated function.
Example
-------
@check_db_locks
def api_request(app, *api_path, **kwargs):
"""
def new_func(app, *args, **kwargs):
retval = func(app, *args, **kwargs)
temp_session = app.session_factory()
temp_session.execute('CREATE TABLE dummy (foo INT)')
temp_session.execute('DROP TABLE dummy')
temp_session.close()
return retval
return new_func
def find_user(db, name):
"""Find user in database."""
return db.query(orm.User).filter(orm.User.name == name).first()
def add_user(db, app=None, **kwargs):
"""Add a user to the database."""
orm_user = find_user(db, name=kwargs.get('name'))
if orm_user is None:
orm_user = orm.User(**kwargs)
db.add(orm_user)
else:
for attr, value in kwargs.items():
setattr(orm_user, attr, value)
db.commit()
if app:
user = app.users[orm_user.id] = User(orm_user, app.tornado_settings)
return user
else:
return orm_user
def auth_header(db, name):
"""Return header with user's API authorization token."""
user = find_user(db, name)
if user is None:
user = add_user(db, name=name)
token = user.new_api_token()
return {'Authorization': 'token %s' % token}
@check_db_locks
@gen.coroutine
def api_request(app, *api_path, **kwargs):
"""Make an API request"""
base_url = app.hub.url
headers = kwargs.setdefault('headers', {})
if 'Authorization' not in headers and not kwargs.pop('noauth', False):
headers.update(auth_header(app.db, 'admin'))
kwargs['cert'] = (app.internal_ssl_cert, app.internal_ssl_key)
kwargs['verify'] = app.internal_ssl_ca
url = ujoin(base_url, 'api', *api_path)
method = kwargs.pop('method', 'get')
f = getattr(async_requests, method)
resp = yield f(url, **kwargs)
assert "frame-ancestors 'self'" in resp.headers['Content-Security-Policy']
assert ujoin(app.hub.base_url, "security/csp-report") in resp.headers['Content-Security-Policy']
assert 'http' not in resp.headers['Content-Security-Policy']
return resp
@mark.gen_test
def test_spawn(app):
db = app.db
name = 'wash'
user = add_user(db, app=app, name=name)
options = {
's': ['value'],
'i': 5,
}
before_servers = sorted(db.query(orm.Server), key=lambda s: s.url)
r = yield api_request(app, 'users', name, 'server', method='post',
data=json.dumps(options),
)
assert r.status_code == 201
assert 'pid' in user.orm_spawners[''].state
app_user = app.users[name]
assert app_user.spawner is not None
spawner = app_user.spawner
assert app_user.spawner.user_options == options
assert not app_user.spawner._spawn_pending
status = yield app_user.spawner.poll()
assert status is None
assert spawner.server.base_url == ujoin(app.base_url, 'user/%s' % name) + '/'
url = public_url(app, user)
r = yield api_request(app, url)
assert r.status_code == 200
assert r.text == spawner.server.base_url
r = yield api_request(app, ujoin(url, 'args'))
assert r.status_code == 200
argv = r.json()
assert '--port' in ' '.join(argv)
r = yield api_request(app, ujoin(url, 'env'))
env = r.json()
for expected in ['JUPYTERHUB_USER', 'JUPYTERHUB_BASE_URL', 'JUPYTERHUB_API_TOKEN']:
assert expected in env
if app.subdomain_host:
assert env['JUPYTERHUB_HOST'] == app.subdomain_host
r = yield api_request(app, 'users', name, 'server', method='delete')
assert r.status_code == 204
assert 'pid' not in user.orm_spawners[''].state
status = yield app_user.spawner.poll()
assert status == 0
# check that we cleaned up after ourselves
assert spawner.server is None
after_servers = sorted(db.query(orm.Server), key=lambda s: s.url)
assert before_servers == after_servers
tokens = list(db.query(orm.APIToken).filter(orm.APIToken.user_id == user.id))
assert tokens == []
assert app.users.count_active_users()['pending'] == 0
@mark.gen_test
def test_root_api(app):
base_url = app.hub.url
r = yield api_request(app, '')
r.raise_for_status()
expected = {
'version': jupyterhub.__version__
}
assert r.json() == expected

View File

@@ -0,0 +1,7 @@
"""Tests for the SSL enabled REST API."""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
from jupyterhub.tests.test_api import *
ssl_enabled = True

View File

@@ -0,0 +1,10 @@
"""Test the JupyterHub entry point with internal ssl"""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
import sys
import jupyterhub.tests.mocking
from .utils import ssl_setup
from jupyterhub.tests.test_app import *

View File

@@ -0,0 +1,7 @@
"""Tests for process spawning with internal_ssl"""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
from jupyterhub.tests.test_spawner import *
ssl_enabled = True

View File

@@ -1,6 +1,8 @@
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
import requests import requests
from certipy import Certipy
class _AsyncRequests: class _AsyncRequests:
"""Wrapper around requests to return a Future from request methods """Wrapper around requests to return a Future from request methods
@@ -16,3 +18,10 @@ class _AsyncRequests:
# async_requests.get = requests.get returning a Future, etc. # async_requests.get = requests.get returning a Future, etc.
async_requests = _AsyncRequests() async_requests = _AsyncRequests()
def ssl_setup(cert_dir, authority_name):
# Set up the external certs with the same authority as the internal
# one so that certificate trust works regardless of chosen endpoint.
cert_store = Certipy(store_dir=cert_dir)
internal_authority = cert_store.create_ca(authority_name)
external_certs = cert_store.create_signed_pair('external', authority_name)
return external_certs

View File

@@ -508,7 +508,10 @@ class User:
cert = self.settings['internal_ssl_cert'] cert = self.settings['internal_ssl_cert']
ca = self.settings['internal_ssl_ca'] ca = self.settings['internal_ssl_ca']
ssl_context = make_ssl_context(key, cert, cafile=ca) ssl_context = make_ssl_context(key, cert, cafile=ca)
resp = await server.wait_up(http=True, timeout=spawner.http_timeout, ssl_context=ssl_context) resp = await server.wait_up(
http=True,
timeout=spawner.http_timeout,
ssl_context=ssl_context)
except Exception as e: except Exception as e:
if isinstance(e, TimeoutError): if isinstance(e, TimeoutError):
self.log.warning( self.log.warning(

View File

@@ -72,7 +72,11 @@ def can_connect(ip, port):
return True return True
def make_ssl_context(keyfile, certfile, cafile=None, verify=True, check_hostname=True): def make_ssl_context(
keyfile, certfile, cafile=None,
verify=True, check_hostname=True):
"""Setup context for starting an https server or making requests over ssl.
"""
if not keyfile or not certfile: if not keyfile or not certfile:
return None return None
purpose = ssl.Purpose.SERVER_AUTH if verify else ssl.Purpose.CLIENT_AUTH purpose = ssl.Purpose.SERVER_AUTH if verify else ssl.Purpose.CLIENT_AUTH