From 29b73563dc00f5020eaa604f246bb88d71dfb2d3 Mon Sep 17 00:00:00 2001 From: Min RK Date: Thu, 17 Mar 2022 15:41:01 +0100 Subject: [PATCH 1/5] cache common scope operations we expand/parse the same scopes _a lot_. We can save time with some caching. Main change: cached functions must return immutable frozenset instead of mutable set, to avoid mutating the result of subsequent returns. Some functions can only be cached _sometimes_ (e.g. group lookups in db cannot be cached), for which we have a DoNotCache(result) exception --- jupyterhub/_memoize.py | 90 +++++++++++++++++++++ jupyterhub/apihandlers/auth.py | 5 +- jupyterhub/apihandlers/users.py | 2 +- jupyterhub/oauth/provider.py | 5 +- jupyterhub/scopes.py | 107 +++++++++++++++++++++---- jupyterhub/tests/test_memoize.py | 75 +++++++++++++++++ jupyterhub/tests/test_services_auth.py | 2 +- 7 files changed, 264 insertions(+), 22 deletions(-) create mode 100644 jupyterhub/_memoize.py create mode 100644 jupyterhub/tests/test_memoize.py diff --git a/jupyterhub/_memoize.py b/jupyterhub/_memoize.py new file mode 100644 index 00000000..1ee2861e --- /dev/null +++ b/jupyterhub/_memoize.py @@ -0,0 +1,90 @@ +"""Utilities for memoization""" +from collections import OrderedDict +from functools import wraps + + +class DoNotCache(Exception): + """Special exception to return a result without caching it""" + + def __init__(self, result): + self.result = result + + +class LRUCache: + """A simple Least-Recently-Used (LRU) cache with a max size""" + + def __init__(self, maxsize=1024): + self._cache = OrderedDict() + self.maxsize = maxsize + + def __contains__(self, key): + return key in self._cache + + def get(self, key, default=None): + """Get an item from the cache""" + if key in self._cache: + # cache hit, bump to front of the queue for LRU + result = self._cache[key] + self._cache.move_to_end(key) + return result + return default + + def set(self, key, value): + """Store an entry in the cache + + Purges oldest entry if cache is full + """ + self._cache[key] = value + # cache is full, purge oldest entry + if len(self._cache) > self.maxsize: + self._cache.popitem(last=False) + + __getitem__ = get + __setitem__ = set + + +def lru_cache_key(key_func, maxsize=1024): + """Like functools.lru_cache, but takes a custom key function, + as seen in sorted(key=func). + + Useful for non-hashable arguments which have a known hashable equivalent (e.g. sets, lists), + or mutable objects where only immutable fields might be used + (e.g. User, where only username affects output). + + Example: + + @lru_cache_key(lambda user: user.name) + def func_user(user): + # output only varies by name + + Args: + key (callable): + Should have the same signature as the decorated function. + Returns a hashable key to use in the cache + maxsize (int): + The maximum size of the cache. + """ + + def cache_func(func): + cache = LRUCache(maxsize=maxsize) + # the actual decorated function: + @wraps(func) + def cached(*args, **kwargs): + cache_key = key_func(*args, **kwargs) + if cache_key in cache: + # cache hit + return cache[cache_key] + else: + # cache miss, call function and cache result + try: + result = func(*args, **kwargs) + except DoNotCache as e: + # DoNotCache prevents caching + result = e.result + else: + cache[cache_key] = result + return result + + return cached + + return cache_func diff --git a/jupyterhub/apihandlers/auth.py b/jupyterhub/apihandlers/auth.py index 643e3a3a..11171752 100644 --- a/jupyterhub/apihandlers/auth.py +++ b/jupyterhub/apihandlers/auth.py @@ -1,7 +1,6 @@ """Authorization handlers""" # Copyright (c) Jupyter Development Team. # Distributed under the terms of the Modified BSD License. -import itertools import json from datetime import datetime from urllib.parse import parse_qsl, quote, urlencode, urlparse, urlunparse @@ -30,7 +29,7 @@ class TokenAPIHandler(APIHandler): if owner: # having a token means we should be able to read the owner's model # (this is the only thing this handler is for) - self.expanded_scopes.update(scopes.identify_scopes(owner)) + self.expanded_scopes |= scopes.identify_scopes(owner) self.parsed_scopes = scopes.parse_scopes(self.expanded_scopes) # record activity whenever we see a token @@ -288,7 +287,7 @@ class OAuthAuthorizeHandler(OAuthHandler, BaseHandler): # rather than the expanded_scope intersection required_scopes = {*scopes.identify_scopes(), *scopes.access_scopes(client)} - user_scopes.update({"inherit", *required_scopes}) + user_scopes |= {"inherit", *required_scopes} allowed_scopes = requested_scopes.intersection(user_scopes) excluded_scopes = requested_scopes.difference(user_scopes) diff --git a/jupyterhub/apihandlers/users.py b/jupyterhub/apihandlers/users.py index 6341c3b1..5e1ea285 100644 --- a/jupyterhub/apihandlers/users.py +++ b/jupyterhub/apihandlers/users.py @@ -44,7 +44,7 @@ class SelfAPIHandler(APIHandler): for scope in identify_scopes: if scope not in self.expanded_scopes: _added_scopes.add(scope) - self.expanded_scopes.add(scope) + self.expanded_scopes |= {scope} if _added_scopes: # re-parse with new scopes self.parsed_scopes = scopes.parse_scopes(self.expanded_scopes) diff --git a/jupyterhub/oauth/provider.py b/jupyterhub/oauth/provider.py index 3e1c6e0e..05637142 100644 --- a/jupyterhub/oauth/provider.py +++ b/jupyterhub/oauth/provider.py @@ -154,9 +154,8 @@ class JupyterHubRequestValidator(RequestValidator): scopes = roles_to_scopes(orm_client.allowed_roles) if 'inherit' not in scopes: # add identify-user scope - scopes.update(identify_scopes()) - # add access-service scope - scopes.update(access_scopes(orm_client)) + # and access-service scope + scopes |= identify_scopes() | access_scopes(orm_client) return scopes def get_original_scopes(self, refresh_token, request, *args, **kwargs): diff --git a/jupyterhub/scopes.py b/jupyterhub/scopes.py index 25db5c7e..ccceec9a 100644 --- a/jupyterhub/scopes.py +++ b/jupyterhub/scopes.py @@ -24,6 +24,7 @@ from tornado import web from tornado.log import app_log from . import orm, roles +from ._memoize import DoNotCache, lru_cache_key """when modifying the scope definitions, make sure that `docs/source/rbac/generate-scope-table.py` is run so that changes are reflected in the documentation and REST API description.""" @@ -144,6 +145,12 @@ class Scope(Enum): ALL = True +def _intersection_cache_key(scopes_a, scopes_b, db=None): + """Cache key function for scope intersections""" + return (frozenset(scopes_a), frozenset(scopes_b)) + + +@lru_cache_key(_intersection_cache_key) def _intersect_expanded_scopes(scopes_a, scopes_b, db=None): """Intersect two sets of scopes by comparing their permissions @@ -159,11 +166,16 @@ def _intersect_expanded_scopes(scopes_a, scopes_b, db=None): (i.e. users!group=x & users!user=y will be empty, even if user y is in group x.) """ empty_set = frozenset() + scopes_a = frozenset(scopes_a) + scopes_b = frozenset(scopes_b) # cached lookups for group membership of users and servers @lru_cache() def groups_for_user(username): """Get set of group names for a given username""" + # if we need a group lookup, the result is not cacheable + nonlocal needs_db + needs_db = True user = db.query(orm.User).filter_by(name=username).first() if user is None: return empty_set @@ -179,6 +191,11 @@ def _intersect_expanded_scopes(scopes_a, scopes_b, db=None): parsed_scopes_a = parse_scopes(scopes_a) parsed_scopes_b = parse_scopes(scopes_b) + # track whether we need a db lookup (for groups) + # because we can't cache the intersection if we do + # if there are no group filters, this is cacheable + needs_db = False + common_bases = parsed_scopes_a.keys() & parsed_scopes_b.keys() common_filters = {} @@ -220,6 +237,7 @@ def _intersect_expanded_scopes(scopes_a, scopes_b, db=None): UserWarning, ) warned = True + needs_db = True common_filters[base] = { entity: filters_a[entity] & filters_b[entity] @@ -245,6 +263,7 @@ def _intersect_expanded_scopes(scopes_a, scopes_b, db=None): # resolve group/server hierarchy if db available servers = servers.difference(common_servers) if db is not None and servers and 'group' in b: + _needs_db = True for server in servers: server_groups = groups_for_server(server) if server_groups & b['group']: @@ -272,7 +291,12 @@ def _intersect_expanded_scopes(scopes_a, scopes_b, db=None): if common_users and "user" not in common_filters[base]: common_filters[base]["user"] = common_users - return unparse_scopes(common_filters) + intersection = unparse_scopes(common_filters) + if needs_db: + # return intersection, but don't cache it if it needed db lookups + raise DoNotCache(intersection) + + return intersection def get_scopes_for(orm_object): @@ -313,7 +337,7 @@ def get_scopes_for(orm_object): # only thing we miss by short-circuiting here: warning about excluded extra scopes return owner_scopes - token_scopes = expand_scopes(token_scopes, owner=owner) + token_scopes = set(expand_scopes(token_scopes, owner=owner)) if orm_object.client_id != "jupyterhub": # oauth tokens can be used to access the service issuing the token, @@ -358,6 +382,7 @@ def get_scopes_for(orm_object): return expanded_scopes +@lru_cache() def _expand_self_scope(username): """ Users have a metascope 'self' that should be expanded to standard user privileges. @@ -390,9 +415,11 @@ def _expand_self_scope(username): 'read:tokens', 'access:servers', ] - return {f"{scope}!user={username}" for scope in scope_list} + # return frozenset because the result is cached + return frozenset(f"{scope}!user={username}" for scope in scope_list) +@lru_cache(maxsize=65535) def _expand_scope(scope): """Returns a scope and all all subscopes @@ -433,9 +460,30 @@ def _expand_scope(scope): else: expanded_scopes = expanded_scope_names - return expanded_scopes + # return frozenset because we are cached + return frozenset(expanded_scopes) +def _expand_scopes_key(scopes, owner=None): + """Cache key function for expand_scopes + + scopes is usually a mutable list or set, + which can be hashed as a frozenset + + For the owner, we only care about what kind they are, + and their name. + """ + # freeze scopes for hash + frozen_scopes = frozenset(scopes) + if owner is None: + owner_key = None + else: + # owner key is the type and name + owner_key = (type(owner).__name__, owner.name) + return (frozen_scopes, owner_key) + + +@lru_cache_key(_expand_scopes_key) def expand_scopes(scopes, owner=None): """Returns a set of fully expanded scopes for a collection of raw scopes @@ -479,7 +527,7 @@ def expand_scopes(scopes, owner=None): stacklevel=2, ) - # reduce to minimize + # reduce to discard overlapping scopes return reduce_scopes(expanded_scopes) @@ -614,6 +662,7 @@ def _check_token_scopes(scopes, owner): ) +@lru_cache_key(lambda scope_list: frozenset(scope_list)) def parse_scopes(scope_list): """ Parses scopes and filters in something akin to JSON style @@ -652,6 +701,10 @@ def parse_scopes(scope_list): return parsed_scopes +# Note: it doesn't make sense to cache unparse_scopes +# because computing the cache key is as expensive as the function itself + + def unparse_scopes(parsed_scopes): """Turn a parsed_scopes dictionary back into a expanded scopes set""" expanded_scopes = set() @@ -662,9 +715,10 @@ def unparse_scopes(parsed_scopes): for entity, names_list in filters.items(): for name in names_list: expanded_scopes.add(f'{base}!{entity}={name}') - return expanded_scopes + return frozenset(expanded_scopes) +@lru_cache_key(frozenset) def reduce_scopes(expanded_scopes): """Reduce expanded scopes to minimal set @@ -723,6 +777,14 @@ def needs_scope(*scopes): return scope_decorator +def _identify_key(obj=None): + if obj is None: + return None + else: + return (type(obj).__name__, obj.name) + + +@lru_cache_key(_identify_key) def identify_scopes(obj=None): """Return 'identify' scopes for an orm object @@ -735,20 +797,25 @@ def identify_scopes(obj=None): identify scopes (set): set of scopes needed for 'identify' endpoints """ if obj is None: - return {f"read:users:{field}!user" for field in {"name", "groups"}} + return frozenset(f"read:users:{field}!user" for field in {"name", "groups"}) elif isinstance(obj, orm.User): - return {f"read:users:{field}!user={obj.name}" for field in {"name", "groups"}} + return frozenset( + f"read:users:{field}!user={obj.name}" for field in {"name", "groups"} + ) elif isinstance(obj, orm.Service): - return {f"read:services:{field}!service={obj.name}" for field in {"name"}} + return frozenset( + f"read:services:{field}!service={obj.name}" for field in {"name"} + ) else: raise TypeError(f"Expected orm.User or orm.Service, got {obj!r}") +@lru_cache_key(lambda oauth_client: oauth_client.identifier) def access_scopes(oauth_client): """Return scope(s) required to access an oauth client""" scopes = set() if oauth_client.identifier == "jupyterhub": - return scopes + return frozenset() spawner = oauth_client.spawner if spawner: scopes.add(f"access:servers!server={spawner.user.name}/{spawner.name}") @@ -760,9 +827,19 @@ def access_scopes(oauth_client): app_log.warning( f"OAuth client {oauth_client} has no associated service or spawner!" ) - return scopes + return frozenset(scopes) +def _check_scope_key(sub_scope, orm_resource, kind): + """Cache key function for check_scope_filter""" + if kind == 'server': + resource_key = (orm_resource.user.name, orm_resource.name) + else: + resource_key = orm_resource.name + return (sub_scope, resource_key, kind) + + +@lru_cache_key(_check_scope_key) def check_scope_filter(sub_scope, orm_resource, kind): """Return whether a sub_scope filter applies to a given resource. @@ -791,8 +868,8 @@ def check_scope_filter(sub_scope, orm_resource, kind): if kind == 'user' and 'group' in sub_scope: group_names = {group.name for group in orm_resource.groups} user_in_group = bool(group_names & set(sub_scope['group'])) - if user_in_group: - return True + # cannot cache if we needed to lookup groups in db + raise DoNotCache(user_in_group) return False @@ -831,6 +908,7 @@ def describe_parsed_scopes(parsed_scopes, username=None): return descriptions +@lru_cache_key(lambda raw_scopes, username=None: (frozenset(raw_scopes), username)) def describe_raw_scopes(raw_scopes, username=None): """Return list of descriptions of raw scopes @@ -861,7 +939,8 @@ def describe_raw_scopes(raw_scopes, username=None): "filter": filter_text, } ) - return descriptions + # make sure we return immutable from a cached function + return tuple(descriptions) # regex for custom scope diff --git a/jupyterhub/tests/test_memoize.py b/jupyterhub/tests/test_memoize.py new file mode 100644 index 00000000..a53cc7d5 --- /dev/null +++ b/jupyterhub/tests/test_memoize.py @@ -0,0 +1,75 @@ +from jupyterhub._memoize import DoNotCache, LRUCache, lru_cache_key + + +def test_lru_cache(): + cache = LRUCache(maxsize=2) + cache["a"] = 1 + assert "a" in cache + assert "b" not in cache + cache["b"] = 2 + assert cache["b"] == 2 + + # accessing a makes it more recent than b + assert cache["a"] == 1 + assert "b" in cache + assert "a" in cache + + # storing c pushes oldest ('b') out of cache + cache["c"] = 3 + print(cache.maxsize, len(cache._cache)) + assert "a" in cache + assert "c" in cache + assert "b" not in cache + + +def test_lru_cache_key(): + + call_count = 0 + + @lru_cache_key(frozenset) + def reverse(arg): + nonlocal call_count + call_count += 1 + return list(reversed(arg)) + + in1 = [1, 2] + before = call_count + out1 = reverse(in1) + assert call_count == before + 1 + assert out1 == [2, 1] + + before = call_count + out2 = reverse(in1) + assert call_count == before + assert out2 is out1 + + +def test_do_not_cache(): + + call_count = 0 + + @lru_cache_key(lambda arg: arg) + def is_even(arg): + nonlocal call_count + call_count += 1 + if arg % 2: + raise DoNotCache(False) + return True + + before = call_count + assert is_even(0) == True + assert call_count == before + 1 + + # caches even results + before = call_count + assert is_even(0) == True + assert call_count == before + + before = call_count + assert is_even(1) == False + assert call_count == before + 1 + + # doesn't cache odd results + before = call_count + assert is_even(1) == False + assert call_count == before + 1 diff --git a/jupyterhub/tests/test_services_auth.py b/jupyterhub/tests/test_services_auth.py index 0d25ddf2..8cda0dbc 100644 --- a/jupyterhub/tests/test_services_auth.py +++ b/jupyterhub/tests/test_services_auth.py @@ -360,7 +360,7 @@ async def test_oauth_service_roles( ) if 'inherit' in expected_scopes: - expected_scopes = scopes.get_scopes_for(user.orm_user) + expected_scopes = set(scopes.get_scopes_for(user.orm_user)) # always expect identify/access scopes # on successful authentication From bb6427ea9b8e824d5eecbc9000d046af724f8168 Mon Sep 17 00:00:00 2001 From: Min RK Date: Wed, 30 Mar 2022 13:48:26 +0200 Subject: [PATCH 2/5] Add FrozenDict for caching parsed_scopes dicts Since we need them to be immutable --- jupyterhub/_memoize.py | 47 ++++++++++++++++++++++++++++++++ jupyterhub/scopes.py | 11 +++----- jupyterhub/tests/test_memoize.py | 21 +++++++++++++- 3 files changed, 71 insertions(+), 8 deletions(-) diff --git a/jupyterhub/_memoize.py b/jupyterhub/_memoize.py index 1ee2861e..a8167229 100644 --- a/jupyterhub/_memoize.py +++ b/jupyterhub/_memoize.py @@ -88,3 +88,50 @@ def lru_cache_key(key_func, maxsize=1024): return cached return cache_func + + +class FrozenDict(dict): + """A frozen dictionary subclass + + Immutable and hashable, so it can be used as a cache key + + Values will be frozen with `.freeze(value)` + and must be hashable after freezing. + + Not rigorous, but enough for our purposes. + """ + + _hash = None + + def __init__(self, d): + dict_set = dict.__setitem__ + for key, value in d.items(): + dict.__setitem__(self, key, self._freeze(value)) + + def _freeze(self, item): + """Make values of a dict hashable + - list, set -> frozenset + - dict -> recursive _FrozenDict + - anything else: assumed hashable + """ + if isinstance(item, FrozenDict): + return item + elif isinstance(item, (set, list)): + return frozenset(item) + elif isinstance(item, dict): + return FrozenDict(item) + else: + # any other type is assumed hashable + return item + + def __setitem__(self, key): + raise RuntimeError("Cannot modify frozen {type(self).__name__}") + + def update(self, other): + raise RuntimeError("Cannot modify frozen {type(self).__name__}") + + def __hash__(self): + """Cache hash because we are immutable""" + if self._hash is None: + self._hash = hash(tuple((key, value) for key, value in self.items())) + return self._hash diff --git a/jupyterhub/scopes.py b/jupyterhub/scopes.py index ccceec9a..e17ba38e 100644 --- a/jupyterhub/scopes.py +++ b/jupyterhub/scopes.py @@ -24,7 +24,7 @@ from tornado import web from tornado.log import app_log from . import orm, roles -from ._memoize import DoNotCache, lru_cache_key +from ._memoize import DoNotCache, FrozenDict, lru_cache_key """when modifying the scope definitions, make sure that `docs/source/rbac/generate-scope-table.py` is run so that changes are reflected in the documentation and REST API description.""" @@ -698,13 +698,10 @@ def parse_scopes(scope_list): parsed_scopes[base_scope][key] = {value} else: parsed_scopes[base_scope][key].add(value) - return parsed_scopes - - -# Note: it doesn't make sense to cache unparse_scopes -# because computing the cache key is as expensive as the function itself + return FrozenDict(parsed_scopes) +@lru_cache_key(FrozenDict) def unparse_scopes(parsed_scopes): """Turn a parsed_scopes dictionary back into a expanded scopes set""" expanded_scopes = set() @@ -722,7 +719,7 @@ def unparse_scopes(parsed_scopes): def reduce_scopes(expanded_scopes): """Reduce expanded scopes to minimal set - Eliminates redundancy, such as access:services and access:services!service=x + Eliminates overlapping scopes, such as access:services and access:services!service=x """ return unparse_scopes(parse_scopes(expanded_scopes)) diff --git a/jupyterhub/tests/test_memoize.py b/jupyterhub/tests/test_memoize.py index a53cc7d5..c7525479 100644 --- a/jupyterhub/tests/test_memoize.py +++ b/jupyterhub/tests/test_memoize.py @@ -1,4 +1,6 @@ -from jupyterhub._memoize import DoNotCache, LRUCache, lru_cache_key +import pytest + +from jupyterhub._memoize import DoNotCache, FrozenDict, LRUCache, lru_cache_key def test_lru_cache(): @@ -73,3 +75,20 @@ def test_do_not_cache(): before = call_count assert is_even(1) == False assert call_count == before + 1 + + +@pytest.mark.parametrize( + "d", + [ + {"key": "value"}, + {"key": ["list"]}, + {"key": {"set"}}, + {"key": ("tu", "ple")}, + {"key": {"nested": ["dict"]}}, + ], +) +def test_frozen_dict(d): + frozen_1 = FrozenDict(d) + frozen_2 = FrozenDict(d) + assert hash(frozen_1) == hash(frozen_2) + assert frozen_1 == frozen_2 From eebc0f485d1c3238c8b396a16b83ea9571d452a9 Mon Sep 17 00:00:00 2001 From: Min RK Date: Fri, 1 Apr 2022 11:29:51 +0200 Subject: [PATCH 3/5] Apply suggestions from code review Co-authored-by: Simon Li --- jupyterhub/_memoize.py | 4 +++- jupyterhub/scopes.py | 2 +- jupyterhub/tests/test_memoize.py | 2 +- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/jupyterhub/_memoize.py b/jupyterhub/_memoize.py index a8167229..6ff79110 100644 --- a/jupyterhub/_memoize.py +++ b/jupyterhub/_memoize.py @@ -116,7 +116,9 @@ class FrozenDict(dict): """ if isinstance(item, FrozenDict): return item - elif isinstance(item, (set, list)): + elif isinstance(item, list): + return tuple(self._freeze(e) for e in item) + elif isinstance(item, set): return frozenset(item) elif isinstance(item, dict): return FrozenDict(item) diff --git a/jupyterhub/scopes.py b/jupyterhub/scopes.py index e17ba38e..3997c9e2 100644 --- a/jupyterhub/scopes.py +++ b/jupyterhub/scopes.py @@ -662,7 +662,7 @@ def _check_token_scopes(scopes, owner): ) -@lru_cache_key(lambda scope_list: frozenset(scope_list)) +@lru_cache_key(frozenset) def parse_scopes(scope_list): """ Parses scopes and filters in something akin to JSON style diff --git a/jupyterhub/tests/test_memoize.py b/jupyterhub/tests/test_memoize.py index c7525479..0629cf57 100644 --- a/jupyterhub/tests/test_memoize.py +++ b/jupyterhub/tests/test_memoize.py @@ -18,7 +18,7 @@ def test_lru_cache(): # storing c pushes oldest ('b') out of cache cache["c"] = 3 - print(cache.maxsize, len(cache._cache)) + assert len(cache._cache) == 2 assert "a" in cache assert "c" in cache assert "b" not in cache From ab2913008e215de4e2a3b1fadd35317816c478c5 Mon Sep 17 00:00:00 2001 From: Min RK Date: Fri, 1 Apr 2022 11:54:32 +0200 Subject: [PATCH 4/5] more docs, comments, asserts about immutable scope functions --- jupyterhub/_memoize.py | 29 ++++++++++++++++++++++------- jupyterhub/scopes.py | 18 +++++++++++++----- jupyterhub/tests/test_memoize.py | 2 +- jupyterhub/tests/test_scopes.py | 6 ++++++ 4 files changed, 42 insertions(+), 13 deletions(-) diff --git a/jupyterhub/_memoize.py b/jupyterhub/_memoize.py index 6ff79110..21907b1d 100644 --- a/jupyterhub/_memoize.py +++ b/jupyterhub/_memoize.py @@ -1,10 +1,23 @@ -"""Utilities for memoization""" +"""Utilities for memoization + +Note: a memoized function should always return an _immutable_ +result to avoid later modifications polluting cached results. +""" from collections import OrderedDict from functools import wraps -class DoNotCache(Exception): - """Special exception to return a result without caching it""" +class DoNotCache: + """Wrapper to return a result without caching it. + + In a function decorated with `@lru_cache_key`: + + return DoNotCache(result) + + is equivalent to: + + return result # but don't cache it! + """ def __init__(self, result): self.result = result @@ -51,6 +64,9 @@ def lru_cache_key(key_func, maxsize=1024): or mutable objects where only immutable fields might be used (e.g. User, where only username affects output). + For safety: Cached results should always be immutable, + such as using `frozenset` instead of mutable `set`. + Example: @lru_cache_key(lambda user: user.name) @@ -76,11 +92,10 @@ def lru_cache_key(key_func, maxsize=1024): return cache[cache_key] else: # cache miss, call function and cache result - try: - result = func(*args, **kwargs) - except DoNotCache as e: + result = func(*args, **kwargs) + if isinstance(result, DoNotCache): # DoNotCache prevents caching - result = e.result + result = result.result else: cache[cache_key] = result return result diff --git a/jupyterhub/scopes.py b/jupyterhub/scopes.py index 3997c9e2..fbbd9256 100644 --- a/jupyterhub/scopes.py +++ b/jupyterhub/scopes.py @@ -1,6 +1,10 @@ """ General scope definitions and utilities +Scope functions generally return _immutable_ collections, +such as `frozenset` to avoid mutating cached values. +If needed, mutable copies can be made, e.g. `set(frozen_scopes)` + Scope variable nomenclature --------------------------- scopes or 'raw' scopes: collection of scopes that may contain abbreviations (e.g., in role definition) @@ -294,7 +298,7 @@ def _intersect_expanded_scopes(scopes_a, scopes_b, db=None): intersection = unparse_scopes(common_filters) if needs_db: # return intersection, but don't cache it if it needed db lookups - raise DoNotCache(intersection) + return DoNotCache(intersection) return intersection @@ -415,7 +419,7 @@ def _expand_self_scope(username): 'read:tokens', 'access:servers', ] - # return frozenset because the result is cached + # return immutable frozenset because the result is cached return frozenset(f"{scope}!user={username}" for scope in scope_list) @@ -460,7 +464,7 @@ def _expand_scope(scope): else: expanded_scopes = expanded_scope_names - # return frozenset because we are cached + # return immutable frozenset because the result is cached return frozenset(expanded_scopes) @@ -528,7 +532,8 @@ def expand_scopes(scopes, owner=None): ) # reduce to discard overlapping scopes - return reduce_scopes(expanded_scopes) + # return immutable frozenset because the result is cached + return frozenset(reduce_scopes(expanded_scopes)) def _needs_scope_expansion(filter_, filter_value, sub_scope): @@ -698,6 +703,7 @@ def parse_scopes(scope_list): parsed_scopes[base_scope][key] = {value} else: parsed_scopes[base_scope][key].add(value) + # return immutable FrozenDict because the result is cached return FrozenDict(parsed_scopes) @@ -712,6 +718,7 @@ def unparse_scopes(parsed_scopes): for entity, names_list in filters.items(): for name in names_list: expanded_scopes.add(f'{base}!{entity}={name}') + # return immutable frozenset because the result is cached return frozenset(expanded_scopes) @@ -721,6 +728,7 @@ def reduce_scopes(expanded_scopes): Eliminates overlapping scopes, such as access:services and access:services!service=x """ + # unparse_scopes already returns a frozenset return unparse_scopes(parse_scopes(expanded_scopes)) @@ -866,7 +874,7 @@ def check_scope_filter(sub_scope, orm_resource, kind): group_names = {group.name for group in orm_resource.groups} user_in_group = bool(group_names & set(sub_scope['group'])) # cannot cache if we needed to lookup groups in db - raise DoNotCache(user_in_group) + return DoNotCache(user_in_group) return False diff --git a/jupyterhub/tests/test_memoize.py b/jupyterhub/tests/test_memoize.py index 0629cf57..c37942a8 100644 --- a/jupyterhub/tests/test_memoize.py +++ b/jupyterhub/tests/test_memoize.py @@ -55,7 +55,7 @@ def test_do_not_cache(): nonlocal call_count call_count += 1 if arg % 2: - raise DoNotCache(False) + return DoNotCache(False) return True before = call_count diff --git a/jupyterhub/tests/test_scopes.py b/jupyterhub/tests/test_scopes.py index f8a08b4b..2f40d949 100644 --- a/jupyterhub/tests/test_scopes.py +++ b/jupyterhub/tests/test_scopes.py @@ -8,6 +8,7 @@ from tornado import web from tornado.httputil import HTTPServerRequest from .. import orm, roles, scopes +from .._memoize import FrozenDict from ..handlers import BaseHandler from ..scopes import ( Scope, @@ -38,6 +39,7 @@ def test_scope_constructor(): f'read:users!user={user2}', ] parsed_scopes = parse_scopes(scope_list) + assert isinstance(parsed_scopes, FrozenDict) assert 'read:users' in parsed_scopes assert parsed_scopes['users'] @@ -467,6 +469,7 @@ async def test_metascope_self_expansion( orm_obj = create_service_with_scopes('self') # test expansion of user/service scopes scopes = get_scopes_for(orm_obj) + assert isinstance(scopes, frozenset) assert bool(scopes) == has_user_scopes # test expansion of token scopes @@ -488,6 +491,8 @@ async def test_metascope_inherit_expansion(app, create_user_with_scopes): token.scopes.clear() app.db.commit() token_scope_set = get_scopes_for(token) + assert isinstance(token_scope_set, frozenset) + assert token_scope_set.issubset(identify_scopes(user.orm_user)) @@ -1169,4 +1174,5 @@ def test_expand_scopes(user, scopes, expected): expected.update(_expand_self_scope(user.name)) expanded = expand_scopes(scopes, owner=user.orm_user) + assert isinstance(expanded, frozenset) assert sorted(expanded) == sorted(expected) From ff020cb5a47e2e54f028834043828ac32d4557e4 Mon Sep 17 00:00:00 2001 From: Min RK Date: Thu, 7 Apr 2022 09:42:25 +0200 Subject: [PATCH 5/5] needs_db typo Co-authored-by: Simon Li --- jupyterhub/scopes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jupyterhub/scopes.py b/jupyterhub/scopes.py index fbbd9256..1d83ba60 100644 --- a/jupyterhub/scopes.py +++ b/jupyterhub/scopes.py @@ -267,7 +267,7 @@ def _intersect_expanded_scopes(scopes_a, scopes_b, db=None): # resolve group/server hierarchy if db available servers = servers.difference(common_servers) if db is not None and servers and 'group' in b: - _needs_db = True + needs_db = True for server in servers: server_groups = groups_for_server(server) if server_groups & b['group']: