diff --git a/app/dao/notifications_dao.py b/app/dao/notifications_dao.py index e5787720f..d7b049290 100644 --- a/app/dao/notifications_dao.py +++ b/app/dao/notifications_dao.py @@ -6,6 +6,7 @@ from datetime import ( ) from flask import current_app +from werkzeug.datastructures import MultiDict from app import db from app.models import ( @@ -191,10 +192,16 @@ def get_notifications_for_service(service_id, filter_dict=None, page=1): def filter_query(query, filter_dict=None): - if filter_dict and 'status' in filter_dict: - query = query.filter_by(status=filter_dict['status']) - if filter_dict and 'template_type' in filter_dict: - query = query.join(Template).filter(Template.template_type == filter_dict['template_type']) + if filter_dict is None: + filter_dict = MultiDict() + else: + filter_dict = MultiDict(filter_dict) + statuses = filter_dict.getlist('status') if 'status' in filter_dict else None + if statuses: + query = query.filter(Notification.status.in_(statuses)) + template_types = filter_dict.getlist('template_type') if 'template_type' in filter_dict else None + if template_types: + query = query.join(Template).filter(Template.template_type.in_(template_types)) return query diff --git a/app/dao/permissions_dao.py b/app/dao/permissions_dao.py index 2213c0040..eca2446e9 100644 --- a/app/dao/permissions_dao.py +++ b/app/dao/permissions_dao.py @@ -32,8 +32,10 @@ class PermissionDAO(DAOClass): class Meta: model = Permission - def get_query(self, filter_by_dict={}): - if isinstance(filter_by_dict, dict): + def get_query(self, filter_by_dict=None): + if filter_by_dict is None: + filter_by_dict = MultiDict() + else: filter_by_dict = MultiDict(filter_by_dict) query = self.Meta.model.query if 'id' in filter_by_dict: diff --git a/app/schemas.py b/app/schemas.py index dc1ce3c48..1181a7bef 100644 --- a/app/schemas.py +++ b/app/schemas.py @@ -2,7 +2,7 @@ from flask_marshmallow.fields import fields from . import ma from . import models from app.dao.permissions_dao import permission_dao -from marshmallow import (post_load, ValidationError, validates, validates_schema) +from marshmallow import (post_load, ValidationError, validates, validates_schema, pre_load) from marshmallow_sqlalchemy import field_for from utils.recipients import ( validate_email_address, InvalidEmailError, @@ -74,6 +74,11 @@ class ServiceSchema(BaseSchema): exclude = ("updated_at", "created_at", "api_keys", "templates", "jobs", 'old_id') +class NotificationModelSchema(BaseSchema): + class Meta: + model = models.Notification + + class TemplateSchema(BaseSchema): class Meta: model = models.Template @@ -203,10 +208,29 @@ class EmailDataSchema(ma.Schema): class NotificationsFilterSchema(ma.Schema): - template_type = field_for(models.Template, 'template_type', load_only=True, required=False) - status = field_for(models.Notification, 'status', load_only=True, required=False) + template_type = fields.Nested(TemplateSchema, only='template_type', many=True) + status = fields.Nested(NotificationModelSchema, only='status', many=True) page = fields.Int(required=False) + @pre_load + def handle_multidict(self, in_data): + if isinstance(in_data, dict) and hasattr(in_data, 'getlist'): + out_data = dict([(k, in_data.get(k)) for k in in_data.keys()]) + if 'template_type' in in_data: + out_data['template_type'] = [{'template_type': x} for x in in_data.getlist('template_type')] + if 'status' in in_data: + out_data['status'] = [{"status": x} for x in in_data.getlist('status')] + + return out_data + + @post_load + def convert_schema_object_to_field(self, in_data): + if 'template_type' in in_data: + in_data['template_type'] = [x.template_type for x in in_data['template_type']] + if 'status' in in_data: + in_data['status'] = [x.status for x in in_data['status']] + return in_data + @validates('page') def validate_page(self, value): try: @@ -216,6 +240,7 @@ class NotificationsFilterSchema(ma.Schema): except: raise ValidationError("Not a positive integer") + user_schema = UserSchema() user_schema_load_json = UserSchema(load_json=True) service_schema = ServiceSchema() diff --git a/tests/app/notifications/test_rest.py b/tests/app/notifications/test_rest.py index ba89ed108..4882ba420 100644 --- a/tests/app/notifications/test_rest.py +++ b/tests/app/notifications/test_rest.py @@ -230,28 +230,34 @@ def test_should_reject_invalid_page_param(notify_api, sample_email_template): def test_should_return_pagination_links(notify_api, notify_db, notify_db_session, sample_email_template): with notify_api.test_request_context(): with notify_api.test_client() as client: - notify_api.config['PAGE_SIZE'] = 1 + # Effectively mocking page size + original_page_size = notify_api.config['PAGE_SIZE'] + try: + notify_api.config['PAGE_SIZE'] = 1 - create_sample_notification(notify_db, notify_db_session, sample_email_template.service) - notification_2 = create_sample_notification(notify_db, notify_db_session, sample_email_template.service) - create_sample_notification(notify_db, notify_db_session, sample_email_template.service) + create_sample_notification(notify_db, notify_db_session, sample_email_template.service) + notification_2 = create_sample_notification(notify_db, notify_db_session, sample_email_template.service) + create_sample_notification(notify_db, notify_db_session, sample_email_template.service) - auth_header = create_authorization_header( - service_id=sample_email_template.service_id, - path='/notifications', - method='GET') + auth_header = create_authorization_header( + service_id=sample_email_template.service_id, + path='/notifications', + method='GET') - response = client.get( - '/notifications?page=2', - headers=[auth_header]) + response = client.get( + '/notifications?page=2', + headers=[auth_header]) - notifications = json.loads(response.get_data(as_text=True)) - assert len(notifications['notifications']) == 1 - assert notifications['links']['last'] == '/notifications?page=3' - assert notifications['links']['prev'] == '/notifications?page=1' - assert notifications['links']['next'] == '/notifications?page=3' - assert notifications['notifications'][0]['to'] == notification_2.to - assert response.status_code == 200 + notifications = json.loads(response.get_data(as_text=True)) + assert len(notifications['notifications']) == 1 + assert notifications['links']['last'] == '/notifications?page=3' + assert notifications['links']['prev'] == '/notifications?page=1' + assert notifications['links']['next'] == '/notifications?page=3' + assert notifications['notifications'][0]['to'] == notification_2.to + assert response.status_code == 200 + + finally: + notify_api.config['PAGE_SIZE'] = original_page_size def test_get_all_notifications_returns_empty_list(notify_api, sample_api_key): @@ -301,6 +307,41 @@ def test_filter_by_template_type(notify_api, notify_db, notify_db_session, sampl assert response.status_code == 200 +def test_filter_by_multiple_template_types(notify_api, + notify_db, + notify_db_session, + sample_template, + sample_email_template): + with notify_api.test_request_context(): + with notify_api.test_client() as client: + + notification_1 = create_sample_notification( + notify_db, + notify_db_session, + service=sample_email_template.service, + template=sample_template) + notification_2 = create_sample_notification( + notify_db, + notify_db_session, + service=sample_email_template.service, + template=sample_email_template) + + auth_header = create_authorization_header( + service_id=sample_email_template.service_id, + path='/notifications', + method='GET') + + response = client.get( + '/notifications?template_type=sms&template_type=email', + headers=[auth_header]) + + assert response.status_code == 200 + notifications = json.loads(response.get_data(as_text=True)) + assert len(notifications['notifications']) == 2 + set(['sms', 'email']) == set( + [x['template']['template_type'] for x in notifications['notifications']]) + + def test_filter_by_status(notify_api, notify_db, notify_db_session, sample_email_template): with notify_api.test_request_context(): with notify_api.test_client() as client: @@ -309,8 +350,15 @@ def test_filter_by_status(notify_api, notify_db, notify_db_session, sample_email notify_db, notify_db_session, service=sample_email_template.service, + template=sample_email_template, status="delivered") + notification_2 = create_sample_notification( + notify_db, + notify_db_session, + service=sample_email_template.service, + template=sample_email_template) + auth_header = create_authorization_header( service_id=sample_email_template.service_id, path='/notifications', @@ -326,6 +374,42 @@ def test_filter_by_status(notify_api, notify_db, notify_db_session, sample_email assert response.status_code == 200 +def test_filter_by_multiple_statuss(notify_api, + notify_db, + notify_db_session, + sample_email_template): + with notify_api.test_request_context(): + with notify_api.test_client() as client: + + notification_1 = create_sample_notification( + notify_db, + notify_db_session, + service=sample_email_template.service, + template=sample_email_template, + status="delivered") + + notification_2 = create_sample_notification( + notify_db, + notify_db_session, + service=sample_email_template.service, + template=sample_email_template) + + auth_header = create_authorization_header( + service_id=sample_email_template.service_id, + path='/notifications', + method='GET') + + response = client.get( + '/notifications?status=delivered&status=sent', + headers=[auth_header]) + + assert response.status_code == 200 + notifications = json.loads(response.get_data(as_text=True)) + assert len(notifications['notifications']) == 2 + set(['delivered', 'sent']) == set( + [x['status'] for x in notifications['notifications']]) + + def test_filter_by_status_and_template_type(notify_api, notify_db, notify_db_session,