diff --git a/app/dao/fact_billing_dao.py b/app/dao/fact_billing_dao.py index 6c4ed7a90..18387a25e 100644 --- a/app/dao/fact_billing_dao.py +++ b/app/dao/fact_billing_dao.py @@ -1,5 +1,6 @@ from datetime import datetime, timedelta, time +from sqlalchemy.dialects.postgresql import insert from sqlalchemy import func, case, desc, Date from app import db @@ -14,7 +15,8 @@ from app.models import ( LETTER_TYPE, SMS_TYPE, Rate, - LetterRate + LetterRate, + NotificationHistory ) from app.utils import convert_utc_to_bst, convert_bst_to_utc @@ -84,43 +86,48 @@ def fetch_monthly_billing_for_year(service_id, year): def fetch_billing_data_for_day(process_day, service_id=None): start_date = convert_bst_to_utc(datetime.combine(process_day, time.min)) end_date = convert_bst_to_utc(datetime.combine(process_day + timedelta(days=1), time.min)) + # use notification_history if process day is older than 7 days + # this is useful if we need to rebuild the ft_billing table for a date older than 7 days ago. + table = Notification + if start_date < datetime.utcnow() - timedelta(days=7): + table = NotificationHistory transit_data = db.session.query( - Notification.template_id, - Notification.service_id, - Notification.notification_type, - func.coalesce(Notification.sent_by, + table.template_id, + table.service_id, + table.notification_type, + func.coalesce(table.sent_by, case( [ - (Notification.notification_type == 'letter', 'dvla'), - (Notification.notification_type == 'sms', 'unknown'), - (Notification.notification_type == 'email', 'ses') + (table.notification_type == 'letter', 'dvla'), + (table.notification_type == 'sms', 'unknown'), + (table.notification_type == 'email', 'ses') ]), ).label('sent_by'), - func.coalesce(Notification.rate_multiplier, 1).label('rate_multiplier'), - func.coalesce(Notification.international, False).label('international'), - func.sum(Notification.billable_units).label('billable_units'), + func.coalesce(table.rate_multiplier, 1).label('rate_multiplier'), + func.coalesce(table.international, False).label('international'), + func.sum(table.billable_units).label('billable_units'), func.count().label('notifications_sent'), Service.crown, ).filter( - Notification.status != NOTIFICATION_CREATED, # at created status, provider information is not available - Notification.status != NOTIFICATION_TECHNICAL_FAILURE, - Notification.key_type != KEY_TYPE_TEST, - Notification.created_at >= start_date, - Notification.created_at < end_date + table.status != NOTIFICATION_CREATED, # at created status, provider information is not available + table.status != NOTIFICATION_TECHNICAL_FAILURE, + table.key_type != KEY_TYPE_TEST, + table.created_at >= start_date, + table.created_at < end_date ).group_by( - Notification.template_id, - Notification.service_id, - Notification.notification_type, + table.template_id, + table.service_id, + table.notification_type, 'sent_by', - Notification.rate_multiplier, - Notification.international, + table.rate_multiplier, + table.international, Service.crown ).join( Service ) if service_id: - transit_data = transit_data.filter(Notification.service_id == service_id) + transit_data = transit_data.filter(table.service_id == service_id) return transit_data.all() @@ -143,33 +150,48 @@ def get_rate(non_letter_rates, letter_rates, notification_type, date, crown=None def update_fact_billing(data, process_day): - inserted_records = 0 - updated_records = 0 non_letter_rates, letter_rates = get_rates_for_billing() - update_count = FactBilling.query.filter( - FactBilling.bst_date == datetime.date(process_day), - FactBilling.template_id == data.template_id, - FactBilling.service_id == data.service_id, - FactBilling.provider == data.sent_by, # This could be zero - this is a bug that needs to be fixed. - FactBilling.rate_multiplier == data.rate_multiplier, - FactBilling.notification_type == data.notification_type, - FactBilling.international == data.international - ).update( - {"notifications_sent": data.notifications_sent, - "billable_units": data.billable_units}, - synchronize_session=False) + rate = get_rate(non_letter_rates, + letter_rates, + data.notification_type, + process_day, + data.crown, + data.rate_multiplier) + billing_record = create_billing_record(data, rate, process_day) - if update_count == 0: - rate = get_rate(non_letter_rates, - letter_rates, - data.notification_type, - process_day, - data.crown, - data.rate_multiplier) - billing_record = create_billing_record(data, rate, process_day) - db.session.add(billing_record) - inserted_records += 1 - updated_records += update_count + table = FactBilling.__table__ + ''' + This uses the Postgres upsert to avoid race conditions when two threads try to insert + at the same row. The excluded object refers to values that we tried to insert but were + rejected. + http://docs.sqlalchemy.org/en/latest/dialects/postgresql.html#insert-on-conflict-upsert + ''' + stmt = insert(table).values( + bst_date=billing_record.bst_date, + template_id=billing_record.template_id, + service_id=billing_record.service_id, + provider=billing_record.provider, + rate_multiplier=billing_record.rate_multiplier, + notification_type=billing_record.notification_type, + international=billing_record.international, + billable_units=billing_record.billable_units, + notifications_sent=billing_record.notifications_sent, + rate=billing_record.rate + ) + + stmt = stmt.on_conflict_do_update( + index_elements=[table.c.bst_date, + table.c.template_id, + table.c.service_id, + table.c.provider, + table.c.rate_multiplier, + table.c.notification_type, + table.c.international], + set_={"notifications_sent": stmt.excluded.notifications_sent, + "billable_units": stmt.excluded.billable_units + } + ) + db.session.connection().execute(stmt) db.session.commit() diff --git a/tests/app/celery/test_reporting_tasks.py b/tests/app/celery/test_reporting_tasks.py index 2e33fcc45..9f058654c 100644 --- a/tests/app/celery/test_reporting_tasks.py +++ b/tests/app/celery/test_reporting_tasks.py @@ -438,7 +438,22 @@ def test_create_nightly_billing_update_when_record_exists( assert len(records) == 1 assert records[0].bst_date == date(2018, 1, 14) + assert records[0].billable_units == 1 + + sample_notification( + notify_db, + notify_db_session, + created_at=datetime.now() - timedelta(days=1), + service=sample_service, + template=sample_template, + status='delivered', + sent_by=None, + international=False, + rate_multiplier=1.0, + billable_units=1, + ) # run again, make sure create_nightly_billing() updates with no error create_nightly_billing() assert len(records) == 1 + assert records[0].billable_units == 2 diff --git a/tests/app/dao/test_ft_billing_dao.py b/tests/app/dao/test_ft_billing_dao.py index 4cd880ea0..09deb9266 100644 --- a/tests/app/dao/test_ft_billing_dao.py +++ b/tests/app/dao/test_ft_billing_dao.py @@ -10,7 +10,7 @@ from app.dao.fact_billing_dao import ( get_rate, fetch_billing_totals_for_year, ) -from app.models import FactBilling +from app.models import FactBilling, Notification from app.utils import convert_utc_to_bst from tests.app.db import ( create_ft_billing, @@ -187,6 +187,19 @@ def test_fetch_billing_data_for_day_returns_empty_list(notify_db_session): assert results == [] +def test_fetch_billing_data_for_day_uses_notification_history(notify_db_session): + service = create_service() + sms_template = create_template(service=service, template_type='sms') + create_notification(template=sms_template, status='delivered', created_at=datetime.utcnow() - timedelta(days=8)) + create_notification(template=sms_template, status='delivered', created_at=datetime.utcnow() - timedelta(days=8)) + + Notification.query.delete() + db.session.commit() + results = fetch_billing_data_for_day(process_day=datetime.utcnow() - timedelta(days=8), service_id=service.id) + assert len(results) == 1 + assert results[0].notifications_sent == 2 + + def test_fetch_billing_data_for_day_returns_list_for_given_service(notify_db_session): service = create_service() service_2 = create_service(service_name='Service 2')