refactor: Implement low-latency STT pipeline with speculative transcription

Major architectural overhaul of the speech-to-text pipeline for real-time voice chat:

STT Server Rewrite:
- Replaced RealtimeSTT dependency with direct Silero VAD + Faster-Whisper integration
- Achieved sub-second latency by eliminating unnecessary abstractions
- Uses small.en Whisper model for fast transcription (~850ms)

Speculative Transcription (NEW):
- Start transcribing at 150ms silence (speculative) while still listening
- If speech continues, discard speculative result and keep buffering
- If 400ms silence confirmed, use pre-computed speculative result immediately
- Reduces latency by ~250-850ms for typical utterances with clear pauses

VAD Implementation:
- Silero VAD with ONNX (CPU-efficient) for 32ms chunk processing
- Direct speech boundary detection without RealtimeSTT overhead
- Configurable thresholds for silence detection (400ms final, 150ms speculative)

Architecture:
- Single Whisper model loaded once, shared across sessions
- VAD runs on every 512-sample chunk for immediate speech detection
- Background transcription worker thread for non-blocking processing
- Greedy decoding (beam_size=1) for maximum speed

Performance:
- Previous: 400ms silence wait + ~850ms transcription = ~1.25s total latency
- Current: 400ms silence wait + 0ms (speculative ready) = ~400ms (best case)
- Single model reduces VRAM usage, prevents OOM on GTX 1660

Container Manager Updates:
- Updated health check logic to work with new response format
- Changed from checking 'warmed_up' flag to just 'status: ready'
- Improved terminology from 'warmup' to 'models loading'

Files Changed:
- stt-realtime/stt_server.py: Complete rewrite with Silero VAD + speculative transcription
- stt-realtime/requirements.txt: Removed RealtimeSTT, using torch.hub for Silero VAD
- bot/utils/container_manager.py: Updated health check for new STT response format
- bot/api.py: Updated docstring to reflect new architecture
- backups/: Archived old RealtimeSTT-based implementation

This addresses low latency requirements while maintaining accuracy with configurable
speech detection thresholds.
This commit is contained in:
2026-01-22 22:08:07 +02:00
parent 2934efba22
commit eb03dfce4d
5 changed files with 850 additions and 400 deletions

View File

