Major architectural overhaul of the speech-to-text pipeline for real-time voice chat: STT Server Rewrite: - Replaced RealtimeSTT dependency with direct Silero VAD + Faster-Whisper integration - Achieved sub-second latency by eliminating unnecessary abstractions - Uses small.en Whisper model for fast transcription (~850ms) Speculative Transcription (NEW): - Start transcribing at 150ms silence (speculative) while still listening - If speech continues, discard speculative result and keep buffering - If 400ms silence confirmed, use pre-computed speculative result immediately - Reduces latency by ~250-850ms for typical utterances with clear pauses VAD Implementation: - Silero VAD with ONNX (CPU-efficient) for 32ms chunk processing - Direct speech boundary detection without RealtimeSTT overhead - Configurable thresholds for silence detection (400ms final, 150ms speculative) Architecture: - Single Whisper model loaded once, shared across sessions - VAD runs on every 512-sample chunk for immediate speech detection - Background transcription worker thread for non-blocking processing - Greedy decoding (beam_size=1) for maximum speed Performance: - Previous: 400ms silence wait + ~850ms transcription = ~1.25s total latency - Current: 400ms silence wait + 0ms (speculative ready) = ~400ms (best case) - Single model reduces VRAM usage, prevents OOM on GTX 1660 Container Manager Updates: - Updated health check logic to work with new response format - Changed from checking 'warmed_up' flag to just 'status: ready' - Improved terminology from 'warmup' to 'models loading' Files Changed: - stt-realtime/stt_server.py: Complete rewrite with Silero VAD + speculative transcription - stt-realtime/requirements.txt: Removed RealtimeSTT, using torch.hub for Silero VAD - bot/utils/container_manager.py: Updated health check for new STT response format - bot/api.py: Updated docstring to reflect new architecture - backups/: Archived old RealtimeSTT-based implementation This addresses low latency requirements while maintaining accuracy with configurable speech detection thresholds.
482 lines
17 KiB
Python
482 lines
17 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Low-Latency STT WebSocket Server
|
|
|
|
Uses Silero VAD for speech detection + Faster-Whisper turbo for transcription.
|
|
Achieves sub-second latency after speech ends.
|
|
|
|
Architecture:
|
|
1. Silero VAD runs on every audio chunk to detect speech boundaries
|
|
2. When speech ends (silence detected), immediately transcribe the buffer
|
|
3. Send final transcript - no waiting for stability
|
|
|
|
Protocol:
|
|
- Client sends: binary audio data (16kHz, 16-bit mono PCM)
|
|
- Client sends: JSON {"command": "reset"} to reset state
|
|
- Server sends: JSON {"type": "partial", "text": "...", "timestamp": float}
|
|
- Server sends: JSON {"type": "final", "text": "...", "timestamp": float}
|
|
"""
|
|
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
import time
|
|
import threading
|
|
import queue
|
|
from typing import Optional, Dict, Any
|
|
import numpy as np
|
|
import websockets
|
|
from websockets.server import serve
|
|
from aiohttp import web
|
|
|
|
# Configure logging
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format='%(asctime)s %(levelname)s [%(name)s] %(message)s',
|
|
datefmt='%Y-%m-%d %H:%M:%S'
|
|
)
|
|
logger = logging.getLogger('stt-realtime')
|
|
|
|
# Silero VAD
|
|
import torch
|
|
torch.set_num_threads(1) # Prevent thread contention
|
|
|
|
# Faster-Whisper for transcription
|
|
from faster_whisper import WhisperModel
|
|
|
|
# Global model (shared across sessions for memory efficiency)
|
|
whisper_model: Optional[WhisperModel] = None
|
|
vad_model = None
|
|
warmup_complete = False
|
|
|
|
|
|
def load_vad_model():
|
|
"""Load Silero VAD model."""
|
|
global vad_model
|
|
model, _ = torch.hub.load(
|
|
repo_or_dir='snakers4/silero-vad',
|
|
model='silero_vad',
|
|
force_reload=False,
|
|
onnx=True # Use ONNX for speed
|
|
)
|
|
vad_model = model
|
|
logger.info("Silero VAD loaded (ONNX)")
|
|
return model
|
|
|
|
|
|
def load_whisper_model(config: Dict[str, Any]):
|
|
"""Load Faster-Whisper model."""
|
|
global whisper_model
|
|
whisper_model = WhisperModel(
|
|
config['model'],
|
|
device=config['device'],
|
|
compute_type=config['compute_type'],
|
|
)
|
|
logger.info(f"Faster-Whisper '{config['model']}' loaded on {config['device']}")
|
|
return whisper_model
|
|
|
|
|
|
class STTSession:
|
|
"""
|
|
Low-latency STT session using Silero VAD + Faster-Whisper.
|
|
"""
|
|
|
|
SAMPLE_RATE = 16000
|
|
VAD_CHUNK_MS = 32 # Silero needs 512 samples at 16kHz = 32ms
|
|
VAD_CHUNK_SAMPLES = 512 # Fixed: Silero requires exactly 512 samples at 16kHz
|
|
|
|
def __init__(self, websocket, session_id: str, config: Dict[str, Any]):
|
|
self.websocket = websocket
|
|
self.session_id = session_id
|
|
self.config = config
|
|
self.running = False
|
|
self.loop = None
|
|
|
|
# Audio state
|
|
self.audio_buffer = [] # Float32 samples for current utterance
|
|
self.vad_buffer = [] # Small buffer for VAD chunk alignment
|
|
|
|
# Speech detection state
|
|
self.is_speaking = False
|
|
self.silence_start_time = 0
|
|
self.speech_start_time = 0
|
|
|
|
# Configurable thresholds
|
|
self.vad_threshold = config.get('vad_threshold', 0.5)
|
|
self.silence_duration_ms = config.get('silence_duration_ms', 400)
|
|
self.min_speech_ms = config.get('min_speech_ms', 250)
|
|
self.max_speech_duration = config.get('max_speech_duration', 30.0)
|
|
|
|
# Speculative transcription settings
|
|
self.speculative_silence_ms = config.get('speculative_silence_ms', 150) # Start transcribing early
|
|
self.speculative_pending = False # Is a speculative transcription in flight?
|
|
self.speculative_audio_snapshot = None # Audio buffer snapshot for speculative
|
|
self.speculative_result = None # Result from speculative transcription
|
|
self.speculative_result_ready = threading.Event()
|
|
|
|
# Transcription queue
|
|
self.transcribe_queue = queue.Queue()
|
|
self.transcribe_thread = None
|
|
|
|
logger.info(f"[{session_id}] Session created (speculative: {self.speculative_silence_ms}ms, final: {self.silence_duration_ms}ms)")
|
|
|
|
async def start(self, loop: asyncio.AbstractEventLoop):
|
|
"""Start the session."""
|
|
self.loop = loop
|
|
self.running = True
|
|
|
|
self.transcribe_thread = threading.Thread(target=self._transcription_worker, daemon=True)
|
|
self.transcribe_thread.start()
|
|
|
|
logger.info(f"[{self.session_id}] Session started")
|
|
|
|
def _transcription_worker(self):
|
|
"""Background thread that processes transcription requests."""
|
|
while self.running:
|
|
try:
|
|
item = self.transcribe_queue.get(timeout=0.1)
|
|
if item is None:
|
|
continue
|
|
|
|
audio_array, is_final, is_speculative = item
|
|
start_time = time.time()
|
|
|
|
segments, info = whisper_model.transcribe(
|
|
audio_array,
|
|
language=self.config.get('language', 'en'),
|
|
beam_size=1,
|
|
best_of=1,
|
|
temperature=0.0,
|
|
vad_filter=False,
|
|
without_timestamps=True,
|
|
)
|
|
|
|
text = " ".join(seg.text for seg in segments).strip()
|
|
elapsed = time.time() - start_time
|
|
|
|
if is_speculative:
|
|
# Store result for potential use
|
|
self.speculative_result = (text, elapsed)
|
|
self.speculative_result_ready.set()
|
|
logger.debug(f"[{self.session_id}] SPECULATIVE ({elapsed:.2f}s): {text}")
|
|
elif text:
|
|
transcript_type = "final" if is_final else "partial"
|
|
logger.info(f"[{self.session_id}] {transcript_type.upper()} ({elapsed:.2f}s): {text}")
|
|
|
|
asyncio.run_coroutine_threadsafe(
|
|
self._send_transcript(transcript_type, text),
|
|
self.loop
|
|
)
|
|
|
|
except queue.Empty:
|
|
continue
|
|
except Exception as e:
|
|
logger.error(f"[{self.session_id}] Transcription error: {e}", exc_info=True)
|
|
|
|
async def _send_transcript(self, transcript_type: str, text: str):
|
|
"""Send transcript to client."""
|
|
try:
|
|
await self.websocket.send(json.dumps({
|
|
"type": transcript_type,
|
|
"text": text,
|
|
"timestamp": time.time()
|
|
}))
|
|
except Exception as e:
|
|
logger.error(f"[{self.session_id}] Send error: {e}")
|
|
|
|
def feed_audio(self, audio_data: bytes):
|
|
"""Process incoming audio data."""
|
|
if not self.running:
|
|
return
|
|
|
|
audio_int16 = np.frombuffer(audio_data, dtype=np.int16)
|
|
audio_float = audio_int16.astype(np.float32) / 32768.0
|
|
|
|
self.vad_buffer.extend(audio_float)
|
|
|
|
while len(self.vad_buffer) >= self.VAD_CHUNK_SAMPLES:
|
|
chunk = np.array(self.vad_buffer[:self.VAD_CHUNK_SAMPLES], dtype=np.float32)
|
|
self.vad_buffer = self.vad_buffer[self.VAD_CHUNK_SAMPLES:]
|
|
self._process_vad_chunk(chunk)
|
|
|
|
def _process_vad_chunk(self, chunk: np.ndarray):
|
|
"""Process a single VAD chunk."""
|
|
current_time = time.time()
|
|
|
|
chunk_tensor = torch.from_numpy(chunk)
|
|
speech_prob = vad_model(chunk_tensor, self.SAMPLE_RATE).item()
|
|
|
|
is_speech = speech_prob >= self.vad_threshold
|
|
|
|
if is_speech:
|
|
if not self.is_speaking:
|
|
self.is_speaking = True
|
|
self.speech_start_time = current_time
|
|
self.audio_buffer = []
|
|
logger.debug(f"[{self.session_id}] Speech started")
|
|
|
|
self.audio_buffer.extend(chunk)
|
|
self.silence_start_time = 0
|
|
|
|
# Cancel any speculative transcription if speech resumed
|
|
if self.speculative_pending:
|
|
logger.debug(f"[{self.session_id}] Speech resumed, canceling speculative")
|
|
self.speculative_pending = False
|
|
self.speculative_result = None
|
|
self.speculative_result_ready.clear()
|
|
|
|
speech_duration = current_time - self.speech_start_time
|
|
if speech_duration >= self.max_speech_duration:
|
|
logger.info(f"[{self.session_id}] Max duration reached")
|
|
self._finalize_utterance()
|
|
|
|
else:
|
|
if self.is_speaking:
|
|
self.audio_buffer.extend(chunk)
|
|
|
|
if self.silence_start_time == 0:
|
|
self.silence_start_time = current_time
|
|
|
|
silence_duration_ms = (current_time - self.silence_start_time) * 1000
|
|
speech_duration_ms = (self.silence_start_time - self.speech_start_time) * 1000
|
|
|
|
# Trigger speculative transcription early
|
|
if (not self.speculative_pending and
|
|
silence_duration_ms >= self.speculative_silence_ms and
|
|
speech_duration_ms >= self.min_speech_ms):
|
|
self._start_speculative_transcription()
|
|
|
|
# Final silence threshold reached
|
|
if silence_duration_ms >= self.silence_duration_ms:
|
|
if speech_duration_ms >= self.min_speech_ms:
|
|
logger.debug(f"[{self.session_id}] Speech ended ({speech_duration_ms:.0f}ms)")
|
|
self._finalize_utterance()
|
|
else:
|
|
logger.debug(f"[{self.session_id}] Discarding short utterance")
|
|
self._reset_state()
|
|
|
|
def _start_speculative_transcription(self):
|
|
"""Start speculative transcription without waiting for full silence."""
|
|
if self.audio_buffer:
|
|
self.speculative_pending = True
|
|
self.speculative_result = None
|
|
self.speculative_result_ready.clear()
|
|
|
|
# Snapshot current buffer
|
|
audio_array = np.array(self.audio_buffer, dtype=np.float32)
|
|
duration = len(audio_array) / self.SAMPLE_RATE
|
|
|
|
logger.debug(f"[{self.session_id}] Starting speculative transcription ({duration:.1f}s)")
|
|
# is_speculative=True
|
|
self.transcribe_queue.put((audio_array, False, True))
|
|
|
|
def _finalize_utterance(self):
|
|
"""Finalize current utterance and send transcript."""
|
|
if not self.audio_buffer:
|
|
self._reset_state()
|
|
return
|
|
|
|
audio_array = np.array(self.audio_buffer, dtype=np.float32)
|
|
duration = len(audio_array) / self.SAMPLE_RATE
|
|
|
|
# Check if we have a speculative result ready
|
|
if self.speculative_pending and self.speculative_result_ready.wait(timeout=0.05):
|
|
# Use speculative result immediately!
|
|
text, elapsed = self.speculative_result
|
|
if text:
|
|
logger.info(f"[{self.session_id}] FINAL [speculative] ({elapsed:.2f}s): {text}")
|
|
asyncio.run_coroutine_threadsafe(
|
|
self._send_transcript("final", text),
|
|
self.loop
|
|
)
|
|
self._reset_state()
|
|
return
|
|
|
|
# No speculative result, do regular transcription
|
|
logger.info(f"[{self.session_id}] Queuing transcription ({duration:.1f}s)")
|
|
self.transcribe_queue.put((audio_array, True, False))
|
|
|
|
self._reset_state()
|
|
|
|
def _reset_state(self):
|
|
"""Reset speech detection state."""
|
|
self.is_speaking = False
|
|
self.audio_buffer = []
|
|
self.silence_start_time = 0
|
|
self.speech_start_time = 0
|
|
self.speculative_pending = False
|
|
self.speculative_result = None
|
|
self.speculative_result_ready.clear()
|
|
|
|
def reset(self):
|
|
"""Reset session state."""
|
|
logger.info(f"[{self.session_id}] Resetting")
|
|
self._reset_state()
|
|
self.vad_buffer = []
|
|
|
|
async def stop(self):
|
|
"""Stop the session."""
|
|
logger.info(f"[{self.session_id}] Stopping...")
|
|
self.running = False
|
|
|
|
if self.audio_buffer and self.is_speaking:
|
|
self._finalize_utterance()
|
|
|
|
if self.transcribe_thread and self.transcribe_thread.is_alive():
|
|
self.transcribe_thread.join(timeout=2)
|
|
|
|
logger.info(f"[{self.session_id}] Stopped")
|
|
|
|
|
|
class STTServer:
|
|
"""WebSocket server for low-latency STT."""
|
|
|
|
def __init__(self, host: str, port: int, config: Dict[str, Any]):
|
|
self.host = host
|
|
self.port = port
|
|
self.config = config
|
|
self.sessions: Dict[str, STTSession] = {}
|
|
self.session_counter = 0
|
|
|
|
logger.info("=" * 60)
|
|
logger.info("Low-Latency STT Server")
|
|
logger.info(f" Host: {host}:{port}")
|
|
logger.info(f" Model: {config['model']}")
|
|
logger.info(f" Language: {config.get('language', 'en')}")
|
|
logger.info(f" Silence: {config.get('silence_duration_ms', 400)}ms")
|
|
logger.info("=" * 60)
|
|
|
|
async def handle_client(self, websocket):
|
|
"""Handle WebSocket client."""
|
|
self.session_counter += 1
|
|
session_id = f"session_{self.session_counter}"
|
|
session = None
|
|
|
|
try:
|
|
logger.info(f"[{session_id}] Client connected")
|
|
|
|
session = STTSession(websocket, session_id, self.config)
|
|
self.sessions[session_id] = session
|
|
await session.start(asyncio.get_event_loop())
|
|
|
|
async for message in websocket:
|
|
if isinstance(message, bytes):
|
|
session.feed_audio(message)
|
|
else:
|
|
try:
|
|
data = json.loads(message)
|
|
cmd = data.get('command', '')
|
|
if cmd == 'reset':
|
|
session.reset()
|
|
elif cmd == 'ping':
|
|
await websocket.send(json.dumps({
|
|
'type': 'pong',
|
|
'timestamp': time.time()
|
|
}))
|
|
except json.JSONDecodeError:
|
|
pass
|
|
|
|
except websockets.exceptions.ConnectionClosed:
|
|
logger.info(f"[{session_id}] Client disconnected")
|
|
except Exception as e:
|
|
logger.error(f"[{session_id}] Error: {e}", exc_info=True)
|
|
finally:
|
|
if session:
|
|
await session.stop()
|
|
del self.sessions[session_id]
|
|
|
|
async def run(self):
|
|
"""Run the server."""
|
|
logger.info(f"Starting server on ws://{self.host}:{self.port}")
|
|
|
|
async with serve(
|
|
self.handle_client,
|
|
self.host,
|
|
self.port,
|
|
ping_interval=30,
|
|
ping_timeout=10,
|
|
max_size=10 * 1024 * 1024,
|
|
):
|
|
logger.info("Server ready")
|
|
await asyncio.Future()
|
|
|
|
|
|
async def warmup(config: Dict[str, Any]):
|
|
"""Load models at startup."""
|
|
global warmup_complete
|
|
|
|
logger.info("Loading models...")
|
|
|
|
load_vad_model()
|
|
load_whisper_model(config)
|
|
|
|
logger.info("Warming up transcription...")
|
|
dummy_audio = np.zeros(16000, dtype=np.float32)
|
|
segments, _ = whisper_model.transcribe(
|
|
dummy_audio,
|
|
language=config.get('language', 'en'),
|
|
beam_size=1,
|
|
)
|
|
list(segments)
|
|
|
|
warmup_complete = True
|
|
logger.info("Warmup complete")
|
|
|
|
|
|
async def health_handler(request):
|
|
"""Health check endpoint."""
|
|
if warmup_complete:
|
|
return web.json_response({"status": "ready"})
|
|
return web.json_response({"status": "warming_up"}, status=503)
|
|
|
|
|
|
async def start_http_server(host: str, port: int):
|
|
"""Start HTTP health server."""
|
|
app = web.Application()
|
|
app.router.add_get('/health', health_handler)
|
|
runner = web.AppRunner(app)
|
|
await runner.setup()
|
|
site = web.TCPSite(runner, host, port)
|
|
await site.start()
|
|
logger.info(f"Health server on http://{host}:{port}")
|
|
|
|
|
|
def main():
|
|
"""Main entry point."""
|
|
import os
|
|
|
|
host = os.environ.get('STT_HOST', '0.0.0.0')
|
|
port = int(os.environ.get('STT_PORT', '8766'))
|
|
http_port = int(os.environ.get('STT_HTTP_PORT', '8767'))
|
|
|
|
config = {
|
|
'model': 'small.en',
|
|
'language': 'en',
|
|
'compute_type': 'float16',
|
|
'device': 'cuda',
|
|
'vad_threshold': 0.5,
|
|
'silence_duration_ms': 400, # Final silence threshold
|
|
'speculative_silence_ms': 150, # Start transcribing early at 150ms
|
|
'min_speech_ms': 250,
|
|
'max_speech_duration': 30.0,
|
|
}
|
|
|
|
server = STTServer(host, port, config)
|
|
|
|
async def run_all():
|
|
await warmup(config)
|
|
asyncio.create_task(start_http_server(host, http_port))
|
|
await server.run()
|
|
|
|
try:
|
|
asyncio.run(run_all())
|
|
except KeyboardInterrupt:
|
|
logger.info("Shutdown requested")
|
|
except Exception as e:
|
|
logger.error(f"Server error: {e}", exc_info=True)
|
|
raise
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|