diff --git a/app/dao/dao_utils.py b/app/dao/dao_utils.py index c7af24845..fa038b55c 100644 --- a/app/dao/dao_utils.py +++ b/app/dao/dao_utils.py @@ -1,5 +1,5 @@ import itertools -from functools import wraps, partial +from functools import wraps from app import db from app.history_meta import create_history @@ -20,8 +20,18 @@ def transactional(func): return commit_or_rollback -def version_class(model_class, history_cls=None, must_write_history=True): - create_hist = partial(create_history, history_cls=history_cls) +class VersionOptions(): + + def __init__(self, model_class, history_class=None, must_write_history=True): + self.model_class = model_class + self.history_class = history_class + self.must_write_history = must_write_history + + +def version_class(*version_options): + + if len(version_options) == 1 and not isinstance(version_options[0], VersionOptions): + version_options = (VersionOptions(version_options[0]),) def versioned(func): @wraps(func) @@ -29,20 +39,35 @@ def version_class(model_class, history_cls=None, must_write_history=True): func(*args, **kwargs) - history_objects = [create_hist(obj) for obj in - itertools.chain(db.session.new, db.session.dirty) - if isinstance(obj, model_class)] + session_objects = [] - if history_objects == [] and must_write_history: - raise RuntimeError(( - 'Can\'t record history for {} ' - '(something in your code has casued the database to ' - 'flush the session early so there\'s nothing to ' - 'copy into the history table)' - ).format(model_class.__name__)) + for version_option in version_options: + tmp_session_objects = [ + ( + session_object, version_option.history_class + ) + for session_object in itertools.chain( + db.session.new, db.session.dirty + ) + if isinstance( + session_object, version_option.model_class + ) + ] - for h_obj in history_objects: - db.session.add(h_obj) + if tmp_session_objects == [] and version_option.must_write_history: + raise RuntimeError(( + 'Can\'t record history for {} ' + '(something in your code has casued the database to ' + 'flush the session early so there\'s nothing to ' + 'copy into the history table)' + ).format(version_option.model_class.__name__)) + + session_objects += tmp_session_objects + + for session_object, history_class in session_objects: + db.session.add( + create_history(session_object, history_cls=history_class) + ) return record_version return versioned diff --git a/app/dao/services_dao.py b/app/dao/services_dao.py index db9d8325e..c71d01851 100644 --- a/app/dao/services_dao.py +++ b/app/dao/services_dao.py @@ -9,7 +9,8 @@ from flask import current_app from app import db from app.dao.dao_utils import ( transactional, - version_class + version_class, + VersionOptions, ) from app.dao.organisation_dao import dao_get_organisation_by_email_address from app.dao.service_sms_sender_dao import insert_service_sms_sender @@ -117,9 +118,11 @@ def dao_fetch_all_services_by_user(user_id, only_active=False): @transactional -@version_class(Service) -@version_class(Template, TemplateHistory, must_write_history=False) -@version_class(ApiKey, must_write_history=False) +@version_class( + VersionOptions(ApiKey, must_write_history=False), + VersionOptions(Service), + VersionOptions(Template, history_class=TemplateHistory, must_write_history=False), +) def dao_archive_service(service_id): # have to eager load templates and api keys so that we don't flush when we loop through them # to ensure that db.session still contains the models when it comes to creating history objects @@ -372,8 +375,10 @@ def dao_fetch_todays_stats_for_all_services(include_from_test_key=True, only_act @transactional -@version_class(Service) -@version_class(ApiKey, must_write_history=False) +@version_class( + VersionOptions(ApiKey, must_write_history=False), + VersionOptions(Service), +) def dao_suspend_service(service_id): # have to eager load api keys so that we don't flush when we loop through them # to ensure that db.session still contains the models when it comes to creating history objects @@ -381,12 +386,12 @@ def dao_suspend_service(service_id): joinedload('api_keys'), ).filter(Service.id == service_id).one() - service.active = False - for api_key in service.api_keys: if not api_key.expiry_date: api_key.expiry_date = datetime.utcnow() + service.active = False + @transactional @version_class(Service) diff --git a/app/dao/templates_dao.py b/app/dao/templates_dao.py index 7861e20e1..63079ef68 100644 --- a/app/dao/templates_dao.py +++ b/app/dao/templates_dao.py @@ -11,12 +11,15 @@ from app.models import ( ) from app.dao.dao_utils import ( transactional, - version_class + version_class, + VersionOptions, ) @transactional -@version_class(Template, TemplateHistory) +@version_class( + VersionOptions(Template, history_class=TemplateHistory) +) def dao_create_template(template): template.id = uuid.uuid4() # must be set now so version history model can use same id template.archived = False @@ -36,7 +39,9 @@ def dao_create_template(template): @transactional -@version_class(Template, TemplateHistory) +@version_class( + VersionOptions(Template, history_class=TemplateHistory) +) def dao_update_template(template): if template.archived: template.folder = None