Merge pull request #2385 from alphagov/raise-if-fail-to-write-history

Raise exception if history can’t be written
This commit is contained in:
Chris Hill-Scott
2019-04-17 11:17:49 +01:00
committed by GitHub
4 changed files with 76 additions and 19 deletions

View File

@@ -1,5 +1,5 @@
import itertools
from functools import wraps, partial
from functools import wraps
from app import db
from app.history_meta import create_history
@@ -20,18 +20,55 @@ def transactional(func):
return commit_or_rollback
def version_class(model_class, history_cls=None):
create_hist = partial(create_history, history_cls=history_cls)
class VersionOptions():
def __init__(self, model_class, history_class=None, must_write_history=True):
self.model_class = model_class
self.history_class = history_class
self.must_write_history = must_write_history
def version_class(*version_options):
if len(version_options) == 1 and not isinstance(version_options[0], VersionOptions):
version_options = (VersionOptions(version_options[0]),)
def versioned(func):
@wraps(func)
def record_version(*args, **kwargs):
func(*args, **kwargs)
history_objects = [create_hist(obj) for obj in
itertools.chain(db.session.new, db.session.dirty)
if isinstance(obj, model_class)]
for h_obj in history_objects:
db.session.add(h_obj)
session_objects = []
for version_option in version_options:
tmp_session_objects = [
(
session_object, version_option.history_class
)
for session_object in itertools.chain(
db.session.new, db.session.dirty
)
if isinstance(
session_object, version_option.model_class
)
]
if tmp_session_objects == [] and version_option.must_write_history:
raise RuntimeError((
'Can\'t record history for {} '
'(something in your code has casued the database to '
'flush the session early so there\'s nothing to '
'copy into the history table)'
).format(version_option.model_class.__name__))
session_objects += tmp_session_objects
for session_object, history_class in session_objects:
db.session.add(
create_history(session_object, history_cls=history_class)
)
return record_version
return versioned

View File

@@ -9,7 +9,8 @@ from flask import current_app
from app import db
from app.dao.dao_utils import (
transactional,
version_class
version_class,
VersionOptions,
)
from app.dao.email_branding_dao import dao_get_email_branding_by_name
from app.dao.letter_branding_dao import dao_get_letter_branding_by_name
@@ -127,9 +128,11 @@ def dao_fetch_all_services_by_user(user_id, only_active=False):
@transactional
@version_class(Service)
@version_class(Template, TemplateHistory)
@version_class(ApiKey)
@version_class(
VersionOptions(ApiKey, must_write_history=False),
VersionOptions(Service),
VersionOptions(Template, history_class=TemplateHistory, must_write_history=False),
)
def dao_archive_service(service_id):
# have to eager load templates and api keys so that we don't flush when we loop through them
# to ensure that db.session still contains the models when it comes to creating history objects
@@ -383,8 +386,10 @@ def dao_fetch_todays_stats_for_all_services(include_from_test_key=True, only_act
@transactional
@version_class(Service)
@version_class(ApiKey)
@version_class(
VersionOptions(ApiKey, must_write_history=False),
VersionOptions(Service),
)
def dao_suspend_service(service_id):
# have to eager load api keys so that we don't flush when we loop through them
# to ensure that db.session still contains the models when it comes to creating history objects
@@ -392,12 +397,12 @@ def dao_suspend_service(service_id):
joinedload('api_keys'),
).filter(Service.id == service_id).one()
service.active = False
for api_key in service.api_keys:
if not api_key.expiry_date:
api_key.expiry_date = datetime.utcnow()
service.active = False
@transactional
@version_class(Service)

View File

@@ -11,12 +11,15 @@ from app.models import (
)
from app.dao.dao_utils import (
transactional,
version_class
version_class,
VersionOptions,
)
@transactional
@version_class(Template, TemplateHistory)
@version_class(
VersionOptions(Template, history_class=TemplateHistory)
)
def dao_create_template(template):
template.id = uuid.uuid4() # must be set now so version history model can use same id
template.archived = False
@@ -36,7 +39,9 @@ def dao_create_template(template):
@transactional
@version_class(Template, TemplateHistory)
@version_class(
VersionOptions(Template, history_class=TemplateHistory)
)
def dao_update_template(template):
if template.archived:
template.folder = None

View File

@@ -864,6 +864,16 @@ def test_dao_fetch_todays_stats_for_all_services_can_exclude_from_test_key(notif
assert stats[0].count == 2
@freeze_time('2001-01-01T23:59:00')
def test_dao_suspend_service_with_no_api_keys(notify_db_session):
service = create_service()
dao_suspend_service(service.id)
service = Service.query.get(service.id)
assert not service.active
assert service.name == service.name
assert service.api_keys == []
@freeze_time('2001-01-01T23:59:00')
def test_dao_suspend_service_marks_service_as_inactive_and_expires_api_keys(notify_db_session):
service = create_service()