Merge branch 'main' of https://github.com/GSA/notifications-api into no-attribute-day-bug

This commit is contained in:
Andrew Shumway
2024-11-07 09:46:13 -07:00
5 changed files with 54 additions and 31 deletions

View File

@@ -4,6 +4,7 @@ import string
import time import time
import uuid import uuid
from contextlib import contextmanager from contextlib import contextmanager
from multiprocessing import Manager
from time import monotonic from time import monotonic
from celery import Celery, Task, current_task from celery import Celery, Task, current_task
@@ -119,6 +120,9 @@ def create_app(application):
redis_store.init_app(application) redis_store.init_app(application)
document_download_client.init_app(application) document_download_client.init_app(application)
manager = Manager()
application.config["job_cache"] = manager.dict()
register_blueprint(application) register_blueprint(application)
# avoid circular imports by importing this file later # avoid circular imports by importing this file later

View File

@@ -2,7 +2,6 @@ import datetime
import re import re
import time import time
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from multiprocessing import Manager
import botocore import botocore
from boto3 import Session from boto3 import Session
@@ -16,8 +15,6 @@ NEW_FILE_LOCATION_STRUCTURE = "{}-service-notify/{}.csv"
# Temporarily extend cache to 7 days # Temporarily extend cache to 7 days
ttl = 60 * 60 * 24 * 7 ttl = 60 * 60 * 24 * 7
manager = Manager()
job_cache = manager.dict()
# Global variable # Global variable
@@ -25,17 +22,40 @@ s3_client = None
s3_resource = None s3_resource = None
def set_job_cache(job_cache, key, value): def set_job_cache(key, value):
current_app.logger.info(f"Setting {key} in the job_cache.")
job_cache = current_app.config["job_cache"]
job_cache[key] = (value, time.time() + 8 * 24 * 60 * 60) job_cache[key] = (value, time.time() + 8 * 24 * 60 * 60)
def get_job_cache(key):
job_cache = current_app.config["job_cache"]
ret = job_cache.get(key)
if ret is None:
current_app.logger.warning(f"Could not find {key} in the job_cache.")
else:
current_app.logger.info(f"Got {key} from job_cache.")
return ret
def len_job_cache():
job_cache = current_app.config["job_cache"]
ret = len(job_cache)
current_app.logger.info(f"Length of job_cache is {ret}")
return ret
def clean_cache(): def clean_cache():
job_cache = current_app.config["job_cache"]
current_time = time.time() current_time = time.time()
keys_to_delete = [] keys_to_delete = []
for key, (_, expiry_time) in job_cache.items(): for key, (_, expiry_time) in job_cache.items():
if expiry_time < current_time: if expiry_time < current_time:
keys_to_delete.append(key) keys_to_delete.append(key)
current_app.logger.info(
f"Deleting the following keys from the job_cache: {keys_to_delete}"
)
for key in keys_to_delete: for key in keys_to_delete:
del job_cache[key] del job_cache[key]
@@ -162,17 +182,16 @@ def read_s3_file(bucket_name, object_key, s3res):
""" """
try: try:
job_id = get_job_id_from_s3_object_key(object_key) job_id = get_job_id_from_s3_object_key(object_key)
if job_cache.get(job_id) is None: if get_job_cache(job_id) is None:
object = ( object = (
s3res.Object(bucket_name, object_key) s3res.Object(bucket_name, object_key)
.get()["Body"] .get()["Body"]
.read() .read()
.decode("utf-8") .decode("utf-8")
) )
set_job_cache(job_cache, job_id, object) set_job_cache(job_id, object)
set_job_cache(job_cache, f"{job_id}_phones", extract_phones(object)) set_job_cache(f"{job_id}_phones", extract_phones(object))
set_job_cache( set_job_cache(
job_cache,
f"{job_id}_personalisation", f"{job_id}_personalisation",
extract_personalisation(object), extract_personalisation(object),
) )
@@ -192,7 +211,7 @@ def get_s3_files():
s3res = get_s3_resource() s3res = get_s3_resource()
current_app.logger.info( current_app.logger.info(
f"job_cache length before regen: {len(job_cache)} #notify-admin-1200" f"job_cache length before regen: {len_job_cache()} #notify-admin-1200"
) )
try: try:
with ThreadPoolExecutor() as executor: with ThreadPoolExecutor() as executor:
@@ -201,7 +220,7 @@ def get_s3_files():
current_app.logger.exception("Connection pool issue") current_app.logger.exception("Connection pool issue")
current_app.logger.info( current_app.logger.info(
f"job_cache length after regen: {len(job_cache)} #notify-admin-1200" f"job_cache length after regen: {len_job_cache()} #notify-admin-1200"
) )
@@ -424,12 +443,12 @@ def extract_personalisation(job):
def get_phone_number_from_s3(service_id, job_id, job_row_number): def get_phone_number_from_s3(service_id, job_id, job_row_number):
job = job_cache.get(job_id) job = get_job_cache(job_id)
if job is None: if job is None:
current_app.logger.info(f"job {job_id} was not in the cache") current_app.logger.info(f"job {job_id} was not in the cache")
job = get_job_from_s3(service_id, job_id) job = get_job_from_s3(service_id, job_id)
# Even if it is None, put it here to avoid KeyErrors # Even if it is None, put it here to avoid KeyErrors
set_job_cache(job_cache, job_id, job) set_job_cache(job_id, job)
else: else:
# skip expiration date from cache, we don't need it here # skip expiration date from cache, we don't need it here
job = job[0] job = job[0]
@@ -441,7 +460,7 @@ def get_phone_number_from_s3(service_id, job_id, job_row_number):
return "Unavailable" return "Unavailable"
phones = extract_phones(job) phones = extract_phones(job)
set_job_cache(job_cache, f"{job_id}_phones", phones) set_job_cache(f"{job_id}_phones", phones)
# If we can find the quick dictionary, use it # If we can find the quick dictionary, use it
phone_to_return = phones[job_row_number] phone_to_return = phones[job_row_number]
@@ -458,12 +477,12 @@ def get_personalisation_from_s3(service_id, job_id, job_row_number):
# We don't want to constantly pull down a job from s3 every time we need the personalisation. # We don't want to constantly pull down a job from s3 every time we need the personalisation.
# At the same time we don't want to store it in redis or the db # At the same time we don't want to store it in redis or the db
# So this is a little recycling mechanism to reduce the number of downloads. # So this is a little recycling mechanism to reduce the number of downloads.
job = job_cache.get(job_id) job = get_job_cache(job_id)
if job is None: if job is None:
current_app.logger.info(f"job {job_id} was not in the cache") current_app.logger.info(f"job {job_id} was not in the cache")
job = get_job_from_s3(service_id, job_id) job = get_job_from_s3(service_id, job_id)
# Even if it is None, put it here to avoid KeyErrors # Even if it is None, put it here to avoid KeyErrors
set_job_cache(job_cache, job_id, job) set_job_cache(job_id, job)
else: else:
# skip expiration date from cache, we don't need it here # skip expiration date from cache, we don't need it here
job = job[0] job = job[0]
@@ -478,9 +497,9 @@ def get_personalisation_from_s3(service_id, job_id, job_row_number):
) )
return {} return {}
set_job_cache(job_cache, f"{job_id}_personalisation", extract_personalisation(job)) set_job_cache(f"{job_id}_personalisation", extract_personalisation(job))
return job_cache.get(f"{job_id}_personalisation")[0].get(job_row_number) return get_job_cache(f"{job_id}_personalisation")[0].get(job_row_number)
def get_job_metadata_from_s3(service_id, job_id): def get_job_metadata_from_s3(service_id, job_id):

