Merge pull request #3850 from minrk/cache-scopes

memoize some scope functions
This commit is contained in:
Min RK
2022-04-07 12:56:19 +02:00
committed by GitHub
8 changed files with 361 additions and 25 deletions

154
jupyterhub/_memoize.py Normal file
View File

@@ -0,0 +1,154 @@
"""Utilities for memoization
Note: a memoized function should always return an _immutable_
result to avoid later modifications polluting cached results.
"""
from collections import OrderedDict
from functools import wraps
class DoNotCache:
"""Wrapper to return a result without caching it.
In a function decorated with `@lru_cache_key`:
return DoNotCache(result)
is equivalent to:
return result # but don't cache it!
"""
def __init__(self, result):
self.result = result
class LRUCache:
"""A simple Least-Recently-Used (LRU) cache with a max size"""
def __init__(self, maxsize=1024):
self._cache = OrderedDict()
self.maxsize = maxsize
def __contains__(self, key):
return key in self._cache
def get(self, key, default=None):
"""Get an item from the cache"""
if key in self._cache:
# cache hit, bump to front of the queue for LRU
result = self._cache[key]
self._cache.move_to_end(key)
return result
return default
def set(self, key, value):
"""Store an entry in the cache
Purges oldest entry if cache is full
"""
self._cache[key] = value
# cache is full, purge oldest entry
if len(self._cache) > self.maxsize:
self._cache.popitem(last=False)
__getitem__ = get
__setitem__ = set
def lru_cache_key(key_func, maxsize=1024):
"""Like functools.lru_cache, but takes a custom key function,
as seen in sorted(key=func).
Useful for non-hashable arguments which have a known hashable equivalent (e.g. sets, lists),
or mutable objects where only immutable fields might be used
(e.g. User, where only username affects output).
For safety: Cached results should always be immutable,
such as using `frozenset` instead of mutable `set`.
Example:
@lru_cache_key(lambda user: user.name)
def func_user(user):
# output only varies by name
Args:
key (callable):
Should have the same signature as the decorated function.
Returns a hashable key to use in the cache
maxsize (int):
The maximum size of the cache.
"""
def cache_func(func):
cache = LRUCache(maxsize=maxsize)
# the actual decorated function:
@wraps(func)
def cached(*args, **kwargs):
cache_key = key_func(*args, **kwargs)
if cache_key in cache:
# cache hit
return cache[cache_key]
else:
# cache miss, call function and cache result
result = func(*args, **kwargs)
if isinstance(result, DoNotCache):
# DoNotCache prevents caching
result = result.result
else:
cache[cache_key] = result
return result
return cached
return cache_func
class FrozenDict(dict):
"""A frozen dictionary subclass
Immutable and hashable, so it can be used as a cache key
Values will be frozen with `.freeze(value)`
and must be hashable after freezing.
Not rigorous, but enough for our purposes.
"""
_hash = None
def __init__(self, d):
dict_set = dict.__setitem__
for key, value in d.items():
dict.__setitem__(self, key, self._freeze(value))
def _freeze(self, item):
"""Make values of a dict hashable
- list, set -> frozenset
- dict -> recursive _FrozenDict
- anything else: assumed hashable
"""
if isinstance(item, FrozenDict):
return item
elif isinstance(item, list):
return tuple(self._freeze(e) for e in item)
elif isinstance(item, set):
return frozenset(item)
elif isinstance(item, dict):
return FrozenDict(item)
else:
# any other type is assumed hashable
return item
def __setitem__(self, key):
raise RuntimeError("Cannot modify frozen {type(self).__name__}")
def update(self, other):
raise RuntimeError("Cannot modify frozen {type(self).__name__}")
def __hash__(self):
"""Cache hash because we are immutable"""
if self._hash is None:
self._hash = hash(tuple((key, value) for key, value in self.items()))
return self._hash

View File

