Fastapi refactor update
This commit is contained in:
@@ -1,150 +1,150 @@
|
||||
import json
|
||||
import re
|
||||
import logging
|
||||
from typing import List, Optional, Callable, TypeVar
|
||||
from openai import AsyncOpenAI
|
||||
from openai.types.chat import ChatCompletionMessageParam
|
||||
|
||||
from app.services.abc import ILLMService
|
||||
from app.helpers import count_tokens
|
||||
from app.configs.constants import BLACKLISTED_WORDS
|
||||
from pydantic import BaseModel
|
||||
|
||||
T = TypeVar('T', bound=BaseModel)
|
||||
|
||||
|
||||
class OpenAI(ILLMService):
|
||||
|
||||
MAX_TOKENS = 4097
|
||||
TRY_LIMIT = 2
|
||||
|
||||
def __init__(self, client: AsyncOpenAI):
|
||||
self._client = client
|
||||
self._logger = logging.getLogger(__name__)
|
||||
self._default_model = "gpt-4o-2024-08-06"
|
||||
|
||||
async def prediction(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[ChatCompletionMessageParam],
|
||||
fields_to_check: Optional[List[str]],
|
||||
temperature: float,
|
||||
check_blacklisted: bool = True,
|
||||
token_count: int = -1
|
||||
):
|
||||
if token_count == -1:
|
||||
token_count = self._count_total_tokens(messages)
|
||||
return await self._prediction(model, messages, token_count, fields_to_check, temperature, 0, check_blacklisted)
|
||||
|
||||
async def _prediction(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[ChatCompletionMessageParam],
|
||||
token_count: int,
|
||||
fields_to_check: Optional[List[str]],
|
||||
temperature: float,
|
||||
try_count: int,
|
||||
check_blacklisted: bool,
|
||||
):
|
||||
result = await self._client.chat.completions.create(
|
||||
model=model,
|
||||
max_tokens=int(self.MAX_TOKENS - token_count - 300),
|
||||
temperature=float(temperature),
|
||||
messages=messages,
|
||||
response_format={"type": "json_object"}
|
||||
)
|
||||
result = result.choices[0].message.content
|
||||
|
||||
if check_blacklisted:
|
||||
found_blacklisted_word = self._get_found_blacklisted_words(result)
|
||||
|
||||
if found_blacklisted_word is not None and try_count < self.TRY_LIMIT:
|
||||
self._logger.warning("Result contains blacklisted words: " + str(found_blacklisted_word))
|
||||
return await self._prediction(
|
||||
model, messages, token_count, fields_to_check, temperature, (try_count + 1), check_blacklisted
|
||||
)
|
||||
elif found_blacklisted_word is not None and try_count >= self.TRY_LIMIT:
|
||||
return ""
|
||||
|
||||
if fields_to_check is None:
|
||||
return json.loads(result)
|
||||
|
||||
if not self._check_fields(result, fields_to_check) and try_count < self.TRY_LIMIT:
|
||||
return await self._prediction(
|
||||
model, messages, token_count, fields_to_check, temperature, (try_count + 1), check_blacklisted
|
||||
)
|
||||
|
||||
return json.loads(result)
|
||||
|
||||
async def prediction_override(self, **kwargs):
|
||||
return await self._client.chat.completions.create(
|
||||
**kwargs
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _get_found_blacklisted_words(text: str):
|
||||
text_lower = text.lower()
|
||||
for word in BLACKLISTED_WORDS:
|
||||
if re.search(r'\b' + re.escape(word) + r'\b', text_lower):
|
||||
return word
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _count_total_tokens(messages):
|
||||
total_tokens = 0
|
||||
for message in messages:
|
||||
total_tokens += count_tokens(message["content"])["n_tokens"]
|
||||
return total_tokens
|
||||
|
||||
@staticmethod
|
||||
def _check_fields(obj, fields):
|
||||
return all(field in obj for field in fields)
|
||||
|
||||
async def pydantic_prediction(
|
||||
self,
|
||||
messages: List[ChatCompletionMessageParam],
|
||||
map_to_model: Callable,
|
||||
json_scheme: str,
|
||||
*,
|
||||
model: Optional[str] = None,
|
||||
temperature: Optional[float] = None,
|
||||
max_retries: int = 3
|
||||
) -> List[T] | T | None:
|
||||
params = {
|
||||
"messages": messages,
|
||||
"response_format": {"type": "json_object"},
|
||||
"model": model if model else self._default_model
|
||||
}
|
||||
|
||||
if temperature:
|
||||
params["temperature"] = temperature
|
||||
|
||||
attempt = 0
|
||||
while attempt < max_retries:
|
||||
result = await self._client.chat.completions.create(**params)
|
||||
result_content = result.choices[0].message.content
|
||||
try:
|
||||
result_json = json.loads(result_content)
|
||||
return map_to_model(result_json)
|
||||
except Exception as e:
|
||||
attempt += 1
|
||||
self._logger.info(f"GPT returned malformed response: {result_content}\n {str(e)}")
|
||||
params["messages"] = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": (
|
||||
"Your previous response wasn't in the json format I've explicitly told you to output. "
|
||||
f"In your next response, you will fix it and return me just the json I've asked."
|
||||
)
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": (
|
||||
f"Previous response: {result_content}\n"
|
||||
f"JSON format: {json_scheme}"
|
||||
)
|
||||
}
|
||||
]
|
||||
if attempt >= max_retries:
|
||||
self._logger.error(f"Max retries exceeded!")
|
||||
return None
|
||||
import json
|
||||
import re
|
||||
import logging
|
||||
from typing import List, Optional, Callable, TypeVar
|
||||
from openai import AsyncOpenAI
|
||||
from openai.types.chat import ChatCompletionMessageParam
|
||||
|
||||
from app.services.abc import ILLMService
|
||||
from app.helpers import count_tokens
|
||||
from app.configs.constants import BLACKLISTED_WORDS
|
||||
from pydantic import BaseModel
|
||||
|
||||
T = TypeVar('T', bound=BaseModel)
|
||||
|
||||
|
||||
class OpenAI(ILLMService):
|
||||
|
||||
MAX_TOKENS = 4097
|
||||
TRY_LIMIT = 2
|
||||
|
||||
def __init__(self, client: AsyncOpenAI):
|
||||
self._client = client
|
||||
self._logger = logging.getLogger(__name__)
|
||||
self._default_model = "gpt-4o-2024-08-06"
|
||||
|
||||
async def prediction(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[ChatCompletionMessageParam],
|
||||
fields_to_check: Optional[List[str]],
|
||||
temperature: float,
|
||||
check_blacklisted: bool = True,
|
||||
token_count: int = -1
|
||||
):
|
||||
if token_count == -1:
|
||||
token_count = self._count_total_tokens(messages)
|
||||
return await self._prediction(model, messages, token_count, fields_to_check, temperature, 0, check_blacklisted)
|
||||
|
||||
async def _prediction(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[ChatCompletionMessageParam],
|
||||
token_count: int,
|
||||
fields_to_check: Optional[List[str]],
|
||||
temperature: float,
|
||||
try_count: int,
|
||||
check_blacklisted: bool,
|
||||
):
|
||||
result = await self._client.chat.completions.create(
|
||||
model=model,
|
||||
max_tokens=int(self.MAX_TOKENS - token_count - 300),
|
||||
temperature=float(temperature),
|
||||
messages=messages,
|
||||
response_format={"type": "json_object"}
|
||||
)
|
||||
result = result.choices[0].message.content
|
||||
|
||||
if check_blacklisted:
|
||||
found_blacklisted_word = self._get_found_blacklisted_words(result)
|
||||
|
||||
if found_blacklisted_word is not None and try_count < self.TRY_LIMIT:
|
||||
self._logger.warning("Result contains blacklisted words: " + str(found_blacklisted_word))
|
||||
return await self._prediction(
|
||||
model, messages, token_count, fields_to_check, temperature, (try_count + 1), check_blacklisted
|
||||
)
|
||||
elif found_blacklisted_word is not None and try_count >= self.TRY_LIMIT:
|
||||
return ""
|
||||
|
||||
if fields_to_check is None:
|
||||
return json.loads(result)
|
||||
|
||||
if not self._check_fields(result, fields_to_check) and try_count < self.TRY_LIMIT:
|
||||
return await self._prediction(
|
||||
model, messages, token_count, fields_to_check, temperature, (try_count + 1), check_blacklisted
|
||||
)
|
||||
|
||||
return json.loads(result)
|
||||
|
||||
async def prediction_override(self, **kwargs):
|
||||
return await self._client.chat.completions.create(
|
||||
**kwargs
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _get_found_blacklisted_words(text: str):
|
||||
text_lower = text.lower()
|
||||
for word in BLACKLISTED_WORDS:
|
||||
if re.search(r'\b' + re.escape(word) + r'\b', text_lower):
|
||||
return word
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _count_total_tokens(messages):
|
||||
total_tokens = 0
|
||||
for message in messages:
|
||||
total_tokens += count_tokens(message["content"])["n_tokens"]
|
||||
return total_tokens
|
||||
|
||||
@staticmethod
|
||||
def _check_fields(obj, fields):
|
||||
return all(field in obj for field in fields)
|
||||
|
||||
async def pydantic_prediction(
|
||||
self,
|
||||
messages: List[ChatCompletionMessageParam],
|
||||
map_to_model: Callable,
|
||||
json_scheme: str,
|
||||
*,
|
||||
model: Optional[str] = None,
|
||||
temperature: Optional[float] = None,
|
||||
max_retries: int = 3
|
||||
) -> List[T] | T | None:
|
||||
params = {
|
||||
"messages": messages,
|
||||
"response_format": {"type": "json_object"},
|
||||
"model": model if model else self._default_model
|
||||
}
|
||||
|
||||
if temperature:
|
||||
params["temperature"] = temperature
|
||||
|
||||
attempt = 0
|
||||
while attempt < max_retries:
|
||||
result = await self._client.chat.completions.create(**params)
|
||||
result_content = result.choices[0].message.content
|
||||
try:
|
||||
result_json = json.loads(result_content)
|
||||
return map_to_model(result_json)
|
||||
except Exception as e:
|
||||
attempt += 1
|
||||
self._logger.info(f"GPT returned malformed response: {result_content}\n {str(e)}")
|
||||
params["messages"] = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": (
|
||||
"Your previous response wasn't in the json format I've explicitly told you to output. "
|
||||
f"In your next response, you will fix it and return me just the json I've asked."
|
||||
)
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": (
|
||||
f"Previous response: {result_content}\n"
|
||||
f"JSON format: {json_scheme}"
|
||||
)
|
||||
}
|
||||
]
|
||||
if attempt >= max_retries:
|
||||
self._logger.error(f"Max retries exceeded!")
|
||||
return None
|
||||
|
||||
Reference in New Issue
Block a user