Move Celery task Request ID injection into headers

Previously we passed along this piece of state via the kwargs for
a task, but this runs the risk of the task accidentally receiving
the extra kwarg unless we've covered all the code paths that could
invoke it directly e.g. retries don't invoke __call__.

This switches to using Celery "headers" to pass the extra state. It
turns out that a Celery has two "header" concepts, which leads to
some confusion and even a bug with the framework [1]:

- In older (pre v4.4) versions of Celery, the "headers" specified
by apply_async() would become _the_ headers in the message that
gets passed around workers, etc. These would be available later on
via "self.request.headers".

- Since Celery protocol v2, the meaning of "headers" in the message
changed to become (basically) _all_ metadata about the task [2],
with the "headers" option in apply_async() being merged [3] into
the big dict of metadata.

This makes using headers a bit confusing unfortunately, since the
data structure we put in is subtly different to what comes out in
the request context. Nonetheless, it still works. I've added some
comments to try and clarify it.

Note that one of the original tests is no longer necessary, since we
don't need to worry about argument passing styles with headers.

[1]: https://github.com/celery/celery/issues/4875
[2]: 663e4d3a0b (diff-07a65448b2db3252a9711766beec23372715cd7597c3e309bf53859eabc0107fR343)
[3]: 681a922220/celery/app/amqp.py (L495)
This commit is contained in:
Ben Thorner
2021-11-10 15:42:51 +00:00
parent 98b6c1d67d
commit 89a8dd1a03
2 changed files with 70 additions and 18 deletions

View File

@@ -68,9 +68,12 @@ def make_task(app):
# ensure task has flask context to access config, logger, etc
with app.app_context():
self.start = time.monotonic()
# Remove piggyback values from kwargs
# Add 'request_id' to 'g' so that it gets logged
g.request_id = kwargs.pop('request_id', None)
# TEMPORARY: remove old piggyback values from kwargs
kwargs.pop('request_id', None)
# Add 'request_id' to 'g' so that it gets logged. Note
# that each header is a direct attribute of the task
# context (aka "request").
g.request_id = self.request.get('notify_request_id')
return super().__call__(*args, **kwargs)
@@ -90,11 +93,11 @@ class NotifyCelery(Celery):
self._app = app
def send_task(self, name, args=None, kwargs=None, **other_kwargs):
kwargs = kwargs or {}
other_kwargs['headers'] = other_kwargs.get('headers') or {}
if has_request_context() and hasattr(request, 'request_id'):
kwargs['request_id'] = request.request_id
other_kwargs['headers']['notify_request_id'] = request.request_id
elif has_app_context() and 'request_id' in g:
kwargs['request_id'] = g.request_id
other_kwargs['headers']['notify_request_id'] = g.request_id
return super().send_task(name, args, kwargs, **other_kwargs)

View File

@@ -22,6 +22,15 @@ def async_task(celery_task):
celery_task.pop_request()
@pytest.fixture
def request_id_task(celery_task):
# Note that each header is a direct attribute of the
# task context (aka "request").
celery_task.push_request(notify_request_id='1234')
yield celery_task
celery_task.pop_request()
def test_success_should_log_and_call_statsd(mocker, notify_api, async_task):
statsd = mocker.patch.object(notify_api.statsd_client, 'timing')
logger = mocker.patch.object(notify_api.logger, 'info')
@@ -78,32 +87,67 @@ def test_failure_queue_when_applied_synchronously(mocker, notify_api, celery_tas
logger.assert_called_once_with(f'Celery task {celery_task.name} (queue: none) failed')
def test_call_exports_request_id_from_kwargs(mocker, celery_task):
def test_call_exports_request_id_from_headers(mocker, request_id_task):
g = mocker.patch('app.celery.celery.g')
# this would fail if the kwarg was passed through unexpectedly
celery_task(request_id='1234')
request_id_task()
assert g.request_id == '1234'
def test_send_task_injects_global_request_id_into_kwargs(mocker, notify_api):
def test_call_copes_if_request_id_not_in_headers(mocker, celery_task):
g = mocker.patch('app.celery.celery.g')
celery_task()
assert g.request_id is None
def test_send_task_injects_global_request_id_into_headers(mocker, notify_api):
super_apply = mocker.patch('celery.Celery.send_task')
g.request_id = '1234'
notify_celery.send_task('some-task')
super_apply.assert_called_with('some-task', None, {'request_id': '1234'})
super_apply.assert_called_with(
'some-task', # name
None, # args
None, # kwargs
headers={'notify_request_id': '1234'} # other kwargs
)
def test_send_task_injects_request_id_with_other_kwargs(mocker, notify_api):
def test_send_task_injects_request_id_with_existing_headers(mocker, notify_api):
super_apply = mocker.patch('celery.Celery.send_task')
g.request_id = '1234'
notify_celery.send_task('some-task', kwargs={'something': 'else'})
super_apply.assert_called_with('some-task', None, {'request_id': '1234', 'something': 'else'})
notify_celery.send_task(
'some-task',
None, # args
None, # kwargs
headers={'something': 'else'} # other kwargs
)
super_apply.assert_called_with(
'some-task', # name
None, # args
None, # kwargs
headers={'notify_request_id': '1234', 'something': 'else'} # other kwargs
)
def test_send_task_injects_request_id_with_positional_args(mocker, notify_api):
def test_send_task_injects_request_id_with_none_headers(mocker, notify_api):
super_apply = mocker.patch('celery.Celery.send_task')
g.request_id = '1234'
notify_celery.send_task('some-task', ['args'], {'kw': 'args'})
super_apply.assert_called_with('some-task', ['args'], {'request_id': '1234', 'kw': 'args'})
notify_celery.send_task(
'some-task',
None, # args
None, # kwargs
headers=None, # other kwargs (task retry set headers to "None")
)
super_apply.assert_called_with(
'some-task', # name
None, # args
None, # kwargs
headers={'notify_request_id': '1234'} # other kwargs
)
def test_send_task_injects_id_into_kwargs_from_request(mocker, notify_api):
@@ -114,4 +158,9 @@ def test_send_task_injects_id_into_kwargs_from_request(mocker, notify_api):
with notify_api.test_request_context(headers=request_headers):
notify_celery.send_task('some-task')
super_apply.assert_called_with('some-task', None, {'request_id': '1234'})
super_apply.assert_called_with(
'some-task', # name
None, # args
None, # kwargs
headers={'notify_request_id': '1234'} # other kwargs
)