mirror of
https://github.com/GSA/notifications-api.git
synced 2026-01-30 06:21:50 -05:00
Merge pull request #1363 from GSA/notify-api-1323
upgrade org and template dao to sqlalchemy 2.0
This commit is contained in:
@@ -384,5 +384,5 @@
|
||||
}
|
||||
]
|
||||
},
|
||||
"generated_at": "2024-10-11T19:26:50Z"
|
||||
"generated_at": "2024-10-14T17:46:47Z"
|
||||
}
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from sqlalchemy import delete, select, update
|
||||
from sqlalchemy.sql.expression import func
|
||||
|
||||
from app import db
|
||||
@@ -6,55 +7,57 @@ from app.models import Domain, Organization, Service, User
|
||||
|
||||
|
||||
def dao_get_organizations():
|
||||
return Organization.query.order_by(
|
||||
stmt = select(Organization).order_by(
|
||||
Organization.active.desc(), Organization.name.asc()
|
||||
).all()
|
||||
)
|
||||
return db.session.execute(stmt).scalars().all()
|
||||
|
||||
|
||||
def dao_count_organizations_with_live_services():
|
||||
return (
|
||||
db.session.query(Organization.id)
|
||||
stmt = (
|
||||
select(func.count(func.distinct(Organization.id)))
|
||||
.join(Organization.services)
|
||||
.filter(
|
||||
Service.active.is_(True),
|
||||
Service.restricted.is_(False),
|
||||
Service.count_as_live.is_(True),
|
||||
)
|
||||
.distinct()
|
||||
.count()
|
||||
)
|
||||
return db.session.execute(stmt).scalar() or 0
|
||||
|
||||
|
||||
def dao_get_organization_services(organization_id):
|
||||
return Organization.query.filter_by(id=organization_id).one().services
|
||||
stmt = select(Organization).filter_by(id=organization_id)
|
||||
return db.session.execute(stmt).scalars().one().services
|
||||
|
||||
|
||||
def dao_get_organization_live_services(organization_id):
|
||||
return Service.query.filter_by(
|
||||
organization_id=organization_id, restricted=False
|
||||
).all()
|
||||
stmt = select(Service).filter_by(organization_id=organization_id, restricted=False)
|
||||
return db.session.execute(stmt).scalars().all()
|
||||
|
||||
|
||||
def dao_get_organization_by_id(organization_id):
|
||||
return Organization.query.filter_by(id=organization_id).one()
|
||||
stmt = select(Organization).filter_by(id=organization_id)
|
||||
return db.session.execute(stmt).scalars().one()
|
||||
|
||||
|
||||
def dao_get_organization_by_email_address(email_address):
|
||||
email_address = email_address.lower().replace(".gsi.gov.uk", ".gov.uk")
|
||||
|
||||
for domain in Domain.query.order_by(func.char_length(Domain.domain).desc()).all():
|
||||
stmt = select(Domain).order_by(func.char_length(Domain.domain).desc())
|
||||
domains = db.session.execute(stmt).scalars().all()
|
||||
for domain in domains:
|
||||
if email_address.endswith(
|
||||
"@{}".format(domain.domain)
|
||||
) or email_address.endswith(".{}".format(domain.domain)):
|
||||
return Organization.query.filter_by(id=domain.organization_id).one()
|
||||
stmt = select(Organization).filter_by(id=domain.organization_id)
|
||||
return db.session.execute(stmt).scalars().one()
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def dao_get_organization_by_service_id(service_id):
|
||||
return (
|
||||
Organization.query.join(Organization.services).filter_by(id=service_id).first()
|
||||
)
|
||||
stmt = select(Organization).join(Organization.services).filter_by(id=service_id)
|
||||
return db.session.execute(stmt).scalars().first()
|
||||
|
||||
|
||||
@autocommit
|
||||
@@ -65,10 +68,14 @@ def dao_create_organization(organization):
|
||||
@autocommit
|
||||
def dao_update_organization(organization_id, **kwargs):
|
||||
domains = kwargs.pop("domains", None)
|
||||
num_updated = Organization.query.filter_by(id=organization_id).update(kwargs)
|
||||
stmt = (
|
||||
update(Organization).where(Organization.id == organization_id).values(**kwargs)
|
||||
)
|
||||
num_updated = db.session.execute(stmt).rowcount
|
||||
|
||||
if isinstance(domains, list):
|
||||
Domain.query.filter_by(organization_id=organization_id).delete()
|
||||
stmt = delete(Domain).filter_by(organization_id=organization_id)
|
||||
db.session.execute(stmt)
|
||||
db.session.bulk_save_objects(
|
||||
[
|
||||
Domain(domain=domain.lower(), organization_id=organization_id)
|
||||
@@ -76,7 +83,7 @@ def dao_update_organization(organization_id, **kwargs):
|
||||
]
|
||||
)
|
||||
|
||||
organization = Organization.query.get(organization_id)
|
||||
organization = db.session.get(Organization, organization_id)
|
||||
if "organization_type" in kwargs:
|
||||
_update_organization_services(
|
||||
organization, "organization_type", only_where_none=False
|
||||
@@ -101,7 +108,8 @@ def _update_organization_services(organization, attribute, only_where_none=True)
|
||||
@autocommit
|
||||
@version_class(Service)
|
||||
def dao_add_service_to_organization(service, organization_id):
|
||||
organization = Organization.query.filter_by(id=organization_id).one()
|
||||
stmt = select(Organization).filter_by(id=organization_id)
|
||||
organization = db.session.execute(stmt).scalars().one()
|
||||
|
||||
service.organization_id = organization_id
|
||||
service.organization_type = organization.organization_type
|
||||
@@ -122,7 +130,8 @@ def dao_get_users_for_organization(organization_id):
|
||||
@autocommit
|
||||
def dao_add_user_to_organization(organization_id, user_id):
|
||||
organization = dao_get_organization_by_id(organization_id)
|
||||
user = User.query.filter_by(id=user_id).one()
|
||||
stmt = select(User).filter_by(id=user_id)
|
||||
user = db.session.execute(stmt).scalars().one()
|
||||
user.organizations.append(organization)
|
||||
db.session.add(organization)
|
||||
return user
|
||||
|
||||
@@ -1,16 +1,20 @@
|
||||
from sqlalchemy import select
|
||||
|
||||
from app import db
|
||||
from app.dao.dao_utils import autocommit
|
||||
from app.models import TemplateFolder
|
||||
|
||||
|
||||
def dao_get_template_folder_by_id_and_service_id(template_folder_id, service_id):
|
||||
return TemplateFolder.query.filter(
|
||||
stmt = select(TemplateFolder).filter(
|
||||
TemplateFolder.id == template_folder_id, TemplateFolder.service_id == service_id
|
||||
).one()
|
||||
)
|
||||
return db.session.execute(stmt).scalars().one()
|
||||
|
||||
|
||||
def dao_get_valid_template_folders_by_id(folder_ids):
|
||||
return TemplateFolder.query.filter(TemplateFolder.id.in_(folder_ids)).all()
|
||||
stmt = select(TemplateFolder).filter(TemplateFolder.id.in_(folder_ids))
|
||||
return db.session.execute(stmt).scalars().all()
|
||||
|
||||
|
||||
@autocommit
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import uuid
|
||||
|
||||
from sqlalchemy import asc, desc
|
||||
from sqlalchemy import asc, desc, select
|
||||
|
||||
from app import db
|
||||
from app.dao.dao_utils import VersionOptions, autocommit, version_class
|
||||
@@ -46,24 +46,29 @@ def dao_redact_template(template, user_id):
|
||||
|
||||
def dao_get_template_by_id_and_service_id(template_id, service_id, version=None):
|
||||
if version is not None:
|
||||
return TemplateHistory.query.filter_by(
|
||||
stmt = select(TemplateHistory).filter_by(
|
||||
id=template_id, hidden=False, service_id=service_id, version=version
|
||||
).one()
|
||||
return Template.query.filter_by(
|
||||
)
|
||||
return db.session.execute(stmt).scalars().one()
|
||||
stmt = select(Template).filter_by(
|
||||
id=template_id, hidden=False, service_id=service_id
|
||||
).one()
|
||||
)
|
||||
return db.session.execute(stmt).scalars().one()
|
||||
|
||||
|
||||
def dao_get_template_by_id(template_id, version=None):
|
||||
if version is not None:
|
||||
return TemplateHistory.query.filter_by(id=template_id, version=version).one()
|
||||
return Template.query.filter_by(id=template_id).one()
|
||||
stmt = select(TemplateHistory).filter_by(id=template_id, version=version)
|
||||
return db.session.execute(stmt).scalars().one()
|
||||
stmt = select(Template).filter_by(id=template_id)
|
||||
return db.session.execute(stmt).scalars().one()
|
||||
|
||||
|
||||
def dao_get_all_templates_for_service(service_id, template_type=None):
|
||||
if template_type is not None:
|
||||
return (
|
||||
Template.query.filter_by(
|
||||
stmt = (
|
||||
select(Template)
|
||||
.filter_by(
|
||||
service_id=service_id,
|
||||
template_type=template_type,
|
||||
hidden=False,
|
||||
@@ -73,26 +78,27 @@ def dao_get_all_templates_for_service(service_id, template_type=None):
|
||||
asc(Template.name),
|
||||
asc(Template.template_type),
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
return (
|
||||
Template.query.filter_by(service_id=service_id, hidden=False, archived=False)
|
||||
return db.session.execute(stmt).scalars().all()
|
||||
stmt = (
|
||||
select(Template)
|
||||
.filter_by(service_id=service_id, hidden=False, archived=False)
|
||||
.order_by(
|
||||
asc(Template.name),
|
||||
asc(Template.template_type),
|
||||
)
|
||||
.all()
|
||||
)
|
||||
return db.session.execute(stmt).scalars().all()
|
||||
|
||||
|
||||
def dao_get_template_versions(service_id, template_id):
|
||||
return (
|
||||
TemplateHistory.query.filter_by(
|
||||
stmt = (
|
||||
select(TemplateHistory)
|
||||
.filter_by(
|
||||
service_id=service_id,
|
||||
id=template_id,
|
||||
hidden=False,
|
||||
)
|
||||
.order_by(desc(TemplateHistory.version))
|
||||
.all()
|
||||
)
|
||||
return db.session.execute(stmt).scalars().all()
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.exc import IntegrityError, SQLAlchemyError
|
||||
|
||||
from app import db
|
||||
@@ -57,7 +58,8 @@ def test_get_organization_by_id_gets_correct_organization(notify_db_session):
|
||||
def test_update_organization(notify_db_session):
|
||||
create_organization()
|
||||
|
||||
organization = Organization.query.one()
|
||||
stmt = select(Organization)
|
||||
organization = db.session.execute(stmt).scalars().one()
|
||||
user = create_user()
|
||||
email_branding = create_email_branding()
|
||||
|
||||
@@ -78,7 +80,8 @@ def test_update_organization(notify_db_session):
|
||||
|
||||
dao_update_organization(organization.id, **data)
|
||||
|
||||
organization = Organization.query.one()
|
||||
stmt = select(Organization)
|
||||
organization = db.session.execute(stmt).scalars().one()
|
||||
|
||||
for attribute, value in data.items():
|
||||
assert getattr(organization, attribute) == value
|
||||
@@ -102,7 +105,8 @@ def test_update_organization_domains_lowercases(
|
||||
):
|
||||
create_organization()
|
||||
|
||||
organization = Organization.query.one()
|
||||
stmt = select(Organization)
|
||||
organization = db.session.execute(stmt).scalars().one()
|
||||
|
||||
# Seed some domains
|
||||
dao_update_organization(organization.id, domains=["123", "456"])
|
||||
@@ -121,7 +125,8 @@ def test_update_organization_domains_lowercases_integrity_error(
|
||||
):
|
||||
create_organization()
|
||||
|
||||
organization = Organization.query.one()
|
||||
stmt = select(Organization)
|
||||
organization = db.session.execute(stmt).scalars().one()
|
||||
|
||||
# Seed some domains
|
||||
dao_update_organization(organization.id, domains=["123", "456"])
|
||||
@@ -175,11 +180,11 @@ def test_update_organization_updates_the_service_org_type_if_org_type_is_provide
|
||||
|
||||
assert sample_organization.organization_type == OrganizationType.FEDERAL
|
||||
assert sample_service.organization_type == OrganizationType.FEDERAL
|
||||
stmt = select(Service.get_history_model()).filter_by(
|
||||
id=sample_service.id, version=2
|
||||
)
|
||||
assert (
|
||||
Service.get_history_model()
|
||||
.query.filter_by(id=sample_service.id, version=2)
|
||||
.one()
|
||||
.organization_type
|
||||
db.session.execute(stmt).scalars().one().organization_type
|
||||
== OrganizationType.FEDERAL
|
||||
)
|
||||
|
||||
@@ -229,11 +234,11 @@ def test_add_service_to_organization(sample_service, sample_organization):
|
||||
assert sample_organization.services[0].id == sample_service.id
|
||||
|
||||
assert sample_service.organization_type == sample_organization.organization_type
|
||||
stmt = select(Service.get_history_model()).filter_by(
|
||||
id=sample_service.id, version=2
|
||||
)
|
||||
assert (
|
||||
Service.get_history_model()
|
||||
.query.filter_by(id=sample_service.id, version=2)
|
||||
.one()
|
||||
.organization_type
|
||||
db.session.execute(stmt).scalars().one().organization_type
|
||||
== sample_organization.organization_type
|
||||
)
|
||||
assert sample_service.organization_id == sample_organization.id
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from sqlalchemy import select
|
||||
|
||||
from app import db
|
||||
from app.dao.service_user_dao import dao_get_service_user
|
||||
from app.dao.template_folder_dao import (
|
||||
@@ -17,5 +19,5 @@ def test_dao_delete_template_folder_deletes_user_folder_permissions(
|
||||
dao_update_template_folder(folder)
|
||||
|
||||
dao_delete_template_folder(folder)
|
||||
|
||||
assert db.session.query(user_folder_permissions).all() == []
|
||||
stmt = select(user_folder_permissions)
|
||||
assert db.session.execute(stmt).scalars().all() == []
|
||||
|
||||
@@ -2,8 +2,10 @@ from datetime import datetime
|
||||
|
||||
import pytest
|
||||
from freezegun import freeze_time
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.orm.exc import NoResultFound
|
||||
|
||||
from app import db
|
||||
from app.dao.templates_dao import (
|
||||
dao_create_template,
|
||||
dao_get_all_templates_for_service,
|
||||
@@ -17,6 +19,16 @@ from app.models import Template, TemplateHistory, TemplateRedacted
|
||||
from tests.app.db import create_template
|
||||
|
||||
|
||||
def template_query_count():
|
||||
stmt = select(func.count()).select_from(Template)
|
||||
return db.session.execute(stmt).scalar() or 0
|
||||
|
||||
|
||||
def template_history_query_count():
|
||||
stmt = select(func.count()).select_from(TemplateHistory)
|
||||
return db.session.execute(stmt).scalar() or 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"template_type, subject",
|
||||
[
|
||||
@@ -37,7 +49,7 @@ def test_create_template(sample_service, sample_user, template_type, subject):
|
||||
template = Template(**data)
|
||||
dao_create_template(template)
|
||||
|
||||
assert Template.query.count() == 1
|
||||
assert template_query_count() == 1
|
||||
assert len(dao_get_all_templates_for_service(sample_service.id)) == 1
|
||||
assert (
|
||||
dao_get_all_templates_for_service(sample_service.id)[0].name
|
||||
@@ -50,11 +62,13 @@ def test_create_template(sample_service, sample_user, template_type, subject):
|
||||
|
||||
|
||||
def test_create_template_creates_redact_entry(sample_service):
|
||||
assert TemplateRedacted.query.count() == 0
|
||||
stmt = select(func.count()).select_from(TemplateRedacted)
|
||||
assert db.session.execute(stmt).scalar() == 0
|
||||
|
||||
template = create_template(sample_service)
|
||||
|
||||
redacted = TemplateRedacted.query.one()
|
||||
stmt = select(TemplateRedacted)
|
||||
redacted = db.session.execute(stmt).scalars().one()
|
||||
assert redacted.template_id == template.id
|
||||
assert redacted.redact_personalisation is False
|
||||
assert redacted.updated_by_id == sample_service.created_by_id
|
||||
@@ -79,7 +93,8 @@ def test_update_template(sample_service, sample_user):
|
||||
|
||||
|
||||
def test_redact_template(sample_template):
|
||||
redacted = TemplateRedacted.query.one()
|
||||
stmt = select(TemplateRedacted)
|
||||
redacted = db.session.execute(stmt).scalars().one()
|
||||
assert redacted.template_id == sample_template.id
|
||||
assert redacted.redact_personalisation is False
|
||||
|
||||
@@ -96,7 +111,7 @@ def test_get_all_templates_for_service(service_factory):
|
||||
service_1 = service_factory.get("service 1", email_from="service.1")
|
||||
service_2 = service_factory.get("service 2", email_from="service.2")
|
||||
|
||||
assert Template.query.count() == 2
|
||||
assert template_query_count() == 2
|
||||
assert len(dao_get_all_templates_for_service(service_1.id)) == 1
|
||||
assert len(dao_get_all_templates_for_service(service_2.id)) == 1
|
||||
|
||||
@@ -119,7 +134,7 @@ def test_get_all_templates_for_service(service_factory):
|
||||
content="Template content",
|
||||
)
|
||||
|
||||
assert Template.query.count() == 5
|
||||
assert template_query_count() == 5
|
||||
assert len(dao_get_all_templates_for_service(service_1.id)) == 3
|
||||
assert len(dao_get_all_templates_for_service(service_2.id)) == 2
|
||||
|
||||
@@ -144,7 +159,7 @@ def test_get_all_templates_for_service_is_alphabetised(sample_service):
|
||||
service=sample_service,
|
||||
)
|
||||
|
||||
assert Template.query.count() == 3
|
||||
assert template_query_count() == 3
|
||||
assert (
|
||||
dao_get_all_templates_for_service(sample_service.id)[0].name
|
||||
== "Sample Template 1"
|
||||
@@ -171,7 +186,7 @@ def test_get_all_templates_for_service_is_alphabetised(sample_service):
|
||||
|
||||
|
||||
def test_get_all_returns_empty_list_if_no_templates(sample_service):
|
||||
assert Template.query.count() == 0
|
||||
assert template_query_count() == 0
|
||||
assert len(dao_get_all_templates_for_service(sample_service.id)) == 0
|
||||
|
||||
|
||||
@@ -257,8 +272,8 @@ def test_get_template_by_id_and_service_returns_none_if_no_template(
|
||||
def test_create_template_creates_a_history_record_with_current_data(
|
||||
sample_service, sample_user
|
||||
):
|
||||
assert Template.query.count() == 0
|
||||
assert TemplateHistory.query.count() == 0
|
||||
assert template_query_count() == 0
|
||||
assert template_history_query_count() == 0
|
||||
data = {
|
||||
"name": "Sample Template",
|
||||
"template_type": TemplateType.EMAIL,
|
||||
@@ -270,10 +285,12 @@ def test_create_template_creates_a_history_record_with_current_data(
|
||||
template = Template(**data)
|
||||
dao_create_template(template)
|
||||
|
||||
assert Template.query.count() == 1
|
||||
assert template_query_count() == 1
|
||||
|
||||
template_from_db = Template.query.first()
|
||||
template_history = TemplateHistory.query.first()
|
||||
stmt = select(Template)
|
||||
template_from_db = db.session.execute(stmt).scalars().first()
|
||||
stmt = select(TemplateHistory)
|
||||
template_history = db.session.execute(stmt).scalars().first()
|
||||
|
||||
assert template_from_db.id == template_history.id
|
||||
assert template_from_db.name == template_history.name
|
||||
@@ -286,8 +303,8 @@ def test_create_template_creates_a_history_record_with_current_data(
|
||||
def test_update_template_creates_a_history_record_with_current_data(
|
||||
sample_service, sample_user
|
||||
):
|
||||
assert Template.query.count() == 0
|
||||
assert TemplateHistory.query.count() == 0
|
||||
assert template_query_count() == 0
|
||||
assert template_history_query_count() == 0
|
||||
data = {
|
||||
"name": "Sample Template",
|
||||
"template_type": TemplateType.EMAIL,
|
||||
@@ -301,22 +318,26 @@ def test_update_template_creates_a_history_record_with_current_data(
|
||||
|
||||
created = dao_get_all_templates_for_service(sample_service.id)[0]
|
||||
assert created.name == "Sample Template"
|
||||
assert Template.query.count() == 1
|
||||
assert Template.query.first().version == 1
|
||||
assert TemplateHistory.query.count() == 1
|
||||
assert template_query_count() == 1
|
||||
stmt = select(Template)
|
||||
assert db.session.execute(stmt).scalars().first().version == 1
|
||||
assert template_history_query_count() == 1
|
||||
|
||||
created.name = "new name"
|
||||
dao_update_template(created)
|
||||
|
||||
assert Template.query.count() == 1
|
||||
assert TemplateHistory.query.count() == 2
|
||||
assert template_query_count() == 1
|
||||
assert template_history_query_count() == 2
|
||||
|
||||
template_from_db = Template.query.first()
|
||||
stmt = select(Template)
|
||||
template_from_db = db.session.execute(stmt).scalars().first()
|
||||
|
||||
assert template_from_db.version == 2
|
||||
|
||||
assert TemplateHistory.query.filter_by(name="Sample Template").one().version == 1
|
||||
assert TemplateHistory.query.filter_by(name="new name").one().version == 2
|
||||
stmt = select(TemplateHistory).filter_by(name="Sample Template")
|
||||
assert db.session.execute(stmt).scalars().one().version == 1
|
||||
stmt = select(TemplateHistory).filter_by(name="new name")
|
||||
assert db.session.execute(stmt).scalars().one().version == 2
|
||||
|
||||
|
||||
def test_get_template_history_version(sample_user, sample_service, sample_template):
|
||||
|
||||
Reference in New Issue
Block a user