diff --git a/app/__init__.py b/app/__init__.py index 21682fa9e..cc21ef59a 100644 --- a/app/__init__.py +++ b/app/__init__.py @@ -20,6 +20,7 @@ from flask_sqlalchemy import SQLAlchemy as _SQLAlchemy from gds_metrics import GDSMetrics from gds_metrics.metrics import Gauge, Histogram from notifications_utils import logging, request_helper +from notifications_utils.celery import NotifyCelery from notifications_utils.clients.encryption.encryption_client import Encryption from notifications_utils.clients.redis.redis_client import RedisClient from notifications_utils.clients.statsd.statsd_client import StatsdClient @@ -28,7 +29,6 @@ from sqlalchemy import event from werkzeug.exceptions import HTTPException as WerkzeugHTTPException from werkzeug.local import LocalProxy -from app.celery.celery import NotifyCelery from app.clients import NotificationProviderClients from app.clients.cbc_proxy import CBCProxyClient from app.clients.document_download import DocumentDownloadClient diff --git a/app/celery/celery.py b/app/celery/celery.py deleted file mode 100644 index d60f1903f..000000000 --- a/app/celery/celery.py +++ /dev/null @@ -1,117 +0,0 @@ -import time -from contextlib import contextmanager - -from celery import Celery, Task -from celery.signals import worker_process_shutdown -from flask import g, request -from flask.ctx import has_app_context, has_request_context - - -@worker_process_shutdown.connect -def log_on_worker_shutdown(sender, signal, pid, exitcode, **kwargs): - # imported here to avoid circular imports - from app import notify_celery - - # if the worker has already restarted at least once, then we no longer have app context and current_app won't work - # to create a new one. Instead we have to create a new app context from the original flask app and use that instead. - with notify_celery._app.app_context(): - # if the worker has restarted - notify_celery._app.logger.info('worker shutdown: PID: {} Exitcode: {}'.format(pid, exitcode)) - - -def make_task(app): - class NotifyTask(Task): - abstract = True - start = None - typing = False - - @property - def queue_name(self): - delivery_info = self.request.delivery_info or {} - return delivery_info.get('routing_key', 'none') - - @property - def request_id(self): - # Note that each header is a direct attribute of the - # task context (aka "request"). - return self.request.get('notify_request_id') - - @contextmanager - def app_context(self): - with app.app_context(): - # Add 'request_id' to 'g' so that it gets logged. - g.request_id = self.request_id - yield - - def on_success(self, retval, task_id, args, kwargs): - # enables request id tracing for these logs - with self.app_context(): - elapsed_time = time.monotonic() - self.start - - app.logger.info( - "Celery task {task_name} (queue: {queue_name}) took {time}".format( - task_name=self.name, - queue_name=self.queue_name, - time="{0:.4f}".format(elapsed_time) - ) - ) - - app.statsd_client.timing( - "celery.{queue_name}.{task_name}.success".format( - task_name=self.name, - queue_name=self.queue_name - ), elapsed_time - ) - - def on_failure(self, exc, task_id, args, kwargs, einfo): - # enables request id tracing for these logs - with self.app_context(): - app.logger.exception( - "Celery task {task_name} (queue: {queue_name}) failed".format( - task_name=self.name, - queue_name=self.queue_name, - ) - ) - - app.statsd_client.incr( - "celery.{queue_name}.{task_name}.failure".format( - task_name=self.name, - queue_name=self.queue_name - ) - ) - - super().on_failure(exc, task_id, args, kwargs, einfo) - - def __call__(self, *args, **kwargs): - # ensure task has flask context to access config, logger, etc - with self.app_context(): - self.start = time.monotonic() - # TEMPORARY: remove old piggyback values from kwargs - kwargs.pop('request_id', None) - - return super().__call__(*args, **kwargs) - - return NotifyTask - - -class NotifyCelery(Celery): - - def init_app(self, app): - super().__init__( - app.import_name, - broker=app.config['CELERY']['broker_url'], - task_cls=make_task(app), - ) - - self.conf.update(app.config['CELERY']) - self._app = app - - def send_task(self, name, args=None, kwargs=None, **other_kwargs): - other_kwargs['headers'] = other_kwargs.get('headers') or {} - - if has_request_context() and hasattr(request, 'request_id'): - other_kwargs['headers']['notify_request_id'] = request.request_id - elif has_app_context() and 'request_id' in g: - other_kwargs['headers']['notify_request_id'] = g.request_id - - return super().send_task(name, args, kwargs, **other_kwargs) diff --git a/requirements.in b/requirements.in index 91152a4d4..88c942024 100644 --- a/requirements.in +++ b/requirements.in @@ -36,7 +36,7 @@ notifications-python-client==6.0.2 # PaaS awscli-cwlogs==1.4.6 -git+https://github.com/alphagov/notifications-utils.git@48.0.0#egg=notifications-utils==48.0.0 +git+https://github.com/alphagov/notifications-utils.git@48.1.0#egg=notifications-utils==48.1.0 # gds-metrics requires prometheseus 0.2.0, override that requirement as 0.7.1 brings significant performance gains prometheus-client==0.10.1 diff --git a/requirements.txt b/requirements.txt index c1832f1f4..9a7c79e83 100644 --- a/requirements.txt +++ b/requirements.txt @@ -151,7 +151,7 @@ mistune==0.8.4 # via notifications-utils notifications-python-client==6.0.2 # via -r requirements.in -notifications-utils @ git+https://github.com/alphagov/notifications-utils.git@48.0.0 +notifications-utils @ git+https://github.com/alphagov/notifications-utils.git@48.1.0 # via -r requirements.in orderedset==2.0.3 # via notifications-utils diff --git a/tests/app/celery/test_celery.py b/tests/app/celery/test_celery.py deleted file mode 100644 index 47ce58c2f..000000000 --- a/tests/app/celery/test_celery.py +++ /dev/null @@ -1,166 +0,0 @@ -import uuid - -import pytest -from flask import g -from freezegun import freeze_time - -from app import notify_celery - - -# requiring notify_api ensures notify_celery.init_app has been called -@pytest.fixture(scope='session') -def celery_task(notify_api): - @notify_celery.task(name=uuid.uuid4(), base=notify_celery.task_cls) - def test_task(delivery_info=None): pass - return test_task - - -@pytest.fixture -def async_task(celery_task): - celery_task.push_request(delivery_info={'routing_key': 'test-queue'}) - yield celery_task - celery_task.pop_request() - - -@pytest.fixture -def request_id_task(celery_task): - # Note that each header is a direct attribute of the - # task context (aka "request"). - celery_task.push_request(notify_request_id='1234') - yield celery_task - celery_task.pop_request() - - -def test_success_should_log_and_call_statsd(mocker, notify_api, async_task): - statsd = mocker.patch.object(notify_api.statsd_client, 'timing') - logger = mocker.patch.object(notify_api.logger, 'info') - - with freeze_time() as frozen: - async_task() - frozen.tick(5) - - async_task.on_success( - retval=None, task_id=1234, args=[], kwargs={} - ) - - statsd.assert_called_once_with(f'celery.test-queue.{async_task.name}.success', 5.0) - logger.assert_called_once_with(f'Celery task {async_task.name} (queue: test-queue) took 5.0000') - - -def test_success_queue_when_applied_synchronously(mocker, notify_api, celery_task): - statsd = mocker.patch.object(notify_api.statsd_client, 'timing') - logger = mocker.patch.object(notify_api.logger, 'info') - - with freeze_time() as frozen: - celery_task() - frozen.tick(5) - - celery_task.on_success( - retval=None, task_id=1234, args=[], kwargs={} - ) - - statsd.assert_called_once_with(f'celery.none.{celery_task.name}.success', 5.0) - logger.assert_called_once_with(f'Celery task {celery_task.name} (queue: none) took 5.0000') - - -def test_failure_should_log_and_call_statsd(mocker, notify_api, async_task): - statsd = mocker.patch.object(notify_api.statsd_client, 'incr') - logger = mocker.patch.object(notify_api.logger, 'exception') - - async_task.on_failure( - exc=Exception, task_id=1234, args=[], kwargs={}, einfo=None - ) - - statsd.assert_called_once_with(f'celery.test-queue.{async_task.name}.failure') - logger.assert_called_once_with(f'Celery task {async_task.name} (queue: test-queue) failed') - - -def test_failure_queue_when_applied_synchronously(mocker, notify_api, celery_task): - statsd = mocker.patch.object(notify_api.statsd_client, 'incr') - logger = mocker.patch.object(notify_api.logger, 'exception') - - celery_task.on_failure( - exc=Exception, task_id=1234, args=[], kwargs={}, einfo=None - ) - - statsd.assert_called_once_with(f'celery.none.{celery_task.name}.failure') - logger.assert_called_once_with(f'Celery task {celery_task.name} (queue: none) failed') - - -def test_call_exports_request_id_from_headers(mocker, request_id_task): - g = mocker.patch('app.celery.celery.g') - request_id_task() - assert g.request_id == '1234' - - -def test_call_copes_if_request_id_not_in_headers(mocker, celery_task): - g = mocker.patch('app.celery.celery.g') - celery_task() - assert g.request_id is None - - -def test_send_task_injects_global_request_id_into_headers(mocker, notify_api): - super_apply = mocker.patch('celery.Celery.send_task') - g.request_id = '1234' - notify_celery.send_task('some-task') - - super_apply.assert_called_with( - 'some-task', # name - None, # args - None, # kwargs - headers={'notify_request_id': '1234'} # other kwargs - ) - - -def test_send_task_injects_request_id_with_existing_headers(mocker, notify_api): - super_apply = mocker.patch('celery.Celery.send_task') - g.request_id = '1234' - - notify_celery.send_task( - 'some-task', - None, # args - None, # kwargs - headers={'something': 'else'} # other kwargs - ) - - super_apply.assert_called_with( - 'some-task', # name - None, # args - None, # kwargs - headers={'notify_request_id': '1234', 'something': 'else'} # other kwargs - ) - - -def test_send_task_injects_request_id_with_none_headers(mocker, notify_api): - super_apply = mocker.patch('celery.Celery.send_task') - g.request_id = '1234' - - notify_celery.send_task( - 'some-task', - None, # args - None, # kwargs - headers=None, # other kwargs (task retry set headers to "None") - ) - - super_apply.assert_called_with( - 'some-task', # name - None, # args - None, # kwargs - headers={'notify_request_id': '1234'} # other kwargs - ) - - -def test_send_task_injects_id_into_kwargs_from_request(mocker, notify_api): - super_apply = mocker.patch('celery.Celery.send_task') - request_id_header = notify_api.config['NOTIFY_TRACE_ID_HEADER'] - request_headers = {request_id_header: '1234'} - - with notify_api.test_request_context(headers=request_headers): - notify_celery.send_task('some-task') - - super_apply.assert_called_with( - 'some-task', # name - None, # args - None, # kwargs - headers={'notify_request_id': '1234'} # other kwargs - )