mirror of
https://github.com/GSA/notifications-api.git
synced 2026-02-02 09:26:08 -05:00
Remove default creds from s3 module
This commit is contained in:
@@ -1,25 +1,19 @@
|
|||||||
import os
|
|
||||||
|
|
||||||
import botocore
|
import botocore
|
||||||
from boto3 import Session, client
|
from boto3 import Session
|
||||||
from flask import current_app
|
from flask import current_app
|
||||||
|
|
||||||
FILE_LOCATION_STRUCTURE = 'service-{}-notify/{}.csv'
|
FILE_LOCATION_STRUCTURE = 'service-{}-notify/{}.csv'
|
||||||
|
|
||||||
default_access_key = os.environ.get('AWS_ACCESS_KEY_ID')
|
|
||||||
default_secret_key = os.environ.get('AWS_SECRET_ACCESS_KEY')
|
|
||||||
default_region = os.environ.get('AWS_REGION')
|
|
||||||
|
|
||||||
|
|
||||||
def get_s3_file(
|
def get_s3_file(
|
||||||
bucket_name, file_location, access_key=default_access_key, secret_key=default_secret_key, region=default_region
|
bucket_name, file_location, access_key, secret_key, region
|
||||||
):
|
):
|
||||||
s3_file = get_s3_object(bucket_name, file_location, access_key, secret_key, region)
|
s3_file = get_s3_object(bucket_name, file_location, access_key, secret_key, region)
|
||||||
return s3_file.get()['Body'].read().decode('utf-8')
|
return s3_file.get()['Body'].read().decode('utf-8')
|
||||||
|
|
||||||
|
|
||||||
def get_s3_object(
|
def get_s3_object(
|
||||||
bucket_name, file_location, access_key=default_access_key, secret_key=default_secret_key, region=default_region
|
bucket_name, file_location, access_key, secret_key, region
|
||||||
):
|
):
|
||||||
session = Session(aws_access_key_id=access_key, aws_secret_access_key=secret_key, region_name=region)
|
session = Session(aws_access_key_id=access_key, aws_secret_access_key=secret_key, region_name=region)
|
||||||
s3 = session.resource('s3')
|
s3 = session.resource('s3')
|
||||||
@@ -27,7 +21,7 @@ def get_s3_object(
|
|||||||
|
|
||||||
|
|
||||||
def file_exists(
|
def file_exists(
|
||||||
bucket_name, file_location, access_key=default_access_key, secret_key=default_secret_key, region=default_region
|
bucket_name, file_location, access_key, secret_key, region
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
# try and access metadata of object
|
# try and access metadata of object
|
||||||
@@ -85,28 +79,3 @@ def remove_contact_list_from_s3(service_id, contact_list_id):
|
|||||||
def remove_s3_object(bucket_name, object_key, access_key, secret_key, region):
|
def remove_s3_object(bucket_name, object_key, access_key, secret_key, region):
|
||||||
obj = get_s3_object(bucket_name, object_key, access_key, secret_key, region)
|
obj = get_s3_object(bucket_name, object_key, access_key, secret_key, region)
|
||||||
return obj.delete()
|
return obj.delete()
|
||||||
|
|
||||||
|
|
||||||
def get_list_of_files_by_suffix(
|
|
||||||
bucket_name,
|
|
||||||
subfolder='',
|
|
||||||
suffix='',
|
|
||||||
last_modified=None,
|
|
||||||
access_key=default_access_key,
|
|
||||||
secret_key=default_secret_key,
|
|
||||||
region=default_region
|
|
||||||
):
|
|
||||||
s3_client = client('s3', region, aws_access_key_id=access_key, aws_secret_access_key=secret_key)
|
|
||||||
paginator = s3_client.get_paginator('list_objects_v2')
|
|
||||||
|
|
||||||
page_iterator = paginator.paginate(
|
|
||||||
Bucket=bucket_name,
|
|
||||||
Prefix=subfolder
|
|
||||||
)
|
|
||||||
|
|
||||||
for page in page_iterator:
|
|
||||||
for obj in page.get('Contents', []):
|
|
||||||
key = obj['Key']
|
|
||||||
if key.lower().endswith(suffix.lower()):
|
|
||||||
if not last_modified or obj['LastModified'] >= last_modified:
|
|
||||||
yield key
|
|
||||||
|
|||||||
@@ -1,17 +1,11 @@
|
|||||||
from datetime import datetime, timedelta
|
from datetime import datetime
|
||||||
|
from os import getenv
|
||||||
|
|
||||||
import pytest
|
from app.aws.s3 import get_s3_file
|
||||||
import pytz
|
|
||||||
from freezegun import freeze_time
|
|
||||||
|
|
||||||
from app.aws.s3 import (
|
default_access_key = getenv('AWS_ACCESS_KEY_ID')
|
||||||
default_access_key,
|
default_secret_key = getenv('AWS_SECRET_ACCESS_KEY')
|
||||||
default_region,
|
default_region = getenv('AWS_REGION')
|
||||||
default_secret_key,
|
|
||||||
get_list_of_files_by_suffix,
|
|
||||||
get_s3_file,
|
|
||||||
)
|
|
||||||
from tests.app.conftest import datetime_in_past
|
|
||||||
|
|
||||||
|
|
||||||
def single_s3_object_stub(key='foo', last_modified=None):
|
def single_s3_object_stub(key='foo', last_modified=None):
|
||||||
@@ -24,7 +18,7 @@ def single_s3_object_stub(key='foo', last_modified=None):
|
|||||||
|
|
||||||
def test_get_s3_file_makes_correct_call(notify_api, mocker):
|
def test_get_s3_file_makes_correct_call(notify_api, mocker):
|
||||||
get_s3_mock = mocker.patch('app.aws.s3.get_s3_object')
|
get_s3_mock = mocker.patch('app.aws.s3.get_s3_object')
|
||||||
get_s3_file('foo-bucket', 'bar-file.txt')
|
get_s3_file('foo-bucket', 'bar-file.txt', default_access_key, default_secret_key, default_region)
|
||||||
|
|
||||||
get_s3_mock.assert_called_with(
|
get_s3_mock.assert_called_with(
|
||||||
'foo-bucket',
|
'foo-bucket',
|
||||||
@@ -33,52 +27,3 @@ def test_get_s3_file_makes_correct_call(notify_api, mocker):
|
|||||||
default_secret_key,
|
default_secret_key,
|
||||||
default_region,
|
default_region,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@freeze_time("2018-01-11 00:00:00")
|
|
||||||
@pytest.mark.parametrize('suffix_str, days_before, returned_no', [
|
|
||||||
('.ACK.txt', None, 1),
|
|
||||||
('.ack.txt', None, 1),
|
|
||||||
('.ACK.TXT', None, 1),
|
|
||||||
('', None, 2),
|
|
||||||
('', 1, 1),
|
|
||||||
])
|
|
||||||
def test_get_list_of_files_by_suffix(notify_api, mocker, suffix_str, days_before, returned_no):
|
|
||||||
paginator_mock = mocker.patch('app.aws.s3.client')
|
|
||||||
multiple_pages_s3_object = [
|
|
||||||
{
|
|
||||||
"Contents": [
|
|
||||||
single_s3_object_stub('bar/foo.ACK.txt', datetime_in_past(1, 0)),
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"Contents": [
|
|
||||||
single_s3_object_stub('bar/foo1.rs.txt', datetime_in_past(2, 0)),
|
|
||||||
]
|
|
||||||
}
|
|
||||||
]
|
|
||||||
paginator_mock.return_value.get_paginator.return_value.paginate.return_value = multiple_pages_s3_object
|
|
||||||
if (days_before):
|
|
||||||
key = get_list_of_files_by_suffix('foo-bucket', subfolder='bar', suffix=suffix_str,
|
|
||||||
last_modified=datetime.now(tz=pytz.utc) - timedelta(days=days_before))
|
|
||||||
else:
|
|
||||||
key = get_list_of_files_by_suffix('foo-bucket', subfolder='bar', suffix=suffix_str)
|
|
||||||
|
|
||||||
assert sum(1 for x in key) == returned_no
|
|
||||||
for k in key:
|
|
||||||
assert k == 'bar/foo.ACK.txt'
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_list_of_files_by_suffix_empty_contents_return_with_no_error(notify_api, mocker):
|
|
||||||
paginator_mock = mocker.patch('app.aws.s3.client')
|
|
||||||
multiple_pages_s3_object = [
|
|
||||||
{
|
|
||||||
"other_content": [
|
|
||||||
'some_values',
|
|
||||||
]
|
|
||||||
}
|
|
||||||
]
|
|
||||||
paginator_mock.return_value.get_paginator.return_value.paginate.return_value = multiple_pages_s3_object
|
|
||||||
key = get_list_of_files_by_suffix('foo-bucket', subfolder='bar', suffix='.pdf')
|
|
||||||
|
|
||||||
assert sum(1 for x in key) == 0
|
|
||||||
|
|||||||
Reference in New Issue
Block a user