Decided on Parakeet ONNX Runtime. Works pretty great. Realtime voice chat possible now. UX lacking.

This commit is contained in:
2026-01-19 00:29:44 +02:00
parent 0a8910fff8
commit 362108f4b0
34 changed files with 4593 additions and 73 deletions

View File

@@ -0,0 +1,416 @@
#!/usr/bin/env python3
"""
ASR WebSocket Server with VAD - Optimized for Discord Bots
This server uses Voice Activity Detection (VAD) to:
- Detect speech start and end automatically
- Only transcribe speech segments (ignore silence)
- Provide clean boundaries for Discord message formatting
- Minimize processing of silence/noise
"""
import asyncio
import websockets
import numpy as np
import json
import logging
import sys
from datetime import datetime
from pathlib import Path
from collections import deque
from dataclasses import dataclass
from typing import Optional
# Add project root to path
sys.path.insert(0, str(Path(__file__).parent.parent))
from asr.asr_pipeline import ASRPipeline
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler('vad_server.log'),
logging.StreamHandler()
]
)
logger = logging.getLogger(__name__)
@dataclass
class SpeechSegment:
"""Represents a segment of detected speech."""
audio: np.ndarray
start_time: float
end_time: Optional[float] = None
is_complete: bool = False
class VADState:
"""Manages VAD state for speech detection."""
def __init__(self, sample_rate: int = 16000, speech_threshold: float = 0.5):
self.sample_rate = sample_rate
# Simple energy-based VAD parameters
self.energy_threshold = 0.005 # Lower threshold for better detection
self.speech_frames = 0
self.silence_frames = 0
self.min_speech_frames = 3 # 3 frames minimum (300ms with 100ms chunks)
self.min_silence_frames = 5 # 5 frames of silence (500ms)
self.is_speech = False
self.speech_buffer = []
# Pre-buffer to capture audio BEFORE speech detection
# This prevents cutting off the start of speech
self.pre_buffer_frames = 5 # Keep 5 frames (500ms) of pre-speech audio
self.pre_buffer = deque(maxlen=self.pre_buffer_frames)
# Progressive transcription tracking
self.last_partial_samples = 0 # Track when we last sent a partial
self.partial_interval_samples = int(sample_rate * 0.3) # Partial every 0.3 seconds (near real-time)
logger.info(f"VAD initialized: energy_threshold={self.energy_threshold}, pre_buffer={self.pre_buffer_frames} frames")
def calculate_energy(self, audio_chunk: np.ndarray) -> float:
"""Calculate RMS energy of audio chunk."""
return np.sqrt(np.mean(audio_chunk ** 2))
def process_audio(self, audio_chunk: np.ndarray) -> tuple[bool, Optional[np.ndarray], Optional[np.ndarray]]:
"""
Process audio chunk and detect speech boundaries.
Returns:
(speech_detected, complete_segment, partial_segment)
- speech_detected: True if currently in speech
- complete_segment: Audio segment if speech ended, None otherwise
- partial_segment: Audio for partial transcription, None otherwise
"""
energy = self.calculate_energy(audio_chunk)
chunk_is_speech = energy > self.energy_threshold
logger.debug(f"Energy: {energy:.6f}, Is speech: {chunk_is_speech}")
partial_segment = None
if chunk_is_speech:
self.speech_frames += 1
self.silence_frames = 0
if not self.is_speech and self.speech_frames >= self.min_speech_frames:
# Speech started - add pre-buffer to capture the beginning!
self.is_speech = True
logger.info("🎤 Speech started (including pre-buffer)")
# Add pre-buffered audio to speech buffer
if self.pre_buffer:
logger.debug(f"Adding {len(self.pre_buffer)} pre-buffered frames")
self.speech_buffer.extend(list(self.pre_buffer))
self.pre_buffer.clear()
if self.is_speech:
self.speech_buffer.append(audio_chunk)
else:
# Not in speech yet, keep in pre-buffer
self.pre_buffer.append(audio_chunk)
# Check if we should send a partial transcription
current_samples = sum(len(chunk) for chunk in self.speech_buffer)
samples_since_last_partial = current_samples - self.last_partial_samples
# Send partial if enough NEW audio accumulated AND we have minimum duration
min_duration_for_partial = int(self.sample_rate * 0.8) # At least 0.8s of audio
if samples_since_last_partial >= self.partial_interval_samples and current_samples >= min_duration_for_partial:
# Time for a partial update
partial_segment = np.concatenate(self.speech_buffer)
self.last_partial_samples = current_samples
logger.debug(f"📝 Partial update: {current_samples/self.sample_rate:.2f}s")
else:
if self.is_speech:
self.silence_frames += 1
# Add some trailing silence (up to limit)
if self.silence_frames < self.min_silence_frames:
self.speech_buffer.append(audio_chunk)
else:
# Speech ended
logger.info(f"🛑 Speech ended after {self.silence_frames} silence frames")
self.is_speech = False
self.speech_frames = 0
self.silence_frames = 0
self.last_partial_samples = 0 # Reset partial counter
if self.speech_buffer:
complete_segment = np.concatenate(self.speech_buffer)
segment_duration = len(complete_segment) / self.sample_rate
self.speech_buffer = []
self.pre_buffer.clear() # Clear pre-buffer after speech ends
logger.info(f"✅ Complete segment: {segment_duration:.2f}s")
return False, complete_segment, None
else:
self.speech_frames = 0
# Keep adding to pre-buffer when not in speech
self.pre_buffer.append(audio_chunk)
return self.is_speech, None, partial_segment
class VADServer:
"""
WebSocket server with VAD for Discord bot integration.
"""
def __init__(
self,
host: str = "0.0.0.0",
port: int = 8766,
model_path: str = "models/parakeet",
sample_rate: int = 16000,
):
"""Initialize server."""
self.host = host
self.port = port
self.sample_rate = sample_rate
self.active_connections = set()
# Terminal control codes
self.BOLD = '\033[1m'
self.GREEN = '\033[92m'
self.YELLOW = '\033[93m'
self.BLUE = '\033[94m'
self.RED = '\033[91m'
self.RESET = '\033[0m'
# Initialize ASR pipeline
logger.info("Loading ASR model...")
self.pipeline = ASRPipeline(model_path=model_path)
logger.info("ASR Pipeline ready")
def print_header(self):
"""Print server header."""
print("\n" + "=" * 80)
print(f"{self.BOLD}{self.BLUE}ASR Server with VAD - Discord Bot Ready{self.RESET}")
print("=" * 80)
print(f"Server: ws://{self.host}:{self.port}")
print(f"Sample Rate: {self.sample_rate} Hz")
print(f"Model: Parakeet TDT 0.6B V3")
print(f"VAD: Energy-based speech detection")
print("=" * 80 + "\n")
def display_transcription(self, client_id: str, text: str, duration: float):
"""Display transcription in the terminal."""
timestamp = datetime.now().strftime("%H:%M:%S")
print(f"{self.GREEN}{self.BOLD}[{timestamp}] {client_id}{self.RESET}")
print(f"{self.GREEN} 📝 {text}{self.RESET}")
print(f"{self.BLUE} ⏱️ Duration: {duration:.2f}s{self.RESET}\n")
sys.stdout.flush()
async def handle_client(self, websocket):
"""Handle WebSocket client connection."""
client_id = f"{websocket.remote_address[0]}:{websocket.remote_address[1]}"
logger.info(f"Client connected: {client_id}")
self.active_connections.add(websocket)
print(f"\n{self.BOLD}{'='*80}{self.RESET}")
print(f"{self.GREEN}✓ Client connected: {client_id}{self.RESET}")
print(f"{self.BOLD}{'='*80}{self.RESET}\n")
sys.stdout.flush()
# Initialize VAD state for this client
vad_state = VADState(sample_rate=self.sample_rate)
try:
# Send welcome message
await websocket.send(json.dumps({
"type": "info",
"message": "Connected to ASR server with VAD",
"sample_rate": self.sample_rate,
"vad_enabled": True,
}))
async for message in websocket:
try:
if isinstance(message, bytes):
# Binary audio data
audio_data = np.frombuffer(message, dtype=np.int16)
audio_data = audio_data.astype(np.float32) / 32768.0
# Process through VAD
is_speech, complete_segment, partial_segment = vad_state.process_audio(audio_data)
# Send VAD status to client (only on state change)
prev_speech_state = getattr(vad_state, '_prev_speech_state', False)
if is_speech != prev_speech_state:
vad_state._prev_speech_state = is_speech
await websocket.send(json.dumps({
"type": "vad_status",
"is_speech": is_speech,
}))
# Handle partial transcription (progressive updates while speaking)
if partial_segment is not None:
try:
text = self.pipeline.transcribe(
partial_segment,
sample_rate=self.sample_rate
)
if text and text.strip():
duration = len(partial_segment) / self.sample_rate
# Display on server
timestamp = datetime.now().strftime("%H:%M:%S")
print(f"{self.YELLOW}[{timestamp}] {client_id}{self.RESET}")
print(f"{self.YELLOW} → PARTIAL: {text}{self.RESET}\n")
sys.stdout.flush()
# Send to client
response = {
"type": "transcript",
"text": text,
"is_final": False,
"duration": duration,
}
await websocket.send(json.dumps(response))
except Exception as e:
logger.error(f"Partial transcription error: {e}")
# If we have a complete speech segment, transcribe it
if complete_segment is not None:
try:
text = self.pipeline.transcribe(
complete_segment,
sample_rate=self.sample_rate
)
if text and text.strip():
duration = len(complete_segment) / self.sample_rate
# Display on server
self.display_transcription(client_id, text, duration)
# Send to client
response = {
"type": "transcript",
"text": text,
"is_final": True,
"duration": duration,
}
await websocket.send(json.dumps(response))
except Exception as e:
logger.error(f"Transcription error: {e}")
await websocket.send(json.dumps({
"type": "error",
"message": f"Transcription failed: {str(e)}"
}))
elif isinstance(message, str):
# JSON command
try:
command = json.loads(message)
if command.get("type") == "force_transcribe":
# Force transcribe current buffer
if vad_state.speech_buffer:
audio_chunk = np.concatenate(vad_state.speech_buffer)
vad_state.speech_buffer = []
vad_state.is_speech = False
text = self.pipeline.transcribe(
audio_chunk,
sample_rate=self.sample_rate
)
if text and text.strip():
duration = len(audio_chunk) / self.sample_rate
self.display_transcription(client_id, text, duration)
response = {
"type": "transcript",
"text": text,
"is_final": True,
"duration": duration,
}
await websocket.send(json.dumps(response))
elif command.get("type") == "reset":
# Reset VAD state
vad_state = VADState(sample_rate=self.sample_rate)
await websocket.send(json.dumps({
"type": "info",
"message": "VAD state reset"
}))
print(f"{self.YELLOW}[{client_id}] VAD reset{self.RESET}\n")
sys.stdout.flush()
elif command.get("type") == "set_threshold":
# Adjust VAD threshold
threshold = command.get("threshold", 0.01)
vad_state.energy_threshold = threshold
await websocket.send(json.dumps({
"type": "info",
"message": f"VAD threshold set to {threshold}"
}))
except json.JSONDecodeError:
logger.warning(f"Invalid JSON from {client_id}: {message}")
except Exception as e:
logger.error(f"Error processing message from {client_id}: {e}")
break
except websockets.exceptions.ConnectionClosed:
logger.info(f"Connection closed: {client_id}")
except Exception as e:
logger.error(f"Unexpected error with {client_id}: {e}")
finally:
self.active_connections.discard(websocket)
print(f"\n{self.BOLD}{'='*80}{self.RESET}")
print(f"{self.YELLOW}✗ Client disconnected: {client_id}{self.RESET}")
print(f"{self.BOLD}{'='*80}{self.RESET}\n")
sys.stdout.flush()
logger.info(f"Connection closed: {client_id}")
async def start(self):
"""Start the WebSocket server."""
self.print_header()
async with websockets.serve(self.handle_client, self.host, self.port):
logger.info(f"Starting WebSocket server on {self.host}:{self.port}")
print(f"{self.GREEN}{self.BOLD}Server is running with VAD enabled!{self.RESET}")
print(f"{self.BOLD}Ready for Discord bot connections...{self.RESET}\n")
sys.stdout.flush()
# Keep server running
await asyncio.Future()
def main():
"""Main entry point."""
import argparse
parser = argparse.ArgumentParser(description="ASR Server with VAD for Discord")
parser.add_argument("--host", default="0.0.0.0", help="Host address")
parser.add_argument("--port", type=int, default=8766, help="Port number")
parser.add_argument("--model-path", default="models/parakeet", help="Model directory")
parser.add_argument("--sample-rate", type=int, default=16000, help="Sample rate")
args = parser.parse_args()
server = VADServer(
host=args.host,
port=args.port,
model_path=args.model_path,
sample_rate=args.sample_rate,
)
try:
asyncio.run(server.start())
except KeyboardInterrupt:
print(f"\n\n{server.YELLOW}Server stopped by user{server.RESET}")
logger.info("Server stopped by user")
if __name__ == "__main__":
main()