diff --git a/app/dao/notifications_dao.py b/app/dao/notifications_dao.py index 699695ef1..8eac58fc4 100644 --- a/app/dao/notifications_dao.py +++ b/app/dao/notifications_dao.py @@ -22,7 +22,7 @@ from sqlalchemy.sql import functions from notifications_utils.international_billing_rates import INTERNATIONAL_BILLING_RATES from app import db, create_uuid -from app.utils import midnight_n_days_ago +from app.utils import midnight_n_days_ago, escape_special_characters from app.errors import InvalidRequest from app.models import ( Notification, @@ -452,11 +452,7 @@ def dao_get_notifications_by_to_field(service_id, search_term, notification_type else: raise InvalidRequest("Only email and SMS can use search by recipient", 400) - for special_character in ('\\', '_', '%', '/'): - normalised = normalised.replace( - special_character, - '\{}'.format(special_character) - ) + normalised = escape_special_characters(normalised) filters = [ Notification.service_id == service_id, diff --git a/app/dao/users_dao.py b/app/dao/users_dao.py index 89d20bc84..f2965d3a2 100644 --- a/app/dao/users_dao.py +++ b/app/dao/users_dao.py @@ -6,6 +6,7 @@ from sqlalchemy.orm import joinedload from app import db from app.models import (User, VerifyCode) +from app.utils import escape_special_characters def _remove_values_for_keys_if_present(dict, keys): @@ -98,7 +99,8 @@ def get_user_by_email(email): def get_users_by_partial_email(email): - return User.query.filter(User.email_address.ilike("\%{}\%".format(email))).all() + email = escape_special_characters(email) + return User.query.filter(User.email_address.ilike("%{}%".format(email))).all() def increment_failed_login_count(user): diff --git a/app/utils.py b/app/utils.py index 71468e135..3ceeed2e6 100644 --- a/app/utils.py +++ b/app/utils.py @@ -115,3 +115,12 @@ def last_n_days(limit_days): # reverse the countdown, -1 from first two args to ensure it stays 0-indexed for x in range(limit_days - 1, -1, -1) ] + + +def escape_special_characters(string): + for special_character in ('\\', '_', '%', '/'): + string = string.replace( + special_character, + '\{}'.format(special_character) + ) + return string