diff --git a/app/main/views/send.py b/app/main/views/send.py index 37415b54b..1c3e9f206 100644 --- a/app/main/views/send.py +++ b/app/main/views/send.py @@ -44,7 +44,7 @@ manage_templates_page_headings = { def get_page_headings(template_type): - if current_user.has_permissions(session.get('service_id', ''), 'manage_service'): + if current_user.has_permissions(['manage_service']): return manage_service_page_headings[template_type] else: return manage_templates_page_headings[template_type] diff --git a/app/notify_client/models.py b/app/notify_client/models.py index 3cf6497b6..b881e8390 100644 --- a/app/notify_client/models.py +++ b/app/notify_client/models.py @@ -1,4 +1,5 @@ from flask.ext.login import (UserMixin, login_fresh) +from flask import session class User(UserMixin): @@ -81,7 +82,9 @@ class User(UserMixin): def permissions(self, permissions): raise AttributeError("Read only property") - def has_permissions(self, service_id, permissions, or_=False): + def has_permissions(self, permissions, service_id=None, or_=False): + if service_id is None: + service_id = session.get('service_id', '') if service_id in self._permissions: if or_: return any([x in self._permissions[service_id] for x in permissions]) diff --git a/app/templates/main_nav.html b/app/templates/main_nav.html index cfc6cd0d1..96a360727 100644 --- a/app/templates/main_nav.html +++ b/app/templates/main_nav.html @@ -2,26 +2,26 @@
- {% if current_user.has_permissions(session.get('service_id', ''), ['manage_templates']) %} + {% if current_user.has_permissions(['manage_templates']) %} Add a new template {% endif %}
diff --git a/app/utils.py b/app/utils.py index 965939028..11eb18876 100644 --- a/app/utils.py +++ b/app/utils.py @@ -100,10 +100,8 @@ def user_has_permissions(*permissions, or_=False): def wrap(func): @wraps(func) def wrap_func(*args, **kwargs): - # We are making the assumption that the user is logged in. from flask_login import current_user - service_id = session.get('service_id', '') - if current_user and current_user.has_permissions(service_id, permissions, or_=or_): + if current_user and current_user.has_permissions(permissions, or_=or_): return func(*args, **kwargs) else: abort(403) diff --git a/tests/app/main/test_utils.py b/tests/app/main/test_utils.py index 67262c360..95b435f6e 100644 --- a/tests/app/main/test_utils.py +++ b/tests/app/main/test_utils.py @@ -6,18 +6,53 @@ from app.main.views.index import index from werkzeug.exceptions import Forbidden -# def test_user_has_permissions(app_, -# api_user_active, -# mock_get_user, -# mock_get_user_by_email, -# mock_login): -# with app_.test_request_context(): -# with app_.test_client() as client: -# client.login(api_user_active) -# decorator = user_has_permissions('something') -# decorated_index = decorator(index) -# try: -# response = decorated_index() -# pytest.fail("Failed to throw a forbidden exception") -# except Forbidden: -# pass +def test_user_has_permissions_on_endpoint_fail(app_, + api_user_active, + mock_login, + mock_get_user_with_permissions): + with app_.test_request_context(): + with app_.test_client() as client: + client.login(api_user_active) + decorator = user_has_permissions('something') + decorated_index = decorator(index) + try: + response = decorated_index() + pytest.fail("Failed to throw a forbidden exception") + except Forbidden: + pass + + +def test_user_has_permissions_success(app_, + api_user_active, + mock_login, + mock_get_user_with_permissions): + with app_.test_request_context(): + with app_.test_client() as client: + client.login(api_user_active) + decorator = user_has_permissions('manage_users') + decorated_index = decorator(index) + response = decorated_index() + + +def test_user_has_permissions_or(app_, + api_user_active, + mock_login, + mock_get_user_with_permissions): + with app_.test_request_context(): + with app_.test_client() as client: + client.login(api_user_active) + decorator = user_has_permissions('something', 'manage_users', or_=True) + decorated_index = decorator(index) + response = decorated_index() + + +def test_user_has_permissions_multiple(app_, + api_user_active, + mock_login, + mock_get_user_with_permissions): + with app_.test_request_context(): + with app_.test_client() as client: + client.login(api_user_active) + decorator = user_has_permissions('manage_templates', 'manage_users') + decorated_index = decorator(index) + response = decorated_index() diff --git a/tests/conftest.py b/tests/conftest.py index 81c159cf2..a3d4f54aa 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -305,6 +305,15 @@ def mock_get_user_by_email(mocker, api_user_active): return mocker.patch('app.user_api_client.get_user_by_email', side_effect=_get_user) +@pytest.fixture(scope='function') +def mock_get_user_with_permissions(mocker, api_user_active): + def _get_user(id): + api_user_active._permissions[''] = ['manage_users', 'manage_templates', 'manage_settings'] + return api_user_active + return mocker.patch( + 'app.user_api_client.get_user', side_effect=_get_user) + + @pytest.fixture(scope='function') def mock_dont_get_user_by_email(mocker): @@ -523,7 +532,7 @@ def mock_get_jobs(mocker): @pytest.fixture(scope='function') def mock_has_permissions(mocker): - def _has_permission(service_id, permissions, or_=False): + def _has_permission(permissions, service_id=None, or_=False): return True return mocker.patch( 'app.notify_client.user_api_client.User.has_permissions',