diff --git a/jupyterhub/handlers/base.py b/jupyterhub/handlers/base.py index e5764f3b..0db1660c 100644 --- a/jupyterhub/handlers/base.py +++ b/jupyterhub/handlers/base.py @@ -40,6 +40,7 @@ from ..metrics import SERVER_STOP_DURATION_SECONDS from ..metrics import ServerPollStatus from ..metrics import ServerSpawnStatus from ..metrics import ServerStopStatus +from ..metrics import TOTAL_USERS from ..objects import Server from ..spawner import LocalProcessSpawner from ..user import User @@ -453,6 +454,7 @@ class BaseHandler(RequestHandler): # not found, create and register user u = orm.User(name=username) self.db.add(u) + TOTAL_USERS.inc() self.db.commit() user = self._user_from_orm(u) return user diff --git a/jupyterhub/tests/mocking.py b/jupyterhub/tests/mocking.py index 4cadd128..5ae6ec8c 100644 --- a/jupyterhub/tests/mocking.py +++ b/jupyterhub/tests/mocking.py @@ -41,6 +41,7 @@ from traitlets import Bool from traitlets import default from traitlets import Dict +from .. import metrics from .. import orm from ..app import JupyterHub from ..auth import PAMAuthenticator @@ -327,6 +328,7 @@ class MockHub(JupyterHub): user = orm.User(name='user') self.db.add(user) self.db.commit() + metrics.TOTAL_USERS.inc() def stop(self): super().stop() diff --git a/jupyterhub/tests/test_metrics.py b/jupyterhub/tests/test_metrics.py new file mode 100644 index 00000000..29c22122 --- /dev/null +++ b/jupyterhub/tests/test_metrics.py @@ -0,0 +1,34 @@ +import json + +from .utils import add_user +from .utils import api_request +from jupyterhub import metrics +from jupyterhub import orm + + +async def test_total_users(app): + num_users = app.db.query(orm.User).count() + sample = metrics.TOTAL_USERS.collect()[0].samples[0] + assert sample.value == num_users + + await api_request( + app, "/users", method="post", data=json.dumps({"usernames": ["incrementor"]}) + ) + + sample = metrics.TOTAL_USERS.collect()[0].samples[0] + assert sample.value == num_users + 1 + + # GET /users used to double-count + await api_request(app, "/users") + + # populate the Users cache dict if any are missing: + for user in app.db.query(orm.User): + _ = app.users[user.id] + + sample = metrics.TOTAL_USERS.collect()[0].samples[0] + assert sample.value == num_users + 1 + + await api_request(app, "/users/incrementor", method="delete") + + sample = metrics.TOTAL_USERS.collect()[0].samples[0] + assert sample.value == num_users diff --git a/jupyterhub/tests/utils.py b/jupyterhub/tests/utils.py index 09aeb196..fd69178f 100644 --- a/jupyterhub/tests/utils.py +++ b/jupyterhub/tests/utils.py @@ -4,6 +4,7 @@ from concurrent.futures import ThreadPoolExecutor import requests from certipy import Certipy +from jupyterhub import metrics from jupyterhub import orm from jupyterhub.objects import Server from jupyterhub.utils import url_path_join as ujoin @@ -97,6 +98,7 @@ def add_user(db, app=None, **kwargs): if orm_user is None: orm_user = orm.User(**kwargs) db.add(orm_user) + metrics.TOTAL_USERS.inc() else: for attr, value in kwargs.items(): setattr(orm_user, attr, value) diff --git a/jupyterhub/user.py b/jupyterhub/user.py index a970c120..bf8603f8 100644 --- a/jupyterhub/user.py +++ b/jupyterhub/user.py @@ -69,7 +69,6 @@ class UserDict(dict): """Add a user to the UserDict""" if orm_user.id not in self: self[orm_user.id] = self.from_orm(orm_user) - TOTAL_USERS.inc() return self[orm_user.id] def __contains__(self, key):