use outermost proxied entry when checking for browser protocol

wee care about what the browser sees, so trust the outermost entry instead of the innermost

This is not secure _in general_, in that these values can be spoofed by malicious proxies,
but for CORS and cookie purposes, we only care about what the browser sees,
however many hops there may be.

A malicious proxy in the chain here isn't a concern because what matters is the immediate
hop from the _browser_, not the immediate hop from the _server_.
This commit is contained in:
Min RK
2022-01-07 14:03:11 +01:00
parent a2ba55756d
commit ccfee4d235
7 changed files with 119 additions and 13 deletions

View File

@@ -16,6 +16,7 @@ from tornado import web
from .. import orm from .. import orm
from .. import roles from .. import roles
from .. import scopes from .. import scopes
from ..utils import get_browser_protocol
from ..utils import token_authenticated from ..utils import token_authenticated
from .base import APIHandler from .base import APIHandler
from .base import BaseHandler from .base import BaseHandler
@@ -115,7 +116,10 @@ class OAuthHandler:
# make absolute local redirects full URLs # make absolute local redirects full URLs
# to satisfy oauthlib's absolute URI requirement # to satisfy oauthlib's absolute URI requirement
redirect_uri = ( redirect_uri = (
self.request.protocol + "://" + self.request.headers['Host'] + redirect_uri get_browser_protocol(self.request)
+ "://"
+ self.request.host
+ redirect_uri
) )
parsed_url = urlparse(uri) parsed_url = urlparse(uri)
query_list = parse_qsl(parsed_url.query, keep_blank_values=True) query_list = parse_qsl(parsed_url.query, keep_blank_values=True)

View File

@@ -14,6 +14,7 @@ from tornado import web
from .. import orm from .. import orm
from ..handlers import BaseHandler from ..handlers import BaseHandler
from ..utils import get_browser_protocol
from ..utils import isoformat from ..utils import isoformat
from ..utils import url_path_join from ..utils import url_path_join
@@ -60,6 +61,8 @@ class APIHandler(BaseHandler):
""" """
host_header = self.app.forwarded_host_header or "Host" host_header = self.app.forwarded_host_header or "Host"
host = self.request.headers.get(host_header) host = self.request.headers.get(host_header)
if host and "," in host:
host = host.split(",", 1)[0].strip()
referer = self.request.headers.get("Referer") referer = self.request.headers.get("Referer")
# If no header is provided, assume it comes from a script/curl. # If no header is provided, assume it comes from a script/curl.
@@ -71,7 +74,8 @@ class APIHandler(BaseHandler):
self.log.warning("Blocking API request with no referer") self.log.warning("Blocking API request with no referer")
return False return False
proto = self.request.protocol proto = get_browser_protocol(self.request)
full_host = f"{proto}://{host}{self.hub.base_url}" full_host = f"{proto}://{host}{self.hub.base_url}"
host_url = urlparse(full_host) host_url = urlparse(full_host)
referer_url = urlparse(referer) referer_url = urlparse(referer)

View File

@@ -49,6 +49,7 @@ from ..spawner import LocalProcessSpawner
from ..user import User from ..user import User
from ..utils import AnyTimeoutError from ..utils import AnyTimeoutError
from ..utils import get_accepted_mimetype from ..utils import get_accepted_mimetype
from ..utils import get_browser_protocol
from ..utils import maybe_future from ..utils import maybe_future
from ..utils import url_path_join from ..utils import url_path_join
@@ -632,12 +633,10 @@ class BaseHandler(RequestHandler):
next_url = self.get_argument('next', default='') next_url = self.get_argument('next', default='')
# protect against some browsers' buggy handling of backslash as slash # protect against some browsers' buggy handling of backslash as slash
next_url = next_url.replace('\\', '%5C') next_url = next_url.replace('\\', '%5C')
if (next_url + '/').startswith( proto = get_browser_protocol(self.request)
( host = self.request.host
f'{self.request.protocol}://{self.request.host}/',
f'//{self.request.host}/', if (next_url + '/').startswith((f'{proto}://{host}/', f'//{host}/',)) or (
)
) or (
self.subdomain_host self.subdomain_host
and urlparse(next_url).netloc and urlparse(next_url).netloc
and ("." + urlparse(next_url).netloc).endswith( and ("." + urlparse(next_url).netloc).endswith(

View File

@@ -53,6 +53,7 @@ from traitlets import validate
from traitlets.config import SingletonConfigurable from traitlets.config import SingletonConfigurable
from ..scopes import _intersect_expanded_scopes from ..scopes import _intersect_expanded_scopes
from ..utils import get_browser_protocol
from ..utils import url_path_join from ..utils import url_path_join
@@ -772,7 +773,7 @@ class HubOAuth(HubAuth):
# OAuth that doesn't complete shouldn't linger too long. # OAuth that doesn't complete shouldn't linger too long.
'max_age': 600, 'max_age': 600,
} }
if handler.request.protocol == 'https': if get_browser_protocol(handler.request) == 'https':
kwargs['secure'] = True kwargs['secure'] = True
# load user cookie overrides # load user cookie overrides
kwargs.update(self.cookie_options) kwargs.update(self.cookie_options)
@@ -812,7 +813,7 @@ class HubOAuth(HubAuth):
def set_cookie(self, handler, access_token): def set_cookie(self, handler, access_token):
"""Set a cookie recording OAuth result""" """Set a cookie recording OAuth result"""
kwargs = {'path': self.base_url, 'httponly': True} kwargs = {'path': self.base_url, 'httponly': True}
if handler.request.protocol == 'https': if get_browser_protocol(handler.request) == 'https':
kwargs['secure'] = True kwargs['secure'] = True
# load user cookie overrides # load user cookie overrides
kwargs.update(self.cookie_options) kwargs.update(self.cookie_options)

View File

@@ -118,16 +118,39 @@ async def test_post_content_type(app, content_type, status):
("fake.example", {"netloc": "fake.example", "scheme": "https"}, {}, 403), ("fake.example", {"netloc": "fake.example", "scheme": "https"}, {}, 403),
# explicit ports, match # explicit ports, match
("fake.example:81", {"netloc": "fake.example:81"}, {}, 200), ("fake.example:81", {"netloc": "fake.example:81"}, {}, 200),
# Test proxy defined headers taken into account by xheaders=True in # Test proxy protocol defined headers taken into account by utils.get_browser_protocol
# https://github.com/jupyterhub/jupyterhub/blob/2.0.1/jupyterhub/app.py#L3065
( (
"fake.example", "fake.example",
{"netloc": "fake.example", "scheme": "https"}, {"netloc": "fake.example", "scheme": "https"},
# note {"X-Forwarded-Proto": "https"} does not work
{'X-Scheme': 'https'}, {'X-Scheme': 'https'},
200, 200,
), ),
(
"fake.example",
{"netloc": "fake.example", "scheme": "https"},
{'X-Forwarded-Proto': 'https'},
200,
),
(
"fake.example",
{"netloc": "fake.example", "scheme": "https"},
{
'Forwarded': 'host=fake.example;proto=https,for=1.2.34;proto=http',
'X-Scheme': 'http',
},
200,
),
(
"fake.example",
{"netloc": "fake.example", "scheme": "https"},
{
'Forwarded': 'host=fake.example;proto=http,for=1.2.34;proto=http',
'X-Scheme': 'https',
},
403,
),
("fake.example", {"netloc": "fake.example"}, {'X-Scheme': 'https'}, 403), ("fake.example", {"netloc": "fake.example"}, {'X-Scheme': 'https'}, 403),
("fake.example", {"netloc": "fake.example"}, {'X-Scheme': 'https, http'}, 403),
], ],
) )
async def test_cors_check(request, app, host, referer, extraheaders, status): async def test_cors_check(request, app, host, referer, extraheaders, status):

