import json from logging import getLogger from typing import List, Optional, Callable, TypeVar from openai.types.chat import ChatCompletionMessageParam from pydantic import BaseModel T = TypeVar('T', bound=BaseModel) class GPT: def __init__(self, openai_client): self._client = openai_client self._default_model = "gpt-4o-2024-08-06" self._logger = getLogger(__name__) def prediction( self, messages: List[ChatCompletionMessageParam], map_to_model: Callable, json_scheme: str, *, model: Optional[str] = None, temperature: Optional[float] = None, max_retries: int = 3 ) -> List[T] | T | None: params = { "messages": messages, "response_format": {"type": "json_object"}, "model": model if model else self._default_model } if temperature: params["temperature"] = temperature attempt = 0 while attempt < max_retries: result = self._client.chat.completions.create(**params) result_content = result.choices[0].message.content try: result_json = json.loads(result_content) return map_to_model(result_json) except Exception as e: attempt += 1 self._logger.info(f"GPT returned malformed response: {result_content}\n {str(e)}") params["messages"] = [ { "role": "user", "content": ( "Your previous response wasn't in the json format I've explicitly told you to output. " f"In your next response, you will fix it and return me just the json I've asked." ) }, { "role": "user", "content": ( f"Previous response: {result_content}\n" f"JSON format: {json_scheme}" ) } ] if attempt >= max_retries: self._logger.error(f"Max retries exceeded!") return None