avoid database error on repeated group name in sync_groups

This commit is contained in:
Min RK
2022-08-19 10:53:21 +02:00
parent 71e86f3064
commit 0b9ae96a96
2 changed files with 9 additions and 9 deletions

View File

@@ -29,12 +29,12 @@ async def test_userdict_get(db, attr):
["isin1", "isin2"], ["isin1", "isin2"],
["isin1"], ["isin1"],
["notin", "isin1"], ["notin", "isin1"],
["new-group", "isin1"], ["new-group", "new-group", "isin1"],
[], [],
], ],
) )
def test_sync_groups(app, user, group_names): def test_sync_groups(app, user, group_names):
expected = sorted(group_names) expected = sorted(set(group_names))
db = app.db db = app.db
db.add(orm.Group(name="notin")) db.add(orm.Group(name="notin"))
in_groups = [orm.Group(name="isin1"), orm.Group(name="isin2")] in_groups = [orm.Group(name="isin1"), orm.Group(name="isin2")]

View File

@@ -310,19 +310,19 @@ class User:
return return
# log group changes # log group changes
new_groups = set(group_names).difference(current_groups) added_groups = new_groups.difference(current_groups)
removed_groups = current_groups.difference(group_names) removed_groups = current_groups.difference(group_names)
if new_groups: if added_groups:
self.log.info(f"Adding user {self.name} to group(s): {new_groups}") self.log.info(f"Adding user {self.name} to group(s): {added_groups}")
if removed_groups: if removed_groups:
self.log.info(f"Removing user {self.name} from group(s): {removed_groups}") self.log.info(f"Removing user {self.name} from group(s): {removed_groups}")
if group_names: if group_names:
groups = ( groups = (
self.db.query(orm.Group).filter(orm.Group.name.in_(group_names)).all() self.db.query(orm.Group).filter(orm.Group.name.in_(new_groups)).all()
) )
existing_groups = {g.name for g in groups} existing_groups = {g.name for g in groups}
for group_name in group_names: for group_name in added_groups:
if group_name not in existing_groups: if group_name not in existing_groups:
# create groups that don't exist yet # create groups that don't exist yet
self.log.info( self.log.info(
@@ -331,9 +331,9 @@ class User:
group = orm.Group(name=group_name) group = orm.Group(name=group_name)
self.db.add(group) self.db.add(group)
groups.append(group) groups.append(group)
self.groups = groups self.orm_user.groups = groups
else: else:
self.groups = [] self.orm_user.groups = []
self.db.commit() self.db.commit()
async def save_auth_state(self, auth_state): async def save_auth_state(self, auth_state):