diff --git a/app/__init__.py b/app/__init__.py index 17f2828d0..007b52aee 100644 --- a/app/__init__.py +++ b/app/__init__.py @@ -1,12 +1,13 @@ import os -from flask import Flask +from flask import Flask, session from flask._compat import string_types from flask.ext import assets from flask.ext.sqlalchemy import SQLAlchemy from flask_login import LoginManager from flask_wtf import CsrfProtect from webassets.filter import get_filter +from werkzeug.exceptions import abort from config import configs @@ -22,8 +23,10 @@ def create_app(config_name): application.config.from_object(configs[config_name]) db.init_app(application) init_app(application) - csrf.init_app(application) + init_csrf(application) + login_manager.init_app(application) + login_manager.login_view = 'main.sign_in.render_sign_in' from app.main import main as main_blueprint application.register_blueprint(main_blueprint) @@ -31,6 +34,25 @@ def create_app(config_name): return application +def init_csrf(application): + csrf.init_app(application) + + @csrf.error_handler + def csrf_handler(reason): + if 'user_id' not in session: + application.logger.info( + u'csrf.session_expired: Redirecting user to log in page' + ) + + return application.login_manager.unauthorized() + + application.logger.info( + u'csrf.invalid_token: Aborting request, user_id: {user_id}', + extra={'user_id': session['user_id']}) + + abort(400, reason) + + def init_app(app): for key, value in app.config.items(): if key in os.environ: diff --git a/app/main/dao/users_dao.py b/app/main/dao/users_dao.py index 95d68eafa..1cb7aeea2 100644 --- a/app/main/dao/users_dao.py +++ b/app/main/dao/users_dao.py @@ -1,5 +1,5 @@ from app import db -from app.models import Users +from app.models import User from app.main.encryption import encrypt @@ -10,12 +10,12 @@ def insert_user(user): def get_user_by_id(id): - return Users.query.filter_by(id=id).first() + return User.query.filter_by(id=id).first() def get_all_users(): - return Users.query.all() + return User.query.all() def get_user_by_email(email_address): - return Users.query.filter_by(email_address=email_address).first() + return User.query.filter_by(email_address=email_address).first() diff --git a/app/main/views/sign_in.py b/app/main/views/sign_in.py index 6ffa19144..abf14e4b4 100644 --- a/app/main/views/sign_in.py +++ b/app/main/views/sign_in.py @@ -3,13 +3,19 @@ from datetime import datetime from flask import render_template, redirect, jsonify from flask_login import login_user +from app import login_manager from app.main import main from app.main.forms import LoginForm from app.main.dao import users_dao -from app.models import Users +from app.models import User from app.main.encryption import encrypt +@login_manager.user_loader +def load_user(user_id): + return users_dao.get_user_by_id(user_id) + + @main.route("/sign-in", methods=(['GET'])) def render_sign_in(): return render_template('signin.html', form=LoginForm()) @@ -40,12 +46,12 @@ def render_create_user(): def create_user_for_test(): form = LoginForm() if form.validate_on_submit(): - user = Users(email_address=form.email_address.data, - name=form.email_address.data, - password=form.password.data, - created_at=datetime.now(), - mobile_number='+447651234534', - role_id=1) + user = User(email_address=form.email_address.data, + name=form.email_address.data, + password=form.password.data, + created_at=datetime.now(), + mobile_number='+447651234534', + role_id=1) users_dao.insert_user(user) return redirect('/sign-in') diff --git a/app/models.py b/app/models.py index 08068b978..2ed53c257 100644 --- a/app/models.py +++ b/app/models.py @@ -12,7 +12,7 @@ class Roles(db.Model): role = db.Column(db.String, nullable=False, unique=True) -class Users(db.Model): +class User(db.Model): __tablename__ = 'users' id = db.Column(db.Integer, primary_key=True) @@ -63,7 +63,7 @@ class Users(db.Model): @staticmethod def load_user(user_id): - user = Users.query.filter_by(id=user_id).first() + user = User.query.filter_by(id=user_id).first() if user.is_active(): return user diff --git a/tests/app/main/dao/test_users_dao.py b/tests/app/main/dao/test_users_dao.py index 3f734ff87..694a0fe3b 100644 --- a/tests/app/main/dao/test_users_dao.py +++ b/tests/app/main/dao/test_users_dao.py @@ -3,17 +3,17 @@ from datetime import datetime import pytest import sqlalchemy -from app.models import Users +from app.models import User from app.main.dao import users_dao def test_insert_user_should_add_user(notifications_admin, notifications_admin_db): - user = Users(name='test insert', - password='somepassword', - email_address='test@insert.gov.uk', - mobile_number='+441234123412', - created_at=datetime.now(), - role_id=1) + user = User(name='test insert', + password='somepassword', + email_address='test@insert.gov.uk', + mobile_number='+441234123412', + created_at=datetime.now(), + role_id=1) users_dao.insert_user(user) saved_user = users_dao.get_user_by_id(user.id) @@ -21,24 +21,24 @@ def test_insert_user_should_add_user(notifications_admin, notifications_admin_db def test_insert_user_with_role_that_does_not_exist_fails(notifications_admin, notifications_admin_db): - user = Users(name='role does not exist', - password='somepassword', - email_address='test@insert.gov.uk', - mobile_number='+441234123412', - created_at=datetime.now(), - role_id=100) + user = User(name='role does not exist', + password='somepassword', + email_address='test@insert.gov.uk', + mobile_number='+441234123412', + created_at=datetime.now(), + role_id=100) with pytest.raises(sqlalchemy.exc.IntegrityError) as error: users_dao.insert_user(user) assert 'insert or update on table "users" violates foreign key constraint "users_role_id_fkey"' in str(error.value) def test_get_user_by_email(notifications_admin, notifications_admin_db): - user = Users(name='test_get_by_email', - password='somepassword', - email_address='email@example.gov.uk', - mobile_number='+441234153412', - created_at=datetime.now(), - role_id=1) + user = User(name='test_get_by_email', + password='somepassword', + email_address='email@example.gov.uk', + mobile_number='+441234153412', + created_at=datetime.now(), + role_id=1) users_dao.insert_user(user) retrieved = users_dao.get_user_by_email(user.email_address) @@ -46,24 +46,24 @@ def test_get_user_by_email(notifications_admin, notifications_admin_db): def test_get_all_users_returns_all_users(notifications_admin, notifications_admin_db): - user1 = Users(name='test one', - password='somepassword', - email_address='test1@get_all.gov.uk', - mobile_number='+441234123412', - created_at=datetime.now(), - role_id=1) - user2 = Users(name='test two', - password='some2ndpassword', - email_address='test2@get_all.gov.uk', - mobile_number='+441234123412', - created_at=datetime.now(), - role_id=1) - user3 = Users(name='test three', - password='some2ndpassword', - email_address='test2@get_all.gov.uk', - mobile_number='+441234123412', - created_at=datetime.now(), - role_id=1) + user1 = User(name='test one', + password='somepassword', + email_address='test1@get_all.gov.uk', + mobile_number='+441234123412', + created_at=datetime.now(), + role_id=1) + user2 = User(name='test two', + password='some2ndpassword', + email_address='test2@get_all.gov.uk', + mobile_number='+441234123412', + created_at=datetime.now(), + role_id=1) + user3 = User(name='test three', + password='some2ndpassword', + email_address='test2@get_all.gov.uk', + mobile_number='+441234123412', + created_at=datetime.now(), + role_id=1) users_dao.insert_user(user1) users_dao.insert_user(user2) diff --git a/tests/app/main/views/test_sign_in.py b/tests/app/main/views/test_sign_in.py index 90b6b1af0..255c9d90c 100644 --- a/tests/app/main/views/test_sign_in.py +++ b/tests/app/main/views/test_sign_in.py @@ -1,7 +1,7 @@ from datetime import datetime from app.main.dao import users_dao -from app.models import Users +from app.models import User def test_render_sign_in_returns_sign_in_template(notifications_admin): @@ -14,12 +14,12 @@ def test_render_sign_in_returns_sign_in_template(notifications_admin): def test_process_sign_in_return_2fa_template(notifications_admin, notifications_admin_db): - user = Users(email_address='valid@example.gov.uk', - password='val1dPassw0rd!', - mobile_number='+441234123123', - name='valid', - created_at=datetime.now(), - role_id=1) + user = User(email_address='valid@example.gov.uk', + password='val1dPassw0rd!', + mobile_number='+441234123123', + name='valid', + created_at=datetime.now(), + role_id=1) users_dao.insert_user(user) response = notifications_admin.test_client().post('/sign-in', data={'email_address': 'valid@example.gov.uk',