mirror of
https://github.com/jupyterhub/jupyterhub.git
synced 2025-10-07 10:04:07 +00:00
Merge pull request #3850 from minrk/cache-scopes
memoize some scope functions
This commit is contained in:
154
jupyterhub/_memoize.py
Normal file
154
jupyterhub/_memoize.py
Normal file
@@ -0,0 +1,154 @@
|
|||||||
|
"""Utilities for memoization
|
||||||
|
|
||||||
|
Note: a memoized function should always return an _immutable_
|
||||||
|
result to avoid later modifications polluting cached results.
|
||||||
|
"""
|
||||||
|
from collections import OrderedDict
|
||||||
|
from functools import wraps
|
||||||
|
|
||||||
|
|
||||||
|
class DoNotCache:
|
||||||
|
"""Wrapper to return a result without caching it.
|
||||||
|
|
||||||
|
In a function decorated with `@lru_cache_key`:
|
||||||
|
|
||||||
|
return DoNotCache(result)
|
||||||
|
|
||||||
|
is equivalent to:
|
||||||
|
|
||||||
|
return result # but don't cache it!
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, result):
|
||||||
|
self.result = result
|
||||||
|
|
||||||
|
|
||||||
|
class LRUCache:
|
||||||
|
"""A simple Least-Recently-Used (LRU) cache with a max size"""
|
||||||
|
|
||||||
|
def __init__(self, maxsize=1024):
|
||||||
|
self._cache = OrderedDict()
|
||||||
|
self.maxsize = maxsize
|
||||||
|
|
||||||
|
def __contains__(self, key):
|
||||||
|
return key in self._cache
|
||||||
|
|
||||||
|
def get(self, key, default=None):
|
||||||
|
"""Get an item from the cache"""
|
||||||
|
if key in self._cache:
|
||||||
|
# cache hit, bump to front of the queue for LRU
|
||||||
|
result = self._cache[key]
|
||||||
|
self._cache.move_to_end(key)
|
||||||
|
return result
|
||||||
|
return default
|
||||||
|
|
||||||
|
def set(self, key, value):
|
||||||
|
"""Store an entry in the cache
|
||||||
|
|
||||||
|
Purges oldest entry if cache is full
|
||||||
|
"""
|
||||||
|
self._cache[key] = value
|
||||||
|
# cache is full, purge oldest entry
|
||||||
|
if len(self._cache) > self.maxsize:
|
||||||
|
self._cache.popitem(last=False)
|
||||||
|
|
||||||
|
__getitem__ = get
|
||||||
|
__setitem__ = set
|
||||||
|
|
||||||
|
|
||||||
|
def lru_cache_key(key_func, maxsize=1024):
|
||||||
|
"""Like functools.lru_cache, but takes a custom key function,
|
||||||
|
as seen in sorted(key=func).
|
||||||
|
|
||||||
|
Useful for non-hashable arguments which have a known hashable equivalent (e.g. sets, lists),
|
||||||
|
or mutable objects where only immutable fields might be used
|
||||||
|
(e.g. User, where only username affects output).
|
||||||
|
|
||||||
|
For safety: Cached results should always be immutable,
|
||||||
|
such as using `frozenset` instead of mutable `set`.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
@lru_cache_key(lambda user: user.name)
|
||||||
|
def func_user(user):
|
||||||
|
# output only varies by name
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key (callable):
|
||||||
|
Should have the same signature as the decorated function.
|
||||||
|
Returns a hashable key to use in the cache
|
||||||
|
maxsize (int):
|
||||||
|
The maximum size of the cache.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def cache_func(func):
|
||||||
|
cache = LRUCache(maxsize=maxsize)
|
||||||
|
# the actual decorated function:
|
||||||
|
@wraps(func)
|
||||||
|
def cached(*args, **kwargs):
|
||||||
|
cache_key = key_func(*args, **kwargs)
|
||||||
|
if cache_key in cache:
|
||||||
|
# cache hit
|
||||||
|
return cache[cache_key]
|
||||||
|
else:
|
||||||
|
# cache miss, call function and cache result
|
||||||
|
result = func(*args, **kwargs)
|
||||||
|
if isinstance(result, DoNotCache):
|
||||||
|
# DoNotCache prevents caching
|
||||||
|
result = result.result
|
||||||
|
else:
|
||||||
|
cache[cache_key] = result
|
||||||
|
return result
|
||||||
|
|
||||||
|
return cached
|
||||||
|
|
||||||
|
return cache_func
|
||||||
|
|
||||||
|
|
||||||
|
class FrozenDict(dict):
|
||||||
|
"""A frozen dictionary subclass
|
||||||
|
|
||||||
|
Immutable and hashable, so it can be used as a cache key
|
||||||
|
|
||||||
|
Values will be frozen with `.freeze(value)`
|
||||||
|
and must be hashable after freezing.
|
||||||
|
|
||||||
|
Not rigorous, but enough for our purposes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
_hash = None
|
||||||
|
|
||||||
|
def __init__(self, d):
|
||||||
|
dict_set = dict.__setitem__
|
||||||
|
for key, value in d.items():
|
||||||
|
dict.__setitem__(self, key, self._freeze(value))
|
||||||
|
|
||||||
|
def _freeze(self, item):
|
||||||
|
"""Make values of a dict hashable
|
||||||
|
- list, set -> frozenset
|
||||||
|
- dict -> recursive _FrozenDict
|
||||||
|
- anything else: assumed hashable
|
||||||
|
"""
|
||||||
|
if isinstance(item, FrozenDict):
|
||||||
|
return item
|
||||||
|
elif isinstance(item, list):
|
||||||
|
return tuple(self._freeze(e) for e in item)
|
||||||
|
elif isinstance(item, set):
|
||||||
|
return frozenset(item)
|
||||||
|
elif isinstance(item, dict):
|
||||||
|
return FrozenDict(item)
|
||||||
|
else:
|
||||||
|
# any other type is assumed hashable
|
||||||
|
return item
|
||||||
|
|
||||||
|
def __setitem__(self, key):
|
||||||
|
raise RuntimeError("Cannot modify frozen {type(self).__name__}")
|
||||||
|
|
||||||
|
def update(self, other):
|
||||||
|
raise RuntimeError("Cannot modify frozen {type(self).__name__}")
|
||||||
|
|
||||||
|
def __hash__(self):
|
||||||
|
"""Cache hash because we are immutable"""
|
||||||
|
if self._hash is None:
|
||||||
|
self._hash = hash(tuple((key, value) for key, value in self.items()))
|
||||||
|
return self._hash
|
@@ -1,7 +1,6 @@
|
|||||||
"""Authorization handlers"""
|
"""Authorization handlers"""
|
||||||
# Copyright (c) Jupyter Development Team.
|
# Copyright (c) Jupyter Development Team.
|
||||||
# Distributed under the terms of the Modified BSD License.
|
# Distributed under the terms of the Modified BSD License.
|
||||||
import itertools
|
|
||||||
import json
|
import json
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from urllib.parse import parse_qsl, quote, urlencode, urlparse, urlunparse
|
from urllib.parse import parse_qsl, quote, urlencode, urlparse, urlunparse
|
||||||
@@ -30,7 +29,7 @@ class TokenAPIHandler(APIHandler):
|
|||||||
if owner:
|
if owner:
|
||||||
# having a token means we should be able to read the owner's model
|
# having a token means we should be able to read the owner's model
|
||||||
# (this is the only thing this handler is for)
|
# (this is the only thing this handler is for)
|
||||||
self.expanded_scopes.update(scopes.identify_scopes(owner))
|
self.expanded_scopes |= scopes.identify_scopes(owner)
|
||||||
self.parsed_scopes = scopes.parse_scopes(self.expanded_scopes)
|
self.parsed_scopes = scopes.parse_scopes(self.expanded_scopes)
|
||||||
|
|
||||||
# record activity whenever we see a token
|
# record activity whenever we see a token
|
||||||
@@ -288,7 +287,7 @@ class OAuthAuthorizeHandler(OAuthHandler, BaseHandler):
|
|||||||
# rather than the expanded_scope intersection
|
# rather than the expanded_scope intersection
|
||||||
|
|
||||||
required_scopes = {*scopes.identify_scopes(), *scopes.access_scopes(client)}
|
required_scopes = {*scopes.identify_scopes(), *scopes.access_scopes(client)}
|
||||||
user_scopes.update({"inherit", *required_scopes})
|
user_scopes |= {"inherit", *required_scopes}
|
||||||
|
|
||||||
allowed_scopes = requested_scopes.intersection(user_scopes)
|
allowed_scopes = requested_scopes.intersection(user_scopes)
|
||||||
excluded_scopes = requested_scopes.difference(user_scopes)
|
excluded_scopes = requested_scopes.difference(user_scopes)
|
||||||
|
@@ -44,7 +44,7 @@ class SelfAPIHandler(APIHandler):
|
|||||||
for scope in identify_scopes:
|
for scope in identify_scopes:
|
||||||
if scope not in self.expanded_scopes:
|
if scope not in self.expanded_scopes:
|
||||||
_added_scopes.add(scope)
|
_added_scopes.add(scope)
|
||||||
self.expanded_scopes.add(scope)
|
self.expanded_scopes |= {scope}
|
||||||
if _added_scopes:
|
if _added_scopes:
|
||||||
# re-parse with new scopes
|
# re-parse with new scopes
|
||||||
self.parsed_scopes = scopes.parse_scopes(self.expanded_scopes)
|
self.parsed_scopes = scopes.parse_scopes(self.expanded_scopes)
|
||||||
|
@@ -154,9 +154,8 @@ class JupyterHubRequestValidator(RequestValidator):
|
|||||||
scopes = roles_to_scopes(orm_client.allowed_roles)
|
scopes = roles_to_scopes(orm_client.allowed_roles)
|
||||||
if 'inherit' not in scopes:
|
if 'inherit' not in scopes:
|
||||||
# add identify-user scope
|
# add identify-user scope
|
||||||
scopes.update(identify_scopes())
|
# and access-service scope
|
||||||
# add access-service scope
|
scopes |= identify_scopes() | access_scopes(orm_client)
|
||||||
scopes.update(access_scopes(orm_client))
|
|
||||||
return scopes
|
return scopes
|
||||||
|
|
||||||
def get_original_scopes(self, refresh_token, request, *args, **kwargs):
|
def get_original_scopes(self, refresh_token, request, *args, **kwargs):
|
||||||
|
@@ -1,6 +1,10 @@
|
|||||||
"""
|
"""
|
||||||
General scope definitions and utilities
|
General scope definitions and utilities
|
||||||
|
|
||||||
|
Scope functions generally return _immutable_ collections,
|
||||||
|
such as `frozenset` to avoid mutating cached values.
|
||||||
|
If needed, mutable copies can be made, e.g. `set(frozen_scopes)`
|
||||||
|
|
||||||
Scope variable nomenclature
|
Scope variable nomenclature
|
||||||
---------------------------
|
---------------------------
|
||||||
scopes or 'raw' scopes: collection of scopes that may contain abbreviations (e.g., in role definition)
|
scopes or 'raw' scopes: collection of scopes that may contain abbreviations (e.g., in role definition)
|
||||||
@@ -24,6 +28,7 @@ from tornado import web
|
|||||||
from tornado.log import app_log
|
from tornado.log import app_log
|
||||||
|
|
||||||
from . import orm, roles
|
from . import orm, roles
|
||||||
|
from ._memoize import DoNotCache, FrozenDict, lru_cache_key
|
||||||
|
|
||||||
"""when modifying the scope definitions, make sure that `docs/source/rbac/generate-scope-table.py` is run
|
"""when modifying the scope definitions, make sure that `docs/source/rbac/generate-scope-table.py` is run
|
||||||
so that changes are reflected in the documentation and REST API description."""
|
so that changes are reflected in the documentation and REST API description."""
|
||||||
@@ -144,6 +149,12 @@ class Scope(Enum):
|
|||||||
ALL = True
|
ALL = True
|
||||||
|
|
||||||
|
|
||||||
|
def _intersection_cache_key(scopes_a, scopes_b, db=None):
|
||||||
|
"""Cache key function for scope intersections"""
|
||||||
|
return (frozenset(scopes_a), frozenset(scopes_b))
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache_key(_intersection_cache_key)
|
||||||
def _intersect_expanded_scopes(scopes_a, scopes_b, db=None):
|
def _intersect_expanded_scopes(scopes_a, scopes_b, db=None):
|
||||||
"""Intersect two sets of scopes by comparing their permissions
|
"""Intersect two sets of scopes by comparing their permissions
|
||||||
|
|
||||||
@@ -159,11 +170,16 @@ def _intersect_expanded_scopes(scopes_a, scopes_b, db=None):
|
|||||||
(i.e. users!group=x & users!user=y will be empty, even if user y is in group x.)
|
(i.e. users!group=x & users!user=y will be empty, even if user y is in group x.)
|
||||||
"""
|
"""
|
||||||
empty_set = frozenset()
|
empty_set = frozenset()
|
||||||
|
scopes_a = frozenset(scopes_a)
|
||||||
|
scopes_b = frozenset(scopes_b)
|
||||||
|
|
||||||
# cached lookups for group membership of users and servers
|
# cached lookups for group membership of users and servers
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def groups_for_user(username):
|
def groups_for_user(username):
|
||||||
"""Get set of group names for a given username"""
|
"""Get set of group names for a given username"""
|
||||||
|
# if we need a group lookup, the result is not cacheable
|
||||||
|
nonlocal needs_db
|
||||||
|
needs_db = True
|
||||||
user = db.query(orm.User).filter_by(name=username).first()
|
user = db.query(orm.User).filter_by(name=username).first()
|
||||||
if user is None:
|
if user is None:
|
||||||
return empty_set
|
return empty_set
|
||||||
@@ -179,6 +195,11 @@ def _intersect_expanded_scopes(scopes_a, scopes_b, db=None):
|
|||||||
parsed_scopes_a = parse_scopes(scopes_a)
|
parsed_scopes_a = parse_scopes(scopes_a)
|
||||||
parsed_scopes_b = parse_scopes(scopes_b)
|
parsed_scopes_b = parse_scopes(scopes_b)
|
||||||
|
|
||||||
|
# track whether we need a db lookup (for groups)
|
||||||
|
# because we can't cache the intersection if we do
|
||||||
|
# if there are no group filters, this is cacheable
|
||||||
|
needs_db = False
|
||||||
|
|
||||||
common_bases = parsed_scopes_a.keys() & parsed_scopes_b.keys()
|
common_bases = parsed_scopes_a.keys() & parsed_scopes_b.keys()
|
||||||
|
|
||||||
common_filters = {}
|
common_filters = {}
|
||||||
@@ -220,6 +241,7 @@ def _intersect_expanded_scopes(scopes_a, scopes_b, db=None):
|
|||||||
UserWarning,
|
UserWarning,
|
||||||
)
|
)
|
||||||
warned = True
|
warned = True
|
||||||
|
needs_db = True
|
||||||
|
|
||||||
common_filters[base] = {
|
common_filters[base] = {
|
||||||
entity: filters_a[entity] & filters_b[entity]
|
entity: filters_a[entity] & filters_b[entity]
|
||||||
@@ -245,6 +267,7 @@ def _intersect_expanded_scopes(scopes_a, scopes_b, db=None):
|
|||||||
# resolve group/server hierarchy if db available
|
# resolve group/server hierarchy if db available
|
||||||
servers = servers.difference(common_servers)
|
servers = servers.difference(common_servers)
|
||||||
if db is not None and servers and 'group' in b:
|
if db is not None and servers and 'group' in b:
|
||||||
|
needs_db = True
|
||||||
for server in servers:
|
for server in servers:
|
||||||
server_groups = groups_for_server(server)
|
server_groups = groups_for_server(server)
|
||||||
if server_groups & b['group']:
|
if server_groups & b['group']:
|
||||||
@@ -272,7 +295,12 @@ def _intersect_expanded_scopes(scopes_a, scopes_b, db=None):
|
|||||||
if common_users and "user" not in common_filters[base]:
|
if common_users and "user" not in common_filters[base]:
|
||||||
common_filters[base]["user"] = common_users
|
common_filters[base]["user"] = common_users
|
||||||
|
|
||||||
return unparse_scopes(common_filters)
|
intersection = unparse_scopes(common_filters)
|
||||||
|
if needs_db:
|
||||||
|
# return intersection, but don't cache it if it needed db lookups
|
||||||
|
return DoNotCache(intersection)
|
||||||
|
|
||||||
|
return intersection
|
||||||
|
|
||||||
|
|
||||||
def get_scopes_for(orm_object):
|
def get_scopes_for(orm_object):
|
||||||
@@ -313,7 +341,7 @@ def get_scopes_for(orm_object):
|
|||||||
# only thing we miss by short-circuiting here: warning about excluded extra scopes
|
# only thing we miss by short-circuiting here: warning about excluded extra scopes
|
||||||
return owner_scopes
|
return owner_scopes
|
||||||
|
|
||||||
token_scopes = expand_scopes(token_scopes, owner=owner)
|
token_scopes = set(expand_scopes(token_scopes, owner=owner))
|
||||||
|
|
||||||
if orm_object.client_id != "jupyterhub":
|
if orm_object.client_id != "jupyterhub":
|
||||||
# oauth tokens can be used to access the service issuing the token,
|
# oauth tokens can be used to access the service issuing the token,
|
||||||
@@ -358,6 +386,7 @@ def get_scopes_for(orm_object):
|
|||||||
return expanded_scopes
|
return expanded_scopes
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
def _expand_self_scope(username):
|
def _expand_self_scope(username):
|
||||||
"""
|
"""
|
||||||
Users have a metascope 'self' that should be expanded to standard user privileges.
|
Users have a metascope 'self' that should be expanded to standard user privileges.
|
||||||
@@ -390,9 +419,11 @@ def _expand_self_scope(username):
|
|||||||
'read:tokens',
|
'read:tokens',
|
||||||
'access:servers',
|
'access:servers',
|
||||||
]
|
]
|
||||||
return {f"{scope}!user={username}" for scope in scope_list}
|
# return immutable frozenset because the result is cached
|
||||||
|
return frozenset(f"{scope}!user={username}" for scope in scope_list)
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=65535)
|
||||||
def _expand_scope(scope):
|
def _expand_scope(scope):
|
||||||
"""Returns a scope and all all subscopes
|
"""Returns a scope and all all subscopes
|
||||||
|
|
||||||
@@ -433,9 +464,30 @@ def _expand_scope(scope):
|
|||||||
else:
|
else:
|
||||||
expanded_scopes = expanded_scope_names
|
expanded_scopes = expanded_scope_names
|
||||||
|
|
||||||
return expanded_scopes
|
# return immutable frozenset because the result is cached
|
||||||
|
return frozenset(expanded_scopes)
|
||||||
|
|
||||||
|
|
||||||
|
def _expand_scopes_key(scopes, owner=None):
|
||||||
|
"""Cache key function for expand_scopes
|
||||||
|
|
||||||
|
scopes is usually a mutable list or set,
|
||||||
|
which can be hashed as a frozenset
|
||||||
|
|
||||||
|
For the owner, we only care about what kind they are,
|
||||||
|
and their name.
|
||||||
|
"""
|
||||||
|
# freeze scopes for hash
|
||||||
|
frozen_scopes = frozenset(scopes)
|
||||||
|
if owner is None:
|
||||||
|
owner_key = None
|
||||||
|
else:
|
||||||
|
# owner key is the type and name
|
||||||
|
owner_key = (type(owner).__name__, owner.name)
|
||||||
|
return (frozen_scopes, owner_key)
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache_key(_expand_scopes_key)
|
||||||
def expand_scopes(scopes, owner=None):
|
def expand_scopes(scopes, owner=None):
|
||||||
"""Returns a set of fully expanded scopes for a collection of raw scopes
|
"""Returns a set of fully expanded scopes for a collection of raw scopes
|
||||||
|
|
||||||
@@ -479,8 +531,9 @@ def expand_scopes(scopes, owner=None):
|
|||||||
stacklevel=2,
|
stacklevel=2,
|
||||||
)
|
)
|
||||||
|
|
||||||
# reduce to minimize
|
# reduce to discard overlapping scopes
|
||||||
return reduce_scopes(expanded_scopes)
|
# return immutable frozenset because the result is cached
|
||||||
|
return frozenset(reduce_scopes(expanded_scopes))
|
||||||
|
|
||||||
|
|
||||||
def _needs_scope_expansion(filter_, filter_value, sub_scope):
|
def _needs_scope_expansion(filter_, filter_value, sub_scope):
|
||||||
@@ -614,6 +667,7 @@ def _check_token_scopes(scopes, owner):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache_key(frozenset)
|
||||||
def parse_scopes(scope_list):
|
def parse_scopes(scope_list):
|
||||||
"""
|
"""
|
||||||
Parses scopes and filters in something akin to JSON style
|
Parses scopes and filters in something akin to JSON style
|
||||||
@@ -649,9 +703,11 @@ def parse_scopes(scope_list):
|
|||||||
parsed_scopes[base_scope][key] = {value}
|
parsed_scopes[base_scope][key] = {value}
|
||||||
else:
|
else:
|
||||||
parsed_scopes[base_scope][key].add(value)
|
parsed_scopes[base_scope][key].add(value)
|
||||||
return parsed_scopes
|
# return immutable FrozenDict because the result is cached
|
||||||
|
return FrozenDict(parsed_scopes)
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache_key(FrozenDict)
|
||||||
def unparse_scopes(parsed_scopes):
|
def unparse_scopes(parsed_scopes):
|
||||||
"""Turn a parsed_scopes dictionary back into a expanded scopes set"""
|
"""Turn a parsed_scopes dictionary back into a expanded scopes set"""
|
||||||
expanded_scopes = set()
|
expanded_scopes = set()
|
||||||
@@ -662,14 +718,17 @@ def unparse_scopes(parsed_scopes):
|
|||||||
for entity, names_list in filters.items():
|
for entity, names_list in filters.items():
|
||||||
for name in names_list:
|
for name in names_list:
|
||||||
expanded_scopes.add(f'{base}!{entity}={name}')
|
expanded_scopes.add(f'{base}!{entity}={name}')
|
||||||
return expanded_scopes
|
# return immutable frozenset because the result is cached
|
||||||
|
return frozenset(expanded_scopes)
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache_key(frozenset)
|
||||||
def reduce_scopes(expanded_scopes):
|
def reduce_scopes(expanded_scopes):
|
||||||
"""Reduce expanded scopes to minimal set
|
"""Reduce expanded scopes to minimal set
|
||||||
|
|
||||||
Eliminates redundancy, such as access:services and access:services!service=x
|
Eliminates overlapping scopes, such as access:services and access:services!service=x
|
||||||
"""
|
"""
|
||||||
|
# unparse_scopes already returns a frozenset
|
||||||
return unparse_scopes(parse_scopes(expanded_scopes))
|
return unparse_scopes(parse_scopes(expanded_scopes))
|
||||||
|
|
||||||
|
|
||||||
@@ -723,6 +782,14 @@ def needs_scope(*scopes):
|
|||||||
return scope_decorator
|
return scope_decorator
|
||||||
|
|
||||||
|
|
||||||
|
def _identify_key(obj=None):
|
||||||
|
if obj is None:
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
return (type(obj).__name__, obj.name)
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache_key(_identify_key)
|
||||||
def identify_scopes(obj=None):
|
def identify_scopes(obj=None):
|
||||||
"""Return 'identify' scopes for an orm object
|
"""Return 'identify' scopes for an orm object
|
||||||
|
|
||||||
@@ -735,20 +802,25 @@ def identify_scopes(obj=None):
|
|||||||
identify scopes (set): set of scopes needed for 'identify' endpoints
|
identify scopes (set): set of scopes needed for 'identify' endpoints
|
||||||
"""
|
"""
|
||||||
if obj is None:
|
if obj is None:
|
||||||
return {f"read:users:{field}!user" for field in {"name", "groups"}}
|
return frozenset(f"read:users:{field}!user" for field in {"name", "groups"})
|
||||||
elif isinstance(obj, orm.User):
|
elif isinstance(obj, orm.User):
|
||||||
return {f"read:users:{field}!user={obj.name}" for field in {"name", "groups"}}
|
return frozenset(
|
||||||
|
f"read:users:{field}!user={obj.name}" for field in {"name", "groups"}
|
||||||
|
)
|
||||||
elif isinstance(obj, orm.Service):
|
elif isinstance(obj, orm.Service):
|
||||||
return {f"read:services:{field}!service={obj.name}" for field in {"name"}}
|
return frozenset(
|
||||||
|
f"read:services:{field}!service={obj.name}" for field in {"name"}
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise TypeError(f"Expected orm.User or orm.Service, got {obj!r}")
|
raise TypeError(f"Expected orm.User or orm.Service, got {obj!r}")
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache_key(lambda oauth_client: oauth_client.identifier)
|
||||||
def access_scopes(oauth_client):
|
def access_scopes(oauth_client):
|
||||||
"""Return scope(s) required to access an oauth client"""
|
"""Return scope(s) required to access an oauth client"""
|
||||||
scopes = set()
|
scopes = set()
|
||||||
if oauth_client.identifier == "jupyterhub":
|
if oauth_client.identifier == "jupyterhub":
|
||||||
return scopes
|
return frozenset()
|
||||||
spawner = oauth_client.spawner
|
spawner = oauth_client.spawner
|
||||||
if spawner:
|
if spawner:
|
||||||
scopes.add(f"access:servers!server={spawner.user.name}/{spawner.name}")
|
scopes.add(f"access:servers!server={spawner.user.name}/{spawner.name}")
|
||||||
@@ -760,9 +832,19 @@ def access_scopes(oauth_client):
|
|||||||
app_log.warning(
|
app_log.warning(
|
||||||
f"OAuth client {oauth_client} has no associated service or spawner!"
|
f"OAuth client {oauth_client} has no associated service or spawner!"
|
||||||
)
|
)
|
||||||
return scopes
|
return frozenset(scopes)
|
||||||
|
|
||||||
|
|
||||||
|
def _check_scope_key(sub_scope, orm_resource, kind):
|
||||||
|
"""Cache key function for check_scope_filter"""
|
||||||
|
if kind == 'server':
|
||||||
|
resource_key = (orm_resource.user.name, orm_resource.name)
|
||||||
|
else:
|
||||||
|
resource_key = orm_resource.name
|
||||||
|
return (sub_scope, resource_key, kind)
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache_key(_check_scope_key)
|
||||||
def check_scope_filter(sub_scope, orm_resource, kind):
|
def check_scope_filter(sub_scope, orm_resource, kind):
|
||||||
"""Return whether a sub_scope filter applies to a given resource.
|
"""Return whether a sub_scope filter applies to a given resource.
|
||||||
|
|
||||||
@@ -791,8 +873,8 @@ def check_scope_filter(sub_scope, orm_resource, kind):
|
|||||||
if kind == 'user' and 'group' in sub_scope:
|
if kind == 'user' and 'group' in sub_scope:
|
||||||
group_names = {group.name for group in orm_resource.groups}
|
group_names = {group.name for group in orm_resource.groups}
|
||||||
user_in_group = bool(group_names & set(sub_scope['group']))
|
user_in_group = bool(group_names & set(sub_scope['group']))
|
||||||
if user_in_group:
|
# cannot cache if we needed to lookup groups in db
|
||||||
return True
|
return DoNotCache(user_in_group)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
@@ -831,6 +913,7 @@ def describe_parsed_scopes(parsed_scopes, username=None):
|
|||||||
return descriptions
|
return descriptions
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache_key(lambda raw_scopes, username=None: (frozenset(raw_scopes), username))
|
||||||
def describe_raw_scopes(raw_scopes, username=None):
|
def describe_raw_scopes(raw_scopes, username=None):
|
||||||
"""Return list of descriptions of raw scopes
|
"""Return list of descriptions of raw scopes
|
||||||
|
|
||||||
@@ -861,7 +944,8 @@ def describe_raw_scopes(raw_scopes, username=None):
|
|||||||
"filter": filter_text,
|
"filter": filter_text,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
return descriptions
|
# make sure we return immutable from a cached function
|
||||||
|
return tuple(descriptions)
|
||||||
|
|
||||||
|
|
||||||
# regex for custom scope
|
# regex for custom scope
|
||||||
|
94
jupyterhub/tests/test_memoize.py
Normal file
94
jupyterhub/tests/test_memoize.py
Normal file
@@ -0,0 +1,94 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from jupyterhub._memoize import DoNotCache, FrozenDict, LRUCache, lru_cache_key
|
||||||
|
|
||||||
|
|
||||||
|
def test_lru_cache():
|
||||||
|
cache = LRUCache(maxsize=2)
|
||||||
|
cache["a"] = 1
|
||||||
|
assert "a" in cache
|
||||||
|
assert "b" not in cache
|
||||||
|
cache["b"] = 2
|
||||||
|
assert cache["b"] == 2
|
||||||
|
|
||||||
|
# accessing a makes it more recent than b
|
||||||
|
assert cache["a"] == 1
|
||||||
|
assert "b" in cache
|
||||||
|
assert "a" in cache
|
||||||
|
|
||||||
|
# storing c pushes oldest ('b') out of cache
|
||||||
|
cache["c"] = 3
|
||||||
|
assert len(cache._cache) == 2
|
||||||
|
assert "a" in cache
|
||||||
|
assert "c" in cache
|
||||||
|
assert "b" not in cache
|
||||||
|
|
||||||
|
|
||||||
|
def test_lru_cache_key():
|
||||||
|
|
||||||
|
call_count = 0
|
||||||
|
|
||||||
|
@lru_cache_key(frozenset)
|
||||||
|
def reverse(arg):
|
||||||
|
nonlocal call_count
|
||||||
|
call_count += 1
|
||||||
|
return list(reversed(arg))
|
||||||
|
|
||||||
|
in1 = [1, 2]
|
||||||
|
before = call_count
|
||||||
|
out1 = reverse(in1)
|
||||||
|
assert call_count == before + 1
|
||||||
|
assert out1 == [2, 1]
|
||||||
|
|
||||||
|
before = call_count
|
||||||
|
out2 = reverse(in1)
|
||||||
|
assert call_count == before
|
||||||
|
assert out2 is out1
|
||||||
|
|
||||||
|
|
||||||
|
def test_do_not_cache():
|
||||||
|
|
||||||
|
call_count = 0
|
||||||
|
|
||||||
|
@lru_cache_key(lambda arg: arg)
|
||||||
|
def is_even(arg):
|
||||||
|
nonlocal call_count
|
||||||
|
call_count += 1
|
||||||
|
if arg % 2:
|
||||||
|
return DoNotCache(False)
|
||||||
|
return True
|
||||||
|
|
||||||
|
before = call_count
|
||||||
|
assert is_even(0) == True
|
||||||
|
assert call_count == before + 1
|
||||||
|
|
||||||
|
# caches even results
|
||||||
|
before = call_count
|
||||||
|
assert is_even(0) == True
|
||||||
|
assert call_count == before
|
||||||
|
|
||||||
|
before = call_count
|
||||||
|
assert is_even(1) == False
|
||||||
|
assert call_count == before + 1
|
||||||
|
|
||||||
|
# doesn't cache odd results
|
||||||
|
before = call_count
|
||||||
|
assert is_even(1) == False
|
||||||
|
assert call_count == before + 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"d",
|
||||||
|
[
|
||||||
|
{"key": "value"},
|
||||||
|
{"key": ["list"]},
|
||||||
|
{"key": {"set"}},
|
||||||
|
{"key": ("tu", "ple")},
|
||||||
|
{"key": {"nested": ["dict"]}},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_frozen_dict(d):
|
||||||
|
frozen_1 = FrozenDict(d)
|
||||||
|
frozen_2 = FrozenDict(d)
|
||||||
|
assert hash(frozen_1) == hash(frozen_2)
|
||||||
|
assert frozen_1 == frozen_2
|
@@ -8,6 +8,7 @@ from tornado import web
|
|||||||
from tornado.httputil import HTTPServerRequest
|
from tornado.httputil import HTTPServerRequest
|
||||||
|
|
||||||
from .. import orm, roles, scopes
|
from .. import orm, roles, scopes
|
||||||
|
from .._memoize import FrozenDict
|
||||||
from ..handlers import BaseHandler
|
from ..handlers import BaseHandler
|
||||||
from ..scopes import (
|
from ..scopes import (
|
||||||
Scope,
|
Scope,
|
||||||
@@ -38,6 +39,7 @@ def test_scope_constructor():
|
|||||||
f'read:users!user={user2}',
|
f'read:users!user={user2}',
|
||||||
]
|
]
|
||||||
parsed_scopes = parse_scopes(scope_list)
|
parsed_scopes = parse_scopes(scope_list)
|
||||||
|
assert isinstance(parsed_scopes, FrozenDict)
|
||||||
|
|
||||||
assert 'read:users' in parsed_scopes
|
assert 'read:users' in parsed_scopes
|
||||||
assert parsed_scopes['users']
|
assert parsed_scopes['users']
|
||||||
@@ -467,6 +469,7 @@ async def test_metascope_self_expansion(
|
|||||||
orm_obj = create_service_with_scopes('self')
|
orm_obj = create_service_with_scopes('self')
|
||||||
# test expansion of user/service scopes
|
# test expansion of user/service scopes
|
||||||
scopes = get_scopes_for(orm_obj)
|
scopes = get_scopes_for(orm_obj)
|
||||||
|
assert isinstance(scopes, frozenset)
|
||||||
assert bool(scopes) == has_user_scopes
|
assert bool(scopes) == has_user_scopes
|
||||||
|
|
||||||
# test expansion of token scopes
|
# test expansion of token scopes
|
||||||
@@ -488,6 +491,8 @@ async def test_metascope_inherit_expansion(app, create_user_with_scopes):
|
|||||||
token.scopes.clear()
|
token.scopes.clear()
|
||||||
app.db.commit()
|
app.db.commit()
|
||||||
token_scope_set = get_scopes_for(token)
|
token_scope_set = get_scopes_for(token)
|
||||||
|
assert isinstance(token_scope_set, frozenset)
|
||||||
|
|
||||||
assert token_scope_set.issubset(identify_scopes(user.orm_user))
|
assert token_scope_set.issubset(identify_scopes(user.orm_user))
|
||||||
|
|
||||||
|
|
||||||
@@ -1169,4 +1174,5 @@ def test_expand_scopes(user, scopes, expected):
|
|||||||
expected.update(_expand_self_scope(user.name))
|
expected.update(_expand_self_scope(user.name))
|
||||||
|
|
||||||
expanded = expand_scopes(scopes, owner=user.orm_user)
|
expanded = expand_scopes(scopes, owner=user.orm_user)
|
||||||
|
assert isinstance(expanded, frozenset)
|
||||||
assert sorted(expanded) == sorted(expected)
|
assert sorted(expanded) == sorted(expected)
|
||||||
|
@@ -360,7 +360,7 @@ async def test_oauth_service_roles(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if 'inherit' in expected_scopes:
|
if 'inherit' in expected_scopes:
|
||||||
expected_scopes = scopes.get_scopes_for(user.orm_user)
|
expected_scopes = set(scopes.get_scopes_for(user.orm_user))
|
||||||
|
|
||||||
# always expect identify/access scopes
|
# always expect identify/access scopes
|
||||||
# on successful authentication
|
# on successful authentication
|
||||||
|
Reference in New Issue
Block a user