Files
miku-discord/stt/whisper_transcriber.py

194 lines
6.1 KiB
Python
Raw Normal View History

"""
Faster-Whisper Transcriber
GPU-accelerated speech-to-text using faster-whisper (CTranslate2).
Supports streaming transcription with partial results.
"""
import numpy as np
from faster_whisper import WhisperModel
from typing import Iterator, Optional, List
import logging
import asyncio
from concurrent.futures import ThreadPoolExecutor
logger = logging.getLogger('whisper')
class WhisperTranscriber:
"""
Faster-Whisper based transcription with streaming support.
Runs on GPU (GTX 1660) with small model for balance of speed/quality.
"""
def __init__(
self,
model_size: str = "small",
device: str = "cuda",
compute_type: str = "float16",
language: str = "en",
beam_size: int = 5
):
"""
Initialize Whisper transcriber.
Args:
model_size: Model size (tiny, base, small, medium, large)
device: Device to run on (cuda or cpu)
compute_type: Compute precision (float16, int8, int8_float16)
language: Language code for transcription
beam_size: Beam search size (higher = better quality, slower)
"""
self.model_size = model_size
self.device = device
self.compute_type = compute_type
self.language = language
self.beam_size = beam_size
logger.info(f"Loading Faster-Whisper model: {model_size} on {device}...")
# Load model
self.model = WhisperModel(
model_size,
device=device,
compute_type=compute_type,
download_root="/models"
)
# Thread pool for blocking transcription calls
self.executor = ThreadPoolExecutor(max_workers=2)
logger.info(f"Whisper model loaded: {model_size} ({compute_type})")
async def transcribe_async(
self,
audio: np.ndarray,
sample_rate: int = 16000,
initial_prompt: Optional[str] = None
) -> str:
"""
Transcribe audio asynchronously (non-blocking).
Args:
audio: Audio data as numpy array (float32)
sample_rate: Audio sample rate
initial_prompt: Optional prompt to guide transcription
Returns:
Transcribed text
"""
loop = asyncio.get_event_loop()
# Run transcription in thread pool to avoid blocking
result = await loop.run_in_executor(
self.executor,
self._transcribe_blocking,
audio,
sample_rate,
initial_prompt
)
return result
def _transcribe_blocking(
self,
audio: np.ndarray,
sample_rate: int,
initial_prompt: Optional[str]
) -> str:
"""
Blocking transcription call (runs in thread pool).
"""
# Convert to float32 if needed
if audio.dtype != np.float32:
audio = audio.astype(np.float32) / 32768.0
# Transcribe
segments, info = self.model.transcribe(
audio,
language=self.language,
beam_size=self.beam_size,
initial_prompt=initial_prompt,
vad_filter=False, # We handle VAD separately
word_timestamps=False # Can enable for word-level timing
)
# Collect all segments
text_parts = []
for segment in segments:
text_parts.append(segment.text.strip())
full_text = " ".join(text_parts).strip()
logger.debug(f"Transcribed: '{full_text}' (language: {info.language}, "
f"probability: {info.language_probability:.2f})")
return full_text
async def transcribe_streaming(
self,
audio_stream: Iterator[np.ndarray],
sample_rate: int = 16000,
chunk_duration_s: float = 2.0
) -> Iterator[dict]:
"""
Transcribe audio stream with partial results.
Args:
audio_stream: Iterator yielding audio chunks
sample_rate: Audio sample rate
chunk_duration_s: Duration of each chunk to transcribe
Yields:
{"type": "partial", "text": "partial transcript"}
{"type": "final", "text": "complete transcript"}
"""
accumulated_audio = []
chunk_samples = int(chunk_duration_s * sample_rate)
async for audio_chunk in audio_stream:
accumulated_audio.append(audio_chunk)
# Check if we have enough audio for transcription
total_samples = sum(len(chunk) for chunk in accumulated_audio)
if total_samples >= chunk_samples:
# Concatenate accumulated audio
audio_data = np.concatenate(accumulated_audio)
# Transcribe current accumulated audio
text = await self.transcribe_async(audio_data, sample_rate)
if text:
yield {
"type": "partial",
"text": text,
"duration": total_samples / sample_rate
}
# Final transcription of remaining audio
if accumulated_audio:
audio_data = np.concatenate(accumulated_audio)
text = await self.transcribe_async(audio_data, sample_rate)
if text:
yield {
"type": "final",
"text": text,
"duration": len(audio_data) / sample_rate
}
def get_supported_languages(self) -> List[str]:
"""Get list of supported language codes."""
return [
"en", "zh", "de", "es", "ru", "ko", "fr", "ja", "pt", "tr",
"pl", "ca", "nl", "ar", "sv", "it", "id", "hi", "fi", "vi",
"he", "uk", "el", "ms", "cs", "ro", "da", "hu", "ta", "no"
]
def cleanup(self):
"""Cleanup resources."""
self.executor.shutdown(wait=True)
logger.info("Whisper transcriber cleaned up")