diff --git a/app/celery/provider_tasks.py b/app/celery/provider_tasks.py index 32889b663..8b7af54fd 100644 --- a/app/celery/provider_tasks.py +++ b/app/celery/provider_tasks.py @@ -82,7 +82,7 @@ def send_sms_to_provider(self, service_id, notification_id): notification_id, SMS_TYPE, provider.get_name(), - content_char_count=template.replaced_content_count + billable_units=notification.billable_units ) notification.sent_at = datetime.utcnow() @@ -164,7 +164,8 @@ def send_email_to_provider(self, service_id, notification_id): update_provider_stats( notification_id, EMAIL_TYPE, - provider.get_name() + provider.get_name(), + billable_units=1 ) notification.reference = reference notification.sent_at = datetime.utcnow() diff --git a/app/dao/notifications_dao.py b/app/dao/notifications_dao.py index 14a3821b3..1e7f2d75e 100644 --- a/app/dao/notifications_dao.py +++ b/app/dao/notifications_dao.py @@ -299,30 +299,22 @@ def update_provider_stats( id_, notification_type, provider_name, - content_char_count=None): + billable_units=1): notification = Notification.query.filter(Notification.id == id_).one() provider = ProviderDetails.query.filter_by(identifier=provider_name).one() - def unit_count(): - if notification_type == EMAIL_TYPE: - return 1 - else: - if (content_char_count): - return get_sms_fragment_count(content_char_count) - return get_sms_fragment_count(notification.content_char_count) - update_count = db.session.query(ProviderStatistics).filter_by( day=date.today(), service_id=notification.service_id, provider_id=provider.id - ).update({'unit_count': ProviderStatistics.unit_count + unit_count()}) + ).update({'unit_count': ProviderStatistics.unit_count + billable_units}) if update_count == 0: provider_stats = ProviderStatistics( day=notification.created_at.date(), service_id=notification.service_id, provider_id=provider.id, - unit_count=unit_count() + unit_count=billable_units ) db.session.add(provider_stats) diff --git a/tests/app/celery/test_provider_tasks.py b/tests/app/celery/test_provider_tasks.py index 3f8f293ff..eb788129a 100644 --- a/tests/app/celery/test_provider_tasks.py +++ b/tests/app/celery/test_provider_tasks.py @@ -557,8 +557,8 @@ def test_should_not_set_billable_units_if_research_mode(notify_db, sample_servic notify_db.session.commit() send_sms_to_provider( - sample_notification.service_id, - sample_notification.id + sample_notification.service_id, + sample_notification.id ) persisted_notification = notifications_dao.get_notification(sample_service.id, sample_notification.id) diff --git a/tests/app/dao/test_provider_statistics_dao.py b/tests/app/dao/test_provider_statistics_dao.py index c3b268de4..955aa3c23 100644 --- a/tests/app/dao/test_provider_statistics_dao.py +++ b/tests/app/dao/test_provider_statistics_dao.py @@ -1,10 +1,11 @@ from datetime import datetime import uuid -from app.models import NotificationHistory, KEY_TYPE_NORMAL, NOTIFICATION_STATUS_TYPES +import pytest + +from app.models import NotificationHistory, KEY_TYPE_NORMAL, KEY_TYPE_TEAM, KEY_TYPE_TEST, NOTIFICATION_STATUS_TYPES from app.dao.notifications_dao import update_provider_stats -from app.dao.provider_statistics_dao import ( - get_provider_statistics, get_fragment_count) +from app.dao.provider_statistics_dao import get_provider_statistics, get_fragment_count from tests.app.conftest import sample_notification as create_sample_notification @@ -46,24 +47,24 @@ def test_should_update_provider_statistics_sms_multi(notify_db, notify_db, notify_db_session, template=sample_template, - content_char_count=160) - update_provider_stats(n1.id, 'sms', mmg_provider.identifier) + billable_units=1) + update_provider_stats(n1.id, 'sms', mmg_provider.identifier, n1.billable_units) n2 = create_sample_notification( notify_db, notify_db_session, template=sample_template, - content_char_count=161) - update_provider_stats(n2.id, 'sms', mmg_provider.identifier) + billable_units=2) + update_provider_stats(n2.id, 'sms', mmg_provider.identifier, n2.billable_units) n3 = create_sample_notification( notify_db, notify_db_session, template=sample_template, - content_char_count=307) - update_provider_stats(n3.id, 'sms', mmg_provider.identifier) + billable_units=4) + update_provider_stats(n3.id, 'sms', mmg_provider.identifier, n3.billable_units) provider_stats = get_provider_statistics( sample_template.service, providers=[mmg_provider.identifier]).one() - assert provider_stats.unit_count == 6 + assert provider_stats.unit_count == 7 def test_should_update_provider_statistics_email_multi(notify_db, diff --git a/tests/app/service/test_service_fragment_count.py b/tests/app/service/test_service_fragment_count.py deleted file mode 100644 index 090e34d7b..000000000 --- a/tests/app/service/test_service_fragment_count.py +++ /dev/null @@ -1,77 +0,0 @@ -import json -from datetime import (date, timedelta) -from flask import url_for -from tests import create_authorization_header - - -def test_fragment_count(notify_api, sample_provider_statistics): - with notify_api.test_request_context(): - with notify_api.test_client() as client: - endpoint = url_for( - 'service.get_service_provider_aggregate_statistics', - service_id=str(sample_provider_statistics.service.id)) - auth_header = create_authorization_header() - resp = client.get( - endpoint, - headers=[auth_header] - ) - assert resp.status_code == 200 - json_resp = json.loads(resp.get_data(as_text=True)) - assert json_resp['data']['sms_count'] == 1 - - -def test_fragment_count_from_to(notify_api, sample_provider_statistics): - with notify_api.test_request_context(): - with notify_api.test_client() as client: - today_str = date.today().strftime('%Y-%m-%d') - endpoint = url_for( - 'service.get_service_provider_aggregate_statistics', - service_id=str(sample_provider_statistics.service.id), - date_from=today_str, - date_to=today_str) - auth_header = create_authorization_header() - resp = client.get( - endpoint, - headers=[auth_header] - ) - assert resp.status_code == 200 - json_resp = json.loads(resp.get_data(as_text=True)) - assert json_resp['data']['sms_count'] == 1 - - -def test_fragment_count_from_greater_than_to(notify_api, sample_provider_statistics): - with notify_api.test_request_context(): - with notify_api.test_client() as client: - today_str = date.today().strftime('%Y-%m-%d') - yesterday_str = date.today() - timedelta(days=1) - endpoint = url_for( - 'service.get_service_provider_aggregate_statistics', - service_id=str(sample_provider_statistics.service.id), - date_from=today_str, - date_to=yesterday_str) - auth_header = create_authorization_header() - resp = client.get( - endpoint, - headers=[auth_header] - ) - assert resp.status_code == 400 - json_resp = json.loads(resp.get_data(as_text=True)) - assert 'date_from needs to be greater than date_to' in json_resp['message']['_schema'] - - -def test_fragment_count_in_future(notify_api, sample_provider_statistics): - with notify_api.test_request_context(): - with notify_api.test_client() as client: - tomorrow_str = (date.today() + timedelta(days=1)).strftime('%Y-%m-%d') - endpoint = url_for( - 'service.get_service_provider_aggregate_statistics', - service_id=str(sample_provider_statistics.service.id), - date_from=tomorrow_str) - auth_header = create_authorization_header() - resp = client.get( - endpoint, - headers=[auth_header] - ) - assert resp.status_code == 400 - json_resp = json.loads(resp.get_data(as_text=True)) - assert 'Date cannot be in the future' in json_resp['message']['date_from']