2016-04-20 17:25:20 +01:00
|
|
|
import itertools
|
2019-03-07 17:39:38 +00:00
|
|
|
from functools import wraps
|
2016-08-02 16:23:14 +01:00
|
|
|
|
2016-09-23 12:21:00 +01:00
|
|
|
from app import db
|
2016-08-02 16:23:14 +01:00
|
|
|
from app.history_meta import create_history
|
2016-04-14 15:09:59 +01:00
|
|
|
|
|
|
|
|
|
|
|
|
|
def transactional(func):
|
|
|
|
|
@wraps(func)
|
|
|
|
|
def commit_or_rollback(*args, **kwargs):
|
|
|
|
|
from flask import current_app
|
|
|
|
|
try:
|
2016-05-27 12:09:36 +01:00
|
|
|
res = func(*args, **kwargs)
|
2016-04-14 15:09:59 +01:00
|
|
|
db.session.commit()
|
2016-05-27 12:09:36 +01:00
|
|
|
return res
|
2016-04-14 15:09:59 +01:00
|
|
|
except Exception as e:
|
|
|
|
|
current_app.logger.error(e)
|
|
|
|
|
db.session.rollback()
|
|
|
|
|
raise
|
|
|
|
|
return commit_or_rollback
|
|
|
|
|
|
|
|
|
|
|
2019-03-07 17:39:38 +00:00
|
|
|
class VersionOptions():
|
|
|
|
|
|
|
|
|
|
def __init__(self, model_class, history_class=None, must_write_history=True):
|
|
|
|
|
self.model_class = model_class
|
|
|
|
|
self.history_class = history_class
|
|
|
|
|
self.must_write_history = must_write_history
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def version_class(*version_options):
|
|
|
|
|
|
|
|
|
|
if len(version_options) == 1 and not isinstance(version_options[0], VersionOptions):
|
|
|
|
|
version_options = (VersionOptions(version_options[0]),)
|
2016-08-02 16:23:14 +01:00
|
|
|
|
2016-04-21 18:10:57 +01:00
|
|
|
def versioned(func):
|
|
|
|
|
@wraps(func)
|
|
|
|
|
def record_version(*args, **kwargs):
|
2019-03-06 08:37:55 +00:00
|
|
|
|
2016-04-21 18:10:57 +01:00
|
|
|
func(*args, **kwargs)
|
2019-03-06 08:37:55 +00:00
|
|
|
|
2019-03-07 17:39:38 +00:00
|
|
|
session_objects = []
|
|
|
|
|
|
|
|
|
|
for version_option in version_options:
|
|
|
|
|
tmp_session_objects = [
|
|
|
|
|
(
|
|
|
|
|
session_object, version_option.history_class
|
|
|
|
|
)
|
|
|
|
|
for session_object in itertools.chain(
|
|
|
|
|
db.session.new, db.session.dirty
|
|
|
|
|
)
|
|
|
|
|
if isinstance(
|
|
|
|
|
session_object, version_option.model_class
|
|
|
|
|
)
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
if tmp_session_objects == [] and version_option.must_write_history:
|
|
|
|
|
raise RuntimeError((
|
|
|
|
|
'Can\'t record history for {} '
|
|
|
|
|
'(something in your code has casued the database to '
|
|
|
|
|
'flush the session early so there\'s nothing to '
|
|
|
|
|
'copy into the history table)'
|
|
|
|
|
).format(version_option.model_class.__name__))
|
2019-03-06 08:37:55 +00:00
|
|
|
|
2019-03-07 17:39:38 +00:00
|
|
|
session_objects += tmp_session_objects
|
2019-03-06 08:37:55 +00:00
|
|
|
|
2019-03-07 17:39:38 +00:00
|
|
|
for session_object, history_class in session_objects:
|
|
|
|
|
db.session.add(
|
|
|
|
|
create_history(session_object, history_cls=history_class)
|
|
|
|
|
)
|
2019-03-06 08:37:55 +00:00
|
|
|
|
2016-04-21 18:10:57 +01:00
|
|
|
return record_version
|
|
|
|
|
return versioned
|
2016-09-23 12:21:00 +01:00
|
|
|
|
|
|
|
|
|
|
|
|
|
def dao_rollback():
|
|
|
|
|
db.session.rollback()
|