diff --git a/backups/stt_server_realtimestt_based_2025-01-22.py b/backups/stt_server_realtimestt_based_2025-01-22.py new file mode 100644 index 0000000..fbc566e --- /dev/null +++ b/backups/stt_server_realtimestt_based_2025-01-22.py @@ -0,0 +1,510 @@ +#!/usr/bin/env python3 +""" +RealtimeSTT WebSocket Server + +Provides real-time speech-to-text transcription using Faster-Whisper. +Receives audio chunks via WebSocket and streams back partial/final transcripts. + +Protocol: +- Client sends: binary audio data (16kHz, 16-bit mono PCM) +- Client sends: JSON {"command": "reset"} to reset state +- Server sends: JSON {"type": "partial", "text": "...", "timestamp": float} +- Server sends: JSON {"type": "final", "text": "...", "timestamp": float} +""" + +import asyncio +import json +import logging +import time +import threading +import queue +from typing import Optional, Dict, Any +import numpy as np +import websockets +from websockets.server import serve +from aiohttp import web + +# Configure logging +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s %(levelname)s [%(name)s] %(message)s', + datefmt='%Y-%m-%d %H:%M:%S' +) +logger = logging.getLogger('stt-realtime') + +# Import RealtimeSTT +from RealtimeSTT import AudioToTextRecorder + +# Global warmup state +warmup_complete = False +warmup_lock = threading.Lock() +warmup_recorder = None + + +class STTSession: + """ + Manages a single STT session for a WebSocket client. + + Key architectural point: We own the audio buffer and decoder. + RealtimeSTT is used ONLY for VAD, not for transcription ownership. + """ + + def __init__(self, websocket, session_id: str, config: Dict[str, Any]): + self.websocket = websocket + self.session_id = session_id + self.config = config + self.recorder: Optional[AudioToTextRecorder] = None + self.running = False + self.audio_queue = queue.Queue() + self.feed_thread: Optional[threading.Thread] = None + + # OUR audio buffer - we own this, not RealtimeSTT + self.float_buffer = [] # Rolling float32 buffer (0.0 to 1.0 range) + self.max_buffer_duration = 30.0 # Keep max 30 seconds + + # Decode state + self.last_decode_text = "" + self.recording_active = False + self.recording_stop_time = 0 + self.last_decode_time = 0 + self.final_sent = False # Track if we've sent final for this utterance + self.last_audio_time = 0 # Track when we last received audio with speech + self.speech_detected = False # Track if we've detected any speech + + logger.info(f"[{session_id}] Session created") + + def _on_recording_stop(self): + """Called when recording stops (silence detected).""" + logger.info(f"[{self.session_id}] âšī¸ Recording stopped - will emit final in decode loop") + self.recording_active = False + self.recording_stop_time = time.time() # Track when recording stopped + + def _on_recording_start(self): + """Called when recording starts (speech detected).""" + logger.info(f"[{self.session_id}] đŸŽ™ī¸ Recording started") + self.recording_active = True + self.float_buffer = [] # Reset buffer for new utterance + self.last_decode_text = "" + self.last_decode_time = 0 + self.final_sent = False # Reset final flag for new utterance + + async def _send_transcript(self, transcript_type: str, text: str): + """Send transcript to client via WebSocket.""" + try: + message = { + "type": transcript_type, + "text": text, + "timestamp": time.time() + } + await self.websocket.send(json.dumps(message)) + except Exception as e: + logger.error(f"[{self.session_id}] Failed to send transcript: {e}") + + def _feed_audio_thread(self): + """Thread that feeds audio to the recorder.""" + logger.info(f"[{self.session_id}] Audio feed thread started") + while self.running: + try: + # Get audio chunk with timeout + audio_chunk = self.audio_queue.get(timeout=0.1) + if audio_chunk is not None and self.recorder: + self.recorder.feed_audio(audio_chunk) + except queue.Empty: + continue + except Exception as e: + logger.error(f"[{self.session_id}] Error feeding audio: {e}") + logger.info(f"[{self.session_id}] Audio feed thread stopped") + + async def start(self, loop: asyncio.AbstractEventLoop): + """Start the STT session.""" + self.loop = loop + self.running = True + + logger.info(f"[{self.session_id}] Starting RealtimeSTT recorder...") + logger.info(f"[{self.session_id}] Model: {self.config['model']}") + logger.info(f"[{self.session_id}] Device: {self.config['device']}") + + try: + # Create recorder in a thread to avoid blocking + def init_recorder(): + # Build initialization kwargs + recorder_kwargs = { + # Model settings - ONLY turbo model, no dual-model setup + 'model': self.config['model'], + 'language': self.config['language'], + 'compute_type': self.config['compute_type'], + 'device': self.config['device'], + + # Disable microphone - we feed audio manually + 'use_microphone': False, + + # DISABLE realtime partials - we'll use incremental utterance decoding instead + 'enable_realtime_transcription': False, # ← KEY CHANGE: No streaming partials + + # VAD settings - optimized for longer utterances (per ChatGPT) + 'silero_sensitivity': self.config['silero_sensitivity'], + 'silero_use_onnx': True, # Faster + 'webrtc_sensitivity': self.config['webrtc_sensitivity'], + 'post_speech_silence_duration': self.config['silence_duration'], + 'min_length_of_recording': self.config['min_recording_length'], + 'min_gap_between_recordings': self.config['min_gap'], + 'pre_recording_buffer_duration': 1.2, # ChatGPT: ~1.2s before first decode + + # Callbacks + 'on_recording_start': self._on_recording_start, + 'on_recording_stop': self._on_recording_stop, + 'on_vad_detect_start': lambda: logger.debug(f"[{self.session_id}] VAD listening"), + 'on_vad_detect_stop': lambda: logger.debug(f"[{self.session_id}] VAD stopped"), + + # Other settings + 'spinner': False, # No spinner in container + 'level': logging.WARNING, # Reduce internal logging + + # Beam search settings - optimized for accuracy + 'beam_size': 5, + + # Batch sizes + 'batch_size': 16, + + 'initial_prompt': "", + } + + self.recorder = AudioToTextRecorder(**recorder_kwargs) + logger.info(f"[{self.session_id}] ✅ Recorder initialized (incremental mode, transcript-stability silence detection)") + + # Run initialization in thread pool + await asyncio.get_event_loop().run_in_executor(None, init_recorder) + + # Start audio feed thread + self.feed_thread = threading.Thread(target=self._feed_audio_thread, daemon=True) + self.feed_thread.start() + + # NOTE: We don't call recorder.start() - VAD callbacks don't work with use_microphone=False + # Instead, we detect silence ourselves via transcript stability in the decode loop + + # Start CORRECT incremental decoding loop + # Since RealtimeSTT VAD callbacks don't work with use_microphone=False, + # we detect silence ourselves via transcript stability + def run_decode_loop(): + """ + Decode buffer periodically. Detect end-of-speech when: + 1. We have a transcript AND + 2. Transcript hasn't changed for silence_threshold seconds + """ + decode_interval = 0.7 # Re-decode every 700ms + min_audio_before_first_decode = 1.2 # Wait 1.2s before first decode + silence_threshold = 1.5 # If transcript stable for 1.5s, consider it final + + last_transcript_change_time = 0 + has_transcript = False + + logger.info(f"[{self.session_id}] Decode loop ready (silence detection: {silence_threshold}s)") + + while self.running: + try: + current_time = time.time() + buffer_duration = len(self.float_buffer) / 16000.0 if self.float_buffer else 0 + + # Only decode if we have enough audio + if buffer_duration >= min_audio_before_first_decode: + # Check if enough time since last decode + if (current_time - self.last_decode_time) >= decode_interval: + try: + audio_array = np.array(self.float_buffer, dtype=np.float32) + + logger.debug(f"[{self.session_id}] 🔄 Decode (buffer: {buffer_duration:.1f}s)") + + result = self.recorder.perform_final_transcription(audio_array) + text = result.strip() if result else "" + + if text: + if text != self.last_decode_text: + # Transcript changed - update and reset stability timer + self.last_decode_text = text + last_transcript_change_time = current_time + has_transcript = True + + logger.info(f"[{self.session_id}] 📝 Partial: {text}") + + asyncio.run_coroutine_threadsafe( + self._send_transcript("partial", text), + self.loop + ) + # else: transcript same, stability timer continues + + self.last_decode_time = current_time + + except Exception as e: + logger.error(f"[{self.session_id}] Decode error: {e}", exc_info=True) + + # Check for silence (transcript stable for threshold) + if has_transcript and last_transcript_change_time > 0: + time_since_change = current_time - last_transcript_change_time + if time_since_change >= silence_threshold: + # Transcript has been stable - emit final + logger.info(f"[{self.session_id}] ✅ Final (stable {time_since_change:.1f}s): {self.last_decode_text}") + + asyncio.run_coroutine_threadsafe( + self._send_transcript("final", self.last_decode_text), + self.loop + ) + + # Reset for next utterance + self.float_buffer = [] + self.last_decode_text = "" + self.last_decode_time = 0 + last_transcript_change_time = 0 + has_transcript = False + + time.sleep(0.1) # Check every 100ms + + except Exception as e: + if self.running: + logger.error(f"[{self.session_id}] Decode loop error: {e}", exc_info=True) + break + + self.text_thread = threading.Thread(target=run_decode_loop, daemon=True) + self.text_thread.start() + + logger.info(f"[{self.session_id}] ✅ Session started successfully") + + except Exception as e: + logger.error(f"[{self.session_id}] Failed to start session: {e}", exc_info=True) + raise + + def feed_audio(self, audio_data: bytes): + """Feed audio data to the recorder AND our buffer.""" + if self.running: + # Convert bytes to numpy array (16-bit PCM) + audio_np = np.frombuffer(audio_data, dtype=np.int16) + + # Feed to RealtimeSTT for VAD only + self.audio_queue.put(audio_np) + + # Also add to OUR float32 buffer (normalized to -1.0 to 1.0) + float_audio = audio_np.astype(np.float32) / 32768.0 + self.float_buffer.extend(float_audio) + + # Keep buffer size bounded (max 30 seconds at 16kHz = 480k samples) + max_samples = int(self.max_buffer_duration * 16000) + if len(self.float_buffer) > max_samples: + self.float_buffer = self.float_buffer[-max_samples:] + + def reset(self): + """Reset the session state.""" + logger.info(f"[{self.session_id}] Resetting session") + self.float_buffer = [] + self.last_decode_text = "" + # Clear audio queue + while not self.audio_queue.empty(): + try: + self.audio_queue.get_nowait() + except queue.Empty: + break + + async def stop(self): + """Stop the session and cleanup.""" + logger.info(f"[{self.session_id}] Stopping session...") + self.running = False + + # Wait for threads to finish + if self.feed_thread and self.feed_thread.is_alive(): + self.feed_thread.join(timeout=2) + + # Shutdown recorder + if self.recorder: + try: + self.recorder.shutdown() + except Exception as e: + logger.error(f"[{self.session_id}] Error shutting down recorder: {e}") + + logger.info(f"[{self.session_id}] Session stopped") + + +class STTServer: + """ + WebSocket server for RealtimeSTT. + Handles multiple concurrent clients (one per Discord user). + """ + + def __init__(self, host: str = "0.0.0.0", port: int = 8766, config: Dict[str, Any] = None): + self.host = host + self.port = port + self.sessions: Dict[str, STTSession] = {} + self.session_counter = 0 + + # Config must be provided + if not config: + raise ValueError("Configuration dict must be provided to STTServer") + + self.config = config + + logger.info("=" * 60) + logger.info("RealtimeSTT Server Configuration:") + logger.info(f" Host: {host}:{port}") + logger.info(f" Model: {self.config['model']}") + logger.info(f" Language: {self.config.get('language', 'auto-detect')}") + logger.info(f" Device: {self.config['device']}") + logger.info(f" Compute Type: {self.config['compute_type']}") + logger.info(f" Silence Duration: {self.config['silence_duration']}s") + logger.info(f" Realtime Pause: {self.config.get('realtime_processing_pause', 'N/A')}s") + logger.info("=" * 60) + + async def handle_client(self, websocket): + """Handle a WebSocket client connection.""" + self.session_counter += 1 + session_id = f"session_{self.session_counter}" + session = None + + try: + logger.info(f"[{session_id}] Client connected from {websocket.remote_address}") + + # Create session + session = STTSession(websocket, session_id, self.config) + self.sessions[session_id] = session + + # Start session + await session.start(asyncio.get_event_loop()) + + # Process messages + async for message in websocket: + try: + if isinstance(message, bytes): + # Binary audio data + session.feed_audio(message) + else: + # JSON command + data = json.loads(message) + command = data.get('command', '') + + if command == 'reset': + session.reset() + elif command == 'ping': + await websocket.send(json.dumps({ + 'type': 'pong', + 'timestamp': time.time() + })) + else: + logger.warning(f"[{session_id}] Unknown command: {command}") + + except json.JSONDecodeError: + logger.warning(f"[{session_id}] Invalid JSON message") + except Exception as e: + logger.error(f"[{session_id}] Error processing message: {e}") + + except websockets.exceptions.ConnectionClosed: + logger.info(f"[{session_id}] Client disconnected") + except Exception as e: + logger.error(f"[{session_id}] Error: {e}", exc_info=True) + finally: + # Cleanup + if session: + await session.stop() + del self.sessions[session_id] + + async def run(self): + """Run the WebSocket server.""" + logger.info(f"Starting RealtimeSTT server on ws://{self.host}:{self.port}") + + async with serve( + self.handle_client, + self.host, + self.port, + ping_interval=30, + ping_timeout=10, + max_size=10 * 1024 * 1024, # 10MB max message size + ): + logger.info("✅ Server ready and listening for connections") + await asyncio.Future() # Run forever + + +async def warmup_model(config: Dict[str, Any]): + """ + Warmup is DISABLED - it wastes memory by loading a model that's never reused. + The first session will load the model when needed. + """ + global warmup_complete + + logger.info("âš ī¸ Warmup disabled to save VRAM - model will load on first connection") + warmup_complete = True # Mark as complete so health check passes + + +async def health_handler(request): + """HTTP health check endpoint""" + if warmup_complete: + return web.json_response({ + "status": "ready", + "warmed_up": True, + "model": "small.en", + "device": "cuda" + }) + else: + return web.json_response({ + "status": "warming_up", + "warmed_up": False, + "model": "small.en", + "device": "cuda" + }, status=503) + + +async def start_http_server(host: str, http_port: int): + """Start HTTP server for health checks""" + app = web.Application() + app.router.add_get('/health', health_handler) + + runner = web.AppRunner(app) + await runner.setup() + site = web.TCPSite(runner, host, http_port) + await site.start() + + logger.info(f"✅ HTTP health server listening on http://{host}:{http_port}") + + +def main(): + """Main entry point.""" + import os + + # Get configuration from environment + host = os.environ.get('STT_HOST', '0.0.0.0') + port = int(os.environ.get('STT_PORT', '8766')) + http_port = int(os.environ.get('STT_HTTP_PORT', '8767')) # HTTP health check port + + # Configuration - ChatGPT's incremental utterance decoding approach + config = { + 'model': 'turbo', # Fast multilingual model + 'language': 'en', # SET LANGUAGE! Auto-detect adds 4+ seconds latency (change to 'ja', 'bg' as needed) + 'compute_type': 'float16', + 'device': 'cuda', + + # VAD settings - ChatGPT: "minimum speech ~600ms, end-of-speech silence ~400-600ms" + 'silero_sensitivity': 0.6, + 'webrtc_sensitivity': 3, + 'silence_duration': 0.5, # 500ms end-of-speech silence + 'min_recording_length': 0.6, # 600ms minimum speech + 'min_gap': 0.3, + } + + # Create and run server + server = STTServer(host=host, port=port, config=config) + + async def run_all(): + # Start warmup in background + asyncio.create_task(warmup_model(config)) + + # Start HTTP health server + asyncio.create_task(start_http_server(host, http_port)) + + # Start WebSocket server + await server.run() + + try: + asyncio.run(run_all()) + except KeyboardInterrupt: + logger.info("Server shutdown requested") + except Exception as e: + logger.error(f"Server error: {e}", exc_info=True) + raise + + +if __name__ == '__main__': + main() diff --git a/bot/api.py b/bot/api.py index b89755c..eb5dd0e 100644 --- a/bot/api.py +++ b/bot/api.py @@ -2541,7 +2541,7 @@ async def initiate_voice_call(user_id: str = Form(...), voice_channel_id: str = Flow: 1. Start STT and TTS containers - 2. Wait for warmup + 2. Wait for models to load (health check) 3. Join voice channel 4. Send DM with invite to user 5. Wait for user to join (30min timeout) @@ -2642,16 +2642,10 @@ Keep it brief (1-2 sentences). Make it feel personal and enthusiastic!""" sent_message = await user.send(dm_message) - # Log to DM logger - await dm_logger.log_message( - user_id=user.id, - user_name=user.name, - message_content=dm_message, - direction="outgoing", - message_id=sent_message.id, - attachments=[], - response_type="voice_call_invite" - ) + # Log to DM logger (create a mock message object for logging) + # The dm_logger.log_user_message expects a discord.Message object + # So we need to use the actual sent_message + dm_logger.log_user_message(user, sent_message, is_bot_message=True) logger.info(f"✓ DM sent to {user.name}") @@ -2701,15 +2695,7 @@ async def _voice_call_timeout_handler(voice_session: 'VoiceSession', user: disco sent_message = await user.send(timeout_message) # Log to DM logger - await dm_logger.log_message( - user_id=user.id, - user_name=user.name, - message_content=timeout_message, - direction="outgoing", - message_id=sent_message.id, - attachments=[], - response_type="voice_call_timeout" - ) + dm_logger.log_user_message(user, sent_message, is_bot_message=True) except: pass diff --git a/bot/utils/container_manager.py b/bot/utils/container_manager.py index 0d42e09..b318784 100644 --- a/bot/utils/container_manager.py +++ b/bot/utils/container_manager.py @@ -1,7 +1,7 @@ # container_manager.py """ Manages Docker containers for STT and TTS services. -Handles startup, shutdown, and warmup detection. +Handles startup, shutdown, and readiness detection. """ import asyncio @@ -18,12 +18,12 @@ class ContainerManager: STT_CONTAINER = "miku-stt" TTS_CONTAINER = "miku-rvc-api" - # Warmup check endpoints + # Health check endpoints STT_HEALTH_URL = "http://miku-stt:8767/health" # HTTP health check endpoint TTS_HEALTH_URL = "http://miku-rvc-api:8765/health" - # Warmup timeouts - STT_WARMUP_TIMEOUT = 30 # seconds + # Startup timeouts (time to load models and become ready) + STT_WARMUP_TIMEOUT = 30 # seconds (Whisper model loading) TTS_WARMUP_TIMEOUT = 60 # seconds (RVC takes longer) @classmethod @@ -65,17 +65,17 @@ class ContainerManager: logger.info(f"✓ {cls.TTS_CONTAINER} started") - # Wait for warmup - logger.info("âŗ Waiting for containers to warm up...") + # Wait for models to load and become ready + logger.info("âŗ Waiting for models to load...") stt_ready = await cls._wait_for_stt_warmup() if not stt_ready: - logger.error("STT failed to warm up") + logger.error("STT failed to become ready") return False tts_ready = await cls._wait_for_tts_warmup() if not tts_ready: - logger.error("TTS failed to warm up") + logger.error("TTS failed to become ready") return False logger.info("✅ All voice containers ready!") @@ -130,7 +130,8 @@ class ContainerManager: async with session.get(cls.STT_HEALTH_URL, timeout=aiohttp.ClientTimeout(total=2)) as resp: if resp.status == 200: data = await resp.json() - if data.get("status") == "ready" and data.get("warmed_up"): + # New STT server returns {"status": "ready"} when models are loaded + if data.get("status") == "ready": logger.info("✓ STT is ready") return True except Exception: diff --git a/stt-realtime/requirements.txt b/stt-realtime/requirements.txt index 9b471eb..3fac6c8 100644 --- a/stt-realtime/requirements.txt +++ b/stt-realtime/requirements.txt @@ -1,19 +1,16 @@ -# RealtimeSTT dependencies -RealtimeSTT>=0.3.104 +# Low-latency STT dependencies websockets>=12.0 numpy>=1.24.0 -# For faster-whisper backend (GPU accelerated) +# Faster-whisper backend (GPU accelerated) faster-whisper>=1.0.0 ctranslate2>=4.4.0 # Audio processing soundfile>=0.12.0 -librosa>=0.10.0 -# VAD dependencies (included with RealtimeSTT but explicit) -webrtcvad>=2.0.10 -silero-vad>=5.1 +# VAD - Silero (loaded via torch.hub) +# No explicit package needed, comes with torch # Utilities aiohttp>=3.9.0 diff --git a/stt-realtime/stt_server.py b/stt-realtime/stt_server.py index ec31733..90f054d 100644 --- a/stt-realtime/stt_server.py +++ b/stt-realtime/stt_server.py @@ -1,9 +1,14 @@ #!/usr/bin/env python3 """ -RealtimeSTT WebSocket Server +Low-Latency STT WebSocket Server -Provides real-time speech-to-text transcription using Faster-Whisper. -Receives audio chunks via WebSocket and streams back partial/final transcripts. +Uses Silero VAD for speech detection + Faster-Whisper turbo for transcription. +Achieves sub-second latency after speech ends. + +Architecture: +1. Silero VAD runs on every audio chunk to detect speech boundaries +2. When speech ends (silence detected), immediately transcribe the buffer +3. Send final transcript - no waiting for stability Protocol: - Client sends: binary audio data (16kHz, 16-bit mono PCM) @@ -32,352 +37,357 @@ logging.basicConfig( ) logger = logging.getLogger('stt-realtime') -# Import RealtimeSTT -from RealtimeSTT import AudioToTextRecorder +# Silero VAD +import torch +torch.set_num_threads(1) # Prevent thread contention -# Global warmup state +# Faster-Whisper for transcription +from faster_whisper import WhisperModel + +# Global model (shared across sessions for memory efficiency) +whisper_model: Optional[WhisperModel] = None +vad_model = None warmup_complete = False -warmup_lock = threading.Lock() -warmup_recorder = None + + +def load_vad_model(): + """Load Silero VAD model.""" + global vad_model + model, _ = torch.hub.load( + repo_or_dir='snakers4/silero-vad', + model='silero_vad', + force_reload=False, + onnx=True # Use ONNX for speed + ) + vad_model = model + logger.info("Silero VAD loaded (ONNX)") + return model + + +def load_whisper_model(config: Dict[str, Any]): + """Load Faster-Whisper model.""" + global whisper_model + whisper_model = WhisperModel( + config['model'], + device=config['device'], + compute_type=config['compute_type'], + ) + logger.info(f"Faster-Whisper '{config['model']}' loaded on {config['device']}") + return whisper_model class STTSession: """ - Manages a single STT session for a WebSocket client. - Uses RealtimeSTT's AudioToTextRecorder with feed_audio() method. + Low-latency STT session using Silero VAD + Faster-Whisper. """ + SAMPLE_RATE = 16000 + VAD_CHUNK_MS = 32 # Silero needs 512 samples at 16kHz = 32ms + VAD_CHUNK_SAMPLES = 512 # Fixed: Silero requires exactly 512 samples at 16kHz + def __init__(self, websocket, session_id: str, config: Dict[str, Any]): self.websocket = websocket self.session_id = session_id self.config = config - self.recorder: Optional[AudioToTextRecorder] = None self.running = False - self.audio_queue = queue.Queue() - self.feed_thread: Optional[threading.Thread] = None - self.last_partial = "" - self.last_stabilized = "" # Track last stabilized partial - self.last_text_was_stabilized = False # Track which came last - self.recording_active = False # Track if currently recording + self.loop = None - logger.info(f"[{session_id}] Session created") - - def _on_realtime_transcription(self, text: str): - """Called when partial transcription is available.""" - if text and text != self.last_partial: - self.last_partial = text - self.last_text_was_stabilized = False # Partial came after stabilized - logger.info(f"[{self.session_id}] 📝 Partial: {text}") - asyncio.run_coroutine_threadsafe( - self._send_transcript("partial", text), - self.loop - ) - - def _on_realtime_stabilized(self, text: str): - """Called when a stabilized partial is available (high confidence).""" - if text and text.strip(): - self.last_stabilized = text - self.last_text_was_stabilized = True # Stabilized came after partial - logger.info(f"[{self.session_id}] 🔒 Stabilized: {text}") - asyncio.run_coroutine_threadsafe( - self._send_transcript("partial", text), - self.loop - ) - - def _on_recording_stop(self): - """Called when recording stops (silence detected).""" - logger.info(f"[{self.session_id}] âšī¸ Recording stopped") - self.recording_active = False + # Audio state + self.audio_buffer = [] # Float32 samples for current utterance + self.vad_buffer = [] # Small buffer for VAD chunk alignment - # Use the most recent text: prioritize whichever came last - if self.last_text_was_stabilized: - final_text = self.last_stabilized or self.last_partial - source = "stabilized" if self.last_stabilized else "partial" - else: - final_text = self.last_partial or self.last_stabilized - source = "partial" if self.last_partial else "stabilized" + # Speech detection state + self.is_speaking = False + self.silence_start_time = 0 + self.speech_start_time = 0 - if final_text: - logger.info(f"[{self.session_id}] ✅ Final (from {source}): {final_text}") - asyncio.run_coroutine_threadsafe( - self._send_transcript("final", final_text), - self.loop - ) - else: - # No transcript means VAD false positive (detected "speech" in pure noise) - logger.warning(f"[{self.session_id}] âš ī¸ Recording stopped but no transcript available (VAD false positive)") - logger.info(f"[{self.session_id}] 🔄 Clearing audio buffer to recover") - - # Clear the audio queue to prevent stale data - try: - while not self.audio_queue.empty(): - self.audio_queue.get_nowait() - except Exception: - pass + # Configurable thresholds + self.vad_threshold = config.get('vad_threshold', 0.5) + self.silence_duration_ms = config.get('silence_duration_ms', 400) + self.min_speech_ms = config.get('min_speech_ms', 250) + self.max_speech_duration = config.get('max_speech_duration', 30.0) - # Reset state - self.last_stabilized = "" - self.last_partial = "" - self.last_text_was_stabilized = False - - def _on_recording_start(self): - """Called when recording starts (speech detected).""" - logger.info(f"[{self.session_id}] đŸŽ™ī¸ Recording started") - self.recording_active = True - self.last_stabilized = "" - self.last_partial = "" - - def _on_transcription(self, text: str): - """Not used - we use stabilized partials as finals.""" - pass - - async def _send_transcript(self, transcript_type: str, text: str): - """Send transcript to client via WebSocket.""" - try: - message = { - "type": transcript_type, - "text": text, - "timestamp": time.time() - } - await self.websocket.send(json.dumps(message)) - except Exception as e: - logger.error(f"[{self.session_id}] Failed to send transcript: {e}") - - def _feed_audio_thread(self): - """Thread that feeds audio to the recorder.""" - logger.info(f"[{self.session_id}] Audio feed thread started") - while self.running: - try: - # Get audio chunk with timeout - audio_chunk = self.audio_queue.get(timeout=0.1) - if audio_chunk is not None and self.recorder: - self.recorder.feed_audio(audio_chunk) - except queue.Empty: - continue - except Exception as e: - logger.error(f"[{self.session_id}] Error feeding audio: {e}") - logger.info(f"[{self.session_id}] Audio feed thread stopped") + # Speculative transcription settings + self.speculative_silence_ms = config.get('speculative_silence_ms', 150) # Start transcribing early + self.speculative_pending = False # Is a speculative transcription in flight? + self.speculative_audio_snapshot = None # Audio buffer snapshot for speculative + self.speculative_result = None # Result from speculative transcription + self.speculative_result_ready = threading.Event() + + # Transcription queue + self.transcribe_queue = queue.Queue() + self.transcribe_thread = None + + logger.info(f"[{session_id}] Session created (speculative: {self.speculative_silence_ms}ms, final: {self.silence_duration_ms}ms)") async def start(self, loop: asyncio.AbstractEventLoop): - """Start the STT session.""" + """Start the session.""" self.loop = loop self.running = True - logger.info(f"[{self.session_id}] Starting RealtimeSTT recorder...") - logger.info(f"[{self.session_id}] Model: {self.config['model']}") - logger.info(f"[{self.session_id}] Device: {self.config['device']}") + self.transcribe_thread = threading.Thread(target=self._transcription_worker, daemon=True) + self.transcribe_thread.start() - try: - # Create recorder in a thread to avoid blocking - def init_recorder(): - self.recorder = AudioToTextRecorder( - # Model settings - using same model for both partial and final - model=self.config['model'], - language=self.config['language'], - compute_type=self.config['compute_type'], - device=self.config['device'], - - # Disable microphone - we feed audio manually - use_microphone=False, - - # Real-time transcription - use same model for everything - enable_realtime_transcription=True, - realtime_model_type=self.config['model'], # Use same model - realtime_processing_pause=0.05, # 50ms between updates - on_realtime_transcription_update=self._on_realtime_transcription, - on_realtime_transcription_stabilized=self._on_realtime_stabilized, - - # VAD settings - very permissive, rely on Discord's VAD for speech detection - # Our VAD is only for silence detection, not filtering audio content - silero_sensitivity=0.05, # Very low = barely filters anything - silero_use_onnx=True, # Faster - webrtc_sensitivity=3, - post_speech_silence_duration=self.config['silence_duration'], - min_length_of_recording=self.config['min_recording_length'], - min_gap_between_recordings=self.config['min_gap'], - pre_recording_buffer_duration=1.0, # Capture more audio before/after speech - - # Callbacks - on_recording_start=self._on_recording_start, - on_recording_stop=self._on_recording_stop, - on_vad_detect_start=lambda: logger.debug(f"[{self.session_id}] VAD listening"), - on_vad_detect_stop=lambda: logger.debug(f"[{self.session_id}] VAD stopped"), - - # Other settings - spinner=False, # No spinner in container - level=logging.WARNING, # Reduce internal logging - - # Beam search settings - beam_size=5, # Higher beam = better accuracy (used for final processing) - beam_size_realtime=5, # Increased from 3 for better real-time accuracy - - # Batch sizes - batch_size=16, - realtime_batch_size=8, - - initial_prompt="", # Can add context here if needed + logger.info(f"[{self.session_id}] Session started") + + def _transcription_worker(self): + """Background thread that processes transcription requests.""" + while self.running: + try: + item = self.transcribe_queue.get(timeout=0.1) + if item is None: + continue + + audio_array, is_final, is_speculative = item + start_time = time.time() + + segments, info = whisper_model.transcribe( + audio_array, + language=self.config.get('language', 'en'), + beam_size=1, + best_of=1, + temperature=0.0, + vad_filter=False, + without_timestamps=True, ) - logger.info(f"[{self.session_id}] ✅ Recorder initialized") - - # Run initialization in thread pool - await asyncio.get_event_loop().run_in_executor(None, init_recorder) - - # Start audio feed thread - self.feed_thread = threading.Thread(target=self._feed_audio_thread, daemon=True) - self.feed_thread.start() - - # Start the recorder's text processing loop in a thread - def run_text_loop(): - while self.running: - try: - # This blocks until speech is detected and transcribed - text = self.recorder.text(self._on_transcription) - except Exception as e: - if self.running: - logger.error(f"[{self.session_id}] Text loop error: {e}") - break - - self.text_thread = threading.Thread(target=run_text_loop, daemon=True) - self.text_thread.start() - - logger.info(f"[{self.session_id}] ✅ Session started successfully") - + + text = " ".join(seg.text for seg in segments).strip() + elapsed = time.time() - start_time + + if is_speculative: + # Store result for potential use + self.speculative_result = (text, elapsed) + self.speculative_result_ready.set() + logger.debug(f"[{self.session_id}] SPECULATIVE ({elapsed:.2f}s): {text}") + elif text: + transcript_type = "final" if is_final else "partial" + logger.info(f"[{self.session_id}] {transcript_type.upper()} ({elapsed:.2f}s): {text}") + + asyncio.run_coroutine_threadsafe( + self._send_transcript(transcript_type, text), + self.loop + ) + + except queue.Empty: + continue + except Exception as e: + logger.error(f"[{self.session_id}] Transcription error: {e}", exc_info=True) + + async def _send_transcript(self, transcript_type: str, text: str): + """Send transcript to client.""" + try: + await self.websocket.send(json.dumps({ + "type": transcript_type, + "text": text, + "timestamp": time.time() + })) except Exception as e: - logger.error(f"[{self.session_id}] Failed to start session: {e}", exc_info=True) - raise + logger.error(f"[{self.session_id}] Send error: {e}") def feed_audio(self, audio_data: bytes): - """Feed audio data to the recorder.""" - if self.running: - # Convert bytes to numpy array (16-bit PCM) - audio_np = np.frombuffer(audio_data, dtype=np.int16) - self.audio_queue.put(audio_np) + """Process incoming audio data.""" + if not self.running: + return + + audio_int16 = np.frombuffer(audio_data, dtype=np.int16) + audio_float = audio_int16.astype(np.float32) / 32768.0 + + self.vad_buffer.extend(audio_float) + + while len(self.vad_buffer) >= self.VAD_CHUNK_SAMPLES: + chunk = np.array(self.vad_buffer[:self.VAD_CHUNK_SAMPLES], dtype=np.float32) + self.vad_buffer = self.vad_buffer[self.VAD_CHUNK_SAMPLES:] + self._process_vad_chunk(chunk) + + def _process_vad_chunk(self, chunk: np.ndarray): + """Process a single VAD chunk.""" + current_time = time.time() + + chunk_tensor = torch.from_numpy(chunk) + speech_prob = vad_model(chunk_tensor, self.SAMPLE_RATE).item() + + is_speech = speech_prob >= self.vad_threshold + + if is_speech: + if not self.is_speaking: + self.is_speaking = True + self.speech_start_time = current_time + self.audio_buffer = [] + logger.debug(f"[{self.session_id}] Speech started") + + self.audio_buffer.extend(chunk) + self.silence_start_time = 0 + + # Cancel any speculative transcription if speech resumed + if self.speculative_pending: + logger.debug(f"[{self.session_id}] Speech resumed, canceling speculative") + self.speculative_pending = False + self.speculative_result = None + self.speculative_result_ready.clear() + + speech_duration = current_time - self.speech_start_time + if speech_duration >= self.max_speech_duration: + logger.info(f"[{self.session_id}] Max duration reached") + self._finalize_utterance() + + else: + if self.is_speaking: + self.audio_buffer.extend(chunk) + + if self.silence_start_time == 0: + self.silence_start_time = current_time + + silence_duration_ms = (current_time - self.silence_start_time) * 1000 + speech_duration_ms = (self.silence_start_time - self.speech_start_time) * 1000 + + # Trigger speculative transcription early + if (not self.speculative_pending and + silence_duration_ms >= self.speculative_silence_ms and + speech_duration_ms >= self.min_speech_ms): + self._start_speculative_transcription() + + # Final silence threshold reached + if silence_duration_ms >= self.silence_duration_ms: + if speech_duration_ms >= self.min_speech_ms: + logger.debug(f"[{self.session_id}] Speech ended ({speech_duration_ms:.0f}ms)") + self._finalize_utterance() + else: + logger.debug(f"[{self.session_id}] Discarding short utterance") + self._reset_state() + + def _start_speculative_transcription(self): + """Start speculative transcription without waiting for full silence.""" + if self.audio_buffer: + self.speculative_pending = True + self.speculative_result = None + self.speculative_result_ready.clear() + + # Snapshot current buffer + audio_array = np.array(self.audio_buffer, dtype=np.float32) + duration = len(audio_array) / self.SAMPLE_RATE + + logger.debug(f"[{self.session_id}] Starting speculative transcription ({duration:.1f}s)") + # is_speculative=True + self.transcribe_queue.put((audio_array, False, True)) + + def _finalize_utterance(self): + """Finalize current utterance and send transcript.""" + if not self.audio_buffer: + self._reset_state() + return + + audio_array = np.array(self.audio_buffer, dtype=np.float32) + duration = len(audio_array) / self.SAMPLE_RATE + + # Check if we have a speculative result ready + if self.speculative_pending and self.speculative_result_ready.wait(timeout=0.05): + # Use speculative result immediately! + text, elapsed = self.speculative_result + if text: + logger.info(f"[{self.session_id}] FINAL [speculative] ({elapsed:.2f}s): {text}") + asyncio.run_coroutine_threadsafe( + self._send_transcript("final", text), + self.loop + ) + self._reset_state() + return + + # No speculative result, do regular transcription + logger.info(f"[{self.session_id}] Queuing transcription ({duration:.1f}s)") + self.transcribe_queue.put((audio_array, True, False)) + + self._reset_state() + + def _reset_state(self): + """Reset speech detection state.""" + self.is_speaking = False + self.audio_buffer = [] + self.silence_start_time = 0 + self.speech_start_time = 0 + self.speculative_pending = False + self.speculative_result = None + self.speculative_result_ready.clear() def reset(self): - """Reset the session state.""" - logger.info(f"[{self.session_id}] Resetting session") - self.last_partial = "" - # Clear audio queue - while not self.audio_queue.empty(): - try: - self.audio_queue.get_nowait() - except queue.Empty: - break + """Reset session state.""" + logger.info(f"[{self.session_id}] Resetting") + self._reset_state() + self.vad_buffer = [] async def stop(self): - """Stop the session and cleanup.""" - logger.info(f"[{self.session_id}] Stopping session...") + """Stop the session.""" + logger.info(f"[{self.session_id}] Stopping...") self.running = False - # Wait for threads to finish - if self.feed_thread and self.feed_thread.is_alive(): - self.feed_thread.join(timeout=2) + if self.audio_buffer and self.is_speaking: + self._finalize_utterance() - # Shutdown recorder - if self.recorder: - try: - self.recorder.shutdown() - except Exception as e: - logger.error(f"[{self.session_id}] Error shutting down recorder: {e}") + if self.transcribe_thread and self.transcribe_thread.is_alive(): + self.transcribe_thread.join(timeout=2) - logger.info(f"[{self.session_id}] Session stopped") + logger.info(f"[{self.session_id}] Stopped") class STTServer: - """ - WebSocket server for RealtimeSTT. - Handles multiple concurrent clients (one per Discord user). - """ + """WebSocket server for low-latency STT.""" - def __init__(self, host: str = "0.0.0.0", port: int = 8766): + def __init__(self, host: str, port: int, config: Dict[str, Any]): self.host = host self.port = port + self.config = config self.sessions: Dict[str, STTSession] = {} self.session_counter = 0 - # Default configuration - self.config = { - # Model - using small.en (English-only, more accurate than multilingual small) - 'model': 'small.en', - 'language': 'en', - 'compute_type': 'float16', # FP16 for GPU efficiency - 'device': 'cuda', - - # VAD settings - 'silero_sensitivity': 0.6, - 'webrtc_sensitivity': 3, - 'silence_duration': 0.8, # Shorter to improve responsiveness - 'min_recording_length': 0.5, - 'min_gap': 0.3, - } - logger.info("=" * 60) - logger.info("RealtimeSTT Server Configuration:") + logger.info("Low-Latency STT Server") logger.info(f" Host: {host}:{port}") - logger.info(f" Model: {self.config['model']} (English-only, optimized)") - logger.info(f" Beam size: 5 (higher accuracy)") - logger.info(f" Strategy: Use last partial as final (instant response)") - logger.info(f" Language: {self.config['language']}") - logger.info(f" Device: {self.config['device']}") - logger.info(f" Compute Type: {self.config['compute_type']}") - logger.info(f" Silence Duration: {self.config['silence_duration']}s") + logger.info(f" Model: {config['model']}") + logger.info(f" Language: {config.get('language', 'en')}") + logger.info(f" Silence: {config.get('silence_duration_ms', 400)}ms") logger.info("=" * 60) async def handle_client(self, websocket): - """Handle a WebSocket client connection.""" + """Handle WebSocket client.""" self.session_counter += 1 session_id = f"session_{self.session_counter}" session = None try: - logger.info(f"[{session_id}] Client connected from {websocket.remote_address}") + logger.info(f"[{session_id}] Client connected") - # Create session session = STTSession(websocket, session_id, self.config) self.sessions[session_id] = session - - # Start session await session.start(asyncio.get_event_loop()) - # Process messages async for message in websocket: - try: - if isinstance(message, bytes): - # Binary audio data - session.feed_audio(message) - else: - # JSON command + if isinstance(message, bytes): + session.feed_audio(message) + else: + try: data = json.loads(message) - command = data.get('command', '') - - if command == 'reset': + cmd = data.get('command', '') + if cmd == 'reset': session.reset() - elif command == 'ping': + elif cmd == 'ping': await websocket.send(json.dumps({ 'type': 'pong', 'timestamp': time.time() })) - else: - logger.warning(f"[{session_id}] Unknown command: {command}") - - except json.JSONDecodeError: - logger.warning(f"[{session_id}] Invalid JSON message") - except Exception as e: - logger.error(f"[{session_id}] Error processing message: {e}") + except json.JSONDecodeError: + pass except websockets.exceptions.ConnectionClosed: logger.info(f"[{session_id}] Client disconnected") except Exception as e: logger.error(f"[{session_id}] Error: {e}", exc_info=True) finally: - # Cleanup if session: await session.stop() del self.sessions[session_id] async def run(self): - """Run the WebSocket server.""" - logger.info(f"Starting RealtimeSTT server on ws://{self.host}:{self.port}") + """Run the server.""" + logger.info(f"Starting server on ws://{self.host}:{self.port}") async with serve( self.handle_client, @@ -385,137 +395,83 @@ class STTServer: self.port, ping_interval=30, ping_timeout=10, - max_size=10 * 1024 * 1024, # 10MB max message size + max_size=10 * 1024 * 1024, ): - logger.info("✅ Server ready and listening for connections") - await asyncio.Future() # Run forever + logger.info("Server ready") + await asyncio.Future() -async def warmup_model(config: Dict[str, Any]): - """ - Warm up the STT model by loading it and processing test audio. - This ensures the model is cached in memory before handling real requests. - """ - global warmup_complete, warmup_recorder +async def warmup(config: Dict[str, Any]): + """Load models at startup.""" + global warmup_complete - with warmup_lock: - if warmup_complete: - logger.info("Model already warmed up") - return - - logger.info("đŸ”Ĩ Starting model warmup...") - try: - # Generate silent test audio (1 second of silence, 16kHz) - test_audio = np.zeros(16000, dtype=np.int16) - - # Initialize a temporary recorder to load the model - logger.info("Loading Faster-Whisper model...") - - def dummy_callback(text): - pass - - # This will trigger model loading and compilation - warmup_recorder = AudioToTextRecorder( - model=config['model'], - language=config['language'], - compute_type=config['compute_type'], - device=config['device'], - silero_sensitivity=config['silero_sensitivity'], - webrtc_sensitivity=config['webrtc_sensitivity'], - post_speech_silence_duration=config['silence_duration'], - min_length_of_recording=config['min_recording_length'], - min_gap_between_recordings=config['min_gap'], - enable_realtime_transcription=True, - realtime_processing_pause=0.1, - on_realtime_transcription_update=dummy_callback, - on_realtime_transcription_stabilized=dummy_callback, - spinner=False, - level=logging.WARNING, - beam_size=5, - beam_size_realtime=5, - batch_size=16, - realtime_batch_size=8, - initial_prompt="", - ) - - logger.info("✅ Model loaded and warmed up successfully") - warmup_complete = True - - except Exception as e: - logger.error(f"❌ Warmup failed: {e}", exc_info=True) - warmup_complete = False + logger.info("Loading models...") + + load_vad_model() + load_whisper_model(config) + + logger.info("Warming up transcription...") + dummy_audio = np.zeros(16000, dtype=np.float32) + segments, _ = whisper_model.transcribe( + dummy_audio, + language=config.get('language', 'en'), + beam_size=1, + ) + list(segments) + + warmup_complete = True + logger.info("Warmup complete") async def health_handler(request): - """HTTP health check endpoint""" + """Health check endpoint.""" if warmup_complete: - return web.json_response({ - "status": "ready", - "warmed_up": True, - "model": "small.en", - "device": "cuda" - }) - else: - return web.json_response({ - "status": "warming_up", - "warmed_up": False, - "model": "small.en", - "device": "cuda" - }, status=503) + return web.json_response({"status": "ready"}) + return web.json_response({"status": "warming_up"}, status=503) -async def start_http_server(host: str, http_port: int): - """Start HTTP server for health checks""" +async def start_http_server(host: str, port: int): + """Start HTTP health server.""" app = web.Application() app.router.add_get('/health', health_handler) - runner = web.AppRunner(app) await runner.setup() - site = web.TCPSite(runner, host, http_port) + site = web.TCPSite(runner, host, port) await site.start() - - logger.info(f"✅ HTTP health server listening on http://{host}:{http_port}") + logger.info(f"Health server on http://{host}:{port}") def main(): """Main entry point.""" import os - # Get configuration from environment host = os.environ.get('STT_HOST', '0.0.0.0') port = int(os.environ.get('STT_PORT', '8766')) - http_port = int(os.environ.get('STT_HTTP_PORT', '8767')) # HTTP health check port + http_port = int(os.environ.get('STT_HTTP_PORT', '8767')) - # Configuration config = { 'model': 'small.en', 'language': 'en', 'compute_type': 'float16', 'device': 'cuda', - 'silero_sensitivity': 0.6, - 'webrtc_sensitivity': 3, - 'silence_duration': 0.8, - 'min_recording_length': 0.5, - 'min_gap': 0.3, + 'vad_threshold': 0.5, + 'silence_duration_ms': 400, # Final silence threshold + 'speculative_silence_ms': 150, # Start transcribing early at 150ms + 'min_speech_ms': 250, + 'max_speech_duration': 30.0, } - # Create and run server - server = STTServer(host=host, port=port) + server = STTServer(host, port, config) async def run_all(): - # Start warmup in background - asyncio.create_task(warmup_model(config)) - - # Start HTTP health server + await warmup(config) asyncio.create_task(start_http_server(host, http_port)) - - # Start WebSocket server await server.run() try: asyncio.run(run_all()) except KeyboardInterrupt: - logger.info("Server shutdown requested") + logger.info("Shutdown requested") except Exception as e: logger.error(f"Server error: {e}", exc_info=True) raise