""" STT Server FastAPI WebSocket server for real-time speech-to-text. Combines Silero VAD (CPU) and NVIDIA Parakeet (GPU) for efficient transcription. Architecture: - VAD runs continuously on every audio chunk (CPU) - Parakeet transcribes only when VAD detects speech (GPU) - Supports multiple concurrent users - Sends partial and final transcripts via WebSocket with word-level tokens """ from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException from fastapi.responses import JSONResponse import numpy as np import asyncio import logging from typing import Dict, Optional from datetime import datetime from vad_processor import VADProcessor from parakeet_transcriber import ParakeetTranscriber # Configure logging logging.basicConfig( level=logging.INFO, format='[%(levelname)s] [%(name)s] %(message)s' ) logger = logging.getLogger('stt_server') # Initialize FastAPI app app = FastAPI(title="Miku STT Server", version="1.0.0") # Global instances (initialized on startup) vad_processor: Optional[VADProcessor] = None parakeet_transcriber: Optional[ParakeetTranscriber] = None # User session tracking user_sessions: Dict[str, dict] = {} class UserSTTSession: """Manages STT state for a single user.""" def __init__(self, user_id: str, websocket: WebSocket): self.user_id = user_id self.websocket = websocket self.audio_buffer = [] self.is_speaking = False self.timestamp_ms = 0.0 self.transcript_buffer = [] self.last_transcript = "" self.last_partial_duration = 0.0 # Track when we last sent a partial self.last_speech_timestamp = 0.0 # Track last time we detected speech self.speech_timeout_ms = 3000 # Force finalization after 3s of no new speech logger.info(f"Created STT session for user {user_id}") async def process_audio_chunk(self, audio_data: bytes): """ Process incoming audio chunk. Args: audio_data: Raw PCM audio (int16, 16kHz mono) """ # Convert bytes to numpy array (int16) audio_np = np.frombuffer(audio_data, dtype=np.int16) # Calculate timestamp (assuming 16kHz, 20ms chunks = 320 samples) chunk_duration_ms = (len(audio_np) / 16000) * 1000 self.timestamp_ms += chunk_duration_ms # Run VAD on chunk vad_event = vad_processor.detect_speech_segment(audio_np, self.timestamp_ms) if vad_event: event_type = vad_event["event"] probability = vad_event["probability"] logger.debug(f"VAD event for user {self.user_id}: {event_type} (prob={probability:.3f})") # Send VAD event to client await self.websocket.send_json({ "type": "vad", "event": event_type, "speaking": event_type in ["speech_start", "speaking"], "probability": probability, "timestamp": self.timestamp_ms }) # Handle speech events if event_type == "speech_start": self.is_speaking = True self.audio_buffer = [audio_np] self.last_partial_duration = 0.0 self.last_speech_timestamp = self.timestamp_ms logger.info(f"[STT] User {self.user_id} SPEECH START") elif event_type == "speaking": if self.is_speaking: self.audio_buffer.append(audio_np) self.last_speech_timestamp = self.timestamp_ms # Update speech timestamp # Transcribe partial every ~1 second for streaming (reduced from 2s) total_samples = sum(len(chunk) for chunk in self.audio_buffer) duration_s = total_samples / 16000 # More frequent partials for better responsiveness if duration_s >= 1.0: logger.debug(f"Triggering partial transcription at {duration_s:.1f}s") await self._transcribe_partial() # Keep buffer for final transcription, but mark progress self.last_partial_duration = duration_s elif event_type == "speech_end": self.is_speaking = False logger.info(f"[STT] User {self.user_id} SPEECH END (VAD detected) - transcribing final") # Transcribe final await self._transcribe_final() # Clear buffer self.audio_buffer = [] self.last_partial_duration = 0.0 logger.debug(f"User {self.user_id} stopped speaking") else: # No VAD event - still accumulate audio if speaking if self.is_speaking: self.audio_buffer.append(audio_np) # Check for timeout time_since_speech = self.timestamp_ms - self.last_speech_timestamp if time_since_speech >= self.speech_timeout_ms: # Timeout - user probably stopped but VAD didn't detect it logger.warning(f"[STT] User {self.user_id} SPEECH TIMEOUT after {time_since_speech:.0f}ms - forcing finalization") self.is_speaking = False # Force final transcription await self._transcribe_final() # Clear buffer self.audio_buffer = [] self.last_partial_duration = 0.0 async def _transcribe_partial(self): """Transcribe accumulated audio and send partial result (no timestamps to save VRAM).""" if not self.audio_buffer: return # Concatenate audio audio_full = np.concatenate(self.audio_buffer) # Transcribe asynchronously WITHOUT timestamps for partials (saves 1-2GB VRAM) try: result = await parakeet_transcriber.transcribe_async( audio_full, sample_rate=16000, return_timestamps=False # Disable timestamps for partials to reduce VRAM usage ) # Result is just a string when timestamps=False text = result if isinstance(result, str) else result.get("text", "") if text and text != self.last_transcript: self.last_transcript = text # Send partial transcript without word tokens (saves memory) await self.websocket.send_json({ "type": "partial", "text": text, "words": [], # No word tokens for partials "user_id": self.user_id, "timestamp": self.timestamp_ms }) logger.info(f"Partial [{self.user_id}]: {text}") except Exception as e: logger.error(f"Partial transcription failed: {e}", exc_info=True) async def _transcribe_final(self): """Transcribe final accumulated audio with word tokens.""" if not self.audio_buffer: return # Concatenate all audio audio_full = np.concatenate(self.audio_buffer) try: result = await parakeet_transcriber.transcribe_async( audio_full, sample_rate=16000, return_timestamps=True ) if result and result.get("text"): self.last_transcript = result["text"] # Send final transcript with word tokens await self.websocket.send_json({ "type": "final", "text": result["text"], "words": result.get("words", []), # Word-level tokens for LLM "user_id": self.user_id, "timestamp": self.timestamp_ms }) logger.info(f"Final [{self.user_id}]: {result['text']}") except Exception as e: logger.error(f"Final transcription failed: {e}", exc_info=True) async def check_interruption(self, audio_data: bytes) -> bool: """ Check if user is interrupting (for use during Miku's speech). Args: audio_data: Raw PCM audio chunk Returns: True if interruption detected """ audio_np = np.frombuffer(audio_data, dtype=np.int16) speech_prob, is_speaking = vad_processor.process_chunk(audio_np) # Interruption: high probability sustained for threshold duration if speech_prob > 0.7: # Higher threshold for interruption await self.websocket.send_json({ "type": "interruption", "probability": speech_prob, "timestamp": self.timestamp_ms }) return True return False @app.on_event("startup") async def startup_event(): """Initialize models on server startup.""" global vad_processor, parakeet_transcriber logger.info("=" * 50) logger.info("Initializing Miku STT Server") logger.info("=" * 50) # Initialize VAD (CPU) logger.info("Loading Silero VAD model (CPU)...") vad_processor = VADProcessor( sample_rate=16000, threshold=0.5, min_speech_duration_ms=250, # Conservative - wait 250ms before starting min_silence_duration_ms=300 # Reduced from 500ms - detect silence faster ) logger.info("✓ VAD ready") # Initialize Parakeet (GPU) logger.info("Loading NVIDIA Parakeet TDT model (GPU)...") parakeet_transcriber = ParakeetTranscriber( model_name="nvidia/parakeet-tdt-0.6b-v3", device="cuda", language="en" ) logger.info("✓ Parakeet ready") logger.info("=" * 50) logger.info("STT Server ready to accept connections") logger.info("=" * 50) @app.on_event("shutdown") async def shutdown_event(): """Cleanup on server shutdown.""" logger.info("Shutting down STT server...") if parakeet_transcriber: parakeet_transcriber.cleanup() logger.info("STT server shutdown complete") @app.get("/") async def root(): """Health check endpoint.""" return { "service": "Miku STT Server", "status": "running", "vad_ready": vad_processor is not None, "whisper_ready": whisper_transcriber is not None, "active_sessions": len(user_sessions) } @app.get("/health") async def health(): """Detailed health check.""" return { "status": "healthy", "models": { "vad": { "loaded": vad_processor is not None, "device": "cpu" }, "whisper": { "loaded": whisper_transcriber is not None, "model": "small", "device": "cuda" } }, "sessions": { "active": len(user_sessions), "users": list(user_sessions.keys()) } } @app.websocket("/ws/stt/{user_id}") async def websocket_stt(websocket: WebSocket, user_id: str): """ WebSocket endpoint for real-time STT. Client sends: Raw PCM audio (int16, 16kHz mono, 20ms chunks) Server sends: JSON events: - {"type": "vad", "event": "speech_start|speaking|speech_end", ...} - {"type": "partial", "text": "...", ...} - {"type": "final", "text": "...", ...} - {"type": "interruption", "probability": 0.xx} """ await websocket.accept() logger.info(f"STT WebSocket connected: user {user_id}") # Create session session = UserSTTSession(user_id, websocket) user_sessions[user_id] = session try: # Send ready message await websocket.send_json({ "type": "ready", "user_id": user_id, "message": "STT session started" }) # Main loop: receive audio chunks while True: # Receive binary audio data data = await websocket.receive_bytes() # Process audio chunk await session.process_audio_chunk(data) except WebSocketDisconnect: logger.info(f"User {user_id} disconnected") except Exception as e: logger.error(f"Error in STT WebSocket for user {user_id}: {e}", exc_info=True) finally: # Cleanup session if user_id in user_sessions: del user_sessions[user_id] logger.info(f"STT session ended for user {user_id}") @app.post("/interrupt/check") async def check_interruption(user_id: str): """ Check if user is interrupting (for use during Miku's speech). Query param: user_id: Discord user ID Returns: {"interrupting": bool, "probability": float} """ session = user_sessions.get(user_id) if not session: raise HTTPException(status_code=404, detail="User session not found") # Get current VAD state vad_state = vad_processor.get_state() return { "interrupting": vad_state["speaking"], "user_id": user_id } if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000, log_level="info")