View File

@@ -17,8 +17,8 @@ AWS_CLIENT_CONFIG = Config(
# there may come a time when increasing this helps # there may come a time when increasing this helps
# with job cache management. # with job cache management.
# max_pool_connections=10, # max_pool_connections=10,
# Reducing to 4 connections due to BrokenPipeErrors # Reducing to 7 connections due to BrokenPipeErrors
max_pool_connections=4, max_pool_connections=7,
) )

View File

@@ -1,8 +1,8 @@
env: production env: production
web_instances: 2 web_instances: 2
web_memory: 2G web_memory: 3G
worker_instances: 4 worker_instances: 4
worker_memory: 3G worker_memory: 2G
scheduler_memory: 256M scheduler_memory: 256M
public_api_route: notify-api.app.cloud.gov public_api_route: notify-api.app.cloud.gov
admin_base_url: https://beta.notify.gov admin_base_url: https://beta.notify.gov

View File

@@ -1,7 +1,7 @@
import os import os
from datetime import timedelta from datetime import timedelta
from os import getenv from os import getenv
from unittest.mock import ANY, MagicMock, Mock, call, patch from unittest.mock import MagicMock, Mock, call, patch
import botocore import botocore
import pytest import pytest
@@ -70,7 +70,7 @@ def test_cleanup_old_s3_objects(mocker):
mock_remove_csv_object.assert_called_once_with("A") mock_remove_csv_object.assert_called_once_with("A")
def test_read_s3_file_success(mocker): def test_read_s3_file_success(client, mocker):
mock_s3res = MagicMock() mock_s3res = MagicMock()
mock_extract_personalisation = mocker.patch("app.aws.s3.extract_personalisation") mock_extract_personalisation = mocker.patch("app.aws.s3.extract_personalisation")
mock_extract_phones = mocker.patch("app.aws.s3.extract_phones") mock_extract_phones = mocker.patch("app.aws.s3.extract_phones")
@@ -89,16 +89,13 @@ def test_read_s3_file_success(mocker):
mock_extract_phones.return_value = ["1234567890"] mock_extract_phones.return_value = ["1234567890"]
mock_extract_personalisation.return_value = {"name": "John Doe"} mock_extract_personalisation.return_value = {"name": "John Doe"}
global job_cache
job_cache = {}
read_s3_file(bucket_name, object_key, mock_s3res) read_s3_file(bucket_name, object_key, mock_s3res)
mock_get_job_id.assert_called_once_with(object_key) mock_get_job_id.assert_called_once_with(object_key)
mock_s3res.Object.assert_called_once_with(bucket_name, object_key) mock_s3res.Object.assert_called_once_with(bucket_name, object_key)
expected_calls = [ expected_calls = [
call(ANY, job_id, file_content), call(job_id, file_content),
call(ANY, f"{job_id}_phones", ["1234567890"]), call(f"{job_id}_phones", ["1234567890"]),
call(ANY, f"{job_id}_personalisation", {"name": "John Doe"}), call(f"{job_id}_personalisation", {"name": "John Doe"}),
] ]
mock_set_job_cache.assert_has_calls(expected_calls, any_order=True) mock_set_job_cache.assert_has_calls(expected_calls, any_order=True)
@@ -380,9 +377,12 @@ def test_file_exists_false(notify_api, mocker):
get_s3_mock.assert_called_once() get_s3_mock.assert_called_once()
def test_get_s3_files_success(notify_api, mocker): def test_get_s3_files_success(client, mocker):
mock_current_app = mocker.patch("app.aws.s3.current_app") mock_current_app = mocker.patch("app.aws.s3.current_app")
mock_current_app.config = {"CSV_UPLOAD_BUCKET": {"bucket": "test-bucket"}} mock_current_app.config = {
"CSV_UPLOAD_BUCKET": {"bucket": "test-bucket"},
"job_cache": {},
}
mock_thread_pool_executor = mocker.patch("app.aws.s3.ThreadPoolExecutor") mock_thread_pool_executor = mocker.patch("app.aws.s3.ThreadPoolExecutor")
mock_read_s3_file = mocker.patch("app.aws.s3.read_s3_file") mock_read_s3_file = mocker.patch("app.aws.s3.read_s3_file")
mock_list_s3_objects = mocker.patch("app.aws.s3.list_s3_objects") mock_list_s3_objects = mocker.patch("app.aws.s3.list_s3_objects")