diff --git a/app/aws/s3.py b/app/aws/s3.py index c33366a2c..01cd6692e 100644 --- a/app/aws/s3.py +++ b/app/aws/s3.py @@ -10,6 +10,7 @@ from boto3 import Session from flask import current_app from app.clients import AWS_CLIENT_CONFIG +from app.utils import hilite from notifications_utils import aware_utcnow FILE_LOCATION_STRUCTURE = "service-{}-notify/{}.csv" @@ -65,6 +66,7 @@ def clean_cache(): def get_s3_client(): global s3_client if s3_client is None: + # print(hilite("S3 CLIENT IS NONE, CREATING IT!")) access_key = current_app.config["CSV_UPLOAD_BUCKET"]["access_key_id"] secret_key = current_app.config["CSV_UPLOAD_BUCKET"]["secret_access_key"] region = current_app.config["CSV_UPLOAD_BUCKET"]["region"] @@ -74,12 +76,15 @@ def get_s3_client(): region_name=region, ) s3_client = session.client("s3") + # else: + # print(hilite("S3 CLIENT ALREADY EXISTS, REUSING IT!")) return s3_client def get_s3_resource(): global s3_resource if s3_resource is None: + print(hilite("S3 RESOURCE IS NONE, CREATING IT!")) access_key = current_app.config["CSV_UPLOAD_BUCKET"]["access_key_id"] secret_key = current_app.config["CSV_UPLOAD_BUCKET"]["secret_access_key"] region = current_app.config["CSV_UPLOAD_BUCKET"]["region"] @@ -89,6 +94,8 @@ def get_s3_resource(): region_name=region, ) s3_resource = session.resource("s3", config=AWS_CLIENT_CONFIG) + else: + print(hilite("S3 RESOURCE ALREADY EXSITS, REUSING IT!")) return s3_resource diff --git a/app/config.py b/app/config.py index d3f2a5197..9ec37a71c 100644 --- a/app/config.py +++ b/app/config.py @@ -2,10 +2,12 @@ import json from datetime import datetime, timedelta from os import getenv, path +from boto3 import Session from celery.schedules import crontab from kombu import Exchange, Queue import notifications_utils +from app.clients import AWS_CLIENT_CONFIG from app.cloudfoundry_config import cloud_config @@ -51,6 +53,13 @@ class TaskNames(object): SCAN_FILE = "scan-file" +session = Session( + aws_access_key_id=getenv("CSV_AWS_ACCESS_KEY_ID"), + aws_secret_access_key=getenv("CSV_AWS_SECRET_ACCESS_KEY"), + region_name=getenv("CSV_AWS_REGION"), +) + + class Config(object): NOTIFY_APP_NAME = "api" DEFAULT_REDIS_EXPIRE_TIME = 4 * 24 * 60 * 60 @@ -166,6 +175,9 @@ class Config(object): current_minute = (datetime.now().minute + 1) % 60 + S3_CLIENT = session.client("s3") + S3_RESOURCE = session.resource("s3", config=AWS_CLIENT_CONFIG) + CELERY = { "worker_max_tasks_per_child": 500, "task_ignore_result": True, diff --git a/notifications_utils/s3.py b/notifications_utils/s3.py index 0a01f7493..46c89c68f 100644 --- a/notifications_utils/s3.py +++ b/notifications_utils/s3.py @@ -16,11 +16,32 @@ AWS_CLIENT_CONFIG = Config( use_fips_endpoint=True, ) +# Global variable +s3_resource = None + default_access_key_id = os.environ.get("AWS_ACCESS_KEY_ID") default_secret_access_key = os.environ.get("AWS_SECRET_ACCESS_KEY") default_region = os.environ.get("AWS_REGION") +def get_s3_resource(): + global s3_resource + if s3_resource is None: + # print(hilite("S3 RESOURCE IS NONE, CREATING IT!")) + access_key = (default_access_key_id,) + secret_key = (default_secret_access_key,) + region = (default_region,) + session = Session( + aws_access_key_id=access_key, + aws_secret_access_key=secret_key, + region_name=region, + ) + s3_resource = session.resource("s3", config=AWS_CLIENT_CONFIG) + # else: + # print(hilite("S3 RESOURCE ALREADY EXSITS, REUSING IT!")) + return s3_resource + + def s3upload( filedata, region, @@ -32,12 +53,7 @@ def s3upload( access_key=default_access_key_id, secret_key=default_secret_access_key, ): - session = Session( - aws_access_key_id=access_key, - aws_secret_access_key=secret_key, - region_name=region, - ) - _s3 = session.resource("s3", config=AWS_CLIENT_CONFIG) + _s3 = get_s3_resource() key = _s3.Object(bucket_name, file_location) @@ -73,12 +89,7 @@ def s3download( secret_key=default_secret_access_key, ): try: - session = Session( - aws_access_key_id=access_key, - aws_secret_access_key=secret_key, - region_name=region, - ) - s3 = session.resource("s3", config=AWS_CLIENT_CONFIG) + s3 = get_s3_resource() key = s3.Object(bucket_name, filename) return key.get()["Body"] except botocore.exceptions.ClientError as error: