Do all version table writes in one commit

The behaviour of stacking the version decorators does not work as
expected.

What you would expect to happen is that each decorator causes a history
row to be written for its respective model object.

What actually happens is that the first decorator adds history records
to the database session, but then causes the database session to commit.
This means that subsequent uses of this decorator find a clean session,
and therefore no changes to copy to their respective history tables.

This commit changes the intended use of the decorator so that it is only
used once per function, and accepts multiple definitions of what to
record history for. This way it can record everything that needs to go
into the history before doing anything that would risk flushing the
session.
This commit is contained in:
Chris Hill-Scott
2019-03-07 17:39:38 +00:00
parent c257ec105c
commit eeb90bed57
3 changed files with 61 additions and 26 deletions

View File

@@ -1,5 +1,5 @@
import itertools import itertools
from functools import wraps, partial from functools import wraps
from app import db from app import db
from app.history_meta import create_history from app.history_meta import create_history
@@ -20,8 +20,18 @@ def transactional(func):
return commit_or_rollback return commit_or_rollback
def version_class(model_class, history_cls=None, must_write_history=True): class VersionOptions():
create_hist = partial(create_history, history_cls=history_cls)
def __init__(self, model_class, history_class=None, must_write_history=True):
self.model_class = model_class
self.history_class = history_class
self.must_write_history = must_write_history
def version_class(*version_options):
if len(version_options) == 1 and not isinstance(version_options[0], VersionOptions):
version_options = (VersionOptions(version_options[0]),)
def versioned(func): def versioned(func):
@wraps(func) @wraps(func)
@@ -29,20 +39,35 @@ def version_class(model_class, history_cls=None, must_write_history=True):
func(*args, **kwargs) func(*args, **kwargs)
history_objects = [create_hist(obj) for obj in session_objects = []
itertools.chain(db.session.new, db.session.dirty)
if isinstance(obj, model_class)]
if history_objects == [] and must_write_history: for version_option in version_options:
tmp_session_objects = [
(
session_object, version_option.history_class
)
for session_object in itertools.chain(
db.session.new, db.session.dirty
)
if isinstance(
session_object, version_option.model_class
)
]
if tmp_session_objects == [] and version_option.must_write_history:
raise RuntimeError(( raise RuntimeError((
'Can\'t record history for {} ' 'Can\'t record history for {} '
'(something in your code has casued the database to ' '(something in your code has casued the database to '
'flush the session early so there\'s nothing to ' 'flush the session early so there\'s nothing to '
'copy into the history table)' 'copy into the history table)'
).format(model_class.__name__)) ).format(version_option.model_class.__name__))
for h_obj in history_objects: session_objects += tmp_session_objects
db.session.add(h_obj)
for session_object, history_class in session_objects:
db.session.add(
create_history(session_object, history_cls=history_class)
)
return record_version return record_version
return versioned return versioned

View File

@@ -9,7 +9,8 @@ from flask import current_app
from app import db from app import db
from app.dao.dao_utils import ( from app.dao.dao_utils import (
transactional, transactional,
version_class version_class,
VersionOptions,
) )
from app.dao.organisation_dao import dao_get_organisation_by_email_address from app.dao.organisation_dao import dao_get_organisation_by_email_address
from app.dao.service_sms_sender_dao import insert_service_sms_sender from app.dao.service_sms_sender_dao import insert_service_sms_sender
@@ -117,9 +118,11 @@ def dao_fetch_all_services_by_user(user_id, only_active=False):
@transactional @transactional
@version_class(Service) @version_class(
@version_class(Template, TemplateHistory, must_write_history=False) VersionOptions(ApiKey, must_write_history=False),
@version_class(ApiKey, must_write_history=False) VersionOptions(Service),
VersionOptions(Template, history_class=TemplateHistory, must_write_history=False),
)
def dao_archive_service(service_id): def dao_archive_service(service_id):
# have to eager load templates and api keys so that we don't flush when we loop through them # have to eager load templates and api keys so that we don't flush when we loop through them
# to ensure that db.session still contains the models when it comes to creating history objects # to ensure that db.session still contains the models when it comes to creating history objects
@@ -372,8 +375,10 @@ def dao_fetch_todays_stats_for_all_services(include_from_test_key=True, only_act
@transactional @transactional
@version_class(Service) @version_class(
@version_class(ApiKey, must_write_history=False) VersionOptions(ApiKey, must_write_history=False),
VersionOptions(Service),
)
def dao_suspend_service(service_id): def dao_suspend_service(service_id):
# have to eager load api keys so that we don't flush when we loop through them # have to eager load api keys so that we don't flush when we loop through them
# to ensure that db.session still contains the models when it comes to creating history objects # to ensure that db.session still contains the models when it comes to creating history objects
@@ -381,12 +386,12 @@ def dao_suspend_service(service_id):
joinedload('api_keys'), joinedload('api_keys'),
).filter(Service.id == service_id).one() ).filter(Service.id == service_id).one()
service.active = False
for api_key in service.api_keys: for api_key in service.api_keys:
if not api_key.expiry_date: if not api_key.expiry_date:
api_key.expiry_date = datetime.utcnow() api_key.expiry_date = datetime.utcnow()
service.active = False
@transactional @transactional
@version_class(Service) @version_class(Service)

View File

@@ -11,12 +11,15 @@ from app.models import (
) )
from app.dao.dao_utils import ( from app.dao.dao_utils import (
transactional, transactional,
version_class version_class,
VersionOptions,
) )
@transactional @transactional
@version_class(Template, TemplateHistory) @version_class(
VersionOptions(Template, history_class=TemplateHistory)
)
def dao_create_template(template): def dao_create_template(template):
template.id = uuid.uuid4() # must be set now so version history model can use same id template.id = uuid.uuid4() # must be set now so version history model can use same id
template.archived = False template.archived = False
@@ -36,7 +39,9 @@ def dao_create_template(template):
@transactional @transactional
@version_class(Template, TemplateHistory) @version_class(
VersionOptions(Template, history_class=TemplateHistory)
)
def dao_update_template(template): def dao_update_template(template):
if template.archived: if template.archived:
template.folder = None template.folder = None