mirror of
https://github.com/jupyterhub/jupyterhub.git
synced 2025-10-11 20:13:02 +00:00
implement state handling in HubOAuth
This commit is contained in:
@@ -9,11 +9,15 @@ model describing the authenticated user.
|
||||
authenticate with the Hub.
|
||||
|
||||
"""
|
||||
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import socket
|
||||
import time
|
||||
from urllib.parse import quote, urlencode
|
||||
import uuid
|
||||
import warnings
|
||||
|
||||
import requests
|
||||
@@ -397,6 +401,14 @@ class HubOAuth(HubAuth):
|
||||
"""
|
||||
return self.oauth_client_id
|
||||
|
||||
@property
|
||||
def state_cookie_name(self):
|
||||
"""The cookie name for storing OAuth state
|
||||
|
||||
This cookie is only live for the duration of the OAuth handshake.
|
||||
"""
|
||||
return self.cookie_name + '-oauth-state'
|
||||
|
||||
def _get_user_cookie(self, handler):
|
||||
token = handler.get_secure_cookie(self.cookie_name)
|
||||
if token:
|
||||
@@ -476,6 +488,84 @@ class HubOAuth(HubAuth):
|
||||
|
||||
return token_reply['access_token']
|
||||
|
||||
def _encode_state(self, state):
|
||||
"""Encode a state dict as url-safe base64"""
|
||||
# trim trailing `=` because
|
||||
json_state = json.dumps(state)
|
||||
return base64.urlsafe_b64encode(
|
||||
json_state.encode('utf8')
|
||||
).decode('ascii').rstrip('=')
|
||||
|
||||
def _decode_state(self, b64_state):
|
||||
"""Decode a base64 state
|
||||
|
||||
Always returns a dict.
|
||||
The dict will be empty if the state is invalid.
|
||||
"""
|
||||
if isinstance(b64_state, str):
|
||||
b64_state = b64_state.encode('ascii')
|
||||
if len(b64_state) != 4:
|
||||
# restore padding
|
||||
b64_state = b64_state + (b'=' * (4 - len(b64_state) % 4))
|
||||
try:
|
||||
json_state = base64.urlsafe_b64decode(b64_state).decode('utf8')
|
||||
except ValueError:
|
||||
app_log.error("Failed to b64-decode state: %r", b64_state)
|
||||
return {}
|
||||
try:
|
||||
return json.loads(json_state)
|
||||
except ValueError:
|
||||
app_log.error("Failed to json-decode state: %r", json_state)
|
||||
return {}
|
||||
|
||||
def set_state_cookie(self, handler, next_url=None):
|
||||
"""Generate an OAuth state and store it in a cookie
|
||||
|
||||
Parameters
|
||||
----------
|
||||
handler (RequestHandler): A tornado RequestHandler
|
||||
next_url (str): The page to redirect to on successful login
|
||||
|
||||
Returns
|
||||
-------
|
||||
state (str): The OAuth state that has been stored in the cookie (url safe, base64-encoded)
|
||||
"""
|
||||
b64_state = self.generate_state(next_url)
|
||||
kwargs = {
|
||||
'path': self.base_url,
|
||||
'httponly': True,
|
||||
'expires_days': 1,
|
||||
}
|
||||
if handler.request.protocol == 'https':
|
||||
kwargs['secure'] = True
|
||||
handler.set_secure_cookie(
|
||||
self.state_cookie_name,
|
||||
b64_state,
|
||||
**kwargs
|
||||
)
|
||||
return b64_state
|
||||
|
||||
def generate_state(self, next_url=None):
|
||||
"""Generate a state string, given a next_url redirect target
|
||||
|
||||
Parameters
|
||||
----------
|
||||
next_url (str): The URL of the page to redirect to on successful login.
|
||||
|
||||
Returns
|
||||
-------
|
||||
state (str): The base64-encoded state string.
|
||||
"""
|
||||
return self._encode_state({
|
||||
'uuid': uuid.uuid4().hex,
|
||||
'next_url': next_url
|
||||
})
|
||||
|
||||
def get_next_url(self, b64_state=''):
|
||||
"""Get the next_url for redirection, given an encoded OAuth state"""
|
||||
state = self._decode_state(b64_state)
|
||||
return state.get('next_url') or self.base_url
|
||||
|
||||
def set_cookie(self, handler, access_token):
|
||||
"""Set a cookie recording OAuth result"""
|
||||
kwargs = {
|
||||
@@ -565,8 +655,14 @@ class HubAuthenticated(object):
|
||||
|
||||
def get_login_url(self):
|
||||
"""Return the Hub's login URL"""
|
||||
app_log.debug("Redirecting to login url: %s" % self.hub_auth.login_url)
|
||||
return self.hub_auth.login_url
|
||||
login_url = self.hub_auth.login_url
|
||||
app_log.debug("Redirecting to login url: %s", login_url)
|
||||
if isinstance(self.hub_auth, HubOAuthenticated):
|
||||
# add state argument to OAuth url
|
||||
state = self.hub_auth.set_state_cookie(self, next_url=self.request.uri)
|
||||
return url_concat(login_url, {'state': state})
|
||||
else:
|
||||
return login_url
|
||||
|
||||
def check_hub_user(self, model):
|
||||
"""Check whether Hub-authenticated user or service should be allowed.
|
||||
@@ -657,6 +753,21 @@ class HubOAuthCallbackHandler(HubOAuthenticated, RequestHandler):
|
||||
code = self.get_argument("code", False)
|
||||
if not code:
|
||||
raise HTTPError(400, "oauth callback made without a token")
|
||||
|
||||
# validate OAuth state
|
||||
arg_state = self.get_argument("state", None)
|
||||
cookie_state = self.get_secure_cookie(self.hub_auth.state_cookie_name)
|
||||
next_url = None
|
||||
if arg_state or cookie_state:
|
||||
# clear cookie state now that we've consumed it
|
||||
self.clear_cookie(self.hub_auth.state_cookie_name)
|
||||
if isinstance(cookie_state, bytes):
|
||||
cookie_state = cookie_state.decode('ascii', 'replace')
|
||||
# check that state matches
|
||||
if arg_state != cookie_state:
|
||||
app_log.debug("oauth state %r != %r", arg_state, cookie_state)
|
||||
raise HTTPError(403, "oauth state does not match")
|
||||
next_url = self.hub_auth.get_next_url(cookie_state)
|
||||
# TODO: make async (in a Thread?)
|
||||
token = self.hub_auth.token_for_code(code)
|
||||
user_model = self.hub_auth.user_for_token(token)
|
||||
@@ -664,7 +775,6 @@ class HubOAuthCallbackHandler(HubOAuthenticated, RequestHandler):
|
||||
raise HTTPError(500, "oauth callback failed to identify a user")
|
||||
app_log.info("Logged-in user %s", user_model)
|
||||
self.hub_auth.set_cookie(self, token)
|
||||
next_url = self.get_argument('next', '') or self.hub_auth.base_url
|
||||
self.redirect(next_url)
|
||||
self.redirect(next_url or self.hub_auth.base_url)
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user