diff --git a/app/celery/celery.py b/app/celery/celery.py index 98ca6812f..d44500aec 100644 --- a/app/celery/celery.py +++ b/app/celery/celery.py @@ -68,9 +68,12 @@ def make_task(app): # ensure task has flask context to access config, logger, etc with app.app_context(): self.start = time.monotonic() - # Remove piggyback values from kwargs - # Add 'request_id' to 'g' so that it gets logged - g.request_id = kwargs.pop('request_id', None) + # TEMPORARY: remove old piggyback values from kwargs + kwargs.pop('request_id', None) + # Add 'request_id' to 'g' so that it gets logged. Note + # that each header is a direct attribute of the task + # context (aka "request"). + g.request_id = self.request.get('notify_request_id') return super().__call__(*args, **kwargs) @@ -90,11 +93,11 @@ class NotifyCelery(Celery): self._app = app def send_task(self, name, args=None, kwargs=None, **other_kwargs): - kwargs = kwargs or {} + other_kwargs['headers'] = other_kwargs.get('headers') or {} if has_request_context() and hasattr(request, 'request_id'): - kwargs['request_id'] = request.request_id + other_kwargs['headers']['notify_request_id'] = request.request_id elif has_app_context() and 'request_id' in g: - kwargs['request_id'] = g.request_id + other_kwargs['headers']['notify_request_id'] = g.request_id return super().send_task(name, args, kwargs, **other_kwargs) diff --git a/tests/app/celery/test_celery.py b/tests/app/celery/test_celery.py index 3406cae38..47ce58c2f 100644 --- a/tests/app/celery/test_celery.py +++ b/tests/app/celery/test_celery.py @@ -22,6 +22,15 @@ def async_task(celery_task): celery_task.pop_request() +@pytest.fixture +def request_id_task(celery_task): + # Note that each header is a direct attribute of the + # task context (aka "request"). + celery_task.push_request(notify_request_id='1234') + yield celery_task + celery_task.pop_request() + + def test_success_should_log_and_call_statsd(mocker, notify_api, async_task): statsd = mocker.patch.object(notify_api.statsd_client, 'timing') logger = mocker.patch.object(notify_api.logger, 'info') @@ -78,32 +87,67 @@ def test_failure_queue_when_applied_synchronously(mocker, notify_api, celery_tas logger.assert_called_once_with(f'Celery task {celery_task.name} (queue: none) failed') -def test_call_exports_request_id_from_kwargs(mocker, celery_task): +def test_call_exports_request_id_from_headers(mocker, request_id_task): g = mocker.patch('app.celery.celery.g') - # this would fail if the kwarg was passed through unexpectedly - celery_task(request_id='1234') + request_id_task() assert g.request_id == '1234' -def test_send_task_injects_global_request_id_into_kwargs(mocker, notify_api): +def test_call_copes_if_request_id_not_in_headers(mocker, celery_task): + g = mocker.patch('app.celery.celery.g') + celery_task() + assert g.request_id is None + + +def test_send_task_injects_global_request_id_into_headers(mocker, notify_api): super_apply = mocker.patch('celery.Celery.send_task') g.request_id = '1234' notify_celery.send_task('some-task') - super_apply.assert_called_with('some-task', None, {'request_id': '1234'}) + + super_apply.assert_called_with( + 'some-task', # name + None, # args + None, # kwargs + headers={'notify_request_id': '1234'} # other kwargs + ) -def test_send_task_injects_request_id_with_other_kwargs(mocker, notify_api): +def test_send_task_injects_request_id_with_existing_headers(mocker, notify_api): super_apply = mocker.patch('celery.Celery.send_task') g.request_id = '1234' - notify_celery.send_task('some-task', kwargs={'something': 'else'}) - super_apply.assert_called_with('some-task', None, {'request_id': '1234', 'something': 'else'}) + + notify_celery.send_task( + 'some-task', + None, # args + None, # kwargs + headers={'something': 'else'} # other kwargs + ) + + super_apply.assert_called_with( + 'some-task', # name + None, # args + None, # kwargs + headers={'notify_request_id': '1234', 'something': 'else'} # other kwargs + ) -def test_send_task_injects_request_id_with_positional_args(mocker, notify_api): +def test_send_task_injects_request_id_with_none_headers(mocker, notify_api): super_apply = mocker.patch('celery.Celery.send_task') g.request_id = '1234' - notify_celery.send_task('some-task', ['args'], {'kw': 'args'}) - super_apply.assert_called_with('some-task', ['args'], {'request_id': '1234', 'kw': 'args'}) + + notify_celery.send_task( + 'some-task', + None, # args + None, # kwargs + headers=None, # other kwargs (task retry set headers to "None") + ) + + super_apply.assert_called_with( + 'some-task', # name + None, # args + None, # kwargs + headers={'notify_request_id': '1234'} # other kwargs + ) def test_send_task_injects_id_into_kwargs_from_request(mocker, notify_api): @@ -114,4 +158,9 @@ def test_send_task_injects_id_into_kwargs_from_request(mocker, notify_api): with notify_api.test_request_context(headers=request_headers): notify_celery.send_task('some-task') - super_apply.assert_called_with('some-task', None, {'request_id': '1234'}) + super_apply.assert_called_with( + 'some-task', # name + None, # args + None, # kwargs + headers={'notify_request_id': '1234'} # other kwargs + )