diff --git a/jupyterhub/apihandlers/base.py b/jupyterhub/apihandlers/base.py index 73b129d8..2542c825 100644 --- a/jupyterhub/apihandlers/base.py +++ b/jupyterhub/apihandlers/base.py @@ -58,7 +58,8 @@ class APIHandler(BaseHandler): - allow unspecified host/referer (e.g. scripts) """ - host = self.request.headers.get(self.app.forwarded_host_header or "Host") + host_header = self.app.forwarded_host_header or "Host" + host = self.request.headers.get(host_header) referer = self.request.headers.get("Referer") # If no header is provided, assume it comes from a script/curl. @@ -70,13 +71,24 @@ class APIHandler(BaseHandler): self.log.warning("Blocking API request with no referer") return False - host_path = url_path_join(host, self.hub.base_url) - referer_path = referer.split('://', 1)[-1] - if not (referer_path + '/').startswith(host_path): + proto = self.request.protocol + full_host = f"{proto}://{host}{self.hub.base_url}" + host_url = urlparse(full_host) + referer_url = urlparse(referer) + # resolve default ports for http[s] + referer_port = referer_url.port or ( + 443 if referer_url.scheme == 'https' else 80 + ) + host_port = host_url.port or (443 if host_url.scheme == 'https' else 80) + if ( + referer_url.scheme != host_url.scheme + or referer_url.hostname != host_url.hostname + or referer_port != host_port + or not (referer_url.path + "/").startswith(host_url.path) + ): self.log.warning( - "Blocking Cross Origin API request. Referer: %s, Host: %s", - referer, - host_path, + f"Blocking Cross Origin API request. Referer: {referer}," + " {host_header}: {host}, Host URL: {full_host}", ) return False return True diff --git a/jupyterhub/tests/test_api.py b/jupyterhub/tests/test_api.py index 634a75f5..4f526ded 100644 --- a/jupyterhub/tests/test_api.py +++ b/jupyterhub/tests/test_api.py @@ -9,6 +9,7 @@ from datetime import timedelta from unittest import mock from urllib.parse import quote from urllib.parse import urlparse +from urllib.parse import urlunparse from pytest import fixture from pytest import mark @@ -65,7 +66,15 @@ async def test_auth_api(app): assert r.status_code == 403 -async def test_cors_checks(request, app): +@mark.parametrize( + "content_type, status", + [ + ("text/plain", 403), + # accepted, but invalid + ("application/json; charset=UTF-8", 400), + ], +) +async def test_post_content_type(app, content_type, status): url = ujoin(public_host(app), app.hub.base_url) host = urlparse(url).netloc # add admin user @@ -74,42 +83,6 @@ async def test_cors_checks(request, app): user = add_user(app.db, name='admin', admin=True) cookies = await app.login_user('admin') - r = await api_request( - app, 'users', headers={'Authorization': '', 'Referer': 'null'}, cookies=cookies - ) - assert r.status_code == 403 - - r = await api_request( - app, - 'users', - headers={ - 'Authorization': '', - 'Referer': 'http://attack.com/csrf/vulnerability', - }, - cookies=cookies, - ) - assert r.status_code == 403 - - r = await api_request( - app, - 'users', - headers={'Authorization': '', 'Referer': url, 'Host': host}, - cookies=cookies, - ) - assert r.status_code == 200 - - r = await api_request( - app, - 'users', - headers={ - 'Authorization': '', - 'Referer': ujoin(url, 'foo/bar/baz/bat'), - 'Host': host, - }, - cookies=cookies, - ) - assert r.status_code == 200 - r = await api_request( app, 'users', @@ -117,24 +90,62 @@ async def test_cors_checks(request, app): data='{}', headers={ "Authorization": "", - "Content-Type": "text/plain", + "Content-Type": content_type, }, cookies=cookies, ) - assert r.status_code == 403 + assert r.status_code == status - r = await api_request( - app, - 'users', - method='post', - data='{}', - headers={ - "Authorization": "", - "Content-Type": "application/json; charset=UTF-8", - }, - cookies=cookies, - ) - assert r.status_code == 400 # accepted, but invalid + +@mark.parametrize( + "host, referer, status", + [ + ('$host', '$url', 200), + (None, None, 200), + (None, 'null', 403), + (None, 'http://attack.com/csrf/vulnerability', 403), + ('$host', {"path": "/user/someuser"}, 403), + ('$host', {"path": "{path}/foo/bar/subpath"}, 200), + # mismatch host + ("mismatch.com", "$url", 403), + # explicit host, matches + ("fake.example", {"netloc": "fake.example"}, 200), + # explicit port, matches implicit port + ("fake.example:80", {"netloc": "fake.example"}, 200), + # explicit port, mismatch + ("fake.example:81", {"netloc": "fake.example"}, 403), + # implicit ports, mismatch proto + ("fake.example", {"netloc": "fake.example", "scheme": "https"}, 403), + ], +) +async def test_cors_check(request, app, host, referer, status): + url = ujoin(public_host(app), app.hub.base_url) + real_host = urlparse(url).netloc + if host == "$host": + host = real_host + + if referer == '$url': + referer = url + elif isinstance(referer, dict): + parsed_url = urlparse(url) + # apply {} + url_ns = {key: getattr(parsed_url, key) for key in parsed_url._fields} + for key, value in referer.items(): + referer[key] = value.format(**url_ns) + referer = urlunparse(parsed_url._replace(**referer)) + + # disable default auth header, cors is for cookie auth + headers = {"Authorization": ""} + if host is not None: + headers['X-Forwarded-Host'] = host + if referer is not None: + headers['Referer'] = referer + + # add admin user + user = find_user(app.db, 'admin') + if user is None: + user = add_user(app.db, name='admin', admin=True) + cookies = await app.login_user('admin') # test custom forwarded_host_header behavior app.forwarded_host_header = 'X-Forwarded-Host' @@ -148,28 +159,10 @@ async def test_cors_checks(request, app): r = await api_request( app, 'users', - headers={ - 'Authorization': '', - 'Referer': url, - 'Host': host, - 'X-Forwarded-Host': 'example.com', - }, + headers=headers, cookies=cookies, ) - assert r.status_code == 403 - - r = await api_request( - app, - 'users', - headers={ - 'Authorization': '', - 'Referer': url, - 'Host': host, - 'X-Forwarded-Host': host, - }, - cookies=cookies, - ) - assert r.status_code == 200 + assert r.status_code == status # --------------