Files
encoach_backend/app/services/impl/third_parties/whisper.py
2024-11-25 16:41:38 +00:00

67 lines
2.2 KiB
Python

import os
import threading
import whisper
import asyncio
from concurrent.futures import ThreadPoolExecutor
from typing import Dict
from whisper import Whisper
from app.services.abc import ISpeechToTextService
"""
The whisper model is not thread safe, a thread pool
with 4 whisper models will be created so it can
process up to 4 transcriptions at a time.
The base model requires ~1GB so 4 instances is the safe bet:
https://github.com/openai/whisper?tab=readme-ov-file#available-models-and-languages
"""
class OpenAIWhisper(ISpeechToTextService):
def __init__(self, model_name: str = "base", num_models: int = 4):
self._model_name = model_name
self._num_models = num_models
self._models: Dict[int, 'Whisper'] = {}
self._lock = threading.Lock()
self._next_model_id = 0
self._is_closed = False
for i in range(num_models):
self._models[i] = whisper.load_model(self._model_name)
self._executor = ThreadPoolExecutor(
max_workers=num_models,
thread_name_prefix="whisper_worker"
)
def get_model(self) -> 'Whisper':
with self._lock:
model_id = self._next_model_id
self._next_model_id = (self._next_model_id + 1) % self._num_models
return self._models[model_id]
async def speech_to_text(self, file_path: str) -> str:
if not os.path.exists(file_path):
raise FileNotFoundError(f"File {file_path} not found.")
def transcribe():
model = self.get_model()
return model.transcribe(
file_path,
fp16=False,
language='English',
verbose=False
)["text"]
loop = asyncio.get_running_loop()
return await loop.run_in_executor(self._executor, transcribe)
def close(self):
with self._lock:
if not self._is_closed:
self._is_closed = True
if self._executor:
self._executor.shutdown(wait=True, cancel_futures=True)
def __del__(self):
self.close()