This commit is contained in:
Kenneth Kehl
2024-10-31 14:38:07 -07:00
parent 78ac1ee094
commit c5b227403e

View File

@@ -4,8 +4,10 @@ from unittest.mock import Mock
import pytest import pytest
from flask import current_app from flask import current_app
from freezegun import freeze_time from freezegun import freeze_time
from sqlalchemy import select
from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.exc import SQLAlchemyError
from app import db
from app.dao.organization_dao import ( from app.dao.organization_dao import (
dao_add_service_to_organization, dao_add_service_to_organization,
dao_add_user_to_organization, dao_add_user_to_organization,
@@ -175,7 +177,7 @@ def test_post_create_organization(admin_request, notify_db_session):
"organization.create_organization", _data=data, _expected_status=201 "organization.create_organization", _data=data, _expected_status=201
) )
organizations = Organization.query.all() organizations = _get_organizations()
assert data["name"] == response["name"] assert data["name"] == response["name"]
assert data["active"] == response["active"] assert data["active"] == response["active"]
@@ -186,6 +188,11 @@ def test_post_create_organization(admin_request, notify_db_session):
assert organizations[0].email_branding_id is None assert organizations[0].email_branding_id is None
def _get_organizations():
stmt = select(Organization)
return db.session.execute(stmt).scalars().all()
@pytest.mark.parametrize("org_type", ["nhs_central", "nhs_local", "nhs_gp"]) @pytest.mark.parametrize("org_type", ["nhs_central", "nhs_local", "nhs_gp"])
@pytest.mark.skip(reason="Update for TTS") @pytest.mark.skip(reason="Update for TTS")
def test_post_create_organization_sets_default_nhs_branding_for_nhs_orgs( def test_post_create_organization_sets_default_nhs_branding_for_nhs_orgs(
@@ -201,7 +208,7 @@ def test_post_create_organization_sets_default_nhs_branding_for_nhs_orgs(
"organization.create_organization", _data=data, _expected_status=201 "organization.create_organization", _data=data, _expected_status=201
) )
organizations = Organization.query.all() organizations = _get_organizations()
assert len(organizations) == 1 assert len(organizations) == 1
assert organizations[0].email_branding_id == uuid.UUID( assert organizations[0].email_branding_id == uuid.UUID(
@@ -212,7 +219,7 @@ def test_post_create_organization_sets_default_nhs_branding_for_nhs_orgs(
def test_post_create_organization_existing_name_raises_400( def test_post_create_organization_existing_name_raises_400(
admin_request, sample_organization admin_request, sample_organization
): ):
organization = Organization.query.all() organization = _get_organizations()
assert len(organization) == 1 assert len(organization) == 1
data = { data = {
@@ -225,14 +232,14 @@ def test_post_create_organization_existing_name_raises_400(
"organization.create_organization", _data=data, _expected_status=400 "organization.create_organization", _data=data, _expected_status=400
) )
organization = Organization.query.all() organization = _get_organizations()
assert len(organization) == 1 assert len(organization) == 1
assert response["message"] == "Organization name already exists" assert response["message"] == "Organization name already exists"
def test_post_create_organization_works(admin_request, sample_organization): def test_post_create_organization_works(admin_request, sample_organization):
organization = Organization.query.all() organization = _get_organizations()
assert len(organization) == 1 assert len(organization) == 1
data = { data = {
@@ -245,7 +252,7 @@ def test_post_create_organization_works(admin_request, sample_organization):
"organization.create_organization", _data=data, _expected_status=201 "organization.create_organization", _data=data, _expected_status=201
) )
organization = Organization.query.all() organization = _get_organizations()
assert len(organization) == 2 assert len(organization) == 2
@@ -310,7 +317,7 @@ def test_post_update_organization_updates_fields(
_expected_status=204, _expected_status=204,
) )
organization = Organization.query.all() organization = _get_organizations()
assert len(organization) == 1 assert len(organization) == 1
assert organization[0].id == org.id assert organization[0].id == org.id
@@ -343,7 +350,7 @@ def test_post_update_organization_updates_domains(
_expected_status=204, _expected_status=204,
) )
organization = Organization.query.all() organization = _get_organizations()
assert len(organization) == 1 assert len(organization) == 1
assert [domain.domain for domain in organization[0].domains] == domain_list assert [domain.domain for domain in organization[0].domains] == domain_list
@@ -383,7 +390,7 @@ def test_post_update_organization_to_nhs_type_updates_branding_if_none_present(
_expected_status=204, _expected_status=204,
) )
organization = Organization.query.all() organization = _get_organizations()
assert len(organization) == 1 assert len(organization) == 1
assert organization[0].id == org.id assert organization[0].id == org.id
@@ -413,7 +420,7 @@ def test_post_update_organization_to_nhs_type_does_not_update_branding_if_defaul
_expected_status=204, _expected_status=204,
) )
organization = Organization.query.all() organization = _get_organizations()
assert len(organization) == 1 assert len(organization) == 1
assert organization[0].id == org.id assert organization[0].id == org.id
@@ -471,7 +478,7 @@ def test_post_update_organization_gives_404_status_if_org_does_not_exist(
_expected_status=404, _expected_status=404,
) )
organization = Organization.query.all() organization = _get_organizations()
assert not organization assert not organization