import os from typing import Tuple import jwt from jwt import InvalidTokenError from pydantic import BaseModel, Field from starlette.authentication import AuthenticationBackend from starlette.middleware.authentication import ( AuthenticationMiddleware as BaseAuthenticationMiddleware, ) from starlette.requests import HTTPConnection class Session(BaseModel): authenticated: bool = Field(False, description="Is user authenticated?") class AuthBackend(AuthenticationBackend): async def authenticate( self, conn: HTTPConnection ) -> Tuple[bool, Session]: session = Session() authorization: str = conn.headers.get("Authorization") if not authorization: return False, session try: scheme, token = authorization.split(" ") if scheme.lower() != "bearer": return False, session except ValueError: return False, session jwt_secret_key = os.getenv("JWT_SECRET_KEY") if not jwt_secret_key: return False, session try: jwt.decode(token, jwt_secret_key, algorithms=["HS256"]) except InvalidTokenError: return False, session session.authenticated = True return True, session class AuthenticationMiddleware(BaseAuthenticationMiddleware): pass