import json import re import logging from typing import List, Optional from openai import AsyncOpenAI from openai.types.chat import ChatCompletionMessageParam from app.services.abc import ILLMService from app.helpers import count_tokens from app.configs.constants import BLACKLISTED_WORDS class OpenAI(ILLMService): MAX_TOKENS = 4097 TRY_LIMIT = 2 def __init__(self, client: AsyncOpenAI): self._client = client self._logger = logging.getLogger(__name__) async def prediction( self, model: str, messages: List[ChatCompletionMessageParam], fields_to_check: Optional[List[str]], temperature: float, check_blacklisted: bool = True, token_count: int = -1 ): if token_count == -1: token_count = self._count_total_tokens(messages) return await self._prediction(model, messages, token_count, fields_to_check, temperature, 0, check_blacklisted) async def _prediction( self, model: str, messages: List[ChatCompletionMessageParam], token_count: int, fields_to_check: Optional[List[str]], temperature: float, try_count: int, check_blacklisted: bool, ): result = await self._client.chat.completions.create( model=model, max_tokens=int(self.MAX_TOKENS - token_count - 300), temperature=float(temperature), messages=messages, response_format={"type": "json_object"} ) result = result.choices[0].message.content if check_blacklisted: found_blacklisted_word = self._get_found_blacklisted_words(result) if found_blacklisted_word is not None and try_count < self.TRY_LIMIT: self._logger.warning("Result contains blacklisted words: " + str(found_blacklisted_word)) return await self._prediction( model, messages, token_count, fields_to_check, temperature, (try_count + 1), check_blacklisted ) elif found_blacklisted_word is not None and try_count >= self.TRY_LIMIT: return "" if fields_to_check is None: return json.loads(result) if not self._check_fields(result, fields_to_check) and try_count < self.TRY_LIMIT: return await self._prediction( model, messages, token_count, fields_to_check, temperature, (try_count + 1), check_blacklisted ) return json.loads(result) async def prediction_override(self, **kwargs): return await self._client.chat.completions.create( **kwargs ) @staticmethod def _get_found_blacklisted_words(text: str): text_lower = text.lower() for word in BLACKLISTED_WORDS: if re.search(r'\b' + re.escape(word) + r'\b', text_lower): return word return None @staticmethod def _count_total_tokens(messages): total_tokens = 0 for message in messages: total_tokens += count_tokens(message["content"])["n_tokens"] return total_tokens @staticmethod def _check_fields(obj, fields): return all(field in obj for field in fields)