diff --git a/jupyterhub/apihandlers/auth.py b/jupyterhub/apihandlers/auth.py index 4d63ec99..2ec2dec0 100644 --- a/jupyterhub/apihandlers/auth.py +++ b/jupyterhub/apihandlers/auth.py @@ -16,6 +16,7 @@ from tornado import web from .. import orm from .. import roles from .. import scopes +from ..utils import get_browser_protocol from ..utils import token_authenticated from .base import APIHandler from .base import BaseHandler @@ -115,7 +116,10 @@ class OAuthHandler: # make absolute local redirects full URLs # to satisfy oauthlib's absolute URI requirement redirect_uri = ( - self.request.protocol + "://" + self.request.headers['Host'] + redirect_uri + get_browser_protocol(self.request) + + "://" + + self.request.host + + redirect_uri ) parsed_url = urlparse(uri) query_list = parse_qsl(parsed_url.query, keep_blank_values=True) diff --git a/jupyterhub/apihandlers/base.py b/jupyterhub/apihandlers/base.py index b107bf64..6cbb7659 100644 --- a/jupyterhub/apihandlers/base.py +++ b/jupyterhub/apihandlers/base.py @@ -14,6 +14,7 @@ from tornado import web from .. import orm from ..handlers import BaseHandler +from ..utils import get_browser_protocol from ..utils import isoformat from ..utils import url_path_join @@ -60,6 +61,8 @@ class APIHandler(BaseHandler): """ host_header = self.app.forwarded_host_header or "Host" host = self.request.headers.get(host_header) + if host and "," in host: + host = host.split(",", 1)[0].strip() referer = self.request.headers.get("Referer") # If no header is provided, assume it comes from a script/curl. @@ -71,7 +74,8 @@ class APIHandler(BaseHandler): self.log.warning("Blocking API request with no referer") return False - proto = self.request.protocol + proto = get_browser_protocol(self.request) + full_host = f"{proto}://{host}{self.hub.base_url}" host_url = urlparse(full_host) referer_url = urlparse(referer) diff --git a/jupyterhub/handlers/base.py b/jupyterhub/handlers/base.py index ab37e21f..56cf1fa8 100644 --- a/jupyterhub/handlers/base.py +++ b/jupyterhub/handlers/base.py @@ -49,6 +49,7 @@ from ..spawner import LocalProcessSpawner from ..user import User from ..utils import AnyTimeoutError from ..utils import get_accepted_mimetype +from ..utils import get_browser_protocol from ..utils import maybe_future from ..utils import url_path_join @@ -632,12 +633,10 @@ class BaseHandler(RequestHandler): next_url = self.get_argument('next', default='') # protect against some browsers' buggy handling of backslash as slash next_url = next_url.replace('\\', '%5C') - if (next_url + '/').startswith( - ( - f'{self.request.protocol}://{self.request.host}/', - f'//{self.request.host}/', - ) - ) or ( + proto = get_browser_protocol(self.request) + host = self.request.host + + if (next_url + '/').startswith((f'{proto}://{host}/', f'//{host}/',)) or ( self.subdomain_host and urlparse(next_url).netloc and ("." + urlparse(next_url).netloc).endswith( diff --git a/jupyterhub/services/auth.py b/jupyterhub/services/auth.py index a08049a3..9315564d 100644 --- a/jupyterhub/services/auth.py +++ b/jupyterhub/services/auth.py @@ -53,6 +53,7 @@ from traitlets import validate from traitlets.config import SingletonConfigurable from ..scopes import _intersect_expanded_scopes +from ..utils import get_browser_protocol from ..utils import url_path_join @@ -772,7 +773,7 @@ class HubOAuth(HubAuth): # OAuth that doesn't complete shouldn't linger too long. 'max_age': 600, } - if handler.request.protocol == 'https': + if get_browser_protocol(handler.request) == 'https': kwargs['secure'] = True # load user cookie overrides kwargs.update(self.cookie_options) @@ -812,7 +813,7 @@ class HubOAuth(HubAuth): def set_cookie(self, handler, access_token): """Set a cookie recording OAuth result""" kwargs = {'path': self.base_url, 'httponly': True} - if handler.request.protocol == 'https': + if get_browser_protocol(handler.request) == 'https': kwargs['secure'] = True # load user cookie overrides kwargs.update(self.cookie_options) diff --git a/jupyterhub/tests/test_api.py b/jupyterhub/tests/test_api.py index 1b777c74..1397ea63 100644 --- a/jupyterhub/tests/test_api.py +++ b/jupyterhub/tests/test_api.py @@ -118,16 +118,39 @@ async def test_post_content_type(app, content_type, status): ("fake.example", {"netloc": "fake.example", "scheme": "https"}, {}, 403), # explicit ports, match ("fake.example:81", {"netloc": "fake.example:81"}, {}, 200), - # Test proxy defined headers taken into account by xheaders=True in - # https://github.com/jupyterhub/jupyterhub/blob/2.0.1/jupyterhub/app.py#L3065 + # Test proxy protocol defined headers taken into account by utils.get_browser_protocol ( "fake.example", {"netloc": "fake.example", "scheme": "https"}, - # note {"X-Forwarded-Proto": "https"} does not work {'X-Scheme': 'https'}, 200, ), + ( + "fake.example", + {"netloc": "fake.example", "scheme": "https"}, + {'X-Forwarded-Proto': 'https'}, + 200, + ), + ( + "fake.example", + {"netloc": "fake.example", "scheme": "https"}, + { + 'Forwarded': 'host=fake.example;proto=https,for=1.2.34;proto=http', + 'X-Scheme': 'http', + }, + 200, + ), + ( + "fake.example", + {"netloc": "fake.example", "scheme": "https"}, + { + 'Forwarded': 'host=fake.example;proto=http,for=1.2.34;proto=http', + 'X-Scheme': 'https', + }, + 403, + ), ("fake.example", {"netloc": "fake.example"}, {'X-Scheme': 'https'}, 403), + ("fake.example", {"netloc": "fake.example"}, {'X-Scheme': 'https, http'}, 403), ], ) async def test_cors_check(request, app, host, referer, extraheaders, status): diff --git a/jupyterhub/tests/test_utils.py b/jupyterhub/tests/test_utils.py index c75c9051..129bc0fa 100644 --- a/jupyterhub/tests/test_utils.py +++ b/jupyterhub/tests/test_utils.py @@ -2,12 +2,16 @@ import asyncio import time from concurrent.futures import ThreadPoolExecutor +from unittest.mock import Mock import pytest from async_generator import aclosing from tornado import gen from tornado.concurrent import run_on_executor +from tornado.httpserver import HTTPRequest +from tornado.httputil import HTTPHeaders +from .. import utils from ..utils import iterate_until @@ -88,3 +92,33 @@ async def test_tornado_coroutines(): # verify that tornado gen and executor methods return awaitables assert (await t.on_executor()) == "executor" assert (await t.tornado_coroutine()) == "gen.coroutine" + + +@pytest.mark.parametrize( + "forwarded, x_scheme, x_forwarded_proto, expected", + [ + ("", "", "", "_attr_"), + ("for=1.2.3.4", "", "", "_attr_"), + ("for=1.2.3.4,proto=https", "", "", "_attr_"), + ("", "https", "http", "https"), + ("", "https, http", "", "https"), + ("", "https, http", "http", "https"), + ("proto=http ; for=1.2.3.4, proto=https", "https, http", "", "http"), + ("proto=invalid;for=1.2.3.4,proto=http", "https, http", "", "https"), + ("for=1.2.3.4,proto=http", "https, http", "", "https"), + ("", "invalid, http", "", "_attr_"), + ], +) +def test_browser_protocol(x_scheme, x_forwarded_proto, forwarded, expected): + request = Mock(spec=HTTPRequest) + request.protocol = "_attr_" + request.headers = HTTPHeaders() + if x_scheme: + request.headers["X-Scheme"] = x_scheme + if x_forwarded_proto: + request.headers["X-Forwarded-Proto"] = x_forwarded_proto + if forwarded: + request.headers["Forwarded"] = forwarded + + proto = utils.get_browser_protocol(request) + assert proto == expected diff --git a/jupyterhub/utils.py b/jupyterhub/utils.py index 60464d78..1fc4aa89 100644 --- a/jupyterhub/utils.py +++ b/jupyterhub/utils.py @@ -683,3 +683,44 @@ def catch_db_error(f): return r return catching + + +def get_browser_protocol(request): + """Get the _protocol_ seen by the browser + + Like tornado's _apply_xheaders, + but in the case of multiple proxy hops, + use the outermost value (what the browser likely sees) + instead of the innermost value, + which is the most trustworthy. + + We care about what the browser sees, + not where the request actually came from, + so trusting possible spoofs is the right thing to do. + """ + headers = request.headers + # first choice: Forwarded header + forwarded_header = headers.get("Forwarded") + if forwarded_header: + first_forwarded = forwarded_header.split(",", 1)[0].strip() + fields = {} + forwarded_dict = {} + for field in first_forwarded.split(";"): + key, _, value = field.partition("=") + fields[key.strip().lower()] = value.strip() + if "proto" in fields and fields["proto"].lower() in {"http", "https"}: + return fields["proto"].lower() + else: + app_log.warning( + f"Forwarded header present without protocol: {forwarded_header}" + ) + + # second choice: X-Scheme or X-Forwarded-Proto + proto_header = headers.get("X-Scheme", headers.get("X-Forwarded-Proto", None)) + if proto_header: + proto_header = proto_header.split(",")[0].strip().lower() + if proto_header in {"http", "https"}: + return proto_header + + # no forwarded headers + return request.protocol