Files
miku-discord/bot/utils/voice_receiver.py.old

420 lines
14 KiB
Python

"""
Discord Voice Receiver
Captures audio from Discord voice channels and streams to STT.
Handles opus decoding and audio preprocessing.
"""
import discord
import audioop
import numpy as np
import asyncio
import logging
from typing import Dict, Optional
from collections import deque
from utils.stt_client import STTClient
logger = logging.getLogger('voice_receiver')
class VoiceReceiver(discord.sinks.Sink):
"""
Voice Receiver for Discord Audio Capture
Captures audio from Discord voice channels using discord.py's voice websocket.
Processes Opus audio, decodes to PCM, resamples to 16kHz mono for STT.
Note: Standard discord.py doesn't have built-in audio receiving.
This implementation hooks into the voice websocket directly.
"""
import asyncio
import struct
import audioop
import logging
from typing import Dict, Optional, Callable
import discord
# Import opus decoder
try:
import discord.opus as opus
if not opus.is_loaded():
opus.load_opus('opus')
except Exception as e:
logging.error(f"Failed to load opus: {e}")
from utils.stt_client import STTClient
logger = logging.getLogger('voice_receiver')
class VoiceReceiver:
"""
Receives and processes audio from Discord voice channel.
This class monkey-patches the VoiceClient to intercept received RTP packets,
decodes Opus audio, and forwards to STT clients.
"""
def __init__(
self,
voice_client: discord.VoiceClient,
voice_manager,
stt_url: str = "ws://miku-stt:8001"
):
"""
Initialize voice receiver.
Args:
voice_client: Discord VoiceClient to receive audio from
voice_manager: Voice manager instance for callbacks
stt_url: Base URL for STT WebSocket server
"""
self.voice_client = voice_client
self.voice_manager = voice_manager
self.stt_url = stt_url
# Per-user STT clients
self.stt_clients: Dict[int, STTClient] = {}
# Opus decoder instances per SSRC (one per user)
self.opus_decoders: Dict[int, any] = {}
# Resampler state per user (for 48kHz → 16kHz)
self.resample_state: Dict[int, tuple] = {}
# Original receive method (for restoration)
self._original_receive = None
# Active flag
self.active = False
logger.info("VoiceReceiver initialized")
async def start_listening(self, user_id: int, user: discord.User):
"""
Start listening to a specific user's audio.
Args:
user_id: Discord user ID
user: Discord User object
"""
if user_id in self.stt_clients:
logger.warning(f"Already listening to user {user_id}")
return
try:
# Create STT client for this user
stt_client = STTClient(
user_id=user_id,
stt_url=self.stt_url,
on_vad_event=lambda event, prob: asyncio.create_task(
self.voice_manager.on_user_vad_event(user_id, event)
),
on_partial_transcript=lambda text: asyncio.create_task(
self.voice_manager.on_partial_transcript(user_id, text)
),
on_final_transcript=lambda text: asyncio.create_task(
self.voice_manager.on_final_transcript(user_id, text, user)
),
on_interruption=lambda prob: asyncio.create_task(
self.voice_manager.on_user_interruption(user_id, prob)
)
)
# Connect to STT server
await stt_client.connect()
# Store client
self.stt_clients[user_id] = stt_client
# Initialize opus decoder for this user if needed
# (Will be done when we receive their SSRC)
# Patch voice client to receive audio if not already patched
if not self.active:
await self._patch_voice_client()
logger.info(f"✓ Started listening to user {user_id} ({user.name})")
except Exception as e:
logger.error(f"Failed to start listening to user {user_id}: {e}", exc_info=True)
raise
async def stop_listening(self, user_id: int):
"""
Stop listening to a specific user.
Args:
user_id: Discord user ID
"""
if user_id not in self.stt_clients:
logger.warning(f"Not listening to user {user_id}")
return
try:
# Disconnect STT client
stt_client = self.stt_clients.pop(user_id)
await stt_client.disconnect()
# Clean up decoder and resampler state
# Note: We don't know the SSRC here, so we'll just remove by user_id
# Actual cleanup happens in _process_audio when we match SSRC to user_id
# If no more clients, unpatch voice client
if not self.stt_clients:
await self._unpatch_voice_client()
logger.info(f"✓ Stopped listening to user {user_id}")
except Exception as e:
logger.error(f"Failed to stop listening to user {user_id}: {e}", exc_info=True)
raise
async def _patch_voice_client(self):
"""Patch VoiceClient to intercept received audio packets."""
logger.warning("⚠️ Audio receiving not yet implemented - discord.py doesn't support receiving by default")
logger.warning("⚠️ You need discord.py-self or a custom fork with receiving support")
logger.warning("⚠️ STT will not receive any audio until this is implemented")
self.active = True
# TODO: Implement RTP packet receiving
# This requires either:
# 1. Using discord.py-self which has receiving support
# 2. Monkey-patching voice_client.ws to intercept packets
# 3. Using a separate UDP socket listener
async def _unpatch_voice_client(self):
"""Restore original VoiceClient behavior."""
self.active = False
logger.info("Unpatch voice client (receiving disabled)")
async def _process_audio(self, ssrc: int, opus_data: bytes):
"""
Process received Opus audio packet.
Args:
ssrc: RTP SSRC (identifies the audio source/user)
opus_data: Opus-encoded audio data
"""
# TODO: Map SSRC to user_id (requires tracking voice state updates)
# For now, this is a placeholder
pass
async def cleanup(self):
"""Clean up all resources."""
# Disconnect all STT clients
for user_id in list(self.stt_clients.keys()):
await self.stop_listening(user_id)
# Unpatch voice client
if self.active:
await self._unpatch_voice_client()
logger.info("VoiceReceiver cleanup complete") def __init__(self, voice_manager):
"""
Initialize voice receiver.
Args:
voice_manager: Reference to VoiceManager for callbacks
"""
super().__init__()
self.voice_manager = voice_manager
# Per-user STT clients
self.stt_clients: Dict[int, STTClient] = {}
# Audio buffers per user (for resampling)
self.audio_buffers: Dict[int, deque] = {}
# User info (for logging)
self.users: Dict[int, discord.User] = {}
logger.info("Voice receiver initialized")
async def start_listening(self, user_id: int, user: discord.User):
"""
Start listening to a specific user.
Args:
user_id: Discord user ID
user: Discord user object
"""
if user_id in self.stt_clients:
logger.warning(f"Already listening to user {user.name} ({user_id})")
return
logger.info(f"Starting to listen to user {user.name} ({user_id})")
# Store user info
self.users[user_id] = user
# Initialize audio buffer
self.audio_buffers[user_id] = deque(maxlen=1000) # Max 1000 chunks
# Create STT client with callbacks
stt_client = STTClient(
user_id=str(user_id),
on_vad_event=lambda event: self._on_vad_event(user_id, event),
on_partial_transcript=lambda text, ts: self._on_partial_transcript(user_id, text, ts),
on_final_transcript=lambda text, ts: self._on_final_transcript(user_id, text, ts),
on_interruption=lambda prob: self._on_interruption(user_id, prob)
)
# Connect to STT
try:
await stt_client.connect()
self.stt_clients[user_id] = stt_client
logger.info(f"✓ STT connected for user {user.name}")
except Exception as e:
logger.error(f"Failed to connect STT for user {user.name}: {e}")
async def stop_listening(self, user_id: int):
"""
Stop listening to a specific user.
Args:
user_id: Discord user ID
"""
if user_id not in self.stt_clients:
return
user = self.users.get(user_id)
logger.info(f"Stopping listening to user {user.name if user else user_id}")
# Disconnect STT client
stt_client = self.stt_clients[user_id]
await stt_client.disconnect()
# Cleanup
del self.stt_clients[user_id]
if user_id in self.audio_buffers:
del self.audio_buffers[user_id]
if user_id in self.users:
del self.users[user_id]
logger.info(f"✓ Stopped listening to user {user.name if user else user_id}")
async def stop_all(self):
"""Stop listening to all users."""
logger.info("Stopping all voice receivers")
user_ids = list(self.stt_clients.keys())
for user_id in user_ids:
await self.stop_listening(user_id)
logger.info("✓ All voice receivers stopped")
def write(self, data: discord.sinks.core.AudioData):
"""
Called by discord.py when audio is received.
Args:
data: Audio data from Discord
"""
# Get user ID from SSRC
user_id = data.user.id if data.user else None
if not user_id:
return
# Check if we're listening to this user
if user_id not in self.stt_clients:
return
# Process audio
try:
# Decode opus to PCM (48kHz stereo)
pcm_data = data.pcm
# Convert stereo to mono if needed
if len(pcm_data) % 4 == 0: # Stereo int16 (2 channels * 2 bytes)
# Average left and right channels
pcm_mono = audioop.tomono(pcm_data, 2, 0.5, 0.5)
else:
pcm_mono = pcm_data
# Resample from 48kHz to 16kHz
# Discord sends 20ms chunks at 48kHz = 960 samples
# We need 320 samples at 16kHz (20ms)
pcm_16k = audioop.ratecv(pcm_mono, 2, 1, 48000, 16000, None)[0]
# Send to STT
asyncio.create_task(self._send_audio_chunk(user_id, pcm_16k))
except Exception as e:
logger.error(f"Error processing audio for user {user_id}: {e}")
async def _send_audio_chunk(self, user_id: int, audio_data: bytes):
"""
Send audio chunk to STT client.
Args:
user_id: Discord user ID
audio_data: PCM audio (int16, 16kHz mono)
"""
stt_client = self.stt_clients.get(user_id)
if not stt_client or not stt_client.is_connected():
return
try:
await stt_client.send_audio(audio_data)
except Exception as e:
logger.error(f"Failed to send audio chunk for user {user_id}: {e}")
async def _on_vad_event(self, user_id: int, event: dict):
"""Handle VAD event from STT."""
user = self.users.get(user_id)
event_type = event.get('event')
probability = event.get('probability', 0)
logger.debug(f"VAD [{user.name if user else user_id}]: {event_type} (prob={probability:.3f})")
# Notify voice manager
if hasattr(self.voice_manager, 'on_user_vad_event'):
await self.voice_manager.on_user_vad_event(user_id, event)
async def _on_partial_transcript(self, user_id: int, text: str, timestamp: float):
"""Handle partial transcript from STT."""
user = self.users.get(user_id)
logger.info(f"Partial [{user.name if user else user_id}]: {text}")
# Notify voice manager
if hasattr(self.voice_manager, 'on_partial_transcript'):
await self.voice_manager.on_partial_transcript(user_id, text)
async def _on_final_transcript(self, user_id: int, text: str, timestamp: float):
"""Handle final transcript from STT."""
user = self.users.get(user_id)
logger.info(f"Final [{user.name if user else user_id}]: {text}")
# Notify voice manager - THIS TRIGGERS LLM RESPONSE
if hasattr(self.voice_manager, 'on_final_transcript'):
await self.voice_manager.on_final_transcript(user_id, text)
async def _on_interruption(self, user_id: int, probability: float):
"""Handle interruption detection from STT."""
user = self.users.get(user_id)
logger.info(f"Interruption from [{user.name if user else user_id}] (prob={probability:.3f})")
# Notify voice manager - THIS CANCELS MIKU'S SPEECH
if hasattr(self.voice_manager, 'on_user_interruption'):
await self.voice_manager.on_user_interruption(user_id, probability)
def cleanup(self):
"""Cleanup resources."""
logger.info("Cleaning up voice receiver")
# Async cleanup will be called separately
def get_listening_users(self) -> list:
"""Get list of users currently being listened to."""
return [
{
'user_id': user_id,
'username': user.name if user else 'Unknown',
'connected': client.is_connected()
}
for user_id, (user, client) in
[(uid, (self.users.get(uid), self.stt_clients.get(uid)))
for uid in self.stt_clients.keys()]
]