More tweaks, trying to get tests to be clean.

Signed-off-by: Cliff Hill <Clifford.hill@gsa.gov>
This commit is contained in:
Cliff Hill
2024-01-10 11:18:33 -05:00
parent 908d695b54
commit ac9591ec7c
18 changed files with 37 additions and 29 deletions

View File

@@ -25,7 +25,7 @@ def serialize_ft_billing_remove_emails(rows):
"charged_units": row.charged_units,
}
for row in rows
if row.notification_type != "email"
if row.notification_type != NotificationType.EMAIL
]

View File

@@ -69,13 +69,13 @@ def delete_notifications_older_than_retention():
@notify_celery.task(name="delete-sms-notifications")
@cronitor("delete-sms-notifications")
def delete_sms_notifications_older_than_retention():
_delete_notifications_older_than_retention_by_type("sms")
_delete_notifications_older_than_retention_by_type(NotificationType.SMS)
@notify_celery.task(name="delete-email-notifications")
@cronitor("delete-email-notifications")
def delete_email_notifications_older_than_retention():
_delete_notifications_older_than_retention_by_type("email")
_delete_notifications_older_than_retention_by_type(NotificationType.EMAIL)
def _delete_notifications_older_than_retention_by_type(notification_type):

View File

@@ -1,6 +1,7 @@
from abc import abstractmethod
from typing import Protocol
from app.models import NotificationType
from botocore.config import Config
AWS_CLIENT_CONFIG = Config(
@@ -53,10 +54,10 @@ class NotificationProviderClients(object):
return self.email_clients.get(name)
def get_client_by_name_and_type(self, name, notification_type):
assert notification_type in ["email", "sms"] # nosec B101
assert notification_type in {NotificationType.EMAIL, NotificationType.SMS} # nosec B101
if notification_type == "email":
if notification_type == NotificationType.EMAIL:
return self.get_email_client(name)
if notification_type == "sms":
if notification_type == NotificationType.SMS:
return self.get_sms_client(name)

View File

@@ -23,6 +23,7 @@ from app.models import (
NOTIFICATION_TEMPORARY_FAILURE,
FactNotificationStatus,
Notification,
NotificationType,
NotificationAllTimeView,
Service,
Template,
@@ -467,7 +468,7 @@ def get_total_notifications_for_date_range(start_date, end_date):
case(
[
(
FactNotificationStatus.notification_type == "email",
FactNotificationStatus.notification_type == NotificationType.EMAIL,
FactNotificationStatus.notification_count,
)
],
@@ -478,7 +479,7 @@ def get_total_notifications_for_date_range(start_date, end_date):
case(
[
(
FactNotificationStatus.notification_type == "sms",
FactNotificationStatus.notification_type == NotificationType.SMS,
FactNotificationStatus.notification_count,
)
],

View File

@@ -62,7 +62,7 @@ def _get_sms_providers_for_update(time_threshold):
# get current priority of both providers
q = (
ProviderDetails.query.filter(
ProviderDetails.notification_type == "sms", ProviderDetails.active
ProviderDetails.notification_type == NotificationType.SMS, ProviderDetails.active
)
.with_for_update()
.all()

View File

@@ -106,7 +106,7 @@ def dao_fetch_live_services_data():
case(
[
(
this_year_ft_billing.c.notification_type == "email",
this_year_ft_billing.c.notification_type == NotificationType.EMAIL,
func.sum(this_year_ft_billing.c.notifications_sent),
)
],
@@ -115,7 +115,7 @@ def dao_fetch_live_services_data():
case(
[
(
this_year_ft_billing.c.notification_type == "sms",
this_year_ft_billing.c.notification_type == NotificationType.SMS,
func.sum(this_year_ft_billing.c.notifications_sent),
)
],

View File

@@ -11,6 +11,7 @@ from app.dao.service_data_retention_dao import (
)
from app.errors import register_errors
from app.inbound_sms.inbound_sms_schemas import get_inbound_sms_for_service_schema
from app.models import NotificationType
from app.schema_validation import validate
inbound_sms = Blueprint(
@@ -31,7 +32,7 @@ def post_inbound_sms_for_service(service_id):
# user_number = try_validate_and_format_phone_number(user_number, international=True)
inbound_data_retention = fetch_service_data_retention_by_notification_type(
service_id, "sms"
service_id, NotificationType.SMS
)
limit_days = (
inbound_data_retention.days_of_retention if inbound_data_retention else 7
@@ -49,7 +50,7 @@ def get_most_recent_inbound_sms_for_service(service_id):
page = request.args.get("page", 1)
inbound_data_retention = fetch_service_data_retention_by_notification_type(
service_id, "sms"
service_id, NotificationType.SMS
)
limit_days = (
inbound_data_retention.days_of_retention if inbound_data_retention else 7

View File

@@ -33,6 +33,7 @@ from app.utils import (
class TemplateType(Enum):
SMS = "sms"
EMAIL = "email"
LETTER = "letter"
class NotificationType(Enum):

View File

@@ -2,6 +2,7 @@ from app import performance_platform_client
from app.dao.fact_notification_status_dao import (
get_total_sent_notifications_for_day_and_type,
)
from app.models import NotificationType
# TODO: is this obsolete? it doesn't seem to be used anywhere
@@ -19,8 +20,8 @@ def send_total_notifications_sent_for_day_stats(start_time, notification_type, c
# TODO: is this obsolete? it doesn't seem to be used anywhere
def get_total_sent_notifications_for_day(day):
email_count = get_total_sent_notifications_for_day_and_type(day, "email")
sms_count = get_total_sent_notifications_for_day_and_type(day, "sms")
email_count = get_total_sent_notifications_for_day_and_type(day, NotificationType.EMAIL)
sms_count = get_total_sent_notifications_for_day_and_type(day, NotificationType.SMS)
return {
"email": email_count,

View File

@@ -1,3 +1,5 @@
from app.models import NotificationType
add_service_data_retention_request = {
"$schema": "http://json-schema.org/draft-07/schema#",
"description": "POST service data retention schema",
@@ -5,7 +7,7 @@ add_service_data_retention_request = {
"type": "object",
"properties": {
"days_of_retention": {"type": "integer"},
"notification_type": {"enum": ["sms", "email"]},
"notification_type": {"enum": [NotificationType.SMS.value, NotificationType.EMAIL.value]},
},
"required": ["days_of_retention", "notification_type"],
}

View File

@@ -16,7 +16,7 @@ post_create_template_schema = {
"created_by": uuid,
"parent_folder_id": uuid,
},
"if": {"properties": {"template_type": {"enum": ["email"]}}},
"if": {"properties": {"template_type": {"enum": [TemplateType.EMAIL.value]}}},
"then": {"required": ["subject"]},
"required": ["name", "template_type", "content", "service", "created_by"],
}

View File

@@ -39,6 +39,7 @@ from app.models import (
Permission,
Service,
TemplateType,
VerifyCodeType,
)
from app.notifications.process_notifications import (
persist_notification,
@@ -226,7 +227,7 @@ def verify_user_code(user_id):
user_to_verify.current_session_id = str(uuid.uuid4())
user_to_verify.logged_in_at = datetime.utcnow()
if data["code_type"] == "email":
if data["code_type"] == VerifyCodeType.EMAIL:
user_to_verify.email_access_validated_at = datetime.utcnow()
user_to_verify.failed_login_count = 0
save_model_user(user_to_verify)

View File

@@ -41,7 +41,7 @@ get_notification_response = {
"line_5": {"type": ["string", "null"]},
"line_6": {"type": ["string", "null"]},
"postcode": {"type": ["string", "null"]},
"type": {"enum": ["sms", "email"]},
"type": {"enum": [e.value for e in TemplateType]},
"status": {"type": "string"},
"template": template,
"body": {"type": "string"},

View File

@@ -120,7 +120,7 @@ def test_get_template_with_non_existent_template_id_returns_404(
}
@pytest.mark.parametrize("tmp_type", list(TemplateType))
@pytest.mark.parametrize("tmp_type", TemplateType)
def test_get_template_with_non_existent_version_returns_404(
client, sample_service, tmp_type
):

View File

@@ -59,7 +59,7 @@ valid_post = [
]
@pytest.mark.parametrize("tmp_type", list(TemplateType))
@pytest.mark.parametrize("tmp_type", TemplateType)
@pytest.mark.parametrize(
"subject,content,post_data,expected_subject,expected_content,expected_html",
valid_post,
@@ -125,7 +125,7 @@ def test_email_templates_not_rendered_into_content(client, sample_service):
assert resp_json["body"] == template.content
@pytest.mark.parametrize("tmp_type", list(TemplateType))
@pytest.mark.parametrize("tmp_type", TemplateType)
def test_invalid_post_template_returns_400(client, sample_service, tmp_type):
template = create_template(
sample_service,

View File

@@ -114,7 +114,7 @@ def test_get_template_request_schema_against_invalid_args_is_invalid(
assert error["message"] in error_message
@pytest.mark.parametrize("template_type", list(TemplateType))
@pytest.mark.parametrize("template_type", TemplateType)
@pytest.mark.parametrize(
"response", [valid_json_get_response, valid_json_get_response_with_optionals]
)
@@ -149,7 +149,7 @@ def test_post_template_preview_against_invalid_args_is_invalid(args, error_messa
assert error["message"] in error_messages
@pytest.mark.parametrize("template_type", list(TemplateType))
@pytest.mark.parametrize("template_type", TemplateType)
@pytest.mark.parametrize(
"response", [valid_json_post_response, valid_json_post_response_with_optionals]
)

View File

@@ -41,7 +41,7 @@ def test_get_all_templates_returns_200(client, sample_service):
assert template["subject"] == templates[index].subject
@pytest.mark.parametrize("tmp_type", list(TemplateType))
@pytest.mark.parametrize("tmp_type", TemplateType)
def test_get_all_templates_for_valid_type_returns_200(client, sample_service, tmp_type):
templates = [
create_template(
@@ -75,7 +75,7 @@ def test_get_all_templates_for_valid_type_returns_200(client, sample_service, tm
assert template["subject"] == templates[index].subject
@pytest.mark.parametrize("tmp_type", list(TemplateType))
@pytest.mark.parametrize("tmp_type", TemplateType)
def test_get_correct_num_templates_for_valid_type_returns_200(
client, sample_service, tmp_type
):

View File

@@ -256,19 +256,19 @@ invalid_json_get_all_response = [
]
@pytest.mark.parametrize("template_type", list(TemplateType))
@pytest.mark.parametrize("template_type", TemplateType)
def test_get_all_template_request_schema_against_no_args_is_valid(template_type):
data = {}
assert validate(data, get_all_template_request) == data
@pytest.mark.parametrize("template_type", list(TemplateType))
@pytest.mark.parametrize("template_type", TemplateType)
def test_get_all_template_request_schema_against_valid_args_is_valid(template_type):
data = {"type": template_type}
assert validate(data, get_all_template_request) == data
@pytest.mark.parametrize("template_type", list(TemplateType))
@pytest.mark.parametrize("template_type", TemplateType)
def test_get_all_template_request_schema_against_invalid_args_is_invalid(template_type):
data = {"type": "unknown"}