From c3510d2853e3d90eaa982ee7cc894b19ffb496dd Mon Sep 17 00:00:00 2001 From: Min RK Date: Thu, 19 Oct 2023 13:43:20 +0200 Subject: [PATCH] move service oauth state fields to in-memory only use a random token as the actual oauth state, and use a local cache dict to store the extra info like cookie_name, next_url this avoids the state field getting too big and passing local browser-server info to anyone else --- jupyterhub/services/auth.py | 178 ++++++++++++++++++------- jupyterhub/tests/test_services_auth.py | 14 +- 2 files changed, 141 insertions(+), 51 deletions(-) diff --git a/jupyterhub/services/auth.py b/jupyterhub/services/auth.py index 389c7870..7af39944 100644 --- a/jupyterhub/services/auth.py +++ b/jupyterhub/services/auth.py @@ -24,16 +24,15 @@ A tornado implementation is provided in :class:`HubOAuthCallbackHandler`. """ import asyncio -import base64 import hashlib import json import os import random import re +import secrets import socket import string import time -import uuid import warnings from http import HTTPStatus from unittest import mock @@ -106,14 +105,24 @@ class _ExpiringDict(dict): """ max_age = 0 + purge_interval = 0 - def __init__(self, max_age=0): + def __init__(self, max_age=0, purge_interval="max_age"): self.max_age = max_age + if purge_interval == "max_age": + # default behavior: use max_age + purge_interval = max_age + self.purge_interval = purge_interval self.timestamps = {} self.values = {} + self._last_purge = time.monotonic() + + def __len__(self): + return len(self.values) def __setitem__(self, key, value): """Store key and record timestamp""" + self._maybe_purge() self.timestamps[key] = time.monotonic() self.values[key] = value @@ -139,6 +148,7 @@ class _ExpiringDict(dict): if self.max_age > 0 and timestamp + self.max_age < now: self.values.pop(key) self.timestamps.pop(key) + self._maybe_purge() def __contains__(self, key): """dict check for `key in dict`""" @@ -150,17 +160,57 @@ class _ExpiringDict(dict): self._check_age(key) return self.values[key] + def __delitem__(self, key): + del self.values[key] + del self.timestamps[key] + def get(self, key, default=None): - """dict-like get:""" + """dict-like get""" try: return self[key] except KeyError: return default + def pop(self, key, default="_raise"): + """Remove and return an item""" + if key in self: + value = self.values.pop(key) + del self.timestamps[key] + return value + else: + if default == "_raise": + raise KeyError(key) + else: + return default + def clear(self): """Clear the cache""" self.values.clear() self.timestamps.clear() + self._last_purge = time.monotonic() + + # extended methods + def _maybe_purge(self): + """purge expired values _if_ it's been purge_interval since the last purge + + Called on every get/set, to keep the expired values clear. + """ + if not self.purge_interval > 0: + return + now = time.monotonic() + if self._last_purge < (now - self.purge_interval): + self.purge_expired() + + def purge_expired(self): + """Purge all expired values""" + if not self.max_age > 0: + return + now = self._last_purge = time.monotonic() + cutoff = now - self.max_age + for key in list(self.timestamps): + timestamp = self.timestamps[key] + if timestamp < cutoff: + del self[key] class HubAuth(SingletonConfigurable): @@ -854,37 +904,32 @@ class HubOAuth(HubAuth): return token_reply['access_token'] - def _encode_state(self, state): - """Encode a state dict as url-safe base64""" - # trim trailing `=` because = is itself not url-safe! - json_state = json.dumps(state) - return ( - base64.urlsafe_b64encode(json_state.encode('utf8')) - .decode('ascii') - .rstrip('=') - ) + # state-related - def _decode_state(self, b64_state): - """Decode a base64 state + oauth_state_max_age = Integer( + 600, + config=True, + help="""Max age (seconds) of oauth state. + + Governs both oauth state cookie Max-Age, + as well as the in-memory _oauth_states cache. + """, + ) + _oauth_states = Instance( + _ExpiringDict, + allow_none=False, + help=""" + Store oauth state info, such as next_url + + The oauth state only contains the oauth state _id_, + while other information such as the cookie name, next_url + are stored in this dictionary. + """, + ) - Always returns a dict. - The dict will be empty if the state is invalid. - """ - if isinstance(b64_state, str): - b64_state = b64_state.encode('ascii') - if len(b64_state) != 4: - # restore padding - b64_state = b64_state + (b'=' * (4 - len(b64_state) % 4)) - try: - json_state = base64.urlsafe_b64decode(b64_state).decode('utf8') - except ValueError: - app_log.error("Failed to b64-decode state: %r", b64_state) - return {} - try: - return json.loads(json_state) - except ValueError: - app_log.error("Failed to json-decode state: %r", json_state) - return {} + @default('_oauth_states') + def _default_oauth_states(self): + return _ExpiringDict(max_age=self.oauth_state_max_age) def set_state_cookie(self, handler, next_url=None): """Generate an OAuth state and store it in a cookie @@ -914,7 +959,7 @@ class HubOAuth(HubAuth): extra_state['cookie_name'] = cookie_name else: cookie_name = self.state_cookie_name - b64_state = self.generate_state(next_url, **extra_state) + state_id = self.generate_state(next_url, **extra_state) kwargs = { 'path': self.base_url, 'httponly': True, @@ -924,16 +969,24 @@ class HubOAuth(HubAuth): # OAuth that doesn't complete shouldn't linger too long. 'max_age': 600, } + # don't allow overriding some fields + no_override_keys = set(kwargs.keys()) | {"expires_days", "expires"} if get_browser_protocol(handler.request) == 'https': kwargs['secure'] = True # load user cookie overrides - kwargs.update(self.cookie_options) - handler.set_secure_cookie(cookie_name, b64_state, **kwargs) - return b64_state + for key, value in self.cookie_options: + # don't include overrides + if key.lower() not in no_override_keys: + kwargs[key] = value + handler.set_secure_cookie(cookie_name, state_id, **kwargs) + return state_id def generate_state(self, next_url=None, **extra_state): """Generate a state string, given a next_url redirect target + The state info is stored locally in self._oauth_states, + and only the state id is returned for use in the oauth state field (cookie, redirect param) + Parameters ---------- next_url : str @@ -941,24 +994,44 @@ class HubOAuth(HubAuth): Returns ------- - state (str): The base64-encoded state string. + state_id (str): The state string to be used as a cookie value. """ - state = {'uuid': uuid.uuid4().hex, 'next_url': next_url} - state.update(extra_state) - return self._encode_state(state) + state_id = secrets.token_urlsafe(16) + state = {'next_url': next_url} + if extra_state: + state.update(extra_state) + self._oauth_states[state_id] = state + return state_id - def get_next_url(self, b64_state=''): + def clear_oauth_state(self, state_id): + """Clear persisted oauth state""" + self._oauth_states.pop(state_id, None) + self._oauth_states.purge_expired() + + def clear_oauth_state_cookies(self, handler): + """Clear persisted oauth state""" + for cookie_name, cookie in handler.request.cookies.items(): + if cookie_name.startswith(self.state_cookie_name): + handler.clear_cookie( + cookie_name, + path=self.base_url, + ) + + def _decode_state(self, state_id, /): + return self._oauth_states.get(state_id, {}) + + def get_next_url(self, state_id='', /): """Get the next_url for redirection, given an encoded OAuth state""" - state = self._decode_state(b64_state) + state = self._decode_state(state_id) return state.get('next_url') or self.base_url - def get_state_cookie_name(self, b64_state=''): + def get_state_cookie_name(self, state_id='', /): """Get the cookie name for oauth state, given an encoded OAuth state Cookie name is stored in the state itself because the cookie name is randomized to deal with races between concurrent oauth sequences. """ - state = self._decode_state(b64_state) + state = self._decode_state(state_id) return state.get('cookie_name') or self.state_cookie_name def set_cookie(self, handler, access_token): @@ -1230,23 +1303,30 @@ class HubOAuthCallbackHandler(HubOAuthenticated, RequestHandler): code = self.get_argument("code", False) if not code: - raise HTTPError(400, "oauth callback made without a token") + raise HTTPError(400, "OAuth callback made without a token") # validate OAuth state arg_state = self.get_argument("state", None) if arg_state is None: - raise HTTPError(500, "oauth state is missing. Try logging in again.") + raise HTTPError(400, "OAuth state is missing. Try logging in again.") cookie_name = self.hub_auth.get_state_cookie_name(arg_state) cookie_state = self.get_secure_cookie(cookie_name) # clear cookie state now that we've consumed it - self.clear_cookie(cookie_name, path=self.hub_auth.base_url) + if cookie_state: + self.clear_cookie(cookie_name, path=self.hub_auth.base_url) if isinstance(cookie_state, bytes): cookie_state = cookie_state.decode('ascii', 'replace') # check that state matches if arg_state != cookie_state: app_log.warning("oauth state %r != %r", arg_state, cookie_state) - raise HTTPError(403, "oauth state does not match. Try logging in again.") + raise HTTPError(403, "OAuth state does not match. Try logging in again.") next_url = self.hub_auth.get_next_url(cookie_state) + # clear consumed state from _oauth_states cache now that we're done with it + self.hub_auth.clear_oauth_state(cookie_state) + # clear _all_ oauth state cookies on success ? + # This prevents multiple concurrent logins in the same browser, + # which is probably okay. + # self.hub_auth.clear_oauth_state_cookies(self) token = await self.hub_auth.token_for_code(code, sync=False) session_id = self.hub_auth.get_session_id(self) diff --git a/jupyterhub/tests/test_services_auth.py b/jupyterhub/tests/test_services_auth.py index 68faac9b..8c8b5b4a 100644 --- a/jupyterhub/tests/test_services_auth.py +++ b/jupyterhub/tests/test_services_auth.py @@ -2,6 +2,7 @@ import copy import os import sys +import time from binascii import hexlify from unittest import mock from urllib.parse import parse_qs, urlparse @@ -530,12 +531,21 @@ async def test_oauth_cookie_collision(app, mockservice_url, create_user_with_sco state_cookie_name = 'service-%s-oauth-state' % service.name service_cookie_name = 'service-%s' % service.name oauth_1 = await s.get(url) - print(oauth_1.headers) - print(oauth_1.cookies, oauth_1.url, url) assert state_cookie_name in s.cookies state_cookies = [c for c in s.cookies.keys() if c.startswith(state_cookie_name)] # only one state cookie assert state_cookies == [state_cookie_name] + # create dict of Cookie objects, so we can check properties + # cookies.__getitem__ returns only cookie _value_, but we want to check expiration + cookie_dict = {cookie.name: cookie for cookie in s.cookies} + state_cookie = cookie_dict[state_cookie_name] + + # check state cookie properties + # should expire in 10 minutes (600 seconds) + assert time.time() < state_cookie.expires < time.time() + 630 + # path is set right + assert state_cookie.path == service.prefix + state_1 = s.cookies[state_cookie_name] # start second oauth login before finishing the first