Brushed up the backend, added writing task 1 academic prompt gen and grading ENCOA-274
This commit is contained in:
15
ielts_be/services/impl/third_parties/__init__.py
Normal file
15
ielts_be/services/impl/third_parties/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from .aws_polly import AWSPolly
|
||||
from .heygen import Heygen
|
||||
from .openai import OpenAI
|
||||
from .whisper import OpenAIWhisper
|
||||
from .gpt_zero import GPTZero
|
||||
from .elai import ELAI
|
||||
|
||||
__all__ = [
|
||||
"AWSPolly",
|
||||
"Heygen",
|
||||
"OpenAI",
|
||||
"OpenAIWhisper",
|
||||
"GPTZero",
|
||||
"ELAI"
|
||||
]
|
||||
86
ielts_be/services/impl/third_parties/aws_polly.py
Normal file
86
ielts_be/services/impl/third_parties/aws_polly.py
Normal file
@@ -0,0 +1,86 @@
|
||||
import random
|
||||
|
||||
from aiobotocore.client import BaseClient
|
||||
|
||||
from ielts_be.dtos.listening import Dialog
|
||||
from ielts_be.services import ITextToSpeechService
|
||||
from ielts_be.configs.constants import NeuralVoices
|
||||
|
||||
|
||||
class AWSPolly(ITextToSpeechService):
|
||||
|
||||
def __init__(self, client: BaseClient):
|
||||
self._client = client
|
||||
|
||||
async def synthesize_speech(self, text: str, voice: str, engine: str = "neural", output_format: str = "mp3"):
|
||||
tts_response = await self._client.synthesize_speech(
|
||||
Engine=engine,
|
||||
Text=text,
|
||||
OutputFormat=output_format,
|
||||
VoiceId=voice
|
||||
)
|
||||
return await tts_response['AudioStream'].read()
|
||||
|
||||
async def text_to_speech(self, dialog: Dialog) -> bytes:
|
||||
if not dialog.conversation and not dialog.monologue:
|
||||
raise ValueError("Unsupported argument for text_to_speech")
|
||||
|
||||
if not dialog.conversation:
|
||||
audio_segments = await self._text_to_speech(dialog.monologue)
|
||||
else:
|
||||
audio_segments = await self._conversation_to_speech(dialog)
|
||||
|
||||
final_message = await self.synthesize_speech(
|
||||
"This audio recording, for the listening exercise, has finished.",
|
||||
"Stephen"
|
||||
)
|
||||
|
||||
# Add finish message
|
||||
audio_segments.append(final_message)
|
||||
|
||||
# Combine the audio segments into a single audio file
|
||||
combined_audio = b"".join(audio_segments)
|
||||
|
||||
return combined_audio
|
||||
# Save the combined audio to a single file
|
||||
#async with aiofiles.open(file_name, "wb") as f:
|
||||
# await f.write(combined_audio)
|
||||
|
||||
#print("Speech segments saved to " + file_name)
|
||||
|
||||
async def _text_to_speech(self, text: str):
|
||||
voice = random.choice(NeuralVoices.ALL_NEURAL_VOICES)['Id']
|
||||
audio_segments = []
|
||||
for part in self._divide_text(text):
|
||||
audio_segments.append(await self.synthesize_speech(part, voice))
|
||||
|
||||
return audio_segments
|
||||
|
||||
async def _conversation_to_speech(self, dialog: Dialog):
|
||||
audio_segments = []
|
||||
for convo_payload in dialog.conversation:
|
||||
audio_segments.append(await self.synthesize_speech(convo_payload.text, convo_payload.voice))
|
||||
|
||||
return audio_segments
|
||||
|
||||
@staticmethod
|
||||
def _divide_text(text, max_length=3000):
|
||||
if len(text) <= max_length:
|
||||
return [text]
|
||||
|
||||
divisions = []
|
||||
current_position = 0
|
||||
|
||||
while current_position < len(text):
|
||||
next_position = min(current_position + max_length, len(text))
|
||||
next_period_position = text.rfind('.', current_position, next_position)
|
||||
|
||||
if next_period_position != -1 and next_period_position > current_position:
|
||||
divisions.append(text[current_position:next_period_position + 1])
|
||||
current_position = next_period_position + 1
|
||||
else:
|
||||
# If no '.' found in the next chunk, split at max_length
|
||||
divisions.append(text[current_position:next_position])
|
||||
current_position = next_position
|
||||
|
||||
return divisions
|
||||
84
ielts_be/services/impl/third_parties/elai/__init__.py
Normal file
84
ielts_be/services/impl/third_parties/elai/__init__.py
Normal file
@@ -0,0 +1,84 @@
|
||||
from copy import deepcopy
|
||||
from logging import getLogger
|
||||
from httpx import AsyncClient
|
||||
|
||||
from ielts_be.dtos.video import Task, TaskStatus
|
||||
from ielts_be.services import IVideoGeneratorService
|
||||
|
||||
|
||||
class ELAI(IVideoGeneratorService):
|
||||
|
||||
_ELAI_ENDPOINT = 'https://apis.elai.io/api/v1/videos'
|
||||
|
||||
def __init__(self, client: AsyncClient, token: str, avatars: dict, *, conf: dict):
|
||||
super().__init__(deepcopy(avatars))
|
||||
|
||||
self._http_client = client
|
||||
self._conf = deepcopy(conf)
|
||||
self._logger = getLogger(__name__)
|
||||
self._GET_HEADER = {
|
||||
"accept": "application/json",
|
||||
"Authorization": f"Bearer {token}"
|
||||
}
|
||||
self._POST_HEADER = {
|
||||
"accept": "application/json",
|
||||
"content-type": "application/json",
|
||||
"Authorization": f"Bearer {token}"
|
||||
}
|
||||
|
||||
async def create_video(self, text: str, avatar: str):
|
||||
avatar_url = self._avatars[avatar].get("avatar_url")
|
||||
avatar_code = self._avatars[avatar].get("avatar_code")
|
||||
avatar_gender = self._avatars[avatar].get("avatar_gender")
|
||||
avatar_canvas = self._avatars[avatar].get("avatar_canvas")
|
||||
voice_id = self._avatars[avatar].get("voice_id")
|
||||
voice_provider = self._avatars[avatar].get("voice_provider")
|
||||
|
||||
self._conf["slides"][0]["canvas"]["objects"][0]["src"] = avatar_url
|
||||
self._conf["slides"]["avatar"] = {
|
||||
"code": avatar_code,
|
||||
"gender": avatar_gender,
|
||||
"canvas": avatar_canvas
|
||||
}
|
||||
self._conf["slides"]["speech"] = text
|
||||
self._conf["slides"]["voice"] = voice_id
|
||||
self._conf["slides"]["voiceProvider"] = voice_provider
|
||||
|
||||
response = await self._http_client.post(self._ELAI_ENDPOINT, headers=self._POST_HEADER, json=self._conf)
|
||||
|
||||
self._logger.info(response.status_code)
|
||||
self._logger.info(response.json())
|
||||
|
||||
video_id = response.json()["_id"]
|
||||
|
||||
if video_id:
|
||||
await self._http_client.post(f'{self._ELAI_ENDPOINT}/render/{video_id}', headers=self._GET_HEADER)
|
||||
return Task(
|
||||
result=video_id,
|
||||
status=TaskStatus.STARTED,
|
||||
)
|
||||
else:
|
||||
return Task(status=TaskStatus.ERROR)
|
||||
|
||||
async def pool_status(self, video_id: str) -> Task:
|
||||
response = await self._http_client.get(f'{self._ELAI_ENDPOINT}/{video_id}', headers=self._GET_HEADER)
|
||||
response_data = response.json()
|
||||
|
||||
if response_data['status'] == 'ready':
|
||||
self._logger.info(response_data)
|
||||
return Task(
|
||||
status=TaskStatus.COMPLETED,
|
||||
result=response_data.get('url')
|
||||
)
|
||||
elif response_data['status'] == 'failed':
|
||||
self._logger.error('Video creation failed.')
|
||||
return Task(
|
||||
status=TaskStatus.ERROR,
|
||||
result=response_data.get('url')
|
||||
)
|
||||
else:
|
||||
self._logger.info('Video is still processing.')
|
||||
return Task(
|
||||
status=TaskStatus.IN_PROGRESS,
|
||||
result=video_id
|
||||
)
|
||||
58
ielts_be/services/impl/third_parties/elai/avatars.json
Normal file
58
ielts_be/services/impl/third_parties/elai/avatars.json
Normal file
@@ -0,0 +1,58 @@
|
||||
{
|
||||
"Gia": {
|
||||
"avatar_code": "gia.business",
|
||||
"avatar_gender": "female",
|
||||
"avatar_url": "https://elai-avatars.s3.us-east-2.amazonaws.com/common/gia/business/gia_business.png",
|
||||
"avatar_canvas": "https://elai-avatars.s3.us-east-2.amazonaws.com/common/gia/business/gia_business.png",
|
||||
"voice_id": "EXAVITQu4vr4xnSDxMaL",
|
||||
"voice_provider": "elevenlabs"
|
||||
},
|
||||
"Vadim": {
|
||||
"avatar_code": "vadim.business",
|
||||
"avatar_gender": "male",
|
||||
"avatar_url": "https://elai-avatars.s3.us-east-2.amazonaws.com/common/vadim/business/vadim_business.png",
|
||||
"avatar_canvas": "https://d3u63mhbhkevz8.cloudfront.net/common/vadim/business/vadim_business.png",
|
||||
"voice_id": "flq6f7yk4E4fJM5XTYuZ",
|
||||
"voice_provider": "elevenlabs"
|
||||
},
|
||||
"Orhan": {
|
||||
"avatar_code": "orhan.business",
|
||||
"avatar_gender": "male",
|
||||
"avatar_url": "https://elai-avatars.s3.us-east-2.amazonaws.com/common/orhan/business/orhan.png",
|
||||
"avatar_canvas": "https://d3u63mhbhkevz8.cloudfront.net/common/orhan/business/orhan.png",
|
||||
"voice_id": "en-US-AndrewMultilingualNeural",
|
||||
"voice_provider": "azure"
|
||||
},
|
||||
"Flora": {
|
||||
"avatar_code": "flora.business",
|
||||
"avatar_gender": "female",
|
||||
"avatar_url": "https://elai-avatars.s3.us-east-2.amazonaws.com/common/flora/business/flora_business.png",
|
||||
"avatar_canvas": "https://d3u63mhbhkevz8.cloudfront.net/common/flora/business/flora_business.png",
|
||||
"voice_id": "en-US-JaneNeural",
|
||||
"voice_provider": "azure"
|
||||
},
|
||||
"Scarlett": {
|
||||
"avatar_code": "scarlett.business",
|
||||
"avatar_gender": "female",
|
||||
"avatar_url": "https://elai-avatars.s3.us-east-2.amazonaws.com/common/scarlett/business/scarlett_business.png",
|
||||
"avatar_canvas": "https://d3u63mhbhkevz8.cloudfront.net/common/scarlett/business/scarlett_business.png",
|
||||
"voice_id": "en-US-NancyNeural",
|
||||
"voice_provider": "azure"
|
||||
},
|
||||
"Parker": {
|
||||
"avatar_code": "parker.casual",
|
||||
"avatar_gender": "male",
|
||||
"avatar_url": "https://elai-avatars.s3.us-east-2.amazonaws.com/common/parker/casual/parker_casual.png",
|
||||
"avatar_canvas": "https://d3u63mhbhkevz8.cloudfront.net/common/parker/casual/parker_casual.png",
|
||||
"voice_id": "en-US-TonyNeural",
|
||||
"voice_provider": "azure"
|
||||
},
|
||||
"Ethan": {
|
||||
"avatar_code": "ethan.business",
|
||||
"avatar_gender": "male",
|
||||
"avatar_url": "https://elai-avatars.s3.us-east-2.amazonaws.com/common/ethan/business/ethan_business_low.png",
|
||||
"avatar_canvas": "https://d3u63mhbhkevz8.cloudfront.net/common/ethan/business/ethan_business_low.png",
|
||||
"voice_id": "en-US-JasonNeural",
|
||||
"voice_provider": "azure"
|
||||
}
|
||||
}
|
||||
72
ielts_be/services/impl/third_parties/elai/conf.json
Normal file
72
ielts_be/services/impl/third_parties/elai/conf.json
Normal file
@@ -0,0 +1,72 @@
|
||||
{
|
||||
"name": "API test",
|
||||
"slides": [
|
||||
{
|
||||
"id": 1,
|
||||
"canvas": {
|
||||
"objects": [
|
||||
{
|
||||
"type": "avatar",
|
||||
"left": 151.5,
|
||||
"top": 36,
|
||||
"fill": "#4868FF",
|
||||
"scaleX": 0.3,
|
||||
"scaleY": 0.3,
|
||||
"width": 1080,
|
||||
"height": 1080,
|
||||
"avatarType": "transparent",
|
||||
"animation": {
|
||||
"type": null,
|
||||
"exitType": null
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "image",
|
||||
"version": "5.3.0",
|
||||
"originX": "left",
|
||||
"originY": "top",
|
||||
"left": 30,
|
||||
"top": 30,
|
||||
"width": 800,
|
||||
"height": 600,
|
||||
"fill": "rgb(0,0,0)",
|
||||
"stroke": null,
|
||||
"strokeWidth": 0,
|
||||
"strokeDashArray": null,
|
||||
"strokeLineCap": "butt",
|
||||
"strokeDashOffset": 0,
|
||||
"strokeLineJoin": "miter",
|
||||
"strokeUniform": false,
|
||||
"strokeMiterLimit": 4,
|
||||
"scaleX": 0.18821429,
|
||||
"scaleY": 0.18821429,
|
||||
"angle": 0,
|
||||
"flipX": false,
|
||||
"flipY": false,
|
||||
"opacity": 1,
|
||||
"shadow": null,
|
||||
"visible": true,
|
||||
"backgroundColor": "",
|
||||
"fillRule": "nonzero",
|
||||
"paintFirst": "fill",
|
||||
"globalCompositeOperation": "source-over",
|
||||
"skewX": 0,
|
||||
"skewY": 0,
|
||||
"cropX": 0,
|
||||
"cropY": 0,
|
||||
"id": 676845479989,
|
||||
"src": "https://d3u63mhbhkevz8.cloudfront.net/production/uploads/66f5190349f943682dd776ff/en-coach-main-logo-800x600_sm1ype.jpg?Expires=1727654400&Policy=eyJTdGF0ZW1lbnQiOlt7IlJlc291cmNlIjoiaHR0cHM6Ly9kM3U2M21oYmhrZXZ6OC5jbG91ZGZyb250Lm5ldC9wcm9kdWN0aW9uL3VwbG9hZHMvNjZmNTE5MDM0OWY5NDM2ODJkZDc3NmZmL2VuLWNvYWNoLW1haW4tbG9nby04MDB4NjAwX3NtMXlwZS5qcGciLCJDb25kaXRpb24iOnsiRGF0ZUxlc3NUaGFuIjp7IkFXUzpFcG9jaFRpbWUiOjE3Mjc2NTQ0MDB9fX1dfQ__&Signature=kTVzlDeS7cua2HiAE5G%7E-yFqbhu0bHraFH5SauUln7yuNXoX7vtiKIBYiL%7Eps3LCLEZS77arSZ7H%7EG8CKzabHDjAR-Y6Uc%7ELD5KQaMmk0jbAxbC3Wdoq6cfd0qIwEuodQYlC0It2WBidP8KsgOy3uUQ%7EvcBoqlb255yMFw4pHuptOBB1kPs%7EFyzDV0fnRNsKaYRcy0Fn2EFUp13axm0CZQclazuLFM622AyCydKMy0vfxV%7Etny3sskwPaUe2OANGMFg07Q1pRuy6fUON0DsbhAh1tA2H6-nnem5KbFwiZK3IIwwYGBx3H41ovzC6Ejt80Fd0%7EPSHw7GzVBnUmtP-IA__&Key-Pair-Id=K1Y7U91AR6T7E5",
|
||||
"crossOrigin": "anonymous",
|
||||
"filters": [],
|
||||
"_exists": true
|
||||
}
|
||||
],
|
||||
"background": "#ffffff",
|
||||
"version": "4.4.0"
|
||||
},
|
||||
"animation": "fade_in",
|
||||
"language": "English",
|
||||
"voiceType": "text"
|
||||
}
|
||||
]
|
||||
}
|
||||
52
ielts_be/services/impl/third_parties/gpt_zero.py
Normal file
52
ielts_be/services/impl/third_parties/gpt_zero.py
Normal file
@@ -0,0 +1,52 @@
|
||||
from logging import getLogger
|
||||
from typing import Dict, Optional
|
||||
|
||||
from httpx import AsyncClient
|
||||
|
||||
from ielts_be.services import IAIDetectorService
|
||||
|
||||
|
||||
class GPTZero(IAIDetectorService):
|
||||
|
||||
_GPT_ZERO_ENDPOINT = 'https://api.gptzero.me/v2/predict/text'
|
||||
|
||||
def __init__(self, client: AsyncClient, gpt_zero_key: str):
|
||||
self._header = {
|
||||
'x-api-key': gpt_zero_key
|
||||
}
|
||||
self._http_client = client
|
||||
self._logger = getLogger(__name__)
|
||||
|
||||
async def run_detection(self, text: str):
|
||||
data = {
|
||||
'document': text,
|
||||
'version': '',
|
||||
'multilingual': False
|
||||
}
|
||||
|
||||
response = await self._http_client.post(self._GPT_ZERO_ENDPOINT, headers=self._header, json=data)
|
||||
if response.status_code != 200:
|
||||
return None
|
||||
return self._parse_detection(response.json())
|
||||
|
||||
def _parse_detection(self, response: Dict) -> Optional[Dict]:
|
||||
try:
|
||||
text_scan = response["documents"][0]
|
||||
|
||||
filtered_sentences = [
|
||||
{
|
||||
"sentence": item["sentence"],
|
||||
"highlight_sentence_for_ai": item["highlight_sentence_for_ai"]
|
||||
}
|
||||
for item in text_scan["sentences"]
|
||||
]
|
||||
|
||||
return {
|
||||
"class_probabilities": text_scan["class_probabilities"],
|
||||
"confidence_category": text_scan["confidence_category"],
|
||||
"predicted_class": text_scan["predicted_class"],
|
||||
"sentences": filtered_sentences
|
||||
}
|
||||
except Exception as e:
|
||||
self._logger.error(f'Failed to parse GPT\'s Zero response: {str(e)}')
|
||||
return None
|
||||
82
ielts_be/services/impl/third_parties/heygen/__init__.py
Normal file
82
ielts_be/services/impl/third_parties/heygen/__init__.py
Normal file
@@ -0,0 +1,82 @@
|
||||
import logging
|
||||
from copy import deepcopy
|
||||
|
||||
from httpx import AsyncClient
|
||||
|
||||
from ielts_be.dtos.video import Task, TaskStatus
|
||||
from ielts_be.services import IVideoGeneratorService
|
||||
|
||||
|
||||
class Heygen(IVideoGeneratorService):
|
||||
|
||||
_GET_VIDEO_URL = 'https://api.heygen.com/v1/video_status.get'
|
||||
|
||||
def __init__(self, client: AsyncClient, token: str, avatars: dict):
|
||||
super().__init__(deepcopy(avatars))
|
||||
self._get_header = {
|
||||
'X-Api-Key': token
|
||||
}
|
||||
self._post_header = {
|
||||
'X-Api-Key': token,
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
self._http_client = client
|
||||
self._logger = logging.getLogger(__name__)
|
||||
|
||||
async def create_video(self, text: str, avatar: str):
|
||||
avatar = self._avatars[avatar]["id"]
|
||||
|
||||
create_video_url = f'https://api.heygen.com/v2/template/{avatar}/generate'
|
||||
data = {
|
||||
"test": False,
|
||||
"caption": False,
|
||||
"title": "video_title",
|
||||
"variables": {
|
||||
"script_here": {
|
||||
"name": "script_here",
|
||||
"type": "text",
|
||||
"properties": {
|
||||
"content": text
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
response = await self._http_client.post(create_video_url, headers=self._post_header, json=data)
|
||||
self._logger.info(response.status_code)
|
||||
self._logger.info(response.json())
|
||||
video_id = response.json()["data"]["video_id"]
|
||||
|
||||
return Task(
|
||||
result=video_id,
|
||||
status=TaskStatus.STARTED,
|
||||
)
|
||||
|
||||
|
||||
async def poll_status(self, video_id: str) -> Task:
|
||||
response = await self._http_client.get(self._GET_VIDEO_URL, headers=self._get_header, params={
|
||||
'video_id': video_id
|
||||
})
|
||||
response_data = response.json()
|
||||
|
||||
status = response_data["data"]["status"]
|
||||
error = response_data["data"]["error"]
|
||||
if status != "completed" and error is None:
|
||||
self._logger.info(f"Status: {status}")
|
||||
return Task(
|
||||
status=TaskStatus.IN_PROGRESS,
|
||||
result=video_id
|
||||
)
|
||||
|
||||
if error:
|
||||
self._logger.error('Video creation failed.')
|
||||
return Task(
|
||||
status=TaskStatus.ERROR,
|
||||
result=response_data.get('url')
|
||||
)
|
||||
|
||||
url = response.json()['data']['video_url']
|
||||
self._logger.info(f'Successfully generated video: {url}')
|
||||
return Task(
|
||||
status=TaskStatus.COMPLETED,
|
||||
result=url
|
||||
)
|
||||
30
ielts_be/services/impl/third_parties/heygen/avatars.json
Normal file
30
ielts_be/services/impl/third_parties/heygen/avatars.json
Normal file
@@ -0,0 +1,30 @@
|
||||
{
|
||||
"Matthew Noah": {
|
||||
"id": "5912afa7c77c47d3883af3d874047aaf",
|
||||
"avatar_gender": "male"
|
||||
},
|
||||
"Vera Cerise": {
|
||||
"id": "9e58d96a383e4568a7f1e49df549e0e4",
|
||||
"avatar_gender": "female"
|
||||
},
|
||||
"Edward Tony": {
|
||||
"id": "d2cdd9c0379a4d06ae2afb6e5039bd0c",
|
||||
"avatar_gender": "male"
|
||||
},
|
||||
"Tanya Molly": {
|
||||
"id": "045cb5dcd00042b3a1e4f3bc1c12176b",
|
||||
"avatar_gender": "female"
|
||||
},
|
||||
"Kayla Abbi": {
|
||||
"id": "1ae1e5396cc444bfad332155fdb7a934",
|
||||
"avatar_gender": "female"
|
||||
},
|
||||
"Jerome Ryan": {
|
||||
"id": "0ee6aa7cc1084063a630ae514fccaa31",
|
||||
"avatar_gender": "male"
|
||||
},
|
||||
"Tyler Christopher": {
|
||||
"id": "5772cff935844516ad7eeff21f839e43",
|
||||
"avatar_gender": "male"
|
||||
}
|
||||
}
|
||||
153
ielts_be/services/impl/third_parties/openai.py
Normal file
153
ielts_be/services/impl/third_parties/openai.py
Normal file
@@ -0,0 +1,153 @@
|
||||
import json
|
||||
import re
|
||||
import logging
|
||||
from typing import List, Optional, Callable, TypeVar
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
from openai.types.chat import ChatCompletionMessageParam
|
||||
|
||||
from ielts_be.services.abc import ILLMService
|
||||
from ielts_be.helpers import count_tokens
|
||||
from ielts_be.configs.constants import BLACKLISTED_WORDS
|
||||
from pydantic import BaseModel
|
||||
|
||||
T = TypeVar('T', bound=BaseModel)
|
||||
|
||||
|
||||
class OpenAI(ILLMService):
|
||||
|
||||
MAX_TOKENS = 4097
|
||||
TRY_LIMIT = 2
|
||||
|
||||
def __init__(self, client: AsyncOpenAI):
|
||||
self._client = client
|
||||
self._logger = logging.getLogger(__name__)
|
||||
self._default_model = "gpt-4o"
|
||||
|
||||
async def prediction(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[ChatCompletionMessageParam],
|
||||
fields_to_check: Optional[List[str]],
|
||||
temperature: float,
|
||||
check_blacklisted: bool = True,
|
||||
token_count: int = -1
|
||||
):
|
||||
if token_count == -1:
|
||||
token_count = self._count_total_tokens(messages)
|
||||
return await self._prediction(model, messages, token_count, fields_to_check, temperature, 0, check_blacklisted)
|
||||
|
||||
async def _prediction(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[ChatCompletionMessageParam],
|
||||
token_count: int,
|
||||
fields_to_check: Optional[List[str]],
|
||||
temperature: float,
|
||||
try_count: int,
|
||||
check_blacklisted: bool,
|
||||
):
|
||||
result = await self._client.chat.completions.create(
|
||||
model=model,
|
||||
max_tokens=int(self.MAX_TOKENS - token_count - 300),
|
||||
temperature=float(temperature),
|
||||
messages=messages,
|
||||
response_format={"type": "json_object"}
|
||||
)
|
||||
result = result.choices[0].message.content
|
||||
|
||||
if check_blacklisted:
|
||||
found_blacklisted_word = self._get_found_blacklisted_words(result)
|
||||
|
||||
if found_blacklisted_word is not None and try_count < self.TRY_LIMIT:
|
||||
self._logger.warning("Result contains blacklisted words: " + str(found_blacklisted_word))
|
||||
return await self._prediction(
|
||||
model, messages, token_count, fields_to_check, temperature, (try_count + 1), check_blacklisted
|
||||
)
|
||||
elif found_blacklisted_word is not None and try_count >= self.TRY_LIMIT:
|
||||
return ""
|
||||
|
||||
if fields_to_check is None:
|
||||
return json.loads(result)
|
||||
|
||||
if not self._check_fields(result, fields_to_check) and try_count < self.TRY_LIMIT:
|
||||
return await self._prediction(
|
||||
model, messages, token_count, fields_to_check, temperature, (try_count + 1), check_blacklisted
|
||||
)
|
||||
return json.loads(result)
|
||||
|
||||
async def prediction_override(self, **kwargs):
|
||||
return await self._client.chat.completions.create(
|
||||
**kwargs
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _get_found_blacklisted_words(text: str):
|
||||
text_lower = text.lower()
|
||||
for word in BLACKLISTED_WORDS:
|
||||
if re.search(r'\b' + re.escape(word) + r'\b', text_lower):
|
||||
return word
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _count_total_tokens(messages):
|
||||
total_tokens = 0
|
||||
for message in messages:
|
||||
total_tokens += count_tokens(message["content"])["n_tokens"]
|
||||
return total_tokens
|
||||
|
||||
@staticmethod
|
||||
def _check_fields(obj, fields):
|
||||
return all(field in obj for field in fields)
|
||||
|
||||
async def pydantic_prediction(
|
||||
self,
|
||||
messages: List[ChatCompletionMessageParam],
|
||||
map_to_model: Callable,
|
||||
json_scheme: str,
|
||||
*,
|
||||
model: Optional[str] = None,
|
||||
temperature: Optional[float] = None,
|
||||
max_retries: int = 3
|
||||
) -> List[T] | T | None:
|
||||
params = {
|
||||
"messages": messages,
|
||||
"response_format": {"type": "json_object"},
|
||||
"model": model if model else self._default_model
|
||||
}
|
||||
|
||||
if temperature:
|
||||
params["temperature"] = temperature
|
||||
|
||||
attempt = 0
|
||||
while attempt < 3:
|
||||
result = await self._client.chat.completions.create(**params)
|
||||
result_content = result.choices[0].message.content
|
||||
|
||||
try:
|
||||
result_json = json.loads(result_content)
|
||||
print(str(result_json))
|
||||
return map_to_model(result_json)
|
||||
except Exception as e:
|
||||
attempt += 1
|
||||
self._logger.info(f"GPT returned malformed response: {result_content}\n {str(e)}")
|
||||
params["messages"] = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": (
|
||||
"Your previous response wasn't in the json format I've explicitly told you to output. "
|
||||
f"In your next response, you will fix it and return me just the json I've asked."
|
||||
)
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": (
|
||||
f"Previous response: {result_content}\n"
|
||||
f"JSON format: {json_scheme}"
|
||||
f"Validation errors: {e}"
|
||||
)
|
||||
}
|
||||
]
|
||||
if attempt >= max_retries:
|
||||
self._logger.error(f"Max retries exceeded!")
|
||||
return None
|
||||
106
ielts_be/services/impl/third_parties/whisper.py
Normal file
106
ielts_be/services/impl/third_parties/whisper.py
Normal file
@@ -0,0 +1,106 @@
|
||||
import threading
|
||||
import whisper
|
||||
import asyncio
|
||||
import numpy as np
|
||||
import soundfile as sf
|
||||
import librosa
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Dict
|
||||
|
||||
from logging import getLogger
|
||||
from whisper import Whisper
|
||||
|
||||
from ielts_be.services import ISpeechToTextService
|
||||
|
||||
"""
|
||||
The whisper model is not thread safe, a thread pool
|
||||
with 4 whisper models will be created so it can
|
||||
process up to 4 transcriptions at a time.
|
||||
|
||||
The base model requires ~1GB so 4 instances is the safe bet:
|
||||
https://github.com/openai/whisper?tab=readme-ov-file#available-models-and-languages
|
||||
"""
|
||||
class OpenAIWhisper(ISpeechToTextService):
|
||||
def __init__(self, model_name: str = "base", num_models: int = 4):
|
||||
self._model_name = model_name
|
||||
self._num_models = num_models
|
||||
self._models: Dict[int, 'Whisper'] = {}
|
||||
self._lock = threading.Lock()
|
||||
self._next_model_id = 0
|
||||
self._is_closed = False
|
||||
self._logger = getLogger(__name__)
|
||||
|
||||
for i in range(num_models):
|
||||
self._models[i] = whisper.load_model(self._model_name, in_memory=True)
|
||||
|
||||
self._executor = ThreadPoolExecutor(
|
||||
max_workers=num_models,
|
||||
thread_name_prefix="whisper_worker"
|
||||
)
|
||||
|
||||
def get_model(self) -> 'Whisper':
|
||||
with self._lock:
|
||||
model_id = self._next_model_id
|
||||
self._next_model_id = (self._next_model_id + 1) % self._num_models
|
||||
return self._models[model_id]
|
||||
|
||||
async def speech_to_text(self, path: str) -> str:
|
||||
def transcribe():
|
||||
try:
|
||||
audio, sr = sf.read(path)
|
||||
|
||||
# Convert to mono first to reduce memory usage
|
||||
if len(audio.shape) > 1:
|
||||
audio = audio.mean(axis=1)
|
||||
|
||||
# Resample from 48kHz to 16kHz
|
||||
audio = librosa.resample(audio, orig_sr=sr, target_sr=16000)
|
||||
|
||||
# Normalize to [-1, 1] range
|
||||
audio = audio.astype(np.float32)
|
||||
if np.max(np.abs(audio)) > 0:
|
||||
audio = audio / np.max(np.abs(audio))
|
||||
|
||||
# Break up long audio into chunks (30 seconds at 16kHz = 480000 samples)
|
||||
max_samples = 480000
|
||||
if len(audio) > max_samples:
|
||||
chunks = []
|
||||
for i in range(0, len(audio), max_samples):
|
||||
chunk = audio[i:i + max_samples]
|
||||
chunks.append(chunk)
|
||||
|
||||
model = self.get_model()
|
||||
texts = []
|
||||
for chunk in chunks:
|
||||
result = model.transcribe(
|
||||
chunk,
|
||||
fp16=False,
|
||||
language='English',
|
||||
verbose=False
|
||||
)["text"]
|
||||
texts.append(result)
|
||||
return " ".join(texts)
|
||||
else:
|
||||
model = self.get_model()
|
||||
return model.transcribe(
|
||||
audio,
|
||||
fp16=False,
|
||||
language='English',
|
||||
verbose=False
|
||||
)["text"]
|
||||
|
||||
except Exception as e:
|
||||
raise
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
return await loop.run_in_executor(self._executor, transcribe)
|
||||
|
||||
def close(self):
|
||||
with self._lock:
|
||||
if not self._is_closed:
|
||||
self._is_closed = True
|
||||
if self._executor:
|
||||
self._executor.shutdown(wait=True, cancel_futures=True)
|
||||
|
||||
def __del__(self):
|
||||
self.close()
|
||||
Reference in New Issue
Block a user