157 lines
4.3 KiB
Python
157 lines
4.3 KiB
Python
import json
|
|
import os
|
|
import pathlib
|
|
import logging.config
|
|
import logging.handlers
|
|
|
|
import aioboto3
|
|
import contextlib
|
|
from contextlib import asynccontextmanager
|
|
from collections import defaultdict
|
|
from typing import List
|
|
from http import HTTPStatus
|
|
|
|
import httpx
|
|
from fastapi import FastAPI, Request
|
|
from fastapi.encoders import jsonable_encoder
|
|
from fastapi.exceptions import RequestValidationError
|
|
from fastapi.middleware import Middleware
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from fastapi.responses import JSONResponse
|
|
|
|
import nltk
|
|
from starlette import status
|
|
|
|
from ielts_be.api import router
|
|
from ielts_be.configs import DependencyInjector
|
|
from ielts_be.exceptions import CustomException
|
|
from ielts_be.middlewares import AuthenticationMiddleware, AuthBackend
|
|
from ielts_be.services.impl import OpenAIWhisper
|
|
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(_app: FastAPI):
|
|
"""
|
|
Startup and Shutdown logic is in this lifespan method
|
|
|
|
https://fastapi.tiangolo.com/advanced/events/
|
|
"""
|
|
|
|
# NLTK required datasets download
|
|
nltk.download('words')
|
|
nltk.download("punkt")
|
|
|
|
# AWS Polly client instantiation
|
|
context_stack = contextlib.AsyncExitStack()
|
|
session = aioboto3.Session()
|
|
polly_client = await context_stack.enter_async_context(
|
|
session.client(
|
|
'polly',
|
|
region_name='eu-west-1',
|
|
aws_secret_access_key=os.getenv("AWS_SECRET_ACCESS_KEY"),
|
|
aws_access_key_id=os.getenv("AWS_ACCESS_KEY_ID")
|
|
)
|
|
)
|
|
|
|
http_client = httpx.AsyncClient()
|
|
stt = OpenAIWhisper()
|
|
|
|
DependencyInjector(
|
|
polly_client,
|
|
http_client,
|
|
stt
|
|
).inject()
|
|
|
|
# Setup logging
|
|
config_file = pathlib.Path("./ielts_be/configs/logging/logging_config.json")
|
|
with open(config_file) as f_in:
|
|
config = json.load(f_in)
|
|
|
|
logging.config.dictConfig(config)
|
|
|
|
yield
|
|
|
|
stt.close()
|
|
await http_client.aclose()
|
|
await polly_client.close()
|
|
await context_stack.aclose()
|
|
|
|
|
|
def setup_listeners(_app: FastAPI) -> None:
|
|
@_app.exception_handler(RequestValidationError)
|
|
async def custom_form_validation_error(request, exc):
|
|
"""
|
|
Don't delete request param
|
|
"""
|
|
reformatted_message = defaultdict(list)
|
|
for pydantic_error in exc.errors():
|
|
loc, msg = pydantic_error["loc"], pydantic_error["msg"]
|
|
filtered_loc = loc[1:] if loc[0] in ("body", "query", "path") else loc
|
|
field_string = ".".join(filtered_loc)
|
|
if field_string == "cookie.refresh_token":
|
|
return JSONResponse(
|
|
status_code=401,
|
|
content={"error_code": 401, "message": HTTPStatus.UNAUTHORIZED.description},
|
|
)
|
|
reformatted_message[field_string].append(msg)
|
|
|
|
return JSONResponse(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
content=jsonable_encoder(
|
|
{"details": "Invalid request!", "errors": reformatted_message}
|
|
),
|
|
)
|
|
|
|
@_app.exception_handler(CustomException)
|
|
async def custom_exception_handler(request: Request, exc: CustomException):
|
|
"""
|
|
Don't delete request param
|
|
"""
|
|
return JSONResponse(
|
|
status_code=exc.code,
|
|
content={"error_code": exc.error_code, "message": exc.message},
|
|
)
|
|
|
|
@_app.exception_handler(Exception)
|
|
async def default_exception_handler(request: Request, exc: Exception):
|
|
"""
|
|
Don't delete request param
|
|
"""
|
|
return JSONResponse(
|
|
status_code=500,
|
|
content=str(exc),
|
|
)
|
|
|
|
|
|
def setup_middleware() -> List[Middleware]:
|
|
middleware = [
|
|
Middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"],
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
),
|
|
Middleware(
|
|
AuthenticationMiddleware,
|
|
backend=AuthBackend()
|
|
)
|
|
]
|
|
return middleware
|
|
|
|
|
|
def create_app() -> FastAPI:
|
|
env = os.getenv("ENV")
|
|
_app = FastAPI(
|
|
docs_url="/docs" if env != "production" else None,
|
|
redoc_url="/redoc" if env != "production" else None,
|
|
middleware=setup_middleware(),
|
|
lifespan=lifespan
|
|
)
|
|
_app.include_router(router)
|
|
setup_listeners(_app)
|
|
return _app
|
|
|
|
|
|
app = create_app()
|