mirror of
https://github.com/jupyterhub/jupyterhub.git
synced 2025-10-12 04:23:01 +00:00
avoid database error on repeated group name in sync_groups
This commit is contained in:
@@ -29,12 +29,12 @@ async def test_userdict_get(db, attr):
|
|||||||
["isin1", "isin2"],
|
["isin1", "isin2"],
|
||||||
["isin1"],
|
["isin1"],
|
||||||
["notin", "isin1"],
|
["notin", "isin1"],
|
||||||
["new-group", "isin1"],
|
["new-group", "new-group", "isin1"],
|
||||||
[],
|
[],
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_sync_groups(app, user, group_names):
|
def test_sync_groups(app, user, group_names):
|
||||||
expected = sorted(group_names)
|
expected = sorted(set(group_names))
|
||||||
db = app.db
|
db = app.db
|
||||||
db.add(orm.Group(name="notin"))
|
db.add(orm.Group(name="notin"))
|
||||||
in_groups = [orm.Group(name="isin1"), orm.Group(name="isin2")]
|
in_groups = [orm.Group(name="isin1"), orm.Group(name="isin2")]
|
||||||
|
@@ -310,19 +310,19 @@ class User:
|
|||||||
return
|
return
|
||||||
|
|
||||||
# log group changes
|
# log group changes
|
||||||
new_groups = set(group_names).difference(current_groups)
|
added_groups = new_groups.difference(current_groups)
|
||||||
removed_groups = current_groups.difference(group_names)
|
removed_groups = current_groups.difference(group_names)
|
||||||
if new_groups:
|
if added_groups:
|
||||||
self.log.info(f"Adding user {self.name} to group(s): {new_groups}")
|
self.log.info(f"Adding user {self.name} to group(s): {added_groups}")
|
||||||
if removed_groups:
|
if removed_groups:
|
||||||
self.log.info(f"Removing user {self.name} from group(s): {removed_groups}")
|
self.log.info(f"Removing user {self.name} from group(s): {removed_groups}")
|
||||||
|
|
||||||
if group_names:
|
if group_names:
|
||||||
groups = (
|
groups = (
|
||||||
self.db.query(orm.Group).filter(orm.Group.name.in_(group_names)).all()
|
self.db.query(orm.Group).filter(orm.Group.name.in_(new_groups)).all()
|
||||||
)
|
)
|
||||||
existing_groups = {g.name for g in groups}
|
existing_groups = {g.name for g in groups}
|
||||||
for group_name in group_names:
|
for group_name in added_groups:
|
||||||
if group_name not in existing_groups:
|
if group_name not in existing_groups:
|
||||||
# create groups that don't exist yet
|
# create groups that don't exist yet
|
||||||
self.log.info(
|
self.log.info(
|
||||||
@@ -331,9 +331,9 @@ class User:
|
|||||||
group = orm.Group(name=group_name)
|
group = orm.Group(name=group_name)
|
||||||
self.db.add(group)
|
self.db.add(group)
|
||||||
groups.append(group)
|
groups.append(group)
|
||||||
self.groups = groups
|
self.orm_user.groups = groups
|
||||||
else:
|
else:
|
||||||
self.groups = []
|
self.orm_user.groups = []
|
||||||
self.db.commit()
|
self.db.commit()
|
||||||
|
|
||||||
async def save_auth_state(self, auth_state):
|
async def save_auth_state(self, auth_state):
|
||||||
|
Reference in New Issue
Block a user