From b33312b8557f21722ab1212b739ee97f4786e94c Mon Sep 17 00:00:00 2001 From: Adam Shimali Date: Tue, 14 Jun 2016 15:07:23 +0100 Subject: [PATCH 1/2] Change endpoint responses where there are marshalling, unmarshalling or param errors to raise invalid data exception. That will cause those responses to be handled in by errors.py, which will log the errors. Set most of schemas to strict mode so that marshmallow will raise exception rather than checking for errors in return tuple from load. Added handler to errors.py for marshmallow validation errors. --- app/accept_invite/rest.py | 9 +- app/errors.py | 10 +- app/events/rest.py | 5 +- app/invite/rest.py | 7 +- app/job/rest.py | 23 ++--- app/notifications/rest.py | 146 +++++++++------------------ app/notifications_statistics/rest.py | 15 +-- app/permission/rest.py | 15 ++- app/provider_details/rest.py | 22 ++-- app/schemas.py | 55 +++++++++- app/service/rest.py | 78 +++++--------- app/template/rest.py | 56 +++++----- app/template_statistics/rest.py | 12 +-- app/user/rest.py | 50 +++++---- tests/app/dao/test_templates_dao.py | 2 +- tests/app/user/test_rest.py | 2 +- 16 files changed, 240 insertions(+), 267 deletions(-) diff --git a/app/accept_invite/rest.py b/app/accept_invite/rest.py index a914654e9..1a0366f5e 100644 --- a/app/accept_invite/rest.py +++ b/app/accept_invite/rest.py @@ -10,7 +10,11 @@ from notifications_utils.url_safe_token import check_token from app.dao.invited_user_dao import get_invited_user_by_id -from app.errors import register_errors +from app.errors import ( + register_errors, + InvalidRequest +) + from app.schemas import invited_user_schema @@ -30,7 +34,8 @@ def get_invited_user_by_token(token): max_age_seconds) except SignatureExpired: message = 'Invitation with id {} expired'.format(invited_user_id) - return jsonify(result='error', message=message), 400 + errors = {'invitation': [message]} + raise InvalidRequest(errors, status_code=400) invited_user = get_invited_user_by_id(invited_user_id) diff --git a/app/errors.py b/app/errors.py index c3cc70c01..9d90dee15 100644 --- a/app/errors.py +++ b/app/errors.py @@ -4,9 +4,10 @@ from flask import ( ) from sqlalchemy.exc import SQLAlchemyError, DataError from sqlalchemy.orm.exc import NoResultFound +from marshmallow import ValidationError -class InvalidData(Exception): +class InvalidRequest(Exception): def __init__(self, message, status_code): super().__init__() @@ -22,7 +23,12 @@ class InvalidData(Exception): def register_errors(blueprint): - @blueprint.app_errorhandler(InvalidData) + @blueprint.app_errorhandler(ValidationError) + def validation_error(error): + current_app.logger.error(error) + return jsonify(result='error', message=error.messages), 400 + + @blueprint.app_errorhandler(InvalidRequest) def invalid_data(error): response = jsonify(error.to_dict()) response.status_code = error.status_code diff --git a/app/events/rest.py b/app/events/rest.py index d3d38454e..6bd189617 100644 --- a/app/events/rest.py +++ b/app/events/rest.py @@ -5,6 +5,7 @@ from flask import ( ) from app.errors import register_errors + from app.schemas import event_schema from app.dao.events_dao import dao_create_event @@ -15,8 +16,6 @@ register_errors(events) @events.route('', methods=['POST']) def create_event(): data = request.get_json() - event, errors = event_schema.load(data) - if errors: - return jsonify(result="error", message=errors), 400 + event = event_schema.load(data).data dao_create_event(event) return jsonify(data=event_schema.dump(event).data), 201 diff --git a/app/invite/rest.py b/app/invite/rest.py index 14a6dae84..9190ad252 100644 --- a/app/invite/rest.py +++ b/app/invite/rest.py @@ -19,14 +19,13 @@ from app.celery.tasks import (email_invited_user) invite = Blueprint('invite', __name__, url_prefix='/service//invite') from app.errors import register_errors + register_errors(invite) @invite.route('', methods=['POST']) def create_invited_user(service_id): invited_user, errors = invited_user_schema.load(request.get_json()) - if errors: - return jsonify(result="error", message=errors), 400 save_invited_user(invited_user) invitation = _create_invitation(invited_user) encrypted_invitation = encryption.encrypt(invitation) @@ -53,9 +52,7 @@ def update_invited_user(service_id, invited_user_id): current_data = dict(invited_user_schema.dump(fetched).data.items()) current_data.update(request.get_json()) - update_dict, errors = invited_user_schema.load(current_data) - if errors: - return jsonify(result='error', message=errors), 400 + update_dict = invited_user_schema.load(current_data).data save_invited_user(update_dict) return jsonify(data=invited_user_schema.dump(fetched).data), 200 diff --git a/app/job/rest.py b/app/job/rest.py index 9b2d44382..f4c090b19 100644 --- a/app/job/rest.py +++ b/app/job/rest.py @@ -22,7 +22,10 @@ from app.celery.tasks import process_job job = Blueprint('job', __name__, url_prefix='/service//job') -from app.errors import register_errors +from app.errors import ( + register_errors, + InvalidRequest +) register_errors(job) @@ -30,7 +33,7 @@ register_errors(job) @job.route('/', methods=['GET']) def get_job_by_service_and_job_id(service_id, job_id): job = dao_get_job_by_service_id_and_job_id(service_id, job_id) - data, errors = job_schema.dump(job) + data = job_schema.dump(job).data return jsonify(data=data) @@ -40,16 +43,13 @@ def get_jobs_by_service(service_id): try: limit_days = int(request.args['limit_days']) except ValueError as e: - error = '{} is not an integer'.format(request.args['limit_days']) - current_app.logger.error(error) - return jsonify(result="error", message={'limit_days': [error]}), 400 + errors = {'error': ['{} is not an integer'.format(request.args['limit_days'])]} + raise InvalidRequest(errors, status_code=400) else: limit_days = None jobs = dao_get_jobs_by_service_id(service_id, limit_days) - data, errors = job_schema.dump(jobs, many=True) - if errors: - return jsonify(result="error", message=errors), 400 + data = job_schema.dump(jobs, many=True).data return jsonify(data=data) @@ -66,13 +66,10 @@ def create_job(service_id): errors = unarchived_template_schema.validate({'archived': template.archived}) if errors: - return jsonify(result='error', message=errors), 400 + raise InvalidRequest(errors, status_code=400) data.update({"template_version": template.version}) - job, errors = job_schema.load(data) - if errors: - return jsonify(result="error", message=errors), 400 - + job = job_schema.load(data).data dao_create_job(job) process_job.apply_async([str(job.id)], queue="process-job") return jsonify(data=job_schema.dump(job).data), 201 diff --git a/app/notifications/rest.py b/app/notifications/rest.py index 7f4029a65..933297e2f 100644 --- a/app/notifications/rest.py +++ b/app/notifications/rest.py @@ -35,7 +35,10 @@ from app.celery.tasks import send_sms, send_email notifications = Blueprint('notifications', __name__) -from app.errors import register_errors +from app.errors import ( + register_errors, + InvalidRequest +) register_errors(notifications) @@ -47,16 +50,12 @@ def process_ses_response(): ses_request = json.loads(request.data) errors = validate_callback_data(data=ses_request, fields=['Message'], client_name=client_name) if errors: - return jsonify( - result="error", message=errors - ), 400 + raise InvalidRequest(errors, status_code=400) ses_message = json.loads(ses_request['Message']) errors = validate_callback_data(data=ses_message, fields=['notificationType'], client_name=client_name) if errors: - return jsonify( - result="error", message=errors - ), 400 + raise InvalidRequest(errors, status_code=400) notification_type = ses_message['notificationType'] if notification_type == 'Bounce': @@ -67,12 +66,8 @@ def process_ses_response(): try: aws_response_dict = get_aws_responses(notification_type) except KeyError: - message = "{} callback failed: status {} not found".format(client_name, notification_type) - current_app.logger.info(message) - return jsonify( - result="error", - message=message - ), 400 + error = "{} callback failed: status {} not found".format(client_name, notification_type) + raise InvalidRequest(error, status_code=400) notification_status = aws_response_dict['notification_status'] notification_statistics_status = aws_response_dict['notification_statistics_status'] @@ -93,15 +88,9 @@ def process_ses_response(): notification_status, notification_statistics_status ): - message = "SES callback failed: notification either not found or already updated " \ - "from sending. Status {}".format(notification_status) - current_app.logger.info( - message - ) - return jsonify( - result="error", - message=message - ), 404 + error = "SES callback failed: notification either not found or already updated " \ + "from sending. Status {}".format(notification_status) + raise InvalidRequest(error, status_code=404) if not aws_response_dict['success']: current_app.logger.info( @@ -117,20 +106,12 @@ def process_ses_response(): ), 200 except KeyError: - current_app.logger.error( - "SES callback failed: messageId missing" - ) - return jsonify( - result="error", message="SES callback failed: messageId missing" - ), 400 + message = "SES callback failed: messageId missing" + raise InvalidRequest(message, status_code=400) except ValueError as ex: - current_app.logger.exception( - "{} callback failed: invalid json {}".format(client_name, ex) - ) - return jsonify( - result="error", message="{} callback failed: invalid json".format(client_name) - ), 400 + error = "{} callback failed: invalid json".format(client_name) + raise InvalidRequest(error, status_code=400) def is_not_a_notification(source): @@ -149,19 +130,17 @@ def is_not_a_notification(source): def process_mmg_response(): client_name = 'MMG' data = json.loads(request.data) - validation_errors = validate_callback_data(data=data, - fields=['status', 'CID'], - client_name=client_name) - if validation_errors: - [current_app.logger.info(e) for e in validation_errors] - return jsonify(result='error', message=validation_errors), 400 + errors = validate_callback_data(data=data, + fields=['status', 'CID'], + client_name=client_name) + if errors: + raise InvalidRequest(errors, status_code=400) success, errors = process_sms_client_response(status=str(data.get('status')), reference=data.get('CID'), client_name=client_name) if errors: - [current_app.logger.info(e) for e in errors] - return jsonify(result='error', message=errors), 400 + raise InvalidRequest(errors, status_code=400) else: return jsonify(result='success', message=success), 200 @@ -169,12 +148,11 @@ def process_mmg_response(): @notifications.route('/notifications/sms/firetext', methods=['POST']) def process_firetext_response(): client_name = 'Firetext' - validation_errors = validate_callback_data(data=request.form, - fields=['status', 'reference'], - client_name=client_name) - if validation_errors: - current_app.logger.info(validation_errors) - return jsonify(result='error', message=validation_errors), 400 + errors = validate_callback_data(data=request.form, + fields=['status', 'reference'], + client_name=client_name) + if errors: + raise InvalidRequest(errors, status_code=400) response_code = request.form.get('code') status = request.form.get('status') @@ -185,8 +163,7 @@ def process_firetext_response(): reference=request.form.get('reference'), client_name=client_name) if errors: - [current_app.logger.info(e) for e in errors] - return jsonify(result='error', message=errors), 400 + raise InvalidRequest(errors, status_code=400) else: return jsonify(result='success', message=success), 200 @@ -199,10 +176,7 @@ def get_notifications(notification_id): @notifications.route('/notifications', methods=['GET']) def get_all_notifications(): - data, errors = notifications_filter_schema.load(request.args) - if errors: - return jsonify(result="error", message=errors), 400 - + data = notifications_filter_schema.load(request.args).data page = data['page'] if 'page' in data else 1 page_size = data['page_size'] if 'page_size' in data else current_app.config.get('PAGE_SIZE') limit_days = data.get('limit_days') @@ -228,10 +202,7 @@ def get_all_notifications(): @notifications.route('/service//notifications', methods=['GET']) @require_admin() def get_all_notifications_for_service(service_id): - data, errors = notifications_filter_schema.load(request.args) - if errors: - return jsonify(result="error", message=errors), 400 - + data = notifications_filter_schema.load(request.args).data page = data['page'] if 'page' in data else 1 page_size = data['page_size'] if 'page_size' in data else current_app.config.get('PAGE_SIZE') limit_days = data.get('limit_days') @@ -259,10 +230,7 @@ def get_all_notifications_for_service(service_id): @notifications.route('/service//job//notifications', methods=['GET']) @require_admin() def get_all_notifications_for_service_job(service_id, job_id): - data, errors = notifications_filter_schema.load(request.args) - if errors: - return jsonify(result="error", message=errors), 400 - + data = notifications_filter_schema.load(request.args).data page = data['page'] if 'page' in data else 1 page_size = data['page_size'] if 'page_size' in data else current_app.config.get('PAGE_SIZE') @@ -317,16 +285,13 @@ def send_notification(notification_type): total_email_count = service_stats.emails_requested if (total_email_count + total_sms_count >= service.message_limit) and service.restricted: - return jsonify(result="error", message='Exceeded send limits ({}) for today'.format( - service.message_limit)), 429 + error = 'Exceeded send limits ({}) for today'.format(service.message_limit) + raise InvalidRequest(error, status_code=429) notification, errors = ( sms_template_notification_schema if notification_type == 'sms' else email_notification_schema ).load(request.get_json()) - if errors: - return jsonify(result="error", message=errors), 400 - template = templates_dao.dao_get_template_by_id_and_service_id( template_id=notification['template'], service_id=service_id @@ -334,33 +299,24 @@ def send_notification(notification_type): errors = unarchived_template_schema.validate({'archived': template.archived}) if errors: - return jsonify(result='error', message=errors), 400 + raise InvalidRequest(errors, status_code=400) template_object = Template(template.__dict__, notification.get('personalisation', {})) if template_object.missing_data: - return jsonify( - result="error", - message={ - 'template': ['Missing personalisation: {}'.format( - ", ".join(template_object.missing_data) - )] - } - ), 400 + message = 'Missing personalisation: {}'.format(", ".join(template_object.missing_data)) + errors = {'template': [message]} + raise InvalidRequest(errors, status_code=400) + if template_object.additional_data: - return jsonify( - result="error", - message={ - 'template': ['Personalisation not needed for template: {}'.format( - ", ".join(template_object.additional_data) - )] - } - ), 400 + message = 'Personalisation not needed for template: {}'.format(", ".join(template_object.additional_data)) + errors = {'template': [message]} + raise InvalidRequest(errors, status_code=400) if template_object.replaced_content_count > current_app.config.get('SMS_CHAR_COUNT_LIMIT'): - return jsonify( - result="error", - message={'content': ['Content has a character count greater than the limit of {}'.format( - current_app.config.get('SMS_CHAR_COUNT_LIMIT'))]}), 400 + char_count = current_app.config.get('SMS_CHAR_COUNT_LIMIT') + message = 'Content has a character count greater than the limit of {}'.format(char_count) + errors = {'content': [message]} + raise InvalidRequest(errors, status_code=400) if service.restricted and not allowed_to_send_to( notification['to'], @@ -368,11 +324,9 @@ def send_notification(notification_type): [user.mobile_number, user.email_address] for user in service.users ) ): - return jsonify( - result="error", message={ - 'to': ['Invalid {} for restricted service'.format(first_column_heading[notification_type])] - } - ), 400 + message = 'Invalid {} for restricted service'.format(first_column_heading[notification_type]) + errors = {'to': [message]} + raise InvalidRequest(errors, status_code=400) notification_id = create_uuid() notification.update({"template_version": template.version}) @@ -397,13 +351,9 @@ def send_notification(notification_type): @notifications.route('/notifications/statistics') def get_notification_statistics_for_day(): - data, errors = day_schema.load(request.args) - if errors: - return jsonify(result='error', message=errors), 400 - + data = day_schema.load(request.args).data statistics = notifications_dao.dao_get_potential_notification_statistics_for_day( day=data['day'] ) - data, errors = notifications_statistics_schema.dump(statistics, many=True) return jsonify(data=data), 200 diff --git a/app/notifications_statistics/rest.py b/app/notifications_statistics/rest.py index f864e473b..69b1eafab 100644 --- a/app/notifications_statistics/rest.py +++ b/app/notifications_statistics/rest.py @@ -19,7 +19,10 @@ notifications_statistics = Blueprint( __name__, url_prefix='/service//notifications-statistics' ) -from app.errors import register_errors +from app.errors import ( + register_errors, + InvalidRequest +) register_errors(notifications_statistics) @@ -34,9 +37,9 @@ def get_all_notification_statistics_for_service(service_id): limit_days=int(request.args['limit_days']) ) except ValueError as e: - error = '{} is not an integer'.format(request.args['limit_days']) - current_app.logger.error(error) - return jsonify(result="error", message={'limit_days': [error]}), 400 + message = '{} is not an integer'.format(request.args['limit_days']) + errors = {'limit_days': [message]} + raise InvalidRequest(errors, status_code=400) else: statistics = dao_get_notification_statistics_for_service(service_id=service_id) @@ -46,9 +49,7 @@ def get_all_notification_statistics_for_service(service_id): @notifications_statistics.route('/seven_day_aggregate') def get_notification_statistics_for_service_seven_day_aggregate(service_id): - data, errors = week_aggregate_notification_statistics_schema.load(request.args) - if errors: - return jsonify(result='error', message=errors), 400 + data = week_aggregate_notification_statistics_schema.load(request.args).data date_from = data['date_from'] if 'date_from' in data else date(date.today().year, 4, 1) week_count = data['week_count'] if 'week_count' in data else 52 stats = dao_get_7_day_agg_notification_statistics_for_service( diff --git a/app/permission/rest.py b/app/permission/rest.py index 89266b88b..bf017bfe9 100644 --- a/app/permission/rest.py +++ b/app/permission/rest.py @@ -1,5 +1,9 @@ +from flask import ( + jsonify, + request, + Blueprint +) -from flask import (jsonify, request, abort, Blueprint, current_app) from app.schemas import permission_schema from app.errors import register_errors from app.dao.permissions_dao import permission_dao @@ -10,17 +14,12 @@ register_errors(permission) @permission.route('', methods=['GET']) def get_permissions(): - data, errors = permission_schema.dump( - permission_dao.get_query(filter_by_dict=request.args), many=True) - if errors: - abort(500, errors) + data = permission_schema.dump(permission_dao.get_query(filter_by_dict=request.args), many=True).data return jsonify(data=data) @permission.route('/', methods=['GET']) def get_permission(permission_id): inst = permission_dao.get_query(filter_by_dict={'id': permission_id}).one() - data, errors = permission_schema.dump(inst) - if errors: - abort(500, errors) + data = permission_schema.dump(inst).data return jsonify(data=data) diff --git a/app/provider_details/rest.py b/app/provider_details/rest.py index f976f1560..b075653f5 100644 --- a/app/provider_details/rest.py +++ b/app/provider_details/rest.py @@ -1,25 +1,31 @@ from flask import Blueprint, jsonify, request from app.schemas import provider_details_schema + from app.dao.provider_details_dao import ( get_provider_details, get_provider_details_by_id, - get_provider_details_by_id, dao_update_provider_details ) +from app.errors import ( + register_errors, + InvalidRequest +) + provider_details = Blueprint('provider_details', __name__) +register_errors(provider_details) @provider_details.route('', methods=['GET']) def get_providers(): - data, errors = provider_details_schema.dump(get_provider_details(), many=True) + data = provider_details_schema.dump(get_provider_details(), many=True).data return jsonify(provider_details=data) @provider_details.route('/', methods=['GET']) def get_provider_by_id(provider_details_id): - data, errors = provider_details_schema.dump(get_provider_details_by_id(provider_details_id)) + data = provider_details_schema.dump(get_provider_details_by_id(provider_details_id)).data return jsonify(provider_details=data) @@ -29,14 +35,12 @@ def update_provider_details(provider_details_id): current_data = dict(provider_details_schema.dump(fetched_provider_details).data.items()) current_data.update(request.get_json()) - update_dict, errors = provider_details_schema.load(current_data) - if errors: - return jsonify(result="error", message=errors), 400 + update_dict = provider_details_schema.load(current_data).data if "identifier" in request.get_json().keys(): - return jsonify(message={ - "identifier": ["Not permitted to be updated"] - }, result='error'), 400 + message = "Not permitted to be updated" + errors = {'identifier': [message]} + raise InvalidRequest(errors, status_code=400) dao_update_provider_details(update_dict) return jsonify(provider_details=provider_details_schema.dump(fetched_provider_details).data), 200 diff --git a/app/schemas.py b/app/schemas.py index 7d20012cf..97e0890cd 100644 --- a/app/schemas.py +++ b/app/schemas.py @@ -1,13 +1,16 @@ -from datetime import date +from datetime import ( + datetime, + date +) from flask_marshmallow.fields import fields -from sqlalchemy.orm import load_only from marshmallow import ( post_load, ValidationError, validates, validates_schema, - pre_load + pre_load, + pre_dump ) from marshmallow_sqlalchemy import field_for @@ -46,6 +49,7 @@ def _validate_not_in_future(dte, msg="Date cannot be in the future"): class BaseSchema(ma.ModelSchema): + def __init__(self, load_json=False, *args, **kwargs): self.load_json = load_json super(BaseSchema, self).__init__(*args, **kwargs) @@ -82,12 +86,14 @@ class UserSchema(BaseSchema): exclude = ( "updated_at", "created_at", "user_to_service", "_password", "verify_codes") + strict = True class ProviderDetailsSchema(BaseSchema): class Meta: model = models.ProviderDetails exclude = ("provider_rates", "provider_stats") + strict = True class ServiceSchema(BaseSchema): @@ -97,11 +103,13 @@ class ServiceSchema(BaseSchema): class Meta: model = models.Service exclude = ("updated_at", "created_at", "api_keys", "templates", "jobs", 'old_id') + strict = True class NotificationModelSchema(BaseSchema): class Meta: model = models.Notification + strict = True class BaseTemplateSchema(BaseSchema): @@ -109,6 +117,7 @@ class BaseTemplateSchema(BaseSchema): class Meta: model = models.Template exclude = ("service_id", "jobs") + strict = True class TemplateSchema(BaseTemplateSchema): @@ -142,6 +151,13 @@ class TemplateHistorySchema(BaseSchema): class NotificationsStatisticsSchema(BaseSchema): class Meta: model = models.NotificationStatistics + strict = True + + @pre_dump + def handle_date_str(self, in_data): + if isinstance(in_data, dict) and 'day' in in_data: + in_data['day'] = datetime.strptime(in_data['day'], '%Y-%m-%d').date() + return in_data class ApiKeySchema(BaseSchema): @@ -151,6 +167,7 @@ class ApiKeySchema(BaseSchema): class Meta: model = models.ApiKey exclude = ("service", "secret") + strict = True class JobSchema(BaseSchema): @@ -161,13 +178,22 @@ class JobSchema(BaseSchema): class Meta: model = models.Job exclude = ('notifications',) + strict = True class RequestVerifyCodeSchema(ma.Schema): + + class Meta: + strict = True + to = fields.Str(required=False) class NotificationSchema(ma.Schema): + + class Meta: + strict = True + personalisation = fields.Dict(required=False) @@ -225,12 +251,14 @@ class NotificationStatusSchema(BaseSchema): class Meta: model = models.Notification + strict = True class InvitedUserSchema(BaseSchema): class Meta: model = models.InvitedUser + strict = True @validates('email_address') def validate_to(self, value): @@ -255,9 +283,14 @@ class PermissionSchema(BaseSchema): class Meta: model = models.Permission exclude = ("created_at",) + strict = True class EmailDataSchema(ma.Schema): + + class Meta: + strict = True + email = fields.Str(required=False) @validates('email') @@ -269,6 +302,10 @@ class EmailDataSchema(ma.Schema): class NotificationsFilterSchema(ma.Schema): + + class Meta: + strict = True + template_type = fields.Nested(BaseTemplateSchema, only=['template_type'], many=True) status = fields.Nested(NotificationModelSchema, only=['status'], many=True) page = fields.Int(required=False) @@ -309,6 +346,7 @@ class TemplateStatisticsSchema(BaseSchema): class Meta: model = models.TemplateStatistics + strict = True class ServiceHistorySchema(ma.Schema): @@ -337,10 +375,14 @@ class ApiKeyHistorySchema(ma.Schema): class EventSchema(BaseSchema): class Meta: model = models.Event + strict = True class FromToDateSchema(ma.Schema): + class Meta: + strict = True + date_from = fields.Date() date_to = fields.Date() @@ -361,6 +403,10 @@ class FromToDateSchema(ma.Schema): class DaySchema(ma.Schema): + + class Meta: + strict = True + day = fields.Date(required=True) @validates('day') @@ -370,6 +416,9 @@ class DaySchema(ma.Schema): class WeekAggregateNotificationStatisticsSchema(ma.Schema): + class Meta: + strict = True + date_from = fields.Date() week_count = fields.Int() diff --git a/app/service/rest.py b/app/service/rest.py index 0d2da9cb0..1b01ec9a1 100644 --- a/app/service/rest.py +++ b/app/service/rest.py @@ -6,9 +6,9 @@ from datetime import ( from flask import ( jsonify, request, - abort, Blueprint ) + from sqlalchemy.orm.exc import NoResultFound from app.dao.api_key_dao import ( @@ -39,11 +39,12 @@ from app.schemas import ( permission_schema ) -from app.errors import register_errors +from app.errors import ( + register_errors, + InvalidRequest +) service = Blueprint('service', __name__) - - register_errors(service) @@ -54,7 +55,7 @@ def get_services(): services = dao_fetch_all_services_by_user(user_id) else: services = dao_fetch_all_services() - data, errors = service_schema.dump(services, many=True) + data = service_schema.dump(services, many=True).data return jsonify(data=data) @@ -66,7 +67,7 @@ def get_service_by_id(service_id): else: fetched = dao_fetch_service_by_id(service_id) - data, errors = service_schema.dump(fetched) + data = service_schema.dump(fetched).data return jsonify(data=data) @@ -74,16 +75,12 @@ def get_service_by_id(service_id): def create_service(): data = request.get_json() if not data.get('user_id', None): - return jsonify(result="error", message={'user_id': ['Missing data for required field.']}), 400 + errors = {'user_id': ['Missing data for required field.']} + raise InvalidRequest(errors, status_code=400) user = get_model_users(data['user_id']) - data.pop('user_id', None) - valid_service, errors = service_schema.load(request.get_json()) - - if errors: - return jsonify(result="error", message=errors), 400 - + valid_service = service_schema.load(request.get_json()).data dao_create_service(valid_service, user) return jsonify(data=service_schema.dump(valid_service).data), 201 @@ -93,9 +90,7 @@ def update_service(service_id): fetched_service = dao_fetch_service_by_id(service_id) current_data = dict(service_schema.dump(fetched_service).data.items()) current_data.update(request.get_json()) - update_dict, errors = service_schema.load(current_data) - if errors: - return jsonify(result="error", message=errors), 400 + update_dict = service_schema.load(current_data).data dao_update_service(update_dict) return jsonify(data=service_schema.dump(fetched_service).data), 200 @@ -103,14 +98,9 @@ def update_service(service_id): @service.route('//api-key', methods=['POST']) def renew_api_key(service_id=None): fetched_service = dao_fetch_service_by_id(service_id=service_id) - - valid_api_key, errors = api_key_schema.load(request.get_json()) - if errors: - return jsonify(result="error", message=errors), 400 + valid_api_key = api_key_schema.load(request.get_json()).data valid_api_key.service = fetched_service - save_model_api_key(valid_api_key) - unsigned_api_key = get_unsigned_secret(valid_api_key.id) return jsonify(data=unsigned_api_key), 201 @@ -133,7 +123,8 @@ def get_api_keys(service_id, key_id=None): else: api_keys = get_model_api_keys(service_id=service_id) except NoResultFound: - return jsonify(result="error", message="API key not found for id: {}".format(service_id)), 404 + error = "API key not found for id: {}".format(service_id) + raise InvalidRequest(error, status_code=404) return jsonify(apiKeys=api_key_schema.dump(api_keys, many=True).data), 200 @@ -141,7 +132,6 @@ def get_api_keys(service_id, key_id=None): @service.route('//users', methods=['GET']) def get_users_for_service(service_id): fetched = dao_fetch_service_by_id(service_id) - result = user_schema.dump(fetched.users, many=True) return jsonify(data=result.data) @@ -152,15 +142,12 @@ def add_user_to_service(service_id, user_id): user = get_model_users(user_id=user_id) if user in service.users: - return jsonify(result='error', - message='User id: {} already part of service id: {}'.format(user_id, service_id)), 400 - - permissions, errors = permission_schema.load(request.get_json(), many=True) - if errors: - abort(400, errors) + error = 'User id: {} already part of service id: {}'.format(user_id, service_id) + raise InvalidRequest(error, status_code=400) + permissions = permission_schema.load(request.get_json(), many=True).data dao_add_user_to_service(service, user, permissions) - data, errors = service_schema.dump(service) + data = service_schema.dump(service).data return jsonify(data=data), 201 @@ -169,13 +156,13 @@ def remove_user_from_service(service_id, user_id): service = dao_fetch_service_by_id(service_id) user = get_model_users(user_id=user_id) if user not in service.users: - return jsonify( - result='error', - message='User not found'), 404 + error = 'User not found' + raise InvalidRequest(error, status_code=404) + elif len(service.users) == 1: - return jsonify( - result='error', - message='You cannot remove the only user for a service'), 400 + error = 'You cannot remove the only user for a service' + raise InvalidRequest(error, status_code=400) + dao_remove_user_from_service(service, user) return jsonify({}), 204 @@ -183,10 +170,7 @@ def remove_user_from_service(service_id, user_id): @service.route('//fragment/aggregate_statistics') def get_service_provider_aggregate_statistics(service_id): service = dao_fetch_service_by_id(service_id) - data, errors = from_to_date_schema.load(request.args) - if errors: - return jsonify(result='error', message=errors), 400 - + data = from_to_date_schema.load(request.args).data return jsonify(data=get_fragment_count( service, date_from=(data.pop('date_from') if 'date_from' in data else date.today()), @@ -208,21 +192,15 @@ def get_service_history(service_id): ) service_history = Service.get_history_model().query.filter_by(id=service_id).all() - service_data, errors = service_history_schema.dump(service_history, many=True) - if errors: - return jsonify(result="error", message=errors), 400 - + service_data = service_history_schema.dump(service_history, many=True).data api_key_history = ApiKey.get_history_model().query.filter_by(service_id=service_id).all() - - api_keys_data, errors = api_key_history_schema.dump(api_key_history, many=True) - if errors: - return jsonify(result="error", message=errors), 400 + api_keys_data = api_key_history_schema.dump(api_key_history, many=True).data template_history = Template.get_history_model().query.filter_by(service_id=service_id).all() template_data, errors = template_history_schema.dump(template_history, many=True) events = Event.query.all() - events_data, errors = event_schema.dump(events, many=True) + events_data = event_schema.dump(events, many=True).data data = { 'service_history': service_data, diff --git a/app/template/rest.py b/app/template/rest.py index 168d483ea..8ba246d34 100644 --- a/app/template/rest.py +++ b/app/template/rest.py @@ -19,7 +19,10 @@ from app.schemas import (template_schema, template_history_schema) template = Blueprint('template', __name__, url_prefix='/service//template') -from app.errors import register_errors +from app.errors import ( + register_errors, + InvalidRequest +) register_errors(template) @@ -28,27 +31,23 @@ def _content_count_greater_than_limit(content, template_type): template = Template({'content': content, 'template_type': template_type}) if template_type == 'sms' and \ template.content_count > current_app.config.get('SMS_CHAR_COUNT_LIMIT'): - return True, jsonify( - result="error", - message={'content': ['Content has a character count greater than the limit of {}'.format( - current_app.config.get('SMS_CHAR_COUNT_LIMIT'))]} - ) - return False, '' + return True + return False @template.route('', methods=['POST']) def create_template(service_id): fetched_service = dao_fetch_service_by_id(service_id=service_id) - new_template, errors = template_schema.load(request.get_json()) - if errors: - return jsonify(result="error", message=errors), 400 + new_template = template_schema.load(request.get_json()).data new_template.service = fetched_service new_template.content = _strip_html(new_template.content) - over_limit, json_resp = _content_count_greater_than_limit( - new_template.content, - new_template.template_type) + over_limit = _content_count_greater_than_limit(new_template.content, new_template.template_type) if over_limit: - return json_resp, 400 + char_count_limit = current_app.config.get('SMS_CHAR_COUNT_LIMIT') + message = 'Content has a character count greater than the limit of {}'.format(char_count_limit) + errors = {'content': [message]} + raise InvalidRequest(errors, status_code=400) + dao_create_template(new_template) return jsonify(data=template_schema.dump(new_template).data), 201 @@ -65,14 +64,13 @@ def update_template(service_id, template_id): if _template_has_not_changed(current_data, updated_template): return jsonify(data=updated_template), 200 - update_dict, errors = template_schema.load(updated_template) - if errors: - return jsonify(result="error", message=errors), 400 - over_limit, json_resp = _content_count_greater_than_limit( - updated_template['content'], - fetched_template.template_type) + update_dict = template_schema.load(updated_template).data + over_limit = _content_count_greater_than_limit(updated_template['content'], fetched_template.template_type) if over_limit: - return json_resp, 400 + char_count_limit = current_app.config.get('SMS_CHAR_COUNT_LIMIT') + message = 'Content has a character count greater than the limit of {}'.format(char_count_limit) + errors = {'content': [message]} + raise InvalidRequest(errors, status_code=400) dao_update_template(update_dict) return jsonify(data=template_schema.dump(update_dict).data), 200 @@ -80,39 +78,35 @@ def update_template(service_id, template_id): @template.route('', methods=['GET']) def get_all_templates_for_service(service_id): templates = dao_get_all_templates_for_service(service_id=service_id) - data, errors = template_schema.dump(templates, many=True) + data = template_schema.dump(templates, many=True).data return jsonify(data=data) @template.route('/', methods=['GET']) def get_template_by_id_and_service_id(service_id, template_id): fetched_template = dao_get_template_by_id_and_service_id(template_id=template_id, service_id=service_id) - data, errors = template_schema.dump(fetched_template) + data = template_schema.dump(fetched_template).data return jsonify(data=data) @template.route('//version/') def get_template_version(service_id, template_id, version): - data, errors = template_history_schema.dump( + data = template_history_schema.dump( dao_get_template_by_id_and_service_id( template_id=template_id, service_id=service_id, version=version ) - ) - if errors: - return jsonify(result='error', message=errors), 400 + ).data return jsonify(data=data) @template.route('//versions') def get_template_versions(service_id, template_id): - data, errors = template_history_schema.dump( + data = template_history_schema.dump( dao_get_template_versions(service_id=service_id, template_id=template_id), many=True - ) - if errors: - return jsonify(result='error', message=errors), 400 + ).data return jsonify(data=data) diff --git a/app/template_statistics/rest.py b/app/template_statistics/rest.py index 1e46843c5..d4990ba04 100644 --- a/app/template_statistics/rest.py +++ b/app/template_statistics/rest.py @@ -15,7 +15,7 @@ template_statistics = Blueprint('template-statistics', __name__, url_prefix='/service//template-statistics') -from app.errors import register_errors, InvalidData +from app.errors import register_errors, InvalidRequest register_errors(template_statistics) @@ -28,20 +28,16 @@ def get_template_statistics_for_service(service_id): except ValueError as e: error = '{} is not an integer'.format(request.args['limit_days']) message = {'limit_days': [error]} - raise InvalidData(message, status_code=400) + raise InvalidRequest(message, status_code=400) else: limit_days = None stats = dao_get_template_statistics_for_service(service_id, limit_days=limit_days) - data, errors = template_statistics_schema.dump(stats, many=True) - if errors: - raise InvalidData(errors, status_code=400) + data = template_statistics_schema.dump(stats, many=True).data return jsonify(data=data) @template_statistics.route('/') def get_template_statistics_for_template_id(service_id, template_id): stats = dao_get_template_statistics_for_template(template_id) - data, errors = template_statistics_schema.dump(stats, many=True) - if errors: - raise InvalidData(errors, status_code=400) + data = template_statistics_schema.dump(stats, many=True).data return jsonify(data=data) diff --git a/app/user/rest.py b/app/user/rest.py index 46668b98c..93d7f1d9e 100644 --- a/app/user/rest.py +++ b/app/user/rest.py @@ -1,7 +1,7 @@ import json import uuid from datetime import datetime -from flask import (jsonify, request, abort, Blueprint, current_app) +from flask import (jsonify, request, Blueprint, current_app) from app import encryption, DATETIME_FORMAT from app.dao.users_dao import ( get_model_users, @@ -28,9 +28,13 @@ from app.schemas import ( from app.celery.tasks import ( send_sms, email_reset_password, - send_email) + send_email +) -from app.errors import register_errors +from app.errors import ( + register_errors, + InvalidRequest +) user = Blueprint('user', __name__) register_errors(user) @@ -43,9 +47,7 @@ def create_user(): # TODO password policy, what is valid password if not req_json.get('password', None): errors.update({'password': ['Missing data for required field.']}) - return jsonify(result="error", message=errors), 400 - if errors: - return jsonify(result="error", message=errors), 400 + raise InvalidRequest(errors, status_code=400) save_model_user(user_to_create, pwd=req_json.get('password')) return jsonify(data=user_schema.dump(user_to_create).data), 201 @@ -60,11 +62,9 @@ def update_user(user_id): # but would be good to have the same validation here. if pwd is not None and not pwd: errors.update({'password': ['Invalid data for field']}) - if errors: - return jsonify(result="error", message=errors), 400 - status_code = 200 + raise InvalidRequest(errors, status_code=400) save_model_user(user_to_update, update_dict=update_dct, pwd=pwd) - return jsonify(data=user_schema.dump(user_to_update).data), status_code + return jsonify(data=user_schema.dump(user_to_update).data), 200 @user.route('//verify/password', methods=['POST']) @@ -75,9 +75,10 @@ def verify_user_password(user_id): try: txt_pwd = request.get_json()['password'] except KeyError: - return jsonify( - result="error", - message={'password': ['Required field missing data']}), 400 + message = 'Required field missing data' + errors = {'password': [message]} + raise InvalidRequest(errors, status_code=400) + if user_to_verify.check_password(txt_pwd): user_to_verify.logged_in_at = datetime.utcnow() save_model_user(user_to_verify) @@ -85,7 +86,9 @@ def verify_user_password(user_id): return jsonify({}), 204 else: increment_failed_login_count(user_to_verify) - return jsonify(result='error', message={'password': ['Incorrect password']}), 400 + message = 'Incorrect password' + errors = {'password': [message]} + raise InvalidRequest(errors, status_code=400) @user.route('//verify/code', methods=['POST']) @@ -105,12 +108,13 @@ def verify_user_code(user_id): except KeyError: errors.update({'code_type': ['Required field missing data']}) if errors: - return jsonify(result="error", message=errors), 400 + raise InvalidRequest(errors, status_code=400) + code = get_user_code(user_to_verify, txt_code, txt_type) if not code: - return jsonify(result="error", message="Code not found"), 404 + raise InvalidRequest("Code not found", status_code=404) if datetime.utcnow() > code.expiry_datetime or code.code_used: - return jsonify(result="error", message="Code has expired"), 400 + raise InvalidRequest("Code has expired", status_code=400) use_user_code(code.id) return jsonify({}), 204 @@ -119,8 +123,6 @@ def verify_user_code(user_id): def send_user_sms_code(user_id): user_to_send_to = get_model_users(user_id=user_id) verify_code, errors = request_verify_code_schema.load(request.get_json()) - if errors: - return jsonify(result="error", message=errors), 400 secret_code = create_secret_code() create_user_code(user_to_send_to, secret_code, 'sms') @@ -150,8 +152,6 @@ def send_user_sms_code(user_id): def send_user_email_verification(user_id): user_to_send_to = get_model_users(user_id=user_id) verify_code, errors = request_verify_code_schema.load(request.get_json()) - if errors: - return jsonify(result="error", message=errors), 400 secret_code = create_secret_code() create_user_code(user_to_send_to, secret_code, 'email') @@ -191,8 +191,7 @@ def set_permissions(user_id, service_id): user = get_model_users(user_id=user_id) service = dao_fetch_service_by_id(service_id=service_id) permissions, errors = permission_schema.load(request.get_json(), many=True) - if errors: - abort(400, errors) + for p in permissions: p.user = user p.service = service @@ -204,7 +203,8 @@ def set_permissions(user_id, service_id): def get_by_email(): email = request.args.get('email') if not email: - return jsonify(result="error", message="invalid request"), 400 + error = 'Invalid request. Email query string param required' + raise InvalidRequest(error, status_code=400) fetched_user = get_user_by_email(email) result = user_schema.dump(fetched_user) @@ -214,8 +214,6 @@ def get_by_email(): @user.route('/reset-password', methods=['POST']) def send_user_reset_password(): email, errors = email_data_request_schema.load(request.get_json()) - if errors: - return jsonify(result="error", message=errors), 400 user_to_send_to = get_user_by_email(email['email']) diff --git a/tests/app/dao/test_templates_dao.py b/tests/app/dao/test_templates_dao.py index ad4b4372e..82b943963 100644 --- a/tests/app/dao/test_templates_dao.py +++ b/tests/app/dao/test_templates_dao.py @@ -267,6 +267,6 @@ def test_get_template_versions(sample_template): assert x.content == 'new version' else: assert x.content == original_content - from app.schemas import (template_history_schema) + from app.schemas import template_history_schema v = template_history_schema.load(versions, many=True) assert v.__len__() == 2 diff --git a/tests/app/user/test_rest.py b/tests/app/user/test_rest.py index 82b93d4cb..4f6286f6f 100644 --- a/tests/app/user/test_rest.py +++ b/tests/app/user/test_rest.py @@ -296,7 +296,7 @@ def test_get_user_by_email_bad_url_returns_404(notify_api, assert resp.status_code == 400 json_resp = json.loads(resp.get_data(as_text=True)) assert json_resp['result'] == 'error' - assert json_resp['message'] == 'invalid request' + assert json_resp['message'] == 'Invalid request. Email query string param required' def test_get_user_with_permissions(notify_api, From c26840155496e76e2224b14068bff2e591622faf Mon Sep 17 00:00:00 2001 From: Adam Shimali Date: Wed, 15 Jun 2016 16:19:28 +0100 Subject: [PATCH 2/2] Updated for pr comments --- app/job/rest.py | 6 +++--- app/notifications/rest.py | 3 +++ app/template/rest.py | 5 +---- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/app/job/rest.py b/app/job/rest.py index f4c090b19..fc3d49943 100644 --- a/app/job/rest.py +++ b/app/job/rest.py @@ -1,8 +1,8 @@ from flask import ( Blueprint, jsonify, - request, - current_app) + request +) from app.dao.jobs_dao import ( dao_create_job, @@ -43,7 +43,7 @@ def get_jobs_by_service(service_id): try: limit_days = int(request.args['limit_days']) except ValueError as e: - errors = {'error': ['{} is not an integer'.format(request.args['limit_days'])]} + errors = {'limit_days': ['{} is not an integer'.format(request.args['limit_days'])]} raise InvalidRequest(errors, status_code=400) else: limit_days = None diff --git a/app/notifications/rest.py b/app/notifications/rest.py index 933297e2f..ab5cf2168 100644 --- a/app/notifications/rest.py +++ b/app/notifications/rest.py @@ -292,6 +292,9 @@ def send_notification(notification_type): sms_template_notification_schema if notification_type == 'sms' else email_notification_schema ).load(request.get_json()) + if errors: + raise InvalidRequest(errors, status_code=400) + template = templates_dao.dao_get_template_by_id_and_service_id( template_id=notification['template'], service_id=service_id diff --git a/app/template/rest.py b/app/template/rest.py index 8ba246d34..dc285e9d1 100644 --- a/app/template/rest.py +++ b/app/template/rest.py @@ -29,10 +29,7 @@ register_errors(template) def _content_count_greater_than_limit(content, template_type): template = Template({'content': content, 'template_type': template_type}) - if template_type == 'sms' and \ - template.content_count > current_app.config.get('SMS_CHAR_COUNT_LIMIT'): - return True - return False + return template_type == 'sms' and template.content_count > current_app.config.get('SMS_CHAR_COUNT_LIMIT') @template.route('', methods=['POST'])