mirror of
https://github.com/GSA/notifications-api.git
synced 2025-12-17 18:52:30 -05:00
This is fiendishly difficult error to discover on your own. It’s caused when, during the creation of a row in the database, you run a query on the same table, or a table that joins to the table you’re inserting into. What I think is happening is that the database is forced to flush the session before running the query in order to maintain consistency. This means that the session is clean by the time the history stuff comes to do its work, so there’s nothing for it to copy into the history table, and it silently fails to record history. Hopefully raising an exception will: - prevent this from failing silently - save whoever comes across this issue in the future a whole load of time
53 lines
1.5 KiB
Python
53 lines
1.5 KiB
Python
import itertools
|
|
from functools import wraps, partial
|
|
|
|
from app import db
|
|
from app.history_meta import create_history
|
|
|
|
|
|
def transactional(func):
|
|
@wraps(func)
|
|
def commit_or_rollback(*args, **kwargs):
|
|
from flask import current_app
|
|
try:
|
|
res = func(*args, **kwargs)
|
|
db.session.commit()
|
|
return res
|
|
except Exception as e:
|
|
current_app.logger.error(e)
|
|
db.session.rollback()
|
|
raise
|
|
return commit_or_rollback
|
|
|
|
|
|
def version_class(model_class, history_cls=None, must_write_history=True):
|
|
create_hist = partial(create_history, history_cls=history_cls)
|
|
|
|
def versioned(func):
|
|
@wraps(func)
|
|
def record_version(*args, **kwargs):
|
|
|
|
func(*args, **kwargs)
|
|
|
|
history_objects = [create_hist(obj) for obj in
|
|
itertools.chain(db.session.new, db.session.dirty)
|
|
if isinstance(obj, model_class)]
|
|
|
|
if history_objects == [] and must_write_history:
|
|
raise RuntimeError((
|
|
'Can\'t record history for {} '
|
|
'(something in your code has casued the database to '
|
|
'flush the session early so there\'s nothing to '
|
|
'copy into the history table)'
|
|
).format(model_class.__name__))
|
|
|
|
for h_obj in history_objects:
|
|
db.session.add(h_obj)
|
|
|
|
return record_version
|
|
return versioned
|
|
|
|
|
|
def dao_rollback():
|
|
db.session.rollback()
|