diff --git a/app/dao/provider_statistics_dao.py b/app/dao/provider_statistics_dao.py index da942854f..4fffb7dd6 100644 --- a/app/dao/provider_statistics_dao.py +++ b/app/dao/provider_statistics_dao.py @@ -7,7 +7,8 @@ from app.models import ( NotificationHistory, SMS_TYPE, EMAIL_TYPE, - NOTIFICATION_STATUS_TYPES_BILLABLE + NOTIFICATION_STATUS_TYPES_BILLABLE, + KEY_TYPE_TEST ) @@ -21,6 +22,13 @@ def get_provider_statistics(service, **kwargs): def get_fragment_count(service_id): + live_dates = get_service_live_dates(service_id) + shared_filters = [ + NotificationHistory.service_id == service_id, + NotificationHistory.status.in_(NOTIFICATION_STATUS_TYPES_BILLABLE), + NotificationHistory.key_type != KEY_TYPE_TEST + ] + sms_count = db.session.query( func.sum( case( @@ -34,16 +42,15 @@ def get_fragment_count(service_id): ) ) ).filter( - NotificationHistory.service_id == service_id, NotificationHistory.notification_type == SMS_TYPE, - NotificationHistory.status.in_(NOTIFICATION_STATUS_TYPES_BILLABLE) + *shared_filters ) + email_count = db.session.query( func.count(NotificationHistory.id) ).filter( - NotificationHistory.service_id == service_id, NotificationHistory.notification_type == EMAIL_TYPE, - NotificationHistory.status.in_(NOTIFICATION_STATUS_TYPES_BILLABLE) + *shared_filters ) return { 'sms_count': int(sms_count.scalar() or 0), diff --git a/tests/app/dao/test_provider_statistics_dao.py b/tests/app/dao/test_provider_statistics_dao.py index b0c5b4a12..f5f349b7c 100644 --- a/tests/app/dao/test_provider_statistics_dao.py +++ b/tests/app/dao/test_provider_statistics_dao.py @@ -113,6 +113,12 @@ def test_get_fragment_count_filters_on_status(notify_db, sample_template): assert get_fragment_count(sample_template.service_id)['sms_count'] == 6 +def test_get_fragment_count_filters_on_service_id(notify_db, sample_template, service_factory): + service_2 = service_factory.get('service 2', email_from='service.2') + noti_hist(notify_db, sample_template) + assert get_fragment_count(service_2.id)['sms_count'] == 0 + + def test_get_fragment_count_sums_char_count_for_sms(notify_db, sample_template): noti_hist(notify_db, sample_template, content_char_count=1) # 1 noti_hist(notify_db, sample_template, content_char_count=159) # 1 @@ -120,7 +126,17 @@ def test_get_fragment_count_sums_char_count_for_sms(notify_db, sample_template): assert get_fragment_count(sample_template.service_id)['sms_count'] == 4 -def noti_hist(notify_db, template, status='delivered', content_char_count=None): +@pytest.mark.parametrize('key_type,sms_count', [ + (KEY_TYPE_NORMAL, 1), + (KEY_TYPE_TEAM, 1), + (KEY_TYPE_TEST, 0), +]) +def test_get_fragment_count_ignores_test_api_keys(notify_db, sample_template, key_type, sms_count): + noti_hist(notify_db, sample_template, key_type=key_type) + assert get_fragment_count(sample_template.service_id)['sms_count'] == sms_count + + +def noti_hist(notify_db, template, status='delivered', content_char_count=None, key_type=KEY_TYPE_NORMAL): if not content_char_count and template.template_type == 'sms': content_char_count = 1