diff --git a/app/authentication/auth.py b/app/authentication/auth.py index ab32404dc..8c9aac294 100644 --- a/app/authentication/auth.py +++ b/app/authentication/auth.py @@ -5,54 +5,65 @@ from notifications_python_client.errors import TokenDecodeError, TokenExpiredErr from app.dao.api_key_dao import get_model_api_keys -def authentication_response(message, code): - return jsonify(result='error', - message={"token": [message]} - ), code +class AuthError(Exception): + def __init__(self, message, code): + self.message = {"token": [message]} + self.code = code -def requires_auth(): - auth_header = request.headers.get('Authorization', None) +def get_auth_token(req): + auth_header = req.headers.get('Authorization', None) if not auth_header: - return authentication_response('Unauthorized, authentication token must be provided', 401) + raise AuthError('Unauthorized, authentication token must be provided', 401) auth_scheme = auth_header[:7] if auth_scheme != 'Bearer ': - return authentication_response('Unauthorized, authentication bearer scheme must be used', 401) + raise AuthError('Unauthorized, authentication bearer scheme must be used', 401) - auth_token = auth_header[7:] + return auth_header[7:] + + +def requires_auth(): + auth_token = get_auth_token(request) try: client = get_token_issuer(auth_token) except TokenDecodeError: - return authentication_response("Invalid token: signature", 403) + raise AuthError("Invalid token: signature", 403) if client == current_app.config.get('ADMIN_CLIENT_USER_NAME'): - errors_resp = get_decode_errors(auth_token, current_app.config.get('ADMIN_CLIENT_SECRET'), expiry_date=None) - return errors_resp + return handle_admin_key(auth_token, current_app.config.get('ADMIN_CLIENT_SECRET')) - secret_keys = get_model_api_keys(client) - for api_key in secret_keys: - errors_resp = get_decode_errors(auth_token, api_key.unsigned_secret, api_key.expiry_date) - if not errors_resp: - if api_key.expiry_date: - return authentication_response("Invalid token: revoked", 403) - else: - _request_ctx_stack.top.api_user = api_key - return + api_keys = get_model_api_keys(client) - if not secret_keys: - errors_resp = authentication_response("Invalid token: no api keys for service", 403) - current_app.logger.info(errors_resp) - return errors_resp + for api_key in api_keys: + try: + get_decode_errors(auth_token, api_key.unsigned_secret) + except TokenDecodeError: + continue + + if api_key.expiry_date: + raise AuthError("Invalid token: revoked", 403) + + _request_ctx_stack.top.api_user = api_key + return + + if not api_keys: + raise AuthError("Invalid token: no api keys for service", 403) + else: + raise AuthError("Invalid token: signature", 403) -def get_decode_errors(auth_token, unsigned_secret, expiry_date=None): +def handle_admin_key(auth_token, secret): + try: + get_decode_errors(auth_token, secret) + return + except TokenDecodeError as e: + raise AuthError("Invalid token: signature", 403) + + +def get_decode_errors(auth_token, unsigned_secret): try: decode_jwt_token(auth_token, unsigned_secret) - except TokenExpiredError: - return authentication_response("Invalid token: expired", 403) - except TokenDecodeError: - return authentication_response("Invalid token: signature", 403) - else: - return None + except TokenExpiredError as e: + raise AuthError("Invalid token: expired") diff --git a/app/errors.py b/app/errors.py index 9d90dee15..046f21690 100644 --- a/app/errors.py +++ b/app/errors.py @@ -5,6 +5,7 @@ from flask import ( from sqlalchemy.exc import SQLAlchemyError, DataError from sqlalchemy.orm.exc import NoResultFound from marshmallow import ValidationError +from app.authentication.auth import AuthError class InvalidRequest(Exception): @@ -23,6 +24,10 @@ class InvalidRequest(Exception): def register_errors(blueprint): + @blueprint.app_errorhandler(AuthError) + def authentication_error(error): + return jsonify(result='error', message=error.message), error.code + @blueprint.app_errorhandler(ValidationError) def validation_error(error): current_app.logger.error(error) diff --git a/tests/app/authentication/test_authentication.py b/tests/app/authentication/test_authentication.py index daa430190..11b54f279 100644 --- a/tests/app/authentication/test_authentication.py +++ b/tests/app/authentication/test_authentication.py @@ -171,9 +171,9 @@ def test_authentication_returns_token_expired_when_service_uses_expired_key_and_ assert data['message'] == {"token": ['Invalid token: revoked']} -def test_authentication_returns_error_when_api_client_has_no_secrets(notify_api, - notify_db, - notify_db_session): +def test_authentication_returns_error_when_admin_client_has_no_secrets(notify_api, + notify_db, + notify_db_session): with notify_api.test_request_context(): with notify_api.test_client() as client: api_secret = notify_api.config.get('ADMIN_CLIENT_SECRET')