From be0257314710c3d4e1214687e39a916469f2a26a Mon Sep 17 00:00:00 2001 From: Ben Thorner Date: Thu, 15 Apr 2021 12:57:12 +0100 Subject: [PATCH] Fix apply_async not working with positional kwargs Celery's apply_async function accepts 'kwargs' as (get ready to be confused) either a positional argument, or a keyword argument: Positional: apply_async(['args'], {'kw': 'args'}) Keyword: apply_async(args=['args'], kwargs={'kw': 'args'}) We rely on the positional form in at least one place [1]. This fixes the overload of apply_async to cope with both forms, and continue to pass through any other (confusion time again) keyword args to super(), such as queue="queue". Note that we've also decided to stop accepting other positional args, since this is unnecessarily confusing, and we don't currently rely on it in our code. This stops it creeping in in future. [1]: https://github.com/alphagov/notifications-api/blob/fde927e00ed9f117f33ae4eedf4c3a05e4db207b/app/job/rest.py#L186 --- app/celery/celery.py | 10 +++++----- tests/app/celery/test_celery.py | 18 ++++++++++++++++-- 2 files changed, 21 insertions(+), 7 deletions(-) diff --git a/app/celery/celery.py b/app/celery/celery.py index 819bd12c3..bd0b0b3de 100644 --- a/app/celery/celery.py +++ b/app/celery/celery.py @@ -73,15 +73,15 @@ def make_task(app): return super().__call__(*args, **kwargs) - def apply_async(self, *args, **kwargs): - kwargs['kwargs'] = kwargs.get('kwargs', {}) + def apply_async(self, args=None, kwargs=None, **other_kwargs): + kwargs = kwargs or {} if has_request_context() and hasattr(request, 'request_id'): - kwargs['kwargs']['request_id'] = request.request_id + kwargs['request_id'] = request.request_id elif has_app_context() and 'request_id' in g: - kwargs['kwargs']['request_id'] = g.request_id + kwargs['request_id'] = g.request_id - return super().apply_async(*args, **kwargs) + return super().apply_async(args, kwargs, **other_kwargs) return NotifyTask diff --git a/tests/app/celery/test_celery.py b/tests/app/celery/test_celery.py index ac28248fe..73d976c87 100644 --- a/tests/app/celery/test_celery.py +++ b/tests/app/celery/test_celery.py @@ -89,7 +89,21 @@ def test_apply_async_injects_global_request_id_into_kwargs(mocker, celery_task): super_apply = mocker.patch('celery.app.task.Task.apply_async') g.request_id = '1234' celery_task.apply_async() - super_apply.assert_called_with(kwargs={'request_id': '1234'}) + super_apply.assert_called_with(None, {'request_id': '1234'}) + + +def test_apply_async_inject_request_id_with_other_kwargs(mocker, celery_task): + super_apply = mocker.patch('celery.app.task.Task.apply_async') + g.request_id = '1234' + celery_task.apply_async(kwargs={'something': 'else'}) + super_apply.assert_called_with(None, {'request_id': '1234', 'something': 'else'}) + + +def test_apply_async_inject_request_id_with_positional_args(mocker, celery_task): + super_apply = mocker.patch('celery.app.task.Task.apply_async') + g.request_id = '1234' + celery_task.apply_async(['args'], {'something': 'else'}) + super_apply.assert_called_with(['args'], {'request_id': '1234', 'something': 'else'}) def test_apply_async_injects_id_into_kwargs_from_request(mocker, notify_api, celery_task): @@ -100,4 +114,4 @@ def test_apply_async_injects_id_into_kwargs_from_request(mocker, notify_api, cel with notify_api.test_request_context(headers=request_headers): celery_task.apply_async() - super_apply.assert_called_with(kwargs={'request_id': '1234'}) + super_apply.assert_called_with(None, {'request_id': '1234'})