From 0c77f2d0107911930a61cadcb85dcdcff9f38f8f Mon Sep 17 00:00:00 2001 From: Ken Tsang Date: Tue, 11 Jul 2017 18:18:23 +0100 Subject: [PATCH] Refactored dao_update_organisation --- app/dao/organisations_dao.py | 7 +++++++ app/organisation/rest.py | 4 ++-- tests/app/dao/test_organisations_dao.py | 16 ++++++++++------ 3 files changed, 19 insertions(+), 8 deletions(-) diff --git a/app/dao/organisations_dao.py b/app/dao/organisations_dao.py index b9046816f..71b8bfa85 100644 --- a/app/dao/organisations_dao.py +++ b/app/dao/organisations_dao.py @@ -14,3 +14,10 @@ def dao_get_organisation_by_id(org_id): @transactional def dao_create_organisation(organisation): db.session.add(organisation) + + +@transactional +def dao_update_organisation(organisation, **kwargs): + for key, value in kwargs.items(): + setattr(organisation, key, value) + db.session.add(organisation) diff --git a/app/organisation/rest.py b/app/organisation/rest.py index 2262d2cb0..250c0d205 100644 --- a/app/organisation/rest.py +++ b/app/organisation/rest.py @@ -4,6 +4,7 @@ from app.dao.organisations_dao import ( dao_create_organisation, dao_get_organisations, dao_get_organisation_by_id, + dao_update_organisation ) from app.errors import register_errors from app.models import Organisation @@ -45,7 +46,6 @@ def update_organisation(organisation_id): validate(data, post_update_organisation_schema) fetched_organisation = dao_get_organisation_by_id(organisation_id) - for key in data.keys(): - setattr(fetched_organisation, key, data[key]) + dao_update_organisation(fetched_organisation, **data) return jsonify(data=fetched_organisation.serialize()), 200 diff --git a/tests/app/dao/test_organisations_dao.py b/tests/app/dao/test_organisations_dao.py index cb5b1466f..bc5d8c0a6 100644 --- a/tests/app/dao/test_organisations_dao.py +++ b/tests/app/dao/test_organisations_dao.py @@ -4,7 +4,8 @@ from sqlalchemy.exc import IntegrityError from app.dao.organisations_dao import ( dao_create_organisation, dao_get_organisations, - dao_get_organisation_by_id + dao_get_organisation_by_id, + dao_update_organisation, ) from app.models import Organisation @@ -57,11 +58,14 @@ def test_update_organisation(notify_db, notify_db_session): updated_name = 'new name' organisation = create_organisation() - organisation_from_db = Organisation.query.first() - assert organisation_from_db.name != updated_name + organisations_1 = Organisation.query.all() - setattr(organisation_from_db, 'name', updated_name) + assert len(organisations_1) == 1 + assert organisations_1[0].name != updated_name - organisation_from_db_again = Organisation.query.first() + dao_update_organisation(organisations_1[0], name=updated_name) - assert organisation_from_db_again.name == updated_name + organisations_2 = Organisation.query.all() + + assert len(organisations_2) == 1 + assert organisations_2[0].name == updated_name