#!/usr/bin/env python3 """ RealtimeSTT WebSocket Server Provides real-time speech-to-text transcription using Faster-Whisper. Receives audio chunks via WebSocket and streams back partial/final transcripts. Protocol: - Client sends: binary audio data (16kHz, 16-bit mono PCM) - Client sends: JSON {"command": "reset"} to reset state - Server sends: JSON {"type": "partial", "text": "...", "timestamp": float} - Server sends: JSON {"type": "final", "text": "...", "timestamp": float} """ import asyncio import json import logging import time import threading import queue from typing import Optional, Dict, Any import numpy as np import websockets from websockets.server import serve from aiohttp import web # Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s %(levelname)s [%(name)s] %(message)s', datefmt='%Y-%m-%d %H:%M:%S' ) logger = logging.getLogger('stt-realtime') # Import RealtimeSTT from RealtimeSTT import AudioToTextRecorder # Global warmup state warmup_complete = False warmup_lock = threading.Lock() warmup_recorder = None class STTSession: """ Manages a single STT session for a WebSocket client. Uses RealtimeSTT's AudioToTextRecorder with feed_audio() method. """ def __init__(self, websocket, session_id: str, config: Dict[str, Any]): self.websocket = websocket self.session_id = session_id self.config = config self.recorder: Optional[AudioToTextRecorder] = None self.running = False self.audio_queue = queue.Queue() self.feed_thread: Optional[threading.Thread] = None self.last_partial = "" self.last_stabilized = "" # Track last stabilized partial self.last_text_was_stabilized = False # Track which came last self.recording_active = False # Track if currently recording logger.info(f"[{session_id}] Session created") def _on_realtime_transcription(self, text: str): """Called when partial transcription is available.""" if text and text != self.last_partial: self.last_partial = text self.last_text_was_stabilized = False # Partial came after stabilized logger.info(f"[{self.session_id}] 📝 Partial: {text}") asyncio.run_coroutine_threadsafe( self._send_transcript("partial", text), self.loop ) def _on_realtime_stabilized(self, text: str): """Called when a stabilized partial is available (high confidence).""" if text and text.strip(): self.last_stabilized = text self.last_text_was_stabilized = True # Stabilized came after partial logger.info(f"[{self.session_id}] 🔒 Stabilized: {text}") asyncio.run_coroutine_threadsafe( self._send_transcript("partial", text), self.loop ) def _on_recording_stop(self): """Called when recording stops (silence detected).""" logger.info(f"[{self.session_id}] âšī¸ Recording stopped") self.recording_active = False # Use the most recent text: prioritize whichever came last if self.last_text_was_stabilized: final_text = self.last_stabilized or self.last_partial source = "stabilized" if self.last_stabilized else "partial" else: final_text = self.last_partial or self.last_stabilized source = "partial" if self.last_partial else "stabilized" if final_text: logger.info(f"[{self.session_id}] ✅ Final (from {source}): {final_text}") asyncio.run_coroutine_threadsafe( self._send_transcript("final", final_text), self.loop ) else: # No transcript means VAD false positive (detected "speech" in pure noise) logger.warning(f"[{self.session_id}] âš ī¸ Recording stopped but no transcript available (VAD false positive)") logger.info(f"[{self.session_id}] 🔄 Clearing audio buffer to recover") # Clear the audio queue to prevent stale data try: while not self.audio_queue.empty(): self.audio_queue.get_nowait() except Exception: pass # Reset state self.last_stabilized = "" self.last_partial = "" self.last_text_was_stabilized = False def _on_recording_start(self): """Called when recording starts (speech detected).""" logger.info(f"[{self.session_id}] đŸŽ™ī¸ Recording started") self.recording_active = True self.last_stabilized = "" self.last_partial = "" def _on_transcription(self, text: str): """Not used - we use stabilized partials as finals.""" pass async def _send_transcript(self, transcript_type: str, text: str): """Send transcript to client via WebSocket.""" try: message = { "type": transcript_type, "text": text, "timestamp": time.time() } await self.websocket.send(json.dumps(message)) except Exception as e: logger.error(f"[{self.session_id}] Failed to send transcript: {e}") def _feed_audio_thread(self): """Thread that feeds audio to the recorder.""" logger.info(f"[{self.session_id}] Audio feed thread started") while self.running: try: # Get audio chunk with timeout audio_chunk = self.audio_queue.get(timeout=0.1) if audio_chunk is not None and self.recorder: self.recorder.feed_audio(audio_chunk) except queue.Empty: continue except Exception as e: logger.error(f"[{self.session_id}] Error feeding audio: {e}") logger.info(f"[{self.session_id}] Audio feed thread stopped") async def start(self, loop: asyncio.AbstractEventLoop): """Start the STT session.""" self.loop = loop self.running = True logger.info(f"[{self.session_id}] Starting RealtimeSTT recorder...") logger.info(f"[{self.session_id}] Model: {self.config['model']}") logger.info(f"[{self.session_id}] Device: {self.config['device']}") try: # Create recorder in a thread to avoid blocking def init_recorder(): self.recorder = AudioToTextRecorder( # Model settings - using same model for both partial and final model=self.config['model'], language=self.config['language'], compute_type=self.config['compute_type'], device=self.config['device'], # Disable microphone - we feed audio manually use_microphone=False, # Real-time transcription - use same model for everything enable_realtime_transcription=True, realtime_model_type=self.config['model'], # Use same model realtime_processing_pause=0.05, # 50ms between updates on_realtime_transcription_update=self._on_realtime_transcription, on_realtime_transcription_stabilized=self._on_realtime_stabilized, # VAD settings - very permissive, rely on Discord's VAD for speech detection # Our VAD is only for silence detection, not filtering audio content silero_sensitivity=0.05, # Very low = barely filters anything silero_use_onnx=True, # Faster webrtc_sensitivity=3, post_speech_silence_duration=self.config['silence_duration'], min_length_of_recording=self.config['min_recording_length'], min_gap_between_recordings=self.config['min_gap'], pre_recording_buffer_duration=1.0, # Capture more audio before/after speech # Callbacks on_recording_start=self._on_recording_start, on_recording_stop=self._on_recording_stop, on_vad_detect_start=lambda: logger.debug(f"[{self.session_id}] VAD listening"), on_vad_detect_stop=lambda: logger.debug(f"[{self.session_id}] VAD stopped"), # Other settings spinner=False, # No spinner in container level=logging.WARNING, # Reduce internal logging # Beam search settings beam_size=5, # Higher beam = better accuracy (used for final processing) beam_size_realtime=5, # Increased from 3 for better real-time accuracy # Batch sizes batch_size=16, realtime_batch_size=8, initial_prompt="", # Can add context here if needed ) logger.info(f"[{self.session_id}] ✅ Recorder initialized") # Run initialization in thread pool await asyncio.get_event_loop().run_in_executor(None, init_recorder) # Start audio feed thread self.feed_thread = threading.Thread(target=self._feed_audio_thread, daemon=True) self.feed_thread.start() # Start the recorder's text processing loop in a thread def run_text_loop(): while self.running: try: # This blocks until speech is detected and transcribed text = self.recorder.text(self._on_transcription) except Exception as e: if self.running: logger.error(f"[{self.session_id}] Text loop error: {e}") break self.text_thread = threading.Thread(target=run_text_loop, daemon=True) self.text_thread.start() logger.info(f"[{self.session_id}] ✅ Session started successfully") except Exception as e: logger.error(f"[{self.session_id}] Failed to start session: {e}", exc_info=True) raise def feed_audio(self, audio_data: bytes): """Feed audio data to the recorder.""" if self.running: # Convert bytes to numpy array (16-bit PCM) audio_np = np.frombuffer(audio_data, dtype=np.int16) self.audio_queue.put(audio_np) def reset(self): """Reset the session state.""" logger.info(f"[{self.session_id}] Resetting session") self.last_partial = "" # Clear audio queue while not self.audio_queue.empty(): try: self.audio_queue.get_nowait() except queue.Empty: break async def stop(self): """Stop the session and cleanup.""" logger.info(f"[{self.session_id}] Stopping session...") self.running = False # Wait for threads to finish if self.feed_thread and self.feed_thread.is_alive(): self.feed_thread.join(timeout=2) # Shutdown recorder if self.recorder: try: self.recorder.shutdown() except Exception as e: logger.error(f"[{self.session_id}] Error shutting down recorder: {e}") logger.info(f"[{self.session_id}] Session stopped") class STTServer: """ WebSocket server for RealtimeSTT. Handles multiple concurrent clients (one per Discord user). """ def __init__(self, host: str = "0.0.0.0", port: int = 8766): self.host = host self.port = port self.sessions: Dict[str, STTSession] = {} self.session_counter = 0 # Default configuration self.config = { # Model - using small.en (English-only, more accurate than multilingual small) 'model': 'small.en', 'language': 'en', 'compute_type': 'float16', # FP16 for GPU efficiency 'device': 'cuda', # VAD settings 'silero_sensitivity': 0.6, 'webrtc_sensitivity': 3, 'silence_duration': 0.8, # Shorter to improve responsiveness 'min_recording_length': 0.5, 'min_gap': 0.3, } logger.info("=" * 60) logger.info("RealtimeSTT Server Configuration:") logger.info(f" Host: {host}:{port}") logger.info(f" Model: {self.config['model']} (English-only, optimized)") logger.info(f" Beam size: 5 (higher accuracy)") logger.info(f" Strategy: Use last partial as final (instant response)") logger.info(f" Language: {self.config['language']}") logger.info(f" Device: {self.config['device']}") logger.info(f" Compute Type: {self.config['compute_type']}") logger.info(f" Silence Duration: {self.config['silence_duration']}s") logger.info("=" * 60) async def handle_client(self, websocket): """Handle a WebSocket client connection.""" self.session_counter += 1 session_id = f"session_{self.session_counter}" session = None try: logger.info(f"[{session_id}] Client connected from {websocket.remote_address}") # Create session session = STTSession(websocket, session_id, self.config) self.sessions[session_id] = session # Start session await session.start(asyncio.get_event_loop()) # Process messages async for message in websocket: try: if isinstance(message, bytes): # Binary audio data session.feed_audio(message) else: # JSON command data = json.loads(message) command = data.get('command', '') if command == 'reset': session.reset() elif command == 'ping': await websocket.send(json.dumps({ 'type': 'pong', 'timestamp': time.time() })) else: logger.warning(f"[{session_id}] Unknown command: {command}") except json.JSONDecodeError: logger.warning(f"[{session_id}] Invalid JSON message") except Exception as e: logger.error(f"[{session_id}] Error processing message: {e}") except websockets.exceptions.ConnectionClosed: logger.info(f"[{session_id}] Client disconnected") except Exception as e: logger.error(f"[{session_id}] Error: {e}", exc_info=True) finally: # Cleanup if session: await session.stop() del self.sessions[session_id] async def run(self): """Run the WebSocket server.""" logger.info(f"Starting RealtimeSTT server on ws://{self.host}:{self.port}") async with serve( self.handle_client, self.host, self.port, ping_interval=30, ping_timeout=10, max_size=10 * 1024 * 1024, # 10MB max message size ): logger.info("✅ Server ready and listening for connections") await asyncio.Future() # Run forever async def warmup_model(config: Dict[str, Any]): """ Warm up the STT model by loading it and processing test audio. This ensures the model is cached in memory before handling real requests. """ global warmup_complete, warmup_recorder with warmup_lock: if warmup_complete: logger.info("Model already warmed up") return logger.info("đŸ”Ĩ Starting model warmup...") try: # Generate silent test audio (1 second of silence, 16kHz) test_audio = np.zeros(16000, dtype=np.int16) # Initialize a temporary recorder to load the model logger.info("Loading Faster-Whisper model...") def dummy_callback(text): pass # This will trigger model loading and compilation warmup_recorder = AudioToTextRecorder( model=config['model'], language=config['language'], compute_type=config['compute_type'], device=config['device'], silero_sensitivity=config['silero_sensitivity'], webrtc_sensitivity=config['webrtc_sensitivity'], post_speech_silence_duration=config['silence_duration'], min_length_of_recording=config['min_recording_length'], min_gap_between_recordings=config['min_gap'], enable_realtime_transcription=True, realtime_processing_pause=0.1, on_realtime_transcription_update=dummy_callback, on_realtime_transcription_stabilized=dummy_callback, spinner=False, level=logging.WARNING, beam_size=5, beam_size_realtime=5, batch_size=16, realtime_batch_size=8, initial_prompt="", ) logger.info("✅ Model loaded and warmed up successfully") warmup_complete = True except Exception as e: logger.error(f"❌ Warmup failed: {e}", exc_info=True) warmup_complete = False async def health_handler(request): """HTTP health check endpoint""" if warmup_complete: return web.json_response({ "status": "ready", "warmed_up": True, "model": "small.en", "device": "cuda" }) else: return web.json_response({ "status": "warming_up", "warmed_up": False, "model": "small.en", "device": "cuda" }, status=503) async def start_http_server(host: str, http_port: int): """Start HTTP server for health checks""" app = web.Application() app.router.add_get('/health', health_handler) runner = web.AppRunner(app) await runner.setup() site = web.TCPSite(runner, host, http_port) await site.start() logger.info(f"✅ HTTP health server listening on http://{host}:{http_port}") def main(): """Main entry point.""" import os # Get configuration from environment host = os.environ.get('STT_HOST', '0.0.0.0') port = int(os.environ.get('STT_PORT', '8766')) http_port = int(os.environ.get('STT_HTTP_PORT', '8767')) # HTTP health check port # Configuration config = { 'model': 'small.en', 'language': 'en', 'compute_type': 'float16', 'device': 'cuda', 'silero_sensitivity': 0.6, 'webrtc_sensitivity': 3, 'silence_duration': 0.8, 'min_recording_length': 0.5, 'min_gap': 0.3, } # Create and run server server = STTServer(host=host, port=port) async def run_all(): # Start warmup in background asyncio.create_task(warmup_model(config)) # Start HTTP health server asyncio.create_task(start_http_server(host, http_port)) # Start WebSocket server await server.run() try: asyncio.run(run_all()) except KeyboardInterrupt: logger.info("Server shutdown requested") except Exception as e: logger.error(f"Server error: {e}", exc_info=True) raise if __name__ == '__main__': main()