Common
oAuth2 with JWT in python Tornado
20. July 2018
0

I was asked to prototype a small API wrapper around a price prediction algorithm which should protect endpoints with oAuth2 using JWT (Json Web Tokens). The following code is not battle testet but might give an idea on how it works.

I implemented an annotation which can wrap a handler function to protect it. The annotation accepts a list of scopes (simple strings) a client needs to have in order to access the resource. The very important part of validating the JWT token with the auth server is not implemented yet!

from functools import wraps
import jwt

from api.util import get_logger


def oauth2_protected(required_scopes: list):
    def decorator(f):
        @wraps(f)
        async def wrapper(*args, **kwds):
            self = args[0]

            if not self.auth_disabled:
                try:
                    jwt_token = self.request.headers.get('authorization', None)

                    # TODO: verify token!!!
                    payload = jwt.decode(jwt_token, verify=False)

                    client_id = payload.get("client_id")
                    user_name = payload.get("user_name")
                    provided_scopes = payload.get("scope", [])

                    for scope in required_scopes:
                        if scope not in provided_scopes:
                            raise NotAuthorizedError("insufficient permissions",
                                                     client_id,
                                                     user_name,
                                                     provided_scopes,
                                                     required_scopes)

                except NotAuthorizedError as e:
                    self.logger.warning(e)
                    self.set_status(401)
                    self.set_header("Content-Type", "text/plain")
                    self.set_header("WWW-Authenticate",
                                    "Bearer error=\"invalid_token\", error_description=\"insufficient permissions\"")
                    self.write("Insufficient permissions")
                    self.flush()
                    self.finish()
                    return None

                except Exception as e:
                    self.logger.warning("Client failed to authenticate: {}".format(e))
                    self.set_status(401)
                    self.set_header("Content-Type", "text/plain")
                    self.set_header("WWW-Authenticate", "Bearer error=\"invalid_token\"")
                    self.write("Invalid token")
                    self.flush()
                    self.finish()
                    return None

            return await f(*args, **kwds)

        return wrapper

    return decorator


class NotAuthorizedError(Exception):
    def __init__(self, message, client_id, user_name, provided_scopes: list, required_scopes: list):
        super().__init__(message)
        self.client_id = client_id
        self.user_name = user_name
        self.provided_scopes = provided_scopes
        self.required_scopes = required_scopes

    @property
    def identity(self):
        if self.client_id and self.user_name:
            return "{} [{}]".format(self.client_id, self.user_name)
        elif self.client_id:
            return self.client_id
        else:
            return "unknown"

    def __str__(self):
        return "Access denied for client {}. Provided scopes: {} but requires {}".format(self.identity,
                                                                                         self.provided_scopes,
                                                                                         self.required_scopes)

 

Using the annotation is pretty simple:

import json
import uuid
from tornado.web import RequestHandler

from api.oauth2_auth_handler import oauth2_protected
from api.util import get_logger


class PricePredictionRequestHandler(RequestHandler):
    def __init__(self, application, request, **kwargs):
        super().__init__(application, request, **kwargs)
        self.logger = get_logger()

    def initialize(self, prediction_service: PricePredictionService, auth_disabled: bool):
        self.prediction_service = prediction_service
        self.auth_disabled = auth_disabled

    def data_received(self, chunk):
        pass

    @oauth2_protected(["priceprediction.get"])
    async def get(self, user: str = None):
        try:
            self.write("Hello world")
        except Exception as e:
            request_id = uuid.uuid4()

            self.logger.error("Request (id: {}) failed with {}".format(request_id, e))
            self.logger.exception(e)
            self.set_status(500)
            self.set_header("Content-Type", "text/plain")
            self.write("prediction failed for reasons, please check the logs with request id: {}".format(request_id))

        self.flush()
        self.finish()

Every client which has valid credentials and the required scope “price prediction.read” can access the endpoint.