mirror of
https://github.com/jupyterhub/jupyterhub.git
synced 2025-10-16 22:43:00 +00:00
move service oauth state fields to in-memory
only use a random token as the actual oauth state, and use a local cache dict to store the extra info like cookie_name, next_url this avoids the state field getting too big and passing local browser-server info to anyone else
This commit is contained in:
@@ -24,16 +24,15 @@ A tornado implementation is provided in :class:`HubOAuthCallbackHandler`.
|
||||
|
||||
"""
|
||||
import asyncio
|
||||
import base64
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import secrets
|
||||
import socket
|
||||
import string
|
||||
import time
|
||||
import uuid
|
||||
import warnings
|
||||
from http import HTTPStatus
|
||||
from unittest import mock
|
||||
@@ -106,14 +105,24 @@ class _ExpiringDict(dict):
|
||||
"""
|
||||
|
||||
max_age = 0
|
||||
purge_interval = 0
|
||||
|
||||
def __init__(self, max_age=0):
|
||||
def __init__(self, max_age=0, purge_interval="max_age"):
|
||||
self.max_age = max_age
|
||||
if purge_interval == "max_age":
|
||||
# default behavior: use max_age
|
||||
purge_interval = max_age
|
||||
self.purge_interval = purge_interval
|
||||
self.timestamps = {}
|
||||
self.values = {}
|
||||
self._last_purge = time.monotonic()
|
||||
|
||||
def __len__(self):
|
||||
return len(self.values)
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
"""Store key and record timestamp"""
|
||||
self._maybe_purge()
|
||||
self.timestamps[key] = time.monotonic()
|
||||
self.values[key] = value
|
||||
|
||||
@@ -139,6 +148,7 @@ class _ExpiringDict(dict):
|
||||
if self.max_age > 0 and timestamp + self.max_age < now:
|
||||
self.values.pop(key)
|
||||
self.timestamps.pop(key)
|
||||
self._maybe_purge()
|
||||
|
||||
def __contains__(self, key):
|
||||
"""dict check for `key in dict`"""
|
||||
@@ -150,17 +160,57 @@ class _ExpiringDict(dict):
|
||||
self._check_age(key)
|
||||
return self.values[key]
|
||||
|
||||
def __delitem__(self, key):
|
||||
del self.values[key]
|
||||
del self.timestamps[key]
|
||||
|
||||
def get(self, key, default=None):
|
||||
"""dict-like get:"""
|
||||
"""dict-like get"""
|
||||
try:
|
||||
return self[key]
|
||||
except KeyError:
|
||||
return default
|
||||
|
||||
def pop(self, key, default="_raise"):
|
||||
"""Remove and return an item"""
|
||||
if key in self:
|
||||
value = self.values.pop(key)
|
||||
del self.timestamps[key]
|
||||
return value
|
||||
else:
|
||||
if default == "_raise":
|
||||
raise KeyError(key)
|
||||
else:
|
||||
return default
|
||||
|
||||
def clear(self):
|
||||
"""Clear the cache"""
|
||||
self.values.clear()
|
||||
self.timestamps.clear()
|
||||
self._last_purge = time.monotonic()
|
||||
|
||||
# extended methods
|
||||
def _maybe_purge(self):
|
||||
"""purge expired values _if_ it's been purge_interval since the last purge
|
||||
|
||||
Called on every get/set, to keep the expired values clear.
|
||||
"""
|
||||
if not self.purge_interval > 0:
|
||||
return
|
||||
now = time.monotonic()
|
||||
if self._last_purge < (now - self.purge_interval):
|
||||
self.purge_expired()
|
||||
|
||||
def purge_expired(self):
|
||||
"""Purge all expired values"""
|
||||
if not self.max_age > 0:
|
||||
return
|
||||
now = self._last_purge = time.monotonic()
|
||||
cutoff = now - self.max_age
|
||||
for key in list(self.timestamps):
|
||||
timestamp = self.timestamps[key]
|
||||
if timestamp < cutoff:
|
||||
del self[key]
|
||||
|
||||
|
||||
class HubAuth(SingletonConfigurable):
|
||||
@@ -854,37 +904,32 @@ class HubOAuth(HubAuth):
|
||||
|
||||
return token_reply['access_token']
|
||||
|
||||
def _encode_state(self, state):
|
||||
"""Encode a state dict as url-safe base64"""
|
||||
# trim trailing `=` because = is itself not url-safe!
|
||||
json_state = json.dumps(state)
|
||||
return (
|
||||
base64.urlsafe_b64encode(json_state.encode('utf8'))
|
||||
.decode('ascii')
|
||||
.rstrip('=')
|
||||
)
|
||||
# state-related
|
||||
|
||||
def _decode_state(self, b64_state):
|
||||
"""Decode a base64 state
|
||||
oauth_state_max_age = Integer(
|
||||
600,
|
||||
config=True,
|
||||
help="""Max age (seconds) of oauth state.
|
||||
|
||||
Governs both oauth state cookie Max-Age,
|
||||
as well as the in-memory _oauth_states cache.
|
||||
""",
|
||||
)
|
||||
_oauth_states = Instance(
|
||||
_ExpiringDict,
|
||||
allow_none=False,
|
||||
help="""
|
||||
Store oauth state info, such as next_url
|
||||
|
||||
The oauth state only contains the oauth state _id_,
|
||||
while other information such as the cookie name, next_url
|
||||
are stored in this dictionary.
|
||||
""",
|
||||
)
|
||||
|
||||
Always returns a dict.
|
||||
The dict will be empty if the state is invalid.
|
||||
"""
|
||||
if isinstance(b64_state, str):
|
||||
b64_state = b64_state.encode('ascii')
|
||||
if len(b64_state) != 4:
|
||||
# restore padding
|
||||
b64_state = b64_state + (b'=' * (4 - len(b64_state) % 4))
|
||||
try:
|
||||
json_state = base64.urlsafe_b64decode(b64_state).decode('utf8')
|
||||
except ValueError:
|
||||
app_log.error("Failed to b64-decode state: %r", b64_state)
|
||||
return {}
|
||||
try:
|
||||
return json.loads(json_state)
|
||||
except ValueError:
|
||||
app_log.error("Failed to json-decode state: %r", json_state)
|
||||
return {}
|
||||
@default('_oauth_states')
|
||||
def _default_oauth_states(self):
|
||||
return _ExpiringDict(max_age=self.oauth_state_max_age)
|
||||
|
||||
def set_state_cookie(self, handler, next_url=None):
|
||||
"""Generate an OAuth state and store it in a cookie
|
||||
@@ -914,7 +959,7 @@ class HubOAuth(HubAuth):
|
||||
extra_state['cookie_name'] = cookie_name
|
||||
else:
|
||||
cookie_name = self.state_cookie_name
|
||||
b64_state = self.generate_state(next_url, **extra_state)
|
||||
state_id = self.generate_state(next_url, **extra_state)
|
||||
kwargs = {
|
||||
'path': self.base_url,
|
||||
'httponly': True,
|
||||
@@ -924,16 +969,24 @@ class HubOAuth(HubAuth):
|
||||
# OAuth that doesn't complete shouldn't linger too long.
|
||||
'max_age': 600,
|
||||
}
|
||||
# don't allow overriding some fields
|
||||
no_override_keys = set(kwargs.keys()) | {"expires_days", "expires"}
|
||||
if get_browser_protocol(handler.request) == 'https':
|
||||
kwargs['secure'] = True
|
||||
# load user cookie overrides
|
||||
kwargs.update(self.cookie_options)
|
||||
handler.set_secure_cookie(cookie_name, b64_state, **kwargs)
|
||||
return b64_state
|
||||
for key, value in self.cookie_options:
|
||||
# don't include overrides
|
||||
if key.lower() not in no_override_keys:
|
||||
kwargs[key] = value
|
||||
handler.set_secure_cookie(cookie_name, state_id, **kwargs)
|
||||
return state_id
|
||||
|
||||
def generate_state(self, next_url=None, **extra_state):
|
||||
"""Generate a state string, given a next_url redirect target
|
||||
|
||||
The state info is stored locally in self._oauth_states,
|
||||
and only the state id is returned for use in the oauth state field (cookie, redirect param)
|
||||
|
||||
Parameters
|
||||
----------
|
||||
next_url : str
|
||||
@@ -941,24 +994,44 @@ class HubOAuth(HubAuth):
|
||||
|
||||
Returns
|
||||
-------
|
||||
state (str): The base64-encoded state string.
|
||||
state_id (str): The state string to be used as a cookie value.
|
||||
"""
|
||||
state = {'uuid': uuid.uuid4().hex, 'next_url': next_url}
|
||||
state.update(extra_state)
|
||||
return self._encode_state(state)
|
||||
state_id = secrets.token_urlsafe(16)
|
||||
state = {'next_url': next_url}
|
||||
if extra_state:
|
||||
state.update(extra_state)
|
||||
self._oauth_states[state_id] = state
|
||||
return state_id
|
||||
|
||||
def get_next_url(self, b64_state=''):
|
||||
def clear_oauth_state(self, state_id):
|
||||
"""Clear persisted oauth state"""
|
||||
self._oauth_states.pop(state_id, None)
|
||||
self._oauth_states.purge_expired()
|
||||
|
||||
def clear_oauth_state_cookies(self, handler):
|
||||
"""Clear persisted oauth state"""
|
||||
for cookie_name, cookie in handler.request.cookies.items():
|
||||
if cookie_name.startswith(self.state_cookie_name):
|
||||
handler.clear_cookie(
|
||||
cookie_name,
|
||||
path=self.base_url,
|
||||
)
|
||||
|
||||
def _decode_state(self, state_id, /):
|
||||
return self._oauth_states.get(state_id, {})
|
||||
|
||||
def get_next_url(self, state_id='', /):
|
||||
"""Get the next_url for redirection, given an encoded OAuth state"""
|
||||
state = self._decode_state(b64_state)
|
||||
state = self._decode_state(state_id)
|
||||
return state.get('next_url') or self.base_url
|
||||
|
||||
def get_state_cookie_name(self, b64_state=''):
|
||||
def get_state_cookie_name(self, state_id='', /):
|
||||
"""Get the cookie name for oauth state, given an encoded OAuth state
|
||||
|
||||
Cookie name is stored in the state itself because the cookie name
|
||||
is randomized to deal with races between concurrent oauth sequences.
|
||||
"""
|
||||
state = self._decode_state(b64_state)
|
||||
state = self._decode_state(state_id)
|
||||
return state.get('cookie_name') or self.state_cookie_name
|
||||
|
||||
def set_cookie(self, handler, access_token):
|
||||
@@ -1230,23 +1303,30 @@ class HubOAuthCallbackHandler(HubOAuthenticated, RequestHandler):
|
||||
|
||||
code = self.get_argument("code", False)
|
||||
if not code:
|
||||
raise HTTPError(400, "oauth callback made without a token")
|
||||
raise HTTPError(400, "OAuth callback made without a token")
|
||||
|
||||
# validate OAuth state
|
||||
arg_state = self.get_argument("state", None)
|
||||
if arg_state is None:
|
||||
raise HTTPError(500, "oauth state is missing. Try logging in again.")
|
||||
raise HTTPError(400, "OAuth state is missing. Try logging in again.")
|
||||
cookie_name = self.hub_auth.get_state_cookie_name(arg_state)
|
||||
cookie_state = self.get_secure_cookie(cookie_name)
|
||||
# clear cookie state now that we've consumed it
|
||||
self.clear_cookie(cookie_name, path=self.hub_auth.base_url)
|
||||
if cookie_state:
|
||||
self.clear_cookie(cookie_name, path=self.hub_auth.base_url)
|
||||
if isinstance(cookie_state, bytes):
|
||||
cookie_state = cookie_state.decode('ascii', 'replace')
|
||||
# check that state matches
|
||||
if arg_state != cookie_state:
|
||||
app_log.warning("oauth state %r != %r", arg_state, cookie_state)
|
||||
raise HTTPError(403, "oauth state does not match. Try logging in again.")
|
||||
raise HTTPError(403, "OAuth state does not match. Try logging in again.")
|
||||
next_url = self.hub_auth.get_next_url(cookie_state)
|
||||
# clear consumed state from _oauth_states cache now that we're done with it
|
||||
self.hub_auth.clear_oauth_state(cookie_state)
|
||||
# clear _all_ oauth state cookies on success ?
|
||||
# This prevents multiple concurrent logins in the same browser,
|
||||
# which is probably okay.
|
||||
# self.hub_auth.clear_oauth_state_cookies(self)
|
||||
|
||||
token = await self.hub_auth.token_for_code(code, sync=False)
|
||||
session_id = self.hub_auth.get_session_id(self)
|
||||
|
@@ -2,6 +2,7 @@
|
||||
import copy
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from binascii import hexlify
|
||||
from unittest import mock
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
@@ -530,12 +531,21 @@ async def test_oauth_cookie_collision(app, mockservice_url, create_user_with_sco
|
||||
state_cookie_name = 'service-%s-oauth-state' % service.name
|
||||
service_cookie_name = 'service-%s' % service.name
|
||||
oauth_1 = await s.get(url)
|
||||
print(oauth_1.headers)
|
||||
print(oauth_1.cookies, oauth_1.url, url)
|
||||
assert state_cookie_name in s.cookies
|
||||
state_cookies = [c for c in s.cookies.keys() if c.startswith(state_cookie_name)]
|
||||
# only one state cookie
|
||||
assert state_cookies == [state_cookie_name]
|
||||
# create dict of Cookie objects, so we can check properties
|
||||
# cookies.__getitem__ returns only cookie _value_, but we want to check expiration
|
||||
cookie_dict = {cookie.name: cookie for cookie in s.cookies}
|
||||
state_cookie = cookie_dict[state_cookie_name]
|
||||
|
||||
# check state cookie properties
|
||||
# should expire in 10 minutes (600 seconds)
|
||||
assert time.time() < state_cookie.expires < time.time() + 630
|
||||
# path is set right
|
||||
assert state_cookie.path == service.prefix
|
||||
|
||||
state_1 = s.cookies[state_cookie_name]
|
||||
|
||||
# start second oauth login before finishing the first
|
||||
|
Reference in New Issue
Block a user