diff --git a/app/celery/celery.py b/app/celery/celery.py index 819bd12c3..bd0b0b3de 100644 --- a/app/celery/celery.py +++ b/app/celery/celery.py @@ -73,15 +73,15 @@ def make_task(app): return super().__call__(*args, **kwargs) - def apply_async(self, *args, **kwargs): - kwargs['kwargs'] = kwargs.get('kwargs', {}) + def apply_async(self, args=None, kwargs=None, **other_kwargs): + kwargs = kwargs or {} if has_request_context() and hasattr(request, 'request_id'): - kwargs['kwargs']['request_id'] = request.request_id + kwargs['request_id'] = request.request_id elif has_app_context() and 'request_id' in g: - kwargs['kwargs']['request_id'] = g.request_id + kwargs['request_id'] = g.request_id - return super().apply_async(*args, **kwargs) + return super().apply_async(args, kwargs, **other_kwargs) return NotifyTask diff --git a/tests/app/celery/test_celery.py b/tests/app/celery/test_celery.py index ac28248fe..73d976c87 100644 --- a/tests/app/celery/test_celery.py +++ b/tests/app/celery/test_celery.py @@ -89,7 +89,21 @@ def test_apply_async_injects_global_request_id_into_kwargs(mocker, celery_task): super_apply = mocker.patch('celery.app.task.Task.apply_async') g.request_id = '1234' celery_task.apply_async() - super_apply.assert_called_with(kwargs={'request_id': '1234'}) + super_apply.assert_called_with(None, {'request_id': '1234'}) + + +def test_apply_async_inject_request_id_with_other_kwargs(mocker, celery_task): + super_apply = mocker.patch('celery.app.task.Task.apply_async') + g.request_id = '1234' + celery_task.apply_async(kwargs={'something': 'else'}) + super_apply.assert_called_with(None, {'request_id': '1234', 'something': 'else'}) + + +def test_apply_async_inject_request_id_with_positional_args(mocker, celery_task): + super_apply = mocker.patch('celery.app.task.Task.apply_async') + g.request_id = '1234' + celery_task.apply_async(['args'], {'something': 'else'}) + super_apply.assert_called_with(['args'], {'request_id': '1234', 'something': 'else'}) def test_apply_async_injects_id_into_kwargs_from_request(mocker, notify_api, celery_task): @@ -100,4 +114,4 @@ def test_apply_async_injects_id_into_kwargs_from_request(mocker, notify_api, cel with notify_api.test_request_context(headers=request_headers): celery_task.apply_async() - super_apply.assert_called_with(kwargs={'request_id': '1234'}) + super_apply.assert_called_with(None, {'request_id': '1234'})