mirror of
https://github.com/GSA/notifications-api.git
synced 2026-02-03 18:01:08 -05:00
Merge branch 'main' of https://github.com/GSA/notifications-api into message-send-flow-docs
This commit is contained in:
@@ -3,15 +3,16 @@ import secrets
|
||||
import string
|
||||
import time
|
||||
import uuid
|
||||
from contextlib import contextmanager
|
||||
from time import monotonic
|
||||
|
||||
from celery import current_task
|
||||
from celery import Celery, Task, current_task
|
||||
from flask import current_app, g, has_request_context, jsonify, make_response, request
|
||||
from flask.ctx import has_app_context
|
||||
from flask_marshmallow import Marshmallow
|
||||
from flask_migrate import Migrate
|
||||
from flask_sqlalchemy import SQLAlchemy as _SQLAlchemy
|
||||
from notifications_utils import logging, request_helper
|
||||
from notifications_utils.celery import NotifyCelery
|
||||
from notifications_utils.clients.encryption.encryption_client import Encryption
|
||||
from notifications_utils.clients.redis.redis_client import RedisClient
|
||||
from notifications_utils.clients.zendesk.zendesk_client import ZendeskClient
|
||||
@@ -27,6 +28,25 @@ from app.clients.email.aws_ses_stub import AwsSesStubClient
|
||||
from app.clients.sms.aws_sns import AwsSnsClient
|
||||
|
||||
|
||||
class NotifyCelery(Celery):
|
||||
def init_app(self, app):
|
||||
self.task_cls = make_task(app)
|
||||
|
||||
# Configure Celery app with options from the main app config.
|
||||
self.config_from_object(app.config["CELERY"])
|
||||
|
||||
def send_task(self, name, args=None, kwargs=None, **other_kwargs):
|
||||
other_kwargs["headers"] = other_kwargs.get("headers") or {}
|
||||
|
||||
if has_request_context() and hasattr(request, "request_id"):
|
||||
other_kwargs["headers"]["notify_request_id"] = request.request_id
|
||||
|
||||
elif has_app_context() and "request_id" in g:
|
||||
other_kwargs["headers"]["notify_request_id"] = g.request_id
|
||||
|
||||
return super().send_task(name, args, kwargs, **other_kwargs)
|
||||
|
||||
|
||||
class SQLAlchemy(_SQLAlchemy):
|
||||
"""We need to subclass SQLAlchemy in order to override create_engine options"""
|
||||
|
||||
@@ -366,3 +386,58 @@ def setup_sqlalchemy_events(app):
|
||||
@event.listens_for(db.engine, "checkin")
|
||||
def checkin(dbapi_connection, connection_record): # noqa
|
||||
pass
|
||||
|
||||
|
||||
def make_task(app):
|
||||
class NotifyTask(Task):
|
||||
abstract = True
|
||||
start = None
|
||||
|
||||
@property
|
||||
def queue_name(self):
|
||||
delivery_info = self.request.delivery_info or {}
|
||||
return delivery_info.get("routing_key", "none")
|
||||
|
||||
@property
|
||||
def request_id(self):
|
||||
# Note that each header is a direct attribute of the
|
||||
# task context (aka "request").
|
||||
return self.request.get("notify_request_id")
|
||||
|
||||
@contextmanager
|
||||
def app_context(self):
|
||||
with app.app_context():
|
||||
# Add 'request_id' to 'g' so that it gets logged.
|
||||
g.request_id = self.request_id
|
||||
yield
|
||||
|
||||
def on_success(self, retval, task_id, args, kwargs): # noqa
|
||||
# enables request id tracing for these logs
|
||||
with self.app_context():
|
||||
elapsed_time = time.monotonic() - self.start
|
||||
|
||||
app.logger.info(
|
||||
"Celery task {task_name} (queue: {queue_name}) took {time}".format(
|
||||
task_name=self.name,
|
||||
queue_name=self.queue_name,
|
||||
time="{0:.4f}".format(elapsed_time),
|
||||
)
|
||||
)
|
||||
|
||||
def on_failure(self, exc, task_id, args, kwargs, einfo): # noqa
|
||||
# enables request id tracing for these logs
|
||||
with self.app_context():
|
||||
app.logger.exception(
|
||||
"Celery task {task_name} (queue: {queue_name}) failed".format(
|
||||
task_name=self.name,
|
||||
queue_name=self.queue_name,
|
||||
)
|
||||
)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
# ensure task has flask context to access config, logger, etc
|
||||
with self.app_context():
|
||||
self.start = time.monotonic()
|
||||
return super().__call__(*args, **kwargs)
|
||||
|
||||
return NotifyTask
|
||||
|
||||
@@ -114,8 +114,7 @@ def extract_phones(job):
|
||||
job_row = 0
|
||||
for row in job:
|
||||
row = row.split(",")
|
||||
current_app.logger.info(f"PHONE INDEX IS NOW {phone_index}")
|
||||
current_app.logger.info(f"LENGTH OF ROW IS {len(row)}")
|
||||
|
||||
if phone_index >= len(row):
|
||||
phones[job_row] = "Unavailable"
|
||||
current_app.logger.error(
|
||||
|
||||
Reference in New Issue
Block a user