diff --git a/jupyterhub/_memoize.py b/jupyterhub/_memoize.py new file mode 100644 index 00000000..21907b1d --- /dev/null +++ b/jupyterhub/_memoize.py @@ -0,0 +1,154 @@ +"""Utilities for memoization + +Note: a memoized function should always return an _immutable_ +result to avoid later modifications polluting cached results. +""" +from collections import OrderedDict +from functools import wraps + + +class DoNotCache: + """Wrapper to return a result without caching it. + + In a function decorated with `@lru_cache_key`: + + return DoNotCache(result) + + is equivalent to: + + return result # but don't cache it! + """ + + def __init__(self, result): + self.result = result + + +class LRUCache: + """A simple Least-Recently-Used (LRU) cache with a max size""" + + def __init__(self, maxsize=1024): + self._cache = OrderedDict() + self.maxsize = maxsize + + def __contains__(self, key): + return key in self._cache + + def get(self, key, default=None): + """Get an item from the cache""" + if key in self._cache: + # cache hit, bump to front of the queue for LRU + result = self._cache[key] + self._cache.move_to_end(key) + return result + return default + + def set(self, key, value): + """Store an entry in the cache + + Purges oldest entry if cache is full + """ + self._cache[key] = value + # cache is full, purge oldest entry + if len(self._cache) > self.maxsize: + self._cache.popitem(last=False) + + __getitem__ = get + __setitem__ = set + + +def lru_cache_key(key_func, maxsize=1024): + """Like functools.lru_cache, but takes a custom key function, + as seen in sorted(key=func). + + Useful for non-hashable arguments which have a known hashable equivalent (e.g. sets, lists), + or mutable objects where only immutable fields might be used + (e.g. User, where only username affects output). + + For safety: Cached results should always be immutable, + such as using `frozenset` instead of mutable `set`. + + Example: + + @lru_cache_key(lambda user: user.name) + def func_user(user): + # output only varies by name + + Args: + key (callable): + Should have the same signature as the decorated function. + Returns a hashable key to use in the cache + maxsize (int): + The maximum size of the cache. + """ + + def cache_func(func): + cache = LRUCache(maxsize=maxsize) + # the actual decorated function: + @wraps(func) + def cached(*args, **kwargs): + cache_key = key_func(*args, **kwargs) + if cache_key in cache: + # cache hit + return cache[cache_key] + else: + # cache miss, call function and cache result + result = func(*args, **kwargs) + if isinstance(result, DoNotCache): + # DoNotCache prevents caching + result = result.result + else: + cache[cache_key] = result + return result + + return cached + + return cache_func + + +class FrozenDict(dict): + """A frozen dictionary subclass + + Immutable and hashable, so it can be used as a cache key + + Values will be frozen with `.freeze(value)` + and must be hashable after freezing. + + Not rigorous, but enough for our purposes. + """ + + _hash = None + + def __init__(self, d): + dict_set = dict.__setitem__ + for key, value in d.items(): + dict.__setitem__(self, key, self._freeze(value)) + + def _freeze(self, item): + """Make values of a dict hashable + - list, set -> frozenset + - dict -> recursive _FrozenDict + - anything else: assumed hashable + """ + if isinstance(item, FrozenDict): + return item + elif isinstance(item, list): + return tuple(self._freeze(e) for e in item) + elif isinstance(item, set): + return frozenset(item) + elif isinstance(item, dict): + return FrozenDict(item) + else: + # any other type is assumed hashable + return item + + def __setitem__(self, key): + raise RuntimeError("Cannot modify frozen {type(self).__name__}") + + def update(self, other): + raise RuntimeError("Cannot modify frozen {type(self).__name__}") + + def __hash__(self): + """Cache hash because we are immutable""" + if self._hash is None: + self._hash = hash(tuple((key, value) for key, value in self.items())) + return self._hash diff --git a/jupyterhub/apihandlers/auth.py b/jupyterhub/apihandlers/auth.py index 643e3a3a..11171752 100644 --- a/jupyterhub/apihandlers/auth.py +++ b/jupyterhub/apihandlers/auth.py @@ -1,7 +1,6 @@ """Authorization handlers""" # Copyright (c) Jupyter Development Team. # Distributed under the terms of the Modified BSD License. -import itertools import json from datetime import datetime from urllib.parse import parse_qsl, quote, urlencode, urlparse, urlunparse @@ -30,7 +29,7 @@ class TokenAPIHandler(APIHandler): if owner: # having a token means we should be able to read the owner's model # (this is the only thing this handler is for) - self.expanded_scopes.update(scopes.identify_scopes(owner)) + self.expanded_scopes |= scopes.identify_scopes(owner) self.parsed_scopes = scopes.parse_scopes(self.expanded_scopes) # record activity whenever we see a token @@ -288,7 +287,7 @@ class OAuthAuthorizeHandler(OAuthHandler, BaseHandler): # rather than the expanded_scope intersection required_scopes = {*scopes.identify_scopes(), *scopes.access_scopes(client)} - user_scopes.update({"inherit", *required_scopes}) + user_scopes |= {"inherit", *required_scopes} allowed_scopes = requested_scopes.intersection(user_scopes) excluded_scopes = requested_scopes.difference(user_scopes) diff --git a/jupyterhub/apihandlers/users.py b/jupyterhub/apihandlers/users.py index 6341c3b1..5e1ea285 100644 --- a/jupyterhub/apihandlers/users.py +++ b/jupyterhub/apihandlers/users.py @@ -44,7 +44,7 @@ class SelfAPIHandler(APIHandler): for scope in identify_scopes: if scope not in self.expanded_scopes: _added_scopes.add(scope) - self.expanded_scopes.add(scope) + self.expanded_scopes |= {scope} if _added_scopes: # re-parse with new scopes self.parsed_scopes = scopes.parse_scopes(self.expanded_scopes) diff --git a/jupyterhub/oauth/provider.py b/jupyterhub/oauth/provider.py index 3e1c6e0e..05637142 100644 --- a/jupyterhub/oauth/provider.py +++ b/jupyterhub/oauth/provider.py @@ -154,9 +154,8 @@ class JupyterHubRequestValidator(RequestValidator): scopes = roles_to_scopes(orm_client.allowed_roles) if 'inherit' not in scopes: # add identify-user scope - scopes.update(identify_scopes()) - # add access-service scope - scopes.update(access_scopes(orm_client)) + # and access-service scope + scopes |= identify_scopes() | access_scopes(orm_client) return scopes def get_original_scopes(self, refresh_token, request, *args, **kwargs): diff --git a/jupyterhub/scopes.py b/jupyterhub/scopes.py index 25db5c7e..1d83ba60 100644 --- a/jupyterhub/scopes.py +++ b/jupyterhub/scopes.py @@ -1,6 +1,10 @@ """ General scope definitions and utilities +Scope functions generally return _immutable_ collections, +such as `frozenset` to avoid mutating cached values. +If needed, mutable copies can be made, e.g. `set(frozen_scopes)` + Scope variable nomenclature --------------------------- scopes or 'raw' scopes: collection of scopes that may contain abbreviations (e.g., in role definition) @@ -24,6 +28,7 @@ from tornado import web from tornado.log import app_log from . import orm, roles +from ._memoize import DoNotCache, FrozenDict, lru_cache_key """when modifying the scope definitions, make sure that `docs/source/rbac/generate-scope-table.py` is run so that changes are reflected in the documentation and REST API description.""" @@ -144,6 +149,12 @@ class Scope(Enum): ALL = True +def _intersection_cache_key(scopes_a, scopes_b, db=None): + """Cache key function for scope intersections""" + return (frozenset(scopes_a), frozenset(scopes_b)) + + +@lru_cache_key(_intersection_cache_key) def _intersect_expanded_scopes(scopes_a, scopes_b, db=None): """Intersect two sets of scopes by comparing their permissions @@ -159,11 +170,16 @@ def _intersect_expanded_scopes(scopes_a, scopes_b, db=None): (i.e. users!group=x & users!user=y will be empty, even if user y is in group x.) """ empty_set = frozenset() + scopes_a = frozenset(scopes_a) + scopes_b = frozenset(scopes_b) # cached lookups for group membership of users and servers @lru_cache() def groups_for_user(username): """Get set of group names for a given username""" + # if we need a group lookup, the result is not cacheable + nonlocal needs_db + needs_db = True user = db.query(orm.User).filter_by(name=username).first() if user is None: return empty_set @@ -179,6 +195,11 @@ def _intersect_expanded_scopes(scopes_a, scopes_b, db=None): parsed_scopes_a = parse_scopes(scopes_a) parsed_scopes_b = parse_scopes(scopes_b) + # track whether we need a db lookup (for groups) + # because we can't cache the intersection if we do + # if there are no group filters, this is cacheable + needs_db = False + common_bases = parsed_scopes_a.keys() & parsed_scopes_b.keys() common_filters = {} @@ -220,6 +241,7 @@ def _intersect_expanded_scopes(scopes_a, scopes_b, db=None): UserWarning, ) warned = True + needs_db = True common_filters[base] = { entity: filters_a[entity] & filters_b[entity] @@ -245,6 +267,7 @@ def _intersect_expanded_scopes(scopes_a, scopes_b, db=None): # resolve group/server hierarchy if db available servers = servers.difference(common_servers) if db is not None and servers and 'group' in b: + needs_db = True for server in servers: server_groups = groups_for_server(server) if server_groups & b['group']: @@ -272,7 +295,12 @@ def _intersect_expanded_scopes(scopes_a, scopes_b, db=None): if common_users and "user" not in common_filters[base]: common_filters[base]["user"] = common_users - return unparse_scopes(common_filters) + intersection = unparse_scopes(common_filters) + if needs_db: + # return intersection, but don't cache it if it needed db lookups + return DoNotCache(intersection) + + return intersection def get_scopes_for(orm_object): @@ -313,7 +341,7 @@ def get_scopes_for(orm_object): # only thing we miss by short-circuiting here: warning about excluded extra scopes return owner_scopes - token_scopes = expand_scopes(token_scopes, owner=owner) + token_scopes = set(expand_scopes(token_scopes, owner=owner)) if orm_object.client_id != "jupyterhub": # oauth tokens can be used to access the service issuing the token, @@ -358,6 +386,7 @@ def get_scopes_for(orm_object): return expanded_scopes +@lru_cache() def _expand_self_scope(username): """ Users have a metascope 'self' that should be expanded to standard user privileges. @@ -390,9 +419,11 @@ def _expand_self_scope(username): 'read:tokens', 'access:servers', ] - return {f"{scope}!user={username}" for scope in scope_list} + # return immutable frozenset because the result is cached + return frozenset(f"{scope}!user={username}" for scope in scope_list) +@lru_cache(maxsize=65535) def _expand_scope(scope): """Returns a scope and all all subscopes @@ -433,9 +464,30 @@ def _expand_scope(scope): else: expanded_scopes = expanded_scope_names - return expanded_scopes + # return immutable frozenset because the result is cached + return frozenset(expanded_scopes) +def _expand_scopes_key(scopes, owner=None): + """Cache key function for expand_scopes + + scopes is usually a mutable list or set, + which can be hashed as a frozenset + + For the owner, we only care about what kind they are, + and their name. + """ + # freeze scopes for hash + frozen_scopes = frozenset(scopes) + if owner is None: + owner_key = None + else: + # owner key is the type and name + owner_key = (type(owner).__name__, owner.name) + return (frozen_scopes, owner_key) + + +@lru_cache_key(_expand_scopes_key) def expand_scopes(scopes, owner=None): """Returns a set of fully expanded scopes for a collection of raw scopes @@ -479,8 +531,9 @@ def expand_scopes(scopes, owner=None): stacklevel=2, ) - # reduce to minimize - return reduce_scopes(expanded_scopes) + # reduce to discard overlapping scopes + # return immutable frozenset because the result is cached + return frozenset(reduce_scopes(expanded_scopes)) def _needs_scope_expansion(filter_, filter_value, sub_scope): @@ -614,6 +667,7 @@ def _check_token_scopes(scopes, owner): ) +@lru_cache_key(frozenset) def parse_scopes(scope_list): """ Parses scopes and filters in something akin to JSON style @@ -649,9 +703,11 @@ def parse_scopes(scope_list): parsed_scopes[base_scope][key] = {value} else: parsed_scopes[base_scope][key].add(value) - return parsed_scopes + # return immutable FrozenDict because the result is cached + return FrozenDict(parsed_scopes) +@lru_cache_key(FrozenDict) def unparse_scopes(parsed_scopes): """Turn a parsed_scopes dictionary back into a expanded scopes set""" expanded_scopes = set() @@ -662,14 +718,17 @@ def unparse_scopes(parsed_scopes): for entity, names_list in filters.items(): for name in names_list: expanded_scopes.add(f'{base}!{entity}={name}') - return expanded_scopes + # return immutable frozenset because the result is cached + return frozenset(expanded_scopes) +@lru_cache_key(frozenset) def reduce_scopes(expanded_scopes): """Reduce expanded scopes to minimal set - Eliminates redundancy, such as access:services and access:services!service=x + Eliminates overlapping scopes, such as access:services and access:services!service=x """ + # unparse_scopes already returns a frozenset return unparse_scopes(parse_scopes(expanded_scopes)) @@ -723,6 +782,14 @@ def needs_scope(*scopes): return scope_decorator +def _identify_key(obj=None): + if obj is None: + return None + else: + return (type(obj).__name__, obj.name) + + +@lru_cache_key(_identify_key) def identify_scopes(obj=None): """Return 'identify' scopes for an orm object @@ -735,20 +802,25 @@ def identify_scopes(obj=None): identify scopes (set): set of scopes needed for 'identify' endpoints """ if obj is None: - return {f"read:users:{field}!user" for field in {"name", "groups"}} + return frozenset(f"read:users:{field}!user" for field in {"name", "groups"}) elif isinstance(obj, orm.User): - return {f"read:users:{field}!user={obj.name}" for field in {"name", "groups"}} + return frozenset( + f"read:users:{field}!user={obj.name}" for field in {"name", "groups"} + ) elif isinstance(obj, orm.Service): - return {f"read:services:{field}!service={obj.name}" for field in {"name"}} + return frozenset( + f"read:services:{field}!service={obj.name}" for field in {"name"} + ) else: raise TypeError(f"Expected orm.User or orm.Service, got {obj!r}") +@lru_cache_key(lambda oauth_client: oauth_client.identifier) def access_scopes(oauth_client): """Return scope(s) required to access an oauth client""" scopes = set() if oauth_client.identifier == "jupyterhub": - return scopes + return frozenset() spawner = oauth_client.spawner if spawner: scopes.add(f"access:servers!server={spawner.user.name}/{spawner.name}") @@ -760,9 +832,19 @@ def access_scopes(oauth_client): app_log.warning( f"OAuth client {oauth_client} has no associated service or spawner!" ) - return scopes + return frozenset(scopes) +def _check_scope_key(sub_scope, orm_resource, kind): + """Cache key function for check_scope_filter""" + if kind == 'server': + resource_key = (orm_resource.user.name, orm_resource.name) + else: + resource_key = orm_resource.name + return (sub_scope, resource_key, kind) + + +@lru_cache_key(_check_scope_key) def check_scope_filter(sub_scope, orm_resource, kind): """Return whether a sub_scope filter applies to a given resource. @@ -791,8 +873,8 @@ def check_scope_filter(sub_scope, orm_resource, kind): if kind == 'user' and 'group' in sub_scope: group_names = {group.name for group in orm_resource.groups} user_in_group = bool(group_names & set(sub_scope['group'])) - if user_in_group: - return True + # cannot cache if we needed to lookup groups in db + return DoNotCache(user_in_group) return False @@ -831,6 +913,7 @@ def describe_parsed_scopes(parsed_scopes, username=None): return descriptions +@lru_cache_key(lambda raw_scopes, username=None: (frozenset(raw_scopes), username)) def describe_raw_scopes(raw_scopes, username=None): """Return list of descriptions of raw scopes @@ -861,7 +944,8 @@ def describe_raw_scopes(raw_scopes, username=None): "filter": filter_text, } ) - return descriptions + # make sure we return immutable from a cached function + return tuple(descriptions) # regex for custom scope diff --git a/jupyterhub/tests/test_memoize.py b/jupyterhub/tests/test_memoize.py new file mode 100644 index 00000000..c37942a8 --- /dev/null +++ b/jupyterhub/tests/test_memoize.py @@ -0,0 +1,94 @@ +import pytest + +from jupyterhub._memoize import DoNotCache, FrozenDict, LRUCache, lru_cache_key + + +def test_lru_cache(): + cache = LRUCache(maxsize=2) + cache["a"] = 1 + assert "a" in cache + assert "b" not in cache + cache["b"] = 2 + assert cache["b"] == 2 + + # accessing a makes it more recent than b + assert cache["a"] == 1 + assert "b" in cache + assert "a" in cache + + # storing c pushes oldest ('b') out of cache + cache["c"] = 3 + assert len(cache._cache) == 2 + assert "a" in cache + assert "c" in cache + assert "b" not in cache + + +def test_lru_cache_key(): + + call_count = 0 + + @lru_cache_key(frozenset) + def reverse(arg): + nonlocal call_count + call_count += 1 + return list(reversed(arg)) + + in1 = [1, 2] + before = call_count + out1 = reverse(in1) + assert call_count == before + 1 + assert out1 == [2, 1] + + before = call_count + out2 = reverse(in1) + assert call_count == before + assert out2 is out1 + + +def test_do_not_cache(): + + call_count = 0 + + @lru_cache_key(lambda arg: arg) + def is_even(arg): + nonlocal call_count + call_count += 1 + if arg % 2: + return DoNotCache(False) + return True + + before = call_count + assert is_even(0) == True + assert call_count == before + 1 + + # caches even results + before = call_count + assert is_even(0) == True + assert call_count == before + + before = call_count + assert is_even(1) == False + assert call_count == before + 1 + + # doesn't cache odd results + before = call_count + assert is_even(1) == False + assert call_count == before + 1 + + +@pytest.mark.parametrize( + "d", + [ + {"key": "value"}, + {"key": ["list"]}, + {"key": {"set"}}, + {"key": ("tu", "ple")}, + {"key": {"nested": ["dict"]}}, + ], +) +def test_frozen_dict(d): + frozen_1 = FrozenDict(d) + frozen_2 = FrozenDict(d) + assert hash(frozen_1) == hash(frozen_2) + assert frozen_1 == frozen_2 diff --git a/jupyterhub/tests/test_scopes.py b/jupyterhub/tests/test_scopes.py index f8a08b4b..2f40d949 100644 --- a/jupyterhub/tests/test_scopes.py +++ b/jupyterhub/tests/test_scopes.py @@ -8,6 +8,7 @@ from tornado import web from tornado.httputil import HTTPServerRequest from .. import orm, roles, scopes +from .._memoize import FrozenDict from ..handlers import BaseHandler from ..scopes import ( Scope, @@ -38,6 +39,7 @@ def test_scope_constructor(): f'read:users!user={user2}', ] parsed_scopes = parse_scopes(scope_list) + assert isinstance(parsed_scopes, FrozenDict) assert 'read:users' in parsed_scopes assert parsed_scopes['users'] @@ -467,6 +469,7 @@ async def test_metascope_self_expansion( orm_obj = create_service_with_scopes('self') # test expansion of user/service scopes scopes = get_scopes_for(orm_obj) + assert isinstance(scopes, frozenset) assert bool(scopes) == has_user_scopes # test expansion of token scopes @@ -488,6 +491,8 @@ async def test_metascope_inherit_expansion(app, create_user_with_scopes): token.scopes.clear() app.db.commit() token_scope_set = get_scopes_for(token) + assert isinstance(token_scope_set, frozenset) + assert token_scope_set.issubset(identify_scopes(user.orm_user)) @@ -1169,4 +1174,5 @@ def test_expand_scopes(user, scopes, expected): expected.update(_expand_self_scope(user.name)) expanded = expand_scopes(scopes, owner=user.orm_user) + assert isinstance(expanded, frozenset) assert sorted(expanded) == sorted(expected) diff --git a/jupyterhub/tests/test_services_auth.py b/jupyterhub/tests/test_services_auth.py index 0d25ddf2..8cda0dbc 100644 --- a/jupyterhub/tests/test_services_auth.py +++ b/jupyterhub/tests/test_services_auth.py @@ -360,7 +360,7 @@ async def test_oauth_service_roles( ) if 'inherit' in expected_scopes: - expected_scopes = scopes.get_scopes_for(user.orm_user) + expected_scopes = set(scopes.get_scopes_for(user.orm_user)) # always expect identify/access scopes # on successful authentication