From 99bc29418ec368acc1659f4aa475e3c6bd96fcac Mon Sep 17 00:00:00 2001 From: Ben Thorner Date: Tue, 27 Apr 2021 10:35:21 +0100 Subject: [PATCH] Move request_id injection into send_task override This applies the same change we made in other apps [1][2]. Adding the override here is special, though, because it means the others will now get triggered, since this app is the start of the chain of tasks for a request. We will also retain existing request_id tracing for tasks within this app, since "apply_async" calls the "send_task" method internally, which is the one we're overriding. [1]: https://github.com/alphagov/notifications-template-preview/pull/531/commits/6f3c118a1e071f8169698806450b45aa685859aa [2]: https://github.com/alphagov/notifications-antivirus/pull/69/commits/2e08b7aa954c120eb0265fda41ad2025616426ba --- app/celery/celery.py | 20 ++++++++++---------- tests/app/celery/test_celery.py | 32 ++++++++++++++++---------------- 2 files changed, 26 insertions(+), 26 deletions(-) diff --git a/app/celery/celery.py b/app/celery/celery.py index bd0b0b3de..16c50da23 100644 --- a/app/celery/celery.py +++ b/app/celery/celery.py @@ -73,16 +73,6 @@ def make_task(app): return super().__call__(*args, **kwargs) - def apply_async(self, args=None, kwargs=None, **other_kwargs): - kwargs = kwargs or {} - - if has_request_context() and hasattr(request, 'request_id'): - kwargs['request_id'] = request.request_id - elif has_app_context() and 'request_id' in g: - kwargs['request_id'] = g.request_id - - return super().apply_async(args, kwargs, **other_kwargs) - return NotifyTask @@ -97,3 +87,13 @@ class NotifyCelery(Celery): self.conf.update(app.config) self._app = app + + def send_task(self, name, args=None, kwargs=None, **other_kwargs): + kwargs = kwargs or {} + + if has_request_context() and hasattr(request, 'request_id'): + kwargs['request_id'] = request.request_id + elif has_app_context() and 'request_id' in g: + kwargs['request_id'] = g.request_id + + return super().send_task(name, args, kwargs, **other_kwargs) diff --git a/tests/app/celery/test_celery.py b/tests/app/celery/test_celery.py index 73d976c87..3406cae38 100644 --- a/tests/app/celery/test_celery.py +++ b/tests/app/celery/test_celery.py @@ -85,33 +85,33 @@ def test_call_exports_request_id_from_kwargs(mocker, celery_task): assert g.request_id == '1234' -def test_apply_async_injects_global_request_id_into_kwargs(mocker, celery_task): - super_apply = mocker.patch('celery.app.task.Task.apply_async') +def test_send_task_injects_global_request_id_into_kwargs(mocker, notify_api): + super_apply = mocker.patch('celery.Celery.send_task') g.request_id = '1234' - celery_task.apply_async() - super_apply.assert_called_with(None, {'request_id': '1234'}) + notify_celery.send_task('some-task') + super_apply.assert_called_with('some-task', None, {'request_id': '1234'}) -def test_apply_async_inject_request_id_with_other_kwargs(mocker, celery_task): - super_apply = mocker.patch('celery.app.task.Task.apply_async') +def test_send_task_injects_request_id_with_other_kwargs(mocker, notify_api): + super_apply = mocker.patch('celery.Celery.send_task') g.request_id = '1234' - celery_task.apply_async(kwargs={'something': 'else'}) - super_apply.assert_called_with(None, {'request_id': '1234', 'something': 'else'}) + notify_celery.send_task('some-task', kwargs={'something': 'else'}) + super_apply.assert_called_with('some-task', None, {'request_id': '1234', 'something': 'else'}) -def test_apply_async_inject_request_id_with_positional_args(mocker, celery_task): - super_apply = mocker.patch('celery.app.task.Task.apply_async') +def test_send_task_injects_request_id_with_positional_args(mocker, notify_api): + super_apply = mocker.patch('celery.Celery.send_task') g.request_id = '1234' - celery_task.apply_async(['args'], {'something': 'else'}) - super_apply.assert_called_with(['args'], {'request_id': '1234', 'something': 'else'}) + notify_celery.send_task('some-task', ['args'], {'kw': 'args'}) + super_apply.assert_called_with('some-task', ['args'], {'request_id': '1234', 'kw': 'args'}) -def test_apply_async_injects_id_into_kwargs_from_request(mocker, notify_api, celery_task): - super_apply = mocker.patch('celery.app.task.Task.apply_async') +def test_send_task_injects_id_into_kwargs_from_request(mocker, notify_api): + super_apply = mocker.patch('celery.Celery.send_task') request_id_header = notify_api.config['NOTIFY_TRACE_ID_HEADER'] request_headers = {request_id_header: '1234'} with notify_api.test_request_context(headers=request_headers): - celery_task.apply_async() + notify_celery.send_task('some-task') - super_apply.assert_called_with(None, {'request_id': '1234'}) + super_apply.assert_called_with('some-task', None, {'request_id': '1234'})