65 lines
2.2 KiB
Python
65 lines
2.2 KiB
Python
import json
|
|
from logging import getLogger
|
|
|
|
from typing import List, Optional, Callable
|
|
|
|
from openai.types.chat import ChatCompletionMessageParam
|
|
from pydantic import BaseModel
|
|
|
|
|
|
class GPT:
|
|
|
|
def __init__(self, openai_client):
|
|
self._client = openai_client
|
|
self._default_model = "gpt-4o"
|
|
self._logger = getLogger()
|
|
|
|
def prediction(
|
|
self,
|
|
messages: List[ChatCompletionMessageParam],
|
|
map_to_model: Callable,
|
|
json_scheme: str,
|
|
*,
|
|
model: Optional[str] = None,
|
|
temperature: Optional[float] = None,
|
|
max_retries: int = 3
|
|
) -> List[BaseModel] | BaseModel | str | None:
|
|
params = {
|
|
"messages": messages,
|
|
"response_format": {"type": "json_object"},
|
|
"model": model if model else self._default_model
|
|
}
|
|
|
|
if temperature:
|
|
params["temperature"] = temperature
|
|
|
|
attempt = 0
|
|
while attempt < max_retries:
|
|
result = self._client.chat.completions.create(**params)
|
|
result_content = result.choices[0].message.content
|
|
try:
|
|
result_json = json.loads(result_content)
|
|
return map_to_model(result_json)
|
|
except Exception as e:
|
|
attempt += 1
|
|
self._logger.info(f"GPT returned malformed response: {result_content}\n {str(e)}")
|
|
params["messages"] = [
|
|
{
|
|
"role": "user",
|
|
"content": (
|
|
"Your previous response wasn't in the json format I've explicitly told you to output. "
|
|
f"In your next response, you will fix it and return me just the json I've asked."
|
|
)
|
|
},
|
|
{
|
|
"role": "user",
|
|
"content": (
|
|
f"Previous response: {result_content}\n"
|
|
f"JSON format: {json_scheme}"
|
|
)
|
|
}
|
|
]
|
|
if attempt >= max_retries:
|
|
self._logger.error(f"Max retries exceeded!")
|
|
return None
|