From 571686b638c3f3d15e55cd4ef4630a581a62671d Mon Sep 17 00:00:00 2001 From: Martyn Inglis Date: Tue, 10 May 2016 09:04:22 +0100 Subject: [PATCH] Ensure that the primary provider is used in all tasks --- app/__init__.py | 4 + app/celery/tasks.py | 118 ++++++++++++--------- app/clients/__init__.py | 13 ++- app/dao/provider_details_dao.py | 22 ++++ app/provider_details/__init__.py | 0 app/provider_details/rest.py | 29 +++++ app/schemas.py | 7 ++ config.py | 1 + requirements.txt | 2 +- tests/app/celery/test_tasks.py | 3 +- tests/app/dao/test_provider_details.py | 24 ----- tests/app/dao/test_provider_details_dao.py | 57 ++++++++++ tests/app/provider_details/__init__.py | 0 tests/app/provider_details/test_rest.py | 0 14 files changed, 203 insertions(+), 77 deletions(-) create mode 100644 app/dao/provider_details_dao.py create mode 100644 app/provider_details/__init__.py create mode 100644 app/provider_details/rest.py delete mode 100644 tests/app/dao/test_provider_details.py create mode 100644 tests/app/dao/test_provider_details_dao.py create mode 100644 tests/app/provider_details/__init__.py create mode 100644 tests/app/provider_details/test_rest.py diff --git a/app/__init__.py b/app/__init__.py index c907e58a8..a5a37f5c7 100644 --- a/app/__init__.py +++ b/app/__init__.py @@ -64,6 +64,7 @@ def create_app(app_name=None): from app.notifications_statistics.rest import notifications_statistics as notifications_statistics_blueprint from app.template_statistics.rest import template_statistics as template_statistics_blueprint from app.events.rest import events as events_blueprint + from app.provider_details.rest import provider_details as provider_details_blueprint application.register_blueprint(service_blueprint, url_prefix='/service') application.register_blueprint(user_blueprint, url_prefix='/user') @@ -77,6 +78,7 @@ def create_app(app_name=None): application.register_blueprint(notifications_statistics_blueprint) application.register_blueprint(template_statistics_blueprint) application.register_blueprint(events_blueprint) + application.register_blueprint(provider_details_blueprint, url_prefix='/provider-details') return application @@ -90,6 +92,8 @@ def init_app(app): url_for('notifications.process_firetext_response'), url_for('notifications.process_mmg_response'), url_for('status.show_delivery_status'), + url_for('provider_details.get_providers'), + "/provider-details/41ade136-33bc-4151-be23-a04084cde50b" ] if request.path not in no_auth_req: from app.authentication import auth diff --git a/app/celery/tasks.py b/app/celery/tasks.py index 55ec0fd71..4653adab3 100644 --- a/app/celery/tasks.py +++ b/app/celery/tasks.py @@ -3,12 +3,12 @@ from datetime import datetime from flask import current_app from sqlalchemy.exc import SQLAlchemyError - -from app.clients.email.aws_ses import AwsSesClientException -from app.clients.sms.firetext import FiretextClientException -from app.clients.sms.mmg import MMGClientException +from app import clients +from app.clients.email import EmailClientException +from app.clients.sms import SmsClientException from app.dao.services_dao import dao_fetch_service_by_id from app.dao.templates_dao import dao_get_template_by_id +from app.dao.provider_details_dao import get_provider_details_by_notification_type from notifications_utils.template import Template, unlink_govuk_escaped @@ -23,10 +23,7 @@ from app import ( DATETIME_FORMAT, DATE_FORMAT, notify_celery, - encryption, - firetext_client, - aws_ses_client, - mmg_client + encryption ) from app.aws import s3 @@ -159,7 +156,7 @@ def process_job(job_id): 'personalisation': { key: personalisation.get(key) for key in template.placeholders - } + } }) if template.template_type == 'sms': @@ -206,7 +203,8 @@ def remove_job(job_id): def send_sms(service_id, notification_id, encrypted_notification, created_at): notification = encryption.decrypt(encrypted_notification) service = dao_fetch_service_by_id(service_id) - client = mmg_client + + provider = provider_to_use('sms', notification_id) restricted = False @@ -234,23 +232,23 @@ def send_sms(service_id, notification_id, encrypted_notification, created_at): status='failed' if restricted else 'sending', created_at=datetime.strptime(created_at, DATETIME_FORMAT), sent_at=sent_at, - sent_by=client.get_name(), + sent_by=provider.get_name(), content_char_count=template.replaced_content_count ) - dao_create_notification(notification_db_object, TEMPLATE_TYPE_SMS, client.get_name()) + dao_create_notification(notification_db_object, TEMPLATE_TYPE_SMS, provider.get_name()) if restricted: return try: - client.send_sms( + provider.send_sms( to=validate_and_format_phone_number(notification['to']), content=template.replaced, reference=str(notification_id) ) - except MMGClientException as e: + except SmsClientException as e: current_app.logger.error( "SMS notification {} failed".format(notification_id) ) @@ -268,9 +266,10 @@ def send_sms(service_id, notification_id, encrypted_notification, created_at): @notify_celery.task(name="send-email") def send_email(service_id, notification_id, from_address, encrypted_notification, created_at): notification = encryption.decrypt(encrypted_notification) - client = aws_ses_client service = dao_fetch_service_by_id(service_id) + provider = provider_to_use('email', notification_id) + restricted = False if not service_allowed_to_send_to(notification['to'], service): @@ -290,10 +289,10 @@ def send_email(service_id, notification_id, from_address, encrypted_notification status='failed' if restricted else 'sending', created_at=datetime.strptime(created_at, DATETIME_FORMAT), sent_at=sent_at, - sent_by=client.get_name() + sent_by=provider.get_name() ) - dao_create_notification(notification_db_object, TEMPLATE_TYPE_EMAIL, client.get_name()) + dao_create_notification(notification_db_object, TEMPLATE_TYPE_EMAIL, provider.get_name()) if restricted: return @@ -303,7 +302,7 @@ def send_email(service_id, notification_id, from_address, encrypted_notification dao_get_template_by_id(notification['template']).__dict__, values=notification.get('personalisation', {}) ) - reference = client.send_email( + reference = provider.send_email( from_address, notification['to'], template.replaced_subject, @@ -311,7 +310,7 @@ def send_email(service_id, notification_id, from_address, encrypted_notification html_body=template.as_HTML_email, ) update_notification_reference_by_id(notification_id, reference) - except AwsSesClientException as e: + except EmailClientException as e: current_app.logger.exception(e) notification_db_object.status = 'failed' @@ -325,20 +324,15 @@ def send_email(service_id, notification_id, from_address, encrypted_notification @notify_celery.task(name='send-sms-code') def send_sms_code(encrypted_verification): + provider = provider_to_use('sms', 'send-sms-code') + verification_message = encryption.decrypt(encrypted_verification) try: - mmg_client.send_sms(validate_and_format_phone_number(verification_message['to']), - "{} is your Notify authentication code".format( - verification_message['secret_code']), - 'send-sms-code') - except MMGClientException as e: - current_app.logger.exception(e) - - -def send_sms_via_firetext(to, content, reference): - try: - firetext_client.send_sms(to=to, content=content, reference=reference) - except FiretextClientException as e: + provider.send_sms(validate_and_format_phone_number(verification_message['to']), + "{} is your Notify authentication code".format( + verification_message['secret_code']), + 'send-sms-code') + except SmsClientException as e: current_app.logger.exception(e) @@ -369,6 +363,8 @@ def invited_user_url(base_url, token): @notify_celery.task(name='email-invited-user') def email_invited_user(encrypted_invitation): + provider = provider_to_use('email', 'email-invited-user') + invitation = encryption.decrypt(encrypted_invitation) url = invited_user_url(current_app.config['ADMIN_BASE_URL'], invitation['token']) @@ -382,11 +378,11 @@ def email_invited_user(encrypted_invitation): current_app.config['NOTIFY_EMAIL_DOMAIN'] ) subject_line = invitation_subject_line(invitation['user_name'], invitation['service_name']) - aws_ses_client.send_email(email_from, - invitation['to'], - subject_line, - invitation_content) - except AwsSesClientException as e: + provider.send_email(email_from, + invitation['to'], + subject_line, + invitation_content) + except EmailClientException as e: current_app.logger.exception(e) @@ -404,17 +400,23 @@ def password_reset_message(name, url): @notify_celery.task(name='email-reset-password') def email_reset_password(encrypted_reset_password_message): + provider = provider_to_use('email', 'email-reset-password') + reset_password_message = encryption.decrypt(encrypted_reset_password_message) try: email_from = '"GOV.UK Notify" <{}>'.format( current_app.config['VERIFY_CODE_FROM_EMAIL_ADDRESS'] ) - aws_ses_client.send_email(email_from, - reset_password_message['to'], - "Reset your GOV.UK Notify password", - password_reset_message(name=reset_password_message['name'], - url=reset_password_message['reset_password_url'])) - except AwsSesClientException as e: + provider.send_email( + email_from, + reset_password_message['to'], + "Reset your GOV.UK Notify password", + password_reset_message( + name=reset_password_message['name'], + url=reset_password_message['reset_password_url'] + ) + ) + except EmailClientException as e: current_app.logger.exception(e) @@ -429,22 +431,26 @@ def registration_verification_template(name, url): @notify_celery.task(name='email-registration-verification') def email_registration_verification(encrypted_verification_message): + provider = provider_to_use('email', 'email-reset-password') + verification_message = encryption.decrypt(encrypted_verification_message) try: email_from = '"GOV.UK Notify" <{}>'.format( current_app.config['VERIFY_CODE_FROM_EMAIL_ADDRESS'] ) - aws_ses_client.send_email(email_from, - verification_message['to'], - "Confirm GOV.UK Notify registration", - registration_verification_template(name=verification_message['name'], - url=verification_message['url'])) - except AwsSesClientException as e: + provider.send_email( + email_from, + verification_message['to'], + "Confirm GOV.UK Notify registration", + registration_verification_template( + name=verification_message['name'], + url=verification_message['url']) + ) + except EmailClientException as e: current_app.logger.exception(e) def service_allowed_to_send_to(recipient, service): - if not service.restricted: return True @@ -454,3 +460,17 @@ def service_allowed_to_send_to(recipient, service): [user.mobile_number, user.email_address] for user in service.users ) ) + + +def provider_to_use(notification_type, notification_id): + active_providers_in_order = [ + provider for provider in get_provider_details_by_notification_type(notification_type) if provider.active + ] + + if len(active_providers_in_order) == 0: + current_app.logger.error( + "{} {} failed as no active providers".format(notification_type, notification_id) + ) + raise Exception("No active {} providers".format(notification_type)) + + return clients.get_client_by_name_and_type(active_providers_in_order[0].identifier, notification_type) diff --git a/app/clients/__init__.py b/app/clients/__init__.py index af5bf3bb3..74750c51a 100644 --- a/app/clients/__init__.py +++ b/app/clients/__init__.py @@ -28,8 +28,17 @@ class Clients(object): for client in email_clients: self.email_clients[client.name] = client - def sms_client(self, name): + def get_sms_client(self, name): return self.sms_clients.get(name) - def email_client(self, name): + def get_email_client(self, name): return self.email_clients.get(name) + + def get_client_by_name_and_type(self, name, notification_type): + assert notification_type in ['email', 'sms'] + + if notification_type == 'email': + return self.get_email_client(name) + + if notification_type == 'sms': + return self.get_sms_client(name) diff --git a/app/dao/provider_details_dao.py b/app/dao/provider_details_dao.py new file mode 100644 index 000000000..9de2d94aa --- /dev/null +++ b/app/dao/provider_details_dao.py @@ -0,0 +1,22 @@ +from sqlalchemy import asc +from app.dao.dao_utils import transactional +from app.models import ProviderDetails +from app import db + +def get_provider_details(): + return ProviderDetails.query.order_by(asc(ProviderDetails.priority)).all() + + +def get_provider_details_by_id(provider_details_id): + return ProviderDetails.query.get(provider_details_id) + + +def get_provider_details_by_notification_type(notification_type): + return ProviderDetails.query.filter_by( + notification_type=notification_type + ).order_by(asc(ProviderDetails.priority)).all() + + +@transactional +def dao_update_provider_details(provider_details): + db.session.add(provider_details) diff --git a/app/provider_details/__init__.py b/app/provider_details/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/app/provider_details/rest.py b/app/provider_details/rest.py new file mode 100644 index 000000000..ac5ac72b0 --- /dev/null +++ b/app/provider_details/rest.py @@ -0,0 +1,29 @@ +from flask import Blueprint, jsonify, request + +from app.schemas import provider_details_schema +from app.dao.provider_details_dao import ( + get_provider_details, + get_provider_details_by_id, + dao_update_provider_details +) + +provider_details = Blueprint('provider_details', __name__) + + +@provider_details.route('', methods=['GET']) +def get_providers(): + data, errors = provider_details_schema.dump(get_provider_details(), many=True) + return jsonify(provider_details=data) + + +@provider_details.route('/', methods=['POST']) +def update_provider_details(provider_details_id): + fetched_provider_details = get_provider_details_by_id(provider_details_id) + + current_data = dict(provider_details_schema.dump(fetched_provider_details).data.items()) + current_data.update(request.get_json()) + update_dict, errors = provider_details_schema.load(current_data) + if errors: + return jsonify(result="error", message=errors), 400 + dao_update_provider_details(update_dict) + return jsonify(data=provider_details_schema.dump(fetched_provider_details).data), 200 diff --git a/app/schemas.py b/app/schemas.py index 9f18393c5..07b5af5a0 100644 --- a/app/schemas.py +++ b/app/schemas.py @@ -69,6 +69,12 @@ class UserSchema(BaseSchema): "_password", "verify_codes") +class ProviderDetailsSchema(BaseSchema): + class Meta: + model = models.ProviderDetails + exclude = ("provider_rates", "provider_stats") + + class ServiceSchema(BaseSchema): created_by = field_for(models.Service, 'created_by', required=True) @@ -373,3 +379,4 @@ api_key_history_schema = ApiKeyHistorySchema() template_history_schema = TemplateHistorySchema() event_schema = EventSchema() from_to_date_schema = FromToDateSchema() +provider_details_schema = ProviderDetailsSchema() diff --git a/config.py b/config.py index 164dd7d9b..938887e9a 100644 --- a/config.py +++ b/config.py @@ -23,6 +23,7 @@ class Config(object): SQLALCHEMY_COMMIT_ON_TEARDOWN = False SQLALCHEMY_DATABASE_URI = os.environ['SQLALCHEMY_DATABASE_URI'] SQLALCHEMY_RECORD_QUERIES = True + SQLALCHEMY_TRACK_MODIFICATIONS = True VERIFY_CODE_FROM_EMAIL_ADDRESS = os.environ['VERIFY_CODE_FROM_EMAIL_ADDRESS'] NOTIFY_EMAIL_DOMAIN = os.environ['NOTIFY_EMAIL_DOMAIN'] PAGE_SIZE = 50 diff --git a/requirements.txt b/requirements.txt index 37f892147..b515baa03 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,7 +2,7 @@ bleach==1.4.2 Flask==0.10.1 Flask-Script==2.0.5 Flask-Migrate==1.3.1 -Flask-SQLAlchemy==2.1 +Flask-SQLAlchemy==2.0 psycopg2==2.6.1 SQLAlchemy==1.0.5 SQLAlchemy-Utils==0.30.5 diff --git a/tests/app/celery/test_tasks.py b/tests/app/celery/test_tasks.py index 9536febb1..1e8c29588 100644 --- a/tests/app/celery/test_tasks.py +++ b/tests/app/celery/test_tasks.py @@ -40,6 +40,7 @@ class AnyStringWith(str): def __eq__(self, other): return self in other + mmg_error = {'Error': '40', 'Description': 'error'} @@ -804,7 +805,7 @@ def test_email_invited_user_should_send_email(notify_api, mocker): expected_content) -def test_email_reset_password_should_send_email(notify_api, mocker): +def test_email_reset_password_should_send_email(notify_db, notify_db_session, notify_api, mocker): with notify_api.test_request_context(): reset_password_message = {'to': 'someone@it.gov.uk', 'name': 'Some One', diff --git a/tests/app/dao/test_provider_details.py b/tests/app/dao/test_provider_details.py deleted file mode 100644 index a393fe2ad..000000000 --- a/tests/app/dao/test_provider_details.py +++ /dev/null @@ -1,24 +0,0 @@ -from app.models import ProviderDetails -from app import clients - - -def test_should_error_if_any_provider_in_database_not_in_code(notify_db, notify_db_session, notify_api): - providers = ProviderDetails.query.all() - - for provider in providers: - if provider.notification_type == 'sms': - assert clients.sms_client(provider.identifier) - if provider.notification_type == 'email': - assert clients.email_client(provider.identifier) - - -def test_should_not_error_if_any_provider_in_code_not_in_database(notify_db, notify_db_session, notify_api): - providers = ProviderDetails.query.all() - - ProviderDetails.query.filter_by(identifier='mmg').delete() - - for provider in providers: - if provider.notification_type == 'sms': - assert clients.sms_client(provider.identifier) - if provider.notification_type == 'email': - assert clients.email_client(provider.identifier) diff --git a/tests/app/dao/test_provider_details_dao.py b/tests/app/dao/test_provider_details_dao.py new file mode 100644 index 000000000..cb63fa400 --- /dev/null +++ b/tests/app/dao/test_provider_details_dao.py @@ -0,0 +1,57 @@ +from app.models import ProviderDetails +from app import clients +from app.dao.provider_details_dao import ( + get_provider_details, + get_provider_details_by_notification_type +) + + +def test_can_get_all_providers(notify_db, notify_db_session): + assert len(get_provider_details()) == 3 + + +def test_can_get_sms_providers(notify_db, notify_db_session): + assert len(get_provider_details_by_notification_type('sms')) == 2 + types = [provider.notification_type for provider in get_provider_details_by_notification_type('sms')] + assert all('sms' == notification_type for notification_type in types) + + +def test_can_get_sms_providers_in_order(notify_db, notify_db_session): + providers = get_provider_details_by_notification_type('sms') + + assert providers[0].identifier == "mmg" + assert providers[1].identifier == "firetext" + + +def test_can_get_email_providers_in_order(notify_db, notify_db_session): + providers = get_provider_details_by_notification_type('email') + + assert providers[0].identifier == "ses" + + +def test_can_get_email_providers(notify_db, notify_db_session): + assert len(get_provider_details_by_notification_type('email')) == 1 + types = [provider.notification_type for provider in get_provider_details_by_notification_type('email')] + assert all('email' == notification_type for notification_type in types) + + +def test_should_error_if_any_provider_in_database_not_in_code(notify_db, notify_db_session, notify_api): + providers = ProviderDetails.query.all() + + for provider in providers: + if provider.notification_type == 'sms': + assert clients.get_sms_client(provider.identifier) + if provider.notification_type == 'email': + assert clients.get_email_client(provider.identifier) + + +def test_should_not_error_if_any_provider_in_code_not_in_database(notify_db, notify_db_session, notify_api): + providers = ProviderDetails.query.all() + + ProviderDetails.query.filter_by(identifier='mmg').delete() + + for provider in providers: + if provider.notification_type == 'sms': + assert clients.get_sms_client(provider.identifier) + if provider.notification_type == 'email': + assert clients.get_email_client(provider.identifier) diff --git a/tests/app/provider_details/__init__.py b/tests/app/provider_details/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/app/provider_details/test_rest.py b/tests/app/provider_details/test_rest.py new file mode 100644 index 000000000..e69de29bb