move service oauth state fields to in-memory

only use a random token as the actual oauth state,
and use a local cache dict to store the extra info like cookie_name, next_url

this avoids the state field getting too big and passing local browser-server info to anyone else
This commit is contained in:
Min RK
2023-10-19 13:43:20 +02:00
parent 0d6c27ca1d
commit c3510d2853
2 changed files with 141 additions and 51 deletions

View File

@@ -24,16 +24,15 @@ A tornado implementation is provided in :class:`HubOAuthCallbackHandler`.
""" """
import asyncio import asyncio
import base64
import hashlib import hashlib
import json import json
import os import os
import random import random
import re import re
import secrets
import socket import socket
import string import string
import time import time
import uuid
import warnings import warnings
from http import HTTPStatus from http import HTTPStatus
from unittest import mock from unittest import mock
@@ -106,14 +105,24 @@ class _ExpiringDict(dict):
""" """
max_age = 0 max_age = 0
purge_interval = 0
def __init__(self, max_age=0): def __init__(self, max_age=0, purge_interval="max_age"):
self.max_age = max_age self.max_age = max_age
if purge_interval == "max_age":
# default behavior: use max_age
purge_interval = max_age
self.purge_interval = purge_interval
self.timestamps = {} self.timestamps = {}
self.values = {} self.values = {}
self._last_purge = time.monotonic()
def __len__(self):
return len(self.values)
def __setitem__(self, key, value): def __setitem__(self, key, value):
"""Store key and record timestamp""" """Store key and record timestamp"""
self._maybe_purge()
self.timestamps[key] = time.monotonic() self.timestamps[key] = time.monotonic()
self.values[key] = value self.values[key] = value
@@ -139,6 +148,7 @@ class _ExpiringDict(dict):
if self.max_age > 0 and timestamp + self.max_age < now: if self.max_age > 0 and timestamp + self.max_age < now:
self.values.pop(key) self.values.pop(key)
self.timestamps.pop(key) self.timestamps.pop(key)
self._maybe_purge()
def __contains__(self, key): def __contains__(self, key):
"""dict check for `key in dict`""" """dict check for `key in dict`"""
@@ -150,17 +160,57 @@ class _ExpiringDict(dict):
self._check_age(key) self._check_age(key)
return self.values[key] return self.values[key]
def __delitem__(self, key):
del self.values[key]
del self.timestamps[key]
def get(self, key, default=None): def get(self, key, default=None):
"""dict-like get:""" """dict-like get"""
try: try:
return self[key] return self[key]
except KeyError: except KeyError:
return default return default
def pop(self, key, default="_raise"):
"""Remove and return an item"""
if key in self:
value = self.values.pop(key)
del self.timestamps[key]
return value
else:
if default == "_raise":
raise KeyError(key)
else:
return default
def clear(self): def clear(self):
"""Clear the cache""" """Clear the cache"""
self.values.clear() self.values.clear()
self.timestamps.clear() self.timestamps.clear()
self._last_purge = time.monotonic()
# extended methods
def _maybe_purge(self):
"""purge expired values _if_ it's been purge_interval since the last purge
Called on every get/set, to keep the expired values clear.
"""
if not self.purge_interval > 0:
return
now = time.monotonic()
if self._last_purge < (now - self.purge_interval):
self.purge_expired()
def purge_expired(self):
"""Purge all expired values"""
if not self.max_age > 0:
return
now = self._last_purge = time.monotonic()
cutoff = now - self.max_age
for key in list(self.timestamps):
timestamp = self.timestamps[key]
if timestamp < cutoff:
del self[key]
class HubAuth(SingletonConfigurable): class HubAuth(SingletonConfigurable):
@@ -854,37 +904,32 @@ class HubOAuth(HubAuth):
return token_reply['access_token'] return token_reply['access_token']
def _encode_state(self, state): # state-related
"""Encode a state dict as url-safe base64"""
# trim trailing `=` because = is itself not url-safe!
json_state = json.dumps(state)
return (
base64.urlsafe_b64encode(json_state.encode('utf8'))
.decode('ascii')
.rstrip('=')
)
def _decode_state(self, b64_state): oauth_state_max_age = Integer(
"""Decode a base64 state 600,
config=True,
help="""Max age (seconds) of oauth state.
Always returns a dict. Governs both oauth state cookie Max-Age,
The dict will be empty if the state is invalid. as well as the in-memory _oauth_states cache.
""" """,
if isinstance(b64_state, str): )
b64_state = b64_state.encode('ascii') _oauth_states = Instance(
if len(b64_state) != 4: _ExpiringDict,
# restore padding allow_none=False,
b64_state = b64_state + (b'=' * (4 - len(b64_state) % 4)) help="""
try: Store oauth state info, such as next_url
json_state = base64.urlsafe_b64decode(b64_state).decode('utf8')
except ValueError: The oauth state only contains the oauth state _id_,
app_log.error("Failed to b64-decode state: %r", b64_state) while other information such as the cookie name, next_url
return {} are stored in this dictionary.
try: """,
return json.loads(json_state) )
except ValueError:
app_log.error("Failed to json-decode state: %r", json_state) @default('_oauth_states')
return {} def _default_oauth_states(self):
return _ExpiringDict(max_age=self.oauth_state_max_age)
def set_state_cookie(self, handler, next_url=None): def set_state_cookie(self, handler, next_url=None):
"""Generate an OAuth state and store it in a cookie """Generate an OAuth state and store it in a cookie
@@ -914,7 +959,7 @@ class HubOAuth(HubAuth):
extra_state['cookie_name'] = cookie_name extra_state['cookie_name'] = cookie_name
else: else:
cookie_name = self.state_cookie_name cookie_name = self.state_cookie_name
b64_state = self.generate_state(next_url, **extra_state) state_id = self.generate_state(next_url, **extra_state)
kwargs = { kwargs = {
'path': self.base_url, 'path': self.base_url,
'httponly': True, 'httponly': True,
@@ -924,16 +969,24 @@ class HubOAuth(HubAuth):
# OAuth that doesn't complete shouldn't linger too long. # OAuth that doesn't complete shouldn't linger too long.
'max_age': 600, 'max_age': 600,
} }
# don't allow overriding some fields
no_override_keys = set(kwargs.keys()) | {"expires_days", "expires"}
if get_browser_protocol(handler.request) == 'https': if get_browser_protocol(handler.request) == 'https':
kwargs['secure'] = True kwargs['secure'] = True
# load user cookie overrides # load user cookie overrides
kwargs.update(self.cookie_options) for key, value in self.cookie_options:
handler.set_secure_cookie(cookie_name, b64_state, **kwargs) # don't include overrides
return b64_state if key.lower() not in no_override_keys:
kwargs[key] = value
handler.set_secure_cookie(cookie_name, state_id, **kwargs)
return state_id
def generate_state(self, next_url=None, **extra_state): def generate_state(self, next_url=None, **extra_state):
"""Generate a state string, given a next_url redirect target """Generate a state string, given a next_url redirect target
The state info is stored locally in self._oauth_states,
and only the state id is returned for use in the oauth state field (cookie, redirect param)
Parameters Parameters
---------- ----------
next_url : str next_url : str
@@ -941,24 +994,44 @@ class HubOAuth(HubAuth):
Returns Returns
------- -------
state (str): The base64-encoded state string. state_id (str): The state string to be used as a cookie value.
""" """
state = {'uuid': uuid.uuid4().hex, 'next_url': next_url} state_id = secrets.token_urlsafe(16)
state.update(extra_state) state = {'next_url': next_url}
return self._encode_state(state) if extra_state:
state.update(extra_state)
self._oauth_states[state_id] = state
return state_id
def get_next_url(self, b64_state=''): def clear_oauth_state(self, state_id):
"""Clear persisted oauth state"""
self._oauth_states.pop(state_id, None)
self._oauth_states.purge_expired()
def clear_oauth_state_cookies(self, handler):
"""Clear persisted oauth state"""
for cookie_name, cookie in handler.request.cookies.items():
if cookie_name.startswith(self.state_cookie_name):
handler.clear_cookie(
cookie_name,
path=self.base_url,
)
def _decode_state(self, state_id, /):
return self._oauth_states.get(state_id, {})
def get_next_url(self, state_id='', /):
"""Get the next_url for redirection, given an encoded OAuth state""" """Get the next_url for redirection, given an encoded OAuth state"""
state = self._decode_state(b64_state) state = self._decode_state(state_id)
return state.get('next_url') or self.base_url return state.get('next_url') or self.base_url
def get_state_cookie_name(self, b64_state=''): def get_state_cookie_name(self, state_id='', /):
"""Get the cookie name for oauth state, given an encoded OAuth state """Get the cookie name for oauth state, given an encoded OAuth state
Cookie name is stored in the state itself because the cookie name Cookie name is stored in the state itself because the cookie name
is randomized to deal with races between concurrent oauth sequences. is randomized to deal with races between concurrent oauth sequences.
""" """
state = self._decode_state(b64_state) state = self._decode_state(state_id)
return state.get('cookie_name') or self.state_cookie_name return state.get('cookie_name') or self.state_cookie_name
def set_cookie(self, handler, access_token): def set_cookie(self, handler, access_token):
@@ -1230,23 +1303,30 @@ class HubOAuthCallbackHandler(HubOAuthenticated, RequestHandler):
code = self.get_argument("code", False) code = self.get_argument("code", False)
if not code: if not code:
raise HTTPError(400, "oauth callback made without a token") raise HTTPError(400, "OAuth callback made without a token")
# validate OAuth state # validate OAuth state
arg_state = self.get_argument("state", None) arg_state = self.get_argument("state", None)
if arg_state is None: if arg_state is None:
raise HTTPError(500, "oauth state is missing. Try logging in again.") raise HTTPError(400, "OAuth state is missing. Try logging in again.")
cookie_name = self.hub_auth.get_state_cookie_name(arg_state) cookie_name = self.hub_auth.get_state_cookie_name(arg_state)
cookie_state = self.get_secure_cookie(cookie_name) cookie_state = self.get_secure_cookie(cookie_name)
# clear cookie state now that we've consumed it # clear cookie state now that we've consumed it
self.clear_cookie(cookie_name, path=self.hub_auth.base_url) if cookie_state:
self.clear_cookie(cookie_name, path=self.hub_auth.base_url)
if isinstance(cookie_state, bytes): if isinstance(cookie_state, bytes):
cookie_state = cookie_state.decode('ascii', 'replace') cookie_state = cookie_state.decode('ascii', 'replace')
# check that state matches # check that state matches
if arg_state != cookie_state: if arg_state != cookie_state:
app_log.warning("oauth state %r != %r", arg_state, cookie_state) app_log.warning("oauth state %r != %r", arg_state, cookie_state)
raise HTTPError(403, "oauth state does not match. Try logging in again.") raise HTTPError(403, "OAuth state does not match. Try logging in again.")
next_url = self.hub_auth.get_next_url(cookie_state) next_url = self.hub_auth.get_next_url(cookie_state)
# clear consumed state from _oauth_states cache now that we're done with it
self.hub_auth.clear_oauth_state(cookie_state)
# clear _all_ oauth state cookies on success ?
# This prevents multiple concurrent logins in the same browser,
# which is probably okay.
# self.hub_auth.clear_oauth_state_cookies(self)
token = await self.hub_auth.token_for_code(code, sync=False) token = await self.hub_auth.token_for_code(code, sync=False)
session_id = self.hub_auth.get_session_id(self) session_id = self.hub_auth.get_session_id(self)