View File

@@ -2,12 +2,16 @@
import asyncio import asyncio
import time import time
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from unittest.mock import Mock
import pytest import pytest
from async_generator import aclosing from async_generator import aclosing
from tornado import gen from tornado import gen
from tornado.concurrent import run_on_executor from tornado.concurrent import run_on_executor
from tornado.httpserver import HTTPRequest
from tornado.httputil import HTTPHeaders
from .. import utils
from ..utils import iterate_until from ..utils import iterate_until
@@ -88,3 +92,33 @@ async def test_tornado_coroutines():
# verify that tornado gen and executor methods return awaitables # verify that tornado gen and executor methods return awaitables
assert (await t.on_executor()) == "executor" assert (await t.on_executor()) == "executor"
assert (await t.tornado_coroutine()) == "gen.coroutine" assert (await t.tornado_coroutine()) == "gen.coroutine"
@pytest.mark.parametrize(
"forwarded, x_scheme, x_forwarded_proto, expected",
[
("", "", "", "_attr_"),
("for=1.2.3.4", "", "", "_attr_"),
("for=1.2.3.4,proto=https", "", "", "_attr_"),
("", "https", "http", "https"),
("", "https, http", "", "https"),
("", "https, http", "http", "https"),
("proto=http ; for=1.2.3.4, proto=https", "https, http", "", "http"),
("proto=invalid;for=1.2.3.4,proto=http", "https, http", "", "https"),
("for=1.2.3.4,proto=http", "https, http", "", "https"),
("", "invalid, http", "", "_attr_"),
],
)
def test_browser_protocol(x_scheme, x_forwarded_proto, forwarded, expected):
request = Mock(spec=HTTPRequest)
request.protocol = "_attr_"
request.headers = HTTPHeaders()
if x_scheme:
request.headers["X-Scheme"] = x_scheme
if x_forwarded_proto:
request.headers["X-Forwarded-Proto"] = x_forwarded_proto
if forwarded:
request.headers["Forwarded"] = forwarded
proto = utils.get_browser_protocol(request)
assert proto == expected

View File

@@ -683,3 +683,44 @@ def catch_db_error(f):
return r return r
return catching return catching
def get_browser_protocol(request):
"""Get the _protocol_ seen by the browser
Like tornado's _apply_xheaders,
but in the case of multiple proxy hops,
use the outermost value (what the browser likely sees)
instead of the innermost value,
which is the most trustworthy.
We care about what the browser sees,
not where the request actually came from,
so trusting possible spoofs is the right thing to do.
"""
headers = request.headers
# first choice: Forwarded header
forwarded_header = headers.get("Forwarded")
if forwarded_header:
first_forwarded = forwarded_header.split(",", 1)[0].strip()
fields = {}
forwarded_dict = {}
for field in first_forwarded.split(";"):
key, _, value = field.partition("=")
fields[key.strip().lower()] = value.strip()
if "proto" in fields and fields["proto"].lower() in {"http", "https"}:
return fields["proto"].lower()
else:
app_log.warning(
f"Forwarded header present without protocol: {forwarded_header}"
)
# second choice: X-Scheme or X-Forwarded-Proto
proto_header = headers.get("X-Scheme", headers.get("X-Forwarded-Proto", None))
if proto_header:
proto_header = proto_header.split(",")[0].strip().lower()
if proto_header in {"http", "https"}:
return proto_header
# no forwarded headers
return request.protocol