Merge pull request #3201 from alphagov/revamp-celery-stats

Migrate towards new metrics for Celery tasks
This commit is contained in:
Ben Thorner
2021-04-12 15:04:37 +01:00
committed by GitHub
4 changed files with 161 additions and 22 deletions

View File

@@ -110,7 +110,7 @@ def create_app(application):
email_clients = [aws_ses_stub_client] if application.config['SES_STUB_URL'] else [aws_ses_client]
notification_provider_clients.init_app(sms_clients=[firetext_client, mmg_client], email_clients=email_clients)
notify_celery.init_app(application)
notify_celery.init_app(application, statsd_client)
encryption.init_app(application)
redis_store.init_app(application)
document_download_client.init_app(application)

View File

@@ -4,7 +4,6 @@ from celery import Celery, Task
from celery.signals import worker_process_shutdown
from flask import g, request
from flask.ctx import has_app_context, has_request_context
from gds_metrics.metrics import Histogram
@worker_process_shutdown.connect
@@ -19,60 +18,82 @@ def log_on_worker_shutdown(sender, signal, pid, exitcode, **kwargs):
notify_celery._app.logger.info('worker shutdown: PID: {} Exitcode: {}'.format(pid, exitcode))
def make_task(app):
SQS_APPLY_ASYNC_DURATION_SECONDS = Histogram(
'sqs_apply_async_duration_seconds',
'Time taken to put task on queue',
['task_name']
)
def make_task(app, statsd_client):
class NotifyTask(Task):
abstract = True
start = None
def on_success(self, retval, task_id, args, kwargs):
elapsed_time = time.time() - self.start
elapsed_time = time.monotonic() - self.start
delivery_info = self.request.delivery_info or {}
queue_name = delivery_info.get('routing_key', 'none')
app.logger.info(
"{task_name} took {time}".format(
task_name=self.name, time="{0:.4f}".format(elapsed_time)
"Celery task {task_name} (queue: {queue_name}) took {time}".format(
task_name=self.name,
queue_name=queue_name,
time="{0:.4f}".format(elapsed_time)
)
)
statsd_client.timing(
"celery.{queue_name}.{task_name}.success".format(
task_name=self.name,
queue_name=queue_name
), elapsed_time
)
def on_failure(self, exc, task_id, args, kwargs, einfo):
# ensure task will log exceptions to correct handlers
app.logger.exception('Celery task: {} failed'.format(self.name))
delivery_info = self.request.delivery_info or {}
queue_name = delivery_info.get('routing_key', 'none')
app.logger.exception(
"Celery task {task_name} (queue: {queue_name}) failed".format(
task_name=self.name,
queue_name=queue_name,
)
)
statsd_client.incr(
"celery.{queue_name}.{task_name}.failure".format(
task_name=self.name,
queue_name=queue_name
)
)
super().on_failure(exc, task_id, args, kwargs, einfo)
def __call__(self, *args, **kwargs):
# ensure task has flask context to access config, logger, etc
with app.app_context():
self.start = time.time()
# Remove 'request_id' from the kwargs (so the task doesn't get an unexpected kwarg), then add it to g
# so that it gets logged
self.start = time.monotonic()
# Remove piggyback values from kwargs
# Add 'request_id' to 'g' so that it gets logged
g.request_id = kwargs.pop('request_id', None)
return super().__call__(*args, **kwargs)
def apply_async(self, args=None, kwargs=None, task_id=None, producer=None,
link=None, link_error=None, **options):
kwargs = kwargs or {}
if has_request_context() and hasattr(request, 'request_id'):
kwargs['request_id'] = request.request_id
elif has_app_context() and 'request_id' in g:
kwargs['request_id'] = g.request_id
with SQS_APPLY_ASYNC_DURATION_SECONDS.labels(self.name).time():
return super().apply_async(args, kwargs, task_id, producer, link, link_error, **options)
return super().apply_async(args, kwargs, task_id, producer, link, link_error, **options)
return NotifyTask
class NotifyCelery(Celery):
def init_app(self, app):
def init_app(self, app, statsd_client):
super().__init__(
app.import_name,
broker=app.config['BROKER_URL'],
task_cls=make_task(app),
task_cls=make_task(app, statsd_client),
)
self.conf.update(app.config)

View File

@@ -8,7 +8,7 @@ pytest-env==0.6.2
pytest-mock==3.3.1
pytest-cov==2.10.1
pytest-xdist==2.1.0
freezegun==1.0.0
freezegun==1.1.0
requests-mock==1.8.0
# used for creating manifest file locally
jinja2-cli[yaml]==0.7.0

View File

@@ -0,0 +1,118 @@
import uuid
import pytest
from flask import g
from freezegun import freeze_time
from app import notify_celery
# requiring notify_api ensures notify_celery.init_app has been called
@pytest.fixture(scope='session')
def celery_task(notify_api):
@notify_celery.task(name=uuid.uuid4(), base=notify_celery.task_cls)
def test_task(delivery_info=None): pass
return test_task
@pytest.fixture
def async_task(celery_task):
celery_task.push_request(delivery_info={'routing_key': 'test-queue'})
yield celery_task
celery_task.pop_request()
def test_success_should_log_and_call_statsd(mocker, notify_api, async_task):
statsd = mocker.patch.object(notify_api.statsd_client, 'timing')
logger = mocker.patch.object(notify_api.logger, 'info')
with freeze_time() as frozen:
async_task()
frozen.tick(5)
async_task.on_success(
retval=None, task_id=1234, args=[], kwargs={}
)
statsd.assert_called_once_with(f'celery.test-queue.{async_task.name}.success', 5.0)
logger.assert_called_once_with(f'Celery task {async_task.name} (queue: test-queue) took 5.0000')
def test_success_queue_when_applied_synchronously(mocker, notify_api, celery_task):
statsd = mocker.patch.object(notify_api.statsd_client, 'timing')
logger = mocker.patch.object(notify_api.logger, 'info')
with freeze_time() as frozen:
celery_task()
frozen.tick(5)
celery_task.on_success(
retval=None, task_id=1234, args=[], kwargs={}
)
statsd.assert_called_once_with(f'celery.none.{celery_task.name}.success', 5.0)
logger.assert_called_once_with(f'Celery task {celery_task.name} (queue: none) took 5.0000')
def test_failure_should_log_and_call_statsd(mocker, notify_api, async_task):
statsd = mocker.patch.object(notify_api.statsd_client, 'incr')
logger = mocker.patch.object(notify_api.logger, 'exception')
async_task.on_failure(
exc=Exception, task_id=1234, args=[], kwargs={}, einfo=None
)
statsd.assert_called_once_with(f'celery.test-queue.{async_task.name}.failure')
logger.assert_called_once_with(f'Celery task {async_task.name} (queue: test-queue) failed')
def test_failure_queue_when_applied_synchronously(mocker, notify_api, celery_task):
statsd = mocker.patch.object(notify_api.statsd_client, 'incr')
logger = mocker.patch.object(notify_api.logger, 'exception')
celery_task.on_failure(
exc=Exception, task_id=1234, args=[], kwargs={}, einfo=None
)
statsd.assert_called_once_with(f'celery.none.{celery_task.name}.failure')
logger.assert_called_once_with(f'Celery task {celery_task.name} (queue: none) failed')
def test_call_exports_request_id_from_kwargs(mocker, celery_task):
g = mocker.patch('app.celery.celery.g')
# this would fail if the kwarg was passed through unexpectedly
celery_task(request_id='1234')
assert g.request_id == '1234'
def test_apply_async_injects_global_request_id_into_kwargs(mocker, celery_task):
super_apply = mocker.patch('celery.app.task.Task.apply_async')
g.request_id = '1234'
celery_task.apply_async()
super_apply.assert_called_with(
None,
{'request_id': '1234'},
None,
None,
None,
None
)
def test_apply_async_injects_id_into_kwargs_from_request(mocker, notify_api, celery_task):
super_apply = mocker.patch('celery.app.task.Task.apply_async')
request_id_header = notify_api.config['NOTIFY_TRACE_ID_HEADER']
request_headers = {request_id_header: '1234'}
with notify_api.test_request_context(headers=request_headers):
celery_task.apply_async()
super_apply.assert_called_with(
None,
{'request_id': '1234'},
None,
None,
None,
None
)