fix more tests

This commit is contained in:
Kenneth Kehl
2024-10-30 11:44:51 -07:00
parent 7ee741b91c
commit 9dfbd991d5

View File

@@ -12,7 +12,7 @@ from click_datetime import Datetime as click_dt
from faker import Faker
from flask import current_app, json
from notifications_python_client.authentication import create_jwt_token
from sqlalchemy import and_, text
from sqlalchemy import and_, select, text, update
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm.exc import NoResultFound
@@ -123,8 +123,8 @@ def purge_functional_test_data(user_email_prefix):
if getenv("NOTIFY_ENVIRONMENT", "") not in ["development", "test"]:
current_app.logger.error("Can only be run in development")
return
users = User.query.filter(User.email_address.like(f"{user_email_prefix}%")).all()
stmt = select(User).where(User.email_address.like(f"{user_email_prefix}%"))
users = db.session.execute(stmt).scalars().all()
for usr in users:
# Make sure the full email includes a uuid in it
# Just in case someone decides to use a similar email address.
@@ -338,9 +338,10 @@ def populate_organizations_from_file(file_name):
email_branding = None
email_branding_column = columns[5].strip()
if len(email_branding_column) > 0:
email_branding = EmailBranding.query.filter(
stmt = select(EmailBranding).where(
EmailBranding.name == email_branding_column
).one()
)
email_branding = db.session.execute(stmt).scalars().one()
data = {
"name": columns[0],
"active": True,
@@ -406,10 +407,14 @@ def populate_organization_agreement_details_from_file(file_name):
@notify_command(name="associate-services-to-organizations")
def associate_services_to_organizations():
services = Service.get_history_model().query.filter_by(version=1).all()
stmt = select(Service.get_history_model()).where(
Service.get_history_model().version == 1
)
services = db.session.execute(stmt).scalars().all()
for s in services:
created_by_user = User.query.filter_by(id=s.created_by_id).first()
stmt = select(User).where(User.id == s.created_by_id)
created_by_user = db.session.execute(stmt).scalars().first()
organization = dao_get_organization_by_email_address(
created_by_user.email_address
)
@@ -467,15 +472,16 @@ def populate_go_live(file_name):
@notify_command(name="fix-billable-units")
def fix_billable_units():
query = Notification.query.filter(
stmt = select(Notification).where(
Notification.notification_type == NotificationType.SMS,
Notification.status != NotificationStatus.CREATED,
Notification.sent_at == None, # noqa
Notification.billable_units == 0,
Notification.key_type != KeyType.TEST,
)
all = db.session.execute(stmt).scalars().all()
for notification in query.all():
for notification in all:
template_model = dao_get_template_by_id(
notification.template_id, notification.template_version
)
@@ -490,9 +496,12 @@ def fix_billable_units():
f"Updating notification: {notification.id} with {template.fragment_count} billable_units"
)
Notification.query.filter(Notification.id == notification.id).update(
{"billable_units": template.fragment_count}
stmt = (
update(Notification)
.where(Notification.id == notification.id)
.values({"billable_units": template.fragment_count})
)
db.session.execute(stmt)
db.session.commit()
current_app.logger.info("End fix_billable_units")
@@ -637,8 +646,8 @@ def populate_annual_billing_with_defaults(year, missing_services_only):
This is useful to ensure all services start the new year with the correct annual billing.
"""
if missing_services_only:
active_services = (
Service.query.filter(Service.active)
stmt = (
select(Service)
.outerjoin(
AnnualBilling,
and_(
@@ -646,20 +655,18 @@ def populate_annual_billing_with_defaults(year, missing_services_only):
AnnualBilling.financial_year_start == year,
),
)
.filter(AnnualBilling.id == None) # noqa
.all()
.where(Service.active, AnnualBilling.id == None) # noqa
)
active_services = db.session.execute(stmt).scalars().all()
else:
active_services = Service.query.filter(Service.active).all()
stmt = select(Service).where(Service.active)
active_services = db.session.execute(stmt).scalars().all()
previous_year = year - 1
services_with_zero_free_allowance = (
db.session.query(AnnualBilling.service_id)
.filter(
AnnualBilling.financial_year_start == previous_year,
AnnualBilling.free_sms_fragment_limit == 0,
)
.all()
stmt = select(AnnualBilling.id).where(
AnnualBilling.financial_year_start == previous_year,
AnnualBilling.free_sms_fragment_limit == 0,
)
services_with_zero_free_allowance = db.session.execute(stmt).scalars().all()
for service in active_services:
# If a service has free_sms_fragment_limit for the previous year
@@ -750,7 +757,8 @@ def create_user_jwt(token):
def _update_template(id, name, template_type, content, subject):
template = Template.query.filter_by(id=id).first()
stmt = select(Template).where(Template.id == id)
template = db.session.execute(stmt).scalars().first()
if not template:
template = Template(id=id)
template.service_id = "d6aa2c68-a2d9-4437-ab19-3ae8eb202553"
@@ -761,7 +769,8 @@ def _update_template(id, name, template_type, content, subject):
template.content = "\n".join(content)
template.subject = subject
history = TemplateHistory.query.filter_by(id=id).first()
stmt = select(TemplateHistory).where(TemplateHistory.id == id)
history = db.session.execute(stmt).scalars().first()
if not history:
history = TemplateHistory(id=id)
history.service_id = "d6aa2c68-a2d9-4437-ab19-3ae8eb202553"