import json import re import logging from typing import List, Optional, Callable, TypeVar from openai import AsyncOpenAI from openai.types.chat import ChatCompletionMessageParam from ielts_be.services.abc import ILLMService from ielts_be.helpers import count_tokens from ielts_be.configs.constants import BLACKLISTED_WORDS from pydantic import BaseModel T = TypeVar('T', bound=BaseModel) class OpenAI(ILLMService): MAX_TOKENS = 4097 TRY_LIMIT = 2 def __init__(self, client: AsyncOpenAI): self._client = client self._logger = logging.getLogger(__name__) self._default_model = "gpt-4o" async def prediction( self, model: str, messages: List[ChatCompletionMessageParam], fields_to_check: Optional[List[str]], temperature: float, check_blacklisted: bool = True, token_count: int = -1 ): if token_count == -1: token_count = self._count_total_tokens(messages) return await self._prediction(model, messages, token_count, fields_to_check, temperature, 0, check_blacklisted) async def _prediction( self, model: str, messages: List[ChatCompletionMessageParam], token_count: int, fields_to_check: Optional[List[str]], temperature: float, try_count: int, check_blacklisted: bool, ): result = await self._client.chat.completions.create( model=model, max_tokens=int(self.MAX_TOKENS - token_count - 300), temperature=float(temperature), messages=messages, response_format={"type": "json_object"} ) result = result.choices[0].message.content if check_blacklisted: found_blacklisted_word = self._get_found_blacklisted_words(result) if found_blacklisted_word is not None and try_count < self.TRY_LIMIT: self._logger.warning("Result contains blacklisted words: " + str(found_blacklisted_word)) return await self._prediction( model, messages, token_count, fields_to_check, temperature, (try_count + 1), check_blacklisted ) elif found_blacklisted_word is not None and try_count >= self.TRY_LIMIT: return "" if fields_to_check is None: return json.loads(result) if not self._check_fields(result, fields_to_check) and try_count < self.TRY_LIMIT: return await self._prediction( model, messages, token_count, fields_to_check, temperature, (try_count + 1), check_blacklisted ) return json.loads(result) async def prediction_override(self, **kwargs): return await self._client.chat.completions.create( **kwargs ) @staticmethod def _get_found_blacklisted_words(text: str): text_lower = text.lower() for word in BLACKLISTED_WORDS: if re.search(r'\b' + re.escape(word) + r'\b', text_lower): return word return None @staticmethod def _count_total_tokens(messages): total_tokens = 0 for message in messages: # Skip when content isn't text message_content = message.get("content", None) if message_content is not None and isinstance(message_content, str): total_tokens += count_tokens(message["content"])["n_tokens"] return total_tokens @staticmethod def _check_fields(obj, fields): return all(field in obj for field in fields) async def pydantic_prediction( self, messages: List[ChatCompletionMessageParam], map_to_model: Callable, json_scheme: str, *, model: Optional[str] = None, temperature: Optional[float] = None, max_retries: int = 3 ) -> List[T] | T | None: params = { "messages": messages, "response_format": {"type": "json_object"}, "model": model if model else self._default_model } if temperature: params["temperature"] = temperature attempt = 0 while attempt < 3: result = await self._client.chat.completions.create(**params) result_content = result.choices[0].message.content try: result_json = json.loads(result_content) print(str(result_json)) return map_to_model(result_json) except Exception as e: attempt += 1 self._logger.info(f"GPT returned malformed response: {result_content}\n {str(e)}") params["messages"] = [ { "role": "user", "content": ( "Your previous response wasn't in the json format I've explicitly told you to output. " f"In your next response, you will fix it and return me just the json I've asked." ) }, { "role": "user", "content": ( f"Previous response: {result_content}\n" f"JSON format: {json_scheme}" f"Validation errors: {e}" ) } ] if attempt >= max_retries: self._logger.error(f"Max retries exceeded!") return None