refactor: Implement low-latency STT pipeline with speculative transcription
Major architectural overhaul of the speech-to-text pipeline for real-time voice chat: STT Server Rewrite: - Replaced RealtimeSTT dependency with direct Silero VAD + Faster-Whisper integration - Achieved sub-second latency by eliminating unnecessary abstractions - Uses small.en Whisper model for fast transcription (~850ms) Speculative Transcription (NEW): - Start transcribing at 150ms silence (speculative) while still listening - If speech continues, discard speculative result and keep buffering - If 400ms silence confirmed, use pre-computed speculative result immediately - Reduces latency by ~250-850ms for typical utterances with clear pauses VAD Implementation: - Silero VAD with ONNX (CPU-efficient) for 32ms chunk processing - Direct speech boundary detection without RealtimeSTT overhead - Configurable thresholds for silence detection (400ms final, 150ms speculative) Architecture: - Single Whisper model loaded once, shared across sessions - VAD runs on every 512-sample chunk for immediate speech detection - Background transcription worker thread for non-blocking processing - Greedy decoding (beam_size=1) for maximum speed Performance: - Previous: 400ms silence wait + ~850ms transcription = ~1.25s total latency - Current: 400ms silence wait + 0ms (speculative ready) = ~400ms (best case) - Single model reduces VRAM usage, prevents OOM on GTX 1660 Container Manager Updates: - Updated health check logic to work with new response format - Changed from checking 'warmed_up' flag to just 'status: ready' - Improved terminology from 'warmup' to 'models loading' Files Changed: - stt-realtime/stt_server.py: Complete rewrite with Silero VAD + speculative transcription - stt-realtime/requirements.txt: Removed RealtimeSTT, using torch.hub for Silero VAD - bot/utils/container_manager.py: Updated health check for new STT response format - bot/api.py: Updated docstring to reflect new architecture - backups/: Archived old RealtimeSTT-based implementation This addresses low latency requirements while maintaining accuracy with configurable speech detection thresholds.
This commit is contained in:
510
backups/stt_server_realtimestt_based_2025-01-22.py
Normal file
510
backups/stt_server_realtimestt_based_2025-01-22.py
Normal file
@@ -0,0 +1,510 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
RealtimeSTT WebSocket Server
|
||||
|
||||
Provides real-time speech-to-text transcription using Faster-Whisper.
|
||||
Receives audio chunks via WebSocket and streams back partial/final transcripts.
|
||||
|
||||
Protocol:
|
||||
- Client sends: binary audio data (16kHz, 16-bit mono PCM)
|
||||
- Client sends: JSON {"command": "reset"} to reset state
|
||||
- Server sends: JSON {"type": "partial", "text": "...", "timestamp": float}
|
||||
- Server sends: JSON {"type": "final", "text": "...", "timestamp": float}
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
import threading
|
||||
import queue
|
||||
from typing import Optional, Dict, Any
|
||||
import numpy as np
|
||||
import websockets
|
||||
from websockets.server import serve
|
||||
from aiohttp import web
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s %(levelname)s [%(name)s] %(message)s',
|
||||
datefmt='%Y-%m-%d %H:%M:%S'
|
||||
)
|
||||
logger = logging.getLogger('stt-realtime')
|
||||
|
||||
# Import RealtimeSTT
|
||||
from RealtimeSTT import AudioToTextRecorder
|
||||
|
||||
# Global warmup state
|
||||
warmup_complete = False
|
||||
warmup_lock = threading.Lock()
|
||||
warmup_recorder = None
|
||||
|
||||
|
||||
class STTSession:
|
||||
"""
|
||||
Manages a single STT session for a WebSocket client.
|
||||
|
||||
Key architectural point: We own the audio buffer and decoder.
|
||||
RealtimeSTT is used ONLY for VAD, not for transcription ownership.
|
||||
"""
|
||||
|
||||
def __init__(self, websocket, session_id: str, config: Dict[str, Any]):
|
||||
self.websocket = websocket
|
||||
self.session_id = session_id
|
||||
self.config = config
|
||||
self.recorder: Optional[AudioToTextRecorder] = None
|
||||
self.running = False
|
||||
self.audio_queue = queue.Queue()
|
||||
self.feed_thread: Optional[threading.Thread] = None
|
||||
|
||||
# OUR audio buffer - we own this, not RealtimeSTT
|
||||
self.float_buffer = [] # Rolling float32 buffer (0.0 to 1.0 range)
|
||||
self.max_buffer_duration = 30.0 # Keep max 30 seconds
|
||||
|
||||
# Decode state
|
||||
self.last_decode_text = ""
|
||||
self.recording_active = False
|
||||
self.recording_stop_time = 0
|
||||
self.last_decode_time = 0
|
||||
self.final_sent = False # Track if we've sent final for this utterance
|
||||
self.last_audio_time = 0 # Track when we last received audio with speech
|
||||
self.speech_detected = False # Track if we've detected any speech
|
||||
|
||||
logger.info(f"[{session_id}] Session created")
|
||||
|
||||
def _on_recording_stop(self):
|
||||
"""Called when recording stops (silence detected)."""
|
||||
logger.info(f"[{self.session_id}] ⏹️ Recording stopped - will emit final in decode loop")
|
||||
self.recording_active = False
|
||||
self.recording_stop_time = time.time() # Track when recording stopped
|
||||
|
||||
def _on_recording_start(self):
|
||||
"""Called when recording starts (speech detected)."""
|
||||
logger.info(f"[{self.session_id}] 🎙️ Recording started")
|
||||
self.recording_active = True
|
||||
self.float_buffer = [] # Reset buffer for new utterance
|
||||
self.last_decode_text = ""
|
||||
self.last_decode_time = 0
|
||||
self.final_sent = False # Reset final flag for new utterance
|
||||
|
||||
async def _send_transcript(self, transcript_type: str, text: str):
|
||||
"""Send transcript to client via WebSocket."""
|
||||
try:
|
||||
message = {
|
||||
"type": transcript_type,
|
||||
"text": text,
|
||||
"timestamp": time.time()
|
||||
}
|
||||
await self.websocket.send(json.dumps(message))
|
||||
except Exception as e:
|
||||
logger.error(f"[{self.session_id}] Failed to send transcript: {e}")
|
||||
|
||||
def _feed_audio_thread(self):
|
||||
"""Thread that feeds audio to the recorder."""
|
||||
logger.info(f"[{self.session_id}] Audio feed thread started")
|
||||
while self.running:
|
||||
try:
|
||||
# Get audio chunk with timeout
|
||||
audio_chunk = self.audio_queue.get(timeout=0.1)
|
||||
if audio_chunk is not None and self.recorder:
|
||||
self.recorder.feed_audio(audio_chunk)
|
||||
except queue.Empty:
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(f"[{self.session_id}] Error feeding audio: {e}")
|
||||
logger.info(f"[{self.session_id}] Audio feed thread stopped")
|
||||
|
||||
async def start(self, loop: asyncio.AbstractEventLoop):
|
||||
"""Start the STT session."""
|
||||
self.loop = loop
|
||||
self.running = True
|
||||
|
||||
logger.info(f"[{self.session_id}] Starting RealtimeSTT recorder...")
|
||||
logger.info(f"[{self.session_id}] Model: {self.config['model']}")
|
||||
logger.info(f"[{self.session_id}] Device: {self.config['device']}")
|
||||
|
||||
try:
|
||||
# Create recorder in a thread to avoid blocking
|
||||
def init_recorder():
|
||||
# Build initialization kwargs
|
||||
recorder_kwargs = {
|
||||
# Model settings - ONLY turbo model, no dual-model setup
|
||||
'model': self.config['model'],
|
||||
'language': self.config['language'],
|
||||
'compute_type': self.config['compute_type'],
|
||||
'device': self.config['device'],
|
||||
|
||||
# Disable microphone - we feed audio manually
|
||||
'use_microphone': False,
|
||||
|
||||
# DISABLE realtime partials - we'll use incremental utterance decoding instead
|
||||
'enable_realtime_transcription': False, # ← KEY CHANGE: No streaming partials
|
||||
|
||||
# VAD settings - optimized for longer utterances (per ChatGPT)
|
||||
'silero_sensitivity': self.config['silero_sensitivity'],
|
||||
'silero_use_onnx': True, # Faster
|
||||
'webrtc_sensitivity': self.config['webrtc_sensitivity'],
|
||||
'post_speech_silence_duration': self.config['silence_duration'],
|
||||
'min_length_of_recording': self.config['min_recording_length'],
|
||||
'min_gap_between_recordings': self.config['min_gap'],
|
||||
'pre_recording_buffer_duration': 1.2, # ChatGPT: ~1.2s before first decode
|
||||
|
||||
# Callbacks
|
||||
'on_recording_start': self._on_recording_start,
|
||||
'on_recording_stop': self._on_recording_stop,
|
||||
'on_vad_detect_start': lambda: logger.debug(f"[{self.session_id}] VAD listening"),
|
||||
'on_vad_detect_stop': lambda: logger.debug(f"[{self.session_id}] VAD stopped"),
|
||||
|
||||
# Other settings
|
||||
'spinner': False, # No spinner in container
|
||||
'level': logging.WARNING, # Reduce internal logging
|
||||
|
||||
# Beam search settings - optimized for accuracy
|
||||
'beam_size': 5,
|
||||
|
||||
# Batch sizes
|
||||
'batch_size': 16,
|
||||
|
||||
'initial_prompt': "",
|
||||
}
|
||||
|
||||
self.recorder = AudioToTextRecorder(**recorder_kwargs)
|
||||
logger.info(f"[{self.session_id}] ✅ Recorder initialized (incremental mode, transcript-stability silence detection)")
|
||||
|
||||
# Run initialization in thread pool
|
||||
await asyncio.get_event_loop().run_in_executor(None, init_recorder)
|
||||
|
||||
# Start audio feed thread
|
||||
self.feed_thread = threading.Thread(target=self._feed_audio_thread, daemon=True)
|
||||
self.feed_thread.start()
|
||||
|
||||
# NOTE: We don't call recorder.start() - VAD callbacks don't work with use_microphone=False
|
||||
# Instead, we detect silence ourselves via transcript stability in the decode loop
|
||||
|
||||
# Start CORRECT incremental decoding loop
|
||||
# Since RealtimeSTT VAD callbacks don't work with use_microphone=False,
|
||||
# we detect silence ourselves via transcript stability
|
||||
def run_decode_loop():
|
||||
"""
|
||||
Decode buffer periodically. Detect end-of-speech when:
|
||||
1. We have a transcript AND
|
||||
2. Transcript hasn't changed for silence_threshold seconds
|
||||
"""
|
||||
decode_interval = 0.7 # Re-decode every 700ms
|
||||
min_audio_before_first_decode = 1.2 # Wait 1.2s before first decode
|
||||
silence_threshold = 1.5 # If transcript stable for 1.5s, consider it final
|
||||
|
||||
last_transcript_change_time = 0
|
||||
has_transcript = False
|
||||
|
||||
logger.info(f"[{self.session_id}] Decode loop ready (silence detection: {silence_threshold}s)")
|
||||
|
||||
while self.running:
|
||||
try:
|
||||
current_time = time.time()
|
||||
buffer_duration = len(self.float_buffer) / 16000.0 if self.float_buffer else 0
|
||||
|
||||
# Only decode if we have enough audio
|
||||
if buffer_duration >= min_audio_before_first_decode:
|
||||
# Check if enough time since last decode
|
||||
if (current_time - self.last_decode_time) >= decode_interval:
|
||||
try:
|
||||
audio_array = np.array(self.float_buffer, dtype=np.float32)
|
||||
|
||||
logger.debug(f"[{self.session_id}] 🔄 Decode (buffer: {buffer_duration:.1f}s)")
|
||||
|
||||
result = self.recorder.perform_final_transcription(audio_array)
|
||||
text = result.strip() if result else ""
|
||||
|
||||
if text:
|
||||
if text != self.last_decode_text:
|
||||
# Transcript changed - update and reset stability timer
|
||||
self.last_decode_text = text
|
||||
last_transcript_change_time = current_time
|
||||
has_transcript = True
|
||||
|
||||
logger.info(f"[{self.session_id}] 📝 Partial: {text}")
|
||||
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self._send_transcript("partial", text),
|
||||
self.loop
|
||||
)
|
||||
# else: transcript same, stability timer continues
|
||||
|
||||
self.last_decode_time = current_time
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[{self.session_id}] Decode error: {e}", exc_info=True)
|
||||
|
||||
# Check for silence (transcript stable for threshold)
|
||||
if has_transcript and last_transcript_change_time > 0:
|
||||
time_since_change = current_time - last_transcript_change_time
|
||||
if time_since_change >= silence_threshold:
|
||||
# Transcript has been stable - emit final
|
||||
logger.info(f"[{self.session_id}] ✅ Final (stable {time_since_change:.1f}s): {self.last_decode_text}")
|
||||
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self._send_transcript("final", self.last_decode_text),
|
||||
self.loop
|
||||
)
|
||||
|
||||
# Reset for next utterance
|
||||
self.float_buffer = []
|
||||
self.last_decode_text = ""
|
||||
self.last_decode_time = 0
|
||||
last_transcript_change_time = 0
|
||||
has_transcript = False
|
||||
|
||||
time.sleep(0.1) # Check every 100ms
|
||||
|
||||
except Exception as e:
|
||||
if self.running:
|
||||
logger.error(f"[{self.session_id}] Decode loop error: {e}", exc_info=True)
|
||||
break
|
||||
|
||||
self.text_thread = threading.Thread(target=run_decode_loop, daemon=True)
|
||||
self.text_thread.start()
|
||||
|
||||
logger.info(f"[{self.session_id}] ✅ Session started successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[{self.session_id}] Failed to start session: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
def feed_audio(self, audio_data: bytes):
|
||||
"""Feed audio data to the recorder AND our buffer."""
|
||||
if self.running:
|
||||
# Convert bytes to numpy array (16-bit PCM)
|
||||
audio_np = np.frombuffer(audio_data, dtype=np.int16)
|
||||
|
||||
# Feed to RealtimeSTT for VAD only
|
||||
self.audio_queue.put(audio_np)
|
||||
|
||||
# Also add to OUR float32 buffer (normalized to -1.0 to 1.0)
|
||||
float_audio = audio_np.astype(np.float32) / 32768.0
|
||||
self.float_buffer.extend(float_audio)
|
||||
|
||||
# Keep buffer size bounded (max 30 seconds at 16kHz = 480k samples)
|
||||
max_samples = int(self.max_buffer_duration * 16000)
|
||||
if len(self.float_buffer) > max_samples:
|
||||
self.float_buffer = self.float_buffer[-max_samples:]
|
||||
|
||||
def reset(self):
|
||||
"""Reset the session state."""
|
||||
logger.info(f"[{self.session_id}] Resetting session")
|
||||
self.float_buffer = []
|
||||
self.last_decode_text = ""
|
||||
# Clear audio queue
|
||||
while not self.audio_queue.empty():
|
||||
try:
|
||||
self.audio_queue.get_nowait()
|
||||
except queue.Empty:
|
||||
break
|
||||
|
||||
async def stop(self):
|
||||
"""Stop the session and cleanup."""
|
||||
logger.info(f"[{self.session_id}] Stopping session...")
|
||||
self.running = False
|
||||
|
||||
# Wait for threads to finish
|
||||
if self.feed_thread and self.feed_thread.is_alive():
|
||||
self.feed_thread.join(timeout=2)
|
||||
|
||||
# Shutdown recorder
|
||||
if self.recorder:
|
||||
try:
|
||||
self.recorder.shutdown()
|
||||
except Exception as e:
|
||||
logger.error(f"[{self.session_id}] Error shutting down recorder: {e}")
|
||||
|
||||
logger.info(f"[{self.session_id}] Session stopped")
|
||||
|
||||
|
||||
class STTServer:
|
||||
"""
|
||||
WebSocket server for RealtimeSTT.
|
||||
Handles multiple concurrent clients (one per Discord user).
|
||||
"""
|
||||
|
||||
def __init__(self, host: str = "0.0.0.0", port: int = 8766, config: Dict[str, Any] = None):
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.sessions: Dict[str, STTSession] = {}
|
||||
self.session_counter = 0
|
||||
|
||||
# Config must be provided
|
||||
if not config:
|
||||
raise ValueError("Configuration dict must be provided to STTServer")
|
||||
|
||||
self.config = config
|
||||
|
||||
logger.info("=" * 60)
|
||||
logger.info("RealtimeSTT Server Configuration:")
|
||||
logger.info(f" Host: {host}:{port}")
|
||||
logger.info(f" Model: {self.config['model']}")
|
||||
logger.info(f" Language: {self.config.get('language', 'auto-detect')}")
|
||||
logger.info(f" Device: {self.config['device']}")
|
||||
logger.info(f" Compute Type: {self.config['compute_type']}")
|
||||
logger.info(f" Silence Duration: {self.config['silence_duration']}s")
|
||||
logger.info(f" Realtime Pause: {self.config.get('realtime_processing_pause', 'N/A')}s")
|
||||
logger.info("=" * 60)
|
||||
|
||||
async def handle_client(self, websocket):
|
||||
"""Handle a WebSocket client connection."""
|
||||
self.session_counter += 1
|
||||
session_id = f"session_{self.session_counter}"
|
||||
session = None
|
||||
|
||||
try:
|
||||
logger.info(f"[{session_id}] Client connected from {websocket.remote_address}")
|
||||
|
||||
# Create session
|
||||
session = STTSession(websocket, session_id, self.config)
|
||||
self.sessions[session_id] = session
|
||||
|
||||
# Start session
|
||||
await session.start(asyncio.get_event_loop())
|
||||
|
||||
# Process messages
|
||||
async for message in websocket:
|
||||
try:
|
||||
if isinstance(message, bytes):
|
||||
# Binary audio data
|
||||
session.feed_audio(message)
|
||||
else:
|
||||
# JSON command
|
||||
data = json.loads(message)
|
||||
command = data.get('command', '')
|
||||
|
||||
if command == 'reset':
|
||||
session.reset()
|
||||
elif command == 'ping':
|
||||
await websocket.send(json.dumps({
|
||||
'type': 'pong',
|
||||
'timestamp': time.time()
|
||||
}))
|
||||
else:
|
||||
logger.warning(f"[{session_id}] Unknown command: {command}")
|
||||
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"[{session_id}] Invalid JSON message")
|
||||
except Exception as e:
|
||||
logger.error(f"[{session_id}] Error processing message: {e}")
|
||||
|
||||
except websockets.exceptions.ConnectionClosed:
|
||||
logger.info(f"[{session_id}] Client disconnected")
|
||||
except Exception as e:
|
||||
logger.error(f"[{session_id}] Error: {e}", exc_info=True)
|
||||
finally:
|
||||
# Cleanup
|
||||
if session:
|
||||
await session.stop()
|
||||
del self.sessions[session_id]
|
||||
|
||||
async def run(self):
|
||||
"""Run the WebSocket server."""
|
||||
logger.info(f"Starting RealtimeSTT server on ws://{self.host}:{self.port}")
|
||||
|
||||
async with serve(
|
||||
self.handle_client,
|
||||
self.host,
|
||||
self.port,
|
||||
ping_interval=30,
|
||||
ping_timeout=10,
|
||||
max_size=10 * 1024 * 1024, # 10MB max message size
|
||||
):
|
||||
logger.info("✅ Server ready and listening for connections")
|
||||
await asyncio.Future() # Run forever
|
||||
|
||||
|
||||
async def warmup_model(config: Dict[str, Any]):
|
||||
"""
|
||||
Warmup is DISABLED - it wastes memory by loading a model that's never reused.
|
||||
The first session will load the model when needed.
|
||||
"""
|
||||
global warmup_complete
|
||||
|
||||
logger.info("⚠️ Warmup disabled to save VRAM - model will load on first connection")
|
||||
warmup_complete = True # Mark as complete so health check passes
|
||||
|
||||
|
||||
async def health_handler(request):
|
||||
"""HTTP health check endpoint"""
|
||||
if warmup_complete:
|
||||
return web.json_response({
|
||||
"status": "ready",
|
||||
"warmed_up": True,
|
||||
"model": "small.en",
|
||||
"device": "cuda"
|
||||
})
|
||||
else:
|
||||
return web.json_response({
|
||||
"status": "warming_up",
|
||||
"warmed_up": False,
|
||||
"model": "small.en",
|
||||
"device": "cuda"
|
||||
}, status=503)
|
||||
|
||||
|
||||
async def start_http_server(host: str, http_port: int):
|
||||
"""Start HTTP server for health checks"""
|
||||
app = web.Application()
|
||||
app.router.add_get('/health', health_handler)
|
||||
|
||||
runner = web.AppRunner(app)
|
||||
await runner.setup()
|
||||
site = web.TCPSite(runner, host, http_port)
|
||||
await site.start()
|
||||
|
||||
logger.info(f"✅ HTTP health server listening on http://{host}:{http_port}")
|
||||
|
||||
|
||||
def main():
|
||||
"""Main entry point."""
|
||||
import os
|
||||
|
||||
# Get configuration from environment
|
||||
host = os.environ.get('STT_HOST', '0.0.0.0')
|
||||
port = int(os.environ.get('STT_PORT', '8766'))
|
||||
http_port = int(os.environ.get('STT_HTTP_PORT', '8767')) # HTTP health check port
|
||||
|
||||
# Configuration - ChatGPT's incremental utterance decoding approach
|
||||
config = {
|
||||
'model': 'turbo', # Fast multilingual model
|
||||
'language': 'en', # SET LANGUAGE! Auto-detect adds 4+ seconds latency (change to 'ja', 'bg' as needed)
|
||||
'compute_type': 'float16',
|
||||
'device': 'cuda',
|
||||
|
||||
# VAD settings - ChatGPT: "minimum speech ~600ms, end-of-speech silence ~400-600ms"
|
||||
'silero_sensitivity': 0.6,
|
||||
'webrtc_sensitivity': 3,
|
||||
'silence_duration': 0.5, # 500ms end-of-speech silence
|
||||
'min_recording_length': 0.6, # 600ms minimum speech
|
||||
'min_gap': 0.3,
|
||||
}
|
||||
|
||||
# Create and run server
|
||||
server = STTServer(host=host, port=port, config=config)
|
||||
|
||||
async def run_all():
|
||||
# Start warmup in background
|
||||
asyncio.create_task(warmup_model(config))
|
||||
|
||||
# Start HTTP health server
|
||||
asyncio.create_task(start_http_server(host, http_port))
|
||||
|
||||
# Start WebSocket server
|
||||
await server.run()
|
||||
|
||||
try:
|
||||
asyncio.run(run_all())
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Server shutdown requested")
|
||||
except Exception as e:
|
||||
logger.error(f"Server error: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
Reference in New Issue
Block a user