diff --git a/app/__init__.py b/app/__init__.py index 3321e72d2..f6228bd60 100644 --- a/app/__init__.py +++ b/app/__init__.py @@ -1,6 +1,5 @@ import uuid import os -import re from flask import request, url_for from flask import Flask, _request_ctx_stack from flask.ext.sqlalchemy import SQLAlchemy @@ -8,6 +7,7 @@ from flask_marshmallow import Marshmallow from werkzeug.local import LocalProxy from notifications_utils import logging from app.celery.celery import NotifyCelery +from app.clients import Clients from app.clients.sms.mmg import MMGClient from app.clients.sms.twilio import TwilioClient from app.clients.sms.firetext import FiretextClient @@ -26,6 +26,8 @@ mmg_client = MMGClient() aws_ses_client = AwsSesClient() encryption = Encryption() +clients = Clients() + api_user = LocalProxy(lambda: _request_ctx_stack.top.api_user) @@ -48,6 +50,7 @@ def create_app(app_name=None): aws_ses_client.init_app(application.config['AWS_REGION']) notify_celery.init_app(application) encryption.init_app(application) + clients.init_app(sms_clients=[firetext_client, mmg_client], email_clients=[aws_ses_client]) from app.service.rest import service as service_blueprint from app.user.rest import user as user_blueprint @@ -61,6 +64,7 @@ def create_app(app_name=None): from app.notifications_statistics.rest import notifications_statistics as notifications_statistics_blueprint from app.template_statistics.rest import template_statistics as template_statistics_blueprint from app.events.rest import events as events_blueprint + from app.provider_details.rest import provider_details as provider_details_blueprint application.register_blueprint(service_blueprint, url_prefix='/service') application.register_blueprint(user_blueprint, url_prefix='/user') @@ -74,6 +78,7 @@ def create_app(app_name=None): application.register_blueprint(notifications_statistics_blueprint) application.register_blueprint(template_statistics_blueprint) application.register_blueprint(events_blueprint) + application.register_blueprint(provider_details_blueprint, url_prefix='/provider-details') return application @@ -86,7 +91,7 @@ def init_app(app): url_for('notifications.process_ses_response'), url_for('notifications.process_firetext_response'), url_for('notifications.process_mmg_response'), - url_for('status.show_delivery_status'), + url_for('status.show_delivery_status') ] if request.path not in no_auth_req: from app.authentication import auth diff --git a/app/celery/tasks.py b/app/celery/tasks.py index c22f8bac2..cee8ab46e 100644 --- a/app/celery/tasks.py +++ b/app/celery/tasks.py @@ -3,12 +3,12 @@ from datetime import datetime from flask import current_app from sqlalchemy.exc import SQLAlchemyError - -from app.clients.email.aws_ses import AwsSesClientException -from app.clients.sms.firetext import FiretextClientException -from app.clients.sms.mmg import MMGClientException +from app import clients +from app.clients.email import EmailClientException +from app.clients.sms import SmsClientException from app.dao.services_dao import dao_fetch_service_by_id from app.dao.templates_dao import dao_get_template_by_id +from app.dao.provider_details_dao import get_provider_details_by_notification_type from notifications_utils.template import Template, unlink_govuk_escaped @@ -23,10 +23,7 @@ from app import ( DATETIME_FORMAT, DATE_FORMAT, notify_celery, - encryption, - firetext_client, - aws_ses_client, - mmg_client + encryption ) from app.aws import s3 @@ -160,7 +157,7 @@ def process_job(job_id): 'personalisation': { key: personalisation.get(key) for key in template.placeholders - } + } }) if template.template_type == 'sms': @@ -207,7 +204,8 @@ def remove_job(job_id): def send_sms(service_id, notification_id, encrypted_notification, created_at): notification = encryption.decrypt(encrypted_notification) service = dao_fetch_service_by_id(service_id) - client = mmg_client + + provider = provider_to_use('sms', notification_id) restricted = False @@ -236,23 +234,23 @@ def send_sms(service_id, notification_id, encrypted_notification, created_at): status='failed' if restricted else 'sending', created_at=datetime.strptime(created_at, DATETIME_FORMAT), sent_at=sent_at, - sent_by=client.get_name(), + sent_by=provider.get_name(), content_char_count=template.replaced_content_count ) - dao_create_notification(notification_db_object, TEMPLATE_TYPE_SMS, client.get_name()) + dao_create_notification(notification_db_object, TEMPLATE_TYPE_SMS, provider.get_name()) if restricted: return try: - client.send_sms( + provider.send_sms( to=validate_and_format_phone_number(notification['to']), content=template.replaced, reference=str(notification_id) ) - except MMGClientException as e: + except SmsClientException as e: current_app.logger.error( "SMS notification {} failed".format(notification_id) ) @@ -270,9 +268,10 @@ def send_sms(service_id, notification_id, encrypted_notification, created_at): @notify_celery.task(name="send-email") def send_email(service_id, notification_id, from_address, encrypted_notification, created_at): notification = encryption.decrypt(encrypted_notification) - client = aws_ses_client service = dao_fetch_service_by_id(service_id) + provider = provider_to_use('email', notification_id) + restricted = False if not service_allowed_to_send_to(notification['to'], service): @@ -292,10 +291,10 @@ def send_email(service_id, notification_id, from_address, encrypted_notification status='failed' if restricted else 'sending', created_at=datetime.strptime(created_at, DATETIME_FORMAT), sent_at=sent_at, - sent_by=client.get_name() + sent_by=provider.get_name() ) - dao_create_notification(notification_db_object, TEMPLATE_TYPE_EMAIL, client.get_name()) + dao_create_notification(notification_db_object, TEMPLATE_TYPE_EMAIL, provider.get_name()) if restricted: return @@ -305,7 +304,7 @@ def send_email(service_id, notification_id, from_address, encrypted_notification dao_get_template_by_id(notification['template']).__dict__, values=notification.get('personalisation', {}) ) - reference = client.send_email( + reference = provider.send_email( from_address, notification['to'], template.replaced_subject, @@ -313,7 +312,7 @@ def send_email(service_id, notification_id, from_address, encrypted_notification html_body=template.as_HTML_email, ) update_notification_reference_by_id(notification_id, reference) - except AwsSesClientException as e: + except EmailClientException as e: current_app.logger.exception(e) notification_db_object.status = 'failed' @@ -327,20 +326,15 @@ def send_email(service_id, notification_id, from_address, encrypted_notification @notify_celery.task(name='send-sms-code') def send_sms_code(encrypted_verification): + provider = provider_to_use('sms', 'send-sms-code') + verification_message = encryption.decrypt(encrypted_verification) try: - mmg_client.send_sms(validate_and_format_phone_number(verification_message['to']), - "{} is your Notify authentication code".format( - verification_message['secret_code']), - 'send-sms-code') - except MMGClientException as e: - current_app.logger.exception(e) - - -def send_sms_via_firetext(to, content, reference): - try: - firetext_client.send_sms(to=to, content=content, reference=reference) - except FiretextClientException as e: + provider.send_sms(validate_and_format_phone_number(verification_message['to']), + "{} is your Notify authentication code".format( + verification_message['secret_code']), + 'send-sms-code') + except SmsClientException as e: current_app.logger.exception(e) @@ -371,6 +365,8 @@ def invited_user_url(base_url, token): @notify_celery.task(name='email-invited-user') def email_invited_user(encrypted_invitation): + provider = provider_to_use('email', 'email-invited-user') + invitation = encryption.decrypt(encrypted_invitation) url = invited_user_url(current_app.config['ADMIN_BASE_URL'], invitation['token']) @@ -384,11 +380,11 @@ def email_invited_user(encrypted_invitation): current_app.config['NOTIFY_EMAIL_DOMAIN'] ) subject_line = invitation_subject_line(invitation['user_name'], invitation['service_name']) - aws_ses_client.send_email(email_from, - invitation['to'], - subject_line, - invitation_content) - except AwsSesClientException as e: + provider.send_email(email_from, + invitation['to'], + subject_line, + invitation_content) + except EmailClientException as e: current_app.logger.exception(e) @@ -406,17 +402,23 @@ def password_reset_message(name, url): @notify_celery.task(name='email-reset-password') def email_reset_password(encrypted_reset_password_message): + provider = provider_to_use('email', 'email-reset-password') + reset_password_message = encryption.decrypt(encrypted_reset_password_message) try: email_from = '"GOV.UK Notify" <{}>'.format( current_app.config['VERIFY_CODE_FROM_EMAIL_ADDRESS'] ) - aws_ses_client.send_email(email_from, - reset_password_message['to'], - "Reset your GOV.UK Notify password", - password_reset_message(name=reset_password_message['name'], - url=reset_password_message['reset_password_url'])) - except AwsSesClientException as e: + provider.send_email( + email_from, + reset_password_message['to'], + "Reset your GOV.UK Notify password", + password_reset_message( + name=reset_password_message['name'], + url=reset_password_message['reset_password_url'] + ) + ) + except EmailClientException as e: current_app.logger.exception(e) @@ -431,22 +433,26 @@ def registration_verification_template(name, url): @notify_celery.task(name='email-registration-verification') def email_registration_verification(encrypted_verification_message): + provider = provider_to_use('email', 'email-reset-password') + verification_message = encryption.decrypt(encrypted_verification_message) try: email_from = '"GOV.UK Notify" <{}>'.format( current_app.config['VERIFY_CODE_FROM_EMAIL_ADDRESS'] ) - aws_ses_client.send_email(email_from, - verification_message['to'], - "Confirm GOV.UK Notify registration", - registration_verification_template(name=verification_message['name'], - url=verification_message['url'])) - except AwsSesClientException as e: + provider.send_email( + email_from, + verification_message['to'], + "Confirm GOV.UK Notify registration", + registration_verification_template( + name=verification_message['name'], + url=verification_message['url']) + ) + except EmailClientException as e: current_app.logger.exception(e) def service_allowed_to_send_to(recipient, service): - if not service.restricted: return True @@ -456,3 +462,17 @@ def service_allowed_to_send_to(recipient, service): [user.mobile_number, user.email_address] for user in service.users ) ) + + +def provider_to_use(notification_type, notification_id): + active_providers_in_order = [ + provider for provider in get_provider_details_by_notification_type(notification_type) if provider.active + ] + + if not active_providers_in_order: + current_app.logger.error( + "{} {} failed as no active providers".format(notification_type, notification_id) + ) + raise Exception("No active {} providers".format(notification_type)) + + return clients.get_client_by_name_and_type(active_providers_in_order[0].identifier, notification_type) diff --git a/app/clients/__init__.py b/app/clients/__init__.py index 30c359631..74750c51a 100644 --- a/app/clients/__init__.py +++ b/app/clients/__init__.py @@ -1,4 +1,3 @@ - class ClientException(Exception): ''' Base Exceptions for sending notifications that fail @@ -16,3 +15,30 @@ class Client(object): STATISTICS_REQUESTED = 'requested' STATISTICS_DELIVERED = 'delivered' STATISTICS_FAILURE = 'failure' + + +class Clients(object): + sms_clients = {} + email_clients = {} + + def init_app(self, sms_clients, email_clients): + for client in sms_clients: + self.sms_clients[client.name] = client + + for client in email_clients: + self.email_clients[client.name] = client + + def get_sms_client(self, name): + return self.sms_clients.get(name) + + def get_email_client(self, name): + return self.email_clients.get(name) + + def get_client_by_name_and_type(self, name, notification_type): + assert notification_type in ['email', 'sms'] + + if notification_type == 'email': + return self.get_email_client(name) + + if notification_type == 'sms': + return self.get_sms_client(name) diff --git a/app/dao/notifications_dao.py b/app/dao/notifications_dao.py index 6b5bc2a1d..27a35532a 100644 --- a/app/dao/notifications_dao.py +++ b/app/dao/notifications_dao.py @@ -1,4 +1,3 @@ -import math from sqlalchemy import (desc, func, Integer) from sqlalchemy.sql.expression import cast @@ -20,8 +19,8 @@ from app.models import ( TEMPLATE_TYPE_SMS, TEMPLATE_TYPE_EMAIL, Template, - ProviderStatistics -) + ProviderStatistics, + ProviderDetails) from notifications_utils.template import get_sms_fragment_count @@ -88,7 +87,9 @@ def dao_get_template_statistics_for_service(service_id, limit_days=None): @transactional -def dao_create_notification(notification, notification_type, provider): +def dao_create_notification(notification, notification_type, provider_identifier): + provider = ProviderDetails.query.filter_by(identifier=provider_identifier).one() + if notification.job_id: db.session.query(Job).filter_by( id=notification.job_id @@ -125,7 +126,7 @@ def dao_create_notification(notification, notification_type, provider): update_count = db.session.query(ProviderStatistics).filter_by( day=date.today(), service_id=notification.service_id, - provider=provider + provider_id=provider.id ).update({'unit_count': ProviderStatistics.unit_count + ( 1 if notification_type == TEMPLATE_TYPE_EMAIL else get_sms_fragment_count(notification.content_char_count))}) @@ -133,7 +134,7 @@ def dao_create_notification(notification, notification_type, provider): provider_stats = ProviderStatistics( day=notification.created_at.date(), service_id=notification.service_id, - provider=provider, + provider_id=provider.id, unit_count=1 if notification_type == TEMPLATE_TYPE_EMAIL else get_sms_fragment_count( notification.content_char_count)) db.session.add(provider_stats) diff --git a/app/dao/provider_details_dao.py b/app/dao/provider_details_dao.py new file mode 100644 index 000000000..275d74e5a --- /dev/null +++ b/app/dao/provider_details_dao.py @@ -0,0 +1,23 @@ +from sqlalchemy import asc +from app.dao.dao_utils import transactional +from app.models import ProviderDetails +from app import db + + +def get_provider_details(): + return ProviderDetails.query.order_by(asc(ProviderDetails.priority), asc(ProviderDetails.notification_type)).all() + + +def get_provider_details_by_id(provider_details_id): + return ProviderDetails.query.get(provider_details_id) + + +def get_provider_details_by_notification_type(notification_type): + return ProviderDetails.query.filter_by( + notification_type=notification_type + ).order_by(asc(ProviderDetails.priority)).all() + + +@transactional +def dao_update_provider_details(provider_details): + db.session.add(provider_details) diff --git a/app/dao/provider_rates_dao.py b/app/dao/provider_rates_dao.py index e42909a8f..145bd431e 100644 --- a/app/dao/provider_rates_dao.py +++ b/app/dao/provider_rates_dao.py @@ -1,9 +1,11 @@ -from app.models import ProviderRates +from app.models import ProviderRates, ProviderDetails from app import db from app.dao.dao_utils import transactional @transactional -def create_provider_rates(provider, valid_from, rate): - provider_rates = ProviderRates(provider=provider, valid_from=valid_from, rate=rate) +def create_provider_rates(provider_identifier, valid_from, rate): + provider = ProviderDetails.query.filter_by(identifier=provider_identifier).one() + + provider_rates = ProviderRates(provider_id=provider.id, valid_from=valid_from, rate=rate) db.session.add(provider_rates) diff --git a/app/dao/provider_statistics_dao.py b/app/dao/provider_statistics_dao.py index 48489b75a..f84625429 100644 --- a/app/dao/provider_statistics_dao.py +++ b/app/dao/provider_statistics_dao.py @@ -1,6 +1,5 @@ from sqlalchemy import func -from app import db -from app.models import (ProviderStatistics, SMS_PROVIDERS, EMAIL_PROVIDERS) +from app.models import (ProviderStatistics, SMS_PROVIDERS, EMAIL_PROVIDERS, ProviderDetails) def get_provider_statistics(service, **kwargs): @@ -33,7 +32,9 @@ def get_fragment_count(service, date_from, date_to): def filter_query(query, service, **kwargs): query = query.filter_by(service=service) if 'providers' in kwargs: - query = query.filter(ProviderStatistics.provider.in_(kwargs['providers'])) + providers = ProviderDetails.query.filter(ProviderDetails.identifier.in_(kwargs['providers'])).all() + provider_ids = [provider.id for provider in providers] + query = query.filter(ProviderStatistics.provider_id.in_(provider_ids)) if 'date_from' in kwargs: query.filter(ProviderStatistics.day >= kwargs['date_from']) if 'date_to' in kwargs: diff --git a/app/models.py b/app/models.py index e524a14bb..497153a05 100644 --- a/app/models.py +++ b/app/models.py @@ -191,13 +191,18 @@ SMS_PROVIDERS = [MMG_PROVIDER, TWILIO_PROVIDER, FIRETEXT_PROVIDER] EMAIL_PROVIDERS = [SES_PROVIDER] PROVIDERS = SMS_PROVIDERS + EMAIL_PROVIDERS +NOTIFICATION_TYPE = ['email', 'sms', 'letter'] + class ProviderStatistics(db.Model): __tablename__ = 'provider_statistics' id = db.Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) day = db.Column(db.Date, nullable=False) - provider = db.Column(db.Enum(*PROVIDERS, name='providers'), nullable=False) + provider_id = db.Column(UUID(as_uuid=True), db.ForeignKey('provider_details.id'), index=True, nullable=False) + provider = db.relationship( + 'ProviderDetails', backref=db.backref('provider_stats', lazy='dynamic') + ) service_id = db.Column(UUID(as_uuid=True), db.ForeignKey('services.id'), index=True, nullable=False) service = db.relationship('Service', backref=db.backref('service_provider_stats', lazy='dynamic')) unit_count = db.Column(db.BigInteger, nullable=False) @@ -208,8 +213,20 @@ class ProviderRates(db.Model): id = db.Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) valid_from = db.Column(db.DateTime, nullable=False) - provider = db.Column(db.Enum(*PROVIDERS, name='providers'), nullable=False) rate = db.Column(db.Numeric(), nullable=False) + provider_id = db.Column(UUID(as_uuid=True), db.ForeignKey('provider_details.id'), index=True, nullable=False) + provider = db.relationship('ProviderDetails', backref=db.backref('provider_rates', lazy='dynamic')) + + +class ProviderDetails(db.Model): + __tablename__ = 'provider_details' + + id = db.Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + display_name = db.Column(db.String, nullable=False) + identifier = db.Column(db.String, nullable=False) + priority = db.Column(db.Integer, nullable=False) + notification_type = db.Column(db.Enum(*NOTIFICATION_TYPE, name='notification_type'), nullable=False) + active = db.Column(db.Boolean, default=False) JOB_STATUS_TYPES = ['pending', 'in progress', 'finished', 'sending limits exceeded'] diff --git a/app/provider_details/__init__.py b/app/provider_details/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/app/provider_details/rest.py b/app/provider_details/rest.py new file mode 100644 index 000000000..f976f1560 --- /dev/null +++ b/app/provider_details/rest.py @@ -0,0 +1,42 @@ +from flask import Blueprint, jsonify, request + +from app.schemas import provider_details_schema +from app.dao.provider_details_dao import ( + get_provider_details, + get_provider_details_by_id, + get_provider_details_by_id, + dao_update_provider_details +) + +provider_details = Blueprint('provider_details', __name__) + + +@provider_details.route('', methods=['GET']) +def get_providers(): + data, errors = provider_details_schema.dump(get_provider_details(), many=True) + return jsonify(provider_details=data) + + +@provider_details.route('/', methods=['GET']) +def get_provider_by_id(provider_details_id): + data, errors = provider_details_schema.dump(get_provider_details_by_id(provider_details_id)) + return jsonify(provider_details=data) + + +@provider_details.route('/', methods=['POST']) +def update_provider_details(provider_details_id): + fetched_provider_details = get_provider_details_by_id(provider_details_id) + + current_data = dict(provider_details_schema.dump(fetched_provider_details).data.items()) + current_data.update(request.get_json()) + update_dict, errors = provider_details_schema.load(current_data) + if errors: + return jsonify(result="error", message=errors), 400 + + if "identifier" in request.get_json().keys(): + return jsonify(message={ + "identifier": ["Not permitted to be updated"] + }, result='error'), 400 + + dao_update_provider_details(update_dict) + return jsonify(provider_details=provider_details_schema.dump(fetched_provider_details).data), 200 diff --git a/app/schemas.py b/app/schemas.py index 51f08c345..7e8bcfbb6 100644 --- a/app/schemas.py +++ b/app/schemas.py @@ -84,6 +84,12 @@ class UserSchema(BaseSchema): "_password", "verify_codes") +class ProviderDetailsSchema(BaseSchema): + class Meta: + model = models.ProviderDetails + exclude = ("provider_rates", "provider_stats") + + class ServiceSchema(BaseSchema): created_by = field_for(models.Service, 'created_by', required=True) @@ -402,4 +408,5 @@ api_key_history_schema = ApiKeyHistorySchema() template_history_schema = TemplateHistorySchema() event_schema = EventSchema() from_to_date_schema = FromToDateSchema() +provider_details_schema = ProviderDetailsSchema() week_aggregate_notification_statistics_schema = WeekAggregateNotificationStatisticsSchema() diff --git a/config.py b/config.py index 164dd7d9b..938887e9a 100644 --- a/config.py +++ b/config.py @@ -23,6 +23,7 @@ class Config(object): SQLALCHEMY_COMMIT_ON_TEARDOWN = False SQLALCHEMY_DATABASE_URI = os.environ['SQLALCHEMY_DATABASE_URI'] SQLALCHEMY_RECORD_QUERIES = True + SQLALCHEMY_TRACK_MODIFICATIONS = True VERIFY_CODE_FROM_EMAIL_ADDRESS = os.environ['VERIFY_CODE_FROM_EMAIL_ADDRESS'] NOTIFY_EMAIL_DOMAIN = os.environ['NOTIFY_EMAIL_DOMAIN'] PAGE_SIZE = 50 diff --git a/migrations/versions/0005_add_provider_stats.py b/migrations/versions/0005_add_provider_stats.py index 8db036ba6..01ffad89a 100644 --- a/migrations/versions/0005_add_provider_stats.py +++ b/migrations/versions/0005_add_provider_stats.py @@ -15,7 +15,6 @@ import sqlalchemy as sa from sqlalchemy.dialects import postgresql def upgrade(): - ### commands auto generated by Alembic - please adjust! ### op.create_table('provider_rates', sa.Column('id', postgresql.UUID(as_uuid=True), nullable=False), sa.Column('valid_from', sa.DateTime(), nullable=False), @@ -33,12 +32,9 @@ def upgrade(): sa.PrimaryKeyConstraint('id') ) op.create_index(op.f('ix_provider_statistics_service_id'), 'provider_statistics', ['service_id'], unique=False) - ### end Alembic commands ### def downgrade(): - ### commands auto generated by Alembic - please adjust! ### op.drop_index(op.f('ix_provider_statistics_service_id'), table_name='provider_statistics') op.drop_table('provider_statistics') op.drop_table('provider_rates') - ### end Alembic commands ### diff --git a/migrations/versions/0011_ad_provider_details.py b/migrations/versions/0011_ad_provider_details.py new file mode 100644 index 000000000..9dad883be --- /dev/null +++ b/migrations/versions/0011_ad_provider_details.py @@ -0,0 +1,73 @@ +"""empty message + +Revision ID: 0011_ad_provider_details +Revises: 0010_events_table +Create Date: 2016-05-05 09:14:29.328841 + +""" + +# revision identifiers, used by Alembic. +revision = '0011_ad_provider_details' +down_revision = '0010_events_table' + +import uuid + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +def upgrade(): + op.create_table('provider_details', + sa.Column('id', postgresql.UUID(as_uuid=True), nullable=False), + sa.Column('display_name', sa.String(), nullable=False), + sa.Column('identifier', sa.String(), nullable=False), + sa.Column('priority', sa.Integer(), nullable=False), + sa.Column('notification_type', sa.Enum('email', 'sms', 'letter', name='notification_type'), nullable=False), + sa.Column('active', sa.Boolean(), nullable=True), + sa.PrimaryKeyConstraint('id') + ) + + op.add_column('provider_rates', sa.Column('provider_id', postgresql.UUID(as_uuid=True), nullable=True)) + op.create_index(op.f('ix_provider_rates_provider_id'), 'provider_rates', ['provider_id'], unique=False) + op.create_foreign_key("provider_rate_to_provider_fk", 'provider_rates', 'provider_details', ['provider_id'], ['id']) + op.add_column('provider_statistics', sa.Column('provider_id', postgresql.UUID(as_uuid=True), nullable=True)) + op.create_index(op.f('ix_provider_statistics_provider_id'), 'provider_statistics', ['provider_id'], unique=False) + op.create_foreign_key('provider_stats_to_provider_fk', 'provider_statistics', 'provider_details', ['provider_id'], ['id']) + + op.execute( + "INSERT INTO provider_details (id, display_name, identifier, priority, notification_type, active) values ('{}', 'MMG', 'mmg', 10, 'sms', true)".format(str(uuid.uuid4())) + ) + op.execute( + "INSERT INTO provider_details (id, display_name, identifier, priority, notification_type, active) values ('{}', 'Firetext', 'firetext', 20, 'sms', true)".format(str(uuid.uuid4())) + ) + op.execute( + "INSERT INTO provider_details (id, display_name, identifier, priority, notification_type, active) values ('{}', 'AWS SES', 'ses', 10, 'email', true)".format(str(uuid.uuid4())) + ) + op.execute( + "UPDATE provider_rates set provider_id = (select id from provider_details where identifier = 'mmg') where provider = 'mmg'" + ) + op.execute( + "UPDATE provider_rates set provider_id = (select id from provider_details where identifier = 'firetext') where provider = 'firetext'" + ) + op.execute( + "UPDATE provider_rates set provider_id = (select id from provider_details where identifier = 'ses') where provider = 'ses'" + ) + op.execute( + "UPDATE provider_statistics set provider_id = (select id from provider_details where identifier = 'mmg') where provider = 'mmg'" + ) + op.execute( + "UPDATE provider_statistics set provider_id = (select id from provider_details where identifier = 'firetext') where provider = 'firetext'" + ) + op.execute( + "UPDATE provider_statistics set provider_id = (select id from provider_details where identifier = 'ses') where provider = 'ses'" + ) + +def downgrade(): + + op.drop_index(op.f('ix_provider_statistics_provider_id'), table_name='provider_statistics') + op.drop_column('provider_statistics', 'provider_id') + op.drop_index(op.f('ix_provider_rates_provider_id'), table_name='provider_rates') + op.drop_column('provider_rates', 'provider_id') + + op.drop_table('provider_details') + op.execute('drop type notification_type') diff --git a/migrations/versions/0012_complete_provider_details.py b/migrations/versions/0012_complete_provider_details.py new file mode 100644 index 000000000..2165b336a --- /dev/null +++ b/migrations/versions/0012_complete_provider_details.py @@ -0,0 +1,80 @@ +"""empty message + +Revision ID: 0012_complete_provider_details +Revises: 0011_ad_provider_details +Create Date: 2016-05-05 09:18:26.926275 + +""" + +# revision identifiers, used by Alembic. +revision = '0012_complete_provider_details' +down_revision = '0011_ad_provider_details' + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql +from sqlalchemy.dialects.postgresql import ENUM + + +def upgrade(): + + op.alter_column('provider_rates', 'provider_id', + existing_type=postgresql.UUID(), + nullable=False) + op.drop_column('provider_rates', 'provider') + op.alter_column('provider_statistics', 'provider_id', + existing_type=postgresql.UUID(), + nullable=False) + op.drop_column('provider_statistics', 'provider') + op.execute('drop type providers') + + +def downgrade(): + + provider_enum = ENUM('loadtesting', 'firetext', 'mmg', 'ses', 'twilio', name='providers', create_type=True) + provider_enum.create(op.get_bind(), checkfirst=False) + + op.add_column('provider_statistics', sa.Column('provider', provider_enum, autoincrement=False, nullable=True)) + op.alter_column('provider_statistics', 'provider_id', + existing_type=postgresql.UUID(), + nullable=True) + op.add_column('provider_rates', sa.Column('provider', provider_enum, autoincrement=False, nullable=True)) + op.alter_column('provider_rates', 'provider_id', + existing_type=postgresql.UUID(), + nullable=True) + + + op.execute( + "UPDATE provider_rates set provider = 'mmg' where provider_id = (select id from provider_details where identifier = 'mmg')" + ) + op.execute( + "UPDATE provider_rates set provider = 'firetext' where provider_id = (select id from provider_details where identifier = 'firetext')" + ) + op.execute( + "UPDATE provider_rates set provider = 'ses' where provider_id = (select id from provider_details where identifier = 'ses')" + ) + op.execute( + "UPDATE provider_rates set provider = 'loadtesting' where provider_id = (select id from provider_details where identifier = 'loadtesting')" + ) + + op.execute( + "UPDATE provider_statistics set provider = 'mmg' where provider_id = (select id from provider_details where identifier = 'mmg')" + ) + op.execute( + "UPDATE provider_statistics set provider = 'firetext' where provider_id = (select id from provider_details where identifier = 'firetext')" + ) + op.execute( + "UPDATE provider_statistics set provider = 'ses' where provider_id = (select id from provider_details where identifier = 'ses')" + ) + op.execute( + "UPDATE provider_statistics set provider = 'loadtesting' where provider_id = (select id from provider_details where identifier = 'loadtesting')" + ) + + + op.alter_column('provider_rates', 'provider', + existing_type=postgresql.UUID(), + nullable=False) + + op.alter_column('provider_statistics', 'provider', + existing_type=postgresql.UUID(), + nullable=False) \ No newline at end of file diff --git a/requirements_for_test.txt b/requirements_for_test.txt index dbbcc7ae3..66f6f39eb 100644 --- a/requirements_for_test.txt +++ b/requirements_for_test.txt @@ -1,6 +1,6 @@ -r requirements.txt pep8==1.5.7 -pytest==2.8.1 +pytest==2.8.3 pytest-mock==0.8.1 pytest-cov==2.2.0 mock==1.0.1 diff --git a/tests/app/celery/test_tasks.py b/tests/app/celery/test_tasks.py index 24ccd2a44..31bbd4ad9 100644 --- a/tests/app/celery/test_tasks.py +++ b/tests/app/celery/test_tasks.py @@ -13,12 +13,13 @@ from app.celery.tasks import ( delete_verify_codes, delete_invitations, delete_failed_notifications, - delete_successful_notifications + delete_successful_notifications, + provider_to_use ) from app import (aws_ses_client, encryption, DATETIME_FORMAT, mmg_client) from app.clients.email.aws_ses import AwsSesClientException from app.clients.sms.mmg import MMGClientException -from app.dao import notifications_dao, jobs_dao +from app.dao import notifications_dao, jobs_dao, provider_details_dao from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.orm.exc import NoResultFound from app.celery.tasks import s3 @@ -40,9 +41,40 @@ class AnyStringWith(str): def __eq__(self, other): return self in other + mmg_error = {'Error': '40', 'Description': 'error'} +def test_should_return_highest_priority_active_provider(notify_db, notify_db_session): + providers = provider_details_dao.get_provider_details_by_notification_type('sms') + first = providers[0] + second = providers[1] + + assert provider_to_use('sms', '1234').name == first.identifier + + first.priority = 20 + second.priority = 10 + + provider_details_dao.dao_update_provider_details(first) + provider_details_dao.dao_update_provider_details(second) + + assert provider_to_use('sms', '1234').name == second.identifier + + first.priority = 10 + first.active = False + second.priority = 20 + + provider_details_dao.dao_update_provider_details(first) + provider_details_dao.dao_update_provider_details(second) + + assert provider_to_use('sms', '1234').name == second.identifier + + first.active = True + provider_details_dao.dao_update_provider_details(first) + + assert provider_to_use('sms', '1234').name == first.identifier + + def test_should_call_delete_notifications_more_than_week_in_task(notify_api, mocker): mocked = mocker.patch('app.celery.tasks.delete_notifications_created_more_than_a_week_ago') delete_successful_notifications() @@ -804,7 +836,7 @@ def test_email_invited_user_should_send_email(notify_api, mocker): expected_content) -def test_email_reset_password_should_send_email(notify_api, mocker): +def test_email_reset_password_should_send_email(notify_db, notify_db_session, notify_api, mocker): with notify_api.test_request_context(): reset_password_message = {'to': 'someone@it.gov.uk', 'name': 'Some One', diff --git a/tests/app/conftest.py b/tests/app/conftest.py index a3a973e74..50b02bf43 100644 --- a/tests/app/conftest.py +++ b/tests/app/conftest.py @@ -10,10 +10,8 @@ from app.models import ( Notification, InvitedUser, Permission, - MMG_PROVIDER, - SES_PROVIDER, - TWILIO_PROVIDER, ProviderStatistics, + ProviderDetails, NotificationStatistics) from app.dao.users_dao import (save_model_user, create_user_code, create_secret_code) from app.dao.services_dao import (dao_create_service, dao_add_user_to_service) @@ -329,7 +327,7 @@ def sample_notification(notify_db, notification_id = uuid.uuid4() if provider_name is None: - provider_name = mmg_provider_name() if template.template_type == 'sms' else ses_provider_name() + provider = mmg_provider() if template.template_type == 'sms' else ses_provider() if to_field: to = to_field @@ -351,7 +349,7 @@ def sample_notification(notify_db, } notification = Notification(**data) if create: - dao_create_notification(notification, template.template_type, provider_name) + dao_create_notification(notification, template.template_type, provider.identifier) return notification @@ -463,18 +461,18 @@ def fake_uuid(): @pytest.fixture(scope='function') -def ses_provider_name(): - return SES_PROVIDER +def ses_provider(): + return ProviderDetails.query.filter_by(identifier='ses').one() @pytest.fixture(scope='function') -def mmg_provider_name(): - return MMG_PROVIDER +def firetext_provider(): + return ProviderDetails.query.filter_by(identifier='mmg').one() @pytest.fixture(scope='function') -def twilio_provider_name(): - return TWILIO_PROVIDER +def mmg_provider(): + return ProviderDetails.query.filter_by(identifier='mmg').one() @pytest.fixture(scope='function') @@ -484,13 +482,14 @@ def sample_provider_statistics(notify_db, provider=None, day=None, unit_count=1): + if provider is None: - provider = mmg_provider_name() + provider = ProviderDetails.query.filter_by(identifier='mmg').first() if day is None: day = date.today() stats = ProviderStatistics( service=sample_service, - provider=provider, + provider_id=provider.id, day=day, unit_count=unit_count) notify_db.session.add(stats) diff --git a/tests/app/dao/test_notification_dao.py b/tests/app/dao/test_notification_dao.py index 21dd3c337..2732ece42 100644 --- a/tests/app/dao/test_notification_dao.py +++ b/tests/app/dao/test_notification_dao.py @@ -44,14 +44,14 @@ def test_should_by_able_to_update_reference_by_id(sample_notification): assert Notification.query.get(sample_notification.id).reference == 'reference' -def test_should_by_able_to_update_status_by_reference(sample_email_template, ses_provider_name): +def test_should_by_able_to_update_status_by_reference(sample_email_template, ses_provider): data = _notification_json(sample_email_template) notification = Notification(**data) dao_create_notification( notification, sample_email_template.template_type, - ses_provider_name) + ses_provider.identifier) assert Notification.query.get(notification.id).status == "sending" update_notification_reference_by_id(notification.id, 'reference') @@ -100,11 +100,11 @@ def test_should_be_able_to_record_statistics_failure_for_sms(sample_notification ).one().sms_failed == 1 -def test_should_be_able_to_record_statistics_failure_for_email(sample_email_template, ses_provider_name): +def test_should_be_able_to_record_statistics_failure_for_email(sample_email_template, ses_provider): data = _notification_json(sample_email_template) notification = Notification(**data) - dao_create_notification(notification, sample_email_template.template_type, ses_provider_name) + dao_create_notification(notification, sample_email_template.template_type, ses_provider.identifier) update_notification_reference_by_id(notification.id, 'reference') count = update_notification_status_by_reference('reference', 'failed', 'failure') @@ -129,11 +129,11 @@ def test_should_return_zero_count_if_no_notification_with_reference(): assert update_notification_status_by_reference('something', 'delivered', 'delivered') == 0 -def test_should_be_able_to_get_statistics_for_a_service(sample_template, mmg_provider_name): +def test_should_be_able_to_get_statistics_for_a_service(sample_template, mmg_provider): data = _notification_json(sample_template) notification = Notification(**data) - dao_create_notification(notification, sample_template.template_type, mmg_provider_name) + dao_create_notification(notification, sample_template.template_type, mmg_provider.identifier) stats = dao_get_notification_statistics_for_service(sample_template.service.id) assert len(stats) == 1 @@ -148,11 +148,11 @@ def test_should_be_able_to_get_statistics_for_a_service(sample_template, mmg_pro assert stats[0].emails_failed == 0 -def test_should_be_able_to_get_statistics_for_a_service_for_a_day(sample_template, mmg_provider_name): +def test_should_be_able_to_get_statistics_for_a_service_for_a_day(sample_template, mmg_provider): now = datetime.utcnow() data = _notification_json(sample_template) notification = Notification(**data) - dao_create_notification(notification, sample_template.template_type, mmg_provider_name) + dao_create_notification(notification, sample_template.template_type, mmg_provider.identifier) stat = dao_get_notification_statistics_for_service_and_day( sample_template.service.id, now.date() ) @@ -166,25 +166,25 @@ def test_should_be_able_to_get_statistics_for_a_service_for_a_day(sample_templat assert stat.service_id == notification.service_id -def test_should_return_none_if_no_statistics_for_a_service_for_a_day(sample_template, mmg_provider_name): +def test_should_return_none_if_no_statistics_for_a_service_for_a_day(sample_template, mmg_provider): data = _notification_json(sample_template) notification = Notification(**data) - dao_create_notification(notification, sample_template.template_type, mmg_provider_name) + dao_create_notification(notification, sample_template.template_type, mmg_provider.identifier) assert not dao_get_notification_statistics_for_service_and_day( sample_template.service.id, (datetime.utcnow() - timedelta(days=1)).date() ) -def test_should_be_able_to_get_all_statistics_for_a_service(sample_template, mmg_provider_name): +def test_should_be_able_to_get_all_statistics_for_a_service(sample_template, mmg_provider): data = _notification_json(sample_template) notification_1 = Notification(**data) notification_2 = Notification(**data) notification_3 = Notification(**data) - dao_create_notification(notification_1, sample_template.template_type, mmg_provider_name) - dao_create_notification(notification_2, sample_template.template_type, mmg_provider_name) - dao_create_notification(notification_3, sample_template.template_type, mmg_provider_name) + dao_create_notification(notification_1, sample_template.template_type, mmg_provider.identifier) + dao_create_notification(notification_2, sample_template.template_type, mmg_provider.identifier) + dao_create_notification(notification_3, sample_template.template_type, mmg_provider.identifier) stats = dao_get_notification_statistics_for_service(sample_template.service.id) assert len(stats) == 1 @@ -192,7 +192,7 @@ def test_should_be_able_to_get_all_statistics_for_a_service(sample_template, mmg assert stats[0].sms_requested == 3 -def test_should_be_able_to_get_all_statistics_for_a_service_for_several_days(sample_template, mmg_provider_name): +def test_should_be_able_to_get_all_statistics_for_a_service_for_several_days(sample_template, mmg_provider): data = _notification_json(sample_template) today = datetime.utcnow() @@ -210,9 +210,9 @@ def test_should_be_able_to_get_all_statistics_for_a_service_for_several_days(sam 'created_at': two_days_ago }) notification_3 = Notification(**data) - dao_create_notification(notification_1, sample_template.template_type, mmg_provider_name) - dao_create_notification(notification_2, sample_template.template_type, mmg_provider_name) - dao_create_notification(notification_3, sample_template.template_type, mmg_provider_name) + dao_create_notification(notification_1, sample_template.template_type, mmg_provider.identifier) + dao_create_notification(notification_2, sample_template.template_type, mmg_provider.identifier) + dao_create_notification(notification_3, sample_template.template_type, mmg_provider.identifier) stats = dao_get_notification_statistics_for_service(sample_template.service.id) assert len(stats) == 3 @@ -232,7 +232,7 @@ def test_should_be_empty_list_if_no_statistics_for_a_service(sample_service): def test_should_be_able_to_get_all_statistics_for_a_service_for_several_days_previous(sample_template, - mmg_provider_name): + mmg_provider): data = _notification_json(sample_template) today = datetime.utcnow() @@ -250,9 +250,9 @@ def test_should_be_able_to_get_all_statistics_for_a_service_for_several_days_pre 'created_at': eight_days_ago }) notification_3 = Notification(**data) - dao_create_notification(notification_1, sample_template.template_type, mmg_provider_name) - dao_create_notification(notification_2, sample_template.template_type, mmg_provider_name) - dao_create_notification(notification_3, sample_template.template_type, mmg_provider_name) + dao_create_notification(notification_1, sample_template.template_type, mmg_provider.identifier) + dao_create_notification(notification_2, sample_template.template_type, mmg_provider.identifier) + dao_create_notification(notification_3, sample_template.template_type, mmg_provider.identifier) stats = dao_get_notification_statistics_for_service( sample_template.service.id, 7 @@ -266,7 +266,7 @@ def test_should_be_able_to_get_all_statistics_for_a_service_for_several_days_pre assert stats[1].day == seven_days_ago.date() -def test_save_notification_creates_sms_and_template_stats(sample_template, sample_job, mmg_provider_name): +def test_save_notification_creates_sms_and_template_stats(sample_template, sample_job, mmg_provider): assert Notification.query.count() == 0 assert NotificationStatistics.query.count() == 0 assert TemplateStatistics.query.count() == 0 @@ -274,7 +274,7 @@ def test_save_notification_creates_sms_and_template_stats(sample_template, sampl data = _notification_json(sample_template, sample_job.id) notification = Notification(**data) - dao_create_notification(notification, sample_template.template_type, mmg_provider_name) + dao_create_notification(notification, sample_template.template_type, mmg_provider.identifier) assert Notification.query.count() == 1 notification_from_db = Notification.query.all()[0] @@ -302,7 +302,7 @@ def test_save_notification_creates_sms_and_template_stats(sample_template, sampl assert template_stats.usage_count == 1 -def test_save_notification_and_create_email_and_template_stats(sample_email_template, sample_job, ses_provider_name): +def test_save_notification_and_create_email_and_template_stats(sample_email_template, sample_job, ses_provider): assert Notification.query.count() == 0 assert NotificationStatistics.query.count() == 0 @@ -311,7 +311,7 @@ def test_save_notification_and_create_email_and_template_stats(sample_email_temp data = _notification_json(sample_email_template, sample_job.id) notification = Notification(**data) - dao_create_notification(notification, sample_email_template.template_type, ses_provider_name) + dao_create_notification(notification, sample_email_template.template_type, ses_provider.identifier) assert Notification.query.count() == 1 notification_from_db = Notification.query.all()[0] @@ -340,12 +340,12 @@ def test_save_notification_and_create_email_and_template_stats(sample_email_temp @freeze_time("2016-01-01 00:00:00.000000") -def test_save_notification_handles_midnight_properly(sample_template, sample_job, mmg_provider_name): +def test_save_notification_handles_midnight_properly(sample_template, sample_job, mmg_provider): assert Notification.query.count() == 0 data = _notification_json(sample_template, sample_job.id) notification = Notification(**data) - dao_create_notification(notification, sample_template.template_type, mmg_provider_name) + dao_create_notification(notification, sample_template.template_type, mmg_provider.identifier) assert Notification.query.count() == 1 @@ -357,12 +357,12 @@ def test_save_notification_handles_midnight_properly(sample_template, sample_job @freeze_time("2016-01-01 23:59:59.999999") -def test_save_notification_handles_just_before_midnight_properly(sample_template, sample_job, mmg_provider_name): +def test_save_notification_handles_just_before_midnight_properly(sample_template, sample_job, mmg_provider): assert Notification.query.count() == 0 data = _notification_json(sample_template, sample_job.id) notification = Notification(**data) - dao_create_notification(notification, sample_template.template_type, mmg_provider_name) + dao_create_notification(notification, sample_template.template_type, mmg_provider.identifier) assert Notification.query.count() == 1 @@ -373,13 +373,13 @@ def test_save_notification_handles_just_before_midnight_properly(sample_template assert stats.day == date(2016, 1, 1) -def test_save_notification_and_increment_email_stats(sample_email_template, sample_job, ses_provider_name): +def test_save_notification_and_increment_email_stats(sample_email_template, sample_job, ses_provider): assert Notification.query.count() == 0 data = _notification_json(sample_email_template, sample_job.id) notification_1 = Notification(**data) notification_2 = Notification(**data) - dao_create_notification(notification_1, sample_email_template.template_type, ses_provider_name) + dao_create_notification(notification_1, sample_email_template.template_type, ses_provider.identifier) assert Notification.query.count() == 1 @@ -390,7 +390,7 @@ def test_save_notification_and_increment_email_stats(sample_email_template, samp assert stats1.emails_requested == 1 assert stats1.sms_requested == 0 - dao_create_notification(notification_2, sample_email_template.template_type, ses_provider_name) + dao_create_notification(notification_2, sample_email_template.template_type, ses_provider.identifier) assert Notification.query.count() == 2 @@ -402,13 +402,13 @@ def test_save_notification_and_increment_email_stats(sample_email_template, samp assert stats2.sms_requested == 0 -def test_save_notification_and_increment_sms_stats(sample_template, sample_job, mmg_provider_name): +def test_save_notification_and_increment_sms_stats(sample_template, sample_job, mmg_provider): assert Notification.query.count() == 0 data = _notification_json(sample_template, sample_job.id) notification_1 = Notification(**data) notification_2 = Notification(**data) - dao_create_notification(notification_1, sample_template.template_type, mmg_provider_name) + dao_create_notification(notification_1, sample_template.template_type, mmg_provider.identifier) assert Notification.query.count() == 1 @@ -419,7 +419,7 @@ def test_save_notification_and_increment_sms_stats(sample_template, sample_job, assert stats1.emails_requested == 0 assert stats1.sms_requested == 1 - dao_create_notification(notification_2, sample_template.template_type, mmg_provider_name) + dao_create_notification(notification_2, sample_template.template_type, mmg_provider.identifier) assert Notification.query.count() == 2 @@ -431,7 +431,7 @@ def test_save_notification_and_increment_sms_stats(sample_template, sample_job, assert stats2.sms_requested == 2 -def test_not_save_notification_and_not_create_stats_on_commit_error(sample_template, sample_job, mmg_provider_name): +def test_not_save_notification_and_not_create_stats_on_commit_error(sample_template, sample_job, mmg_provider): random_id = str(uuid.uuid4()) assert Notification.query.count() == 0 @@ -439,7 +439,7 @@ def test_not_save_notification_and_not_create_stats_on_commit_error(sample_templ notification = Notification(**data) with pytest.raises(SQLAlchemyError): - dao_create_notification(notification, sample_template.template_type, mmg_provider_name) + dao_create_notification(notification, sample_template.template_type, mmg_provider.identifier) assert Notification.query.count() == 0 assert Job.query.get(sample_job.id).notifications_sent == 0 @@ -447,12 +447,12 @@ def test_not_save_notification_and_not_create_stats_on_commit_error(sample_templ assert TemplateStatistics.query.count() == 0 -def test_save_notification_and_increment_job(sample_template, sample_job, mmg_provider_name): +def test_save_notification_and_increment_job(sample_template, sample_job, mmg_provider): assert Notification.query.count() == 0 data = _notification_json(sample_template, sample_job.id) notification = Notification(**data) - dao_create_notification(notification, sample_template.template_type, mmg_provider_name) + dao_create_notification(notification, sample_template.template_type, mmg_provider.identifier) assert Notification.query.count() == 1 notification_from_db = Notification.query.all()[0] @@ -466,20 +466,20 @@ def test_save_notification_and_increment_job(sample_template, sample_job, mmg_pr assert Job.query.get(sample_job.id).notifications_sent == 1 notification_2 = Notification(**data) - dao_create_notification(notification_2, sample_template.template_type, mmg_provider_name) + dao_create_notification(notification_2, sample_template.template_type, mmg_provider.identifier) assert Notification.query.count() == 2 assert Job.query.get(sample_job.id).notifications_sent == 2 -def test_should_not_increment_job_if_notification_fails_to_persist(sample_template, sample_job, mmg_provider_name): +def test_should_not_increment_job_if_notification_fails_to_persist(sample_template, sample_job, mmg_provider): random_id = str(uuid.uuid4()) assert Notification.query.count() == 0 data = { 'id': random_id, 'to': '+44709123456', 'job_id': sample_job.id, - 'service': sample_template.service, 'service_id': sample_template.service.id, + 'service': sample_template.service, 'template': sample_template, 'template_id': sample_template.id, 'template_version': sample_template.version, @@ -488,7 +488,7 @@ def test_should_not_increment_job_if_notification_fails_to_persist(sample_templa } notification_1 = Notification(**data) - dao_create_notification(notification_1, sample_template.template_type, mmg_provider_name) + dao_create_notification(notification_1, sample_template.template_type, mmg_provider.identifier) assert Notification.query.count() == 1 assert Job.query.get(sample_job.id).notifications_sent == 1 @@ -496,14 +496,14 @@ def test_should_not_increment_job_if_notification_fails_to_persist(sample_templa notification_2 = Notification(**data) with pytest.raises(SQLAlchemyError): - dao_create_notification(notification_2, sample_template.template_type, mmg_provider_name) + dao_create_notification(notification_2, sample_template.template_type, mmg_provider.identifier) assert Notification.query.count() == 1 assert Job.query.get(sample_job.id).notifications_sent == 1 assert Job.query.get(sample_job.id).updated_at == job_last_updated_at -def test_save_notification_and_increment_correct_job(notify_db, notify_db_session, sample_template, mmg_provider_name): +def test_save_notification_and_increment_correct_job(notify_db, notify_db_session, sample_template, mmg_provider): job_1 = sample_job(notify_db, notify_db_session, sample_template.service) job_2 = sample_job(notify_db, notify_db_session, sample_template.service) @@ -511,7 +511,7 @@ def test_save_notification_and_increment_correct_job(notify_db, notify_db_sessio data = _notification_json(sample_template, job_1.id) notification = Notification(**data) - dao_create_notification(notification, sample_template.template_type, mmg_provider_name) + dao_create_notification(notification, sample_template.template_type, mmg_provider.identifier) assert Notification.query.count() == 1 notification_from_db = Notification.query.all()[0] @@ -526,12 +526,12 @@ def test_save_notification_and_increment_correct_job(notify_db, notify_db_sessio assert Job.query.get(job_2.id).notifications_sent == 0 -def test_save_notification_with_no_job(sample_template, mmg_provider_name): +def test_save_notification_with_no_job(sample_template, mmg_provider): assert Notification.query.count() == 0 data = _notification_json(sample_template) notification = Notification(**data) - dao_create_notification(notification, sample_template.template_type, mmg_provider_name) + dao_create_notification(notification, sample_template.template_type, mmg_provider.identifier) assert Notification.query.count() == 1 notification_from_db = Notification.query.all()[0] @@ -550,13 +550,13 @@ def test_get_notification(sample_notification): assert sample_notification == notifcation_from_db -def test_save_notification_no_job_id(sample_template, mmg_provider_name): +def test_save_notification_no_job_id(sample_template, mmg_provider): assert Notification.query.count() == 0 to = '+44709123456' data = _notification_json(sample_template) notification = Notification(**data) - dao_create_notification(notification, sample_template.template_type, mmg_provider_name) + dao_create_notification(notification, sample_template.template_type, mmg_provider.identifier) assert Notification.query.count() == 1 notification_from_db = Notification.query.all()[0] @@ -640,13 +640,13 @@ def test_should_not_delete_failed_notifications_before_seven_days(notify_db, not @freeze_time("2016-03-30") -def test_save_new_notification_creates_template_stats(sample_template, sample_job, mmg_provider_name): +def test_save_new_notification_creates_template_stats(sample_template, sample_job, mmg_provider): assert Notification.query.count() == 0 assert TemplateStatistics.query.count() == 0 data = _notification_json(sample_template, sample_job.id) notification = Notification(**data) - dao_create_notification(notification, sample_template.template_type, mmg_provider_name) + dao_create_notification(notification, sample_template.template_type, mmg_provider.identifier) assert TemplateStatistics.query.count() == 1 template_stats = TemplateStatistics.query.filter(TemplateStatistics.service_id == sample_template.service.id, @@ -658,13 +658,13 @@ def test_save_new_notification_creates_template_stats(sample_template, sample_jo @freeze_time("2016-03-30") -def test_save_new_notification_creates_template_stats_per_day(sample_template, sample_job, mmg_provider_name): +def test_save_new_notification_creates_template_stats_per_day(sample_template, sample_job, mmg_provider): assert Notification.query.count() == 0 assert TemplateStatistics.query.count() == 0 data = _notification_json(sample_template, sample_job.id) notification = Notification(**data) - dao_create_notification(notification, sample_template.template_type, mmg_provider_name) + dao_create_notification(notification, sample_template.template_type, mmg_provider.identifier) assert TemplateStatistics.query.count() == 1 template_stats = TemplateStatistics.query.filter(TemplateStatistics.service_id == sample_template.service.id, @@ -678,7 +678,7 @@ def test_save_new_notification_creates_template_stats_per_day(sample_template, s with freeze_time('2016-03-31'): assert TemplateStatistics.query.count() == 1 new_notification = Notification(**data) - dao_create_notification(new_notification, sample_template.template_type, mmg_provider_name) + dao_create_notification(new_notification, sample_template.template_type, mmg_provider.identifier) assert TemplateStatistics.query.count() == 2 first_stats = TemplateStatistics.query.filter(TemplateStatistics.day == datetime(2016, 3, 30)).first() @@ -694,14 +694,14 @@ def test_save_new_notification_creates_template_stats_per_day(sample_template, s assert second_stats.usage_count == 1 -def test_save_another_notification_increments_template_stats(sample_template, sample_job, mmg_provider_name): +def test_save_another_notification_increments_template_stats(sample_template, sample_job, mmg_provider): assert Notification.query.count() == 0 assert TemplateStatistics.query.count() == 0 data = _notification_json(sample_template, sample_job.id) notification_1 = Notification(**data) notification_2 = Notification(**data) - dao_create_notification(notification_1, sample_template.template_type, mmg_provider_name) + dao_create_notification(notification_1, sample_template.template_type, mmg_provider.identifier) assert TemplateStatistics.query.count() == 1 template_stats = TemplateStatistics.query.filter(TemplateStatistics.service_id == sample_template.service.id, @@ -710,7 +710,7 @@ def test_save_another_notification_increments_template_stats(sample_template, sa assert template_stats.template_id == sample_template.id assert template_stats.usage_count == 1 - dao_create_notification(notification_2, sample_template.template_type, mmg_provider_name) + dao_create_notification(notification_2, sample_template.template_type, mmg_provider.identifier) assert TemplateStatistics.query.count() == 1 template_stats = TemplateStatistics.query.filter(TemplateStatistics.service_id == sample_template.service.id, @@ -720,7 +720,7 @@ def test_save_another_notification_increments_template_stats(sample_template, sa def test_successful_notification_inserts_followed_by_failure_does_not_increment_template_stats(sample_template, sample_job, - mmg_provider_name): + mmg_provider): assert Notification.query.count() == 0 assert NotificationStatistics.query.count() == 0 assert TemplateStatistics.query.count() == 0 @@ -730,9 +730,9 @@ def test_successful_notification_inserts_followed_by_failure_does_not_increment_ notification_1 = Notification(**data) notification_2 = Notification(**data) notification_3 = Notification(**data) - dao_create_notification(notification_1, sample_template.template_type, mmg_provider_name) - dao_create_notification(notification_2, sample_template.template_type, mmg_provider_name) - dao_create_notification(notification_3, sample_template.template_type, mmg_provider_name) + dao_create_notification(notification_1, sample_template.template_type, mmg_provider.identifier) + dao_create_notification(notification_2, sample_template.template_type, mmg_provider.identifier) + dao_create_notification(notification_3, sample_template.template_type, mmg_provider.identifier) assert NotificationStatistics.query.count() == 1 notication_stats = NotificationStatistics.query.filter( @@ -751,7 +751,7 @@ def test_successful_notification_inserts_followed_by_failure_does_not_increment_ try: # Mess up db in really bad way db.session.execute('DROP TABLE TEMPLATE_STATISTICS') - dao_create_notification(failing_notification, sample_template.template_type, mmg_provider_name) + dao_create_notification(failing_notification, sample_template.template_type, mmg_provider.identifier) except Exception as e: # There should be no additional notification stats or counts assert NotificationStatistics.query.count() == 1 @@ -764,24 +764,24 @@ def test_successful_notification_inserts_followed_by_failure_does_not_increment_ @freeze_time("2016-03-30") def test_get_template_stats_for_service_returns_stats_in_reverse_date_order(sample_template, sample_job, - mmg_provider_name): + mmg_provider): template_stats = dao_get_template_statistics_for_service(sample_template.service.id) assert len(template_stats) == 0 data = _notification_json(sample_template, sample_job.id) notification = Notification(**data) - dao_create_notification(notification, sample_template.template_type, mmg_provider_name) + dao_create_notification(notification, sample_template.template_type, mmg_provider.identifier) # move on one day with freeze_time('2016-03-31'): new_notification = Notification(**data) - dao_create_notification(new_notification, sample_template.template_type, mmg_provider_name) + dao_create_notification(new_notification, sample_template.template_type, mmg_provider.identifier) # move on one more day with freeze_time('2016-04-01'): new_notification = Notification(**data) - dao_create_notification(new_notification, sample_template.template_type, mmg_provider_name) + dao_create_notification(new_notification, sample_template.template_type, mmg_provider.identifier) template_stats = dao_get_template_statistics_for_service(sample_template.service_id) assert len(template_stats) == 3 diff --git a/tests/app/dao/test_notifications_dao_provider_statistics.py b/tests/app/dao/test_notifications_dao_provider_statistics.py index 3e879ad81..cc166edd1 100644 --- a/tests/app/dao/test_notifications_dao_provider_statistics.py +++ b/tests/app/dao/test_notifications_dao_provider_statistics.py @@ -3,112 +3,114 @@ from app.models import ProviderStatistics from app.dao.provider_statistics_dao import ( get_provider_statistics, get_fragment_count) -from app.models import Notification from tests.app.conftest import sample_notification as create_sample_notification def test_should_update_provider_statistics_sms(notify_db, notify_db_session, sample_template, - mmg_provider_name): - notification = create_sample_notification( + mmg_provider): + create_sample_notification( notify_db, notify_db_session, template=sample_template) provider_stats = get_provider_statistics( sample_template.service, - providers=[mmg_provider_name]).one() + providers=[mmg_provider.identifier]).one() assert provider_stats.unit_count == 1 def test_should_update_provider_statistics_email(notify_db, notify_db_session, sample_email_template, - ses_provider_name): - notification = create_sample_notification( + ses_provider): + create_sample_notification( notify_db, notify_db_session, template=sample_email_template) provider_stats = get_provider_statistics( sample_email_template.service, - providers=[ses_provider_name]).one() + providers=[ses_provider.identifier]).one() assert provider_stats.unit_count == 1 def test_should_update_provider_statistics_sms_multi(notify_db, notify_db_session, sample_template, - mmg_provider_name): - notification1 = create_sample_notification( + mmg_provider): + create_sample_notification( notify_db, notify_db_session, template=sample_template, content_char_count=160) - notification1 = create_sample_notification( + create_sample_notification( notify_db, notify_db_session, template=sample_template, content_char_count=161) - notification1 = create_sample_notification( + create_sample_notification( notify_db, notify_db_session, template=sample_template, content_char_count=307) provider_stats = get_provider_statistics( sample_template.service, - providers=[mmg_provider_name]).one() + providers=[mmg_provider.identifier]).one() assert provider_stats.unit_count == 6 def test_should_update_provider_statistics_email_multi(notify_db, notify_db_session, sample_email_template, - ses_provider_name): - notification1 = create_sample_notification( + ses_provider): + create_sample_notification( notify_db, notify_db_session, template=sample_email_template) - notification2 = create_sample_notification( + create_sample_notification( notify_db, notify_db_session, template=sample_email_template) - notification3 = create_sample_notification( + create_sample_notification( notify_db, notify_db_session, template=sample_email_template) provider_stats = get_provider_statistics( sample_email_template.service, - providers=[ses_provider_name]).one() + providers=[ses_provider.identifier]).one() assert provider_stats.unit_count == 3 def test_should_aggregate_fragment_count(notify_db, notify_db_session, sample_service, - mmg_provider_name, - twilio_provider_name, - ses_provider_name): + mmg_provider, + firetext_provider, + ses_provider): day = date.today() stats_mmg = ProviderStatistics( service=sample_service, day=day, - provider=mmg_provider_name, + provider_id=mmg_provider.id, unit_count=2 ) - stats_twilio = ProviderStatistics( + + stats_firetext = ProviderStatistics( service=sample_service, day=day, - provider=twilio_provider_name, + provider_id=firetext_provider.id, unit_count=3 ) - stats_twilio = ProviderStatistics( + + stats_ses = ProviderStatistics( service=sample_service, day=day, - provider=ses_provider_name, + provider_id=ses_provider.id, unit_count=1 ) notify_db.session.add(stats_mmg) - notify_db.session.add(stats_twilio) + notify_db.session.add(stats_firetext) + notify_db.session.add(stats_ses) notify_db.session.commit() results = get_fragment_count(sample_service, day, day) assert results['sms_count'] == 5 @@ -118,19 +120,19 @@ def test_should_aggregate_fragment_count(notify_db, def test_should_aggregate_fragment_count_over_days(notify_db, notify_db_session, sample_service, - mmg_provider_name): + mmg_provider): today = date.today() yesterday = today - timedelta(days=1) stats_today = ProviderStatistics( service=sample_service, day=today, - provider=mmg_provider_name, + provider_id=mmg_provider.id, unit_count=2 ) stats_yesterday = ProviderStatistics( service=sample_service, day=yesterday, - provider=mmg_provider_name, + provider_id=mmg_provider.id, unit_count=3 ) notify_db.session.add(stats_today) diff --git a/tests/app/dao/test_provider_details_dao.py b/tests/app/dao/test_provider_details_dao.py new file mode 100644 index 000000000..cb63fa400 --- /dev/null +++ b/tests/app/dao/test_provider_details_dao.py @@ -0,0 +1,57 @@ +from app.models import ProviderDetails +from app import clients +from app.dao.provider_details_dao import ( + get_provider_details, + get_provider_details_by_notification_type +) + + +def test_can_get_all_providers(notify_db, notify_db_session): + assert len(get_provider_details()) == 3 + + +def test_can_get_sms_providers(notify_db, notify_db_session): + assert len(get_provider_details_by_notification_type('sms')) == 2 + types = [provider.notification_type for provider in get_provider_details_by_notification_type('sms')] + assert all('sms' == notification_type for notification_type in types) + + +def test_can_get_sms_providers_in_order(notify_db, notify_db_session): + providers = get_provider_details_by_notification_type('sms') + + assert providers[0].identifier == "mmg" + assert providers[1].identifier == "firetext" + + +def test_can_get_email_providers_in_order(notify_db, notify_db_session): + providers = get_provider_details_by_notification_type('email') + + assert providers[0].identifier == "ses" + + +def test_can_get_email_providers(notify_db, notify_db_session): + assert len(get_provider_details_by_notification_type('email')) == 1 + types = [provider.notification_type for provider in get_provider_details_by_notification_type('email')] + assert all('email' == notification_type for notification_type in types) + + +def test_should_error_if_any_provider_in_database_not_in_code(notify_db, notify_db_session, notify_api): + providers = ProviderDetails.query.all() + + for provider in providers: + if provider.notification_type == 'sms': + assert clients.get_sms_client(provider.identifier) + if provider.notification_type == 'email': + assert clients.get_email_client(provider.identifier) + + +def test_should_not_error_if_any_provider_in_code_not_in_database(notify_db, notify_db_session, notify_api): + providers = ProviderDetails.query.all() + + ProviderDetails.query.filter_by(identifier='mmg').delete() + + for provider in providers: + if provider.notification_type == 'sms': + assert clients.get_sms_client(provider.identifier) + if provider.notification_type == 'email': + assert clients.get_email_client(provider.identifier) diff --git a/tests/app/dao/test_provider_rates_dao.py b/tests/app/dao/test_provider_rates_dao.py index d9ce53e41..417612781 100644 --- a/tests/app/dao/test_provider_rates_dao.py +++ b/tests/app/dao/test_provider_rates_dao.py @@ -1,14 +1,17 @@ from datetime import datetime from decimal import Decimal from app.dao.provider_rates_dao import create_provider_rates -from app.models import ProviderRates +from app.models import ProviderRates, ProviderDetails -def test_create_provider_rates(notify_db, notify_db_session, mmg_provider_name): - now = datetime.utcnow() +def test_create_provider_rates(notify_db, notify_db_session, mmg_provider): + now = datetime.now() rate = Decimal("1.00000") - create_provider_rates(mmg_provider_name, now, rate) + + provider = ProviderDetails.query.filter_by(identifier=mmg_provider.identifier).one() + + create_provider_rates(mmg_provider.identifier, now, rate) assert ProviderRates.query.count() == 1 assert ProviderRates.query.first().rate == rate assert ProviderRates.query.first().valid_from == now - assert ProviderRates.query.first().provider == mmg_provider_name + assert ProviderRates.query.first().provider_id == provider.id diff --git a/tests/app/provider_details/__init__.py b/tests/app/provider_details/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/app/provider_details/test_rest.py b/tests/app/provider_details/test_rest.py new file mode 100644 index 000000000..2cfcf3979 --- /dev/null +++ b/tests/app/provider_details/test_rest.py @@ -0,0 +1,140 @@ +from flask import json +from tests import create_authorization_header + + +def test_get_provider_details_in_type_and_identifier_order(notify_db, notify_db_session, notify_api): + with notify_api.test_request_context(): + with notify_api.test_client() as client: + auth_header = create_authorization_header() + response = client.get( + '/provider-details', + headers=[auth_header] + ) + assert response.status_code == 200 + json_resp = json.loads(response.get_data(as_text=True))['provider_details'] + assert len(json_resp) == 3 + + assert json_resp[0]['identifier'] == 'ses' + assert json_resp[1]['identifier'] == 'mmg' + assert json_resp[2]['identifier'] == 'firetext' + + +def test_get_provider_details_by_id(notify_db, notify_db_session, notify_api): + with notify_api.test_request_context(): + with notify_api.test_client() as client: + auth_header = create_authorization_header() + response = client.get( + '/provider-details', + headers=[auth_header] + ) + json_resp = json.loads(response.get_data(as_text=True))['provider_details'] + + provider_resp = client.get( + '/provider-details/{}'.format(json_resp[0]['id']), + headers=[auth_header] + ) + + provider = json.loads(provider_resp.get_data(as_text=True))['provider_details'] + assert provider['identifier'] == json_resp[0]['identifier'] + + +def test_get_provider_details_contains_correct_fields(notify_db, notify_db_session, notify_api): + with notify_api.test_request_context(): + with notify_api.test_client() as client: + auth_header = create_authorization_header() + response = client.get( + '/provider-details', + headers=[auth_header] + ) + json_resp = json.loads(response.get_data(as_text=True))['provider_details'] + allowed_keys = {"id", "display_name", "identifier", "priority", 'notification_type', "active"} + assert \ + allowed_keys == \ + set(json_resp[0].keys()) + + +def test_should_be_able_to_update_priority(notify_db, notify_db_session, notify_api): + with notify_api.test_request_context(): + with notify_api.test_client() as client: + auth_header = create_authorization_header() + response = client.get( + '/provider-details', + headers=[auth_header] + ) + fetch_resp = json.loads(response.get_data(as_text=True))['provider_details'] + + provider_id = fetch_resp[2]['id'] + + update_resp = client.post( + '/provider-details/{}'.format(provider_id), + headers=[('Content-Type', 'application/json'), auth_header], + data=json.dumps({ + 'priority': 10 + }) + ) + assert update_resp.status_code == 200 + update_json = json.loads(update_resp.get_data(as_text=True))['provider_details'] + assert update_json['identifier'] == 'firetext' + assert update_json['priority'] == 10 + + +def test_should_be_able_to_update_status(notify_db, notify_db_session, notify_api): + with notify_api.test_request_context(): + with notify_api.test_client() as client: + auth_header = create_authorization_header() + response = client.get( + '/provider-details', + headers=[auth_header] + ) + fetch_resp = json.loads(response.get_data(as_text=True))['provider_details'] + + provider_id = fetch_resp[2]['id'] + + update_resp_1 = client.post( + '/provider-details/{}'.format(provider_id), + headers=[('Content-Type', 'application/json'), auth_header], + data=json.dumps({ + 'active': False + }) + ) + assert update_resp_1.status_code == 200 + update_resp_1 = json.loads(update_resp_1.get_data(as_text=True))['provider_details'] + assert update_resp_1['identifier'] == 'firetext' + assert not update_resp_1['active'] + + update_resp_2 = client.post( + '/provider-details/{}'.format(provider_id), + headers=[('Content-Type', 'application/json'), auth_header], + data=json.dumps({ + 'active': True + }) + ) + assert update_resp_2.status_code == 200 + update_resp_2 = json.loads(update_resp_2.get_data(as_text=True))['provider_details'] + assert update_resp_2['identifier'] == 'firetext' + assert update_resp_2['active'] + + +def test_should_not_be_able_to_update_identifier(notify_db, notify_db_session, notify_api): + with notify_api.test_request_context(): + with notify_api.test_client() as client: + auth_header = create_authorization_header() + response = client.get( + '/provider-details', + headers=[auth_header] + ) + fetch_resp = json.loads(response.get_data(as_text=True))['provider_details'] + + provider_id = fetch_resp[2]['id'] + + update_resp = client.post( + '/provider-details/{}'.format(provider_id), + headers=[('Content-Type', 'application/json'), auth_header], + data=json.dumps({ + 'identifier': "new" + }) + ) + assert update_resp.status_code == 400 + update_resp = json.loads(update_resp.get_data(as_text=True)) + assert update_resp['message']['identifier'][0] == 'Not permitted to be updated' + assert update_resp['result'] == 'error' diff --git a/tests/conftest.py b/tests/conftest.py index 1b45b6369..24f3702ba 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -52,7 +52,8 @@ def notify_db_session(request): def teardown(): db.session.remove() for tbl in reversed(meta.sorted_tables): - db.engine.execute(tbl.delete()) + if tbl.name not in ["provider_details"]: + db.engine.execute(tbl.delete()) db.session.commit() meta = MetaData(bind=db.engine, reflect=True)