From fc9b3bea1d0c42e417726b1b090aa6196874f1cd Mon Sep 17 00:00:00 2001 From: Katie Smith Date: Fri, 6 May 2022 15:55:53 +0100 Subject: [PATCH] Make the post and pre decorators take kwargs https://marshmallow.readthedocs.io/en/stable/upgrading.html#decorated-methods-and-handle-error-receive-many-and-partial --- app/schemas.py | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/app/schemas.py b/app/schemas.py index efb914df3..f7883501e 100644 --- a/app/schemas.py +++ b/app/schemas.py @@ -47,7 +47,7 @@ def _validate_datetime_not_in_past(dte, msg="Date cannot be in the past"): class UUIDsAsStringsMixin: @post_dump() - def __post_dump(self, data): + def __post_dump(self, data, **kwargs): for key, value in data.items(): if isinstance(value, UUID): @@ -71,7 +71,7 @@ class BaseSchema(ma.SQLAlchemyAutoSchema): super(BaseSchema, self).__init__(*args, **kwargs) @post_load - def make_instance(self, data): + def make_instance(self, data, **kwargs): """Deserialize data to an instance of the model. Update an existing row if specified in `self.instance` or loaded by primary key(s) in the data; else create a new row. @@ -169,7 +169,7 @@ class UserUpdateAttributeSchema(BaseSchema): raise ValidationError('Invalid phone number: {}'.format(error)) @validates_schema(pass_original=True) - def check_unknown_fields(self, data, original_data): + def check_unknown_fields(self, data, original_data, **kwargs): for key in original_data: if key not in self.fields: raise ValidationError('Unknown field name {}'.format(key)) @@ -181,7 +181,7 @@ class UserUpdatePasswordSchema(BaseSchema): model = models.User @validates_schema(pass_original=True) - def check_unknown_fields(self, data, original_data): + def check_unknown_fields(self, data, original_data, **kwargs): for key in original_data: if key not in self.fields: raise ValidationError('Unknown field name {}'.format(key)) @@ -270,7 +270,7 @@ class ServiceSchema(BaseSchema, UUIDsAsStringsMixin): raise ValidationError('Duplicate Service Permission: {}'.format(duplicates)) @pre_load() - def format_for_data_model(self, in_data): + def format_for_data_model(self, in_data, **kwargs): if isinstance(in_data, dict) and 'permissions' in in_data: str_permissions = in_data['permissions'] permissions = [] @@ -347,7 +347,7 @@ class TemplateSchema(BaseTemplateSchema, UUIDsAsStringsMixin): return template.redact_personalisation @validates_schema - def validate_type(self, data): + def validate_type(self, data, **kwargs): if data.get('template_type') in {models.EMAIL_TYPE, models.LETTER_TYPE}: subject = data.get('subject') if not subject or subject.strip() == '': @@ -377,7 +377,7 @@ class TemplateSchemaNoDetail(TemplateSchema): ) @pre_dump - def remove_content_for_non_broadcast_templates(self, template): + def remove_content_for_non_broadcast_templates(self, template, **kwargs): if template.template_type != models.BROADCAST_TYPE: template.content = None @@ -464,7 +464,7 @@ class SmsNotificationSchema(NotificationSchema): raise ValidationError('Invalid phone number: {}'.format(error)) @post_load - def format_phone_number(self, item): + def format_phone_number(self, item, **kwargs): item['to'] = validate_and_format_phone_number(item['to'], international=True) return item @@ -513,7 +513,7 @@ class NotificationWithTemplateSchema(BaseSchema): key_name = fields.String() @pre_dump - def add_api_key_name(self, in_data): + def add_api_key_name(self, in_data, **kwargs): if in_data.api_key: in_data.key_name = in_data.api_key.name else: @@ -557,12 +557,12 @@ class NotificationWithPersonalisationSchema(NotificationWithTemplateSchema): exclude = () @pre_dump - def handle_personalisation_property(self, in_data): + def handle_personalisation_property(self, in_data, **kwargs): self.personalisation = in_data.personalisation return in_data @post_dump - def handle_template_merge(self, in_data): + def handle_template_merge(self, in_data, **kwargs): in_data['template'] = in_data.pop('template_history') template = get_template_instance(in_data['template'], in_data['personalisation']) in_data['body'] = template.content_with_placeholders_filled_in @@ -624,7 +624,7 @@ class NotificationsFilterSchema(ma.Schema): count_pages = fields.Boolean(required=False) @pre_load - def handle_multidict(self, in_data): + def handle_multidict(self, in_data, **kwargs): if isinstance(in_data, dict) and hasattr(in_data, 'getlist'): out_data = dict([(k, in_data.get(k)) for k in in_data.keys()]) if 'template_type' in in_data: @@ -635,7 +635,7 @@ class NotificationsFilterSchema(ma.Schema): return out_data @post_load - def convert_schema_object_to_field(self, in_data): + def convert_schema_object_to_field(self, in_data, **kwargs): if 'template_type' in in_data: in_data['template_type'] = [x.template_type for x in in_data['template_type']] if 'status' in in_data: @@ -683,7 +683,7 @@ class UnarchivedTemplateSchema(BaseSchema): archived = fields.Boolean(required=True) @validates_schema - def validate_archived(self, data): + def validate_archived(self, data, **kwargs): if data['archived']: raise ValidationError('Template has been deleted', 'template')