diff --git a/app/__init__.py b/app/__init__.py index 175e44a69..c907e58a8 100644 --- a/app/__init__.py +++ b/app/__init__.py @@ -7,6 +7,7 @@ from flask_marshmallow import Marshmallow from werkzeug.local import LocalProxy from notifications_utils import logging from app.celery.celery import NotifyCelery +from app.clients import Clients from app.clients.sms.mmg import MMGClient from app.clients.sms.twilio import TwilioClient from app.clients.sms.firetext import FiretextClient @@ -25,8 +26,7 @@ mmg_client = MMGClient() aws_ses_client = AwsSesClient() encryption = Encryption() -sms_clients = [] -email_clients = [] +clients = Clients() api_user = LocalProxy(lambda: _request_ctx_stack.top.api_user) @@ -50,6 +50,7 @@ def create_app(app_name=None): aws_ses_client.init_app(application.config['AWS_REGION']) notify_celery.init_app(application) encryption.init_app(application) + clients.init_app(sms_clients=[firetext_client, mmg_client], email_clients=[aws_ses_client]) from app.service.rest import service as service_blueprint from app.user.rest import user as user_blueprint @@ -77,9 +78,6 @@ def create_app(app_name=None): application.register_blueprint(template_statistics_blueprint) application.register_blueprint(events_blueprint) - email_clients = [aws_ses_client] - sms_clients = [mmg_client, firetext_client] - return application diff --git a/app/clients/__init__.py b/app/clients/__init__.py index 30c359631..af5bf3bb3 100644 --- a/app/clients/__init__.py +++ b/app/clients/__init__.py @@ -1,4 +1,3 @@ - class ClientException(Exception): ''' Base Exceptions for sending notifications that fail @@ -16,3 +15,21 @@ class Client(object): STATISTICS_REQUESTED = 'requested' STATISTICS_DELIVERED = 'delivered' STATISTICS_FAILURE = 'failure' + + +class Clients(object): + sms_clients = {} + email_clients = {} + + def init_app(self, sms_clients, email_clients): + for client in sms_clients: + self.sms_clients[client.name] = client + + for client in email_clients: + self.email_clients[client.name] = client + + def sms_client(self, name): + return self.sms_clients.get(name) + + def email_client(self, name): + return self.email_clients.get(name) diff --git a/tests/app/dao/test_provider_details.py b/tests/app/dao/test_provider_details.py new file mode 100644 index 000000000..a393fe2ad --- /dev/null +++ b/tests/app/dao/test_provider_details.py @@ -0,0 +1,24 @@ +from app.models import ProviderDetails +from app import clients + + +def test_should_error_if_any_provider_in_database_not_in_code(notify_db, notify_db_session, notify_api): + providers = ProviderDetails.query.all() + + for provider in providers: + if provider.notification_type == 'sms': + assert clients.sms_client(provider.identifier) + if provider.notification_type == 'email': + assert clients.email_client(provider.identifier) + + +def test_should_not_error_if_any_provider_in_code_not_in_database(notify_db, notify_db_session, notify_api): + providers = ProviderDetails.query.all() + + ProviderDetails.query.filter_by(identifier='mmg').delete() + + for provider in providers: + if provider.notification_type == 'sms': + assert clients.sms_client(provider.identifier) + if provider.notification_type == 'email': + assert clients.email_client(provider.identifier)