Files
encoach_backend/app/services/impl/third_parties/openai.py
Carlos Mesquita 3cf9fa5cba Async release
2024-07-23 08:40:35 +01:00

97 lines
3.2 KiB
Python

import json
import re
import logging
from typing import List, Optional
from openai import AsyncOpenAI
from openai.types.chat import ChatCompletionMessageParam
from app.services.abc import ILLMService
from app.helpers import count_tokens
from app.configs.constants import BLACKLISTED_WORDS
class OpenAI(ILLMService):
MAX_TOKENS = 4097
TRY_LIMIT = 2
def __init__(self, client: AsyncOpenAI):
self._client = client
self._logger = logging.getLogger(__name__)
async def prediction(
self,
model: str,
messages: List[ChatCompletionMessageParam],
fields_to_check: Optional[List[str]],
temperature: float,
check_blacklisted: bool = True,
token_count: int = -1
):
if token_count == -1:
token_count = self._count_total_tokens(messages)
return await self._prediction(model, messages, token_count, fields_to_check, temperature, 0, check_blacklisted)
async def _prediction(
self,
model: str,
messages: List[ChatCompletionMessageParam],
token_count: int,
fields_to_check: Optional[List[str]],
temperature: float,
try_count: int,
check_blacklisted: bool,
):
result = await self._client.chat.completions.create(
model=model,
max_tokens=int(self.MAX_TOKENS - token_count - 300),
temperature=float(temperature),
messages=messages,
response_format={"type": "json_object"}
)
result = result.choices[0].message.content
if check_blacklisted:
found_blacklisted_word = self._get_found_blacklisted_words(result)
if found_blacklisted_word is not None and try_count < self.TRY_LIMIT:
self._logger.warning("Result contains blacklisted words: " + str(found_blacklisted_word))
return await self._prediction(
model, messages, token_count, fields_to_check, temperature, (try_count + 1), check_blacklisted
)
elif found_blacklisted_word is not None and try_count >= self.TRY_LIMIT:
return ""
if fields_to_check is None:
return json.loads(result)
if not self._check_fields(result, fields_to_check) and try_count < self.TRY_LIMIT:
return await self._prediction(
model, messages, token_count, fields_to_check, temperature, (try_count + 1), check_blacklisted
)
return json.loads(result)
async def prediction_override(self, **kwargs):
return await self._client.chat.completions.create(
**kwargs
)
@staticmethod
def _get_found_blacklisted_words(text: str):
text_lower = text.lower()
for word in BLACKLISTED_WORDS:
if re.search(r'\b' + re.escape(word) + r'\b', text_lower):
return word
return None
@staticmethod
def _count_total_tokens(messages):
total_tokens = 0
for message in messages:
total_tokens += count_tokens(message["content"])["n_tokens"]
return total_tokens
@staticmethod
def _check_fields(obj, fields):
return all(field in obj for field in fields)