diff --git a/app/dao/services_dao.py b/app/dao/services_dao.py index f55b6ec4c..c5679289c 100644 --- a/app/dao/services_dao.py +++ b/app/dao/services_dao.py @@ -127,7 +127,8 @@ def dao_update_service(service): db.session.add(service) -def dao_add_user_to_service(service, user, permissions=[]): +def dao_add_user_to_service(service, user, permissions=None): + permissions = permissions or [] try: from app.dao.permissions_dao import permission_dao service.users.append(user) @@ -214,7 +215,8 @@ def fetch_todays_total_message_count(service_id): def _stats_for_service_query(service_id): return db.session.query( Notification.notification_type, - Notification.status, + # see dao_fetch_todays_stats_for_all_services for why we have this label + Notification.status.label('status'), func.count(Notification.id).label('count') ).filter( Notification.service_id == service_id, @@ -232,13 +234,13 @@ def dao_fetch_monthly_historical_stats_by_template_for_service(service_id, year) start_date, end_date = get_financial_year(year) sq = db.session.query( NotificationHistory.template_id, - NotificationHistory.status, + # see dao_fetch_todays_stats_for_all_services for why we have this label + NotificationHistory.status.label('status'), month.label('month'), func.count().label('count') ).filter( NotificationHistory.service_id == service_id, NotificationHistory.created_at.between(start_date, end_date) - ).group_by( month, NotificationHistory.template_id, @@ -249,7 +251,7 @@ def dao_fetch_monthly_historical_stats_by_template_for_service(service_id, year) Template.id.label('template_id'), Template.name, Template.template_type, - sq.c.status, + sq.c.status.label('status'), sq.c.count.label('count'), sq.c.month ).join( @@ -267,7 +269,8 @@ def dao_fetch_monthly_historical_stats_for_service(service_id, year): start_date, end_date = get_financial_year(year) rows = db.session.query( NotificationHistory.notification_type, - NotificationHistory.status, + # see dao_fetch_todays_stats_for_all_services for why we have this label + NotificationHistory.status.label('status'), month, func.count(NotificationHistory.id).label('count') ).filter( @@ -306,7 +309,9 @@ def dao_fetch_monthly_historical_stats_for_service(service_id, year): def dao_fetch_todays_stats_for_all_services(include_from_test_key=True): query = db.session.query( Notification.notification_type, - Notification.status, + # this label is necessary as the column has a different name under the hood (_status_enum / _status_fkey), + # if we query the Notification object there is a hybrid property to translate, but here there isn't anything. + Notification.status.label('status'), Notification.service_id, func.count(Notification.id).label('count') ).filter( @@ -336,7 +341,8 @@ def fetch_stats_by_date_range_for_all_services(start_date, end_date, include_fro query = db.session.query( table.notification_type, - table.status, + # see dao_fetch_todays_stats_for_all_services for why we have this label + table.status.label('status'), table.service_id, func.count(table.id).label('count') ).filter( diff --git a/app/errors.py b/app/errors.py index 1c9c1691d..7e7790bc8 100644 --- a/app/errors.py +++ b/app/errors.py @@ -92,7 +92,7 @@ def register_errors(blueprint): @blueprint.errorhandler(SQLAlchemyError) def db_error(e): current_app.logger.exception(e) - if e.orig.pgerror and \ + if hasattr(e, 'orig') and hasattr(e.orig, 'pgerror') and e.orig.pgerror and \ ('duplicate key value violates unique constraint "services_name_key"' in e.orig.pgerror or 'duplicate key value violates unique constraint "services_email_from_key"' in e.orig.pgerror): return jsonify( diff --git a/app/models.py b/app/models.py index 0ef2de359..605b7bab8 100644 --- a/app/models.py +++ b/app/models.py @@ -1,8 +1,9 @@ import time import uuid import datetime -from flask import url_for +from flask import url_for, current_app +from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.dialects.postgresql import ( UUID, JSON @@ -46,7 +47,12 @@ class HistoryModel: def update_from_original(self, original): for c in self.__table__.columns: - setattr(self, c.name, getattr(original, c.name)) + # in some cases, columns may have different names to their underlying db column - so only copy those + # that we can, and leave it up to subclasses to deal with any oddities/properties etc. + if hasattr(original, c.name): + setattr(self, c.name, getattr(original, c.name)) + else: + current_app.logger.debug('{} has no column {} to copy from'.format(original, c.name)) class User(db.Model): @@ -621,6 +627,12 @@ NOTIFICATION_STATUS_TYPES = [ NOTIFICATION_STATUS_TYPES_ENUM = db.Enum(*NOTIFICATION_STATUS_TYPES, name='notify_status_type') +class NotificationStatusTypes(db.Model): + __tablename__ = 'notification_status_types' + + name = db.Column(db.String(255), primary_key=True) + + class Notification(db.Model): __tablename__ = 'notifications' @@ -656,7 +668,15 @@ class Notification(db.Model): unique=False, nullable=True, onupdate=datetime.datetime.utcnow) - status = db.Column(NOTIFICATION_STATUS_TYPES_ENUM, index=True, nullable=False, default='created') + _status_enum = db.Column('status', NOTIFICATION_STATUS_TYPES_ENUM, index=True, nullable=False, default='created') + _status_fkey = db.Column( + 'notification_status', + db.String, + db.ForeignKey('notification_status_types.name'), + index=True, + nullable=True, + default='created' + ) reference = db.Column(db.String, nullable=True, index=True) client_reference = db.Column(db.String, index=True, nullable=True) _personalisation = db.Column(db.String, nullable=True) @@ -672,6 +692,15 @@ class Notification(db.Model): phone_prefix = db.Column(db.String, nullable=True) rate_multiplier = db.Column(db.Float(asdecimal=False), nullable=True) + @hybrid_property + def status(self): + return self._status_enum + + @status.setter + def status(self, status): + self._status_fkey = status + self._status_enum = status + @property def personalisation(self): if self._personalisation: @@ -844,7 +873,15 @@ class NotificationHistory(db.Model, HistoryModel): sent_at = db.Column(db.DateTime, index=False, unique=False, nullable=True) sent_by = db.Column(db.String, nullable=True) updated_at = db.Column(db.DateTime, index=False, unique=False, nullable=True) - status = db.Column(NOTIFICATION_STATUS_TYPES_ENUM, index=True, nullable=False, default='created') + _status_enum = db.Column('status', NOTIFICATION_STATUS_TYPES_ENUM, index=True, nullable=False, default='created') + _status_fkey = db.Column( + 'notification_status', + db.String, + db.ForeignKey('notification_status_types.name'), + index=True, + nullable=True, + default='created' + ) reference = db.Column(db.String, nullable=True, index=True) client_reference = db.Column(db.String, nullable=True) @@ -855,8 +892,18 @@ class NotificationHistory(db.Model, HistoryModel): @classmethod def from_original(cls, notification): history = super().from_original(notification) + history.status = notification.status return history + @hybrid_property + def status(self): + return self._status_enum + + @status.setter + def status(self, status): + self._status_fkey = status + self._status_enum = status + INVITED_USER_STATUS_TYPES = ['pending', 'accepted', 'cancelled'] diff --git a/tests/app/dao/test_notification_dao.py b/tests/app/dao/test_notification_dao.py index b5461e036..9c5006de7 100644 --- a/tests/app/dao/test_notification_dao.py +++ b/tests/app/dao/test_notification_dao.py @@ -311,6 +311,8 @@ def test_should_by_able_to_update_status_by_id(sample_template, sample_job, mmg_ data = _notification_json(sample_template, job_id=sample_job.id, status='sending') notification = Notification(**data) dao_create_notification(notification) + assert notification._status_enum == 'sending' + assert notification._status_fkey == 'sending' assert Notification.query.get(notification.id).status == 'sending' @@ -321,6 +323,8 @@ def test_should_by_able_to_update_status_by_id(sample_template, sample_job, mmg_ assert updated.updated_at == datetime(2000, 1, 2, 12, 0, 0) assert Notification.query.get(notification.id).status == 'delivered' assert notification.updated_at == datetime(2000, 1, 2, 12, 0, 0) + assert notification._status_enum == 'delivered' + assert notification._status_fkey == 'delivered' def test_should_not_update_status_by_id_if_not_sending_and_does_not_update_job(notify_db, notify_db_session): @@ -825,7 +829,7 @@ def test_get_notification_billable_unit_count_per_month(notify_db, notify_db_ses ) == months -def test_update_notification(sample_notification, sample_template): +def test_update_notification(sample_notification): assert sample_notification.status == 'created' sample_notification.status = 'failed' dao_update_notification(sample_notification) @@ -833,6 +837,24 @@ def test_update_notification(sample_notification, sample_template): assert notification_from_db.status == 'failed' +def test_update_notification_with_no_notification_status(sample_notification): + # specifically, it has an old enum status, but not a new status (because the upgrade script has just run) + sample_notification._status_fkey = None + sample_notification._enum_status = 'created' + dao_update_notification(sample_notification) + + assert sample_notification.status == 'created' + assert sample_notification._enum_status == 'created' + assert sample_notification._status_fkey == None + + sample_notification.status = 'failed' + dao_update_notification(sample_notification) + notification_from_db = Notification.query.get(sample_notification.id) + assert notification_from_db.status == 'failed' + assert notification_from_db._status_enum == 'failed' + assert notification_from_db._status_fkey == 'failed' + + @freeze_time("2016-01-10 12:00:00.000000") def test_should_delete_notifications_after_seven_days(notify_db, notify_db_session): assert len(Notification.query.all()) == 0 diff --git a/tests/app/service/test_rest.py b/tests/app/service/test_rest.py index 19f42c880..2e6fa9289 100644 --- a/tests/app/service/test_rest.py +++ b/tests/app/service/test_rest.py @@ -1240,17 +1240,16 @@ def test_get_monthly_notification_stats(mocker, client, sample_service, url, exp assert json.loads(response.get_data(as_text=True)) == expected_json -def test_get_services_with_detailed_flag(notify_api, notify_db, notify_db_session): +def test_get_services_with_detailed_flag(client, notify_db, notify_db_session): notifications = [ create_sample_notification(notify_db, notify_db_session), create_sample_notification(notify_db, notify_db_session), create_sample_notification(notify_db, notify_db_session, key_type=KEY_TYPE_TEST) ] - with notify_api.test_request_context(), notify_api.test_client() as client: - resp = client.get( - '/service?detailed=True', - headers=[create_authorization_header()] - ) + resp = client.get( + '/service?detailed=True', + headers=[create_authorization_header()] + ) assert resp.status_code == 200 data = json.loads(resp.get_data(as_text=True))['data'] diff --git a/tests/conftest.py b/tests/conftest.py index b08059725..61cd17ce9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -76,7 +76,8 @@ def notify_db_session(notify_db): "job_status", "provider_details_history", "template_process_type", - "dvla_organisation"]: + "dvla_organisation", + "notification_status_types"]: notify_db.engine.execute(tbl.delete()) notify_db.session.commit()