move service oauth state fields to in-memory

only use a random token as the actual oauth state,
and use a local cache dict to store the extra info like cookie_name, next_url

this avoids the state field getting too big and passing local browser-server info to anyone else
This commit is contained in:
Min RK
2023-10-19 13:43:20 +02:00
parent 0d6c27ca1d
commit c3510d2853
2 changed files with 141 additions and 51 deletions

View File

@@ -24,16 +24,15 @@ A tornado implementation is provided in :class:`HubOAuthCallbackHandler`.
"""
import asyncio
import base64
import hashlib
import json
import os
import random
import re
import secrets
import socket
import string
import time
import uuid
import warnings
from http import HTTPStatus
from unittest import mock
@@ -106,14 +105,24 @@ class _ExpiringDict(dict):
"""
max_age = 0
purge_interval = 0
def __init__(self, max_age=0):
def __init__(self, max_age=0, purge_interval="max_age"):
self.max_age = max_age
if purge_interval == "max_age":
# default behavior: use max_age
purge_interval = max_age
self.purge_interval = purge_interval
self.timestamps = {}
self.values = {}
self._last_purge = time.monotonic()
def __len__(self):
return len(self.values)
def __setitem__(self, key, value):
"""Store key and record timestamp"""
self._maybe_purge()
self.timestamps[key] = time.monotonic()
self.values[key] = value
@@ -139,6 +148,7 @@ class _ExpiringDict(dict):
if self.max_age > 0 and timestamp + self.max_age < now:
self.values.pop(key)
self.timestamps.pop(key)
self._maybe_purge()
def __contains__(self, key):
"""dict check for `key in dict`"""
@@ -150,17 +160,57 @@ class _ExpiringDict(dict):
self._check_age(key)
return self.values[key]
def __delitem__(self, key):
del self.values[key]
del self.timestamps[key]
def get(self, key, default=None):
"""dict-like get:"""
"""dict-like get"""
try:
return self[key]
except KeyError:
return default
def pop(self, key, default="_raise"):
"""Remove and return an item"""
if key in self:
value = self.values.pop(key)
del self.timestamps[key]
return value
else:
if default == "_raise":
raise KeyError(key)
else:
return default
def clear(self):
"""Clear the cache"""
self.values.clear()
self.timestamps.clear()
self._last_purge = time.monotonic()
# extended methods
def _maybe_purge(self):
"""purge expired values _if_ it's been purge_interval since the last purge
Called on every get/set, to keep the expired values clear.
"""
if not self.purge_interval > 0:
return
now = time.monotonic()
if self._last_purge < (now - self.purge_interval):
self.purge_expired()
def purge_expired(self):
"""Purge all expired values"""
if not self.max_age > 0:
return
now = self._last_purge = time.monotonic()
cutoff = now - self.max_age
for key in list(self.timestamps):
timestamp = self.timestamps[key]
if timestamp < cutoff:
del self[key]
class HubAuth(SingletonConfigurable):
@@ -854,37 +904,32 @@ class HubOAuth(HubAuth):
return token_reply['access_token']
def _encode_state(self, state):
"""Encode a state dict as url-safe base64"""
# trim trailing `=` because = is itself not url-safe!
json_state = json.dumps(state)
return (
base64.urlsafe_b64encode(json_state.encode('utf8'))
.decode('ascii')
.rstrip('=')
)
# state-related
def _decode_state(self, b64_state):
"""Decode a base64 state
oauth_state_max_age = Integer(
600,
config=True,
help="""Max age (seconds) of oauth state.
Governs both oauth state cookie Max-Age,
as well as the in-memory _oauth_states cache.
""",
)
_oauth_states = Instance(
_ExpiringDict,
allow_none=False,
help="""
Store oauth state info, such as next_url
The oauth state only contains the oauth state _id_,
while other information such as the cookie name, next_url
are stored in this dictionary.
""",
)
Always returns a dict.
The dict will be empty if the state is invalid.
"""
if isinstance(b64_state, str):
b64_state = b64_state.encode('ascii')
if len(b64_state) != 4:
# restore padding
b64_state = b64_state + (b'=' * (4 - len(b64_state) % 4))
try:
json_state = base64.urlsafe_b64decode(b64_state).decode('utf8')
except ValueError:
app_log.error("Failed to b64-decode state: %r", b64_state)
return {}
try:
return json.loads(json_state)
except ValueError:
app_log.error("Failed to json-decode state: %r", json_state)
return {}
@default('_oauth_states')
def _default_oauth_states(self):
return _ExpiringDict(max_age=self.oauth_state_max_age)
def set_state_cookie(self, handler, next_url=None):
"""Generate an OAuth state and store it in a cookie
@@ -914,7 +959,7 @@ class HubOAuth(HubAuth):
extra_state['cookie_name'] = cookie_name
else:
cookie_name = self.state_cookie_name
b64_state = self.generate_state(next_url, **extra_state)
state_id = self.generate_state(next_url, **extra_state)
kwargs = {
'path': self.base_url,
'httponly': True,
@@ -924,16 +969,24 @@ class HubOAuth(HubAuth):
# OAuth that doesn't complete shouldn't linger too long.
'max_age': 600,
}
# don't allow overriding some fields
no_override_keys = set(kwargs.keys()) | {"expires_days", "expires"}
if get_browser_protocol(handler.request) == 'https':
kwargs['secure'] = True
# load user cookie overrides
kwargs.update(self.cookie_options)
handler.set_secure_cookie(cookie_name, b64_state, **kwargs)
return b64_state
for key, value in self.cookie_options:
# don't include overrides
if key.lower() not in no_override_keys:
kwargs[key] = value
handler.set_secure_cookie(cookie_name, state_id, **kwargs)
return state_id
def generate_state(self, next_url=None, **extra_state):
"""Generate a state string, given a next_url redirect target
The state info is stored locally in self._oauth_states,
and only the state id is returned for use in the oauth state field (cookie, redirect param)
Parameters
----------
next_url : str
@@ -941,24 +994,44 @@ class HubOAuth(HubAuth):
Returns
-------
state (str): The base64-encoded state string.
state_id (str): The state string to be used as a cookie value.
"""
state = {'uuid': uuid.uuid4().hex, 'next_url': next_url}
state.update(extra_state)
return self._encode_state(state)
state_id = secrets.token_urlsafe(16)
state = {'next_url': next_url}
if extra_state:
state.update(extra_state)
self._oauth_states[state_id] = state
return state_id
def get_next_url(self, b64_state=''):
def clear_oauth_state(self, state_id):
"""Clear persisted oauth state"""
self._oauth_states.pop(state_id, None)
self._oauth_states.purge_expired()
def clear_oauth_state_cookies(self, handler):
"""Clear persisted oauth state"""
for cookie_name, cookie in handler.request.cookies.items():
if cookie_name.startswith(self.state_cookie_name):
handler.clear_cookie(
cookie_name,
path=self.base_url,
)
def _decode_state(self, state_id, /):
return self._oauth_states.get(state_id, {})
def get_next_url(self, state_id='', /):
"""Get the next_url for redirection, given an encoded OAuth state"""
state = self._decode_state(b64_state)
state = self._decode_state(state_id)
return state.get('next_url') or self.base_url
def get_state_cookie_name(self, b64_state=''):
def get_state_cookie_name(self, state_id='', /):
"""Get the cookie name for oauth state, given an encoded OAuth state
Cookie name is stored in the state itself because the cookie name
is randomized to deal with races between concurrent oauth sequences.
"""
state = self._decode_state(b64_state)
state = self._decode_state(state_id)
return state.get('cookie_name') or self.state_cookie_name
def set_cookie(self, handler, access_token):
@@ -1230,23 +1303,30 @@ class HubOAuthCallbackHandler(HubOAuthenticated, RequestHandler):
code = self.get_argument("code", False)
if not code:
raise HTTPError(400, "oauth callback made without a token")
raise HTTPError(400, "OAuth callback made without a token")
# validate OAuth state
arg_state = self.get_argument("state", None)
if arg_state is None:
raise HTTPError(500, "oauth state is missing. Try logging in again.")
raise HTTPError(400, "OAuth state is missing. Try logging in again.")
cookie_name = self.hub_auth.get_state_cookie_name(arg_state)
cookie_state = self.get_secure_cookie(cookie_name)
# clear cookie state now that we've consumed it
self.clear_cookie(cookie_name, path=self.hub_auth.base_url)
if cookie_state:
self.clear_cookie(cookie_name, path=self.hub_auth.base_url)
if isinstance(cookie_state, bytes):
cookie_state = cookie_state.decode('ascii', 'replace')
# check that state matches
if arg_state != cookie_state:
app_log.warning("oauth state %r != %r", arg_state, cookie_state)
raise HTTPError(403, "oauth state does not match. Try logging in again.")
raise HTTPError(403, "OAuth state does not match. Try logging in again.")
next_url = self.hub_auth.get_next_url(cookie_state)
# clear consumed state from _oauth_states cache now that we're done with it
self.hub_auth.clear_oauth_state(cookie_state)
# clear _all_ oauth state cookies on success ?
# This prevents multiple concurrent logins in the same browser,
# which is probably okay.
# self.hub_auth.clear_oauth_state_cookies(self)
token = await self.hub_auth.token_for_code(code, sync=False)
session_id = self.hub_auth.get_session_id(self)

View File

@@ -2,6 +2,7 @@
import copy
import os
import sys
import time
from binascii import hexlify
from unittest import mock
from urllib.parse import parse_qs, urlparse
@@ -530,12 +531,21 @@ async def test_oauth_cookie_collision(app, mockservice_url, create_user_with_sco
state_cookie_name = 'service-%s-oauth-state' % service.name
service_cookie_name = 'service-%s' % service.name
oauth_1 = await s.get(url)
print(oauth_1.headers)
print(oauth_1.cookies, oauth_1.url, url)
assert state_cookie_name in s.cookies
state_cookies = [c for c in s.cookies.keys() if c.startswith(state_cookie_name)]
# only one state cookie
assert state_cookies == [state_cookie_name]
# create dict of Cookie objects, so we can check properties
# cookies.__getitem__ returns only cookie _value_, but we want to check expiration
cookie_dict = {cookie.name: cookie for cookie in s.cookies}
state_cookie = cookie_dict[state_cookie_name]
# check state cookie properties
# should expire in 10 minutes (600 seconds)
assert time.time() < state_cookie.expires < time.time() + 630
# path is set right
assert state_cookie.path == service.prefix
state_1 = s.cookies[state_cookie_name]
# start second oauth login before finishing the first