diff --git a/app/__init__.py b/app/__init__.py index 57cd1c2ef..3b0fd05af 100644 --- a/app/__init__.py +++ b/app/__init__.py @@ -110,7 +110,7 @@ def create_app(application): email_clients = [aws_ses_stub_client] if application.config['SES_STUB_URL'] else [aws_ses_client] notification_provider_clients.init_app(sms_clients=[firetext_client, mmg_client], email_clients=email_clients) - notify_celery.init_app(application) + notify_celery.init_app(application, statsd_client) encryption.init_app(application) redis_store.init_app(application) document_download_client.init_app(application) diff --git a/app/celery/celery.py b/app/celery/celery.py index 6018319c8..ecc7486b8 100644 --- a/app/celery/celery.py +++ b/app/celery/celery.py @@ -4,7 +4,6 @@ from celery import Celery, Task from celery.signals import worker_process_shutdown from flask import g, request from flask.ctx import has_app_context, has_request_context -from gds_metrics.metrics import Histogram @worker_process_shutdown.connect @@ -19,60 +18,82 @@ def log_on_worker_shutdown(sender, signal, pid, exitcode, **kwargs): notify_celery._app.logger.info('worker shutdown: PID: {} Exitcode: {}'.format(pid, exitcode)) -def make_task(app): - SQS_APPLY_ASYNC_DURATION_SECONDS = Histogram( - 'sqs_apply_async_duration_seconds', - 'Time taken to put task on queue', - ['task_name'] - ) - +def make_task(app, statsd_client): class NotifyTask(Task): abstract = True start = None def on_success(self, retval, task_id, args, kwargs): - elapsed_time = time.time() - self.start + elapsed_time = time.monotonic() - self.start + delivery_info = self.request.delivery_info or {} + queue_name = delivery_info.get('routing_key', 'none') + app.logger.info( - "{task_name} took {time}".format( - task_name=self.name, time="{0:.4f}".format(elapsed_time) + "Celery task {task_name} (queue: {queue_name}) took {time}".format( + task_name=self.name, + queue_name=queue_name, + time="{0:.4f}".format(elapsed_time) ) ) + statsd_client.timing( + "celery.{queue_name}.{task_name}.success".format( + task_name=self.name, + queue_name=queue_name + ), elapsed_time + ) + def on_failure(self, exc, task_id, args, kwargs, einfo): - # ensure task will log exceptions to correct handlers - app.logger.exception('Celery task: {} failed'.format(self.name)) + delivery_info = self.request.delivery_info or {} + queue_name = delivery_info.get('routing_key', 'none') + + app.logger.exception( + "Celery task {task_name} (queue: {queue_name}) failed".format( + task_name=self.name, + queue_name=queue_name, + ) + ) + + statsd_client.incr( + "celery.{queue_name}.{task_name}.failure".format( + task_name=self.name, + queue_name=queue_name + ) + ) + super().on_failure(exc, task_id, args, kwargs, einfo) def __call__(self, *args, **kwargs): # ensure task has flask context to access config, logger, etc with app.app_context(): - self.start = time.time() - # Remove 'request_id' from the kwargs (so the task doesn't get an unexpected kwarg), then add it to g - # so that it gets logged + self.start = time.monotonic() + # Remove piggyback values from kwargs + # Add 'request_id' to 'g' so that it gets logged g.request_id = kwargs.pop('request_id', None) + return super().__call__(*args, **kwargs) def apply_async(self, args=None, kwargs=None, task_id=None, producer=None, link=None, link_error=None, **options): kwargs = kwargs or {} + if has_request_context() and hasattr(request, 'request_id'): kwargs['request_id'] = request.request_id elif has_app_context() and 'request_id' in g: kwargs['request_id'] = g.request_id - with SQS_APPLY_ASYNC_DURATION_SECONDS.labels(self.name).time(): - return super().apply_async(args, kwargs, task_id, producer, link, link_error, **options) + return super().apply_async(args, kwargs, task_id, producer, link, link_error, **options) return NotifyTask class NotifyCelery(Celery): - def init_app(self, app): + def init_app(self, app, statsd_client): super().__init__( app.import_name, broker=app.config['BROKER_URL'], - task_cls=make_task(app), + task_cls=make_task(app, statsd_client), ) self.conf.update(app.config) diff --git a/requirements_for_test.txt b/requirements_for_test.txt index 3b9beafba..c9862ea0a 100644 --- a/requirements_for_test.txt +++ b/requirements_for_test.txt @@ -8,7 +8,7 @@ pytest-env==0.6.2 pytest-mock==3.3.1 pytest-cov==2.10.1 pytest-xdist==2.1.0 -freezegun==1.0.0 +freezegun==1.1.0 requests-mock==1.8.0 # used for creating manifest file locally jinja2-cli[yaml]==0.7.0 diff --git a/tests/app/celery/test_celery.py b/tests/app/celery/test_celery.py new file mode 100644 index 000000000..0acbd8d23 --- /dev/null +++ b/tests/app/celery/test_celery.py @@ -0,0 +1,118 @@ +import uuid + +import pytest +from flask import g +from freezegun import freeze_time + +from app import notify_celery + + +# requiring notify_api ensures notify_celery.init_app has been called +@pytest.fixture(scope='session') +def celery_task(notify_api): + @notify_celery.task(name=uuid.uuid4(), base=notify_celery.task_cls) + def test_task(delivery_info=None): pass + return test_task + + +@pytest.fixture +def async_task(celery_task): + celery_task.push_request(delivery_info={'routing_key': 'test-queue'}) + yield celery_task + celery_task.pop_request() + + +def test_success_should_log_and_call_statsd(mocker, notify_api, async_task): + statsd = mocker.patch.object(notify_api.statsd_client, 'timing') + logger = mocker.patch.object(notify_api.logger, 'info') + + with freeze_time() as frozen: + async_task() + frozen.tick(5) + + async_task.on_success( + retval=None, task_id=1234, args=[], kwargs={} + ) + + statsd.assert_called_once_with(f'celery.test-queue.{async_task.name}.success', 5.0) + logger.assert_called_once_with(f'Celery task {async_task.name} (queue: test-queue) took 5.0000') + + +def test_success_queue_when_applied_synchronously(mocker, notify_api, celery_task): + statsd = mocker.patch.object(notify_api.statsd_client, 'timing') + logger = mocker.patch.object(notify_api.logger, 'info') + + with freeze_time() as frozen: + celery_task() + frozen.tick(5) + + celery_task.on_success( + retval=None, task_id=1234, args=[], kwargs={} + ) + + statsd.assert_called_once_with(f'celery.none.{celery_task.name}.success', 5.0) + logger.assert_called_once_with(f'Celery task {celery_task.name} (queue: none) took 5.0000') + + +def test_failure_should_log_and_call_statsd(mocker, notify_api, async_task): + statsd = mocker.patch.object(notify_api.statsd_client, 'incr') + logger = mocker.patch.object(notify_api.logger, 'exception') + + async_task.on_failure( + exc=Exception, task_id=1234, args=[], kwargs={}, einfo=None + ) + + statsd.assert_called_once_with(f'celery.test-queue.{async_task.name}.failure') + logger.assert_called_once_with(f'Celery task {async_task.name} (queue: test-queue) failed') + + +def test_failure_queue_when_applied_synchronously(mocker, notify_api, celery_task): + statsd = mocker.patch.object(notify_api.statsd_client, 'incr') + logger = mocker.patch.object(notify_api.logger, 'exception') + + celery_task.on_failure( + exc=Exception, task_id=1234, args=[], kwargs={}, einfo=None + ) + + statsd.assert_called_once_with(f'celery.none.{celery_task.name}.failure') + logger.assert_called_once_with(f'Celery task {celery_task.name} (queue: none) failed') + + +def test_call_exports_request_id_from_kwargs(mocker, celery_task): + g = mocker.patch('app.celery.celery.g') + # this would fail if the kwarg was passed through unexpectedly + celery_task(request_id='1234') + assert g.request_id == '1234' + + +def test_apply_async_injects_global_request_id_into_kwargs(mocker, celery_task): + super_apply = mocker.patch('celery.app.task.Task.apply_async') + g.request_id = '1234' + celery_task.apply_async() + + super_apply.assert_called_with( + None, + {'request_id': '1234'}, + None, + None, + None, + None + ) + + +def test_apply_async_injects_id_into_kwargs_from_request(mocker, notify_api, celery_task): + super_apply = mocker.patch('celery.app.task.Task.apply_async') + request_id_header = notify_api.config['NOTIFY_TRACE_ID_HEADER'] + request_headers = {request_id_header: '1234'} + + with notify_api.test_request_context(headers=request_headers): + celery_task.apply_async() + + super_apply.assert_called_with( + None, + {'request_id': '1234'}, + None, + None, + None, + None + )