49 lines
1.4 KiB
Python
49 lines
1.4 KiB
Python
import os
|
|
from typing import Tuple
|
|
|
|
import jwt
|
|
from jwt import InvalidTokenError
|
|
from pydantic import BaseModel, Field
|
|
from starlette.authentication import AuthenticationBackend
|
|
from starlette.middleware.authentication import (
|
|
AuthenticationMiddleware as BaseAuthenticationMiddleware,
|
|
)
|
|
from starlette.requests import HTTPConnection
|
|
|
|
|
|
class Session(BaseModel):
|
|
authenticated: bool = Field(False, description="Is user authenticated?")
|
|
|
|
|
|
class AuthBackend(AuthenticationBackend):
|
|
async def authenticate(
|
|
self, conn: HTTPConnection
|
|
) -> Tuple[bool, Session]:
|
|
session = Session()
|
|
authorization: str = conn.headers.get("Authorization")
|
|
if not authorization:
|
|
return False, session
|
|
|
|
try:
|
|
scheme, token = authorization.split(" ")
|
|
if scheme.lower() != "bearer":
|
|
return False, session
|
|
except ValueError:
|
|
return False, session
|
|
|
|
jwt_secret_key = os.getenv("JWT_SECRET_KEY")
|
|
if not jwt_secret_key:
|
|
return False, session
|
|
|
|
try:
|
|
jwt.decode(token, jwt_secret_key, algorithms=["HS256"])
|
|
except InvalidTokenError:
|
|
return False, session
|
|
|
|
session.authenticated = True
|
|
return True, session
|
|
|
|
|
|
class AuthenticationMiddleware(BaseAuthenticationMiddleware):
|
|
pass
|