diff --git a/app/accept_invite/rest.py b/app/accept_invite/rest.py index a914654e9..1a0366f5e 100644 --- a/app/accept_invite/rest.py +++ b/app/accept_invite/rest.py @@ -10,7 +10,11 @@ from notifications_utils.url_safe_token import check_token from app.dao.invited_user_dao import get_invited_user_by_id -from app.errors import register_errors +from app.errors import ( + register_errors, + InvalidRequest +) + from app.schemas import invited_user_schema @@ -30,7 +34,8 @@ def get_invited_user_by_token(token): max_age_seconds) except SignatureExpired: message = 'Invitation with id {} expired'.format(invited_user_id) - return jsonify(result='error', message=message), 400 + errors = {'invitation': [message]} + raise InvalidRequest(errors, status_code=400) invited_user = get_invited_user_by_id(invited_user_id) diff --git a/app/errors.py b/app/errors.py index c3cc70c01..9d90dee15 100644 --- a/app/errors.py +++ b/app/errors.py @@ -4,9 +4,10 @@ from flask import ( ) from sqlalchemy.exc import SQLAlchemyError, DataError from sqlalchemy.orm.exc import NoResultFound +from marshmallow import ValidationError -class InvalidData(Exception): +class InvalidRequest(Exception): def __init__(self, message, status_code): super().__init__() @@ -22,7 +23,12 @@ class InvalidData(Exception): def register_errors(blueprint): - @blueprint.app_errorhandler(InvalidData) + @blueprint.app_errorhandler(ValidationError) + def validation_error(error): + current_app.logger.error(error) + return jsonify(result='error', message=error.messages), 400 + + @blueprint.app_errorhandler(InvalidRequest) def invalid_data(error): response = jsonify(error.to_dict()) response.status_code = error.status_code diff --git a/app/events/rest.py b/app/events/rest.py index d3d38454e..6bd189617 100644 --- a/app/events/rest.py +++ b/app/events/rest.py @@ -5,6 +5,7 @@ from flask import ( ) from app.errors import register_errors + from app.schemas import event_schema from app.dao.events_dao import dao_create_event @@ -15,8 +16,6 @@ register_errors(events) @events.route('', methods=['POST']) def create_event(): data = request.get_json() - event, errors = event_schema.load(data) - if errors: - return jsonify(result="error", message=errors), 400 + event = event_schema.load(data).data dao_create_event(event) return jsonify(data=event_schema.dump(event).data), 201 diff --git a/app/invite/rest.py b/app/invite/rest.py index 14a6dae84..9190ad252 100644 --- a/app/invite/rest.py +++ b/app/invite/rest.py @@ -19,14 +19,13 @@ from app.celery.tasks import (email_invited_user) invite = Blueprint('invite', __name__, url_prefix='/service//invite') from app.errors import register_errors + register_errors(invite) @invite.route('', methods=['POST']) def create_invited_user(service_id): invited_user, errors = invited_user_schema.load(request.get_json()) - if errors: - return jsonify(result="error", message=errors), 400 save_invited_user(invited_user) invitation = _create_invitation(invited_user) encrypted_invitation = encryption.encrypt(invitation) @@ -53,9 +52,7 @@ def update_invited_user(service_id, invited_user_id): current_data = dict(invited_user_schema.dump(fetched).data.items()) current_data.update(request.get_json()) - update_dict, errors = invited_user_schema.load(current_data) - if errors: - return jsonify(result='error', message=errors), 400 + update_dict = invited_user_schema.load(current_data).data save_invited_user(update_dict) return jsonify(data=invited_user_schema.dump(fetched).data), 200 diff --git a/app/job/rest.py b/app/job/rest.py index 9b2d44382..fc3d49943 100644 --- a/app/job/rest.py +++ b/app/job/rest.py @@ -1,8 +1,8 @@ from flask import ( Blueprint, jsonify, - request, - current_app) + request +) from app.dao.jobs_dao import ( dao_create_job, @@ -22,7 +22,10 @@ from app.celery.tasks import process_job job = Blueprint('job', __name__, url_prefix='/service//job') -from app.errors import register_errors +from app.errors import ( + register_errors, + InvalidRequest +) register_errors(job) @@ -30,7 +33,7 @@ register_errors(job) @job.route('/', methods=['GET']) def get_job_by_service_and_job_id(service_id, job_id): job = dao_get_job_by_service_id_and_job_id(service_id, job_id) - data, errors = job_schema.dump(job) + data = job_schema.dump(job).data return jsonify(data=data) @@ -40,16 +43,13 @@ def get_jobs_by_service(service_id): try: limit_days = int(request.args['limit_days']) except ValueError as e: - error = '{} is not an integer'.format(request.args['limit_days']) - current_app.logger.error(error) - return jsonify(result="error", message={'limit_days': [error]}), 400 + errors = {'limit_days': ['{} is not an integer'.format(request.args['limit_days'])]} + raise InvalidRequest(errors, status_code=400) else: limit_days = None jobs = dao_get_jobs_by_service_id(service_id, limit_days) - data, errors = job_schema.dump(jobs, many=True) - if errors: - return jsonify(result="error", message=errors), 400 + data = job_schema.dump(jobs, many=True).data return jsonify(data=data) @@ -66,13 +66,10 @@ def create_job(service_id): errors = unarchived_template_schema.validate({'archived': template.archived}) if errors: - return jsonify(result='error', message=errors), 400 + raise InvalidRequest(errors, status_code=400) data.update({"template_version": template.version}) - job, errors = job_schema.load(data) - if errors: - return jsonify(result="error", message=errors), 400 - + job = job_schema.load(data).data dao_create_job(job) process_job.apply_async([str(job.id)], queue="process-job") return jsonify(data=job_schema.dump(job).data), 201 diff --git a/app/notifications/rest.py b/app/notifications/rest.py index 7f4029a65..ab5cf2168 100644 --- a/app/notifications/rest.py +++ b/app/notifications/rest.py @@ -35,7 +35,10 @@ from app.celery.tasks import send_sms, send_email notifications = Blueprint('notifications', __name__) -from app.errors import register_errors +from app.errors import ( + register_errors, + InvalidRequest +) register_errors(notifications) @@ -47,16 +50,12 @@ def process_ses_response(): ses_request = json.loads(request.data) errors = validate_callback_data(data=ses_request, fields=['Message'], client_name=client_name) if errors: - return jsonify( - result="error", message=errors - ), 400 + raise InvalidRequest(errors, status_code=400) ses_message = json.loads(ses_request['Message']) errors = validate_callback_data(data=ses_message, fields=['notificationType'], client_name=client_name) if errors: - return jsonify( - result="error", message=errors - ), 400 + raise InvalidRequest(errors, status_code=400) notification_type = ses_message['notificationType'] if notification_type == 'Bounce': @@ -67,12 +66,8 @@ def process_ses_response(): try: aws_response_dict = get_aws_responses(notification_type) except KeyError: - message = "{} callback failed: status {} not found".format(client_name, notification_type) - current_app.logger.info(message) - return jsonify( - result="error", - message=message - ), 400 + error = "{} callback failed: status {} not found".format(client_name, notification_type) + raise InvalidRequest(error, status_code=400) notification_status = aws_response_dict['notification_status'] notification_statistics_status = aws_response_dict['notification_statistics_status'] @@ -93,15 +88,9 @@ def process_ses_response(): notification_status, notification_statistics_status ): - message = "SES callback failed: notification either not found or already updated " \ - "from sending. Status {}".format(notification_status) - current_app.logger.info( - message - ) - return jsonify( - result="error", - message=message - ), 404 + error = "SES callback failed: notification either not found or already updated " \ + "from sending. Status {}".format(notification_status) + raise InvalidRequest(error, status_code=404) if not aws_response_dict['success']: current_app.logger.info( @@ -117,20 +106,12 @@ def process_ses_response(): ), 200 except KeyError: - current_app.logger.error( - "SES callback failed: messageId missing" - ) - return jsonify( - result="error", message="SES callback failed: messageId missing" - ), 400 + message = "SES callback failed: messageId missing" + raise InvalidRequest(message, status_code=400) except ValueError as ex: - current_app.logger.exception( - "{} callback failed: invalid json {}".format(client_name, ex) - ) - return jsonify( - result="error", message="{} callback failed: invalid json".format(client_name) - ), 400 + error = "{} callback failed: invalid json".format(client_name) + raise InvalidRequest(error, status_code=400) def is_not_a_notification(source): @@ -149,19 +130,17 @@ def is_not_a_notification(source): def process_mmg_response(): client_name = 'MMG' data = json.loads(request.data) - validation_errors = validate_callback_data(data=data, - fields=['status', 'CID'], - client_name=client_name) - if validation_errors: - [current_app.logger.info(e) for e in validation_errors] - return jsonify(result='error', message=validation_errors), 400 + errors = validate_callback_data(data=data, + fields=['status', 'CID'], + client_name=client_name) + if errors: + raise InvalidRequest(errors, status_code=400) success, errors = process_sms_client_response(status=str(data.get('status')), reference=data.get('CID'), client_name=client_name) if errors: - [current_app.logger.info(e) for e in errors] - return jsonify(result='error', message=errors), 400 + raise InvalidRequest(errors, status_code=400) else: return jsonify(result='success', message=success), 200 @@ -169,12 +148,11 @@ def process_mmg_response(): @notifications.route('/notifications/sms/firetext', methods=['POST']) def process_firetext_response(): client_name = 'Firetext' - validation_errors = validate_callback_data(data=request.form, - fields=['status', 'reference'], - client_name=client_name) - if validation_errors: - current_app.logger.info(validation_errors) - return jsonify(result='error', message=validation_errors), 400 + errors = validate_callback_data(data=request.form, + fields=['status', 'reference'], + client_name=client_name) + if errors: + raise InvalidRequest(errors, status_code=400) response_code = request.form.get('code') status = request.form.get('status') @@ -185,8 +163,7 @@ def process_firetext_response(): reference=request.form.get('reference'), client_name=client_name) if errors: - [current_app.logger.info(e) for e in errors] - return jsonify(result='error', message=errors), 400 + raise InvalidRequest(errors, status_code=400) else: return jsonify(result='success', message=success), 200 @@ -199,10 +176,7 @@ def get_notifications(notification_id): @notifications.route('/notifications', methods=['GET']) def get_all_notifications(): - data, errors = notifications_filter_schema.load(request.args) - if errors: - return jsonify(result="error", message=errors), 400 - + data = notifications_filter_schema.load(request.args).data page = data['page'] if 'page' in data else 1 page_size = data['page_size'] if 'page_size' in data else current_app.config.get('PAGE_SIZE') limit_days = data.get('limit_days') @@ -228,10 +202,7 @@ def get_all_notifications(): @notifications.route('/service//notifications', methods=['GET']) @require_admin() def get_all_notifications_for_service(service_id): - data, errors = notifications_filter_schema.load(request.args) - if errors: - return jsonify(result="error", message=errors), 400 - + data = notifications_filter_schema.load(request.args).data page = data['page'] if 'page' in data else 1 page_size = data['page_size'] if 'page_size' in data else current_app.config.get('PAGE_SIZE') limit_days = data.get('limit_days') @@ -259,10 +230,7 @@ def get_all_notifications_for_service(service_id): @notifications.route('/service//job//notifications', methods=['GET']) @require_admin() def get_all_notifications_for_service_job(service_id, job_id): - data, errors = notifications_filter_schema.load(request.args) - if errors: - return jsonify(result="error", message=errors), 400 - + data = notifications_filter_schema.load(request.args).data page = data['page'] if 'page' in data else 1 page_size = data['page_size'] if 'page_size' in data else current_app.config.get('PAGE_SIZE') @@ -317,15 +285,15 @@ def send_notification(notification_type): total_email_count = service_stats.emails_requested if (total_email_count + total_sms_count >= service.message_limit) and service.restricted: - return jsonify(result="error", message='Exceeded send limits ({}) for today'.format( - service.message_limit)), 429 + error = 'Exceeded send limits ({}) for today'.format(service.message_limit) + raise InvalidRequest(error, status_code=429) notification, errors = ( sms_template_notification_schema if notification_type == 'sms' else email_notification_schema ).load(request.get_json()) if errors: - return jsonify(result="error", message=errors), 400 + raise InvalidRequest(errors, status_code=400) template = templates_dao.dao_get_template_by_id_and_service_id( template_id=notification['template'], @@ -334,33 +302,24 @@ def send_notification(notification_type): errors = unarchived_template_schema.validate({'archived': template.archived}) if errors: - return jsonify(result='error', message=errors), 400 + raise InvalidRequest(errors, status_code=400) template_object = Template(template.__dict__, notification.get('personalisation', {})) if template_object.missing_data: - return jsonify( - result="error", - message={ - 'template': ['Missing personalisation: {}'.format( - ", ".join(template_object.missing_data) - )] - } - ), 400 + message = 'Missing personalisation: {}'.format(", ".join(template_object.missing_data)) + errors = {'template': [message]} + raise InvalidRequest(errors, status_code=400) + if template_object.additional_data: - return jsonify( - result="error", - message={ - 'template': ['Personalisation not needed for template: {}'.format( - ", ".join(template_object.additional_data) - )] - } - ), 400 + message = 'Personalisation not needed for template: {}'.format(", ".join(template_object.additional_data)) + errors = {'template': [message]} + raise InvalidRequest(errors, status_code=400) if template_object.replaced_content_count > current_app.config.get('SMS_CHAR_COUNT_LIMIT'): - return jsonify( - result="error", - message={'content': ['Content has a character count greater than the limit of {}'.format( - current_app.config.get('SMS_CHAR_COUNT_LIMIT'))]}), 400 + char_count = current_app.config.get('SMS_CHAR_COUNT_LIMIT') + message = 'Content has a character count greater than the limit of {}'.format(char_count) + errors = {'content': [message]} + raise InvalidRequest(errors, status_code=400) if service.restricted and not allowed_to_send_to( notification['to'], @@ -368,11 +327,9 @@ def send_notification(notification_type): [user.mobile_number, user.email_address] for user in service.users ) ): - return jsonify( - result="error", message={ - 'to': ['Invalid {} for restricted service'.format(first_column_heading[notification_type])] - } - ), 400 + message = 'Invalid {} for restricted service'.format(first_column_heading[notification_type]) + errors = {'to': [message]} + raise InvalidRequest(errors, status_code=400) notification_id = create_uuid() notification.update({"template_version": template.version}) @@ -397,13 +354,9 @@ def send_notification(notification_type): @notifications.route('/notifications/statistics') def get_notification_statistics_for_day(): - data, errors = day_schema.load(request.args) - if errors: - return jsonify(result='error', message=errors), 400 - + data = day_schema.load(request.args).data statistics = notifications_dao.dao_get_potential_notification_statistics_for_day( day=data['day'] ) - data, errors = notifications_statistics_schema.dump(statistics, many=True) return jsonify(data=data), 200 diff --git a/app/notifications_statistics/rest.py b/app/notifications_statistics/rest.py index f864e473b..69b1eafab 100644 --- a/app/notifications_statistics/rest.py +++ b/app/notifications_statistics/rest.py @@ -19,7 +19,10 @@ notifications_statistics = Blueprint( __name__, url_prefix='/service//notifications-statistics' ) -from app.errors import register_errors +from app.errors import ( + register_errors, + InvalidRequest +) register_errors(notifications_statistics) @@ -34,9 +37,9 @@ def get_all_notification_statistics_for_service(service_id): limit_days=int(request.args['limit_days']) ) except ValueError as e: - error = '{} is not an integer'.format(request.args['limit_days']) - current_app.logger.error(error) - return jsonify(result="error", message={'limit_days': [error]}), 400 + message = '{} is not an integer'.format(request.args['limit_days']) + errors = {'limit_days': [message]} + raise InvalidRequest(errors, status_code=400) else: statistics = dao_get_notification_statistics_for_service(service_id=service_id) @@ -46,9 +49,7 @@ def get_all_notification_statistics_for_service(service_id): @notifications_statistics.route('/seven_day_aggregate') def get_notification_statistics_for_service_seven_day_aggregate(service_id): - data, errors = week_aggregate_notification_statistics_schema.load(request.args) - if errors: - return jsonify(result='error', message=errors), 400 + data = week_aggregate_notification_statistics_schema.load(request.args).data date_from = data['date_from'] if 'date_from' in data else date(date.today().year, 4, 1) week_count = data['week_count'] if 'week_count' in data else 52 stats = dao_get_7_day_agg_notification_statistics_for_service( diff --git a/app/provider_details/rest.py b/app/provider_details/rest.py index f976f1560..b075653f5 100644 --- a/app/provider_details/rest.py +++ b/app/provider_details/rest.py @@ -1,25 +1,31 @@ from flask import Blueprint, jsonify, request from app.schemas import provider_details_schema + from app.dao.provider_details_dao import ( get_provider_details, get_provider_details_by_id, - get_provider_details_by_id, dao_update_provider_details ) +from app.errors import ( + register_errors, + InvalidRequest +) + provider_details = Blueprint('provider_details', __name__) +register_errors(provider_details) @provider_details.route('', methods=['GET']) def get_providers(): - data, errors = provider_details_schema.dump(get_provider_details(), many=True) + data = provider_details_schema.dump(get_provider_details(), many=True).data return jsonify(provider_details=data) @provider_details.route('/', methods=['GET']) def get_provider_by_id(provider_details_id): - data, errors = provider_details_schema.dump(get_provider_details_by_id(provider_details_id)) + data = provider_details_schema.dump(get_provider_details_by_id(provider_details_id)).data return jsonify(provider_details=data) @@ -29,14 +35,12 @@ def update_provider_details(provider_details_id): current_data = dict(provider_details_schema.dump(fetched_provider_details).data.items()) current_data.update(request.get_json()) - update_dict, errors = provider_details_schema.load(current_data) - if errors: - return jsonify(result="error", message=errors), 400 + update_dict = provider_details_schema.load(current_data).data if "identifier" in request.get_json().keys(): - return jsonify(message={ - "identifier": ["Not permitted to be updated"] - }, result='error'), 400 + message = "Not permitted to be updated" + errors = {'identifier': [message]} + raise InvalidRequest(errors, status_code=400) dao_update_provider_details(update_dict) return jsonify(provider_details=provider_details_schema.dump(fetched_provider_details).data), 200 diff --git a/app/schemas.py b/app/schemas.py index 07385df81..a9d39c5c6 100644 --- a/app/schemas.py +++ b/app/schemas.py @@ -1,13 +1,16 @@ -from datetime import date +from datetime import ( + datetime, + date +) from flask_marshmallow.fields import fields -from sqlalchemy.orm import load_only from marshmallow import ( post_load, ValidationError, validates, validates_schema, - pre_load + pre_load, + pre_dump ) from marshmallow_sqlalchemy import field_for @@ -46,6 +49,7 @@ def _validate_not_in_future(dte, msg="Date cannot be in the future"): class BaseSchema(ma.ModelSchema): + def __init__(self, load_json=False, *args, **kwargs): self.load_json = load_json super(BaseSchema, self).__init__(*args, **kwargs) @@ -82,12 +86,14 @@ class UserSchema(BaseSchema): exclude = ( "updated_at", "created_at", "user_to_service", "_password", "verify_codes") + strict = True class ProviderDetailsSchema(BaseSchema): class Meta: model = models.ProviderDetails exclude = ("provider_rates", "provider_stats") + strict = True class ServiceSchema(BaseSchema): @@ -97,11 +103,13 @@ class ServiceSchema(BaseSchema): class Meta: model = models.Service exclude = ("updated_at", "created_at", "api_keys", "templates", "jobs", 'old_id') + strict = True class NotificationModelSchema(BaseSchema): class Meta: model = models.Notification + strict = True class BaseTemplateSchema(BaseSchema): @@ -109,6 +117,7 @@ class BaseTemplateSchema(BaseSchema): class Meta: model = models.Template exclude = ("service_id", "jobs") + strict = True class TemplateSchema(BaseTemplateSchema): @@ -142,6 +151,13 @@ class TemplateHistorySchema(BaseSchema): class NotificationsStatisticsSchema(BaseSchema): class Meta: model = models.NotificationStatistics + strict = True + + @pre_dump + def handle_date_str(self, in_data): + if isinstance(in_data, dict) and 'day' in in_data: + in_data['day'] = datetime.strptime(in_data['day'], '%Y-%m-%d').date() + return in_data class ApiKeySchema(BaseSchema): @@ -151,6 +167,7 @@ class ApiKeySchema(BaseSchema): class Meta: model = models.ApiKey exclude = ("service", "secret") + strict = True class JobSchema(BaseSchema): @@ -161,13 +178,22 @@ class JobSchema(BaseSchema): class Meta: model = models.Job exclude = ('notifications',) + strict = True class RequestVerifyCodeSchema(ma.Schema): + + class Meta: + strict = True + to = fields.Str(required=False) class NotificationSchema(ma.Schema): + + class Meta: + strict = True + personalisation = fields.Dict(required=False) @@ -225,12 +251,14 @@ class NotificationStatusSchema(BaseSchema): class Meta: model = models.Notification + strict = True class InvitedUserSchema(BaseSchema): class Meta: model = models.InvitedUser + strict = True @validates('email_address') def validate_to(self, value): @@ -255,9 +283,14 @@ class PermissionSchema(BaseSchema): class Meta: model = models.Permission exclude = ("created_at",) + strict = True class EmailDataSchema(ma.Schema): + + class Meta: + strict = True + email = fields.Str(required=False) @validates('email') @@ -269,6 +302,10 @@ class EmailDataSchema(ma.Schema): class NotificationsFilterSchema(ma.Schema): + + class Meta: + strict = True + template_type = fields.Nested(BaseTemplateSchema, only=['template_type'], many=True) status = fields.Nested(NotificationModelSchema, only=['status'], many=True) page = fields.Int(required=False) @@ -309,6 +346,7 @@ class TemplateStatisticsSchema(BaseSchema): class Meta: model = models.TemplateStatistics + strict = True class ServiceHistorySchema(ma.Schema): @@ -337,10 +375,14 @@ class ApiKeyHistorySchema(ma.Schema): class EventSchema(BaseSchema): class Meta: model = models.Event + strict = True class FromToDateSchema(ma.Schema): + class Meta: + strict = True + date_from = fields.Date() date_to = fields.Date() @@ -361,6 +403,10 @@ class FromToDateSchema(ma.Schema): class DaySchema(ma.Schema): + + class Meta: + strict = True + day = fields.Date(required=True) @validates('day') @@ -370,6 +416,9 @@ class DaySchema(ma.Schema): class WeekAggregateNotificationStatisticsSchema(ma.Schema): + class Meta: + strict = True + date_from = fields.Date() week_count = fields.Int() diff --git a/app/service/rest.py b/app/service/rest.py index 0d2da9cb0..1b01ec9a1 100644 --- a/app/service/rest.py +++ b/app/service/rest.py @@ -6,9 +6,9 @@ from datetime import ( from flask import ( jsonify, request, - abort, Blueprint ) + from sqlalchemy.orm.exc import NoResultFound from app.dao.api_key_dao import ( @@ -39,11 +39,12 @@ from app.schemas import ( permission_schema ) -from app.errors import register_errors +from app.errors import ( + register_errors, + InvalidRequest +) service = Blueprint('service', __name__) - - register_errors(service) @@ -54,7 +55,7 @@ def get_services(): services = dao_fetch_all_services_by_user(user_id) else: services = dao_fetch_all_services() - data, errors = service_schema.dump(services, many=True) + data = service_schema.dump(services, many=True).data return jsonify(data=data) @@ -66,7 +67,7 @@ def get_service_by_id(service_id): else: fetched = dao_fetch_service_by_id(service_id) - data, errors = service_schema.dump(fetched) + data = service_schema.dump(fetched).data return jsonify(data=data) @@ -74,16 +75,12 @@ def get_service_by_id(service_id): def create_service(): data = request.get_json() if not data.get('user_id', None): - return jsonify(result="error", message={'user_id': ['Missing data for required field.']}), 400 + errors = {'user_id': ['Missing data for required field.']} + raise InvalidRequest(errors, status_code=400) user = get_model_users(data['user_id']) - data.pop('user_id', None) - valid_service, errors = service_schema.load(request.get_json()) - - if errors: - return jsonify(result="error", message=errors), 400 - + valid_service = service_schema.load(request.get_json()).data dao_create_service(valid_service, user) return jsonify(data=service_schema.dump(valid_service).data), 201 @@ -93,9 +90,7 @@ def update_service(service_id): fetched_service = dao_fetch_service_by_id(service_id) current_data = dict(service_schema.dump(fetched_service).data.items()) current_data.update(request.get_json()) - update_dict, errors = service_schema.load(current_data) - if errors: - return jsonify(result="error", message=errors), 400 + update_dict = service_schema.load(current_data).data dao_update_service(update_dict) return jsonify(data=service_schema.dump(fetched_service).data), 200 @@ -103,14 +98,9 @@ def update_service(service_id): @service.route('//api-key', methods=['POST']) def renew_api_key(service_id=None): fetched_service = dao_fetch_service_by_id(service_id=service_id) - - valid_api_key, errors = api_key_schema.load(request.get_json()) - if errors: - return jsonify(result="error", message=errors), 400 + valid_api_key = api_key_schema.load(request.get_json()).data valid_api_key.service = fetched_service - save_model_api_key(valid_api_key) - unsigned_api_key = get_unsigned_secret(valid_api_key.id) return jsonify(data=unsigned_api_key), 201 @@ -133,7 +123,8 @@ def get_api_keys(service_id, key_id=None): else: api_keys = get_model_api_keys(service_id=service_id) except NoResultFound: - return jsonify(result="error", message="API key not found for id: {}".format(service_id)), 404 + error = "API key not found for id: {}".format(service_id) + raise InvalidRequest(error, status_code=404) return jsonify(apiKeys=api_key_schema.dump(api_keys, many=True).data), 200 @@ -141,7 +132,6 @@ def get_api_keys(service_id, key_id=None): @service.route('//users', methods=['GET']) def get_users_for_service(service_id): fetched = dao_fetch_service_by_id(service_id) - result = user_schema.dump(fetched.users, many=True) return jsonify(data=result.data) @@ -152,15 +142,12 @@ def add_user_to_service(service_id, user_id): user = get_model_users(user_id=user_id) if user in service.users: - return jsonify(result='error', - message='User id: {} already part of service id: {}'.format(user_id, service_id)), 400 - - permissions, errors = permission_schema.load(request.get_json(), many=True) - if errors: - abort(400, errors) + error = 'User id: {} already part of service id: {}'.format(user_id, service_id) + raise InvalidRequest(error, status_code=400) + permissions = permission_schema.load(request.get_json(), many=True).data dao_add_user_to_service(service, user, permissions) - data, errors = service_schema.dump(service) + data = service_schema.dump(service).data return jsonify(data=data), 201 @@ -169,13 +156,13 @@ def remove_user_from_service(service_id, user_id): service = dao_fetch_service_by_id(service_id) user = get_model_users(user_id=user_id) if user not in service.users: - return jsonify( - result='error', - message='User not found'), 404 + error = 'User not found' + raise InvalidRequest(error, status_code=404) + elif len(service.users) == 1: - return jsonify( - result='error', - message='You cannot remove the only user for a service'), 400 + error = 'You cannot remove the only user for a service' + raise InvalidRequest(error, status_code=400) + dao_remove_user_from_service(service, user) return jsonify({}), 204 @@ -183,10 +170,7 @@ def remove_user_from_service(service_id, user_id): @service.route('//fragment/aggregate_statistics') def get_service_provider_aggregate_statistics(service_id): service = dao_fetch_service_by_id(service_id) - data, errors = from_to_date_schema.load(request.args) - if errors: - return jsonify(result='error', message=errors), 400 - + data = from_to_date_schema.load(request.args).data return jsonify(data=get_fragment_count( service, date_from=(data.pop('date_from') if 'date_from' in data else date.today()), @@ -208,21 +192,15 @@ def get_service_history(service_id): ) service_history = Service.get_history_model().query.filter_by(id=service_id).all() - service_data, errors = service_history_schema.dump(service_history, many=True) - if errors: - return jsonify(result="error", message=errors), 400 - + service_data = service_history_schema.dump(service_history, many=True).data api_key_history = ApiKey.get_history_model().query.filter_by(service_id=service_id).all() - - api_keys_data, errors = api_key_history_schema.dump(api_key_history, many=True) - if errors: - return jsonify(result="error", message=errors), 400 + api_keys_data = api_key_history_schema.dump(api_key_history, many=True).data template_history = Template.get_history_model().query.filter_by(service_id=service_id).all() template_data, errors = template_history_schema.dump(template_history, many=True) events = Event.query.all() - events_data, errors = event_schema.dump(events, many=True) + events_data = event_schema.dump(events, many=True).data data = { 'service_history': service_data, diff --git a/app/template/rest.py b/app/template/rest.py index 168d483ea..dc285e9d1 100644 --- a/app/template/rest.py +++ b/app/template/rest.py @@ -19,36 +19,32 @@ from app.schemas import (template_schema, template_history_schema) template = Blueprint('template', __name__, url_prefix='/service//template') -from app.errors import register_errors +from app.errors import ( + register_errors, + InvalidRequest +) register_errors(template) def _content_count_greater_than_limit(content, template_type): template = Template({'content': content, 'template_type': template_type}) - if template_type == 'sms' and \ - template.content_count > current_app.config.get('SMS_CHAR_COUNT_LIMIT'): - return True, jsonify( - result="error", - message={'content': ['Content has a character count greater than the limit of {}'.format( - current_app.config.get('SMS_CHAR_COUNT_LIMIT'))]} - ) - return False, '' + return template_type == 'sms' and template.content_count > current_app.config.get('SMS_CHAR_COUNT_LIMIT') @template.route('', methods=['POST']) def create_template(service_id): fetched_service = dao_fetch_service_by_id(service_id=service_id) - new_template, errors = template_schema.load(request.get_json()) - if errors: - return jsonify(result="error", message=errors), 400 + new_template = template_schema.load(request.get_json()).data new_template.service = fetched_service new_template.content = _strip_html(new_template.content) - over_limit, json_resp = _content_count_greater_than_limit( - new_template.content, - new_template.template_type) + over_limit = _content_count_greater_than_limit(new_template.content, new_template.template_type) if over_limit: - return json_resp, 400 + char_count_limit = current_app.config.get('SMS_CHAR_COUNT_LIMIT') + message = 'Content has a character count greater than the limit of {}'.format(char_count_limit) + errors = {'content': [message]} + raise InvalidRequest(errors, status_code=400) + dao_create_template(new_template) return jsonify(data=template_schema.dump(new_template).data), 201 @@ -65,14 +61,13 @@ def update_template(service_id, template_id): if _template_has_not_changed(current_data, updated_template): return jsonify(data=updated_template), 200 - update_dict, errors = template_schema.load(updated_template) - if errors: - return jsonify(result="error", message=errors), 400 - over_limit, json_resp = _content_count_greater_than_limit( - updated_template['content'], - fetched_template.template_type) + update_dict = template_schema.load(updated_template).data + over_limit = _content_count_greater_than_limit(updated_template['content'], fetched_template.template_type) if over_limit: - return json_resp, 400 + char_count_limit = current_app.config.get('SMS_CHAR_COUNT_LIMIT') + message = 'Content has a character count greater than the limit of {}'.format(char_count_limit) + errors = {'content': [message]} + raise InvalidRequest(errors, status_code=400) dao_update_template(update_dict) return jsonify(data=template_schema.dump(update_dict).data), 200 @@ -80,39 +75,35 @@ def update_template(service_id, template_id): @template.route('', methods=['GET']) def get_all_templates_for_service(service_id): templates = dao_get_all_templates_for_service(service_id=service_id) - data, errors = template_schema.dump(templates, many=True) + data = template_schema.dump(templates, many=True).data return jsonify(data=data) @template.route('/', methods=['GET']) def get_template_by_id_and_service_id(service_id, template_id): fetched_template = dao_get_template_by_id_and_service_id(template_id=template_id, service_id=service_id) - data, errors = template_schema.dump(fetched_template) + data = template_schema.dump(fetched_template).data return jsonify(data=data) @template.route('//version/') def get_template_version(service_id, template_id, version): - data, errors = template_history_schema.dump( + data = template_history_schema.dump( dao_get_template_by_id_and_service_id( template_id=template_id, service_id=service_id, version=version ) - ) - if errors: - return jsonify(result='error', message=errors), 400 + ).data return jsonify(data=data) @template.route('//versions') def get_template_versions(service_id, template_id): - data, errors = template_history_schema.dump( + data = template_history_schema.dump( dao_get_template_versions(service_id=service_id, template_id=template_id), many=True - ) - if errors: - return jsonify(result='error', message=errors), 400 + ).data return jsonify(data=data) diff --git a/app/template_statistics/rest.py b/app/template_statistics/rest.py index 1e46843c5..d4990ba04 100644 --- a/app/template_statistics/rest.py +++ b/app/template_statistics/rest.py @@ -15,7 +15,7 @@ template_statistics = Blueprint('template-statistics', __name__, url_prefix='/service//template-statistics') -from app.errors import register_errors, InvalidData +from app.errors import register_errors, InvalidRequest register_errors(template_statistics) @@ -28,20 +28,16 @@ def get_template_statistics_for_service(service_id): except ValueError as e: error = '{} is not an integer'.format(request.args['limit_days']) message = {'limit_days': [error]} - raise InvalidData(message, status_code=400) + raise InvalidRequest(message, status_code=400) else: limit_days = None stats = dao_get_template_statistics_for_service(service_id, limit_days=limit_days) - data, errors = template_statistics_schema.dump(stats, many=True) - if errors: - raise InvalidData(errors, status_code=400) + data = template_statistics_schema.dump(stats, many=True).data return jsonify(data=data) @template_statistics.route('/') def get_template_statistics_for_template_id(service_id, template_id): stats = dao_get_template_statistics_for_template(template_id) - data, errors = template_statistics_schema.dump(stats, many=True) - if errors: - raise InvalidData(errors, status_code=400) + data = template_statistics_schema.dump(stats, many=True).data return jsonify(data=data) diff --git a/app/user/rest.py b/app/user/rest.py index 46668b98c..93d7f1d9e 100644 --- a/app/user/rest.py +++ b/app/user/rest.py @@ -1,7 +1,7 @@ import json import uuid from datetime import datetime -from flask import (jsonify, request, abort, Blueprint, current_app) +from flask import (jsonify, request, Blueprint, current_app) from app import encryption, DATETIME_FORMAT from app.dao.users_dao import ( get_model_users, @@ -28,9 +28,13 @@ from app.schemas import ( from app.celery.tasks import ( send_sms, email_reset_password, - send_email) + send_email +) -from app.errors import register_errors +from app.errors import ( + register_errors, + InvalidRequest +) user = Blueprint('user', __name__) register_errors(user) @@ -43,9 +47,7 @@ def create_user(): # TODO password policy, what is valid password if not req_json.get('password', None): errors.update({'password': ['Missing data for required field.']}) - return jsonify(result="error", message=errors), 400 - if errors: - return jsonify(result="error", message=errors), 400 + raise InvalidRequest(errors, status_code=400) save_model_user(user_to_create, pwd=req_json.get('password')) return jsonify(data=user_schema.dump(user_to_create).data), 201 @@ -60,11 +62,9 @@ def update_user(user_id): # but would be good to have the same validation here. if pwd is not None and not pwd: errors.update({'password': ['Invalid data for field']}) - if errors: - return jsonify(result="error", message=errors), 400 - status_code = 200 + raise InvalidRequest(errors, status_code=400) save_model_user(user_to_update, update_dict=update_dct, pwd=pwd) - return jsonify(data=user_schema.dump(user_to_update).data), status_code + return jsonify(data=user_schema.dump(user_to_update).data), 200 @user.route('//verify/password', methods=['POST']) @@ -75,9 +75,10 @@ def verify_user_password(user_id): try: txt_pwd = request.get_json()['password'] except KeyError: - return jsonify( - result="error", - message={'password': ['Required field missing data']}), 400 + message = 'Required field missing data' + errors = {'password': [message]} + raise InvalidRequest(errors, status_code=400) + if user_to_verify.check_password(txt_pwd): user_to_verify.logged_in_at = datetime.utcnow() save_model_user(user_to_verify) @@ -85,7 +86,9 @@ def verify_user_password(user_id): return jsonify({}), 204 else: increment_failed_login_count(user_to_verify) - return jsonify(result='error', message={'password': ['Incorrect password']}), 400 + message = 'Incorrect password' + errors = {'password': [message]} + raise InvalidRequest(errors, status_code=400) @user.route('//verify/code', methods=['POST']) @@ -105,12 +108,13 @@ def verify_user_code(user_id): except KeyError: errors.update({'code_type': ['Required field missing data']}) if errors: - return jsonify(result="error", message=errors), 400 + raise InvalidRequest(errors, status_code=400) + code = get_user_code(user_to_verify, txt_code, txt_type) if not code: - return jsonify(result="error", message="Code not found"), 404 + raise InvalidRequest("Code not found", status_code=404) if datetime.utcnow() > code.expiry_datetime or code.code_used: - return jsonify(result="error", message="Code has expired"), 400 + raise InvalidRequest("Code has expired", status_code=400) use_user_code(code.id) return jsonify({}), 204 @@ -119,8 +123,6 @@ def verify_user_code(user_id): def send_user_sms_code(user_id): user_to_send_to = get_model_users(user_id=user_id) verify_code, errors = request_verify_code_schema.load(request.get_json()) - if errors: - return jsonify(result="error", message=errors), 400 secret_code = create_secret_code() create_user_code(user_to_send_to, secret_code, 'sms') @@ -150,8 +152,6 @@ def send_user_sms_code(user_id): def send_user_email_verification(user_id): user_to_send_to = get_model_users(user_id=user_id) verify_code, errors = request_verify_code_schema.load(request.get_json()) - if errors: - return jsonify(result="error", message=errors), 400 secret_code = create_secret_code() create_user_code(user_to_send_to, secret_code, 'email') @@ -191,8 +191,7 @@ def set_permissions(user_id, service_id): user = get_model_users(user_id=user_id) service = dao_fetch_service_by_id(service_id=service_id) permissions, errors = permission_schema.load(request.get_json(), many=True) - if errors: - abort(400, errors) + for p in permissions: p.user = user p.service = service @@ -204,7 +203,8 @@ def set_permissions(user_id, service_id): def get_by_email(): email = request.args.get('email') if not email: - return jsonify(result="error", message="invalid request"), 400 + error = 'Invalid request. Email query string param required' + raise InvalidRequest(error, status_code=400) fetched_user = get_user_by_email(email) result = user_schema.dump(fetched_user) @@ -214,8 +214,6 @@ def get_by_email(): @user.route('/reset-password', methods=['POST']) def send_user_reset_password(): email, errors = email_data_request_schema.load(request.get_json()) - if errors: - return jsonify(result="error", message=errors), 400 user_to_send_to = get_user_by_email(email['email']) diff --git a/tests/app/dao/test_templates_dao.py b/tests/app/dao/test_templates_dao.py index ad4b4372e..82b943963 100644 --- a/tests/app/dao/test_templates_dao.py +++ b/tests/app/dao/test_templates_dao.py @@ -267,6 +267,6 @@ def test_get_template_versions(sample_template): assert x.content == 'new version' else: assert x.content == original_content - from app.schemas import (template_history_schema) + from app.schemas import template_history_schema v = template_history_schema.load(versions, many=True) assert v.__len__() == 2 diff --git a/tests/app/user/test_rest.py b/tests/app/user/test_rest.py index 82b93d4cb..4f6286f6f 100644 --- a/tests/app/user/test_rest.py +++ b/tests/app/user/test_rest.py @@ -296,7 +296,7 @@ def test_get_user_by_email_bad_url_returns_404(notify_api, assert resp.status_code == 400 json_resp = json.loads(resp.get_data(as_text=True)) assert json_resp['result'] == 'error' - assert json_resp['message'] == 'invalid request' + assert json_resp['message'] == 'Invalid request. Email query string param required' def test_get_user_with_permissions(notify_api,