97 lines
3.2 KiB
Python
97 lines
3.2 KiB
Python
import json
|
|
import re
|
|
import logging
|
|
from typing import List, Optional
|
|
from openai import AsyncOpenAI
|
|
from openai.types.chat import ChatCompletionMessageParam
|
|
|
|
from app.services.abc import ILLMService
|
|
from app.helpers import count_tokens
|
|
from app.configs.constants import BLACKLISTED_WORDS
|
|
|
|
|
|
class OpenAI(ILLMService):
|
|
|
|
MAX_TOKENS = 4097
|
|
TRY_LIMIT = 2
|
|
|
|
def __init__(self, client: AsyncOpenAI):
|
|
self._client = client
|
|
self._logger = logging.getLogger(__name__)
|
|
|
|
async def prediction(
|
|
self,
|
|
model: str,
|
|
messages: List[ChatCompletionMessageParam],
|
|
fields_to_check: Optional[List[str]],
|
|
temperature: float,
|
|
check_blacklisted: bool = True,
|
|
token_count: int = -1
|
|
):
|
|
if token_count == -1:
|
|
token_count = self._count_total_tokens(messages)
|
|
return await self._prediction(model, messages, token_count, fields_to_check, temperature, 0, check_blacklisted)
|
|
|
|
async def _prediction(
|
|
self,
|
|
model: str,
|
|
messages: List[ChatCompletionMessageParam],
|
|
token_count: int,
|
|
fields_to_check: Optional[List[str]],
|
|
temperature: float,
|
|
try_count: int,
|
|
check_blacklisted: bool,
|
|
):
|
|
result = await self._client.chat.completions.create(
|
|
model=model,
|
|
max_tokens=int(self.MAX_TOKENS - token_count - 300),
|
|
temperature=float(temperature),
|
|
messages=messages,
|
|
response_format={"type": "json_object"}
|
|
)
|
|
result = result.choices[0].message.content
|
|
|
|
if check_blacklisted:
|
|
found_blacklisted_word = self._get_found_blacklisted_words(result)
|
|
|
|
if found_blacklisted_word is not None and try_count < self.TRY_LIMIT:
|
|
self._logger.warning("Result contains blacklisted words: " + str(found_blacklisted_word))
|
|
return await self._prediction(
|
|
model, messages, token_count, fields_to_check, temperature, (try_count + 1), check_blacklisted
|
|
)
|
|
elif found_blacklisted_word is not None and try_count >= self.TRY_LIMIT:
|
|
return ""
|
|
|
|
if fields_to_check is None:
|
|
return json.loads(result)
|
|
|
|
if not self._check_fields(result, fields_to_check) and try_count < self.TRY_LIMIT:
|
|
return await self._prediction(
|
|
model, messages, token_count, fields_to_check, temperature, (try_count + 1), check_blacklisted
|
|
)
|
|
|
|
return json.loads(result)
|
|
|
|
async def prediction_override(self, **kwargs):
|
|
return await self._client.chat.completions.create(
|
|
**kwargs
|
|
)
|
|
|
|
@staticmethod
|
|
def _get_found_blacklisted_words(text: str):
|
|
text_lower = text.lower()
|
|
for word in BLACKLISTED_WORDS:
|
|
if re.search(r'\b' + re.escape(word) + r'\b', text_lower):
|
|
return word
|
|
return None
|
|
|
|
@staticmethod
|
|
def _count_total_tokens(messages):
|
|
total_tokens = 0
|
|
for message in messages:
|
|
total_tokens += count_tokens(message["content"])["n_tokens"]
|
|
return total_tokens
|
|
|
|
@staticmethod
|
|
def _check_fields(obj, fields):
|
|
return all(field in obj for field in fields) |