""" Discord Voice Receiver using discord-ext-voice-recv Captures audio from Discord voice channels and streams to STT. Uses the discord-ext-voice-recv extension for proper audio receiving support. """ import asyncio import audioop import logging from typing import Dict, Optional from collections import deque import discord from discord.ext import voice_recv from utils.stt_client import STTClient logger = logging.getLogger('voice_receiver') class VoiceReceiverSink(voice_recv.AudioSink): """ Audio sink that receives Discord audio and forwards to STT. This sink processes incoming audio from Discord voice channels, decodes/resamples as needed, and sends to STT clients for transcription. """ def __init__(self, voice_manager, stt_url: str = "ws://miku-stt:8766/ws/stt"): """ Initialize Voice Receiver. Args: voice_manager: The voice manager instance stt_url: Base URL for STT WebSocket server with path (port 8766 inside container) """ super().__init__() self.voice_manager = voice_manager self.stt_url = stt_url # Store event loop for thread-safe async calls # Use get_running_loop() in async context, or store it when available try: self.loop = asyncio.get_running_loop() except RuntimeError: # Fallback if not in async context yet self.loop = asyncio.get_event_loop() # Per-user STT clients self.stt_clients: Dict[int, STTClient] = {} # Audio buffers per user (for resampling state) self.audio_buffers: Dict[int, deque] = {} # User info (for logging) self.users: Dict[int, discord.User] = {} # Silence tracking for detecting end of speech self.last_audio_time: Dict[int, float] = {} self.silence_tasks: Dict[int, asyncio.Task] = {} self.silence_timeout = 1.0 # seconds of silence before sending "final" # Interruption detection self.interruption_start_time: Dict[int, float] = {} self.interruption_audio_count: Dict[int, int] = {} self.interruption_threshold_time = 0.8 # seconds of speech to count as interruption self.interruption_threshold_chunks = 8 # minimum audio chunks to count as interruption # Active flag self.active = False logger.info("VoiceReceiverSink initialized") def wants_opus(self) -> bool: """ Tell discord-ext-voice-recv we want Opus data, NOT decoded PCM. We'll decode it ourselves to avoid decoder errors from discord-ext-voice-recv. Returns: True - we want Opus packets, we'll handle decoding """ return True # Get Opus, decode ourselves to avoid packet router errors def write(self, user: Optional[discord.User], data: voice_recv.VoiceData): """ Called by discord-ext-voice-recv when audio is received. This is the main callback that receives audio packets from Discord. We get Opus data, decode it ourselves, resample, and forward to STT. Args: user: Discord user who sent the audio (None if unknown) data: Voice data container with pcm, opus, and packet info """ if not user: return # Skip packets from unknown users user_id = user.id # Check if we're listening to this user if user_id not in self.stt_clients: return try: # Get Opus data (we decode ourselves to avoid PacketRouter errors) opus_data = data.opus if not opus_data: return # Decode Opus to PCM (48kHz stereo int16) # Use discord.py's opus decoder with proper error handling import discord.opus if not hasattr(self, '_opus_decoders'): self._opus_decoders = {} # Create decoder for this user if needed if user_id not in self._opus_decoders: self._opus_decoders[user_id] = discord.opus.Decoder() decoder = self._opus_decoders[user_id] # Decode opus -> PCM (this can fail on corrupt packets, so catch it) try: pcm_data = decoder.decode(opus_data, fec=False) except discord.opus.OpusError as e: # Skip corrupted packets silently (common at stream start) logger.debug(f"Skipping corrupted opus packet for user {user_id}: {e}") return if not pcm_data: return # PCM from Discord is 48kHz stereo int16 # Convert stereo to mono if len(pcm_data) % 4 == 0: # Stereo (2 channels * 2 bytes per sample) pcm_mono = audioop.tomono(pcm_data, 2, 0.5, 0.5) else: pcm_mono = pcm_data # Resample from 48kHz to 16kHz for STT # Discord sends 20ms chunks: 960 samples @ 48kHz → 320 samples @ 16kHz pcm_16k, _ = audioop.ratecv(pcm_mono, 2, 1, 48000, 16000, None) # Send to STT client (schedule on event loop thread-safely) asyncio.run_coroutine_threadsafe( self._send_audio_chunk(user_id, pcm_16k), self.loop ) except Exception as e: logger.error(f"Error processing audio for user {user_id}: {e}", exc_info=True) def cleanup(self): """ Called when the sink is stopped. Cleanup any resources. """ logger.info("VoiceReceiverSink cleanup") # Async cleanup handled separately in stop_all() async def start_listening(self, user_id: int, user: discord.User): """ Start listening to a specific user. Creates an STT client connection for this user and registers callbacks. Args: user_id: Discord user ID user: Discord user object """ if user_id in self.stt_clients: logger.warning(f"Already listening to user {user.name} ({user_id})") return logger.info(f"Starting to listen to user {user.name} ({user_id})") # Store user info self.users[user_id] = user # Initialize audio buffer self.audio_buffers[user_id] = deque(maxlen=1000) # Create STT client with callbacks stt_client = STTClient( user_id=user_id, stt_url=self.stt_url, on_vad_event=lambda event: asyncio.create_task( self._on_vad_event(user_id, event) ), on_partial_transcript=lambda text, timestamp: asyncio.create_task( self._on_partial_transcript(user_id, text) ), on_final_transcript=lambda text, timestamp: asyncio.create_task( self._on_final_transcript(user_id, text, user) ), on_interruption=lambda prob: asyncio.create_task( self._on_interruption(user_id, prob) ) ) # Connect to STT server try: await stt_client.connect() self.stt_clients[user_id] = stt_client self.active = True logger.info(f"✓ STT connected for user {user.name}") except Exception as e: logger.error(f"Failed to connect STT for user {user.name}: {e}", exc_info=True) # Cleanup partial state if user_id in self.audio_buffers: del self.audio_buffers[user_id] if user_id in self.users: del self.users[user_id] raise async def stop_listening(self, user_id: int): """ Stop listening to a specific user. Disconnects the STT client and cleans up resources for this user. Args: user_id: Discord user ID """ if user_id not in self.stt_clients: logger.warning(f"Not listening to user {user_id}") return user = self.users.get(user_id) logger.info(f"Stopping listening to user {user.name if user else user_id}") # Disconnect STT client stt_client = self.stt_clients[user_id] await stt_client.disconnect() # Cleanup del self.stt_clients[user_id] if user_id in self.audio_buffers: del self.audio_buffers[user_id] if user_id in self.users: del self.users[user_id] # Cancel silence detection task if user_id in self.silence_tasks and not self.silence_tasks[user_id].done(): self.silence_tasks[user_id].cancel() del self.silence_tasks[user_id] if user_id in self.last_audio_time: del self.last_audio_time[user_id] # Clear interruption tracking self.interruption_start_time.pop(user_id, None) self.interruption_audio_count.pop(user_id, None) # Cleanup opus decoder for this user if hasattr(self, '_opus_decoders') and user_id in self._opus_decoders: del self._opus_decoders[user_id] # Update active flag if not self.stt_clients: self.active = False logger.info(f"✓ Stopped listening to user {user.name if user else user_id}") async def stop_all(self): """Stop listening to all users and cleanup all resources.""" logger.info("Stopping all voice receivers") user_ids = list(self.stt_clients.keys()) for user_id in user_ids: await self.stop_listening(user_id) self.active = False logger.info("✓ All voice receivers stopped") async def _send_audio_chunk(self, user_id: int, audio_data: bytes): """ Send audio chunk to STT client. Buffers audio until we have 512 samples (32ms @ 16kHz) which is what Silero VAD expects. Discord sends 320 samples (20ms), so we buffer 2 chunks and send 640 samples, then the STT server can split it. Args: user_id: Discord user ID audio_data: PCM audio (int16, 16kHz mono, 320 samples = 640 bytes) """ stt_client = self.stt_clients.get(user_id) if not stt_client or not stt_client.is_connected(): return try: # Get or create buffer for this user if user_id not in self.audio_buffers: self.audio_buffers[user_id] = deque() buffer = self.audio_buffers[user_id] buffer.append(audio_data) # Silero VAD expects 512 samples @ 16kHz (1024 bytes) # Discord gives us 320 samples (640 bytes) every 20ms # Buffer 2 chunks = 640 samples = 1280 bytes, send as one chunk SAMPLES_NEEDED = 512 # What VAD wants BYTES_NEEDED = SAMPLES_NEEDED * 2 # int16 = 2 bytes per sample # Check if we have enough buffered audio total_bytes = sum(len(chunk) for chunk in buffer) if total_bytes >= BYTES_NEEDED: # Concatenate buffered chunks combined = b''.join(buffer) buffer.clear() # Send in 512-sample (1024-byte) chunks for i in range(0, len(combined), BYTES_NEEDED): chunk = combined[i:i+BYTES_NEEDED] if len(chunk) == BYTES_NEEDED: await stt_client.send_audio(chunk) else: # Put remaining partial chunk back in buffer buffer.append(chunk) # Track audio time for silence detection import time current_time = time.time() self.last_audio_time[user_id] = current_time # ===== INTERRUPTION DETECTION ===== # Check if Miku is speaking and user is interrupting # Note: self.voice_manager IS the VoiceSession, not the VoiceManager singleton miku_speaking = self.voice_manager.miku_speaking logger.debug(f"[INTERRUPTION CHECK] user={user_id}, miku_speaking={miku_speaking}") if miku_speaking: # Track interruption if user_id not in self.interruption_start_time: # First chunk during Miku's speech self.interruption_start_time[user_id] = current_time self.interruption_audio_count[user_id] = 1 else: # Increment chunk count self.interruption_audio_count[user_id] += 1 # Calculate interruption duration interruption_duration = current_time - self.interruption_start_time[user_id] chunk_count = self.interruption_audio_count[user_id] # Check if interruption threshold is met if (interruption_duration >= self.interruption_threshold_time and chunk_count >= self.interruption_threshold_chunks): # Trigger interruption! logger.info(f"🛑 User {user_id} interrupted Miku (duration={interruption_duration:.2f}s, chunks={chunk_count})") logger.info(f" → Stopping Miku's TTS and LLM, will process user's speech when finished") # Reset interruption tracking self.interruption_start_time.pop(user_id, None) self.interruption_audio_count.pop(user_id, None) # Call interruption handler (this sets miku_speaking=False) asyncio.create_task( self.voice_manager.on_user_interruption(user_id) ) else: # Miku not speaking, clear interruption tracking self.interruption_start_time.pop(user_id, None) self.interruption_audio_count.pop(user_id, None) # Cancel existing silence task if any if user_id in self.silence_tasks and not self.silence_tasks[user_id].done(): self.silence_tasks[user_id].cancel() # Start new silence detection task self.silence_tasks[user_id] = asyncio.create_task( self._detect_silence(user_id) ) except Exception as e: logger.error(f"Failed to send audio chunk for user {user_id}: {e}") async def _detect_silence(self, user_id: int): """ Wait for silence timeout and send 'final' command to STT. This is called after each audio chunk. If no more audio arrives within the silence_timeout period, we send the 'final' command to get the complete transcription. Args: user_id: Discord user ID """ try: # Wait for silence timeout await asyncio.sleep(self.silence_timeout) # Check if we still have an active STT client stt_client = self.stt_clients.get(user_id) if not stt_client or not stt_client.is_connected(): return # Send final command to get complete transcription logger.debug(f"Silence detected for user {user_id}, requesting final transcript") await stt_client.send_final() except asyncio.CancelledError: # Task was cancelled because new audio arrived pass except Exception as e: logger.error(f"Error in silence detection for user {user_id}: {e}") async def _on_vad_event(self, user_id: int, event: dict): """ Handle VAD event from STT. Args: user_id: Discord user ID event: VAD event dictionary with 'event' and 'probability' keys """ user = self.users.get(user_id) event_type = event.get('event', 'unknown') probability = event.get('probability', 0.0) logger.debug(f"VAD [{user.name if user else user_id}]: {event_type} (prob={probability:.3f})") # Notify voice manager - pass the full event dict if hasattr(self.voice_manager, 'on_user_vad_event'): await self.voice_manager.on_user_vad_event(user_id, event) async def _on_partial_transcript(self, user_id: int, text: str): """ Handle partial transcript from STT. Args: user_id: Discord user ID text: Partial transcript text """ user = self.users.get(user_id) logger.info(f"[VOICE_RECEIVER] Partial [{user.name if user else user_id}]: {text}") print(f"[DEBUG] PARTIAL TRANSCRIPT RECEIVED: {text}") # Extra debug # Notify voice manager if hasattr(self.voice_manager, 'on_partial_transcript'): await self.voice_manager.on_partial_transcript(user_id, text) async def _on_final_transcript(self, user_id: int, text: str, user: discord.User): """ Handle final transcript from STT. This triggers the LLM response generation. Args: user_id: Discord user ID text: Final transcript text user: Discord user object """ logger.info(f"[VOICE_RECEIVER] Final [{user.name if user else user_id}]: {text}") print(f"[DEBUG] FINAL TRANSCRIPT RECEIVED: {text}") # Extra debug # Notify voice manager - THIS TRIGGERS LLM RESPONSE if hasattr(self.voice_manager, 'on_final_transcript'): await self.voice_manager.on_final_transcript(user_id, text) async def _on_interruption(self, user_id: int, probability: float): """ Handle interruption detection from STT. This cancels Miku's current speech if user interrupts. Args: user_id: Discord user ID probability: Interruption confidence probability """ user = self.users.get(user_id) logger.info(f"Interruption from [{user.name if user else user_id}] (prob={probability:.3f})") # Notify voice manager - THIS CANCELS MIKU'S SPEECH if hasattr(self.voice_manager, 'on_user_interruption'): await self.voice_manager.on_user_interruption(user_id, probability) def get_listening_users(self) -> list: """ Get list of users currently being listened to. Returns: List of dicts with user_id, username, and connection status """ return [ { 'user_id': user_id, 'username': user.name if user else 'Unknown', 'connected': client.is_connected() } for user_id, (user, client) in [(uid, (self.users.get(uid), self.stt_clients.get(uid))) for uid in self.stt_clients.keys()] ] @voice_recv.AudioSink.listener() def on_voice_member_speaking_start(self, member: discord.Member): """ Called when a member starts speaking (green circle appears). This is a virtual event from discord-ext-voice-recv based on packet activity. """ if member.id in self.stt_clients: logger.debug(f"🎤 {member.name} started speaking") @voice_recv.AudioSink.listener() def on_voice_member_speaking_stop(self, member: discord.Member): """ Called when a member stops speaking (green circle disappears). This is a virtual event from discord-ext-voice-recv based on packet activity. """ if member.id in self.stt_clients: logger.debug(f"🔇 {member.name} stopped speaking")