Files
miku-discord/stt-parakeet/asr/asr_pipeline.py

163 lines
5.1 KiB
Python
Raw Normal View History

"""
ASR Pipeline using onnx-asr library with Parakeet TDT 0.6B V3 model
"""
import numpy as np
import onnx_asr
from typing import Union, Optional
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class ASRPipeline:
"""
ASR Pipeline wrapper for onnx-asr Parakeet TDT model.
Supports GPU acceleration via ONNX Runtime with CUDA/TensorRT.
"""
def __init__(
self,
model_name: str = "nemo-parakeet-tdt-0.6b-v3",
model_path: Optional[str] = None,
quantization: Optional[str] = None,
providers: Optional[list] = None,
use_vad: bool = False,
):
"""
Initialize ASR Pipeline.
Args:
model_name: Name of the model to load (default: "nemo-parakeet-tdt-0.6b-v3")
model_path: Optional local path to model files (default uses models/parakeet)
quantization: Optional quantization ("int8", "fp16", etc.)
providers: Optional ONNX runtime providers list for GPU acceleration
use_vad: Whether to use Voice Activity Detection
"""
self.model_name = model_name
self.model_path = model_path or "models/parakeet"
self.quantization = quantization
self.use_vad = use_vad
# Configure providers for GPU acceleration
if providers is None:
# Default: try CUDA, then CPU
providers = [
(
"CUDAExecutionProvider",
{
"device_id": 0,
"arena_extend_strategy": "kNextPowerOfTwo",
"gpu_mem_limit": 6 * 1024 * 1024 * 1024, # 6GB
"cudnn_conv_algo_search": "EXHAUSTIVE",
"do_copy_in_default_stream": True,
}
),
"CPUExecutionProvider",
]
self.providers = providers
logger.info(f"Initializing ASR Pipeline with model: {model_name}")
logger.info(f"Model path: {self.model_path}")
logger.info(f"Quantization: {quantization}")
logger.info(f"Providers: {providers}")
# Load the model
try:
self.model = onnx_asr.load_model(
model_name,
self.model_path,
quantization=quantization,
providers=providers,
)
logger.info("Model loaded successfully")
# Optionally add VAD
if use_vad:
logger.info("Loading VAD model...")
vad = onnx_asr.load_vad("silero", providers=providers)
self.model = self.model.with_vad(vad)
logger.info("VAD enabled")
except Exception as e:
logger.error(f"Failed to load model: {e}")
raise
def transcribe(
self,
audio: Union[str, np.ndarray],
sample_rate: int = 16000,
) -> Union[str, list]:
"""
Transcribe audio to text.
Args:
audio: Audio data as numpy array (float32) or path to WAV file
sample_rate: Sample rate of audio (default: 16000 Hz)
Returns:
Transcribed text string, or list of results if VAD is enabled
"""
try:
if isinstance(audio, str):
# Load from file
result = self.model.recognize(audio)
else:
# Process numpy array
if audio.dtype != np.float32:
audio = audio.astype(np.float32)
result = self.model.recognize(audio, sample_rate=sample_rate)
# If VAD is enabled, result is a generator
if self.use_vad:
return list(result)
return result
except Exception as e:
logger.error(f"Transcription failed: {e}")
raise
def transcribe_batch(
self,
audio_files: list,
) -> list:
"""
Transcribe multiple audio files in batch.
Args:
audio_files: List of paths to WAV files
Returns:
List of transcribed text strings
"""
try:
results = self.model.recognize(audio_files)
return results
except Exception as e:
logger.error(f"Batch transcription failed: {e}")
raise
def transcribe_stream(
self,
audio_chunk: np.ndarray,
sample_rate: int = 16000,
) -> str:
"""
Transcribe streaming audio chunk.
Args:
audio_chunk: Audio chunk as numpy array (float32)
sample_rate: Sample rate of audio
Returns:
Transcribed text for the chunk
"""
return self.transcribe(audio_chunk, sample_rate=sample_rate)
# Convenience function for backward compatibility
def load_pipeline(**kwargs) -> ASRPipeline:
"""Load and return ASR pipeline with given configuration."""
return ASRPipeline(**kwargs)