diff --git a/jupyterhub/apihandlers/groups.py b/jupyterhub/apihandlers/groups.py index 3167d9c1..217cf433 100644 --- a/jupyterhub/apihandlers/groups.py +++ b/jupyterhub/apihandlers/groups.py @@ -41,6 +41,36 @@ class GroupListAPIHandler(_GroupAPIHandler): data = [ self.group_model(g) for g in self.db.query(orm.Group) ] self.write(json.dumps(data)) + @admin_only + async def post(self): + """POST creates Multiple groups """ + model = self.get_json_body() + if not model or not isinstance(model, dict) or not model.get('groups'): + raise web.HTTPError(400, "Must specify at least one group to create") + + groupnames = model.pop("groups",[]) + self._check_group_model(model) + + created = [] + for name in groupnames: + existing = orm.Group.find(self.db, name=name) + if existing is not None: + raise web.HTTPError(400, "Group %s already exists" % name) + + usernames = model.get('users', []) + # check that users exist + users = self._usernames_to_users(usernames) + # create the group + self.log.info("Creating new group %s with %i users", + name, len(users), + ) + self.log.debug("Users: %s", usernames) + group = orm.Group(name=name, users=users) + self.db.add(group) + self.db.commit() + created.append(group) + self.write(json.dumps([self.group_model(group) for group in created])) + self.set_status(201) class GroupAPIHandler(_GroupAPIHandler): """View and modify groups by name""" diff --git a/jupyterhub/tests/test_api.py b/jupyterhub/tests/test_api.py index 5123236d..a5fc2b80 100644 --- a/jupyterhub/tests/test_api.py +++ b/jupyterhub/tests/test_api.py @@ -988,6 +988,26 @@ def test_groups_list(app): }] +@mark.group +@mark.gen_test +def test_add_multi_group(app): + db = app.db + names = ['group1', 'group2'] + r = yield api_request(app, 'groups', method='post', + data=json.dumps({'groups': names}), + ) + assert r.status_code == 201 + reply = r.json() + r_names = [group['name'] for group in reply] + assert names == r_names + + # try to create the same groups again + r = yield api_request(app, 'users', method='post', + data=json.dumps({'groups': names}), + ) + assert r.status_code == 400 + + @mark.group @mark.gen_test def test_group_get(app):