@@ -1,7 +1,6 @@
"""Authorization handlers""" """Authorization handlers"""
# Copyright (c) Jupyter Development Team. # Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License. # Distributed under the terms of the Modified BSD License.
import itertools
import json import json
from datetime import datetime from datetime import datetime
from urllib.parse import parse_qsl, quote, urlencode, urlparse, urlunparse from urllib.parse import parse_qsl, quote, urlencode, urlparse, urlunparse
@@ -30,7 +29,7 @@ class TokenAPIHandler(APIHandler):
if owner: if owner:
# having a token means we should be able to read the owner's model # having a token means we should be able to read the owner's model
# (this is the only thing this handler is for) # (this is the only thing this handler is for)
self.expanded_scopes.update(scopes.identify_scopes(owner)) self.expanded_scopes |= scopes.identify_scopes(owner)
self.parsed_scopes = scopes.parse_scopes(self.expanded_scopes) self.parsed_scopes = scopes.parse_scopes(self.expanded_scopes)
# record activity whenever we see a token # record activity whenever we see a token
@@ -288,7 +287,7 @@ class OAuthAuthorizeHandler(OAuthHandler, BaseHandler):
# rather than the expanded_scope intersection # rather than the expanded_scope intersection
required_scopes = {*scopes.identify_scopes(), *scopes.access_scopes(client)} required_scopes = {*scopes.identify_scopes(), *scopes.access_scopes(client)}
user_scopes.update({"inherit", *required_scopes}) user_scopes |= {"inherit", *required_scopes}
allowed_scopes = requested_scopes.intersection(user_scopes) allowed_scopes = requested_scopes.intersection(user_scopes)
excluded_scopes = requested_scopes.difference(user_scopes) excluded_scopes = requested_scopes.difference(user_scopes)

View File

@@ -44,7 +44,7 @@ class SelfAPIHandler(APIHandler):
for scope in identify_scopes: for scope in identify_scopes:
if scope not in self.expanded_scopes: if scope not in self.expanded_scopes:
_added_scopes.add(scope) _added_scopes.add(scope)
self.expanded_scopes.add(scope) self.expanded_scopes |= {scope}
if _added_scopes: if _added_scopes:
# re-parse with new scopes # re-parse with new scopes
self.parsed_scopes = scopes.parse_scopes(self.expanded_scopes) self.parsed_scopes = scopes.parse_scopes(self.expanded_scopes)

View File

@@ -154,9 +154,8 @@ class JupyterHubRequestValidator(RequestValidator):
scopes = roles_to_scopes(orm_client.allowed_roles) scopes = roles_to_scopes(orm_client.allowed_roles)
if 'inherit' not in scopes: if 'inherit' not in scopes:
# add identify-user scope # add identify-user scope
scopes.update(identify_scopes()) # and access-service scope
# add access-service scope scopes |= identify_scopes() | access_scopes(orm_client)
scopes.update(access_scopes(orm_client))
return scopes return scopes
def get_original_scopes(self, refresh_token, request, *args, **kwargs): def get_original_scopes(self, refresh_token, request, *args, **kwargs):

View File

