fix and test TOTAL_USERS count

Don't assume UserDict contains all users

which assumption led to double-counting when a user in the db was loaded into the dict cache
This commit is contained in:
Min RK
2020-11-30 13:27:52 +01:00
parent 18393ec6b4
commit 7e469f911d
5 changed files with 40 additions and 1 deletions

View File

@@ -40,6 +40,7 @@ from ..metrics import SERVER_STOP_DURATION_SECONDS
from ..metrics import ServerPollStatus
from ..metrics import ServerSpawnStatus
from ..metrics import ServerStopStatus
from ..metrics import TOTAL_USERS
from ..objects import Server
from ..spawner import LocalProcessSpawner
from ..user import User
@@ -453,6 +454,7 @@ class BaseHandler(RequestHandler):
# not found, create and register user
u = orm.User(name=username)
self.db.add(u)
TOTAL_USERS.inc()
self.db.commit()
user = self._user_from_orm(u)
return user

View File

@@ -41,6 +41,7 @@ from traitlets import Bool
from traitlets import default
from traitlets import Dict
from .. import metrics
from .. import orm
from ..app import JupyterHub
from ..auth import PAMAuthenticator
@@ -327,6 +328,7 @@ class MockHub(JupyterHub):
user = orm.User(name='user')
self.db.add(user)
self.db.commit()
metrics.TOTAL_USERS.inc()
def stop(self):
super().stop()

View File

@@ -0,0 +1,34 @@
import json
from .utils import add_user
from .utils import api_request
from jupyterhub import metrics
from jupyterhub import orm
async def test_total_users(app):
num_users = app.db.query(orm.User).count()
sample = metrics.TOTAL_USERS.collect()[0].samples[0]
assert sample.value == num_users
await api_request(
app, "/users", method="post", data=json.dumps({"usernames": ["incrementor"]})
)
sample = metrics.TOTAL_USERS.collect()[0].samples[0]
assert sample.value == num_users + 1
# GET /users used to double-count
await api_request(app, "/users")
# populate the Users cache dict if any are missing:
for user in app.db.query(orm.User):
_ = app.users[user.id]
sample = metrics.TOTAL_USERS.collect()[0].samples[0]
assert sample.value == num_users + 1
await api_request(app, "/users/incrementor", method="delete")
sample = metrics.TOTAL_USERS.collect()[0].samples[0]
assert sample.value == num_users

View File

@@ -4,6 +4,7 @@ from concurrent.futures import ThreadPoolExecutor
import requests
from certipy import Certipy
from jupyterhub import metrics
from jupyterhub import orm
from jupyterhub.objects import Server
from jupyterhub.utils import url_path_join as ujoin
@@ -97,6 +98,7 @@ def add_user(db, app=None, **kwargs):
if orm_user is None:
orm_user = orm.User(**kwargs)
db.add(orm_user)
metrics.TOTAL_USERS.inc()
else:
for attr, value in kwargs.items():
setattr(orm_user, attr, value)

View File

@@ -69,7 +69,6 @@ class UserDict(dict):
"""Add a user to the UserDict"""
if orm_user.id not in self:
self[orm_user.id] = self.from_orm(orm_user)
TOTAL_USERS.inc()
return self[orm_user.id]
def __contains__(self, key):