Move code that escapes special chars to helper function and use it

in query get_users_by_partial_email
This commit is contained in:
Pea Tyczynska
2018-07-13 15:26:42 +01:00
parent 782a8ab9e7
commit a69dee5e6d
3 changed files with 14 additions and 7 deletions

View File

@@ -22,7 +22,7 @@ from sqlalchemy.sql import functions
from notifications_utils.international_billing_rates import INTERNATIONAL_BILLING_RATES from notifications_utils.international_billing_rates import INTERNATIONAL_BILLING_RATES
from app import db, create_uuid from app import db, create_uuid
from app.utils import midnight_n_days_ago from app.utils import midnight_n_days_ago, escape_special_characters
from app.errors import InvalidRequest from app.errors import InvalidRequest
from app.models import ( from app.models import (
Notification, Notification,
@@ -452,11 +452,7 @@ def dao_get_notifications_by_to_field(service_id, search_term, notification_type
else: else:
raise InvalidRequest("Only email and SMS can use search by recipient", 400) raise InvalidRequest("Only email and SMS can use search by recipient", 400)
for special_character in ('\\', '_', '%', '/'): normalised = escape_special_characters(normalised)
normalised = normalised.replace(
special_character,
'\{}'.format(special_character)
)
filters = [ filters = [
Notification.service_id == service_id, Notification.service_id == service_id,

View File

@@ -6,6 +6,7 @@ from sqlalchemy.orm import joinedload
from app import db from app import db
from app.models import (User, VerifyCode) from app.models import (User, VerifyCode)
from app.utils import escape_special_characters
def _remove_values_for_keys_if_present(dict, keys): def _remove_values_for_keys_if_present(dict, keys):
@@ -98,7 +99,8 @@ def get_user_by_email(email):
def get_users_by_partial_email(email): def get_users_by_partial_email(email):
return User.query.filter(User.email_address.ilike("\%{}\%".format(email))).all() email = escape_special_characters(email)
return User.query.filter(User.email_address.ilike("%{}%".format(email))).all()
def increment_failed_login_count(user): def increment_failed_login_count(user):

View File

@@ -115,3 +115,12 @@ def last_n_days(limit_days):
# reverse the countdown, -1 from first two args to ensure it stays 0-indexed # reverse the countdown, -1 from first two args to ensure it stays 0-indexed
for x in range(limit_days - 1, -1, -1) for x in range(limit_days - 1, -1, -1)
] ]
def escape_special_characters(string):
for special_character in ('\\', '_', '%', '/'):
string = string.replace(
special_character,
'\{}'.format(special_character)
)
return string