From ec6d87cd0fba7cbb3214a1f844912586b821d5d5 Mon Sep 17 00:00:00 2001 From: Ben Thorner Date: Tue, 13 Apr 2021 14:49:15 +0100 Subject: [PATCH 1/2] Simplify argument passing in apply_async This avoids the need to keep in-sync with any future changes to the signature, and reduces the amount of irrelevant code to read. --- app/celery/celery.py | 11 +++++------ tests/app/celery/test_celery.py | 19 ++----------------- 2 files changed, 7 insertions(+), 23 deletions(-) diff --git a/app/celery/celery.py b/app/celery/celery.py index ecc7486b8..546dac16d 100644 --- a/app/celery/celery.py +++ b/app/celery/celery.py @@ -73,16 +73,15 @@ def make_task(app, statsd_client): return super().__call__(*args, **kwargs) - def apply_async(self, args=None, kwargs=None, task_id=None, producer=None, - link=None, link_error=None, **options): - kwargs = kwargs or {} + def apply_async(self, *args, **kwargs): + kwargs['kwargs'] = kwargs.get('kwargs', {}) if has_request_context() and hasattr(request, 'request_id'): - kwargs['request_id'] = request.request_id + kwargs['kwargs']['request_id'] = request.request_id elif has_app_context() and 'request_id' in g: - kwargs['request_id'] = g.request_id + kwargs['kwargs']['request_id'] = g.request_id - return super().apply_async(args, kwargs, task_id, producer, link, link_error, **options) + return super().apply_async(*args, **kwargs) return NotifyTask diff --git a/tests/app/celery/test_celery.py b/tests/app/celery/test_celery.py index 0acbd8d23..ac28248fe 100644 --- a/tests/app/celery/test_celery.py +++ b/tests/app/celery/test_celery.py @@ -89,15 +89,7 @@ def test_apply_async_injects_global_request_id_into_kwargs(mocker, celery_task): super_apply = mocker.patch('celery.app.task.Task.apply_async') g.request_id = '1234' celery_task.apply_async() - - super_apply.assert_called_with( - None, - {'request_id': '1234'}, - None, - None, - None, - None - ) + super_apply.assert_called_with(kwargs={'request_id': '1234'}) def test_apply_async_injects_id_into_kwargs_from_request(mocker, notify_api, celery_task): @@ -108,11 +100,4 @@ def test_apply_async_injects_id_into_kwargs_from_request(mocker, notify_api, cel with notify_api.test_request_context(headers=request_headers): celery_task.apply_async() - super_apply.assert_called_with( - None, - {'request_id': '1234'}, - None, - None, - None, - None - ) + super_apply.assert_called_with(kwargs={'request_id': '1234'}) From 5eb265138be1f67156128bb15e43c90556d5c84b Mon Sep 17 00:00:00 2001 From: Ben Thorner Date: Tue, 13 Apr 2021 14:53:46 +0100 Subject: [PATCH 2/2] Remove unnecessary statsd_client parameter It turns out this is available from the app object [1], and we were already assuming this in the tests. [1]: https://github.com/alphagov/notifications-utils/blob/48c6c822e85d0d1893d2c239e14706cfe0ad8e16/notifications_utils/clients/statsd/statsd_client.py#L52 --- app/__init__.py | 2 +- app/celery/celery.py | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/app/__init__.py b/app/__init__.py index 3b0fd05af..57cd1c2ef 100644 --- a/app/__init__.py +++ b/app/__init__.py @@ -110,7 +110,7 @@ def create_app(application): email_clients = [aws_ses_stub_client] if application.config['SES_STUB_URL'] else [aws_ses_client] notification_provider_clients.init_app(sms_clients=[firetext_client, mmg_client], email_clients=email_clients) - notify_celery.init_app(application, statsd_client) + notify_celery.init_app(application) encryption.init_app(application) redis_store.init_app(application) document_download_client.init_app(application) diff --git a/app/celery/celery.py b/app/celery/celery.py index 546dac16d..819bd12c3 100644 --- a/app/celery/celery.py +++ b/app/celery/celery.py @@ -18,7 +18,7 @@ def log_on_worker_shutdown(sender, signal, pid, exitcode, **kwargs): notify_celery._app.logger.info('worker shutdown: PID: {} Exitcode: {}'.format(pid, exitcode)) -def make_task(app, statsd_client): +def make_task(app): class NotifyTask(Task): abstract = True start = None @@ -36,7 +36,7 @@ def make_task(app, statsd_client): ) ) - statsd_client.timing( + app.statsd_client.timing( "celery.{queue_name}.{task_name}.success".format( task_name=self.name, queue_name=queue_name @@ -54,7 +54,7 @@ def make_task(app, statsd_client): ) ) - statsd_client.incr( + app.statsd_client.incr( "celery.{queue_name}.{task_name}.failure".format( task_name=self.name, queue_name=queue_name @@ -88,11 +88,11 @@ def make_task(app, statsd_client): class NotifyCelery(Celery): - def init_app(self, app, statsd_client): + def init_app(self, app): super().__init__( app.import_name, broker=app.config['BROKER_URL'], - task_cls=make_task(app, statsd_client), + task_cls=make_task(app), ) self.conf.update(app.config)