Merge pull request #1427 from alphagov/fix-commands

ensure the app context is included in every single flask command
This commit is contained in:
Leo Hemsted
2017-11-24 12:18:20 +00:00
committed by GitHub
2 changed files with 41 additions and 20 deletions

View File

@@ -1,6 +1,7 @@
import uuid import uuid
from datetime import datetime, timedelta from datetime import datetime, timedelta
from decimal import Decimal from decimal import Decimal
import functools
import flask import flask
from flask import current_app from flask import current_app
@@ -24,11 +25,30 @@ from app.performance_platform.processing_time import send_processing_time_for_st
@click.group(name='command', help='Additional commands') @click.group(name='command', help='Additional commands')
def commands(): def command_group():
pass pass
@commands.command() class notify_command:
def __init__(self, name=None):
self.name = name
def __call__(self, func):
# we need to call the flask with_appcontext decorator to ensure the config is loaded, db connected etc etc.
# we also need to use functools.wraps to carry through the names and docstrings etc of the functions.
# Then we need to turn it into a click.Command - that's what command_group.add_command expects.
@click.command(name=self.name)
@functools.wraps(func)
@flask.cli.with_appcontext
def wrapper(*args, **kwargs):
return func(*args, **kwargs)
command_group.add_command(wrapper)
return wrapper
@notify_command()
@click.option('-p', '--provider_name', required=True, help='Provider name') @click.option('-p', '--provider_name', required=True, help='Provider name')
@click.option('-c', '--cost', required=True, help='Cost (pence) per message including decimals') @click.option('-c', '--cost', required=True, help='Cost (pence) per message including decimals')
@click.option('-d', '--valid_from', required=True, help="Date (%Y-%m-%dT%H:%M:%S) valid from") @click.option('-d', '--valid_from', required=True, help="Date (%Y-%m-%dT%H:%M:%S) valid from")
@@ -52,7 +72,7 @@ def create_provider_rates(provider_name, cost, valid_from):
dao_create_provider_rates(provider_name, valid_from, cost) dao_create_provider_rates(provider_name, valid_from, cost)
@commands.command() @notify_command()
@click.option('-u', '--user_email_prefix', required=True, help=""" @click.option('-u', '--user_email_prefix', required=True, help="""
Functional test user email prefix. eg "notify-test-preview" Functional test user email prefix. eg "notify-test-preview"
""") # noqa """) # noqa
@@ -80,7 +100,7 @@ def purge_functional_test_data(user_email_prefix):
delete_model_user(usr) delete_model_user(usr)
@commands.command() @notify_command()
def backfill_notification_statuses(): def backfill_notification_statuses():
""" """
DEPRECATED. Populates notification_status. DEPRECATED. Populates notification_status.
@@ -100,7 +120,7 @@ def backfill_notification_statuses():
result = db.session.execute(subq).fetchall() result = db.session.execute(subq).fetchall()
@commands.command() @notify_command()
def update_notification_international_flag(): def update_notification_international_flag():
""" """
DEPRECATED. Set notifications.international=false. DEPRECATED. Set notifications.international=false.
@@ -127,7 +147,7 @@ def update_notification_international_flag():
result_history = db.session.execute(subq_history).fetchall() result_history = db.session.execute(subq_history).fetchall()
@commands.command() @notify_command()
def fix_notification_statuses_not_in_sync(): def fix_notification_statuses_not_in_sync():
""" """
DEPRECATED. DEPRECATED.
@@ -161,7 +181,7 @@ def fix_notification_statuses_not_in_sync():
result = db.session.execute(subq_hist).fetchall() result = db.session.execute(subq_hist).fetchall()
@commands.command() @notify_command()
def link_inbound_numbers_to_service(): def link_inbound_numbers_to_service():
""" """
DEPRECATED. DEPRECATED.
@@ -182,7 +202,7 @@ def link_inbound_numbers_to_service():
print("Linked {} inbound numbers to service".format(result.rowcount)) print("Linked {} inbound numbers to service".format(result.rowcount))
@commands.command() @notify_command()
@click.option('-y', '--year', required=True, help="Use for integer value for year, e.g. 2017") @click.option('-y', '--year', required=True, help="Use for integer value for year, e.g. 2017")
def populate_monthly_billing(year): def populate_monthly_billing(year):
""" """
@@ -216,7 +236,7 @@ def populate_monthly_billing(year):
populate(service_id, year, i) populate(service_id, year, i)
@commands.command() @notify_command()
@click.option('-s', '--start_date', required=True, help="Date (%Y-%m-%d) start date inclusive") @click.option('-s', '--start_date', required=True, help="Date (%Y-%m-%d) start date inclusive")
@click.option('-e', '--end_date', required=True, help="Date (%Y-%m-%d) end date inclusive") @click.option('-e', '--end_date', required=True, help="Date (%Y-%m-%d) end date inclusive")
def backfill_processing_time(start_date, end_date): def backfill_processing_time(start_date, end_date):
@@ -245,7 +265,7 @@ def backfill_processing_time(start_date, end_date):
send_processing_time_for_start_and_end(process_start_date, process_end_date) send_processing_time_for_start_and_end(process_start_date, process_end_date)
@commands.command() @notify_command()
def populate_service_email_reply_to(): def populate_service_email_reply_to():
""" """
Migrate reply to emails. Migrate reply to emails.
@@ -267,7 +287,7 @@ def populate_service_email_reply_to():
print("Populated email reply to addresses for {}".format(result.rowcount)) print("Populated email reply to addresses for {}".format(result.rowcount))
@commands.command() @notify_command()
def populate_service_sms_sender(): def populate_service_sms_sender():
""" """
Migrate sms senders. Must be called when working on a fresh db! Migrate sms senders. Must be called when working on a fresh db!
@@ -306,7 +326,7 @@ def populate_service_sms_sender():
print("{} service_sms_senders".format(service_sms_sender_count_query)) print("{} service_sms_senders".format(service_sms_sender_count_query))
@commands.command() @notify_command()
def populate_service_letter_contact(): def populate_service_letter_contact():
""" """
Migrates letter contact blocks. Migrates letter contact blocks.
@@ -328,7 +348,7 @@ def populate_service_letter_contact():
print("Populated letter contacts for {} services".format(result.rowcount)) print("Populated letter contacts for {} services".format(result.rowcount))
@commands.command() @notify_command()
def populate_service_and_service_history_free_sms_fragment_limit(): def populate_service_and_service_history_free_sms_fragment_limit():
""" """
DEPRECATED. Set services to have 250k sms limit. DEPRECATED. Set services to have 250k sms limit.
@@ -354,7 +374,7 @@ def populate_service_and_service_history_free_sms_fragment_limit():
print("Populated free sms fragment limits for {} services history".format(services_history_result.rowcount)) print("Populated free sms fragment limits for {} services history".format(services_history_result.rowcount))
@commands.command() @notify_command()
def populate_annual_billing(): def populate_annual_billing():
""" """
add annual_billing for 2016, 2017 and 2018. add annual_billing for 2016, 2017 and 2018.
@@ -379,7 +399,7 @@ def populate_annual_billing():
print("Populated annual billing {} for {} services".format(fy, services_result1.rowcount)) print("Populated annual billing {} for {} services".format(fy, services_result1.rowcount))
@commands.command() @notify_command()
@click.option('-j', '--job_id', required=True, help="Enter the job id to rebuild the dvla file for") @click.option('-j', '--job_id', required=True, help="Enter the job id to rebuild the dvla file for")
def re_run_build_dvla_file_for_job(job_id): def re_run_build_dvla_file_for_job(job_id):
""" """
@@ -390,8 +410,7 @@ def re_run_build_dvla_file_for_job(job_id):
build_dvla_file.apply_async([job_id], queue=QueueNames.JOBS) build_dvla_file.apply_async([job_id], queue=QueueNames.JOBS)
@commands.command(name='list-routes') @notify_command(name='list-routes')
@flask.cli.with_appcontext
def list_routes(): def list_routes():
"""List URLs of all application routes.""" """List URLs of all application routes."""
for rule in sorted(current_app.url_map.iter_rules(), key=lambda r: r.rule): for rule in sorted(current_app.url_map.iter_rules(), key=lambda r: r.rule):
@@ -399,4 +418,4 @@ def list_routes():
def setup_commands(application): def setup_commands(application):
application.cli.add_command(commands) application.cli.add_command(command_group)

View File

@@ -3,10 +3,12 @@ from datetime import datetime
from app.commands import backfill_processing_time from app.commands import backfill_processing_time
def test_backfill_processing_time_works_for_correct_dates(mocker): def test_backfill_processing_time_works_for_correct_dates(mocker, notify_api):
send_mock = mocker.patch('app.commands.send_processing_time_for_start_and_end') send_mock = mocker.patch('app.commands.send_processing_time_for_start_and_end')
backfill_processing_time.callback('2017-08-01', '2017-08-03') # backfill_processing_time is a click.Command object - if you try invoking the callback on its own, it
# throws a `RuntimeError: There is no active click context.` - so get at the original function using __wrapped__
backfill_processing_time.callback.__wrapped__('2017-08-01', '2017-08-03')
assert send_mock.call_count == 3 assert send_mock.call_count == 3
send_mock.assert_any_call(datetime(2017, 7, 31, 23, 0), datetime(2017, 8, 1, 23, 0)) send_mock.assert_any_call(datetime(2017, 7, 31, 23, 0), datetime(2017, 8, 1, 23, 0))