Updated this to the latest version of develop, got rid of most of the duplication, might be missing some packages in toml, needs testing

This commit is contained in:
Carlos Mesquita
2024-08-30 02:35:11 +01:00
parent 3cf9fa5cba
commit f92a803d96
73 changed files with 3642 additions and 2703 deletions

View File

@@ -1,13 +1,16 @@
import json
import re
import logging
from typing import List, Optional
from typing import List, Optional, Callable, TypeVar
from openai import AsyncOpenAI
from openai.types.chat import ChatCompletionMessageParam
from app.services.abc import ILLMService
from app.helpers import count_tokens
from app.configs.constants import BLACKLISTED_WORDS
from pydantic import BaseModel
T = TypeVar('T', bound=BaseModel)
class OpenAI(ILLMService):
@@ -18,6 +21,7 @@ class OpenAI(ILLMService):
def __init__(self, client: AsyncOpenAI):
self._client = client
self._logger = logging.getLogger(__name__)
self._default_model = "gpt-4o-2024-08-06"
async def prediction(
self,
@@ -94,4 +98,53 @@ class OpenAI(ILLMService):
@staticmethod
def _check_fields(obj, fields):
return all(field in obj for field in fields)
return all(field in obj for field in fields)
async def pydantic_prediction(
self,
messages: List[ChatCompletionMessageParam],
map_to_model: Callable,
json_scheme: str,
*,
model: Optional[str] = None,
temperature: Optional[float] = None,
max_retries: int = 3
) -> List[T] | T | None:
params = {
"messages": messages,
"response_format": {"type": "json_object"},
"model": model if model else self._default_model
}
if temperature:
params["temperature"] = temperature
attempt = 0
while attempt < max_retries:
result = await self._client.chat.completions.create(**params)
result_content = result.choices[0].message.content
try:
result_json = json.loads(result_content)
return map_to_model(result_json)
except Exception as e:
attempt += 1
self._logger.info(f"GPT returned malformed response: {result_content}\n {str(e)}")
params["messages"] = [
{
"role": "user",
"content": (
"Your previous response wasn't in the json format I've explicitly told you to output. "
f"In your next response, you will fix it and return me just the json I've asked."
)
},
{
"role": "user",
"content": (
f"Previous response: {result_content}\n"
f"JSON format: {json_scheme}"
)
}
]
if attempt >= max_retries:
self._logger.error(f"Max retries exceeded!")
return None