@@ -1,6 +1,10 @@
""" """
General scope definitions and utilities General scope definitions and utilities
Scope functions generally return _immutable_ collections,
such as `frozenset` to avoid mutating cached values.
If needed, mutable copies can be made, e.g. `set(frozen_scopes)`
Scope variable nomenclature Scope variable nomenclature
--------------------------- ---------------------------
scopes or 'raw' scopes: collection of scopes that may contain abbreviations (e.g., in role definition) scopes or 'raw' scopes: collection of scopes that may contain abbreviations (e.g., in role definition)
@@ -24,6 +28,7 @@ from tornado import web
from tornado.log import app_log from tornado.log import app_log
from . import orm, roles from . import orm, roles
from ._memoize import DoNotCache, FrozenDict, lru_cache_key
"""when modifying the scope definitions, make sure that `docs/source/rbac/generate-scope-table.py` is run """when modifying the scope definitions, make sure that `docs/source/rbac/generate-scope-table.py` is run
so that changes are reflected in the documentation and REST API description.""" so that changes are reflected in the documentation and REST API description."""
@@ -144,6 +149,12 @@ class Scope(Enum):
ALL = True ALL = True
def _intersection_cache_key(scopes_a, scopes_b, db=None):
"""Cache key function for scope intersections"""
return (frozenset(scopes_a), frozenset(scopes_b))
@lru_cache_key(_intersection_cache_key)
def _intersect_expanded_scopes(scopes_a, scopes_b, db=None): def _intersect_expanded_scopes(scopes_a, scopes_b, db=None):
"""Intersect two sets of scopes by comparing their permissions """Intersect two sets of scopes by comparing their permissions
@@ -159,11 +170,16 @@ def _intersect_expanded_scopes(scopes_a, scopes_b, db=None):
(i.e. users!group=x & users!user=y will be empty, even if user y is in group x.) (i.e. users!group=x & users!user=y will be empty, even if user y is in group x.)
""" """
empty_set = frozenset() empty_set = frozenset()
scopes_a = frozenset(scopes_a)
scopes_b = frozenset(scopes_b)
# cached lookups for group membership of users and servers # cached lookups for group membership of users and servers
@lru_cache() @lru_cache()
def groups_for_user(username): def groups_for_user(username):
"""Get set of group names for a given username""" """Get set of group names for a given username"""
# if we need a group lookup, the result is not cacheable
nonlocal needs_db
needs_db = True
user = db.query(orm.User).filter_by(name=username).first() user = db.query(orm.User).filter_by(name=username).first()
if user is None: if user is None:
return empty_set return empty_set
@@ -179,6 +195,11 @@ def _intersect_expanded_scopes(scopes_a, scopes_b, db=None):
parsed_scopes_a = parse_scopes(scopes_a) parsed_scopes_a = parse_scopes(scopes_a)
parsed_scopes_b = parse_scopes(scopes_b) parsed_scopes_b = parse_scopes(scopes_b)
# track whether we need a db lookup (for groups)
# because we can't cache the intersection if we do
# if there are no group filters, this is cacheable
needs_db = False
common_bases = parsed_scopes_a.keys() & parsed_scopes_b.keys() common_bases = parsed_scopes_a.keys() & parsed_scopes_b.keys()
common_filters = {} common_filters = {}
@@ -220,6 +241,7 @@ def _intersect_expanded_scopes(scopes_a, scopes_b, db=None):
UserWarning, UserWarning,
) )
warned = True warned = True
needs_db = True
common_filters[base] = { common_filters[base] = {
entity: filters_a[entity] & filters_b[entity] entity: filters_a[entity] & filters_b[entity]
@@ -245,6 +267,7 @@ def _intersect_expanded_scopes(scopes_a, scopes_b, db=None):
# resolve group/server hierarchy if db available # resolve group/server hierarchy if db available
servers = servers.difference(common_servers) servers = servers.difference(common_servers)
if db is not None and servers and 'group' in b: if db is not None and servers and 'group' in b:
needs_db = True
for server in servers: for server in servers:
server_groups = groups_for_server(server) server_groups = groups_for_server(server)
if server_groups & b['group']: if server_groups & b['group']:
@@ -272,7 +295,12 @@ def _intersect_expanded_scopes(scopes_a, scopes_b, db=None):
if common_users and "user" not in common_filters[base]: if common_users and "user" not in common_filters[base]:
common_filters[base]["user"] = common_users common_filters[base]["user"] = common_users
return unparse_scopes(common_filters) intersection = unparse_scopes(common_filters)
if needs_db:
# return intersection, but don't cache it if it needed db lookups
return DoNotCache(intersection)
return intersection
def get_scopes_for(orm_object): def get_scopes_for(orm_object):
@@ -313,7 +341,7 @@ def get_scopes_for(orm_object):
# only thing we miss by short-circuiting here: warning about excluded extra scopes # only thing we miss by short-circuiting here: warning about excluded extra scopes
return owner_scopes return owner_scopes
token_scopes = expand_scopes(token_scopes, owner=owner) token_scopes = set(expand_scopes(token_scopes, owner=owner))
if orm_object.client_id != "jupyterhub": if orm_object.client_id != "jupyterhub":
# oauth tokens can be used to access the service issuing the token, # oauth tokens can be used to access the service issuing the token,
@@ -358,6 +386,7 @@ def get_scopes_for(orm_object):
return expanded_scopes return expanded_scopes
@lru_cache()
def _expand_self_scope(username): def _expand_self_scope(username):
""" """
Users have a metascope 'self' that should be expanded to standard user privileges. Users have a metascope 'self' that should be expanded to standard user privileges.
@@ -390,9 +419,11 @@ def _expand_self_scope(username):
'read:tokens', 'read:tokens',
'access:servers', 'access:servers',
] ]
return {f"{scope}!user={username}" for scope in scope_list} # return immutable frozenset because the result is cached
return frozenset(f"{scope}!user={username}" for scope in scope_list)
@lru_cache(maxsize=65535)
def _expand_scope(scope): def _expand_scope(scope):
"""Returns a scope and all all subscopes """Returns a scope and all all subscopes
@@ -433,9 +464,30 @@ def _expand_scope(scope):
else: else:
expanded_scopes = expanded_scope_names expanded_scopes = expanded_scope_names
return expanded_scopes # return immutable frozenset because the result is cached
return frozenset(expanded_scopes)
def _expand_scopes_key(scopes, owner=None):
"""Cache key function for expand_scopes
scopes is usually a mutable list or set,
which can be hashed as a frozenset
For the owner, we only care about what kind they are,
and their name.
"""
# freeze scopes for hash
frozen_scopes = frozenset(scopes)
if owner is None:
owner_key = None
else:
# owner key is the type and name
owner_key = (type(owner).__name__, owner.name)
return (frozen_scopes, owner_key)
@lru_cache_key(_expand_scopes_key)
def expand_scopes(scopes, owner=None): def expand_scopes(scopes, owner=None):
"""Returns a set of fully expanded scopes for a collection of raw scopes """Returns a set of fully expanded scopes for a collection of raw scopes
@@ -479,8 +531,9 @@ def expand_scopes(scopes, owner=None):
stacklevel=2, stacklevel=2,
) )
# reduce to minimize # reduce to discard overlapping scopes
return reduce_scopes(expanded_scopes) # return immutable frozenset because the result is cached
return frozenset(reduce_scopes(expanded_scopes))
def _needs_scope_expansion(filter_, filter_value, sub_scope): def _needs_scope_expansion(filter_, filter_value, sub_scope):
@@ -614,6 +667,7 @@ def _check_token_scopes(scopes, owner):
) )
@lru_cache_key(frozenset)
def parse_scopes(scope_list): def parse_scopes(scope_list):
""" """
Parses scopes and filters in something akin to JSON style Parses scopes and filters in something akin to JSON style
@@ -649,9 +703,11 @@ def parse_scopes(scope_list):
parsed_scopes[base_scope][key] = {value} parsed_scopes[base_scope][key] = {value}
else: else:
parsed_scopes[base_scope][key].add(value) parsed_scopes[base_scope][key].add(value)
return parsed_scopes # return immutable FrozenDict because the result is cached
return FrozenDict(parsed_scopes)
@lru_cache_key(FrozenDict)
def unparse_scopes(parsed_scopes): def unparse_scopes(parsed_scopes):
"""Turn a parsed_scopes dictionary back into a expanded scopes set""" """Turn a parsed_scopes dictionary back into a expanded scopes set"""
expanded_scopes = set() expanded_scopes = set()
@@ -662,14 +718,17 @@ def unparse_scopes(parsed_scopes):
for entity, names_list in filters.items(): for entity, names_list in filters.items():
for name in names_list: for name in names_list:
expanded_scopes.add(f'{base}!{entity}={name}') expanded_scopes.add(f'{base}!{entity}={name}')
return expanded_scopes # return immutable frozenset because the result is cached
return frozenset(expanded_scopes)
@lru_cache_key(frozenset)
def reduce_scopes(expanded_scopes): def reduce_scopes(expanded_scopes):
"""Reduce expanded scopes to minimal set """Reduce expanded scopes to minimal set
Eliminates redundancy, such as access:services and access:services!service=x Eliminates overlapping scopes, such as access:services and access:services!service=x
""" """
# unparse_scopes already returns a frozenset
return unparse_scopes(parse_scopes(expanded_scopes)) return unparse_scopes(parse_scopes(expanded_scopes))
@@ -723,6 +782,14 @@ def needs_scope(*scopes):
return scope_decorator return scope_decorator
def _identify_key(obj=None):
if obj is None:
return None
else:
return (type(obj).__name__, obj.name)
@lru_cache_key(_identify_key)
def identify_scopes(obj=None): def identify_scopes(obj=None):
"""Return 'identify' scopes for an orm object """Return 'identify' scopes for an orm object
@@ -735,20 +802,25 @@ def identify_scopes(obj=None):
identify scopes (set): set of scopes needed for 'identify' endpoints identify scopes (set): set of scopes needed for 'identify' endpoints
""" """
if obj is None: if obj is None:
return {f"read:users:{field}!user" for field in {"name", "groups"}} return frozenset(f"read:users:{field}!user" for field in {"name", "groups"})
elif isinstance(obj, orm.User): elif isinstance(obj, orm.User):
return {f"read:users:{field}!user={obj.name}" for field in {"name", "groups"}} return frozenset(
f"read:users:{field}!user={obj.name}" for field in {"name", "groups"}
)
elif isinstance(obj, orm.Service): elif isinstance(obj, orm.Service):
return {f"read:services:{field}!service={obj.name}" for field in {"name"}} return frozenset(
f"read:services:{field}!service={obj.name}" for field in {"name"}
)
else: else:
raise TypeError(f"Expected orm.User or orm.Service, got {obj!r}") raise TypeError(f"Expected orm.User or orm.Service, got {obj!r}")
@lru_cache_key(lambda oauth_client: oauth_client.identifier)
def access_scopes(oauth_client): def access_scopes(oauth_client):
"""Return scope(s) required to access an oauth client""" """Return scope(s) required to access an oauth client"""
scopes = set() scopes = set()
if oauth_client.identifier == "jupyterhub": if oauth_client.identifier == "jupyterhub":
return scopes return frozenset()
spawner = oauth_client.spawner spawner = oauth_client.spawner
if spawner: if spawner:
scopes.add(f"access:servers!server={spawner.user.name}/{spawner.name}") scopes.add(f"access:servers!server={spawner.user.name}/{spawner.name}")
@@ -760,9 +832,19 @@ def access_scopes(oauth_client):
app_log.warning( app_log.warning(
f"OAuth client {oauth_client} has no associated service or spawner!" f"OAuth client {oauth_client} has no associated service or spawner!"
) )
return scopes return frozenset(scopes)
def _check_scope_key(sub_scope, orm_resource, kind):
"""Cache key function for check_scope_filter"""
if kind == 'server':
resource_key = (orm_resource.user.name, orm_resource.name)
else:
resource_key = orm_resource.name
return (sub_scope, resource_key, kind)
@lru_cache_key(_check_scope_key)
def check_scope_filter(sub_scope, orm_resource, kind): def check_scope_filter(sub_scope, orm_resource, kind):
"""Return whether a sub_scope filter applies to a given resource. """Return whether a sub_scope filter applies to a given resource.
@@ -791,8 +873,8 @@ def check_scope_filter(sub_scope, orm_resource, kind):
if kind == 'user' and 'group' in sub_scope: if kind == 'user' and 'group' in sub_scope:
group_names = {group.name for group in orm_resource.groups} group_names = {group.name for group in orm_resource.groups}
user_in_group = bool(group_names & set(sub_scope['group'])) user_in_group = bool(group_names & set(sub_scope['group']))
if user_in_group: # cannot cache if we needed to lookup groups in db
return True return DoNotCache(user_in_group)
return False return False
@@ -831,6 +913,7 @@ def describe_parsed_scopes(parsed_scopes, username=None):
return descriptions return descriptions
@lru_cache_key(lambda raw_scopes, username=None: (frozenset(raw_scopes), username))
def describe_raw_scopes(raw_scopes, username=None): def describe_raw_scopes(raw_scopes, username=None):
"""Return list of descriptions of raw scopes """Return list of descriptions of raw scopes
@@ -861,7 +944,8 @@ def describe_raw_scopes(raw_scopes, username=None):
"filter": filter_text, "filter": filter_text,
} }
) )
return descriptions # make sure we return immutable from a cached function
return tuple(descriptions)
# regex for custom scope # regex for custom scope

