Files
notifications-api/app/dao/dao_utils.py

97 lines
2.7 KiB
Python

import itertools
from contextlib import contextmanager
from functools import wraps
from app import db
from app.history_meta import create_history
def autocommit(func):
@wraps(func)
def commit_or_rollback(*args, **kwargs):
try:
res = func(*args, **kwargs)
if not db.session().in_nested_transaction():
db.session.commit()
return res
except Exception:
db.session.rollback()
raise
return commit_or_rollback
@contextmanager
def transaction():
try:
db.session.begin_nested()
yield
db.session.commit()
if not db.session().in_nested_transaction():
db.session.commit()
except Exception:
db.session.rollback()
raise
class VersionOptions:
def __init__(self, model_class, history_class=None, must_write_history=True):
self.model_class = model_class
self.history_class = history_class
self.must_write_history = must_write_history
def version_class(*version_options):
if len(version_options) == 1 and not isinstance(version_options[0], VersionOptions):
version_options = (VersionOptions(version_options[0]),)
def versioned(func):
@wraps(func)
def record_version(*args, **kwargs):
func(*args, **kwargs)
session_objects = []
for version_option in version_options:
tmp_session_objects = [
(session_object, version_option.history_class)
for session_object in itertools.chain(
db.session.new, db.session.dirty
)
if isinstance(session_object, version_option.model_class)
]
if tmp_session_objects == [] and version_option.must_write_history:
raise RuntimeError(
(
"Can't record history for {} "
"(something in your code has casued the database to "
"flush the session early so there's nothing to "
"copy into the history table)"
).format(version_option.model_class.__name__)
)
session_objects += tmp_session_objects
for session_object, history_class in session_objects:
db.session.add(
create_history(session_object, history_cls=history_class)
)
return record_version
return versioned
def dao_rollback():
db.session.rollback()
@autocommit
def dao_save_object(obj):
# add/update object in db
db.session.add(obj)