Upload level exam without hooking up to firestore and running in thread, will do this when I have the edit view done
This commit is contained in:
66
modules/gpt.py
Normal file
66
modules/gpt.py
Normal file
@@ -0,0 +1,66 @@
|
||||
import json
|
||||
from logging import getLogger
|
||||
|
||||
from typing import List, Optional, Callable, TypeVar
|
||||
|
||||
from openai.types.chat import ChatCompletionMessageParam
|
||||
from pydantic import BaseModel
|
||||
|
||||
T = TypeVar('T', bound=BaseModel)
|
||||
|
||||
|
||||
class GPT:
|
||||
|
||||
def __init__(self, openai_client):
|
||||
self._client = openai_client
|
||||
self._default_model = "gpt-4o-2024-08-06"
|
||||
self._logger = getLogger(__name__)
|
||||
|
||||
def prediction(
|
||||
self,
|
||||
messages: List[ChatCompletionMessageParam],
|
||||
map_to_model: Callable,
|
||||
json_scheme: str,
|
||||
*,
|
||||
model: Optional[str] = None,
|
||||
temperature: Optional[float] = None,
|
||||
max_retries: int = 3
|
||||
) -> List[T] | T | None:
|
||||
params = {
|
||||
"messages": messages,
|
||||
"response_format": {"type": "json_object"},
|
||||
"model": model if model else self._default_model
|
||||
}
|
||||
|
||||
if temperature:
|
||||
params["temperature"] = temperature
|
||||
|
||||
attempt = 0
|
||||
while attempt < max_retries:
|
||||
result = self._client.chat.completions.create(**params)
|
||||
result_content = result.choices[0].message.content
|
||||
try:
|
||||
result_json = json.loads(result_content)
|
||||
return map_to_model(result_json)
|
||||
except Exception as e:
|
||||
attempt += 1
|
||||
self._logger.info(f"GPT returned malformed response: {result_content}\n {str(e)}")
|
||||
params["messages"] = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": (
|
||||
"Your previous response wasn't in the json format I've explicitly told you to output. "
|
||||
f"In your next response, you will fix it and return me just the json I've asked."
|
||||
)
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": (
|
||||
f"Previous response: {result_content}\n"
|
||||
f"JSON format: {json_scheme}"
|
||||
)
|
||||
}
|
||||
]
|
||||
if attempt >= max_retries:
|
||||
self._logger.error(f"Max retries exceeded!")
|
||||
return None
|
||||
Reference in New Issue
Block a user