diff --git a/jupyterhub/apihandlers/base.py b/jupyterhub/apihandlers/base.py index c2cf1bce..ce422009 100644 --- a/jupyterhub/apihandlers/base.py +++ b/jupyterhub/apihandlers/base.py @@ -9,8 +9,43 @@ from http.client import responses from tornado import web from ..handlers import BaseHandler +from ..utils import url_path_join class APIHandler(BaseHandler): + + def check_referer(self): + """Check Origin for cross-site API requests. + + Copied from WebSocket with changes: + + - allow unspecified host/referer (e.g. scripts) + """ + host = self.request.headers.get("Host") + referer = self.request.headers.get("Referer") + + # If no header is provided, assume it comes from a script/curl. + # We are only concerned with cross-site browser stuff here. + if not host: + self.log.warn("Blocking API request with no host") + return False + if not referer: + self.log.warn("Blocking API request with no referer") + return False + + host_path = url_path_join(host, self.hub.server.base_url) + referer_path = referer.split('://', 1)[-1] + if not (referer_path + '/').startswith(host_path): + self.log.warn("Blocking Cross Origin API request. Referer: %s, Host: %s", + referer, host_path) + return False + return True + + def get_current_user_cookie(self): + """Override get_user_cookie to check Referer header""" + if not self.check_referer(): + return None + return super().get_current_user_cookie() + def get_json_body(self): """Return the body of the request as JSON data.""" if not self.request.body: @@ -23,7 +58,6 @@ class APIHandler(BaseHandler): self.log.error("Couldn't parse JSON", exc_info=True) raise web.HTTPError(400, 'Invalid JSON in body of request') return model - def write_error(self, status_code, **kwargs): """Write JSON errors instead of HTML""" diff --git a/jupyterhub/tests/test_api.py b/jupyterhub/tests/test_api.py index 7467173e..dc7facdf 100644 --- a/jupyterhub/tests/test_api.py +++ b/jupyterhub/tests/test_api.py @@ -4,6 +4,7 @@ import json import time from datetime import timedelta from queue import Queue +from urllib.parse import urlparse import requests @@ -97,6 +98,51 @@ def test_auth_api(app): ) assert r.status_code == 403 + +def test_referer_check(app, io_loop): + url = app.hub.server.url + host = urlparse(url).netloc + user = find_user(app.db, 'admin') + if user is None: + user = add_user(app.db, name='admin', admin=True) + cookies = app.login_user('admin') + app_user = get_app_user(app, 'admin') + # stop the admin's server so we don't mess up future tests + io_loop.run_sync(lambda : app.proxy.delete_user(app_user)) + io_loop.run_sync(app_user.stop) + + r = api_request(app, 'users', + headers={ + 'Authorization': '', + 'Referer': 'null', + }, cookies=cookies, + ) + assert r.status_code == 403 + r = api_request(app, 'users', + headers={ + 'Authorization': '', + 'Referer': 'http://attack.com/csrf/vulnerability', + }, cookies=cookies, + ) + assert r.status_code == 403 + r = api_request(app, 'users', + headers={ + 'Authorization': '', + 'Referer': url, + 'Host': host, + }, cookies=cookies, + ) + assert r.status_code == 200 + r = api_request(app, 'users', + headers={ + 'Authorization': '', + 'Referer': ujoin(url, 'foo/bar/baz/bat'), + 'Host': host, + }, cookies=cookies, + ) + assert r.status_code == 200 + + def test_get_users(app): db = app.db r = api_request(app, 'users')