diff --git a/app/models.py b/app/models.py index e78fd18ea..7e064345a 100644 --- a/app/models.py +++ b/app/models.py @@ -133,6 +133,38 @@ class User(db.Model): def check_password(self, password): return check_hash(password, self._password) + def get_permissions(self): + from app.dao.permissions_dao import permission_dao + retval = {} + for x in permission_dao.get_permissions_by_user_id(self.id): + service_id = str(x.service_id) + if service_id not in retval: + retval[service_id] = [] + retval[service_id].append(x.permission) + return retval + + def serialize(self): + return { + 'id': self.id, + 'name': self.name, + 'email_address': self.email_address, + 'auth_type': self.auth_type, + 'current_session_id': self.current_session_id, + 'failed_login_count': self.failed_login_count, + 'logged_in_at': self.logged_in_at.strftime(DATETIME_FORMAT) if self.logged_in_at else None, + 'mobile_number': self.mobile_number, + 'organisations': [x.id for x in self.organisations if x.active], + 'password_changed_at': ( + self.password_changed_at.strftime('%Y-%m-%d %H:%M:%S.%f') + if self.password_changed_at + else None + ), + 'permissions': self.get_permissions(), + 'platform_admin': self.platform_admin, + 'services': [x.id for x in self.services if x.active], + 'state': self.state, + } + user_to_service = db.Table( 'user_to_service', diff --git a/app/organisation/rest.py b/app/organisation/rest.py index 04100ce2a..d436fc656 100644 --- a/app/organisation/rest.py +++ b/app/organisation/rest.py @@ -20,7 +20,6 @@ from app.organisation.organisation_schema import ( post_link_service_to_organisation_schema, ) from app.schema_validation import validate -from app.schemas import user_schema organisation_blueprint = Blueprint('organisation', __name__) register_errors(organisation_blueprint) @@ -98,15 +97,13 @@ def get_organisation_services(organisation_id): @organisation_blueprint.route('//users/', methods=['POST']) def add_user_to_organisation(organisation_id, user_id): new_org_user = dao_add_user_to_organisation(organisation_id, user_id) - return jsonify(data=user_schema.dump(new_org_user).data), 200 + return jsonify(data=new_org_user.serialize()) @organisation_blueprint.route('//users', methods=['GET']) def get_organisation_users(organisation_id): org_users = dao_get_users_for_organisation(organisation_id) - - result = user_schema.dump(org_users, many=True) - return jsonify(data=result.data) + return jsonify(data=[x.serialize() for x in org_users]) @organisation_blueprint.route('/unique', methods=["GET"]) diff --git a/app/schemas.py b/app/schemas.py index 2131bb9b7..b9d3b5abf 100644 --- a/app/schemas.py +++ b/app/schemas.py @@ -665,19 +665,15 @@ class UnarchivedTemplateSchema(BaseSchema): raise ValidationError('Template has been deleted', 'template') -user_schema = UserSchema() -user_schema_load_json = UserSchema(load_json=True) +# should not be used on its own for dumping - only for loading +create_user_schema = UserSchema() user_update_schema_load_json = UserUpdateAttributeSchema(load_json=True, partial=True) user_update_password_schema_load_json = UserUpdatePasswordSchema(load_json=True, partial=True) service_schema = ServiceSchema() -service_schema_load_json = ServiceSchema(load_json=True) detailed_service_schema = DetailedServiceSchema() template_schema = TemplateSchema() -template_schema_load_json = TemplateSchema(load_json=True) api_key_schema = ApiKeySchema() -api_key_schema_load_json = ApiKeySchema(load_json=True) job_schema = JobSchema() -job_schema_load_json = JobSchema(load_json=True) sms_admin_notification_schema = SmsAdminNotificationSchema() sms_template_notification_schema = SmsTemplateNotificationSchema() job_sms_template_notification_schema = JobSmsTemplateNotificationSchema() diff --git a/app/service/rest.py b/app/service/rest.py index 5ed4a5dee..a36b50664 100644 --- a/app/service/rest.py +++ b/app/service/rest.py @@ -82,7 +82,6 @@ from app.service.send_notification import send_one_off_notification from app.schemas import ( service_schema, api_key_schema, - user_schema, permission_schema, notification_with_template_schema, notifications_filter_schema, @@ -258,8 +257,7 @@ def get_api_keys(service_id, key_id=None): @service_blueprint.route('//users', methods=['GET']) def get_users_for_service(service_id): fetched = dao_fetch_service_by_id(service_id) - result = user_schema.dump(fetched.users, many=True) - return jsonify(data=result.data) + return jsonify(data=[x.serialize() for x in fetched.users]) @service_blueprint.route('//users/', methods=['POST']) diff --git a/app/user/rest.py b/app/user/rest.py index b4011ebe7..c29ec8342 100644 --- a/app/user/rest.py +++ b/app/user/rest.py @@ -31,7 +31,7 @@ from app.notifications.process_notifications import ( ) from app.schemas import ( email_data_request_schema, - user_schema, + create_user_schema, permission_schema, user_update_schema_load_json, user_update_password_schema_load_json @@ -67,13 +67,14 @@ def handle_integrity_error(exc): @user_blueprint.route('', methods=['POST']) def create_user(): - user_to_create, errors = user_schema.load(request.get_json()) + user_to_create, errors = create_user_schema.load(request.get_json()) req_json = request.get_json() if not req_json.get('password', None): errors.update({'password': ['Missing data for required field.']}) raise InvalidRequest(errors, status_code=400) save_model_user(user_to_create, pwd=req_json.get('password')) - return jsonify(data=user_schema.dump(user_to_create).data), 201 + result = user_to_create.serialize() + return jsonify(data=result), 201 @user_blueprint.route('/', methods=['POST']) @@ -84,7 +85,7 @@ def update_user_attribute(user_id): if errors: raise InvalidRequest(errors, status_code=400) save_user_attribute(user_to_update, update_dict=update_dct) - return jsonify(data=user_schema.dump(user_to_update).data), 200 + return jsonify(data=user_to_update.serialize()), 200 @user_blueprint.route('//activate', methods=['POST']) @@ -95,14 +96,14 @@ def activate_user(user_id): user.state = 'active' save_model_user(user) - return jsonify(data=user_schema.dump(user).data), 200 + return jsonify(data=user.serialize()), 200 @user_blueprint.route('//reset-failed-login-count', methods=['POST']) def user_reset_failed_login_count(user_id): user_to_update = get_user_by_id(user_id=user_id) reset_failed_login_count(user_to_update) - return jsonify(data=user_schema.dump(user_to_update).data), 200 + return jsonify(data=user_to_update.serialize()), 200 @user_blueprint.route('//verify/password', methods=['POST']) @@ -324,8 +325,8 @@ def send_already_registered_email(user_id): @user_blueprint.route('', methods=['GET']) def get_user(user_id=None): users = get_user_by_id(user_id=user_id) - result = user_schema.dump(users, many=True) if isinstance(users, list) else user_schema.dump(users) - return jsonify(data=result.data) + result = [x.serialize() for x in users] if isinstance(users, list) else users.serialize() + return jsonify(data=result) @user_blueprint.route('//service//permission', methods=['POST']) @@ -350,9 +351,8 @@ def get_by_email(): error = 'Invalid request. Email query string param required' raise InvalidRequest(error, status_code=400) fetched_user = get_user_by_email(email) - result = user_schema.dump(fetched_user) - - return jsonify(data=result.data) + result = fetched_user.serialize() + return jsonify(data=result) @user_blueprint.route('/reset-password', methods=['POST']) @@ -392,7 +392,7 @@ def update_password(user_id): if errors: raise InvalidRequest(errors, status_code=400) update_user_password(user, pwd) - return jsonify(data=user_schema.dump(user).data), 200 + return jsonify(data=user.serialize()), 200 def _create_reset_password_url(email): diff --git a/tests/app/user/test_rest.py b/tests/app/user/test_rest.py index cd635cd5f..5b79cd61b 100644 --- a/tests/app/user/test_rest.py +++ b/tests/app/user/test_rest.py @@ -36,28 +36,53 @@ def test_get_user_list(admin_request, sample_service): assert sorted(expected_permissions) == sorted(fetched['permissions'][str(sample_service.id)]) -def test_get_user(client, sample_service): +def test_get_user(admin_request, sample_service, sample_organisation): """ Tests GET endpoint '/' to retrieve a single service. """ sample_user = sample_service.users[0] - header = create_authorization_header() - resp = client.get(url_for('user.get_user', - user_id=sample_user.id), - headers=[header]) - assert resp.status_code == 200 - json_resp = json.loads(resp.get_data(as_text=True)) + sample_user.organisations = [sample_organisation] + json_resp = admin_request.get( + 'user.get_user', + user_id=sample_user.id + ) expected_permissions = default_service_permissions fetched = json_resp['data'] - assert str(sample_user.id) == fetched['id'] - assert sample_user.name == fetched['name'] - assert sample_user.mobile_number == fetched['mobile_number'] - assert sample_user.email_address == fetched['email_address'] - assert sample_user.state == fetched['state'] + assert fetched['id'] == str(sample_user.id) + assert fetched['name'] == sample_user.name + assert fetched['mobile_number'] == sample_user.mobile_number + assert fetched['email_address'] == sample_user.email_address + assert fetched['state'] == sample_user.state assert fetched['auth_type'] == SMS_AUTH_TYPE - assert sorted(expected_permissions) == sorted(fetched['permissions'][str(sample_service.id)]) + assert fetched['permissions'].keys() == {str(sample_service.id)} + assert fetched['services'] == [str(sample_service.id)] + assert fetched['organisations'] == [str(sample_organisation.id)] + assert sorted(fetched['permissions'][str(sample_service.id)]) == sorted(expected_permissions) + + +def test_get_user_doesnt_return_inactive_services_and_orgs(admin_request, sample_service, sample_organisation): + """ + Tests GET endpoint '/' to retrieve a single service. + """ + sample_service.active = False + sample_organisation.active = False + + sample_user = sample_service.users[0] + sample_user.organisations = [sample_organisation] + + json_resp = admin_request.get( + 'user.get_user', + user_id=sample_user.id + ) + + fetched = json_resp['data'] + + assert fetched['id'] == str(sample_user.id) + assert fetched['services'] == [] + assert fetched['organisations'] == [] + assert fetched['permissions'] == {} def test_post_user(client, notify_db, notify_db_session):