View File

@@ -2,6 +2,7 @@
import copy import copy
import os import os
import sys import sys
import time
from binascii import hexlify from binascii import hexlify
from unittest import mock from unittest import mock
from urllib.parse import parse_qs, urlparse from urllib.parse import parse_qs, urlparse
@@ -530,12 +531,21 @@ async def test_oauth_cookie_collision(app, mockservice_url, create_user_with_sco
state_cookie_name = 'service-%s-oauth-state' % service.name state_cookie_name = 'service-%s-oauth-state' % service.name
service_cookie_name = 'service-%s' % service.name service_cookie_name = 'service-%s' % service.name
oauth_1 = await s.get(url) oauth_1 = await s.get(url)
print(oauth_1.headers)
print(oauth_1.cookies, oauth_1.url, url)
assert state_cookie_name in s.cookies assert state_cookie_name in s.cookies
state_cookies = [c for c in s.cookies.keys() if c.startswith(state_cookie_name)] state_cookies = [c for c in s.cookies.keys() if c.startswith(state_cookie_name)]
# only one state cookie # only one state cookie
assert state_cookies == [state_cookie_name] assert state_cookies == [state_cookie_name]
# create dict of Cookie objects, so we can check properties
# cookies.__getitem__ returns only cookie _value_, but we want to check expiration
cookie_dict = {cookie.name: cookie for cookie in s.cookies}
state_cookie = cookie_dict[state_cookie_name]
# check state cookie properties
# should expire in 10 minutes (600 seconds)
assert time.time() < state_cookie.expires < time.time() + 630
# path is set right
assert state_cookie.path == service.prefix
state_1 = s.cookies[state_cookie_name] state_1 = s.cookies[state_cookie_name]
# start second oauth login before finishing the first # start second oauth login before finishing the first