diff --git a/app/commands.py b/app/commands.py index 7844c942b..2adfb94a4 100644 --- a/app/commands.py +++ b/app/commands.py @@ -94,17 +94,25 @@ class notify_command: self.name = name def __call__(self, func): - # we need to call the flask with_appcontext decorator to ensure the config is loaded, db connected etc etc. - # we also need to use functools.wraps to carry through the names and docstrings etc of the functions. - # Then we need to turn it into a click.Command - that's what command_group.add_command expects. - @click.command(name=self.name) - @functools.wraps(func) - @flask.cli.with_appcontext + decorators = [ + click.command(name=self.name), # turn it into a click.Command + functools.wraps(func) # carry through function name, docstrings, etc. + ] + + # in the test environment the app context is already provided and having + # another will lead to the test db connection being closed prematurely + if os.getenv('NOTIFY_ENVIRONMENT', '') != 'test': + # with_appcontext ensures the config is loaded, db connected, etc. + decorators.insert(0, flask.cli.with_appcontext) + def wrapper(*args, **kwargs): return func(*args, **kwargs) - command_group.add_command(wrapper) + for decorator in decorators: + # this syntax is equivalent to e.g. "@flask.cli.with_appcontext" + wrapper = decorator(wrapper) + command_group.add_command(wrapper) return wrapper diff --git a/tests/app/test_commands.py b/tests/app/test_commands.py new file mode 100644 index 000000000..adc3d76e4 --- /dev/null +++ b/tests/app/test_commands.py @@ -0,0 +1,26 @@ +import uuid + +from app.commands import local_dev_broadcast_permissions +from app.dao.services_dao import dao_add_user_to_service +from tests.app.db import create_user + + +def test_local_dev_broadcast_permissions( + sample_service, + sample_broadcast_service, + notify_api, +): + # create_user will pull existing unless email is unique + user = create_user(email=f'{uuid.uuid4()}@example.com') + dao_add_user_to_service(sample_service, user) + dao_add_user_to_service(sample_broadcast_service, user) + + assert len(user.get_permissions(sample_service.id)) == 0 + assert len(user.get_permissions(sample_broadcast_service.id)) == 0 + + notify_api.test_cli_runner().invoke( + local_dev_broadcast_permissions, ['-u', user.id] + ) + + assert len(user.get_permissions(sample_service.id)) == 0 + assert len(user.get_permissions(sample_broadcast_service.id)) > 0