View File

@@ -0,0 +1,94 @@
import pytest
from jupyterhub._memoize import DoNotCache, FrozenDict, LRUCache, lru_cache_key
def test_lru_cache():
cache = LRUCache(maxsize=2)
cache["a"] = 1
assert "a" in cache
assert "b" not in cache
cache["b"] = 2
assert cache["b"] == 2
# accessing a makes it more recent than b
assert cache["a"] == 1
assert "b" in cache
assert "a" in cache
# storing c pushes oldest ('b') out of cache
cache["c"] = 3
assert len(cache._cache) == 2
assert "a" in cache
assert "c" in cache
assert "b" not in cache
def test_lru_cache_key():
call_count = 0
@lru_cache_key(frozenset)
def reverse(arg):
nonlocal call_count
call_count += 1
return list(reversed(arg))
in1 = [1, 2]
before = call_count
out1 = reverse(in1)
assert call_count == before + 1
assert out1 == [2, 1]
before = call_count
out2 = reverse(in1)
assert call_count == before
assert out2 is out1
def test_do_not_cache():
call_count = 0
@lru_cache_key(lambda arg: arg)
def is_even(arg):
nonlocal call_count
call_count += 1
if arg % 2:
return DoNotCache(False)
return True
before = call_count
assert is_even(0) == True
assert call_count == before + 1
# caches even results
before = call_count
assert is_even(0) == True
assert call_count == before
before = call_count
assert is_even(1) == False
assert call_count == before + 1
# doesn't cache odd results
before = call_count
assert is_even(1) == False
assert call_count == before + 1
@pytest.mark.parametrize(
"d",
[
{"key": "value"},
{"key": ["list"]},
{"key": {"set"}},
{"key": ("tu", "ple")},
{"key": {"nested": ["dict"]}},
],
)
def test_frozen_dict(d):
frozen_1 = FrozenDict(d)
frozen_2 = FrozenDict(d)
assert hash(frozen_1) == hash(frozen_2)
assert frozen_1 == frozen_2

