diff --git a/app/dao/complaint_dao.py b/app/dao/complaint_dao.py index 1cc12bdae..63b7487fb 100644 --- a/app/dao/complaint_dao.py +++ b/app/dao/complaint_dao.py @@ -1,10 +1,11 @@ from datetime import timedelta from flask import current_app -from sqlalchemy import desc +from sqlalchemy import desc, func, select from app import db from app.dao.dao_utils import autocommit +from app.dao.inbound_sms_dao import Pagination from app.models import Complaint from app.utils import get_midnight_in_utc @@ -15,23 +16,36 @@ def save_complaint(complaint): def fetch_paginated_complaints(page=1): - return Complaint.query.order_by(desc(Complaint.created_at)).paginate( - page=page, per_page=current_app.config["PAGE_SIZE"] + page_size = current_app.config["PAGE_SIZE"] + total_count = db.session.scalar(select(func.count()).select_from(Complaint)) + offset = (page - 1) * page_size + stmt = ( + select(Complaint) + .order_by(desc(Complaint.created_at)) + .offset(offset) + .limit(page_size) ) + result = db.session.execute(stmt).scalars().all() + pagination = Pagination(result, page=page, per_page=page_size, total=total_count) + return pagination def fetch_complaints_by_service(service_id): - return ( - Complaint.query.filter_by(service_id=service_id) + stmt = ( + select(Complaint) + .filter_by(service_id=service_id) .order_by(desc(Complaint.created_at)) - .all() ) + return db.session.execute(stmt).scalars().all() def fetch_count_of_complaints(start_date, end_date): start_date = get_midnight_in_utc(start_date) end_date = get_midnight_in_utc(end_date + timedelta(days=1)) - return Complaint.query.filter( - Complaint.created_at >= start_date, Complaint.created_at < end_date - ).count() + stmt = ( + select(func.count()) + .select_from(Complaint) + .filter(Complaint.created_at >= start_date, Complaint.created_at < end_date) + ) + return db.session.execute(stmt).scalar() or 0 diff --git a/app/dao/inbound_numbers_dao.py b/app/dao/inbound_numbers_dao.py index 0a390c024..a86ba530e 100644 --- a/app/dao/inbound_numbers_dao.py +++ b/app/dao/inbound_numbers_dao.py @@ -1,24 +1,30 @@ +from sqlalchemy import and_, select, update + from app import db from app.dao.dao_utils import autocommit from app.models import InboundNumber def dao_get_inbound_numbers(): - return InboundNumber.query.order_by(InboundNumber.updated_at).all() + stmt = select(InboundNumber).order_by(InboundNumber.updated_at) + return db.session.execute(stmt).scalars().all() def dao_get_available_inbound_numbers(): - return InboundNumber.query.filter( + stmt = select(InboundNumber).filter( InboundNumber.active, InboundNumber.service_id.is_(None) - ).all() + ) + return db.session.execute(stmt).scalars().all() def dao_get_inbound_number_for_service(service_id): - return InboundNumber.query.filter(InboundNumber.service_id == service_id).first() + stmt = select(InboundNumber).filter(InboundNumber.service_id == service_id) + return db.session.execute(stmt).scalars().first() def dao_get_inbound_number(inbound_number_id): - return InboundNumber.query.filter(InboundNumber.id == inbound_number_id).first() + stmt = select(InboundNumber).filter(InboundNumber.id == inbound_number_id) + return db.session.execute(stmt).scalars().first() @autocommit @@ -29,9 +35,8 @@ def dao_set_inbound_number_to_service(service_id, inbound_number): @autocommit def dao_set_inbound_number_active_flag(service_id, active): - inbound_number = InboundNumber.query.filter( - InboundNumber.service_id == service_id - ).first() + stmt = select(InboundNumber).filter(InboundNumber.service_id == service_id) + inbound_number = db.session.execute(stmt).scalars().first() inbound_number.active = active db.session.add(inbound_number) @@ -39,9 +44,18 @@ def dao_set_inbound_number_active_flag(service_id, active): @autocommit def dao_allocate_number_for_service(service_id, inbound_number_id): - updated = InboundNumber.query.filter_by( - id=inbound_number_id, active=True, service_id=None - ).update({"service_id": service_id}) - if not updated: + stmt = ( + update(InboundNumber) + .where( + and_( + InboundNumber.id == inbound_number_id, # noqa + InboundNumber.active == True, # noqa + InboundNumber.service_id == None, # noqa + ) + ) + .values({"service_id": service_id}) + ) + result = db.session.execute(stmt) + if result.rowcount == 0: raise Exception("Inbound number: {} is not available".format(inbound_number_id)) - return InboundNumber.query.get(inbound_number_id) + return db.session.get(InboundNumber, inbound_number_id) diff --git a/app/dao/inbound_sms_dao.py b/app/dao/inbound_sms_dao.py index 272ae5e1c..c9b4417e3 100644 --- a/app/dao/inbound_sms_dao.py +++ b/app/dao/inbound_sms_dao.py @@ -1,5 +1,5 @@ from flask import current_app -from sqlalchemy import and_, desc +from sqlalchemy import and_, delete, desc, func, select from sqlalchemy.dialects.postgresql import insert from sqlalchemy.orm import aliased @@ -18,8 +18,10 @@ def dao_create_inbound_sms(inbound_sms): def dao_get_inbound_sms_for_service( service_id, user_number=None, *, limit_days=None, limit=None ): - q = InboundSms.query.filter(InboundSms.service_id == service_id).order_by( - InboundSms.created_at.desc() + q = ( + select(InboundSms) + .filter(InboundSms.service_id == service_id) + .order_by(InboundSms.created_at.desc()) ) if limit_days is not None: start_date = midnight_n_days_ago(limit_days) @@ -31,7 +33,7 @@ def dao_get_inbound_sms_for_service( if limit: q = q.limit(limit) - return q.all() + return db.session.execute(q).scalars().all() def dao_get_paginated_inbound_sms_for_service_for_public_api( @@ -46,27 +48,33 @@ def dao_get_paginated_inbound_sms_for_service_for_public_api( older_than_created_at = ( db.session.query(InboundSms.created_at) .filter(InboundSms.id == older_than) - .as_scalar() + .scalar_subquery() ) filters.append(InboundSms.created_at < older_than_created_at) - query = InboundSms.query.filter(*filters) - - return ( - query.order_by(desc(InboundSms.created_at)).paginate(per_page=page_size).items - ) + # As part of the move to sqlalchemy 2.0, we do this manual pagination + query = db.session.query(InboundSms).filter(*filters) + paginated_items = query.order_by(desc(InboundSms.created_at)).limit(page_size).all() + return paginated_items def dao_count_inbound_sms_for_service(service_id, limit_days): - return InboundSms.query.filter( - InboundSms.service_id == service_id, - InboundSms.created_at >= midnight_n_days_ago(limit_days), - ).count() + stmt = ( + select(func.count()) + .select_from(InboundSms) + .filter( + InboundSms.service_id == service_id, + InboundSms.created_at >= midnight_n_days_ago(limit_days), + ) + ) + result = db.session.execute(stmt).scalar() + return result def _insert_inbound_sms_history(subquery, query_limit=10000): offset = 0 - inbound_sms_query = db.session.query( + subquery_select = select(subquery) + inbound_sms_query = select( InboundSms.id, InboundSms.created_at, InboundSms.service_id, @@ -74,8 +82,10 @@ def _insert_inbound_sms_history(subquery, query_limit=10000): InboundSms.provider_date, InboundSms.provider_reference, InboundSms.provider, - ).filter(InboundSms.id.in_(subquery)) - inbound_sms_count = inbound_sms_query.count() + ).where(InboundSms.id.in_(subquery_select)) + + count_query = select(func.count()).select_from(inbound_sms_query.subquery()) + inbound_sms_count = db.session.execute(count_query).scalar() or 0 while offset < inbound_sms_count: statement = insert(InboundSmsHistory).from_select( @@ -86,7 +96,8 @@ def _insert_inbound_sms_history(subquery, query_limit=10000): statement = statement.on_conflict_do_nothing( constraint="inbound_sms_history_pkey" ) - db.session.connection().execute(statement) + db.session.execute(statement) + db.session.commit() offset += query_limit @@ -95,7 +106,7 @@ def _delete_inbound_sms(datetime_to_delete_from, query_filter): query_limit = 10000 subquery = ( - db.session.query(InboundSms.id) + select(InboundSms.id) .filter(InboundSms.created_at < datetime_to_delete_from, *query_filter) .limit(query_limit) .subquery() @@ -107,9 +118,9 @@ def _delete_inbound_sms(datetime_to_delete_from, query_filter): while number_deleted > 0: _insert_inbound_sms_history(subquery, query_limit=query_limit) - number_deleted = InboundSms.query.filter(InboundSms.id.in_(subquery)).delete( - synchronize_session="fetch" - ) + stmt = delete(InboundSms).filter(InboundSms.id.in_(subquery)) + number_deleted = db.session.execute(stmt).rowcount + db.session.commit() deleted += number_deleted return deleted @@ -121,11 +132,12 @@ def delete_inbound_sms_older_than_retention(): "Deleting inbound sms for services with flexible data retention" ) - flexible_data_retention = ( - ServiceDataRetention.query.join(ServiceDataRetention.service) + stmt = ( + select(ServiceDataRetention) + .join(ServiceDataRetention.service) .filter(ServiceDataRetention.notification_type == NotificationType.SMS) - .all() ) + flexible_data_retention = db.session.execute(stmt).scalars().all() deleted = 0 @@ -158,7 +170,8 @@ def delete_inbound_sms_older_than_retention(): def dao_get_inbound_sms_by_id(service_id, inbound_id): - return InboundSms.query.filter_by(id=inbound_id, service_id=service_id).one() + stmt = select(InboundSms).filter_by(id=inbound_id, service_id=service_id) + return db.session.execute(stmt).scalars().one() def dao_get_paginated_most_recent_inbound_sms_by_user_number_for_service( @@ -184,7 +197,7 @@ def dao_get_paginated_most_recent_inbound_sms_by_user_number_for_service( """ t2 = aliased(InboundSms) q = ( - db.session.query(InboundSms) + select(InboundSms) .outerjoin( t2, and_( @@ -193,12 +206,34 @@ def dao_get_paginated_most_recent_inbound_sms_by_user_number_for_service( InboundSms.created_at < t2.created_at, ), ) - .filter( - t2.id == None, # noqa + .where( + t2.id.is_(None), # noqa InboundSms.service_id == service_id, InboundSms.created_at >= midnight_n_days_ago(limit_days), ) .order_by(InboundSms.created_at.desc()) ) + result = db.session.execute(q).scalars().all() + page_size = current_app.config["PAGE_SIZE"] + offset = (page - 1) * page_size + paginated_results = result[offset : offset + page_size] + pagination = Pagination(paginated_results, page, page_size, len(result)) + return pagination - return q.paginate(page=page, per_page=current_app.config["PAGE_SIZE"]) + +# TODO remove this when billing dao PR is merged. +class Pagination: + def __init__(self, items, page, per_page, total): + self.items = items + self.page = page + self.per_page = per_page + self.total = total + self.pages = (total + per_page - 1) // per_page + self.prev_num = page - 1 if page > 1 else None + self.next_num = page + 1 if page < self.pages else None + + def has_next(self): + return self.page < self.pages + + def has_prev(self): + return self.page > 1 diff --git a/app/inbound_sms/rest.py b/app/inbound_sms/rest.py index 7f8742a16..1cae7a85b 100644 --- a/app/inbound_sms/rest.py +++ b/app/inbound_sms/rest.py @@ -60,9 +60,13 @@ def get_most_recent_inbound_sms_for_service(service_id): results = dao_get_paginated_most_recent_inbound_sms_by_user_number_for_service( service_id, int(page), limit_days ) - return jsonify( - data=[row.serialize() for row in results.items], has_next=results.has_next - ) + try: + x = jsonify( + data=[row.serialize() for row in results.items], has_next=results.has_next() + ) + except Exception as e: + raise e + return x @inbound_sms.route("/summary") diff --git a/tests/app/dao/test_fact_notification_status_dao.py b/tests/app/dao/test_fact_notification_status_dao.py index 2c0de9014..fd97496e3 100644 --- a/tests/app/dao/test_fact_notification_status_dao.py +++ b/tests/app/dao/test_fact_notification_status_dao.py @@ -86,8 +86,7 @@ def test_fetch_notification_status_for_service_by_month(notify_db_session): assert results[0].month.date() == date(2018, 1, 1) assert results[0].notification_type == NotificationType.EMAIL - # TODO fix/investigate - # assert results[0].notification_status == NotificationStatus.DELIVERED + assert results[0].notification_status == NotificationStatus.DELIVERED assert results[0].count == 1 assert results[1].month.date() == date(2018, 1, 1) diff --git a/tests/app/dao/test_inbound_numbers_dao.py b/tests/app/dao/test_inbound_numbers_dao.py index ce3fd6245..efb1e376c 100644 --- a/tests/app/dao/test_inbound_numbers_dao.py +++ b/tests/app/dao/test_inbound_numbers_dao.py @@ -1,6 +1,8 @@ import pytest +from sqlalchemy import select from sqlalchemy.exc import IntegrityError +from app import db from app.dao.inbound_numbers_dao import ( dao_allocate_number_for_service, dao_get_available_inbound_numbers, @@ -35,7 +37,8 @@ def test_set_service_id_on_inbound_number(notify_db_session, sample_inbound_numb dao_set_inbound_number_to_service(service.id, numbers[0]) - res = InboundNumber.query.filter(InboundNumber.service_id == service.id).all() + stmt = select(InboundNumber).filter(InboundNumber.service_id == service.id) + res = db.session.execute(stmt).scalars().all() assert len(res) == 1 assert res[0].service_id == service.id diff --git a/tests/app/dao/test_inbound_sms_dao.py b/tests/app/dao/test_inbound_sms_dao.py index 9f3d6738d..39cdb2f53 100644 --- a/tests/app/dao/test_inbound_sms_dao.py +++ b/tests/app/dao/test_inbound_sms_dao.py @@ -2,6 +2,7 @@ from datetime import datetime from itertools import product from freezegun import freeze_time +from sqlalchemy import select from app import db from app.dao.inbound_sms_dao import ( @@ -141,7 +142,8 @@ def test_should_delete_inbound_sms_according_to_data_retention(notify_db_session deleted_count = delete_inbound_sms_older_than_retention() - history = InboundSmsHistory.query.all() + stmt = select(InboundSmsHistory) + history = db.session.execute(stmt).scalars().all() assert len(history) == 7 # four deleted for the 3-day service, two for the default seven days one, one for the 30 day @@ -171,7 +173,8 @@ def test_insert_into_inbound_sms_history_when_deleting_inbound_sms(sample_servic create_inbound_sms(sample_service, created_at=datetime(2019, 12, 19, 20, 19)) delete_inbound_sms_older_than_retention() - history = InboundSmsHistory.query.all() + stmt = select(InboundSmsHistory) + history = db.session.execute(stmt).scalars().all() assert len(history) == 1 for key_name in [ @@ -226,7 +229,8 @@ def test_delete_inbound_sms_older_than_retention_does_nothing_when_database_conf delete_inbound_sms_older_than_retention() - history = InboundSmsHistory.query.all() + stmt = select(InboundSmsHistory) + history = db.session.execute(stmt).scalars().all() assert len(history) == 1 assert history[0].id == inbound_sms_id @@ -391,7 +395,7 @@ def test_most_recent_inbound_sms_only_returns_most_recent_for_each_number( ) # noqa assert len(res.items) == 2 - assert res.has_next is False + assert res.has_next() is False assert res.per_page == 3 assert res.items[0].content == "111 5" assert res.items[1].content == "222 2" @@ -454,7 +458,7 @@ def test_most_recent_inbound_sms_paginates_properly(notify_api, sample_service): sample_service.id, limit_days=7, page=1 ) # noqa assert len(res.items) == 2 - assert res.has_next is True + assert res.has_next() is True assert res.per_page == 2 assert res.items[0].content == "444 2" assert res.items[1].content == "333 2" @@ -464,7 +468,7 @@ def test_most_recent_inbound_sms_paginates_properly(notify_api, sample_service): sample_service.id, limit_days=7, page=2 ) # noqa assert len(res.items) == 2 - assert res.has_next is False + assert res.has_next() is False assert res.items[0].content == "222 2" assert res.items[1].content == "111 2"