Phase 4 STT pipeline implemented — Silero VAD + faster-whisper — still not working well at all
This commit is contained in:
419
bot/utils/voice_receiver.py.old
Normal file
419
bot/utils/voice_receiver.py.old
Normal file
@@ -0,0 +1,419 @@
|
||||
"""
|
||||
Discord Voice Receiver
|
||||
|
||||
Captures audio from Discord voice channels and streams to STT.
|
||||
Handles opus decoding and audio preprocessing.
|
||||
"""
|
||||
|
||||
import discord
|
||||
import audioop
|
||||
import numpy as np
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Dict, Optional
|
||||
from collections import deque
|
||||
|
||||
from utils.stt_client import STTClient
|
||||
|
||||
logger = logging.getLogger('voice_receiver')
|
||||
|
||||
|
||||
class VoiceReceiver(discord.sinks.Sink):
|
||||
"""
|
||||
Voice Receiver for Discord Audio Capture
|
||||
|
||||
Captures audio from Discord voice channels using discord.py's voice websocket.
|
||||
Processes Opus audio, decodes to PCM, resamples to 16kHz mono for STT.
|
||||
|
||||
Note: Standard discord.py doesn't have built-in audio receiving.
|
||||
This implementation hooks into the voice websocket directly.
|
||||
"""
|
||||
import asyncio
|
||||
import struct
|
||||
import audioop
|
||||
import logging
|
||||
from typing import Dict, Optional, Callable
|
||||
import discord
|
||||
|
||||
# Import opus decoder
|
||||
try:
|
||||
import discord.opus as opus
|
||||
if not opus.is_loaded():
|
||||
opus.load_opus('opus')
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to load opus: {e}")
|
||||
|
||||
from utils.stt_client import STTClient
|
||||
|
||||
logger = logging.getLogger('voice_receiver')
|
||||
|
||||
|
||||
class VoiceReceiver:
|
||||
"""
|
||||
Receives and processes audio from Discord voice channel.
|
||||
|
||||
This class monkey-patches the VoiceClient to intercept received RTP packets,
|
||||
decodes Opus audio, and forwards to STT clients.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
voice_client: discord.VoiceClient,
|
||||
voice_manager,
|
||||
stt_url: str = "ws://miku-stt:8001"
|
||||
):
|
||||
"""
|
||||
Initialize voice receiver.
|
||||
|
||||
Args:
|
||||
voice_client: Discord VoiceClient to receive audio from
|
||||
voice_manager: Voice manager instance for callbacks
|
||||
stt_url: Base URL for STT WebSocket server
|
||||
"""
|
||||
self.voice_client = voice_client
|
||||
self.voice_manager = voice_manager
|
||||
self.stt_url = stt_url
|
||||
|
||||
# Per-user STT clients
|
||||
self.stt_clients: Dict[int, STTClient] = {}
|
||||
|
||||
# Opus decoder instances per SSRC (one per user)
|
||||
self.opus_decoders: Dict[int, any] = {}
|
||||
|
||||
# Resampler state per user (for 48kHz → 16kHz)
|
||||
self.resample_state: Dict[int, tuple] = {}
|
||||
|
||||
# Original receive method (for restoration)
|
||||
self._original_receive = None
|
||||
|
||||
# Active flag
|
||||
self.active = False
|
||||
|
||||
logger.info("VoiceReceiver initialized")
|
||||
|
||||
async def start_listening(self, user_id: int, user: discord.User):
|
||||
"""
|
||||
Start listening to a specific user's audio.
|
||||
|
||||
Args:
|
||||
user_id: Discord user ID
|
||||
user: Discord User object
|
||||
"""
|
||||
if user_id in self.stt_clients:
|
||||
logger.warning(f"Already listening to user {user_id}")
|
||||
return
|
||||
|
||||
try:
|
||||
# Create STT client for this user
|
||||
stt_client = STTClient(
|
||||
user_id=user_id,
|
||||
stt_url=self.stt_url,
|
||||
on_vad_event=lambda event, prob: asyncio.create_task(
|
||||
self.voice_manager.on_user_vad_event(user_id, event)
|
||||
),
|
||||
on_partial_transcript=lambda text: asyncio.create_task(
|
||||
self.voice_manager.on_partial_transcript(user_id, text)
|
||||
),
|
||||
on_final_transcript=lambda text: asyncio.create_task(
|
||||
self.voice_manager.on_final_transcript(user_id, text, user)
|
||||
),
|
||||
on_interruption=lambda prob: asyncio.create_task(
|
||||
self.voice_manager.on_user_interruption(user_id, prob)
|
||||
)
|
||||
)
|
||||
|
||||
# Connect to STT server
|
||||
await stt_client.connect()
|
||||
|
||||
# Store client
|
||||
self.stt_clients[user_id] = stt_client
|
||||
|
||||
# Initialize opus decoder for this user if needed
|
||||
# (Will be done when we receive their SSRC)
|
||||
|
||||
# Patch voice client to receive audio if not already patched
|
||||
if not self.active:
|
||||
await self._patch_voice_client()
|
||||
|
||||
logger.info(f"✓ Started listening to user {user_id} ({user.name})")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start listening to user {user_id}: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def stop_listening(self, user_id: int):
|
||||
"""
|
||||
Stop listening to a specific user.
|
||||
|
||||
Args:
|
||||
user_id: Discord user ID
|
||||
"""
|
||||
if user_id not in self.stt_clients:
|
||||
logger.warning(f"Not listening to user {user_id}")
|
||||
return
|
||||
|
||||
try:
|
||||
# Disconnect STT client
|
||||
stt_client = self.stt_clients.pop(user_id)
|
||||
await stt_client.disconnect()
|
||||
|
||||
# Clean up decoder and resampler state
|
||||
# Note: We don't know the SSRC here, so we'll just remove by user_id
|
||||
# Actual cleanup happens in _process_audio when we match SSRC to user_id
|
||||
|
||||
# If no more clients, unpatch voice client
|
||||
if not self.stt_clients:
|
||||
await self._unpatch_voice_client()
|
||||
|
||||
logger.info(f"✓ Stopped listening to user {user_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to stop listening to user {user_id}: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def _patch_voice_client(self):
|
||||
"""Patch VoiceClient to intercept received audio packets."""
|
||||
logger.warning("⚠️ Audio receiving not yet implemented - discord.py doesn't support receiving by default")
|
||||
logger.warning("⚠️ You need discord.py-self or a custom fork with receiving support")
|
||||
logger.warning("⚠️ STT will not receive any audio until this is implemented")
|
||||
self.active = True
|
||||
# TODO: Implement RTP packet receiving
|
||||
# This requires either:
|
||||
# 1. Using discord.py-self which has receiving support
|
||||
# 2. Monkey-patching voice_client.ws to intercept packets
|
||||
# 3. Using a separate UDP socket listener
|
||||
|
||||
async def _unpatch_voice_client(self):
|
||||
"""Restore original VoiceClient behavior."""
|
||||
self.active = False
|
||||
logger.info("Unpatch voice client (receiving disabled)")
|
||||
|
||||
async def _process_audio(self, ssrc: int, opus_data: bytes):
|
||||
"""
|
||||
Process received Opus audio packet.
|
||||
|
||||
Args:
|
||||
ssrc: RTP SSRC (identifies the audio source/user)
|
||||
opus_data: Opus-encoded audio data
|
||||
"""
|
||||
# TODO: Map SSRC to user_id (requires tracking voice state updates)
|
||||
# For now, this is a placeholder
|
||||
pass
|
||||
|
||||
async def cleanup(self):
|
||||
"""Clean up all resources."""
|
||||
# Disconnect all STT clients
|
||||
for user_id in list(self.stt_clients.keys()):
|
||||
await self.stop_listening(user_id)
|
||||
|
||||
# Unpatch voice client
|
||||
if self.active:
|
||||
await self._unpatch_voice_client()
|
||||
|
||||
logger.info("VoiceReceiver cleanup complete") def __init__(self, voice_manager):
|
||||
"""
|
||||
Initialize voice receiver.
|
||||
|
||||
Args:
|
||||
voice_manager: Reference to VoiceManager for callbacks
|
||||
"""
|
||||
super().__init__()
|
||||
self.voice_manager = voice_manager
|
||||
|
||||
# Per-user STT clients
|
||||
self.stt_clients: Dict[int, STTClient] = {}
|
||||
|
||||
# Audio buffers per user (for resampling)
|
||||
self.audio_buffers: Dict[int, deque] = {}
|
||||
|
||||
# User info (for logging)
|
||||
self.users: Dict[int, discord.User] = {}
|
||||
|
||||
logger.info("Voice receiver initialized")
|
||||
|
||||
async def start_listening(self, user_id: int, user: discord.User):
|
||||
"""
|
||||
Start listening to a specific user.
|
||||
|
||||
Args:
|
||||
user_id: Discord user ID
|
||||
user: Discord user object
|
||||
"""
|
||||
if user_id in self.stt_clients:
|
||||
logger.warning(f"Already listening to user {user.name} ({user_id})")
|
||||
return
|
||||
|
||||
logger.info(f"Starting to listen to user {user.name} ({user_id})")
|
||||
|
||||
# Store user info
|
||||
self.users[user_id] = user
|
||||
|
||||
# Initialize audio buffer
|
||||
self.audio_buffers[user_id] = deque(maxlen=1000) # Max 1000 chunks
|
||||
|
||||
# Create STT client with callbacks
|
||||
stt_client = STTClient(
|
||||
user_id=str(user_id),
|
||||
on_vad_event=lambda event: self._on_vad_event(user_id, event),
|
||||
on_partial_transcript=lambda text, ts: self._on_partial_transcript(user_id, text, ts),
|
||||
on_final_transcript=lambda text, ts: self._on_final_transcript(user_id, text, ts),
|
||||
on_interruption=lambda prob: self._on_interruption(user_id, prob)
|
||||
)
|
||||
|
||||
# Connect to STT
|
||||
try:
|
||||
await stt_client.connect()
|
||||
self.stt_clients[user_id] = stt_client
|
||||
logger.info(f"✓ STT connected for user {user.name}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect STT for user {user.name}: {e}")
|
||||
|
||||
async def stop_listening(self, user_id: int):
|
||||
"""
|
||||
Stop listening to a specific user.
|
||||
|
||||
Args:
|
||||
user_id: Discord user ID
|
||||
"""
|
||||
if user_id not in self.stt_clients:
|
||||
return
|
||||
|
||||
user = self.users.get(user_id)
|
||||
logger.info(f"Stopping listening to user {user.name if user else user_id}")
|
||||
|
||||
# Disconnect STT client
|
||||
stt_client = self.stt_clients[user_id]
|
||||
await stt_client.disconnect()
|
||||
|
||||
# Cleanup
|
||||
del self.stt_clients[user_id]
|
||||
if user_id in self.audio_buffers:
|
||||
del self.audio_buffers[user_id]
|
||||
if user_id in self.users:
|
||||
del self.users[user_id]
|
||||
|
||||
logger.info(f"✓ Stopped listening to user {user.name if user else user_id}")
|
||||
|
||||
async def stop_all(self):
|
||||
"""Stop listening to all users."""
|
||||
logger.info("Stopping all voice receivers")
|
||||
|
||||
user_ids = list(self.stt_clients.keys())
|
||||
for user_id in user_ids:
|
||||
await self.stop_listening(user_id)
|
||||
|
||||
logger.info("✓ All voice receivers stopped")
|
||||
|
||||
def write(self, data: discord.sinks.core.AudioData):
|
||||
"""
|
||||
Called by discord.py when audio is received.
|
||||
|
||||
Args:
|
||||
data: Audio data from Discord
|
||||
"""
|
||||
# Get user ID from SSRC
|
||||
user_id = data.user.id if data.user else None
|
||||
|
||||
if not user_id:
|
||||
return
|
||||
|
||||
# Check if we're listening to this user
|
||||
if user_id not in self.stt_clients:
|
||||
return
|
||||
|
||||
# Process audio
|
||||
try:
|
||||
# Decode opus to PCM (48kHz stereo)
|
||||
pcm_data = data.pcm
|
||||
|
||||
# Convert stereo to mono if needed
|
||||
if len(pcm_data) % 4 == 0: # Stereo int16 (2 channels * 2 bytes)
|
||||
# Average left and right channels
|
||||
pcm_mono = audioop.tomono(pcm_data, 2, 0.5, 0.5)
|
||||
else:
|
||||
pcm_mono = pcm_data
|
||||
|
||||
# Resample from 48kHz to 16kHz
|
||||
# Discord sends 20ms chunks at 48kHz = 960 samples
|
||||
# We need 320 samples at 16kHz (20ms)
|
||||
pcm_16k = audioop.ratecv(pcm_mono, 2, 1, 48000, 16000, None)[0]
|
||||
|
||||
# Send to STT
|
||||
asyncio.create_task(self._send_audio_chunk(user_id, pcm_16k))
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing audio for user {user_id}: {e}")
|
||||
|
||||
async def _send_audio_chunk(self, user_id: int, audio_data: bytes):
|
||||
"""
|
||||
Send audio chunk to STT client.
|
||||
|
||||
Args:
|
||||
user_id: Discord user ID
|
||||
audio_data: PCM audio (int16, 16kHz mono)
|
||||
"""
|
||||
stt_client = self.stt_clients.get(user_id)
|
||||
if not stt_client or not stt_client.is_connected():
|
||||
return
|
||||
|
||||
try:
|
||||
await stt_client.send_audio(audio_data)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send audio chunk for user {user_id}: {e}")
|
||||
|
||||
async def _on_vad_event(self, user_id: int, event: dict):
|
||||
"""Handle VAD event from STT."""
|
||||
user = self.users.get(user_id)
|
||||
event_type = event.get('event')
|
||||
probability = event.get('probability', 0)
|
||||
|
||||
logger.debug(f"VAD [{user.name if user else user_id}]: {event_type} (prob={probability:.3f})")
|
||||
|
||||
# Notify voice manager
|
||||
if hasattr(self.voice_manager, 'on_user_vad_event'):
|
||||
await self.voice_manager.on_user_vad_event(user_id, event)
|
||||
|
||||
async def _on_partial_transcript(self, user_id: int, text: str, timestamp: float):
|
||||
"""Handle partial transcript from STT."""
|
||||
user = self.users.get(user_id)
|
||||
logger.info(f"Partial [{user.name if user else user_id}]: {text}")
|
||||
|
||||
# Notify voice manager
|
||||
if hasattr(self.voice_manager, 'on_partial_transcript'):
|
||||
await self.voice_manager.on_partial_transcript(user_id, text)
|
||||
|
||||
async def _on_final_transcript(self, user_id: int, text: str, timestamp: float):
|
||||
"""Handle final transcript from STT."""
|
||||
user = self.users.get(user_id)
|
||||
logger.info(f"Final [{user.name if user else user_id}]: {text}")
|
||||
|
||||
# Notify voice manager - THIS TRIGGERS LLM RESPONSE
|
||||
if hasattr(self.voice_manager, 'on_final_transcript'):
|
||||
await self.voice_manager.on_final_transcript(user_id, text)
|
||||
|
||||
async def _on_interruption(self, user_id: int, probability: float):
|
||||
"""Handle interruption detection from STT."""
|
||||
user = self.users.get(user_id)
|
||||
logger.info(f"Interruption from [{user.name if user else user_id}] (prob={probability:.3f})")
|
||||
|
||||
# Notify voice manager - THIS CANCELS MIKU'S SPEECH
|
||||
if hasattr(self.voice_manager, 'on_user_interruption'):
|
||||
await self.voice_manager.on_user_interruption(user_id, probability)
|
||||
|
||||
def cleanup(self):
|
||||
"""Cleanup resources."""
|
||||
logger.info("Cleaning up voice receiver")
|
||||
# Async cleanup will be called separately
|
||||
|
||||
def get_listening_users(self) -> list:
|
||||
"""Get list of users currently being listened to."""
|
||||
return [
|
||||
{
|
||||
'user_id': user_id,
|
||||
'username': user.name if user else 'Unknown',
|
||||
'connected': client.is_connected()
|
||||
}
|
||||
for user_id, (user, client) in
|
||||
[(uid, (self.users.get(uid), self.stt_clients.get(uid)))
|
||||
for uid in self.stt_clients.keys()]
|
||||
]
|
||||
Reference in New Issue
Block a user