merge from main

This commit is contained in:
Kenneth Kehl
2024-10-28 13:42:43 -07:00
20 changed files with 727 additions and 502 deletions

View File

@@ -4,9 +4,11 @@ from functools import partial
import pytest
from freezegun import freeze_time
from sqlalchemy import func, select
from sqlalchemy.exc import IntegrityError, SQLAlchemyError
from sqlalchemy.orm.exc import NoResultFound
from app import db
from app.dao.notifications_dao import (
dao_create_notification,
dao_delete_notifications_by_id,
@@ -55,7 +57,10 @@ def test_should_by_able_to_update_status_by_reference(
notification = Notification(**data)
dao_create_notification(notification)
assert Notification.query.get(notification.id).status == NotificationStatus.SENDING
assert (
db.session.get(Notification, notification.id).status
== NotificationStatus.SENDING
)
notification.reference = "reference"
dao_update_notification(notification)
@@ -64,7 +69,8 @@ def test_should_by_able_to_update_status_by_reference(
)
assert updated.status == NotificationStatus.DELIVERED
assert (
Notification.query.get(notification.id).status == NotificationStatus.DELIVERED
db.session.get(Notification, notification.id).status
== NotificationStatus.DELIVERED
)
@@ -81,7 +87,10 @@ def test_should_by_able_to_update_status_by_id(
dao_create_notification(notification)
assert notification.status == NotificationStatus.SENDING
assert Notification.query.get(notification.id).status == NotificationStatus.SENDING
assert (
db.session.get(Notification, notification.id).status
== NotificationStatus.SENDING
)
with freeze_time("2000-01-02 12:00:00"):
updated = update_notification_status_by_id(
@@ -92,7 +101,8 @@ def test_should_by_able_to_update_status_by_id(
assert updated.status == NotificationStatus.DELIVERED
assert updated.updated_at == datetime(2000, 1, 2, 12, 0, 0)
assert (
Notification.query.get(notification.id).status == NotificationStatus.DELIVERED
db.session.get(Notification, notification.id).status
== NotificationStatus.DELIVERED
)
assert notification.updated_at == datetime(2000, 1, 2, 12, 0, 0)
assert notification.status == NotificationStatus.DELIVERED
@@ -107,15 +117,17 @@ def test_should_not_update_status_by_id_if_not_sending_and_does_not_update_job(
job=sample_job,
)
assert (
Notification.query.get(notification.id).status == NotificationStatus.DELIVERED
db.session.get(Notification, notification.id).status
== NotificationStatus.DELIVERED
)
assert not update_notification_status_by_id(
notification.id, NotificationStatus.FAILED
)
assert (
Notification.query.get(notification.id).status == NotificationStatus.DELIVERED
db.session.get(Notification, notification.id).status
== NotificationStatus.DELIVERED
)
assert sample_job == Job.query.get(notification.job_id)
assert sample_job == db.session.get(Job, notification.job_id)
def test_should_not_update_status_by_reference_if_not_sending_and_does_not_update_job(
@@ -128,20 +140,22 @@ def test_should_not_update_status_by_reference_if_not_sending_and_does_not_updat
job=sample_job,
)
assert (
Notification.query.get(notification.id).status == NotificationStatus.DELIVERED
db.session.get(Notification, notification.id).status
== NotificationStatus.DELIVERED
)
assert not update_notification_status_by_reference(
"reference", NotificationStatus.FAILED
)
assert (
Notification.query.get(notification.id).status == NotificationStatus.DELIVERED
db.session.get(Notification, notification.id).status
== NotificationStatus.DELIVERED
)
assert sample_job == Job.query.get(notification.job_id)
assert sample_job == db.session.get(Job, notification.job_id)
def test_should_update_status_by_id_if_created(sample_template, sample_notification):
assert (
Notification.query.get(sample_notification.id).status
db.session.get(Notification, sample_notification.id).status
== NotificationStatus.CREATED
)
updated = update_notification_status_by_id(
@@ -149,7 +163,7 @@ def test_should_update_status_by_id_if_created(sample_template, sample_notificat
NotificationStatus.FAILED,
)
assert (
Notification.query.get(sample_notification.id).status
db.session.get(Notification, sample_notification.id).status
== NotificationStatus.FAILED
)
assert updated.status == NotificationStatus.FAILED
@@ -244,11 +258,17 @@ def test_should_not_update_status_by_reference_if_not_sending(sample_template):
status=NotificationStatus.CREATED,
reference="reference",
)
assert Notification.query.get(notification.id).status == NotificationStatus.CREATED
assert (
db.session.get(Notification, notification.id).status
== NotificationStatus.CREATED
)
updated = update_notification_status_by_reference(
"reference", NotificationStatus.FAILED
)
assert Notification.query.get(notification.id).status == NotificationStatus.CREATED
assert (
db.session.get(Notification, notification.id).status
== NotificationStatus.CREATED
)
assert not updated
@@ -264,14 +284,18 @@ def test_should_by_able_to_update_status_by_id_from_pending_to_delivered(
assert update_notification_status_by_id(
notification_id=notification.id, status=NotificationStatus.PENDING
)
assert Notification.query.get(notification.id).status == NotificationStatus.PENDING
assert (
db.session.get(Notification, notification.id).status
== NotificationStatus.PENDING
)
assert update_notification_status_by_id(
notification.id,
NotificationStatus.DELIVERED,
)
assert (
Notification.query.get(notification.id).status == NotificationStatus.DELIVERED
db.session.get(Notification, notification.id).status
== NotificationStatus.DELIVERED
)
@@ -289,7 +313,10 @@ def test_should_by_able_to_update_status_by_id_from_pending_to_temporary_failure
notification_id=notification.id,
status=NotificationStatus.PENDING,
)
assert Notification.query.get(notification.id).status == NotificationStatus.PENDING
assert (
db.session.get(Notification, notification.id).status
== NotificationStatus.PENDING
)
assert update_notification_status_by_id(
notification.id,
@@ -297,7 +324,7 @@ def test_should_by_able_to_update_status_by_id_from_pending_to_temporary_failure
)
assert (
Notification.query.get(notification.id).status
db.session.get(Notification, notification.id).status
== NotificationStatus.TEMPORARY_FAILURE
)
@@ -312,14 +339,17 @@ def test_should_by_able_to_update_status_by_id_from_sending_to_permanent_failure
)
notification = Notification(**data)
dao_create_notification(notification)
assert Notification.query.get(notification.id).status == NotificationStatus.SENDING
assert (
db.session.get(Notification, notification.id).status
== NotificationStatus.SENDING
)
assert update_notification_status_by_id(
notification.id,
status=NotificationStatus.PERMANENT_FAILURE,
)
assert (
Notification.query.get(notification.id).status
db.session.get(Notification, notification.id).status
== NotificationStatus.PERMANENT_FAILURE
)
@@ -331,7 +361,10 @@ def test_should_not_update_status_once_notification_status_is_delivered(
template=sample_email_template,
status=NotificationStatus.SENDING,
)
assert Notification.query.get(notification.id).status == NotificationStatus.SENDING
assert (
db.session.get(Notification, notification.id).status
== NotificationStatus.SENDING
)
notification.reference = "reference"
dao_update_notification(notification)
@@ -340,7 +373,8 @@ def test_should_not_update_status_once_notification_status_is_delivered(
NotificationStatus.DELIVERED,
)
assert (
Notification.query.get(notification.id).status == NotificationStatus.DELIVERED
db.session.get(Notification, notification.id).status
== NotificationStatus.DELIVERED
)
update_notification_status_by_reference(
@@ -348,7 +382,8 @@ def test_should_not_update_status_once_notification_status_is_delivered(
NotificationStatus.FAILED,
)
assert (
Notification.query.get(notification.id).status == NotificationStatus.DELIVERED
db.session.get(Notification, notification.id).status
== NotificationStatus.DELIVERED
)
@@ -370,7 +405,7 @@ def test_create_notification_creates_notification_with_personalisation(
sample_template_with_placeholders,
sample_job,
):
assert Notification.query.count() == 0
assert _get_notification_query_count() == 0
data = create_notification(
template=sample_template_with_placeholders,
@@ -379,8 +414,8 @@ def test_create_notification_creates_notification_with_personalisation(
status=NotificationStatus.CREATED,
)
assert Notification.query.count() == 1
notification_from_db = Notification.query.all()[0]
assert _get_notification_query_count() == 1
notification_from_db = _get_notification_query_all()[0]
assert notification_from_db.id
assert data.to == notification_from_db.to
assert data.job_id == notification_from_db.job_id
@@ -393,15 +428,15 @@ def test_create_notification_creates_notification_with_personalisation(
def test_save_notification_creates_sms(sample_template, sample_job):
assert Notification.query.count() == 0
assert _get_notification_query_count() == 0
data = _notification_json(sample_template, job_id=sample_job.id)
notification = Notification(**data)
dao_create_notification(notification)
assert Notification.query.count() == 1
notification_from_db = Notification.query.all()[0]
assert _get_notification_query_count() == 1
notification_from_db = _get_notification_query_all()[0]
assert notification_from_db.id
assert "1" == notification_from_db.to
assert data["job_id"] == notification_from_db.job_id
@@ -412,16 +447,36 @@ def test_save_notification_creates_sms(sample_template, sample_job):
assert notification_from_db.status == NotificationStatus.CREATED
def _get_notification_query_all():
stmt = select(Notification)
return db.session.execute(stmt).scalars().all()
def _get_notification_query_one():
stmt = select(Notification)
return db.session.execute(stmt).scalars().one()
def _get_notification_query_count():
stmt = select(func.count(Notification.id))
return db.session.execute(stmt).scalar() or 0
def _get_notification_history_query_count():
stmt = select(func.count(NotificationHistory.id))
return db.session.execute(stmt).scalar() or 0
def test_save_notification_and_create_email(sample_email_template, sample_job):
assert Notification.query.count() == 0
assert _get_notification_query_count() == 0
data = _notification_json(sample_email_template, job_id=sample_job.id)
notification = Notification(**data)
dao_create_notification(notification)
assert Notification.query.count() == 1
notification_from_db = Notification.query.all()[0]
assert _get_notification_query_count() == 1
notification_from_db = _get_notification_query_all()[0]
assert notification_from_db.id
assert "1" == notification_from_db.to
assert data["job_id"] == notification_from_db.job_id
@@ -433,29 +488,29 @@ def test_save_notification_and_create_email(sample_email_template, sample_job):
def test_save_notification(sample_email_template, sample_job):
assert Notification.query.count() == 0
assert _get_notification_query_count() == 0
data = _notification_json(sample_email_template, job_id=sample_job.id)
notification_1 = Notification(**data)
notification_2 = Notification(**data)
dao_create_notification(notification_1)
assert Notification.query.count() == 1
assert _get_notification_query_count() == 1
dao_create_notification(notification_2)
assert Notification.query.count() == 2
assert _get_notification_query_count() == 2
def test_save_notification_does_not_creates_history(sample_email_template, sample_job):
assert Notification.query.count() == 0
assert _get_notification_query_count() == 0
data = _notification_json(sample_email_template, job_id=sample_job.id)
notification_1 = Notification(**data)
dao_create_notification(notification_1)
assert Notification.query.count() == 1
assert NotificationHistory.query.count() == 0
assert _get_notification_query_count() == 1
assert _get_notification_history_query_count() == 0
def test_update_notification_with_research_mode_service_does_not_create_or_update_history(
@@ -464,14 +519,14 @@ def test_update_notification_with_research_mode_service_does_not_create_or_updat
sample_template.service.research_mode = True
notification = create_notification(template=sample_template)
assert Notification.query.count() == 1
assert NotificationHistory.query.count() == 0
assert _get_notification_query_count() == 1
assert _get_notification_history_query_count() == 0
notification.status = NotificationStatus.DELIVERED
dao_update_notification(notification)
assert Notification.query.one().status == NotificationStatus.DELIVERED
assert NotificationHistory.query.count() == 0
assert _get_notification_query_one().status == NotificationStatus.DELIVERED
assert _get_notification_history_query_count() == 0
def test_not_save_notification_and_not_create_stats_on_commit_error(
@@ -479,26 +534,26 @@ def test_not_save_notification_and_not_create_stats_on_commit_error(
):
random_id = str(uuid.uuid4())
assert Notification.query.count() == 0
assert _get_notification_query_count() == 0
data = _notification_json(sample_template, job_id=random_id)
notification = Notification(**data)
with pytest.raises(SQLAlchemyError):
dao_create_notification(notification)
assert Notification.query.count() == 0
assert Job.query.get(sample_job.id).notifications_sent == 0
assert _get_notification_query_count() == 0
assert db.session.get(Job, sample_job.id).notifications_sent == 0
def test_save_notification_and_increment_job(sample_template, sample_job, sns_provider):
assert Notification.query.count() == 0
assert _get_notification_query_count() == 0
data = _notification_json(sample_template, job_id=sample_job.id)
notification = Notification(**data)
dao_create_notification(notification)
assert Notification.query.count() == 1
notification_from_db = Notification.query.all()[0]
assert _get_notification_query_count() == 1
notification_from_db = _get_notification_query_all()[0]
assert notification_from_db.id
assert "1" == notification_from_db.to
assert data["job_id"] == notification_from_db.job_id
@@ -510,21 +565,21 @@ def test_save_notification_and_increment_job(sample_template, sample_job, sns_pr
notification_2 = Notification(**data)
dao_create_notification(notification_2)
assert Notification.query.count() == 2
assert _get_notification_query_count() == 2
def test_save_notification_and_increment_correct_job(sample_template, sns_provider):
job_1 = create_job(sample_template)
job_2 = create_job(sample_template)
assert Notification.query.count() == 0
assert _get_notification_query_count() == 0
data = _notification_json(sample_template, job_id=job_1.id)
notification = Notification(**data)
dao_create_notification(notification)
assert Notification.query.count() == 1
notification_from_db = Notification.query.all()[0]
assert _get_notification_query_count() == 1
notification_from_db = _get_notification_query_all()[0]
assert notification_from_db.id
assert "1" == notification_from_db.to
assert data["job_id"] == notification_from_db.job_id
@@ -537,14 +592,14 @@ def test_save_notification_and_increment_correct_job(sample_template, sns_provid
def test_save_notification_with_no_job(sample_template, sns_provider):
assert Notification.query.count() == 0
assert _get_notification_query_count() == 0
data = _notification_json(sample_template)
notification = Notification(**data)
dao_create_notification(notification)
assert Notification.query.count() == 1
notification_from_db = Notification.query.all()[0]
assert _get_notification_query_count() == 1
notification_from_db = _get_notification_query_all()[0]
assert notification_from_db.id
assert "1" == notification_from_db.to
assert data["service"] == notification_from_db.service
@@ -592,7 +647,7 @@ def test_get_notification_by_id_when_notification_exists_for_different_service(
def test_get_notifications_by_reference(sample_template):
client_reference = "some-client-ref"
assert len(Notification.query.all()) == 0
assert len(_get_notification_query_all()) == 0
create_notification(sample_template, client_reference=client_reference)
create_notification(sample_template, client_reference=client_reference)
create_notification(sample_template, client_reference="other-ref")
@@ -603,14 +658,14 @@ def test_get_notifications_by_reference(sample_template):
def test_save_notification_no_job_id(sample_template):
assert Notification.query.count() == 0
assert _get_notification_query_count() == 0
data = _notification_json(sample_template)
notification = Notification(**data)
dao_create_notification(notification)
assert Notification.query.count() == 1
notification_from_db = Notification.query.all()[0]
assert _get_notification_query_count() == 1
notification_from_db = _get_notification_query_all()[0]
assert notification_from_db.id
assert "1" == notification_from_db.to
assert data["service"] == notification_from_db.service
@@ -687,13 +742,13 @@ def test_update_notification_sets_status(sample_notification):
assert sample_notification.status == NotificationStatus.CREATED
sample_notification.status = NotificationStatus.FAILED
dao_update_notification(sample_notification)
notification_from_db = Notification.query.get(sample_notification.id)
notification_from_db = db.session.get(Notification, sample_notification.id)
assert notification_from_db.status == NotificationStatus.FAILED
@freeze_time("2016-01-10")
def test_should_limit_notifications_return_by_day_limit_plus_one(sample_template):
assert len(Notification.query.all()) == 0
assert len(_get_notification_query_all()) == 0
# create one notification a day between 1st and 9th,
# with assumption that the local timezone is EST
@@ -706,7 +761,7 @@ def test_should_limit_notifications_return_by_day_limit_plus_one(sample_template
status=NotificationStatus.FAILED,
)
all_notifications = Notification.query.all()
all_notifications = _get_notification_query_all()
assert len(all_notifications) == 10
all_notifications = get_notifications_for_service(
@@ -722,19 +777,19 @@ def test_should_limit_notifications_return_by_day_limit_plus_one(sample_template
def test_creating_notification_does_not_add_notification_history(sample_template):
create_notification(template=sample_template)
assert Notification.query.count() == 1
assert NotificationHistory.query.count() == 0
assert _get_notification_query_count() == 1
assert _get_notification_history_query_count() == 0
def test_should_delete_notification_for_id(sample_template):
notification = create_notification(template=sample_template)
assert Notification.query.count() == 1
assert NotificationHistory.query.count() == 0
assert _get_notification_query_count() == 1
assert _get_notification_history_query_count() == 0
dao_delete_notifications_by_id(notification.id)
assert Notification.query.count() == 0
assert _get_notification_query_count() == 0
def test_should_delete_notification_and_ignore_history_for_research_mode(
@@ -744,31 +799,32 @@ def test_should_delete_notification_and_ignore_history_for_research_mode(
notification = create_notification(template=sample_template)
assert Notification.query.count() == 1
assert _get_notification_query_count() == 1
dao_delete_notifications_by_id(notification.id)
assert Notification.query.count() == 0
assert _get_notification_query_count() == 0
def test_should_delete_only_notification_with_id(sample_template):
notification_1 = create_notification(template=sample_template)
notification_2 = create_notification(template=sample_template)
assert Notification.query.count() == 2
assert _get_notification_query_count() == 2
dao_delete_notifications_by_id(notification_1.id)
assert Notification.query.count() == 1
assert Notification.query.first().id == notification_2.id
assert _get_notification_query_count() == 1
stmt = select(Notification)
assert db.session.execute(stmt).scalars().first().id == notification_2.id
def test_should_delete_no_notifications_if_no_matching_ids(sample_template):
create_notification(template=sample_template)
assert Notification.query.count() == 1
assert _get_notification_query_count() == 1
dao_delete_notifications_by_id(uuid.uuid4())
assert Notification.query.count() == 1
assert _get_notification_query_count() == 1
def _notification_json(sample_template, job_id=None, id=None, status=None):
@@ -814,16 +870,19 @@ def test_dao_timeout_notifications(sample_template):
temporary_failure_notifications = dao_timeout_notifications(utc_now())
assert len(temporary_failure_notifications) == 2
assert Notification.query.get(created.id).status == NotificationStatus.CREATED
assert db.session.get(Notification, created.id).status == NotificationStatus.CREATED
assert (
Notification.query.get(sending.id).status
db.session.get(Notification, sending.id).status
== NotificationStatus.TEMPORARY_FAILURE
)
assert (
Notification.query.get(pending.id).status
db.session.get(Notification, pending.id).status
== NotificationStatus.TEMPORARY_FAILURE
)
assert Notification.query.get(delivered.id).status == NotificationStatus.DELIVERED
assert (
db.session.get(Notification, delivered.id).status
== NotificationStatus.DELIVERED
)
def test_dao_timeout_notifications_only_updates_for_older_notifications(
@@ -842,8 +901,8 @@ def test_dao_timeout_notifications_only_updates_for_older_notifications(
temporary_failure_notifications = dao_timeout_notifications(utc_now())
assert len(temporary_failure_notifications) == 0
assert Notification.query.get(sending.id).status == NotificationStatus.SENDING
assert Notification.query.get(pending.id).status == NotificationStatus.PENDING
assert db.session.get(Notification, sending.id).status == NotificationStatus.SENDING
assert db.session.get(Notification, pending.id).status == NotificationStatus.PENDING
def test_should_return_notifications_excluding_jobs_by_default(
@@ -935,7 +994,7 @@ def test_get_notifications_created_by_api_or_csv_are_returned_correctly_excludin
key_type=sample_test_api_key.key_type,
)
all_notifications = Notification.query.all()
all_notifications = _get_notification_query_all()
assert len(all_notifications) == 4
# returns all real API derived notifications
@@ -982,7 +1041,7 @@ def test_get_notifications_with_a_live_api_key_type(
key_type=sample_test_api_key.key_type,
)
all_notifications = Notification.query.all()
all_notifications = _get_notification_query_all()
assert len(all_notifications) == 4
# only those created with normal API key, no jobs
@@ -1114,7 +1173,7 @@ def test_should_exclude_test_key_notifications_by_default(
key_type=sample_test_api_key.key_type,
)
all_notifications = Notification.query.all()
all_notifications = _get_notification_query_all()
assert len(all_notifications) == 4
all_notifications = get_notifications_for_service(
@@ -1757,10 +1816,10 @@ def test_dao_update_notifications_by_reference_updated_notifications(sample_temp
update_dict={"status": NotificationStatus.DELIVERED, "billable_units": 2},
)
assert updated_count == 2
updated_1 = Notification.query.get(notification_1.id)
updated_1 = db.session.get(Notification, notification_1.id)
assert updated_1.billable_units == 2
assert updated_1.status == NotificationStatus.DELIVERED
updated_2 = Notification.query.get(notification_2.id)
updated_2 = db.session.get(Notification, notification_2.id)
assert updated_2.billable_units == 2
assert updated_2.status == NotificationStatus.DELIVERED
@@ -1823,10 +1882,11 @@ def test_dao_update_notifications_by_reference_updates_history_when_one_of_two_n
assert updated_count == 1
assert updated_history_count == 1
assert (
Notification.query.get(notification2.id).status == NotificationStatus.DELIVERED
db.session.get(Notification, notification2.id).status
== NotificationStatus.DELIVERED
)
assert (
NotificationHistory.query.get(notification1.id).status
db.session.get(NotificationHistory, notification1.id).status
== NotificationStatus.DELIVERED
)

View File

@@ -1,6 +1,7 @@
import uuid
import pytest
from sqlalchemy import select
from sqlalchemy.exc import IntegrityError, SQLAlchemyError
from app import db
@@ -57,7 +58,8 @@ def test_get_organization_by_id_gets_correct_organization(notify_db_session):
def test_update_organization(notify_db_session):
create_organization()
organization = Organization.query.one()
stmt = select(Organization)
organization = db.session.execute(stmt).scalars().one()
user = create_user()
email_branding = create_email_branding()
@@ -78,7 +80,8 @@ def test_update_organization(notify_db_session):
dao_update_organization(organization.id, **data)
organization = Organization.query.one()
stmt = select(Organization)
organization = db.session.execute(stmt).scalars().one()
for attribute, value in data.items():
assert getattr(organization, attribute) == value
@@ -102,7 +105,8 @@ def test_update_organization_domains_lowercases(
):
create_organization()
organization = Organization.query.one()
stmt = select(Organization)
organization = db.session.execute(stmt).scalars().one()
# Seed some domains
dao_update_organization(organization.id, domains=["123", "456"])
@@ -121,7 +125,8 @@ def test_update_organization_domains_lowercases_integrity_error(
):
create_organization()
organization = Organization.query.one()
stmt = select(Organization)
organization = db.session.execute(stmt).scalars().one()
# Seed some domains
dao_update_organization(organization.id, domains=["123", "456"])
@@ -175,11 +180,11 @@ def test_update_organization_updates_the_service_org_type_if_org_type_is_provide
assert sample_organization.organization_type == OrganizationType.FEDERAL
assert sample_service.organization_type == OrganizationType.FEDERAL
stmt = select(Service.get_history_model()).filter_by(
id=sample_service.id, version=2
)
assert (
Service.get_history_model()
.query.filter_by(id=sample_service.id, version=2)
.one()
.organization_type
db.session.execute(stmt).scalars().one().organization_type
== OrganizationType.FEDERAL
)
@@ -229,11 +234,11 @@ def test_add_service_to_organization(sample_service, sample_organization):
assert sample_organization.services[0].id == sample_service.id
assert sample_service.organization_type == sample_organization.organization_type
stmt = select(Service.get_history_model()).filter_by(
id=sample_service.id, version=2
)
assert (
Service.get_history_model()
.query.filter_by(id=sample_service.id, version=2)
.one()
.organization_type
db.session.execute(stmt).scalars().one().organization_type
== sample_organization.organization_type
)
assert sample_service.organization_id == sample_organization.id

View File

@@ -1,8 +1,10 @@
import uuid
import pytest
from sqlalchemy import select
from sqlalchemy.exc import SQLAlchemyError
from app import db
from app.dao.service_sms_sender_dao import (
archive_sms_sender,
dao_add_sms_sender_for_service,
@@ -97,10 +99,8 @@ def test_dao_add_sms_sender_for_service(notify_db_session):
is_default=False,
inbound_number_id=None,
)
service_sms_senders = ServiceSmsSender.query.order_by(
ServiceSmsSender.created_at
).all()
stmt = select(ServiceSmsSender).order_by(ServiceSmsSender.created_at)
service_sms_senders = db.session.execute(stmt).scalars().all()
assert len(service_sms_senders) == 2
assert service_sms_senders[0].sms_sender == "testing"
assert service_sms_senders[0].is_default
@@ -116,10 +116,8 @@ def test_dao_add_sms_sender_for_service_switches_default(notify_db_session):
is_default=True,
inbound_number_id=None,
)
service_sms_senders = ServiceSmsSender.query.order_by(
ServiceSmsSender.created_at
).all()
stmt = select(ServiceSmsSender).order_by(ServiceSmsSender.created_at)
service_sms_senders = db.session.execute(stmt).scalars().all()
assert len(service_sms_senders) == 2
assert service_sms_senders[0].sms_sender == "testing"
assert not service_sms_senders[0].is_default
@@ -128,7 +126,8 @@ def test_dao_add_sms_sender_for_service_switches_default(notify_db_session):
def test_dao_update_service_sms_sender(notify_db_session):
service = create_service()
service_sms_senders = ServiceSmsSender.query.filter_by(service_id=service.id).all()
stmt = select(ServiceSmsSender).filter_by(service_id=service.id)
service_sms_senders = db.session.execute(stmt).scalars().all()
assert len(service_sms_senders) == 1
sms_sender_to_update = service_sms_senders[0]
@@ -138,7 +137,8 @@ def test_dao_update_service_sms_sender(notify_db_session):
is_default=True,
sms_sender="updated",
)
sms_senders = ServiceSmsSender.query.filter_by(service_id=service.id).all()
stmt = select(ServiceSmsSender).filter_by(service_id=service.id)
sms_senders = db.session.execute(stmt).scalars().all()
assert len(sms_senders) == 1
assert sms_senders[0].is_default
assert sms_senders[0].sms_sender == "updated"
@@ -159,7 +159,8 @@ def test_dao_update_service_sms_sender_switches_default(notify_db_session):
is_default=True,
sms_sender="updated",
)
sms_senders = ServiceSmsSender.query.filter_by(service_id=service.id).all()
stmt = select(ServiceSmsSender).filter_by(service_id=service.id)
sms_senders = db.session.execute(stmt).scalars().all()
expected = {("testing", False), ("updated", True)}
results = {(sender.sms_sender, sender.is_default) for sender in sms_senders}
@@ -190,7 +191,8 @@ def test_update_existing_sms_sender_with_inbound_number(notify_db_session):
service = create_service()
inbound_number = create_inbound_number(number="12345", service_id=service.id)
existing_sms_sender = ServiceSmsSender.query.filter_by(service_id=service.id).one()
stmt = select(ServiceSmsSender).filter_by(service_id=service.id)
existing_sms_sender = db.session.execute(stmt).scalars().one()
sms_sender = update_existing_sms_sender_with_inbound_number(
service_sms_sender=existing_sms_sender,
sms_sender=inbound_number.number,
@@ -206,7 +208,8 @@ def test_update_existing_sms_sender_with_inbound_number_raises_exception_if_inbo
notify_db_session,
):
service = create_service()
existing_sms_sender = ServiceSmsSender.query.filter_by(service_id=service.id).one()
stmt = select(ServiceSmsSender).filter_by(service_id=service.id)
existing_sms_sender = db.session.execute(stmt).scalars().one()
with pytest.raises(expected_exception=SQLAlchemyError):
update_existing_sms_sender_with_inbound_number(
service_sms_sender=existing_sms_sender,

View File

@@ -6,6 +6,7 @@ from unittest.mock import Mock
import pytest
import sqlalchemy
from freezegun import freeze_time
from sqlalchemy import func, select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm.exc import NoResultFound
@@ -89,9 +90,32 @@ from tests.app.db import (
)
def _get_service_query_count():
stmt = select(func.count(Service.id))
return db.session.execute(stmt).scalar() or 0
def _get_service_history_query_count():
stmt = select(func.count(Service.get_history_model().id))
return db.session.execute(stmt).scalar() or 0
def _get_first_service():
stmt = select(Service).limit(1)
service = db.session.execute(stmt).scalars().first()
return service
def _get_service_by_id(service_id):
stmt = select(Service).filter(Service.id == service_id)
service = db.session.execute(stmt).scalars().one()
return service
def test_create_service(notify_db_session):
user = create_user()
assert Service.query.count() == 0
assert _get_service_query_count() == 0
service = Service(
name="service_name",
email_from="email_from",
@@ -101,8 +125,8 @@ def test_create_service(notify_db_session):
created_by=user,
)
dao_create_service(service, user)
assert Service.query.count() == 1
service_db = Service.query.one()
assert _get_service_query_count() == 1
service_db = _get_first_service()
assert service_db.name == "service_name"
assert service_db.id == service.id
assert service_db.email_from == "email_from"
@@ -120,7 +144,7 @@ def test_create_service_with_organization(notify_db_session):
organization_type=OrganizationType.STATE,
domains=["local-authority.gov.uk"],
)
assert Service.query.count() == 0
assert _get_service_query_count() == 0
service = Service(
name="service_name",
email_from="email_from",
@@ -130,9 +154,9 @@ def test_create_service_with_organization(notify_db_session):
created_by=user,
)
dao_create_service(service, user)
assert Service.query.count() == 1
service_db = Service.query.one()
organization = Organization.query.get(organization.id)
assert _get_service_query_count() == 1
service_db = _get_first_service()
organization = db.session.get(Organization, organization.id)
assert service_db.name == "service_name"
assert service_db.id == service.id
assert service_db.email_from == "email_from"
@@ -151,7 +175,7 @@ def test_fetch_service_by_id_with_api_keys(notify_db_session):
organization_type=OrganizationType.STATE,
domains=["local-authority.gov.uk"],
)
assert Service.query.count() == 0
assert _get_service_query_count() == 0
service = Service(
name="service_name",
email_from="email_from",
@@ -161,9 +185,9 @@ def test_fetch_service_by_id_with_api_keys(notify_db_session):
created_by=user,
)
dao_create_service(service, user)
assert Service.query.count() == 1
service_db = Service.query.one()
organization = Organization.query.get(organization.id)
assert _get_service_query_count() == 1
service_db = _get_first_service()
organization = db.session.get(Organization, organization.id)
assert service_db.name == "service_name"
assert service_db.id == service.id
assert service_db.email_from == "email_from"
@@ -183,7 +207,7 @@ def test_fetch_service_by_id_with_api_keys(notify_db_session):
def test_cannot_create_two_services_with_same_name(notify_db_session):
user = create_user()
assert Service.query.count() == 0
assert _get_service_query_count() == 0
service1 = Service(
name="service_name",
email_from="email_from1",
@@ -209,7 +233,7 @@ def test_cannot_create_two_services_with_same_name(notify_db_session):
def test_cannot_create_two_services_with_same_email_from(notify_db_session):
user = create_user()
assert Service.query.count() == 0
assert _get_service_query_count() == 0
service1 = Service(
name="service_name1",
email_from="email_from",
@@ -235,7 +259,7 @@ def test_cannot_create_two_services_with_same_email_from(notify_db_session):
def test_cannot_create_service_with_no_user(notify_db_session):
user = create_user()
assert Service.query.count() == 0
assert _get_service_query_count() == 0
service = Service(
name="service_name",
email_from="email_from",
@@ -258,7 +282,7 @@ def test_should_add_user_to_service(notify_db_session):
created_by=user,
)
dao_create_service(service, user)
assert user in Service.query.first().users
assert user in _get_first_service().users
new_user = User(
name="Test User",
email_address="new_user@digital.fake.gov",
@@ -267,7 +291,7 @@ def test_should_add_user_to_service(notify_db_session):
)
save_model_user(new_user, validated_email_access=True)
dao_add_user_to_service(service, new_user)
assert new_user in Service.query.first().users
assert new_user in _get_first_service().users
def test_dao_add_user_to_service_sets_folder_permissions(sample_user, sample_service):
@@ -314,7 +338,8 @@ def test_dao_add_user_to_service_raises_error_if_adding_folder_permissions_for_a
other_service_folder = create_template_folder(other_service)
folder_permissions = [str(other_service_folder.id)]
assert ServiceUser.query.count() == 2
stmt = select(func.count(ServiceUser.service_id))
assert db.session.execute(stmt).scalar() == 2
with pytest.raises(IntegrityError) as e:
dao_add_user_to_service(
@@ -326,7 +351,8 @@ def test_dao_add_user_to_service_raises_error_if_adding_folder_permissions_for_a
'insert or update on table "user_folder_permissions" violates foreign key constraint'
in str(e.value)
)
assert ServiceUser.query.count() == 2
stmt = select(func.count(ServiceUser.service_id))
assert db.session.execute(stmt).scalar() == 2
def test_should_remove_user_from_service(notify_db_session):
@@ -347,9 +373,9 @@ def test_should_remove_user_from_service(notify_db_session):
)
save_model_user(new_user, validated_email_access=True)
dao_add_user_to_service(service, new_user)
assert new_user in Service.query.first().users
assert new_user in _get_first_service().users
dao_remove_user_from_service(service, new_user)
assert new_user not in Service.query.first().users
assert new_user not in _get_first_service().users
def test_should_remove_user_from_service_exception(notify_db_session):
@@ -382,11 +408,12 @@ def test_should_remove_user_from_service_exception(notify_db_session):
def test_removing_a_user_from_a_service_deletes_their_permissions(
sample_user, sample_service
):
assert len(Permission.query.all()) == 7
stmt = select(Permission)
assert len(db.session.execute(stmt).all()) == 7
dao_remove_user_from_service(sample_service, sample_user)
assert Permission.query.all() == []
assert db.session.execute(stmt).all() == []
def test_removing_a_user_from_a_service_deletes_their_folder_permissions_for_that_service(
@@ -668,8 +695,8 @@ def test_removing_all_permission_returns_service_with_no_permissions(notify_db_s
def test_create_service_creates_a_history_record_with_current_data(notify_db_session):
user = create_user()
assert Service.query.count() == 0
assert Service.get_history_model().query.count() == 0
assert _get_service_query_count() == 0
assert _get_service_history_query_count() == 0
service = Service(
name="service_name",
email_from="email_from",
@@ -678,11 +705,12 @@ def test_create_service_creates_a_history_record_with_current_data(notify_db_ses
created_by=user,
)
dao_create_service(service, user)
assert Service.query.count() == 1
assert Service.get_history_model().query.count() == 1
assert _get_service_query_count() == 1
assert _get_service_history_query_count() == 1
service_from_db = Service.query.first()
service_history = Service.get_history_model().query.first()
service_from_db = _get_first_service()
stmt = select(Service.get_history_model())
service_history = db.session.execute(stmt).scalars().first()
assert service_from_db.id == service_history.id
assert service_from_db.name == service_history.name
@@ -694,8 +722,8 @@ def test_create_service_creates_a_history_record_with_current_data(notify_db_ses
def test_update_service_creates_a_history_record_with_current_data(notify_db_session):
user = create_user()
assert Service.query.count() == 0
assert Service.get_history_model().query.count() == 0
assert _get_service_query_count() == 0
assert _get_service_history_query_count() == 0
service = Service(
name="service_name",
email_from="email_from",
@@ -705,39 +733,31 @@ def test_update_service_creates_a_history_record_with_current_data(notify_db_ses
)
dao_create_service(service, user)
assert Service.query.count() == 1
assert Service.query.first().version == 1
assert Service.get_history_model().query.count() == 1
assert _get_service_query_count() == 1
assert _get_first_service().version == 1
assert _get_service_history_query_count() == 1
service.name = "updated_service_name"
dao_update_service(service)
assert Service.query.count() == 1
assert Service.get_history_model().query.count() == 2
assert _get_service_query_count() == 1
assert _get_service_history_query_count() == 2
service_from_db = Service.query.first()
service_from_db = _get_first_service()
assert service_from_db.version == 2
assert (
Service.get_history_model().query.filter_by(name="service_name").one().version
== 1
)
assert (
Service.get_history_model()
.query.filter_by(name="updated_service_name")
.one()
.version
== 2
)
stmt = select(Service.get_history_model()).filter_by(name="service_name")
assert db.session.execute(stmt).scalars().one().version == 1
stmt = select(Service.get_history_model()).filter_by(name="updated_service_name")
assert db.session.execute(stmt).scalars().one().version == 2
def test_update_service_permission_creates_a_history_record_with_current_data(
notify_db_session,
):
user = create_user()
assert Service.query.count() == 0
assert Service.get_history_model().query.count() == 0
assert _get_service_query_count() == 0
assert _get_service_history_query_count() == 0
service = Service(
name="service_name",
email_from="email_from",
@@ -755,17 +775,17 @@ def test_update_service_permission_creates_a_history_record_with_current_data(
],
)
assert Service.query.count() == 1
assert _get_service_query_count() == 1
service.permissions.append(
ServicePermission(service_id=service.id, permission=ServicePermissionType.EMAIL)
)
dao_update_service(service)
assert Service.query.count() == 1
assert Service.get_history_model().query.count() == 2
assert _get_service_query_count() == 1
assert _get_service_history_query_count() == 2
service_from_db = Service.query.first()
service_from_db = _get_first_service()
assert service_from_db.version == 2
@@ -784,10 +804,10 @@ def test_update_service_permission_creates_a_history_record_with_current_data(
service.permissions.remove(permission)
dao_update_service(service)
assert Service.query.count() == 1
assert Service.get_history_model().query.count() == 3
assert _get_service_query_count() == 1
assert _get_service_history_query_count() == 3
service_from_db = Service.query.first()
service_from_db = _get_first_service()
assert service_from_db.version == 3
_assert_service_permissions(
service.permissions,
@@ -797,21 +817,20 @@ def test_update_service_permission_creates_a_history_record_with_current_data(
),
)
history = (
Service.get_history_model()
.query.filter_by(name="service_name")
stmt = (
select(Service.get_history_model())
.filter_by(name="service_name")
.order_by("version")
.all()
)
history = db.session.execute(stmt).scalars().all()
assert len(history) == 3
assert history[2].version == 3
def test_create_service_and_history_is_transactional(notify_db_session):
user = create_user()
assert Service.query.count() == 0
assert Service.get_history_model().query.count() == 0
assert _get_service_query_count() == 0
assert _get_service_history_query_count() == 0
service = Service(
name=None,
email_from="email_from",
@@ -828,8 +847,8 @@ def test_create_service_and_history_is_transactional(notify_db_session):
in str(seeei)
)
assert Service.query.count() == 0
assert Service.get_history_model().query.count() == 0
assert _get_service_query_count() == 0
assert _get_service_history_query_count() == 0
def test_delete_service_and_associated_objects(notify_db_session):
@@ -845,8 +864,8 @@ def test_delete_service_and_associated_objects(notify_db_session):
create_notification(template=template, api_key=api_key)
create_invited_user(service=service)
user.organizations = [organization]
assert ServicePermission.query.count() == len(
stmt = select(func.count(ServicePermission.service_id))
assert db.session.execute(stmt).scalar() == len(
(
ServicePermissionType.SMS,
ServicePermissionType.EMAIL,
@@ -855,21 +874,35 @@ def test_delete_service_and_associated_objects(notify_db_session):
)
delete_service_and_all_associated_db_objects(service)
assert VerifyCode.query.count() == 0
assert ApiKey.query.count() == 0
assert ApiKey.get_history_model().query.count() == 0
assert Template.query.count() == 0
assert TemplateHistory.query.count() == 0
assert Job.query.count() == 0
assert Notification.query.count() == 0
assert Permission.query.count() == 0
assert User.query.count() == 0
assert InvitedUser.query.count() == 0
assert Service.query.count() == 0
assert Service.get_history_model().query.count() == 0
assert ServicePermission.query.count() == 0
stmt = select(VerifyCode)
assert db.session.execute(stmt).scalar() is None
stmt = select(ApiKey)
assert db.session.execute(stmt).scalar() is None
stmt = select(ApiKey.get_history_model())
assert db.session.execute(stmt).scalar() is None
stmt = select(Template)
assert db.session.execute(stmt).scalar() is None
stmt = select(TemplateHistory)
assert db.session.execute(stmt).scalar() is None
stmt = select(Job)
assert db.session.execute(stmt).scalar() is None
stmt = select(Notification)
assert db.session.execute(stmt).scalar() is None
stmt = select(Permission)
assert db.session.execute(stmt).scalar() is None
stmt = select(User)
assert db.session.execute(stmt).scalar() is None
stmt = select(InvitedUser)
assert db.session.execute(stmt).scalar() is None
assert _get_service_query_count() == 0
assert _get_service_history_query_count() == 0
stmt = select(ServicePermission)
assert db.session.execute(stmt).scalar() is None
# the organization hasn't been deleted
assert Organization.query.count() == 1
stmt = select(func.count(Organization.id))
assert db.session.execute(stmt).scalar() == 1
def test_add_existing_user_to_another_service_doesnot_change_old_permissions(
@@ -887,9 +920,8 @@ def test_add_existing_user_to_another_service_doesnot_change_old_permissions(
dao_create_service(service_one, user)
assert user.id == service_one.users[0].id
test_user_permissions = Permission.query.filter_by(
service=service_one, user=user
).all()
stmt = select(Permission).filter_by(service=service_one, user=user)
test_user_permissions = db.session.execute(stmt).all()
assert len(test_user_permissions) == 7
other_user = User(
@@ -909,14 +941,12 @@ def test_add_existing_user_to_another_service_doesnot_change_old_permissions(
dao_create_service(service_two, other_user)
assert other_user.id == service_two.users[0].id
other_user_permissions = Permission.query.filter_by(
service=service_two, user=other_user
).all()
stmt = select(Permission).filter_by(service=service_two, user=other_user)
other_user_permissions = db.session.execute(stmt).all()
assert len(other_user_permissions) == 7
stmt = select(Permission).filter_by(service=service_one, user=other_user)
other_user_service_one_permissions = db.session.execute(stmt).all()
other_user_service_one_permissions = Permission.query.filter_by(
service=service_one, user=other_user
).all()
assert len(other_user_service_one_permissions) == 0
# adding the other_user to service_one should leave all other_user permissions on service_two intact
@@ -925,15 +955,12 @@ def test_add_existing_user_to_another_service_doesnot_change_old_permissions(
permissions.append(Permission(permission=p))
dao_add_user_to_service(service_one, other_user, permissions=permissions)
other_user_service_one_permissions = Permission.query.filter_by(
service=service_one, user=other_user
).all()
stmt = select(Permission).filter_by(service=service_one, user=other_user)
other_user_service_one_permissions = db.session.execute(stmt).all()
assert len(other_user_service_one_permissions) == 2
other_user_service_two_permissions = Permission.query.filter_by(
service=service_two, user=other_user
).all()
stmt = select(Permission).filter_by(service=service_two, user=other_user)
other_user_service_two_permissions = db.session.execute(stmt).all()
assert len(other_user_service_two_permissions) == 7
@@ -956,9 +983,10 @@ def test_fetch_stats_filters_on_service(notify_db_session):
def test_fetch_stats_ignores_historical_notification_data(sample_template):
create_notification_history(template=sample_template)
assert Notification.query.count() == 0
assert NotificationHistory.query.count() == 1
stmt = select(func.count(Notification.id))
assert db.session.execute(stmt).scalar() == 0
stmt = select(func.count(NotificationHistory.id))
assert db.session.execute(stmt).scalar() == 1
stats = dao_fetch_todays_stats_for_service(sample_template.service_id)
assert len(stats) == 0
@@ -1316,7 +1344,7 @@ def test_dao_fetch_todays_stats_for_all_services_can_exclude_from_test_key(
def test_dao_suspend_service_with_no_api_keys(notify_db_session):
service = create_service()
dao_suspend_service(service.id)
service = Service.query.get(service.id)
service = _get_service_by_id(service.id)
assert not service.active
assert service.name == service.name
assert service.api_keys == []
@@ -1329,11 +1357,11 @@ def test_dao_suspend_service_marks_service_as_inactive_and_expires_api_keys(
service = create_service()
api_key = create_api_key(service=service)
dao_suspend_service(service.id)
service = Service.query.get(service.id)
service = _get_service_by_id(service.id)
assert not service.active
assert service.name == service.name
api_key = ApiKey.query.get(api_key.id)
api_key = db.session.get(ApiKey, api_key.id)
assert api_key.expiry_date == datetime(2001, 1, 1, 23, 59, 00)
@@ -1344,13 +1372,13 @@ def test_dao_resume_service_marks_service_as_active_and_api_keys_are_still_revok
service = create_service()
api_key = create_api_key(service=service)
dao_suspend_service(service.id)
service = Service.query.get(service.id)
service = _get_service_by_id(service.id)
assert not service.active
dao_resume_service(service.id)
assert Service.query.get(service.id).active
assert _get_service_by_id(service.id).active
api_key = ApiKey.query.get(api_key.id)
api_key = db.session.get(ApiKey, api_key.id)
assert api_key.expiry_date == datetime(2001, 1, 1, 23, 59, 00)

View File

@@ -1,3 +1,5 @@
from sqlalchemy import select
from app import db
from app.dao.service_user_dao import dao_get_service_user
from app.dao.template_folder_dao import (
@@ -17,5 +19,5 @@ def test_dao_delete_template_folder_deletes_user_folder_permissions(
dao_update_template_folder(folder)
dao_delete_template_folder(folder)
assert db.session.query(user_folder_permissions).all() == []
stmt = select(user_folder_permissions)
assert db.session.execute(stmt).scalars().all() == []

View File

@@ -2,8 +2,10 @@ from datetime import datetime
import pytest
from freezegun import freeze_time
from sqlalchemy import func, select
from sqlalchemy.orm.exc import NoResultFound
from app import db
from app.dao.templates_dao import (
dao_create_template,
dao_get_all_templates_for_service,
@@ -17,6 +19,16 @@ from app.models import Template, TemplateHistory, TemplateRedacted
from tests.app.db import create_template
def template_query_count():
stmt = select(func.count()).select_from(Template)
return db.session.execute(stmt).scalar() or 0
def template_history_query_count():
stmt = select(func.count()).select_from(TemplateHistory)
return db.session.execute(stmt).scalar() or 0
@pytest.mark.parametrize(
"template_type, subject",
[
@@ -37,7 +49,7 @@ def test_create_template(sample_service, sample_user, template_type, subject):
template = Template(**data)
dao_create_template(template)
assert Template.query.count() == 1
assert template_query_count() == 1
assert len(dao_get_all_templates_for_service(sample_service.id)) == 1
assert (
dao_get_all_templates_for_service(sample_service.id)[0].name
@@ -50,11 +62,13 @@ def test_create_template(sample_service, sample_user, template_type, subject):
def test_create_template_creates_redact_entry(sample_service):
assert TemplateRedacted.query.count() == 0
stmt = select(func.count()).select_from(TemplateRedacted)
assert db.session.execute(stmt).scalar() == 0
template = create_template(sample_service)
redacted = TemplateRedacted.query.one()
stmt = select(TemplateRedacted)
redacted = db.session.execute(stmt).scalars().one()
assert redacted.template_id == template.id
assert redacted.redact_personalisation is False
assert redacted.updated_by_id == sample_service.created_by_id
@@ -79,7 +93,8 @@ def test_update_template(sample_service, sample_user):
def test_redact_template(sample_template):
redacted = TemplateRedacted.query.one()
stmt = select(TemplateRedacted)
redacted = db.session.execute(stmt).scalars().one()
assert redacted.template_id == sample_template.id
assert redacted.redact_personalisation is False
@@ -96,7 +111,7 @@ def test_get_all_templates_for_service(service_factory):
service_1 = service_factory.get("service 1", email_from="service.1")
service_2 = service_factory.get("service 2", email_from="service.2")
assert Template.query.count() == 2
assert template_query_count() == 2
assert len(dao_get_all_templates_for_service(service_1.id)) == 1
assert len(dao_get_all_templates_for_service(service_2.id)) == 1
@@ -119,7 +134,7 @@ def test_get_all_templates_for_service(service_factory):
content="Template content",
)
assert Template.query.count() == 5
assert template_query_count() == 5
assert len(dao_get_all_templates_for_service(service_1.id)) == 3
assert len(dao_get_all_templates_for_service(service_2.id)) == 2
@@ -144,7 +159,7 @@ def test_get_all_templates_for_service_is_alphabetised(sample_service):
service=sample_service,
)
assert Template.query.count() == 3
assert template_query_count() == 3
assert (
dao_get_all_templates_for_service(sample_service.id)[0].name
== "Sample Template 1"
@@ -171,7 +186,7 @@ def test_get_all_templates_for_service_is_alphabetised(sample_service):
def test_get_all_returns_empty_list_if_no_templates(sample_service):
assert Template.query.count() == 0
assert template_query_count() == 0
assert len(dao_get_all_templates_for_service(sample_service.id)) == 0
@@ -257,8 +272,8 @@ def test_get_template_by_id_and_service_returns_none_if_no_template(
def test_create_template_creates_a_history_record_with_current_data(
sample_service, sample_user
):
assert Template.query.count() == 0
assert TemplateHistory.query.count() == 0
assert template_query_count() == 0
assert template_history_query_count() == 0
data = {
"name": "Sample Template",
"template_type": TemplateType.EMAIL,
@@ -270,10 +285,12 @@ def test_create_template_creates_a_history_record_with_current_data(
template = Template(**data)
dao_create_template(template)
assert Template.query.count() == 1
assert template_query_count() == 1
template_from_db = Template.query.first()
template_history = TemplateHistory.query.first()
stmt = select(Template)
template_from_db = db.session.execute(stmt).scalars().first()
stmt = select(TemplateHistory)
template_history = db.session.execute(stmt).scalars().first()
assert template_from_db.id == template_history.id
assert template_from_db.name == template_history.name
@@ -286,8 +303,8 @@ def test_create_template_creates_a_history_record_with_current_data(
def test_update_template_creates_a_history_record_with_current_data(
sample_service, sample_user
):
assert Template.query.count() == 0
assert TemplateHistory.query.count() == 0
assert template_query_count() == 0
assert template_history_query_count() == 0
data = {
"name": "Sample Template",
"template_type": TemplateType.EMAIL,
@@ -301,22 +318,26 @@ def test_update_template_creates_a_history_record_with_current_data(
created = dao_get_all_templates_for_service(sample_service.id)[0]
assert created.name == "Sample Template"
assert Template.query.count() == 1
assert Template.query.first().version == 1
assert TemplateHistory.query.count() == 1
assert template_query_count() == 1
stmt = select(Template)
assert db.session.execute(stmt).scalars().first().version == 1
assert template_history_query_count() == 1
created.name = "new name"
dao_update_template(created)
assert Template.query.count() == 1
assert TemplateHistory.query.count() == 2
assert template_query_count() == 1
assert template_history_query_count() == 2
template_from_db = Template.query.first()
stmt = select(Template)
template_from_db = db.session.execute(stmt).scalars().first()
assert template_from_db.version == 2
assert TemplateHistory.query.filter_by(name="Sample Template").one().version == 1
assert TemplateHistory.query.filter_by(name="new name").one().version == 2
stmt = select(TemplateHistory).filter_by(name="Sample Template")
assert db.session.execute(stmt).scalars().one().version == 1
stmt = select(TemplateHistory).filter_by(name="new name")
assert db.session.execute(stmt).scalars().one().version == 2
def test_get_template_history_version(sample_user, sample_service, sample_template):

View File

@@ -3,6 +3,7 @@ from datetime import timedelta
import pytest
from freezegun import freeze_time
from sqlalchemy import func, select
from sqlalchemy.exc import DataError
from sqlalchemy.orm.exc import NoResultFound
@@ -37,6 +38,21 @@ from tests.app.db import (
)
def _get_user_query_count():
stmt = select(func.count(User.id))
return db.session.execute(stmt).scalar() or 0
def _get_user_query_first():
stmt = select(User)
return db.session.execute(stmt).scalars().first()
def _get_verify_code_query_count():
stmt = select(func.count(VerifyCode.id))
return db.session.execute(stmt).scalar() or 0
@freeze_time("2020-01-28T12:00:00")
@pytest.mark.parametrize(
"phone_number, expected_phone_number",
@@ -55,8 +71,10 @@ def test_create_user(notify_db_session, phone_number, expected_phone_number):
}
user = User(**data)
save_model_user(user, password="password", validated_email_access=True)
assert User.query.count() == 1
user_query = User.query.first()
stmt = select(func.count(User.id))
assert db.session.execute(stmt).scalar() == 1
stmt = select(User)
user_query = db.session.execute(stmt).scalars().first()
assert user_query.email_address == email
assert user_query.id == user.id
assert user_query.mobile_number == expected_phone_number
@@ -68,7 +86,8 @@ def test_get_all_users(notify_db_session):
create_user(email="1@test.com")
create_user(email="2@test.com")
assert User.query.count() == 2
stmt = select(func.count(User.id))
assert db.session.execute(stmt).scalar() == 2
assert len(get_user_by_id()) == 2
@@ -89,9 +108,10 @@ def test_get_user_invalid_id(notify_db_session):
def test_delete_users(sample_user):
assert User.query.count() == 1
stmt = select(func.count(User.id))
assert db.session.execute(stmt).scalar() == 1
delete_model_user(sample_user)
assert User.query.count() == 0
assert db.session.execute(stmt).scalar() == 0
def test_increment_failed_login_should_increment_failed_logins(sample_user):
@@ -127,9 +147,10 @@ def test_get_user_by_email_is_case_insensitive(sample_user):
def test_should_delete_all_verification_codes_more_than_one_day_old(sample_user):
make_verify_code(sample_user, age=timedelta(hours=24), code="54321")
make_verify_code(sample_user, age=timedelta(hours=24), code="54321")
assert VerifyCode.query.count() == 2
stmt = select(func.count(VerifyCode.id))
assert db.session.execute(stmt).scalar() == 2
delete_codes_older_created_more_than_a_day_ago()
assert VerifyCode.query.count() == 0
assert db.session.execute(stmt).scalar() == 0
def test_should_not_delete_verification_codes_less_than_one_day_old(sample_user):
@@ -137,10 +158,11 @@ def test_should_not_delete_verification_codes_less_than_one_day_old(sample_user)
sample_user, age=timedelta(hours=23, minutes=59, seconds=59), code="12345"
)
make_verify_code(sample_user, age=timedelta(hours=24), code="54321")
assert VerifyCode.query.count() == 2
stmt = select(func.count(VerifyCode.id))
assert db.session.execute(stmt).scalar() == 2
delete_codes_older_created_more_than_a_day_ago()
assert VerifyCode.query.one()._code == "12345"
stmt = select(VerifyCode)
assert db.session.execute(stmt).scalars().one()._code == "12345"
def make_verify_code(user, age=None, expiry_age=None, code="12335", code_used=False):