mirror of
https://github.com/jupyterhub/jupyterhub.git
synced 2025-10-18 15:33:02 +00:00
move service oauth state fields to in-memory
only use a random token as the actual oauth state, and use a local cache dict to store the extra info like cookie_name, next_url this avoids the state field getting too big and passing local browser-server info to anyone else
This commit is contained in:
@@ -24,16 +24,15 @@ A tornado implementation is provided in :class:`HubOAuthCallbackHandler`.
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
import asyncio
|
import asyncio
|
||||||
import base64
|
|
||||||
import hashlib
|
import hashlib
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
import re
|
import re
|
||||||
|
import secrets
|
||||||
import socket
|
import socket
|
||||||
import string
|
import string
|
||||||
import time
|
import time
|
||||||
import uuid
|
|
||||||
import warnings
|
import warnings
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
@@ -106,14 +105,24 @@ class _ExpiringDict(dict):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
max_age = 0
|
max_age = 0
|
||||||
|
purge_interval = 0
|
||||||
|
|
||||||
def __init__(self, max_age=0):
|
def __init__(self, max_age=0, purge_interval="max_age"):
|
||||||
self.max_age = max_age
|
self.max_age = max_age
|
||||||
|
if purge_interval == "max_age":
|
||||||
|
# default behavior: use max_age
|
||||||
|
purge_interval = max_age
|
||||||
|
self.purge_interval = purge_interval
|
||||||
self.timestamps = {}
|
self.timestamps = {}
|
||||||
self.values = {}
|
self.values = {}
|
||||||
|
self._last_purge = time.monotonic()
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.values)
|
||||||
|
|
||||||
def __setitem__(self, key, value):
|
def __setitem__(self, key, value):
|
||||||
"""Store key and record timestamp"""
|
"""Store key and record timestamp"""
|
||||||
|
self._maybe_purge()
|
||||||
self.timestamps[key] = time.monotonic()
|
self.timestamps[key] = time.monotonic()
|
||||||
self.values[key] = value
|
self.values[key] = value
|
||||||
|
|
||||||
@@ -139,6 +148,7 @@ class _ExpiringDict(dict):
|
|||||||
if self.max_age > 0 and timestamp + self.max_age < now:
|
if self.max_age > 0 and timestamp + self.max_age < now:
|
||||||
self.values.pop(key)
|
self.values.pop(key)
|
||||||
self.timestamps.pop(key)
|
self.timestamps.pop(key)
|
||||||
|
self._maybe_purge()
|
||||||
|
|
||||||
def __contains__(self, key):
|
def __contains__(self, key):
|
||||||
"""dict check for `key in dict`"""
|
"""dict check for `key in dict`"""
|
||||||
@@ -150,17 +160,57 @@ class _ExpiringDict(dict):
|
|||||||
self._check_age(key)
|
self._check_age(key)
|
||||||
return self.values[key]
|
return self.values[key]
|
||||||
|
|
||||||
|
def __delitem__(self, key):
|
||||||
|
del self.values[key]
|
||||||
|
del self.timestamps[key]
|
||||||
|
|
||||||
def get(self, key, default=None):
|
def get(self, key, default=None):
|
||||||
"""dict-like get:"""
|
"""dict-like get"""
|
||||||
try:
|
try:
|
||||||
return self[key]
|
return self[key]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
return default
|
return default
|
||||||
|
|
||||||
|
def pop(self, key, default="_raise"):
|
||||||
|
"""Remove and return an item"""
|
||||||
|
if key in self:
|
||||||
|
value = self.values.pop(key)
|
||||||
|
del self.timestamps[key]
|
||||||
|
return value
|
||||||
|
else:
|
||||||
|
if default == "_raise":
|
||||||
|
raise KeyError(key)
|
||||||
|
else:
|
||||||
|
return default
|
||||||
|
|
||||||
def clear(self):
|
def clear(self):
|
||||||
"""Clear the cache"""
|
"""Clear the cache"""
|
||||||
self.values.clear()
|
self.values.clear()
|
||||||
self.timestamps.clear()
|
self.timestamps.clear()
|
||||||
|
self._last_purge = time.monotonic()
|
||||||
|
|
||||||
|
# extended methods
|
||||||
|
def _maybe_purge(self):
|
||||||
|
"""purge expired values _if_ it's been purge_interval since the last purge
|
||||||
|
|
||||||
|
Called on every get/set, to keep the expired values clear.
|
||||||
|
"""
|
||||||
|
if not self.purge_interval > 0:
|
||||||
|
return
|
||||||
|
now = time.monotonic()
|
||||||
|
if self._last_purge < (now - self.purge_interval):
|
||||||
|
self.purge_expired()
|
||||||
|
|
||||||
|
def purge_expired(self):
|
||||||
|
"""Purge all expired values"""
|
||||||
|
if not self.max_age > 0:
|
||||||
|
return
|
||||||
|
now = self._last_purge = time.monotonic()
|
||||||
|
cutoff = now - self.max_age
|
||||||
|
for key in list(self.timestamps):
|
||||||
|
timestamp = self.timestamps[key]
|
||||||
|
if timestamp < cutoff:
|
||||||
|
del self[key]
|
||||||
|
|
||||||
|
|
||||||
class HubAuth(SingletonConfigurable):
|
class HubAuth(SingletonConfigurable):
|
||||||
@@ -854,37 +904,32 @@ class HubOAuth(HubAuth):
|
|||||||
|
|
||||||
return token_reply['access_token']
|
return token_reply['access_token']
|
||||||
|
|
||||||
def _encode_state(self, state):
|
# state-related
|
||||||
"""Encode a state dict as url-safe base64"""
|
|
||||||
# trim trailing `=` because = is itself not url-safe!
|
|
||||||
json_state = json.dumps(state)
|
|
||||||
return (
|
|
||||||
base64.urlsafe_b64encode(json_state.encode('utf8'))
|
|
||||||
.decode('ascii')
|
|
||||||
.rstrip('=')
|
|
||||||
)
|
|
||||||
|
|
||||||
def _decode_state(self, b64_state):
|
oauth_state_max_age = Integer(
|
||||||
"""Decode a base64 state
|
600,
|
||||||
|
config=True,
|
||||||
|
help="""Max age (seconds) of oauth state.
|
||||||
|
|
||||||
Always returns a dict.
|
Governs both oauth state cookie Max-Age,
|
||||||
The dict will be empty if the state is invalid.
|
as well as the in-memory _oauth_states cache.
|
||||||
"""
|
""",
|
||||||
if isinstance(b64_state, str):
|
)
|
||||||
b64_state = b64_state.encode('ascii')
|
_oauth_states = Instance(
|
||||||
if len(b64_state) != 4:
|
_ExpiringDict,
|
||||||
# restore padding
|
allow_none=False,
|
||||||
b64_state = b64_state + (b'=' * (4 - len(b64_state) % 4))
|
help="""
|
||||||
try:
|
Store oauth state info, such as next_url
|
||||||
json_state = base64.urlsafe_b64decode(b64_state).decode('utf8')
|
|
||||||
except ValueError:
|
The oauth state only contains the oauth state _id_,
|
||||||
app_log.error("Failed to b64-decode state: %r", b64_state)
|
while other information such as the cookie name, next_url
|
||||||
return {}
|
are stored in this dictionary.
|
||||||
try:
|
""",
|
||||||
return json.loads(json_state)
|
)
|
||||||
except ValueError:
|
|
||||||
app_log.error("Failed to json-decode state: %r", json_state)
|
@default('_oauth_states')
|
||||||
return {}
|
def _default_oauth_states(self):
|
||||||
|
return _ExpiringDict(max_age=self.oauth_state_max_age)
|
||||||
|
|
||||||
def set_state_cookie(self, handler, next_url=None):
|
def set_state_cookie(self, handler, next_url=None):
|
||||||
"""Generate an OAuth state and store it in a cookie
|
"""Generate an OAuth state and store it in a cookie
|
||||||
@@ -914,7 +959,7 @@ class HubOAuth(HubAuth):
|
|||||||
extra_state['cookie_name'] = cookie_name
|
extra_state['cookie_name'] = cookie_name
|
||||||
else:
|
else:
|
||||||
cookie_name = self.state_cookie_name
|
cookie_name = self.state_cookie_name
|
||||||
b64_state = self.generate_state(next_url, **extra_state)
|
state_id = self.generate_state(next_url, **extra_state)
|
||||||
kwargs = {
|
kwargs = {
|
||||||
'path': self.base_url,
|
'path': self.base_url,
|
||||||
'httponly': True,
|
'httponly': True,
|
||||||
@@ -924,16 +969,24 @@ class HubOAuth(HubAuth):
|
|||||||
# OAuth that doesn't complete shouldn't linger too long.
|
# OAuth that doesn't complete shouldn't linger too long.
|
||||||
'max_age': 600,
|
'max_age': 600,
|
||||||
}
|
}
|
||||||
|
# don't allow overriding some fields
|
||||||
|
no_override_keys = set(kwargs.keys()) | {"expires_days", "expires"}
|
||||||
if get_browser_protocol(handler.request) == 'https':
|
if get_browser_protocol(handler.request) == 'https':
|
||||||
kwargs['secure'] = True
|
kwargs['secure'] = True
|
||||||
# load user cookie overrides
|
# load user cookie overrides
|
||||||
kwargs.update(self.cookie_options)
|
for key, value in self.cookie_options:
|
||||||
handler.set_secure_cookie(cookie_name, b64_state, **kwargs)
|
# don't include overrides
|
||||||
return b64_state
|
if key.lower() not in no_override_keys:
|
||||||
|
kwargs[key] = value
|
||||||
|
handler.set_secure_cookie(cookie_name, state_id, **kwargs)
|
||||||
|
return state_id
|
||||||
|
|
||||||
def generate_state(self, next_url=None, **extra_state):
|
def generate_state(self, next_url=None, **extra_state):
|
||||||
"""Generate a state string, given a next_url redirect target
|
"""Generate a state string, given a next_url redirect target
|
||||||
|
|
||||||
|
The state info is stored locally in self._oauth_states,
|
||||||
|
and only the state id is returned for use in the oauth state field (cookie, redirect param)
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
next_url : str
|
next_url : str
|
||||||
@@ -941,24 +994,44 @@ class HubOAuth(HubAuth):
|
|||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
state (str): The base64-encoded state string.
|
state_id (str): The state string to be used as a cookie value.
|
||||||
"""
|
"""
|
||||||
state = {'uuid': uuid.uuid4().hex, 'next_url': next_url}
|
state_id = secrets.token_urlsafe(16)
|
||||||
state.update(extra_state)
|
state = {'next_url': next_url}
|
||||||
return self._encode_state(state)
|
if extra_state:
|
||||||
|
state.update(extra_state)
|
||||||
|
self._oauth_states[state_id] = state
|
||||||
|
return state_id
|
||||||
|
|
||||||
def get_next_url(self, b64_state=''):
|
def clear_oauth_state(self, state_id):
|
||||||
|
"""Clear persisted oauth state"""
|
||||||
|
self._oauth_states.pop(state_id, None)
|
||||||
|
self._oauth_states.purge_expired()
|
||||||
|
|
||||||
|
def clear_oauth_state_cookies(self, handler):
|
||||||
|
"""Clear persisted oauth state"""
|
||||||
|
for cookie_name, cookie in handler.request.cookies.items():
|
||||||
|
if cookie_name.startswith(self.state_cookie_name):
|
||||||
|
handler.clear_cookie(
|
||||||
|
cookie_name,
|
||||||
|
path=self.base_url,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _decode_state(self, state_id, /):
|
||||||
|
return self._oauth_states.get(state_id, {})
|
||||||
|
|
||||||
|
def get_next_url(self, state_id='', /):
|
||||||
"""Get the next_url for redirection, given an encoded OAuth state"""
|
"""Get the next_url for redirection, given an encoded OAuth state"""
|
||||||
state = self._decode_state(b64_state)
|
state = self._decode_state(state_id)
|
||||||
return state.get('next_url') or self.base_url
|
return state.get('next_url') or self.base_url
|
||||||
|
|
||||||
def get_state_cookie_name(self, b64_state=''):
|
def get_state_cookie_name(self, state_id='', /):
|
||||||
"""Get the cookie name for oauth state, given an encoded OAuth state
|
"""Get the cookie name for oauth state, given an encoded OAuth state
|
||||||
|
|
||||||
Cookie name is stored in the state itself because the cookie name
|
Cookie name is stored in the state itself because the cookie name
|
||||||
is randomized to deal with races between concurrent oauth sequences.
|
is randomized to deal with races between concurrent oauth sequences.
|
||||||
"""
|
"""
|
||||||
state = self._decode_state(b64_state)
|
state = self._decode_state(state_id)
|
||||||
return state.get('cookie_name') or self.state_cookie_name
|
return state.get('cookie_name') or self.state_cookie_name
|
||||||
|
|
||||||
def set_cookie(self, handler, access_token):
|
def set_cookie(self, handler, access_token):
|
||||||
@@ -1230,23 +1303,30 @@ class HubOAuthCallbackHandler(HubOAuthenticated, RequestHandler):
|
|||||||
|
|
||||||
code = self.get_argument("code", False)
|
code = self.get_argument("code", False)
|
||||||
if not code:
|
if not code:
|
||||||
raise HTTPError(400, "oauth callback made without a token")
|
raise HTTPError(400, "OAuth callback made without a token")
|
||||||
|
|
||||||
# validate OAuth state
|
# validate OAuth state
|
||||||
arg_state = self.get_argument("state", None)
|
arg_state = self.get_argument("state", None)
|
||||||
if arg_state is None:
|
if arg_state is None:
|
||||||
raise HTTPError(500, "oauth state is missing. Try logging in again.")
|
raise HTTPError(400, "OAuth state is missing. Try logging in again.")
|
||||||
cookie_name = self.hub_auth.get_state_cookie_name(arg_state)
|
cookie_name = self.hub_auth.get_state_cookie_name(arg_state)
|
||||||
cookie_state = self.get_secure_cookie(cookie_name)
|
cookie_state = self.get_secure_cookie(cookie_name)
|
||||||
# clear cookie state now that we've consumed it
|
# clear cookie state now that we've consumed it
|
||||||
self.clear_cookie(cookie_name, path=self.hub_auth.base_url)
|
if cookie_state:
|
||||||
|
self.clear_cookie(cookie_name, path=self.hub_auth.base_url)
|
||||||
if isinstance(cookie_state, bytes):
|
if isinstance(cookie_state, bytes):
|
||||||
cookie_state = cookie_state.decode('ascii', 'replace')
|
cookie_state = cookie_state.decode('ascii', 'replace')
|
||||||
# check that state matches
|
# check that state matches
|
||||||
if arg_state != cookie_state:
|
if arg_state != cookie_state:
|
||||||
app_log.warning("oauth state %r != %r", arg_state, cookie_state)
|
app_log.warning("oauth state %r != %r", arg_state, cookie_state)
|
||||||
raise HTTPError(403, "oauth state does not match. Try logging in again.")
|
raise HTTPError(403, "OAuth state does not match. Try logging in again.")
|
||||||
next_url = self.hub_auth.get_next_url(cookie_state)
|
next_url = self.hub_auth.get_next_url(cookie_state)
|
||||||
|
# clear consumed state from _oauth_states cache now that we're done with it
|
||||||
|
self.hub_auth.clear_oauth_state(cookie_state)
|
||||||
|
# clear _all_ oauth state cookies on success ?
|
||||||
|
# This prevents multiple concurrent logins in the same browser,
|
||||||
|
# which is probably okay.
|
||||||
|
# self.hub_auth.clear_oauth_state_cookies(self)
|
||||||
|
|
||||||
token = await self.hub_auth.token_for_code(code, sync=False)
|
token = await self.hub_auth.token_for_code(code, sync=False)
|
||||||
session_id = self.hub_auth.get_session_id(self)
|
session_id = self.hub_auth.get_session_id(self)
|
||||||
|
@@ -2,6 +2,7 @@
|
|||||||
import copy
|
import copy
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
import time
|
||||||
from binascii import hexlify
|
from binascii import hexlify
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
from urllib.parse import parse_qs, urlparse
|
from urllib.parse import parse_qs, urlparse
|
||||||
@@ -530,12 +531,21 @@ async def test_oauth_cookie_collision(app, mockservice_url, create_user_with_sco
|
|||||||
state_cookie_name = 'service-%s-oauth-state' % service.name
|
state_cookie_name = 'service-%s-oauth-state' % service.name
|
||||||
service_cookie_name = 'service-%s' % service.name
|
service_cookie_name = 'service-%s' % service.name
|
||||||
oauth_1 = await s.get(url)
|
oauth_1 = await s.get(url)
|
||||||
print(oauth_1.headers)
|
|
||||||
print(oauth_1.cookies, oauth_1.url, url)
|
|
||||||
assert state_cookie_name in s.cookies
|
assert state_cookie_name in s.cookies
|
||||||
state_cookies = [c for c in s.cookies.keys() if c.startswith(state_cookie_name)]
|
state_cookies = [c for c in s.cookies.keys() if c.startswith(state_cookie_name)]
|
||||||
# only one state cookie
|
# only one state cookie
|
||||||
assert state_cookies == [state_cookie_name]
|
assert state_cookies == [state_cookie_name]
|
||||||
|
# create dict of Cookie objects, so we can check properties
|
||||||
|
# cookies.__getitem__ returns only cookie _value_, but we want to check expiration
|
||||||
|
cookie_dict = {cookie.name: cookie for cookie in s.cookies}
|
||||||
|
state_cookie = cookie_dict[state_cookie_name]
|
||||||
|
|
||||||
|
# check state cookie properties
|
||||||
|
# should expire in 10 minutes (600 seconds)
|
||||||
|
assert time.time() < state_cookie.expires < time.time() + 630
|
||||||
|
# path is set right
|
||||||
|
assert state_cookie.path == service.prefix
|
||||||
|
|
||||||
state_1 = s.cookies[state_cookie_name]
|
state_1 = s.cookies[state_cookie_name]
|
||||||
|
|
||||||
# start second oauth login before finishing the first
|
# start second oauth login before finishing the first
|
||||||
|
Reference in New Issue
Block a user