diff --git a/app/delivery/send_to_providers.py b/app/delivery/send_to_providers.py index 82c7879ba..fa9ccfba9 100644 --- a/app/delivery/send_to_providers.py +++ b/app/delivery/send_to_providers.py @@ -56,7 +56,7 @@ def send_sms_to_provider(notification): if service.research_mode or notification.key_type == KEY_TYPE_TEST: notification.billable_units = 0 - update_notification(notification, provider) + update_notification_to_sending(notification, provider) try: send_sms_response(provider.get_name(), str(notification.id), notification.to) except HTTPError: @@ -76,11 +76,14 @@ def send_sms_to_provider(notification): sender=notification.reply_to_text ) except Exception as e: + notification.billable_units = template.fragment_count + notification.sent_by = provider.get_name() + dao_update_notification(notification) dao_toggle_sms_provider(provider.name) raise e else: notification.billable_units = template.fragment_count - update_notification(notification, provider, notification.international) + update_notification_to_sending(notification, provider, notification.international) current_app.logger.debug( "SMS {} sent to provider {} at {}".format(notification.id, provider.get_name(), notification.sent_at) @@ -116,7 +119,7 @@ def send_email_to_provider(notification): reference = str(create_uuid()) notification.billable_units = 0 notification.reference = reference - update_notification(notification, provider) + update_notification_to_sending(notification, provider) send_email_response(reference, notification.to) else: from_address = '"{}" <{}@{}>'.format(service.name, service.email_from, @@ -133,7 +136,7 @@ def send_email_to_provider(notification): reply_to_address=validate_and_format_email_address(email_reply_to) if email_reply_to else None, ) notification.reference = reference - update_notification(notification, provider) + update_notification_to_sending(notification, provider) current_app.logger.debug( "Email {} sent to provider at {}".format(notification.id, notification.sent_at) @@ -142,7 +145,7 @@ def send_email_to_provider(notification): statsd_client.timing("email.total-time", delta_milliseconds) -def update_notification(notification, provider, international=False): +def update_notification_to_sending(notification, provider, international=False): notification.sent_at = datetime.utcnow() notification.sent_by = provider.get_name() if international: diff --git a/tests/app/celery/test_provider_tasks.py b/tests/app/celery/test_provider_tasks.py index c0dfbdc63..65b63682a 100644 --- a/tests/app/celery/test_provider_tasks.py +++ b/tests/app/celery/test_provider_tasks.py @@ -112,17 +112,6 @@ def test_should_retry_and_log_exception(sample_notification, mocker): assert sample_notification.status == 'created' -def test_send_sms_should_switch_providers_on_provider_failure(sample_notification, mocker): - provider_to_use = mocker.patch('app.delivery.send_to_providers.provider_to_use') - provider_to_use.return_value.send_sms.side_effect = Exception('Error') - switch_provider_mock = mocker.patch('app.delivery.send_to_providers.dao_toggle_sms_provider') - mocker.patch('app.celery.provider_tasks.deliver_sms.retry') - - deliver_sms(sample_notification.id) - - assert switch_provider_mock.called is True - - def test_send_sms_should_not_switch_providers_on_non_provider_failure( sample_notification, mocker diff --git a/tests/app/delivery/test_send_to_providers.py b/tests/app/delivery/test_send_to_providers.py index f0bb62dc2..ae8b06782 100644 --- a/tests/app/delivery/test_send_to_providers.py +++ b/tests/app/delivery/test_send_to_providers.py @@ -562,6 +562,26 @@ def test_should_update_billable_units_according_to_research_mode_and_key_type( assert sample_notification.billable_units == billable_units +def test_should_set_notification_billable_units_and_provider_if_sending_to_provider_fails( + notify_db, + sample_service, + sample_notification, + mocker, +): + mocker.patch('app.mmg_client.send_sms', side_effect=Exception()) + mock_toggle_provider = mocker.patch('app.delivery.send_to_providers.dao_toggle_sms_provider') + + sample_notification.billable_units = 0 + assert sample_notification.sent_by is None + + with pytest.raises(Exception): + send_to_providers.send_sms_to_provider(sample_notification) + + assert sample_notification.billable_units == 1 + assert sample_notification.sent_by == 'mmg' + assert mock_toggle_provider.called + + def test_should_send_sms_to_international_providers( restore_provider_details, sample_sms_template_with_html,