diff --git a/app/dao/services_dao.py b/app/dao/services_dao.py index 6a7be124e..dde38af32 100644 --- a/app/dao/services_dao.py +++ b/app/dao/services_dao.py @@ -41,22 +41,32 @@ def dao_fetch_all_services(only_active=False): return query.all() -def dao_fetch_service_by_id(service_id): - return Service.query.filter_by( +def dao_fetch_service_by_id(service_id, only_active=False): + query = Service.query.filter_by( id=service_id ).options( joinedload('users') - ).one() + ) + + if only_active: + query = query.filter(Service.active) + + return query.one() -def dao_fetch_all_services_by_user(user_id): - return Service.query.filter( +def dao_fetch_all_services_by_user(user_id, only_active=False): + query = Service.query.filter( Service.users.any(id=user_id) ).order_by( asc(Service.created_at) ).options( joinedload('users') - ).all() + ) + + if only_active: + query = query.filter(Service.active) + + return query.all() def dao_fetch_service_by_id_and_user(service_id, user_id): diff --git a/app/service/rest.py b/app/service/rest.py index fbdbb04bf..78138b434 100644 --- a/app/service/rest.py +++ b/app/service/rest.py @@ -61,13 +61,16 @@ register_errors(service_blueprint) @service_blueprint.route('', methods=['GET']) def get_services(): + only_active = request.args.get('only_active') == 'True' + detailed = request.args.get('detailed') == 'True' user_id = request.args.get('user_id', None) + if user_id: - services = dao_fetch_all_services_by_user(user_id) - elif request.args.get('detailed') == 'True': - return jsonify(data=get_detailed_services()) + services = dao_fetch_all_services_by_user(user_id, only_active) + elif detailed: + return jsonify(data=get_detailed_services(only_active)) else: - services = dao_fetch_all_services(only_active=request.args.get('only_active') == 'True') + services = dao_fetch_all_services(only_active) data = service_schema.dump(services, many=True).data return jsonify(data=data) @@ -267,8 +270,8 @@ def get_detailed_service(service_id, today_only=False): return detailed_service_schema.dump(service).data -def get_detailed_services(): - services = {service.id: service for service in dao_fetch_all_services()} +def get_detailed_services(only_active=False): + services = {service.id: service for service in dao_fetch_all_services(only_active)} stats = dao_fetch_todays_stats_for_all_services() for service_id, rows in itertools.groupby(stats, lambda x: x.service_id): diff --git a/tests/app/service/test_rest.py b/tests/app/service/test_rest.py index 42184d6de..3b3bdc744 100644 --- a/tests/app/service/test_rest.py +++ b/tests/app/service/test_rest.py @@ -22,9 +22,9 @@ from app.models import KEY_TYPE_TEST def test_get_service_list(notify_api, service_factory): with notify_api.test_request_context(): with notify_api.test_client() as client: - service_factory.get('one', email_from='one') - service_factory.get('two', email_from='two') - service_factory.get('three', email_from='three') + service_factory.get('one') + service_factory.get('two') + service_factory.get('three') auth_header = create_authorization_header() response = client.get( '/service', @@ -39,8 +39,8 @@ def test_get_service_list(notify_api, service_factory): def test_get_service_list_with_only_active_flag(client, service_factory): - inactive = service_factory.get('one', email_from='one') - active = service_factory.get('two', email_from='two') + inactive = service_factory.get('one') + active = service_factory.get('two') inactive.active = False @@ -55,11 +55,37 @@ def test_get_service_list_with_only_active_flag(client, service_factory): assert json_resp['data'][0]['id'] == str(active.id) +def test_get_service_list_with_user_id_and_only_active_flag( + notify_db, + notify_db_session, + client, + sample_user, + service_factory +): + other_user = create_sample_user(notify_db, notify_db_session, email='foo@bar.gov.uk') + + inactive = service_factory.get('one', user=sample_user) + active = service_factory.get('two', user=sample_user) + from_other_user = service_factory.get('three', user=other_user) + + inactive.active = False + + auth_header = create_authorization_header() + response = client.get( + '/service?user_id={}&only_active=True'.format(sample_user.id), + headers=[auth_header] + ) + assert response.status_code == 200 + json_resp = json.loads(response.get_data(as_text=True)) + assert len(json_resp['data']) == 1 + assert json_resp['data'][0]['id'] == str(active.id) + + def test_get_service_list_by_user(notify_db, notify_db_session, client, sample_user, service_factory): other_user = create_sample_user(notify_db, notify_db_session, email='foo@bar.gov.uk') - service_factory.get('one', sample_user, email_from='one') - service_factory.get('two', sample_user, email_from='two') - service_factory.get('three', other_user, email_from='three') + service_factory.get('one', sample_user) + service_factory.get('two', sample_user) + service_factory.get('three', other_user) auth_header = create_authorization_header() response = client.get( @@ -140,7 +166,7 @@ def test_get_service_by_id_should_404_if_no_service(notify_api, notify_db): def test_get_service_by_id_and_user(notify_api, service_factory, sample_user): with notify_api.test_request_context(): with notify_api.test_client() as client: - service = service_factory.get('new service', sample_user, email_from='new.service') + service = service_factory.get('new.service', sample_user) auth_header = create_authorization_header() resp = client.get( '/service/{}?user_id={}'.format(service.id, sample_user.id),