Files
miku-discord/stt/vad_processor.py

205 lines
7.0 KiB
Python

"""
Silero VAD Processor
Lightweight CPU-based Voice Activity Detection for real-time speech detection.
Runs continuously on audio chunks to determine when users are speaking.
"""
import torch
import numpy as np
from typing import Tuple, Optional
import logging
logger = logging.getLogger('vad')
class VADProcessor:
"""
Voice Activity Detection using Silero VAD model.
Processes audio chunks and returns speech probability.
Conservative settings to avoid cutting off speech.
"""
def __init__(
self,
sample_rate: int = 16000,
threshold: float = 0.5,
min_speech_duration_ms: int = 250,
min_silence_duration_ms: int = 500,
speech_pad_ms: int = 30
):
"""
Initialize VAD processor.
Args:
sample_rate: Audio sample rate (must be 8000 or 16000)
threshold: Speech probability threshold (0.0-1.0)
min_speech_duration_ms: Minimum speech duration to trigger (conservative)
min_silence_duration_ms: Minimum silence to end speech (conservative)
speech_pad_ms: Padding around speech segments
"""
self.sample_rate = sample_rate
self.threshold = threshold
self.min_speech_duration_ms = min_speech_duration_ms
self.min_silence_duration_ms = min_silence_duration_ms
self.speech_pad_ms = speech_pad_ms
# Load Silero VAD model (CPU only)
logger.info("Loading Silero VAD model (CPU)...")
self.model, utils = torch.hub.load(
repo_or_dir='snakers4/silero-vad',
model='silero_vad',
force_reload=False,
onnx=False # Use PyTorch model
)
# Extract utility functions
(self.get_speech_timestamps,
self.save_audio,
self.read_audio,
self.VADIterator,
self.collect_chunks) = utils
# State tracking
self.speaking = False
self.speech_start_time = None
self.silence_start_time = None
self.audio_buffer = []
# Chunk buffer for VAD (Silero needs at least 512 samples)
self.vad_buffer = []
self.min_vad_samples = 512 # Minimum samples for VAD processing
logger.info(f"VAD initialized: threshold={threshold}, "
f"min_speech={min_speech_duration_ms}ms, "
f"min_silence={min_silence_duration_ms}ms")
def process_chunk(self, audio_chunk: np.ndarray) -> Tuple[float, bool]:
"""
Process single audio chunk and return speech probability.
Buffers small chunks to meet VAD minimum size requirement.
Args:
audio_chunk: Audio data as numpy array (int16 or float32)
Returns:
(speech_probability, is_speaking): Probability and current speaking state
"""
# Convert to float32 if needed
if audio_chunk.dtype == np.int16:
audio_chunk = audio_chunk.astype(np.float32) / 32768.0
# Add to buffer
self.vad_buffer.append(audio_chunk)
# Check if we have enough samples
total_samples = sum(len(chunk) for chunk in self.vad_buffer)
if total_samples < self.min_vad_samples:
# Not enough samples yet, return neutral probability
return 0.0, False
# Concatenate buffer
audio_full = np.concatenate(self.vad_buffer)
# Process with VAD
audio_tensor = torch.from_numpy(audio_full)
with torch.no_grad():
speech_prob = self.model(audio_tensor, self.sample_rate).item()
# Clear buffer after processing
self.vad_buffer = []
# Update speaking state based on probability
is_speaking = speech_prob > self.threshold
return speech_prob, is_speaking
def detect_speech_segment(
self,
audio_chunk: np.ndarray,
timestamp_ms: float
) -> Optional[dict]:
"""
Process chunk and detect speech start/end events.
Args:
audio_chunk: Audio data
timestamp_ms: Current timestamp in milliseconds
Returns:
Event dict or None:
- {"event": "speech_start", "timestamp": float, "probability": float}
- {"event": "speech_end", "timestamp": float, "probability": float}
- {"event": "speaking", "probability": float} # Ongoing speech
"""
speech_prob, is_speaking = self.process_chunk(audio_chunk)
# Speech started
if is_speaking and not self.speaking:
if self.speech_start_time is None:
self.speech_start_time = timestamp_ms
# Check if speech duration exceeds minimum
speech_duration = timestamp_ms - self.speech_start_time
if speech_duration >= self.min_speech_duration_ms:
self.speaking = True
self.silence_start_time = None
logger.debug(f"Speech started at {timestamp_ms}ms, prob={speech_prob:.3f}")
return {
"event": "speech_start",
"timestamp": timestamp_ms,
"probability": speech_prob
}
# Speech ongoing
elif is_speaking and self.speaking:
self.silence_start_time = None # Reset silence timer
return {
"event": "speaking",
"probability": speech_prob,
"timestamp": timestamp_ms
}
# Silence detected during speech
elif not is_speaking and self.speaking:
if self.silence_start_time is None:
self.silence_start_time = timestamp_ms
# Check if silence duration exceeds minimum
silence_duration = timestamp_ms - self.silence_start_time
if silence_duration >= self.min_silence_duration_ms:
self.speaking = False
self.speech_start_time = None
logger.debug(f"Speech ended at {timestamp_ms}ms, prob={speech_prob:.3f}")
return {
"event": "speech_end",
"timestamp": timestamp_ms,
"probability": speech_prob
}
# No speech or insufficient duration
else:
if not is_speaking:
self.speech_start_time = None
return None
def reset(self):
"""Reset VAD state."""
self.speaking = False
self.speech_start_time = None
self.silence_start_time = None
self.audio_buffer.clear()
logger.debug("VAD state reset")
def get_state(self) -> dict:
"""Get current VAD state."""
return {
"speaking": self.speaking,
"speech_start_time": self.speech_start_time,
"silence_start_time": self.silence_start_time
}