From 3f10e59db3fd072ff28a774f22e13ff0196fef61 Mon Sep 17 00:00:00 2001 From: Imdad Ahad Date: Mon, 7 Nov 2016 17:42:39 +0000 Subject: [PATCH] Add user dao method to update a single user attr --- app/dao/users_dao.py | 16 ++++++++++++---- tests/app/dao/test_users_dao.py | 29 ++++++++++++++++++++++------- 2 files changed, 34 insertions(+), 11 deletions(-) diff --git a/app/dao/users_dao.py b/app/dao/users_dao.py index 1a240fecc..6654234a5 100644 --- a/app/dao/users_dao.py +++ b/app/dao/users_dao.py @@ -7,18 +7,26 @@ from app import db from app.models import (User, VerifyCode) +def _remove_values_for_keys_if_present(dict, keys): + for key in keys: + dict.pop(key, None) + + def create_secret_code(): return ''.join(map(str, random.sample(range(9), 5))) +def save_user_attribute(usr, update_dict={}): + db.session.query(User).filter_by(id=usr.id).update(update_dict) + db.session.commit() + + def save_model_user(usr, update_dict={}, pwd=None): if pwd: usr.password = pwd usr.password_changed_at = datetime.utcnow() if update_dict: - if update_dict.get('id'): - del update_dict['id'] - update_dict.pop('password_changed_at') + _remove_values_for_keys_if_present(update_dict, ['id', 'password_changed_at']) db.session.query(User).filter_by(id=usr.id).update(update_dict) else: db.session.add(usr) @@ -74,7 +82,7 @@ def delete_user_verify_codes(user): db.session.commit() -def get_model_users(user_id=None): +def get_user_by_id(user_id=None): if user_id: return User.query.filter_by(id=user_id).one() return User.query.filter_by().all() diff --git a/tests/app/dao/test_users_dao.py b/tests/app/dao/test_users_dao.py index 7afcf72ce..7d1901f0d 100644 --- a/tests/app/dao/test_users_dao.py +++ b/tests/app/dao/test_users_dao.py @@ -7,12 +7,13 @@ import pytest from app.dao.users_dao import ( save_model_user, - get_model_users, + save_user_attribute, + get_user_by_id, delete_model_user, increment_failed_login_count, reset_failed_login_count, get_user_by_email, - delete_codes_older_created_more_than_a_day_ago + delete_codes_older_created_more_than_a_day_ago, ) from tests.app.conftest import sample_user as create_sample_user @@ -37,13 +38,13 @@ def test_create_user(notify_api, notify_db, notify_db_session): def test_get_all_users(notify_api, notify_db, notify_db_session, sample_user): assert User.query.count() == 1 - assert len(get_model_users()) == 1 + assert len(get_user_by_id()) == 1 email = "another.notify@digital.cabinet-office.gov.uk" another_user = create_sample_user(notify_db, notify_db_session, email=email) assert User.query.count() == 2 - assert len(get_model_users()) == 2 + assert len(get_user_by_id()) == 2 def test_get_user(notify_api, notify_db, notify_db_session): @@ -51,12 +52,12 @@ def test_get_user(notify_api, notify_db, notify_db_session): another_user = create_sample_user(notify_db, notify_db_session, email=email) - assert get_model_users(user_id=another_user.id).email_address == email + assert get_user_by_id(user_id=another_user.id).email_address == email def test_get_user_not_exists(notify_api, notify_db, notify_db_session, fake_uuid): try: - get_model_users(user_id=fake_uuid) + get_user_by_id(user_id=fake_uuid) pytest.fail("NoResultFound exception not thrown.") except NoResultFound as e: pass @@ -64,7 +65,7 @@ def test_get_user_not_exists(notify_api, notify_db, notify_db_session, fake_uuid def test_get_user_invalid_id(notify_api, notify_db, notify_db_session): try: - get_model_users(user_id="blah") + get_user_by_id(user_id="blah") pytest.fail("DataError exception not thrown.") except DataError: pass @@ -131,3 +132,17 @@ def make_verify_code(user, age=timedelta(hours=0), code="12335"): ) db.session.add(verify_code) db.session.commit() + + +@pytest.mark.parametrize('user_attribute, user_value', [ + ('name', 'New User'), + ('email_address', 'newuser@mail.com'), + ('mobile_number', '+4407700900460') +]) +def test_update_user_attribute(client, sample_user, user_attribute, user_value): + assert getattr(sample_user, user_attribute) != user_value + update_dict = { + user_attribute: user_value + } + save_user_attribute(sample_user, update_dict) + assert getattr(sample_user, user_attribute) == user_value