142 lines
5.4 KiB
Python
142 lines
5.4 KiB
Python
import threading
|
|
import whisper
|
|
import asyncio
|
|
import numpy as np
|
|
import soundfile as sf
|
|
import librosa
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
from typing import Dict, List, Optional
|
|
|
|
from logging import getLogger
|
|
|
|
from tenacity import retry, stop_after_attempt, retry_if_exception_type
|
|
from whisper import Whisper
|
|
|
|
from ielts_be.configs.constants import GPTModels, TemperatureSettings
|
|
from ielts_be.exceptions.exceptions import TranscriptionException
|
|
from ielts_be.services import ISpeechToTextService, ILLMService
|
|
|
|
"""
|
|
The whisper model is not thread safe, a thread pool
|
|
with 4 whisper models will be created so it can
|
|
process up to 4 transcriptions at a time.
|
|
|
|
The base model requires ~1GB so 4 instances is the safe bet:
|
|
https://github.com/openai/whisper?tab=readme-ov-file#available-models-and-languages
|
|
"""
|
|
class OpenAIWhisper(ISpeechToTextService):
|
|
def __init__(self, model_name: str = "base", num_models: int = 4):
|
|
self._model_name = model_name
|
|
self._num_models = num_models
|
|
self._models: Dict[int, 'Whisper'] = {}
|
|
self._lock = threading.Lock()
|
|
self._next_model_id = 0
|
|
self._is_closed = False
|
|
self._logger = getLogger(__name__)
|
|
|
|
for i in range(num_models):
|
|
self._models[i] = whisper.load_model(self._model_name, in_memory=True)
|
|
|
|
self._executor = ThreadPoolExecutor(
|
|
max_workers=num_models,
|
|
thread_name_prefix="whisper_worker"
|
|
)
|
|
|
|
def get_model(self) -> 'Whisper':
|
|
with self._lock:
|
|
model_id = self._next_model_id
|
|
self._next_model_id = (self._next_model_id + 1) % self._num_models
|
|
return self._models[model_id]
|
|
|
|
@retry(
|
|
stop=stop_after_attempt(3),
|
|
retry=retry_if_exception_type(Exception),
|
|
reraise=True
|
|
)
|
|
async def speech_to_text(self, path: str, *, index: Optional[int] = None) -> str:
|
|
def transcribe():
|
|
try:
|
|
audio, sr = sf.read(path)
|
|
if len(audio.shape) > 1:
|
|
audio = audio.mean(axis=1)
|
|
|
|
audio = librosa.resample(audio, orig_sr=sr, target_sr=16000)
|
|
audio = audio.astype(np.float32)
|
|
if np.max(np.abs(audio)) > 0:
|
|
audio = audio / np.max(np.abs(audio))
|
|
|
|
max_samples = 480000 # 30 seconds at 16kHz
|
|
overlap = max_samples // 4 # 1/4 overlap
|
|
|
|
# Greater than 30 secs
|
|
if len(audio) > max_samples:
|
|
chunks = []
|
|
texts = []
|
|
model = self.get_model()
|
|
|
|
# i + 1 gets 1/4 overlap
|
|
for i in range(0, len(audio) - overlap, max_samples - overlap):
|
|
chunk = audio[i:i + max_samples]
|
|
chunks.append(chunk)
|
|
|
|
result = model.transcribe(
|
|
chunk,
|
|
fp16=False,
|
|
language='English',
|
|
verbose=False
|
|
)["text"]
|
|
texts.append(result)
|
|
return texts
|
|
else:
|
|
model = self.get_model()
|
|
return model.transcribe(
|
|
audio,
|
|
fp16=False,
|
|
language='English',
|
|
verbose=False
|
|
)["text"]
|
|
|
|
except Exception as e:
|
|
msg = (
|
|
f"Failed to transcribe exercise {index+1} after 3 attempts: {str(e)}"
|
|
if index else
|
|
f"Transcription failed after 3 attempts: {str(e)}"
|
|
)
|
|
raise TranscriptionException(msg)
|
|
loop = asyncio.get_running_loop()
|
|
return await loop.run_in_executor(self._executor, transcribe)
|
|
|
|
def close(self):
|
|
with self._lock:
|
|
if not self._is_closed:
|
|
self._is_closed = True
|
|
if self._executor:
|
|
self._executor.shutdown(wait=True, cancel_futures=True)
|
|
|
|
def __del__(self):
|
|
self.close()
|
|
|
|
@staticmethod
|
|
async def fix_overlap(llm: ILLMService, segments: List[str]):
|
|
messages = [
|
|
{
|
|
"role": "system",
|
|
"content": (
|
|
'You are a helpful assistant designed to fix transcription segments. You will receive '
|
|
'a string array with transcriptions segments that have overlap, your job is to only '
|
|
'remove duplicated words between segments and join them into one single text. You cannot '
|
|
'correct phrasing or wording, your job is to simply make sure that there is no repeated words '
|
|
'between the end of a segment and at the start of the next segment. Your response must be formatted '
|
|
'as JSON in the following format: {"fixed_text": ""}'
|
|
)
|
|
},
|
|
{
|
|
"role": "user",
|
|
"content": f"[\n" + ",\n".join(f' "{segment}"' for segment in segments) + "\n]"
|
|
}
|
|
]
|
|
response = await llm.prediction(
|
|
GPTModels.GPT_4_O, messages, ["fixed_text"], 0.1
|
|
)
|
|
return response["fixed_text"]
|