diff --git a/app/commands.py b/app/commands.py index 5580e7632..c88e2bcb3 100644 --- a/app/commands.py +++ b/app/commands.py @@ -12,7 +12,7 @@ from click_datetime import Datetime as click_dt from faker import Faker from flask import current_app, json from notifications_python_client.authentication import create_jwt_token -from sqlalchemy import and_, text +from sqlalchemy import and_, select, text, update from sqlalchemy.exc import IntegrityError from sqlalchemy.orm.exc import NoResultFound @@ -123,8 +123,8 @@ def purge_functional_test_data(user_email_prefix): if getenv("NOTIFY_ENVIRONMENT", "") not in ["development", "test"]: current_app.logger.error("Can only be run in development") return - - users = User.query.filter(User.email_address.like(f"{user_email_prefix}%")).all() + stmt = select(User).where(User.email_address.like(f"{user_email_prefix}%")) + users = db.session.execute(stmt).scalars().all() for usr in users: # Make sure the full email includes a uuid in it # Just in case someone decides to use a similar email address. @@ -338,9 +338,10 @@ def populate_organizations_from_file(file_name): email_branding = None email_branding_column = columns[5].strip() if len(email_branding_column) > 0: - email_branding = EmailBranding.query.filter( + stmt = select(EmailBranding).where( EmailBranding.name == email_branding_column - ).one() + ) + email_branding = db.session.execute(stmt).scalars().one() data = { "name": columns[0], "active": True, @@ -406,10 +407,14 @@ def populate_organization_agreement_details_from_file(file_name): @notify_command(name="associate-services-to-organizations") def associate_services_to_organizations(): - services = Service.get_history_model().query.filter_by(version=1).all() + stmt = select(Service.get_history_model()).where( + Service.get_history_model().version == 1 + ) + services = db.session.execute(stmt).scalars().all() for s in services: - created_by_user = User.query.filter_by(id=s.created_by_id).first() + stmt = select(User).where(User.id == s.created_by_id) + created_by_user = db.session.execute(stmt).scalars().first() organization = dao_get_organization_by_email_address( created_by_user.email_address ) @@ -467,15 +472,16 @@ def populate_go_live(file_name): @notify_command(name="fix-billable-units") def fix_billable_units(): - query = Notification.query.filter( + stmt = select(Notification).where( Notification.notification_type == NotificationType.SMS, Notification.status != NotificationStatus.CREATED, Notification.sent_at == None, # noqa Notification.billable_units == 0, Notification.key_type != KeyType.TEST, ) + all = db.session.execute(stmt).scalars().all() - for notification in query.all(): + for notification in all: template_model = dao_get_template_by_id( notification.template_id, notification.template_version ) @@ -490,9 +496,12 @@ def fix_billable_units(): f"Updating notification: {notification.id} with {template.fragment_count} billable_units" ) - Notification.query.filter(Notification.id == notification.id).update( - {"billable_units": template.fragment_count} + stmt = ( + update(Notification) + .where(Notification.id == notification.id) + .values({"billable_units": template.fragment_count}) ) + db.session.execute(stmt) db.session.commit() current_app.logger.info("End fix_billable_units") @@ -637,8 +646,8 @@ def populate_annual_billing_with_defaults(year, missing_services_only): This is useful to ensure all services start the new year with the correct annual billing. """ if missing_services_only: - active_services = ( - Service.query.filter(Service.active) + stmt = ( + select(Service) .outerjoin( AnnualBilling, and_( @@ -646,20 +655,18 @@ def populate_annual_billing_with_defaults(year, missing_services_only): AnnualBilling.financial_year_start == year, ), ) - .filter(AnnualBilling.id == None) # noqa - .all() + .where(Service.active, AnnualBilling.id == None) # noqa ) + active_services = db.session.execute(stmt).scalars().all() else: - active_services = Service.query.filter(Service.active).all() + stmt = select(Service).where(Service.active) + active_services = db.session.execute(stmt).scalars().all() previous_year = year - 1 - services_with_zero_free_allowance = ( - db.session.query(AnnualBilling.service_id) - .filter( - AnnualBilling.financial_year_start == previous_year, - AnnualBilling.free_sms_fragment_limit == 0, - ) - .all() + stmt = select(AnnualBilling.id).where( + AnnualBilling.financial_year_start == previous_year, + AnnualBilling.free_sms_fragment_limit == 0, ) + services_with_zero_free_allowance = db.session.execute(stmt).scalars().all() for service in active_services: # If a service has free_sms_fragment_limit for the previous year @@ -750,7 +757,8 @@ def create_user_jwt(token): def _update_template(id, name, template_type, content, subject): - template = Template.query.filter_by(id=id).first() + stmt = select(Template).where(Template.id == id) + template = db.session.execute(stmt).scalars().first() if not template: template = Template(id=id) template.service_id = "d6aa2c68-a2d9-4437-ab19-3ae8eb202553" @@ -761,7 +769,8 @@ def _update_template(id, name, template_type, content, subject): template.content = "\n".join(content) template.subject = subject - history = TemplateHistory.query.filter_by(id=id).first() + stmt = select(TemplateHistory).where(TemplateHistory.id == id) + history = db.session.execute(stmt).scalars().first() if not history: history = TemplateHistory(id=id) history.service_id = "d6aa2c68-a2d9-4437-ab19-3ae8eb202553"