diff --git a/app/celery/tasks.py b/app/celery/tasks.py index 02291d728..ce7f906e7 100644 --- a/app/celery/tasks.py +++ b/app/celery/tasks.py @@ -8,10 +8,6 @@ from notifications_utils.recipients import ( RecipientCSV ) from notifications_utils.statsd_decorators import statsd -from notifications_utils.template import ( - SMSMessageTemplate, - WithSubjectTemplate, -) from notifications_utils.timezones import convert_utc_to_bst from requests import ( HTTPError, @@ -128,9 +124,8 @@ def job_complete(job, resumed=False, start=None): def get_recipient_csv_and_template_and_sender_id(job): db_template = dao_get_template_by_id(job.template_id, job.template_version) + template = db_template._as_utils_template() - TemplateClass = get_template_class(db_template.template_type) - template = TemplateClass(db_template.__dict__) contents, meta_data = s3.get_job_and_metadata_from_s3(service_id=str(job.service_id), job_id=str(job.id)) recipient_csv = RecipientCSV(file_data=contents, template_type=template.template_type, @@ -454,15 +449,6 @@ def handle_exception(task, notification, notification_id, exc): current_app.logger.error('Max retry failed' + retry_msg) -def get_template_class(template_type): - if template_type == SMS_TYPE: - return SMSMessageTemplate - elif template_type in (EMAIL_TYPE, LETTER_TYPE): - # since we don't need rendering capabilities (we only need to extract placeholders) both email and letter can - # use the same base template - return WithSubjectTemplate - - @notify_celery.task(bind=True, name='update-letter-notifications-statuses') @statsd(namespace="tasks") def update_letter_notifications_statuses(self, filename): diff --git a/app/commands.py b/app/commands.py index 7b877f2c1..a28be6fa7 100644 --- a/app/commands.py +++ b/app/commands.py @@ -17,7 +17,7 @@ from notifications_utils.statsd_decorators import statsd from app import db, DATETIME_FORMAT, encryption from app.aws import s3 -from app.celery.tasks import record_daily_sorted_counts, get_template_class, process_row +from app.celery.tasks import record_daily_sorted_counts, process_row from app.celery.nightly_tasks import send_total_sent_notifications_to_performance_platform from app.celery.service_callback_tasks import send_delivery_status_to_service from app.celery.letters_pdf_tasks import create_letters_pdf @@ -889,8 +889,7 @@ def process_row_from_job(job_id, job_row_number): job = dao_get_job_by_id(job_id) db_template = dao_get_template_by_id(job.template_id, job.template_version) - TemplateClass = get_template_class(db_template.template_type) - template = TemplateClass(db_template.__dict__) + template = db_template._as_utils_template() for row in RecipientCSV( s3.get_job_from_s3(str(job.service_id), str(job.id)), diff --git a/app/models.py b/app/models.py index 64beaca52..2bed6639d 100644 --- a/app/models.py +++ b/app/models.py @@ -977,16 +977,12 @@ class TemplateBase(db.Model): def _as_utils_template(self): if self.template_type == EMAIL_TYPE: - return PlainTextEmailTemplate( - {'content': self.content, 'subject': self.subject} - ) + return PlainTextEmailTemplate(self.__dict__) if self.template_type == SMS_TYPE: - return SMSMessageTemplate( - {'content': self.content} - ) + return SMSMessageTemplate(self.__dict__) if self.template_type == LETTER_TYPE: return LetterPrintTemplate( - {'content': self.content, 'subject': self.subject}, + self.__dict__, contact_block=self.service.get_default_letter_contact(), ) diff --git a/tests/app/celery/test_tasks.py b/tests/app/celery/test_tasks.py index 3175f7f1e..d1a60cf75 100644 --- a/tests/app/celery/test_tasks.py +++ b/tests/app/celery/test_tasks.py @@ -9,7 +9,11 @@ from freezegun import freeze_time from requests import RequestException from sqlalchemy.exc import SQLAlchemyError from celery.exceptions import Retry -from notifications_utils.template import SMSMessageTemplate, WithSubjectTemplate +from notifications_utils.template import ( + LetterPrintTemplate, + PlainTextEmailTemplate, + SMSMessageTemplate, +) from notifications_utils.columns import Row from app import ( @@ -26,11 +30,12 @@ from app.celery.tasks import ( save_letter, process_incomplete_job, process_incomplete_jobs, - get_template_class, s3, send_inbound_sms_to_service, process_returned_letters_list, - save_api_email) + save_api_email, + get_recipient_csv_and_template_and_sender_id, +) from app.config import QueueNames from app.dao import jobs_dao, service_email_reply_to_dao, service_sms_sender_dao from app.models import ( @@ -1298,13 +1303,81 @@ def test_should_cancel_job_if_service_is_inactive(sample_service, tasks.process_row.assert_not_called() -@pytest.mark.parametrize('template_type, expected_class', [ - (SMS_TYPE, SMSMessageTemplate), - (EMAIL_TYPE, WithSubjectTemplate), - (LETTER_TYPE, WithSubjectTemplate), -]) -def test_get_template_class(template_type, expected_class): - assert get_template_class(template_type) == expected_class +def test_get_email_template_instance(mocker, sample_email_template, sample_job): + mocker.patch( + 'app.celery.tasks.s3.get_job_and_metadata_from_s3', + return_value=('', {}), + ) + sample_job.template_id = sample_email_template.id + ( + recipient_csv, + template, + _sender_id, + ) = get_recipient_csv_and_template_and_sender_id(sample_job) + + assert isinstance(template, PlainTextEmailTemplate) + assert recipient_csv.placeholders == [ + 'email address' + ] + + +def test_get_sms_template_instance(mocker, sample_template, sample_job): + mocker.patch( + 'app.celery.tasks.s3.get_job_and_metadata_from_s3', + return_value=('', {}), + ) + sample_job.template = sample_template + ( + recipient_csv, + template, + _sender_id, + ) = get_recipient_csv_and_template_and_sender_id(sample_job) + + assert isinstance(template, SMSMessageTemplate) + assert recipient_csv.placeholders == [ + 'phone number' + ] + + +def test_get_letter_template_instance(mocker, sample_job): + mocker.patch( + 'app.celery.tasks.s3.get_job_and_metadata_from_s3', + return_value=('', {}), + ) + sample_contact_block = create_letter_contact( + service=sample_job.service, + contact_block='((reference number))' + ) + sample_template = create_template( + service=sample_job.service, + template_type=LETTER_TYPE, + reply_to=sample_contact_block.id, + ) + sample_job.template_id = sample_template.id + + ( + recipient_csv, + template, + _sender_id, + ) = get_recipient_csv_and_template_and_sender_id(sample_job) + + assert isinstance(template, LetterPrintTemplate) + assert template.contact_block == ( + '((reference number))' + ) + assert template.placeholders == { + 'reference number' + } + assert recipient_csv.placeholders == [ + 'reference number', + 'address line 1', + 'address line 2', + 'address line 3', + 'address line 4', + 'address line 5', + 'address line 6', + 'postcode', + ] def test_send_inbound_sms_to_service_post_https_request_to_service(notify_api, sample_service):