Decided on Parakeet ONNX Runtime. Works pretty great. Realtime voice chat possible now. UX lacking.
This commit is contained in:
162
stt-parakeet/asr/asr_pipeline.py
Normal file
162
stt-parakeet/asr/asr_pipeline.py
Normal file
@@ -0,0 +1,162 @@
|
||||
"""
|
||||
ASR Pipeline using onnx-asr library with Parakeet TDT 0.6B V3 model
|
||||
"""
|
||||
import numpy as np
|
||||
import onnx_asr
|
||||
from typing import Union, Optional
|
||||
import logging
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ASRPipeline:
|
||||
"""
|
||||
ASR Pipeline wrapper for onnx-asr Parakeet TDT model.
|
||||
Supports GPU acceleration via ONNX Runtime with CUDA/TensorRT.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = "nemo-parakeet-tdt-0.6b-v3",
|
||||
model_path: Optional[str] = None,
|
||||
quantization: Optional[str] = None,
|
||||
providers: Optional[list] = None,
|
||||
use_vad: bool = False,
|
||||
):
|
||||
"""
|
||||
Initialize ASR Pipeline.
|
||||
|
||||
Args:
|
||||
model_name: Name of the model to load (default: "nemo-parakeet-tdt-0.6b-v3")
|
||||
model_path: Optional local path to model files (default uses models/parakeet)
|
||||
quantization: Optional quantization ("int8", "fp16", etc.)
|
||||
providers: Optional ONNX runtime providers list for GPU acceleration
|
||||
use_vad: Whether to use Voice Activity Detection
|
||||
"""
|
||||
self.model_name = model_name
|
||||
self.model_path = model_path or "models/parakeet"
|
||||
self.quantization = quantization
|
||||
self.use_vad = use_vad
|
||||
|
||||
# Configure providers for GPU acceleration
|
||||
if providers is None:
|
||||
# Default: try CUDA, then CPU
|
||||
providers = [
|
||||
(
|
||||
"CUDAExecutionProvider",
|
||||
{
|
||||
"device_id": 0,
|
||||
"arena_extend_strategy": "kNextPowerOfTwo",
|
||||
"gpu_mem_limit": 6 * 1024 * 1024 * 1024, # 6GB
|
||||
"cudnn_conv_algo_search": "EXHAUSTIVE",
|
||||
"do_copy_in_default_stream": True,
|
||||
}
|
||||
),
|
||||
"CPUExecutionProvider",
|
||||
]
|
||||
|
||||
self.providers = providers
|
||||
logger.info(f"Initializing ASR Pipeline with model: {model_name}")
|
||||
logger.info(f"Model path: {self.model_path}")
|
||||
logger.info(f"Quantization: {quantization}")
|
||||
logger.info(f"Providers: {providers}")
|
||||
|
||||
# Load the model
|
||||
try:
|
||||
self.model = onnx_asr.load_model(
|
||||
model_name,
|
||||
self.model_path,
|
||||
quantization=quantization,
|
||||
providers=providers,
|
||||
)
|
||||
logger.info("Model loaded successfully")
|
||||
|
||||
# Optionally add VAD
|
||||
if use_vad:
|
||||
logger.info("Loading VAD model...")
|
||||
vad = onnx_asr.load_vad("silero", providers=providers)
|
||||
self.model = self.model.with_vad(vad)
|
||||
logger.info("VAD enabled")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load model: {e}")
|
||||
raise
|
||||
|
||||
def transcribe(
|
||||
self,
|
||||
audio: Union[str, np.ndarray],
|
||||
sample_rate: int = 16000,
|
||||
) -> Union[str, list]:
|
||||
"""
|
||||
Transcribe audio to text.
|
||||
|
||||
Args:
|
||||
audio: Audio data as numpy array (float32) or path to WAV file
|
||||
sample_rate: Sample rate of audio (default: 16000 Hz)
|
||||
|
||||
Returns:
|
||||
Transcribed text string, or list of results if VAD is enabled
|
||||
"""
|
||||
try:
|
||||
if isinstance(audio, str):
|
||||
# Load from file
|
||||
result = self.model.recognize(audio)
|
||||
else:
|
||||
# Process numpy array
|
||||
if audio.dtype != np.float32:
|
||||
audio = audio.astype(np.float32)
|
||||
result = self.model.recognize(audio, sample_rate=sample_rate)
|
||||
|
||||
# If VAD is enabled, result is a generator
|
||||
if self.use_vad:
|
||||
return list(result)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Transcription failed: {e}")
|
||||
raise
|
||||
|
||||
def transcribe_batch(
|
||||
self,
|
||||
audio_files: list,
|
||||
) -> list:
|
||||
"""
|
||||
Transcribe multiple audio files in batch.
|
||||
|
||||
Args:
|
||||
audio_files: List of paths to WAV files
|
||||
|
||||
Returns:
|
||||
List of transcribed text strings
|
||||
"""
|
||||
try:
|
||||
results = self.model.recognize(audio_files)
|
||||
return results
|
||||
except Exception as e:
|
||||
logger.error(f"Batch transcription failed: {e}")
|
||||
raise
|
||||
|
||||
def transcribe_stream(
|
||||
self,
|
||||
audio_chunk: np.ndarray,
|
||||
sample_rate: int = 16000,
|
||||
) -> str:
|
||||
"""
|
||||
Transcribe streaming audio chunk.
|
||||
|
||||
Args:
|
||||
audio_chunk: Audio chunk as numpy array (float32)
|
||||
sample_rate: Sample rate of audio
|
||||
|
||||
Returns:
|
||||
Transcribed text for the chunk
|
||||
"""
|
||||
return self.transcribe(audio_chunk, sample_rate=sample_rate)
|
||||
|
||||
|
||||
# Convenience function for backward compatibility
|
||||
def load_pipeline(**kwargs) -> ASRPipeline:
|
||||
"""Load and return ASR pipeline with given configuration."""
|
||||
return ASRPipeline(**kwargs)
|
||||
Reference in New Issue
Block a user