@@ -0,0 +1,510 @@
#!/usr/bin/env python3
"""
RealtimeSTT WebSocket Server
Provides real-time speech-to-text transcription using Faster-Whisper.
Receives audio chunks via WebSocket and streams back partial/final transcripts.
Protocol:
- Client sends: binary audio data (16kHz, 16-bit mono PCM)
- Client sends: JSON {"command": "reset"} to reset state
- Server sends: JSON {"type": "partial", "text": "...", "timestamp": float}
- Server sends: JSON {"type": "final", "text": "...", "timestamp": float}
"""
import asyncio
import json
import logging
import time
import threading
import queue
from typing import Optional, Dict, Any
import numpy as np
import websockets
from websockets.server import serve
from aiohttp import web
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s %(levelname)s [%(name)s] %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
logger = logging.getLogger('stt-realtime')
# Import RealtimeSTT
from RealtimeSTT import AudioToTextRecorder
# Global warmup state
warmup_complete = False
warmup_lock = threading.Lock()
warmup_recorder = None
class STTSession:
"""
Manages a single STT session for a WebSocket client.
Key architectural point: We own the audio buffer and decoder.
RealtimeSTT is used ONLY for VAD, not for transcription ownership.
"""
def __init__(self, websocket, session_id: str, config: Dict[str, Any]):
self.websocket = websocket
self.session_id = session_id
self.config = config
self.recorder: Optional[AudioToTextRecorder] = None
self.running = False
self.audio_queue = queue.Queue()
self.feed_thread: Optional[threading.Thread] = None
# OUR audio buffer - we own this, not RealtimeSTT
self.float_buffer = [] # Rolling float32 buffer (0.0 to 1.0 range)
self.max_buffer_duration = 30.0 # Keep max 30 seconds
# Decode state
self.last_decode_text = ""
self.recording_active = False
self.recording_stop_time = 0
self.last_decode_time = 0
self.final_sent = False # Track if we've sent final for this utterance
self.last_audio_time = 0 # Track when we last received audio with speech
self.speech_detected = False # Track if we've detected any speech
logger.info(f"[{session_id}] Session created")
def _on_recording_stop(self):
"""Called when recording stops (silence detected)."""
logger.info(f"[{self.session_id}] ⏹️ Recording stopped - will emit final in decode loop")
self.recording_active = False
self.recording_stop_time = time.time() # Track when recording stopped
def _on_recording_start(self):
"""Called when recording starts (speech detected)."""
logger.info(f"[{self.session_id}] 🎙️ Recording started")
self.recording_active = True
self.float_buffer = [] # Reset buffer for new utterance
self.last_decode_text = ""
self.last_decode_time = 0
self.final_sent = False # Reset final flag for new utterance
async def _send_transcript(self, transcript_type: str, text: str):
"""Send transcript to client via WebSocket."""
try:
message = {
"type": transcript_type,
"text": text,
"timestamp": time.time()
}
await self.websocket.send(json.dumps(message))
except Exception as e:
logger.error(f"[{self.session_id}] Failed to send transcript: {e}")
def _feed_audio_thread(self):
"""Thread that feeds audio to the recorder."""
logger.info(f"[{self.session_id}] Audio feed thread started")
while self.running:
try:
# Get audio chunk with timeout
audio_chunk = self.audio_queue.get(timeout=0.1)
if audio_chunk is not None and self.recorder:
self.recorder.feed_audio(audio_chunk)
except queue.Empty:
continue
except Exception as e:
logger.error(f"[{self.session_id}] Error feeding audio: {e}")
logger.info(f"[{self.session_id}] Audio feed thread stopped")
async def start(self, loop: asyncio.AbstractEventLoop):
"""Start the STT session."""
self.loop = loop
self.running = True
logger.info(f"[{self.session_id}] Starting RealtimeSTT recorder...")
logger.info(f"[{self.session_id}] Model: {self.config['model']}")
logger.info(f"[{self.session_id}] Device: {self.config['device']}")
try:
# Create recorder in a thread to avoid blocking
def init_recorder():
# Build initialization kwargs
recorder_kwargs = {
# Model settings - ONLY turbo model, no dual-model setup
'model': self.config['model'],
'language': self.config['language'],
'compute_type': self.config['compute_type'],
'device': self.config['device'],
# Disable microphone - we feed audio manually
'use_microphone': False,
# DISABLE realtime partials - we'll use incremental utterance decoding instead
'enable_realtime_transcription': False, # ← KEY CHANGE: No streaming partials
# VAD settings - optimized for longer utterances (per ChatGPT)
'silero_sensitivity': self.config['silero_sensitivity'],
'silero_use_onnx': True, # Faster
'webrtc_sensitivity': self.config['webrtc_sensitivity'],
'post_speech_silence_duration': self.config['silence_duration'],
'min_length_of_recording': self.config['min_recording_length'],
'min_gap_between_recordings': self.config['min_gap'],
'pre_recording_buffer_duration': 1.2, # ChatGPT: ~1.2s before first decode
# Callbacks
'on_recording_start': self._on_recording_start,
'on_recording_stop': self._on_recording_stop,
'on_vad_detect_start': lambda: logger.debug(f"[{self.session_id}] VAD listening"),
'on_vad_detect_stop': lambda: logger.debug(f"[{self.session_id}] VAD stopped"),
# Other settings
'spinner': False, # No spinner in container
'level': logging.WARNING, # Reduce internal logging
# Beam search settings - optimized for accuracy
'beam_size': 5,
# Batch sizes
'batch_size': 16,
'initial_prompt': "",
}
self.recorder = AudioToTextRecorder(**recorder_kwargs)
logger.info(f"[{self.session_id}] ✅ Recorder initialized (incremental mode, transcript-stability silence detection)")
# Run initialization in thread pool
await asyncio.get_event_loop().run_in_executor(None, init_recorder)
# Start audio feed thread
self.feed_thread = threading.Thread(target=self._feed_audio_thread, daemon=True)
self.feed_thread.start()
# NOTE: We don't call recorder.start() - VAD callbacks don't work with use_microphone=False
# Instead, we detect silence ourselves via transcript stability in the decode loop
# Start CORRECT incremental decoding loop
# Since RealtimeSTT VAD callbacks don't work with use_microphone=False,
# we detect silence ourselves via transcript stability
def run_decode_loop():
"""
Decode buffer periodically. Detect end-of-speech when:
1. We have a transcript AND
2. Transcript hasn't changed for silence_threshold seconds
"""
decode_interval = 0.7 # Re-decode every 700ms
min_audio_before_first_decode = 1.2 # Wait 1.2s before first decode
silence_threshold = 1.5 # If transcript stable for 1.5s, consider it final
last_transcript_change_time = 0
has_transcript = False
logger.info(f"[{self.session_id}] Decode loop ready (silence detection: {silence_threshold}s)")
while self.running:
try:
current_time = time.time()
buffer_duration = len(self.float_buffer) / 16000.0 if self.float_buffer else 0
# Only decode if we have enough audio
if buffer_duration >= min_audio_before_first_decode:
# Check if enough time since last decode
if (current_time - self.last_decode_time) >= decode_interval:
try:
audio_array = np.array(self.float_buffer, dtype=np.float32)
logger.debug(f"[{self.session_id}] 🔄 Decode (buffer: {buffer_duration:.1f}s)")
result = self.recorder.perform_final_transcription(audio_array)
text = result.strip() if result else ""
if text:
if text != self.last_decode_text:
# Transcript changed - update and reset stability timer
self.last_decode_text = text
last_transcript_change_time = current_time
has_transcript = True
logger.info(f"[{self.session_id}] 📝 Partial: {text}")
asyncio.run_coroutine_threadsafe(
self._send_transcript("partial", text),
self.loop
)
# else: transcript same, stability timer continues
self.last_decode_time = current_time
except Exception as e:
logger.error(f"[{self.session_id}] Decode error: {e}", exc_info=True)
# Check for silence (transcript stable for threshold)
if has_transcript and last_transcript_change_time > 0:
time_since_change = current_time - last_transcript_change_time
if time_since_change >= silence_threshold:
# Transcript has been stable - emit final
logger.info(f"[{self.session_id}] ✅ Final (stable {time_since_change:.1f}s): {self.last_decode_text}")
asyncio.run_coroutine_threadsafe(
self._send_transcript("final", self.last_decode_text),
self.loop
)
# Reset for next utterance
self.float_buffer = []
self.last_decode_text = ""
self.last_decode_time = 0
last_transcript_change_time = 0
has_transcript = False
time.sleep(0.1) # Check every 100ms
except Exception as e:
if self.running:
logger.error(f"[{self.session_id}] Decode loop error: {e}", exc_info=True)
break
self.text_thread = threading.Thread(target=run_decode_loop, daemon=True)
self.text_thread.start()
logger.info(f"[{self.session_id}] ✅ Session started successfully")
except Exception as e:
logger.error(f"[{self.session_id}] Failed to start session: {e}", exc_info=True)
raise
def feed_audio(self, audio_data: bytes):
"""Feed audio data to the recorder AND our buffer."""
if self.running:
# Convert bytes to numpy array (16-bit PCM)
audio_np = np.frombuffer(audio_data, dtype=np.int16)
# Feed to RealtimeSTT for VAD only
self.audio_queue.put(audio_np)
# Also add to OUR float32 buffer (normalized to -1.0 to 1.0)
float_audio = audio_np.astype(np.float32) / 32768.0
self.float_buffer.extend(float_audio)
# Keep buffer size bounded (max 30 seconds at 16kHz = 480k samples)
max_samples = int(self.max_buffer_duration * 16000)
if len(self.float_buffer) > max_samples:
self.float_buffer = self.float_buffer[-max_samples:]
def reset(self):
"""Reset the session state."""
logger.info(f"[{self.session_id}] Resetting session")
self.float_buffer = []
self.last_decode_text = ""
# Clear audio queue
while not self.audio_queue.empty():
try:
self.audio_queue.get_nowait()
except queue.Empty:
break
async def stop(self):
"""Stop the session and cleanup."""
logger.info(f"[{self.session_id}] Stopping session...")
self.running = False
# Wait for threads to finish
if self.feed_thread and self.feed_thread.is_alive():
self.feed_thread.join(timeout=2)
# Shutdown recorder
if self.recorder:
try:
self.recorder.shutdown()
except Exception as e:
logger.error(f"[{self.session_id}] Error shutting down recorder: {e}")
logger.info(f"[{self.session_id}] Session stopped")
class STTServer:
"""
WebSocket server for RealtimeSTT.
Handles multiple concurrent clients (one per Discord user).
"""
def __init__(self, host: str = "0.0.0.0", port: int = 8766, config: Dict[str, Any] = None):
self.host = host
self.port = port
self.sessions: Dict[str, STTSession] = {}
self.session_counter = 0
# Config must be provided
if not config:
raise ValueError("Configuration dict must be provided to STTServer")
self.config = config
logger.info("=" * 60)
logger.info("RealtimeSTT Server Configuration:")
logger.info(f" Host: {host}:{port}")
logger.info(f" Model: {self.config['model']}")
logger.info(f" Language: {self.config.get('language', 'auto-detect')}")
logger.info(f" Device: {self.config['device']}")
logger.info(f" Compute Type: {self.config['compute_type']}")
logger.info(f" Silence Duration: {self.config['silence_duration']}s")
logger.info(f" Realtime Pause: {self.config.get('realtime_processing_pause', 'N/A')}s")
logger.info("=" * 60)
async def handle_client(self, websocket):
"""Handle a WebSocket client connection."""
self.session_counter += 1
session_id = f"session_{self.session_counter}"
session = None
try:
logger.info(f"[{session_id}] Client connected from {websocket.remote_address}")
# Create session
session = STTSession(websocket, session_id, self.config)
self.sessions[session_id] = session
# Start session
await session.start(asyncio.get_event_loop())
# Process messages
async for message in websocket:
try:
if isinstance(message, bytes):
# Binary audio data
session.feed_audio(message)
else:
# JSON command
data = json.loads(message)
command = data.get('command', '')
if command == 'reset':
session.reset()
elif command == 'ping':
await websocket.send(json.dumps({
'type': 'pong',
'timestamp': time.time()
}))
else:
logger.warning(f"[{session_id}] Unknown command: {command}")
except json.JSONDecodeError:
logger.warning(f"[{session_id}] Invalid JSON message")
except Exception as e:
logger.error(f"[{session_id}] Error processing message: {e}")
except websockets.exceptions.ConnectionClosed:
logger.info(f"[{session_id}] Client disconnected")
except Exception as e:
logger.error(f"[{session_id}] Error: {e}", exc_info=True)
finally:
# Cleanup
if session:
await session.stop()
del self.sessions[session_id]
async def run(self):
"""Run the WebSocket server."""
logger.info(f"Starting RealtimeSTT server on ws://{self.host}:{self.port}")
async with serve(
self.handle_client,
self.host,
self.port,
ping_interval=30,
ping_timeout=10,
max_size=10 * 1024 * 1024, # 10MB max message size
):
logger.info("✅ Server ready and listening for connections")
await asyncio.Future() # Run forever
async def warmup_model(config: Dict[str, Any]):
"""
Warmup is DISABLED - it wastes memory by loading a model that's never reused.
The first session will load the model when needed.
"""
global warmup_complete
logger.info("⚠️ Warmup disabled to save VRAM - model will load on first connection")
warmup_complete = True # Mark as complete so health check passes
async def health_handler(request):
"""HTTP health check endpoint"""
if warmup_complete:
return web.json_response({
"status": "ready",
"warmed_up": True,
"model": "small.en",
"device": "cuda"
})
else:
return web.json_response({
"status": "warming_up",
"warmed_up": False,
"model": "small.en",
"device": "cuda"
}, status=503)
async def start_http_server(host: str, http_port: int):
"""Start HTTP server for health checks"""
app = web.Application()
app.router.add_get('/health', health_handler)
runner = web.AppRunner(app)
await runner.setup()
site = web.TCPSite(runner, host, http_port)
await site.start()
logger.info(f"✅ HTTP health server listening on http://{host}:{http_port}")
def main():
"""Main entry point."""
import os
# Get configuration from environment
host = os.environ.get('STT_HOST', '0.0.0.0')
port = int(os.environ.get('STT_PORT', '8766'))
http_port = int(os.environ.get('STT_HTTP_PORT', '8767')) # HTTP health check port
# Configuration - ChatGPT's incremental utterance decoding approach
config = {
'model': 'turbo', # Fast multilingual model
'language': 'en', # SET LANGUAGE! Auto-detect adds 4+ seconds latency (change to 'ja', 'bg' as needed)
'compute_type': 'float16',
'device': 'cuda',
# VAD settings - ChatGPT: "minimum speech ~600ms, end-of-speech silence ~400-600ms"
'silero_sensitivity': 0.6,
'webrtc_sensitivity': 3,
'silence_duration': 0.5, # 500ms end-of-speech silence
'min_recording_length': 0.6, # 600ms minimum speech
'min_gap': 0.3,
}
# Create and run server
server = STTServer(host=host, port=port, config=config)
async def run_all():
# Start warmup in background
asyncio.create_task(warmup_model(config))
# Start HTTP health server
asyncio.create_task(start_http_server(host, http_port))
# Start WebSocket server
await server.run()
try:
asyncio.run(run_all())
except KeyboardInterrupt:
logger.info("Server shutdown requested")
except Exception as e:
logger.error(f"Server error: {e}", exc_info=True)
raise
if __name__ == '__main__':
main()