diff --git a/app/commands.py b/app/commands.py index 1c7918315..4b8f2b37f 100644 --- a/app/commands.py +++ b/app/commands.py @@ -6,6 +6,7 @@ import functools import flask from flask import current_app import click +from click_datetime import Datetime as click_dt from app import db from app.dao.monthly_billing_dao import ( @@ -49,26 +50,14 @@ class notify_command: @notify_command() -@click.option('-p', '--provider_name', required=True, help='Provider name') -@click.option('-c', '--cost', required=True, help='Cost (pence) per message including decimals') -@click.option('-d', '--valid_from', required=True, help="Date (%Y-%m-%dT%H:%M:%S) valid from") +@click.option('-p', '--provider_name', required=True, type=click.Choice(PROVIDERS)) +@click.option('-c', '--cost', required=True, help='Cost (pence) per message including decimals', type=float) +@click.option('-d', '--valid_from', required=True, type=click_dt(format='%Y-%m-%dT%H:%M:%S')) def create_provider_rates(provider_name, cost, valid_from): """ Backfill rates for a given provider """ - if provider_name not in PROVIDERS: - raise Exception("Invalid provider name, must be one of ({})".format(', '.join(PROVIDERS))) - - try: - cost = Decimal(cost) - except: - raise Exception("Invalid cost value.") - - try: - valid_from = datetime.strptime('%Y-%m-%dT%H:%M:%S', valid_from) - except: - raise Exception("Invalid valid_from date. Use the format %Y-%m-%dT%H:%M:%S") - + cost = Decimal(cost) dao_create_provider_rates(provider_name, valid_from, cost) @@ -203,18 +192,18 @@ def link_inbound_numbers_to_service(): @notify_command() -@click.option('-y', '--year', required=True, help="Use for integer value for year, e.g. 2017") +@click.option('-y', '--year', required=True, help="e.g. 2017", type=int) def populate_monthly_billing(year): """ Populate monthly billing table for all services for a given year. """ def populate(service_id, year, month): - create_or_update_monthly_billing(service_id, datetime(int(year), int(month), 1)) + create_or_update_monthly_billing(service_id, datetime(year, month, 1)) sms_res = get_monthly_billing_by_notification_type( - service_id, datetime(int(year), int(month), 1), SMS_TYPE + service_id, datetime(year, month, 1), SMS_TYPE ) email_res = get_monthly_billing_by_notification_type( - service_id, datetime(int(year), int(month), 1), EMAIL_TYPE + service_id, datetime(year, month, 1), EMAIL_TYPE ) print("Finished populating data for {} for service id {}".format(month, str(service_id))) print('SMS: {}'.format(sms_res.monthly_totals)) @@ -225,7 +214,7 @@ def populate_monthly_billing(year): ) start, end = 1, 13 - if year == '2016': + if year == 2016: start = 4 for service_id in service_ids: @@ -237,14 +226,12 @@ def populate_monthly_billing(year): @notify_command() -@click.option('-s', '--start_date', required=True, help="Date (%Y-%m-%d) start date inclusive") -@click.option('-e', '--end_date', required=True, help="Date (%Y-%m-%d) end date inclusive") +@click.option('-s', '--start_date', required=True, help="start date inclusive", type=click_dt(format='%Y-%m-%d')) +@click.option('-e', '--end_date', required=True, help="end date inclusive", type=click_dt(format='%Y-%m-%d')) def backfill_processing_time(start_date, end_date): """ Send historical performance platform stats. """ - start_date = datetime.strptime(start_date, '%Y-%m-%d') - end_date = datetime.strptime(end_date, '%Y-%m-%d') delta = end_date - start_date @@ -400,7 +387,7 @@ def populate_annual_billing(): @notify_command() -@click.option('-j', '--job_id', required=True, help="Enter the job id to rebuild the dvla file for") +@click.option('-j', '--job_id', required=True, help="Enter the job id to rebuild the dvla file for", type=click.UUID) def re_run_build_dvla_file_for_job(job_id): """ Rebuild dvla file for a job. diff --git a/requirements.txt b/requirements.txt index c186de052..68937f75d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,6 +7,7 @@ Flask-Marshmallow==0.8.0 Flask-Migrate==2.1.1 Flask-SQLAlchemy==2.3.2 Flask==0.12.2 +click-datetime==0.2 gunicorn==19.7.1 iso8601==0.1.12 jsonschema==2.6.0 diff --git a/tests/app/test_commands.py b/tests/app/test_commands.py index d5f6e75e1..d5313d128 100644 --- a/tests/app/test_commands.py +++ b/tests/app/test_commands.py @@ -8,7 +8,7 @@ def test_backfill_processing_time_works_for_correct_dates(mocker, notify_api): # backfill_processing_time is a click.Command object - if you try invoking the callback on its own, it # throws a `RuntimeError: There is no active click context.` - so get at the original function using __wrapped__ - backfill_processing_time.callback.__wrapped__('2017-08-01', '2017-08-03') + backfill_processing_time.callback.__wrapped__(datetime(2017, 8, 1), datetime(2017, 8, 3)) assert send_mock.call_count == 3 send_mock.assert_any_call(datetime(2017, 7, 31, 23, 0), datetime(2017, 8, 1, 23, 0))