diff --git a/app/dao/templates_dao.py b/app/dao/templates_dao.py index 567e763f3..493c7c072 100644 --- a/app/dao/templates_dao.py +++ b/app/dao/templates_dao.py @@ -42,7 +42,16 @@ def dao_get_template_by_id(template_id, version=None): return Template.query.filter_by(id=template_id).one() -def dao_get_all_templates_for_service(service_id): +def dao_get_all_templates_for_service(service_id, template_type=None): + if template_type is not None: + return Template.query.filter_by( + service_id=service_id, + template_type=template_type, + archived=False + ).order_by( + desc(Template.created_at) + ).all() + return Template.query.filter_by( service_id=service_id, archived=False diff --git a/app/v2/templates/get_templates.py b/app/v2/templates/get_templates.py index a5f4e6f0c..101ace2f9 100644 --- a/app/v2/templates/get_templates.py +++ b/app/v2/templates/get_templates.py @@ -10,9 +10,9 @@ from app.v2.templates.templates_schemas import get_all_template_request @v2_templates_blueprint.route("/", methods=['GET']) def get_templates(): - validate(request.args.to_dict(), get_all_template_request) + data = validate(request.args.to_dict(), get_all_template_request) - templates = templates_dao.dao_get_all_templates_for_service(api_user.service_id) + templates = templates_dao.dao_get_all_templates_for_service(api_user.service_id, data.get('type')) return jsonify( templates=[template.serialize() for template in templates] diff --git a/tests/app/v2/templates/test_get_templates.py b/tests/app/v2/templates/test_get_templates.py index 5f706f0b9..180752271 100644 --- a/tests/app/v2/templates/test_get_templates.py +++ b/tests/app/v2/templates/test_get_templates.py @@ -62,11 +62,35 @@ def test_get_all_templates_for_valid_type_returns_200(client, sample_service, tm reverse_index = len(json_response['templates']) - 1 - i assert json_response['templates'][reverse_index]['id'] == str(templates[i].id) assert json_response['templates'][reverse_index]['body'] == templates[i].content - assert json_response['templates'][reverse_index]['type'] == templates[i].template_type + assert json_response['templates'][reverse_index]['type'] == tmp_type if templates[i].template_type == EMAIL_TYPE: assert json_response['templates'][reverse_index]['subject'] == templates[i].subject +@pytest.mark.parametrize("tmp_type", TEMPLATE_TYPES) +def test_get_correct_num_templates_for_valid_type_returns_200(client, sample_service, tmp_type): + num_templates = 3 + + templates = [] + for i in range(num_templates): + templates.append(create_template(sample_service, template_type=tmp_type)) + + for other_type in TEMPLATE_TYPES: + if other_type != tmp_type: + templates.append(create_template(sample_service, template_type=other_type)) + + auth_header = create_authorization_header(service_id=sample_service.id) + + response = client.get(path='/v2/templates/?type={}'.format(tmp_type), + headers=[('Content-Type', 'application/json'), auth_header]) + + assert response.status_code == 200 + + json_response = json.loads(response.get_data(as_text=True)) + + assert len(json_response['templates']) == num_templates + + def test_get_all_templates_for_invalid_type_returns_400(client, sample_service): auth_header = create_authorization_header(service_id=sample_service.id)