include session id in cache key

if session id is defined, clearing the session id clears the cache,
allowing immediate revocation of tokens by the Hub.
This commit is contained in:
Min RK
2017-12-07 15:42:15 +01:00
parent 498e234c37
commit ee004486bd

View File

@@ -238,6 +238,8 @@ class HubAuth(Configurable):
cached = self.cache.get(cache_key) cached = self.cache.get(cache_key)
if cached is not None: if cached is not None:
return cached return cached
else:
app_log.debug("Cache miss: %s" % cache_key)
data = self._api_request('GET', url, allow_404=True) data = self._api_request('GET', url, allow_404=True)
if data is None: if data is None:
@@ -285,7 +287,7 @@ class HubAuth(Configurable):
return data return data
def user_for_cookie(self, encrypted_cookie, use_cache=True): def user_for_cookie(self, encrypted_cookie, use_cache=True, session_id=''):
"""Ask the Hub to identify the user for a given cookie. """Ask the Hub to identify the user for a given cookie.
Args: Args:
@@ -302,11 +304,11 @@ class HubAuth(Configurable):
"authorizations/cookie", "authorizations/cookie",
self.cookie_name, self.cookie_name,
quote(encrypted_cookie, safe='')), quote(encrypted_cookie, safe='')),
cache_key='cookie:%s' % encrypted_cookie, cache_key='cookie:{}:{}'.format(session_id, encrypted_cookie),
use_cache=use_cache, use_cache=use_cache,
) )
def user_for_token(self, token, use_cache=True): def user_for_token(self, token, use_cache=True, session_id=''):
"""Ask the Hub to identify the user for a given token. """Ask the Hub to identify the user for a given token.
Args: Args:
@@ -322,10 +324,10 @@ class HubAuth(Configurable):
url=url_path_join(self.api_url, url=url_path_join(self.api_url,
"authorizations/token", "authorizations/token",
quote(token, safe='')), quote(token, safe='')),
cache_key='token:%s' % token, cache_key='token:{}:{}'.format(session_id, token),
use_cache=use_cache, use_cache=use_cache,
) )
auth_header_name = 'Authorization' auth_header_name = 'Authorization'
auth_header_pat = re.compile('token\s+(.+)', re.IGNORECASE) auth_header_pat = re.compile('token\s+(.+)', re.IGNORECASE)
@@ -347,8 +349,16 @@ class HubAuth(Configurable):
def _get_user_cookie(self, handler): def _get_user_cookie(self, handler):
"""Get the user model from a cookie""" """Get the user model from a cookie"""
encrypted_cookie = handler.get_cookie(self.cookie_name) encrypted_cookie = handler.get_cookie(self.cookie_name)
session_id = self.get_session_id(handler)
if encrypted_cookie: if encrypted_cookie:
return self.user_for_cookie(encrypted_cookie) return self.user_for_cookie(encrypted_cookie, session_id=session_id)
def get_session_id(self, handler):
"""Get the jupyterhub session id
from the jupyterhub-session-id cookie.
"""
return handler.get_cookie('jupyterhub-session-id', '')
def get_user(self, handler): def get_user(self, handler):
"""Get the Hub user for a given tornado handler. """Get the Hub user for a given tornado handler.
@@ -371,11 +381,12 @@ class HubAuth(Configurable):
return handler._cached_hub_user return handler._cached_hub_user
handler._cached_hub_user = user_model = None handler._cached_hub_user = user_model = None
session_id = self.get_session_id(handler)
# check token first # check token first
token = self.get_token(handler) token = self.get_token(handler)
if token: if token:
user_model = self.user_for_token(token) user_model = self.user_for_token(token, session_id=session_id)
if user_model: if user_model:
handler._token_authenticated = True handler._token_authenticated = True
@@ -425,8 +436,10 @@ class HubOAuth(HubAuth):
def _get_user_cookie(self, handler): def _get_user_cookie(self, handler):
token = handler.get_secure_cookie(self.cookie_name) token = handler.get_secure_cookie(self.cookie_name)
session_id = self.get_session_id(handler)
if token: if token:
user_model = self.user_for_token(token) token = token.decode('ascii', 'replace')
user_model = self.user_for_token(token, session_id=session_id)
if user_model is None: if user_model is None:
app_log.warning("Token stored in cookie may have expired") app_log.warning("Token stored in cookie may have expired")
handler.clear_cookie(self.cookie_name) handler.clear_cookie(self.cookie_name)
@@ -504,7 +517,7 @@ class HubOAuth(HubAuth):
def _encode_state(self, state): def _encode_state(self, state):
"""Encode a state dict as url-safe base64""" """Encode a state dict as url-safe base64"""
# trim trailing `=` because # trim trailing `=` because = is itself not url-safe!
json_state = json.dumps(state) json_state = json.dumps(state)
return base64.urlsafe_b64encode( return base64.urlsafe_b64encode(
json_state.encode('utf8') json_state.encode('utf8')
@@ -823,7 +836,8 @@ class HubOAuthCallbackHandler(HubOAuthenticated, RequestHandler):
next_url = self.hub_auth.get_next_url(cookie_state) next_url = self.hub_auth.get_next_url(cookie_state)
# TODO: make async (in a Thread?) # TODO: make async (in a Thread?)
token = self.hub_auth.token_for_code(code) token = self.hub_auth.token_for_code(code)
user_model = self.hub_auth.user_for_token(token) session_id = self.hub_auth.get_session_id(self)
user_model = self.hub_auth.user_for_token(token, session_id=session_id)
if user_model is None: if user_model is None:
raise HTTPError(500, "oauth callback failed to identify a user") raise HTTPError(500, "oauth callback failed to identify a user")
app_log.info("Logged-in user %s", user_model) app_log.info("Logged-in user %s", user_model)