Files
encoach_backend/app/services/impl/third_parties/openai.py

152 lines
5.4 KiB
Python

import json
import re
import logging
from typing import List, Optional, Callable, TypeVar
from openai import AsyncOpenAI
from openai.types.chat import ChatCompletionMessageParam
from app.services.abc import ILLMService
from app.helpers import count_tokens
from app.configs.constants import BLACKLISTED_WORDS
from pydantic import BaseModel
T = TypeVar('T', bound=BaseModel)
class OpenAI(ILLMService):
MAX_TOKENS = 4097
TRY_LIMIT = 2
def __init__(self, client: AsyncOpenAI):
self._client = client
self._logger = logging.getLogger(__name__)
self._default_model = "gpt-4o-2024-08-06"
async def prediction(
self,
model: str,
messages: List[ChatCompletionMessageParam],
fields_to_check: Optional[List[str]],
temperature: float,
check_blacklisted: bool = True,
token_count: int = -1
):
if token_count == -1:
token_count = self._count_total_tokens(messages)
return await self._prediction(model, messages, token_count, fields_to_check, temperature, 0, check_blacklisted)
async def _prediction(
self,
model: str,
messages: List[ChatCompletionMessageParam],
token_count: int,
fields_to_check: Optional[List[str]],
temperature: float,
try_count: int,
check_blacklisted: bool,
):
result = await self._client.chat.completions.create(
model=model,
max_tokens=int(self.MAX_TOKENS - token_count - 300),
temperature=float(temperature),
messages=messages,
response_format={"type": "json_object"}
)
result = result.choices[0].message.content
if check_blacklisted:
found_blacklisted_word = self._get_found_blacklisted_words(result)
if found_blacklisted_word is not None and try_count < self.TRY_LIMIT:
self._logger.warning("Result contains blacklisted words: " + str(found_blacklisted_word))
return await self._prediction(
model, messages, token_count, fields_to_check, temperature, (try_count + 1), check_blacklisted
)
elif found_blacklisted_word is not None and try_count >= self.TRY_LIMIT:
return ""
if fields_to_check is None:
return json.loads(result)
if not self._check_fields(result, fields_to_check) and try_count < self.TRY_LIMIT:
return await self._prediction(
model, messages, token_count, fields_to_check, temperature, (try_count + 1), check_blacklisted
)
return json.loads(result)
async def prediction_override(self, **kwargs):
return await self._client.chat.completions.create(
**kwargs
)
@staticmethod
def _get_found_blacklisted_words(text: str):
text_lower = text.lower()
for word in BLACKLISTED_WORDS:
if re.search(r'\b' + re.escape(word) + r'\b', text_lower):
return word
return None
@staticmethod
def _count_total_tokens(messages):
total_tokens = 0
for message in messages:
total_tokens += count_tokens(message["content"])["n_tokens"]
return total_tokens
@staticmethod
def _check_fields(obj, fields):
return all(field in obj for field in fields)
async def pydantic_prediction(
self,
messages: List[ChatCompletionMessageParam],
map_to_model: Callable,
json_scheme: str,
*,
model: Optional[str] = None,
temperature: Optional[float] = None,
max_retries: int = 3
) -> List[T] | T | None:
params = {
"messages": messages,
"response_format": {"type": "json_object"},
"model": model if model else self._default_model
}
if temperature:
params["temperature"] = temperature
attempt = 0
while attempt < 3:
result = await self._client.chat.completions.create(**params)
result_content = result.choices[0].message.content
try:
result_json = json.loads(result_content)
return map_to_model(result_json)
except Exception as e:
attempt += 1
self._logger.info(f"GPT returned malformed response: {result_content}\n {str(e)}")
params["messages"] = [
{
"role": "user",
"content": (
"Your previous response wasn't in the json format I've explicitly told you to output. "
f"In your next response, you will fix it and return me just the json I've asked."
)
},
{
"role": "user",
"content": (
f"Previous response: {result_content}\n"
f"JSON format: {json_scheme}"
f"Validation errors: {e}"
)
}
]
if attempt >= max_retries:
self._logger.error(f"Max retries exceeded!")
return None