View File

@@ -8,6 +8,7 @@ from tornado import web
from tornado.httputil import HTTPServerRequest from tornado.httputil import HTTPServerRequest
from .. import orm, roles, scopes from .. import orm, roles, scopes
from .._memoize import FrozenDict
from ..handlers import BaseHandler from ..handlers import BaseHandler
from ..scopes import ( from ..scopes import (
Scope, Scope,
@@ -38,6 +39,7 @@ def test_scope_constructor():
f'read:users!user={user2}', f'read:users!user={user2}',
] ]
parsed_scopes = parse_scopes(scope_list) parsed_scopes = parse_scopes(scope_list)
assert isinstance(parsed_scopes, FrozenDict)
assert 'read:users' in parsed_scopes assert 'read:users' in parsed_scopes
assert parsed_scopes['users'] assert parsed_scopes['users']
@@ -467,6 +469,7 @@ async def test_metascope_self_expansion(
orm_obj = create_service_with_scopes('self') orm_obj = create_service_with_scopes('self')
# test expansion of user/service scopes # test expansion of user/service scopes
scopes = get_scopes_for(orm_obj) scopes = get_scopes_for(orm_obj)
assert isinstance(scopes, frozenset)
assert bool(scopes) == has_user_scopes assert bool(scopes) == has_user_scopes
# test expansion of token scopes # test expansion of token scopes
@@ -488,6 +491,8 @@ async def test_metascope_inherit_expansion(app, create_user_with_scopes):
token.scopes.clear() token.scopes.clear()
app.db.commit() app.db.commit()
token_scope_set = get_scopes_for(token) token_scope_set = get_scopes_for(token)
assert isinstance(token_scope_set, frozenset)
assert token_scope_set.issubset(identify_scopes(user.orm_user)) assert token_scope_set.issubset(identify_scopes(user.orm_user))
@@ -1169,4 +1174,5 @@ def test_expand_scopes(user, scopes, expected):
expected.update(_expand_self_scope(user.name)) expected.update(_expand_self_scope(user.name))
expanded = expand_scopes(scopes, owner=user.orm_user) expanded = expand_scopes(scopes, owner=user.orm_user)
assert isinstance(expanded, frozenset)
assert sorted(expanded) == sorted(expected) assert sorted(expanded) == sorted(expected)

View File

@@ -360,7 +360,7 @@ async def test_oauth_service_roles(
) )
if 'inherit' in expected_scopes: if 'inherit' in expected_scopes:
expected_scopes = scopes.get_scopes_for(user.orm_user) expected_scopes = set(scopes.get_scopes_for(user.orm_user))
# always expect identify/access scopes # always expect identify/access scopes
# on successful authentication # on successful authentication