67 lines
2.2 KiB
Python
67 lines
2.2 KiB
Python
import os
|
|
import threading
|
|
import whisper
|
|
import asyncio
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
from typing import Dict
|
|
from whisper import Whisper
|
|
|
|
from app.services.abc import ISpeechToTextService
|
|
|
|
"""
|
|
The whisper model is not thread safe, a thread pool
|
|
with 4 whisper models will be created so it can
|
|
process up to 4 transcriptions at a time.
|
|
|
|
The base model requires ~1GB so 4 instances is the safe bet:
|
|
https://github.com/openai/whisper?tab=readme-ov-file#available-models-and-languages
|
|
"""
|
|
class OpenAIWhisper(ISpeechToTextService):
|
|
def __init__(self, model_name: str = "base", num_models: int = 4):
|
|
self._model_name = model_name
|
|
self._num_models = num_models
|
|
self._models: Dict[int, 'Whisper'] = {}
|
|
self._lock = threading.Lock()
|
|
self._next_model_id = 0
|
|
self._is_closed = False
|
|
|
|
for i in range(num_models):
|
|
self._models[i] = whisper.load_model(self._model_name)
|
|
|
|
self._executor = ThreadPoolExecutor(
|
|
max_workers=num_models,
|
|
thread_name_prefix="whisper_worker"
|
|
)
|
|
|
|
def get_model(self) -> 'Whisper':
|
|
with self._lock:
|
|
model_id = self._next_model_id
|
|
self._next_model_id = (self._next_model_id + 1) % self._num_models
|
|
return self._models[model_id]
|
|
|
|
async def speech_to_text(self, file_path: str) -> str:
|
|
if not os.path.exists(file_path):
|
|
raise FileNotFoundError(f"File {file_path} not found.")
|
|
|
|
def transcribe():
|
|
model = self.get_model()
|
|
return model.transcribe(
|
|
file_path,
|
|
fp16=False,
|
|
language='English',
|
|
verbose=False
|
|
)["text"]
|
|
|
|
loop = asyncio.get_running_loop()
|
|
return await loop.run_in_executor(self._executor, transcribe)
|
|
|
|
def close(self):
|
|
with self._lock:
|
|
if not self._is_closed:
|
|
self._is_closed = True
|
|
if self._executor:
|
|
self._executor.shutdown(wait=True, cancel_futures=True)
|
|
|
|
def __del__(self):
|
|
self.close()
|