Files
encoach_backend/ielts_be/services/impl/third_parties/whisper.py
2024-12-21 19:27:14 +00:00

142 lines
5.4 KiB
Python

import threading
import whisper
import asyncio
import numpy as np
import soundfile as sf
import librosa
from concurrent.futures import ThreadPoolExecutor
from typing import Dict, List, Optional
from logging import getLogger
from tenacity import retry, stop_after_attempt, retry_if_exception_type
from whisper import Whisper
from ielts_be.configs.constants import GPTModels, TemperatureSettings
from ielts_be.exceptions.exceptions import TranscriptionException
from ielts_be.services import ISpeechToTextService, ILLMService
"""
The whisper model is not thread safe, a thread pool
with 4 whisper models will be created so it can
process up to 4 transcriptions at a time.
The base model requires ~1GB so 4 instances is the safe bet:
https://github.com/openai/whisper?tab=readme-ov-file#available-models-and-languages
"""
class OpenAIWhisper(ISpeechToTextService):
def __init__(self, model_name: str = "base", num_models: int = 4):
self._model_name = model_name
self._num_models = num_models
self._models: Dict[int, 'Whisper'] = {}
self._lock = threading.Lock()
self._next_model_id = 0
self._is_closed = False
self._logger = getLogger(__name__)
for i in range(num_models):
self._models[i] = whisper.load_model(self._model_name, in_memory=True)
self._executor = ThreadPoolExecutor(
max_workers=num_models,
thread_name_prefix="whisper_worker"
)
def get_model(self) -> 'Whisper':
with self._lock:
model_id = self._next_model_id
self._next_model_id = (self._next_model_id + 1) % self._num_models
return self._models[model_id]
@retry(
stop=stop_after_attempt(3),
retry=retry_if_exception_type(Exception),
reraise=True
)
async def speech_to_text(self, path: str, *, index: Optional[int] = None) -> str:
def transcribe():
try:
audio, sr = sf.read(path)
if len(audio.shape) > 1:
audio = audio.mean(axis=1)
audio = librosa.resample(audio, orig_sr=sr, target_sr=16000)
audio = audio.astype(np.float32)
if np.max(np.abs(audio)) > 0:
audio = audio / np.max(np.abs(audio))
max_samples = 480000 # 30 seconds at 16kHz
overlap = max_samples // 4 # 1/4 overlap
# Greater than 30 secs
if len(audio) > max_samples:
chunks = []
texts = []
model = self.get_model()
# i + 1 gets 1/4 overlap
for i in range(0, len(audio) - overlap, max_samples - overlap):
chunk = audio[i:i + max_samples]
chunks.append(chunk)
result = model.transcribe(
chunk,
fp16=False,
language='English',
verbose=False
)["text"]
texts.append(result)
return texts
else:
model = self.get_model()
return model.transcribe(
audio,
fp16=False,
language='English',
verbose=False
)["text"]
except Exception as e:
msg = (
f"Failed to transcribe exercise {index+1} after 3 attempts: {str(e)}"
if index else
f"Transcription failed after 3 attempts: {str(e)}"
)
raise TranscriptionException(msg)
loop = asyncio.get_running_loop()
return await loop.run_in_executor(self._executor, transcribe)
def close(self):
with self._lock:
if not self._is_closed:
self._is_closed = True
if self._executor:
self._executor.shutdown(wait=True, cancel_futures=True)
def __del__(self):
self.close()
@staticmethod
async def fix_overlap(llm: ILLMService, segments: List[str]):
messages = [
{
"role": "system",
"content": (
'You are a helpful assistant designed to fix transcription segments. You will receive '
'a string array with transcriptions segments that have overlap, your job is to only '
'remove duplicated words between segments and join them into one single text. You cannot '
'correct phrasing or wording, your job is to simply make sure that there is no repeated words '
'between the end of a segment and at the start of the next segment. Your response must be formatted '
'as JSON in the following format: {"fixed_text": ""}'
)
},
{
"role": "user",
"content": f"[\n" + ",\n".join(f' "{segment}"' for segment in segments) + "\n]"
}
]
response = await llm.prediction(
GPTModels.GPT_4_O, messages, ["fixed_text"], 0.1
)
return response["fixed_text"]