diff --git a/app/celery/provider_tasks.py b/app/celery/provider_tasks.py index 60ed740b5..2c07eb251 100644 --- a/app/celery/provider_tasks.py +++ b/app/celery/provider_tasks.py @@ -1,6 +1,4 @@ from datetime import datetime -from monotonic import monotonic -from urllib.parse import urljoin from flask import current_app from notifications_utils.recipients import ( @@ -143,6 +141,7 @@ def send_email_to_provider(self, service_id, notification_id): send_email_response.apply_async( (provider.get_name(), reference, notification.to), queue='research-mode' ) + notification.billable_units = 0 else: from_address = '"{}" <{}@{}>'.format(service.name, service.email_from, current_app.config['NOTIFY_EMAIL_DOMAIN']) diff --git a/app/dao/provider_statistics_dao.py b/app/dao/provider_statistics_dao.py index 9a29cffe0..588e9a7dc 100644 --- a/app/dao/provider_statistics_dao.py +++ b/app/dao/provider_statistics_dao.py @@ -1,9 +1,7 @@ -from sqlalchemy import func, cast, Float, case +from sqlalchemy import func from app import db from app.models import ( - ProviderStatistics, - ProviderDetails, NotificationHistory, SMS_TYPE, EMAIL_TYPE, @@ -12,15 +10,6 @@ from app.models import ( ) -def get_provider_statistics(service, **kwargs): - query = ProviderStatistics.query.filter_by(service=service) - if 'providers' in kwargs: - providers = ProviderDetails.query.filter(ProviderDetails.identifier.in_(kwargs['providers'])).all() - provider_ids = [provider.id for provider in providers] - query = query.filter(ProviderStatistics.provider_id.in_(provider_ids)) - return query - - def get_fragment_count(service_id): shared_filters = [ NotificationHistory.service_id == service_id, diff --git a/tests/app/celery/test_provider_tasks.py b/tests/app/celery/test_provider_tasks.py index 08a6026e9..b6ae54637 100644 --- a/tests/app/celery/test_provider_tasks.py +++ b/tests/app/celery/test_provider_tasks.py @@ -3,29 +3,25 @@ from datetime import datetime import pytest from celery.exceptions import MaxRetriesExceededError -from unittest.mock import ANY, call +from unittest.mock import ANY from notifications_utils.recipients import validate_phone_number, format_phone_number import app -from app import statsd_client, mmg_client +from app import mmg_client from app.celery import provider_tasks from app.celery.provider_tasks import send_sms_to_provider, send_email_to_provider from app.celery.research_mode_tasks import send_sms_response, send_email_response from app.clients.email import EmailClientException from app.clients.sms import SmsClientException from app.dao import notifications_dao, provider_details_dao -from app.dao import provider_statistics_dao -from app.dao.provider_statistics_dao import get_provider_statistics from app.models import ( Notification, - NotificationStatistics, - Job, Organisation, KEY_TYPE_NORMAL, KEY_TYPE_TEST, BRANDING_ORG, - BRANDING_BOTH -) + BRANDING_BOTH, + KEY_TYPE_TEAM) from tests.app.conftest import sample_notification @@ -241,14 +237,23 @@ def test_should_call_send_sms_response_task_if_research_mode(notify_db, sample_s assert not persisted_notification.personalisation -@pytest.mark.parametrize('research_mode,key_type', [ - (True, KEY_TYPE_NORMAL), - (False, KEY_TYPE_TEST) +@pytest.mark.parametrize('research_mode,key_type, billable_units', [ + (True, KEY_TYPE_NORMAL, 0), + (False, KEY_TYPE_NORMAL, 1), + (False, KEY_TYPE_TEST, 0), + (True, KEY_TYPE_TEST, 0), + (True, KEY_TYPE_TEAM, 0), + (False, KEY_TYPE_TEAM, 1) ]) -def test_not_should_update_provider_stats_on_success_in_research_mode(notify_db, sample_service, sample_notification, - mocker, research_mode, key_type): - provider_stats = provider_statistics_dao.get_provider_statistics(sample_service).all() - assert len(provider_stats) == 0 +def test_should_update_billable_units_according_to_research_mode_and_key_type(notify_db, + sample_service, + sample_notification, + mocker, + research_mode, + key_type, + billable_units): + + assert Notification.query.count() == 1 mocker.patch('app.mmg_client.send_sms') mocker.patch('app.mmg_client.get_name', return_value="mmg") @@ -264,8 +269,8 @@ def test_not_should_update_provider_stats_on_success_in_research_mode(notify_db, sample_notification.id ) - updated_provider_stats = provider_statistics_dao.get_provider_statistics(sample_service).all() - assert len(updated_provider_stats) == 0 + assert Notification.query.get(sample_notification.id).billable_units == billable_units, \ + "Research mode: {0}, key type: {1}, billable_units: {2}".format(research_mode, key_type, billable_units) def test_should_not_send_to_provider_when_status_is_not_created(notify_db, notify_db_session, @@ -369,9 +374,6 @@ def test_send_email_to_provider_should_call_research_mode_task_response_task_if_ sample_service.research_mode = True notify_db.session.add(sample_service) notify_db.session.commit() - assert not get_provider_statistics( - sample_email_template.service, - providers=[ses_provider.identifier]).first() send_email_to_provider( sample_service.id, notification.id @@ -380,9 +382,6 @@ def test_send_email_to_provider_should_call_research_mode_task_response_task_if_ send_email_response.apply_async.assert_called_once_with( ('ses', str(reference), 'john@smith.com'), queue="research-mode" ) - assert not get_provider_statistics( - sample_email_template.service, - providers=[ses_provider.identifier]).first() persisted_notification = Notification.query.filter_by(id=notification.id).one() assert persisted_notification.to == 'john@smith.com' @@ -392,6 +391,7 @@ def test_send_email_to_provider_should_call_research_mode_task_response_task_if_ assert persisted_notification.created_at <= datetime.utcnow() assert persisted_notification.sent_by == 'ses' assert persisted_notification.reference == str(reference) + assert persisted_notification.billable_units == 0 def test_send_email_to_provider_should_go_into_technical_error_if_exceeds_retries( @@ -465,24 +465,6 @@ def test_send_email_should_use_service_reply_to_email( ) -def test_should_not_set_billable_units_if_research_mode(notify_db, sample_service, sample_notification, mocker): - mocker.patch('app.mmg_client.send_sms') - mocker.patch('app.mmg_client.get_name', return_value="mmg") - mocker.patch('app.celery.research_mode_tasks.send_sms_response.apply_async') - - sample_service.research_mode = True - notify_db.session.add(sample_service) - notify_db.session.commit() - - send_sms_to_provider( - sample_notification.service_id, - sample_notification.id - ) - - persisted_notification = notifications_dao.get_notification_by_id(sample_notification.id) - assert persisted_notification.billable_units == 0 - - def test_get_html_email_renderer_should_return_for_normal_service(sample_service): renderer = provider_tasks.get_html_email_renderer(sample_service) assert renderer.govuk_banner @@ -519,12 +501,3 @@ def test_get_html_email_renderer_prepends_logo_path(notify_db, sample_service): renderer = provider_tasks.get_html_email_renderer(sample_service) assert renderer.brand_logo == 'http://localhost:6012/static/images/email-template/crests/justice-league.png' - - -def _get_provider_statistics(service, **kwargs): - query = ProviderStatistics.query.filter_by(service=service) - if 'providers' in kwargs: - providers = ProviderDetails.query.filter(ProviderDetails.identifier.in_(kwargs['providers'])).all() - provider_ids = [provider.id for provider in providers] - query = query.filter(ProviderStatistics.provider_id.in_(provider_ids)) - return query diff --git a/tests/app/notifications/rest/test_callbacks.py b/tests/app/notifications/rest/test_callbacks.py index eac7aa477..43cfe926f 100644 --- a/tests/app/notifications/rest/test_callbacks.py +++ b/tests/app/notifications/rest/test_callbacks.py @@ -3,7 +3,6 @@ import uuid from datetime import datetime from flask import json from freezegun import freeze_time -from mock import call import app.celery.tasks from app.dao.notifications_dao import (