From 271ce6d76efb7f8a07df98bc5140508b4d53057a Mon Sep 17 00:00:00 2001 From: Rebecca Law Date: Tue, 15 May 2018 11:21:10 +0100 Subject: [PATCH] Changed the update/insert to a postgres upsert to avoid concurrency issues. --- app/dao/fact_billing_dao.py | 66 +++++++++++++++--------- tests/app/celery/test_reporting_tasks.py | 15 ++++++ 2 files changed, 56 insertions(+), 25 deletions(-) diff --git a/app/dao/fact_billing_dao.py b/app/dao/fact_billing_dao.py index 0c86e1ac9..18387a25e 100644 --- a/app/dao/fact_billing_dao.py +++ b/app/dao/fact_billing_dao.py @@ -1,5 +1,6 @@ from datetime import datetime, timedelta, time +from sqlalchemy.dialects.postgresql import insert from sqlalchemy import func, case, desc, Date from app import db @@ -149,33 +150,48 @@ def get_rate(non_letter_rates, letter_rates, notification_type, date, crown=None def update_fact_billing(data, process_day): - inserted_records = 0 - updated_records = 0 non_letter_rates, letter_rates = get_rates_for_billing() - update_count = FactBilling.query.filter( - FactBilling.bst_date == datetime.date(process_day), - FactBilling.template_id == data.template_id, - FactBilling.service_id == data.service_id, - FactBilling.provider == data.sent_by, # This could be zero - this is a bug that needs to be fixed. - FactBilling.rate_multiplier == data.rate_multiplier, - FactBilling.notification_type == data.notification_type, - FactBilling.international == data.international - ).update( - {"notifications_sent": data.notifications_sent, - "billable_units": data.billable_units}, - synchronize_session=False) + rate = get_rate(non_letter_rates, + letter_rates, + data.notification_type, + process_day, + data.crown, + data.rate_multiplier) + billing_record = create_billing_record(data, rate, process_day) - if update_count == 0: - rate = get_rate(non_letter_rates, - letter_rates, - data.notification_type, - process_day, - data.crown, - data.rate_multiplier) - billing_record = create_billing_record(data, rate, process_day) - db.session.add(billing_record) - inserted_records += 1 - updated_records += update_count + table = FactBilling.__table__ + ''' + This uses the Postgres upsert to avoid race conditions when two threads try to insert + at the same row. The excluded object refers to values that we tried to insert but were + rejected. + http://docs.sqlalchemy.org/en/latest/dialects/postgresql.html#insert-on-conflict-upsert + ''' + stmt = insert(table).values( + bst_date=billing_record.bst_date, + template_id=billing_record.template_id, + service_id=billing_record.service_id, + provider=billing_record.provider, + rate_multiplier=billing_record.rate_multiplier, + notification_type=billing_record.notification_type, + international=billing_record.international, + billable_units=billing_record.billable_units, + notifications_sent=billing_record.notifications_sent, + rate=billing_record.rate + ) + + stmt = stmt.on_conflict_do_update( + index_elements=[table.c.bst_date, + table.c.template_id, + table.c.service_id, + table.c.provider, + table.c.rate_multiplier, + table.c.notification_type, + table.c.international], + set_={"notifications_sent": stmt.excluded.notifications_sent, + "billable_units": stmt.excluded.billable_units + } + ) + db.session.connection().execute(stmt) db.session.commit() diff --git a/tests/app/celery/test_reporting_tasks.py b/tests/app/celery/test_reporting_tasks.py index 2e33fcc45..9f058654c 100644 --- a/tests/app/celery/test_reporting_tasks.py +++ b/tests/app/celery/test_reporting_tasks.py @@ -438,7 +438,22 @@ def test_create_nightly_billing_update_when_record_exists( assert len(records) == 1 assert records[0].bst_date == date(2018, 1, 14) + assert records[0].billable_units == 1 + + sample_notification( + notify_db, + notify_db_session, + created_at=datetime.now() - timedelta(days=1), + service=sample_service, + template=sample_template, + status='delivered', + sent_by=None, + international=False, + rate_multiplier=1.0, + billable_units=1, + ) # run again, make sure create_nightly_billing() updates with no error create_nightly_billing() assert len(records) == 1 + assert records[0].billable_units == 2