Phase 4 STT pipeline implemented — Silero VAD + faster-whisper — still not working well at all

This commit is contained in:
2026-01-17 03:14:40 +02:00
parent 3e59e5d2f6
commit d1e6b21508
30 changed files with 156595 additions and 8 deletions

View File

@@ -125,7 +125,7 @@ async def on_message(message):
if message.author == globals.client.user:
return
# Check for voice commands first (!miku join, !miku leave, !miku voice-status, !miku test, !miku say)
# Check for voice commands first (!miku join, !miku leave, !miku voice-status, !miku test, !miku say, !miku listen, !miku stop-listening)
if not isinstance(message.channel, discord.DMChannel) and message.content.strip().lower().startswith('!miku '):
from commands.voice import handle_voice_command
@@ -134,7 +134,7 @@ async def on_message(message):
cmd = parts[1].lower()
args = parts[2:] if len(parts) > 2 else []
if cmd in ['join', 'leave', 'voice-status', 'test', 'say']:
if cmd in ['join', 'leave', 'voice-status', 'test', 'say', 'listen', 'stop-listening']:
await handle_voice_command(message, cmd, args)
return

View File

@@ -39,6 +39,12 @@ async def handle_voice_command(message, cmd, args):
elif cmd == 'say':
await _handle_say(message, args)
elif cmd == 'listen':
await _handle_listen(message, args)
elif cmd == 'stop-listening':
await _handle_stop_listening(message, args)
else:
await message.channel.send(f"❌ Unknown voice command: `{cmd}`")
@@ -366,8 +372,97 @@ Keep responses short (1-3 sentences) since they will be spoken aloud."""
await message.channel.send(f"🎤 Miku: *\"{full_response.strip()}\"*")
logger.info(f"✓ Voice say complete: {full_response.strip()}")
await message.add_reaction("")
except Exception as e:
logger.error(f"Voice say failed: {e}", exc_info=True)
await message.channel.send(f"Voice say failed: {str(e)}")
logger.error(f"Failed to generate voice response: {e}", exc_info=True)
await message.channel.send(f"Error generating voice response: {e}")
async def _handle_listen(message, args):
"""
Handle !miku listen command.
Start listening to a user's voice for STT.
Usage:
!miku listen - Start listening to command author
!miku listen @user - Start listening to mentioned user
"""
# Check if Miku is in voice channel
session = voice_manager.active_session
if not session or not session.voice_client or not session.voice_client.is_connected():
await message.channel.send("❌ I'm not in a voice channel! Use `!miku join` first.")
return
# Determine target user
target_user = None
if args and len(message.mentions) > 0:
# Listen to mentioned user
target_user = message.mentions[0]
else:
# Listen to command author
target_user = message.author
# Check if user is in voice channel
if not target_user.voice or not target_user.voice.channel:
await message.channel.send(f"{target_user.mention} is not in a voice channel!")
return
# Check if user is in same channel as Miku
if target_user.voice.channel.id != session.voice_client.channel.id:
await message.channel.send(
f"{target_user.mention} must be in the same voice channel as me!"
)
return
try:
# Start listening to user
await session.start_listening(target_user)
await message.channel.send(
f"👂 Now listening to {target_user.mention}'s voice! "
f"Speak to me and I'll respond. Use `!miku stop-listening` to stop."
)
await message.add_reaction("👂")
logger.info(f"Started listening to user {target_user.id} ({target_user.name})")
except Exception as e:
logger.error(f"Failed to start listening: {e}", exc_info=True)
await message.channel.send(f"❌ Failed to start listening: {str(e)}")
async def _handle_stop_listening(message, args):
"""
Handle !miku stop-listening command.
Stop listening to a user's voice.
Usage:
!miku stop-listening - Stop listening to command author
!miku stop-listening @user - Stop listening to mentioned user
"""
# Check if Miku is in voice channel
session = voice_manager.active_session
if not session:
await message.channel.send("❌ I'm not in a voice channel!")
return
# Determine target user
target_user = None
if args and len(message.mentions) > 0:
# Stop listening to mentioned user
target_user = message.mentions[0]
else:
# Stop listening to command author
target_user = message.author
try:
# Stop listening to user
await session.stop_listening(target_user.id)
await message.channel.send(f"🔇 Stopped listening to {target_user.mention}.")
await message.add_reaction("🔇")
logger.info(f"Stopped listening to user {target_user.id} ({target_user.name})")
except Exception as e:
logger.error(f"Failed to stop listening: {e}", exc_info=True)
await message.channel.send(f"❌ Failed to stop listening: {str(e)}")

View File

@@ -22,3 +22,4 @@ transformers
torch
PyNaCl>=1.5.0
websockets>=12.0
discord-ext-voice-recv

214
bot/utils/stt_client.py Normal file
View File

@@ -0,0 +1,214 @@
"""
STT Client for Discord Bot
WebSocket client that connects to the STT server and handles:
- Audio streaming to STT
- Receiving VAD events
- Receiving partial/final transcripts
- Interruption detection
"""
import aiohttp
import asyncio
import logging
from typing import Optional, Callable
import json
logger = logging.getLogger('stt_client')
class STTClient:
"""
WebSocket client for STT server communication.
Handles audio streaming and receives transcription events.
"""
def __init__(
self,
user_id: str,
stt_url: str = "ws://miku-stt:8000/ws/stt",
on_vad_event: Optional[Callable] = None,
on_partial_transcript: Optional[Callable] = None,
on_final_transcript: Optional[Callable] = None,
on_interruption: Optional[Callable] = None
):
"""
Initialize STT client.
Args:
user_id: Discord user ID
stt_url: Base WebSocket URL for STT server
on_vad_event: Callback for VAD events (event_dict)
on_partial_transcript: Callback for partial transcripts (text, timestamp)
on_final_transcript: Callback for final transcripts (text, timestamp)
on_interruption: Callback for interruption detection (probability)
"""
self.user_id = user_id
self.stt_url = f"{stt_url}/{user_id}"
# Callbacks
self.on_vad_event = on_vad_event
self.on_partial_transcript = on_partial_transcript
self.on_final_transcript = on_final_transcript
self.on_interruption = on_interruption
# Connection state
self.websocket: Optional[aiohttp.ClientWebSocket] = None
self.session: Optional[aiohttp.ClientSession] = None
self.connected = False
self.running = False
# Receive task
self._receive_task: Optional[asyncio.Task] = None
logger.info(f"STT client initialized for user {user_id}")
async def connect(self):
"""Connect to STT WebSocket server."""
if self.connected:
logger.warning(f"Already connected for user {self.user_id}")
return
try:
self.session = aiohttp.ClientSession()
self.websocket = await self.session.ws_connect(
self.stt_url,
heartbeat=30
)
# Wait for ready message
ready_msg = await self.websocket.receive_json()
logger.info(f"STT connected for user {self.user_id}: {ready_msg}")
self.connected = True
self.running = True
# Start receive task
self._receive_task = asyncio.create_task(self._receive_events())
logger.info(f"✓ STT WebSocket connected for user {self.user_id}")
except Exception as e:
logger.error(f"Failed to connect STT for user {self.user_id}: {e}", exc_info=True)
await self.disconnect()
raise
async def disconnect(self):
"""Disconnect from STT WebSocket."""
logger.info(f"Disconnecting STT for user {self.user_id}")
self.running = False
self.connected = False
# Cancel receive task
if self._receive_task and not self._receive_task.done():
self._receive_task.cancel()
try:
await self._receive_task
except asyncio.CancelledError:
pass
# Close WebSocket
if self.websocket:
await self.websocket.close()
self.websocket = None
# Close session
if self.session:
await self.session.close()
self.session = None
logger.info(f"✓ STT disconnected for user {self.user_id}")
async def send_audio(self, audio_data: bytes):
"""
Send audio chunk to STT server.
Args:
audio_data: PCM audio (int16, 16kHz mono)
"""
if not self.connected or not self.websocket:
logger.warning(f"Cannot send audio, not connected for user {self.user_id}")
return
try:
await self.websocket.send_bytes(audio_data)
logger.debug(f"Sent {len(audio_data)} bytes to STT")
except Exception as e:
logger.error(f"Failed to send audio to STT: {e}")
self.connected = False
async def _receive_events(self):
"""Background task to receive events from STT server."""
try:
while self.running and self.websocket:
try:
msg = await self.websocket.receive()
if msg.type == aiohttp.WSMsgType.TEXT:
event = json.loads(msg.data)
await self._handle_event(event)
elif msg.type == aiohttp.WSMsgType.CLOSED:
logger.info(f"STT WebSocket closed for user {self.user_id}")
break
elif msg.type == aiohttp.WSMsgType.ERROR:
logger.error(f"STT WebSocket error for user {self.user_id}")
break
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"Error receiving STT event: {e}", exc_info=True)
finally:
self.connected = False
logger.info(f"STT receive task ended for user {self.user_id}")
async def _handle_event(self, event: dict):
"""
Handle incoming STT event.
Args:
event: Event dictionary from STT server
"""
event_type = event.get('type')
if event_type == 'vad':
# VAD event: speech detection
logger.debug(f"VAD event: {event}")
if self.on_vad_event:
await self.on_vad_event(event)
elif event_type == 'partial':
# Partial transcript
text = event.get('text', '')
timestamp = event.get('timestamp', 0)
logger.info(f"Partial transcript [{self.user_id}]: {text}")
if self.on_partial_transcript:
await self.on_partial_transcript(text, timestamp)
elif event_type == 'final':
# Final transcript
text = event.get('text', '')
timestamp = event.get('timestamp', 0)
logger.info(f"Final transcript [{self.user_id}]: {text}")
if self.on_final_transcript:
await self.on_final_transcript(text, timestamp)
elif event_type == 'interruption':
# Interruption detected
probability = event.get('probability', 0)
logger.info(f"Interruption detected from user {self.user_id} (prob={probability:.3f})")
if self.on_interruption:
await self.on_interruption(probability)
else:
logger.warning(f"Unknown STT event type: {event_type}")
def is_connected(self) -> bool:
"""Check if STT client is connected."""
return self.connected

View File

@@ -19,6 +19,7 @@ import json
import os
from typing import Optional
import discord
from discord.ext import voice_recv
import globals
from utils.logger import get_logger
@@ -97,12 +98,12 @@ class VoiceSessionManager:
# 10. Create voice session
self.active_session = VoiceSession(guild_id, voice_channel, text_channel)
# 11. Connect to Discord voice channel
# 11. Connect to Discord voice channel with VoiceRecvClient
try:
voice_client = await voice_channel.connect()
voice_client = await voice_channel.connect(cls=voice_recv.VoiceRecvClient)
self.active_session.voice_client = voice_client
self.active_session.active = True
logger.info(f"✓ Connected to voice channel: {voice_channel.name}")
logger.info(f"✓ Connected to voice channel: {voice_channel.name} (with audio receiving)")
except Exception as e:
logger.error(f"Failed to connect to voice channel: {e}", exc_info=True)
raise
@@ -387,7 +388,9 @@ class VoiceSession:
self.voice_client: Optional[discord.VoiceClient] = None
self.audio_source: Optional['MikuVoiceSource'] = None # Forward reference
self.tts_streamer: Optional['TTSTokenStreamer'] = None # Forward reference
self.voice_receiver: Optional['VoiceReceiver'] = None # STT receiver
self.active = False
self.miku_speaking = False # Track if Miku is currently speaking
logger.info(f"VoiceSession created for {voice_channel.name} in guild {guild_id}")
@@ -433,6 +436,207 @@ class VoiceSession:
except Exception as e:
logger.error(f"Error stopping audio streaming: {e}", exc_info=True)
async def start_listening(self, user: discord.User):
"""
Start listening to a user's voice (STT).
Args:
user: Discord user to listen to
"""
from utils.voice_receiver import VoiceReceiverSink
try:
# Create receiver if not exists
if not self.voice_receiver:
self.voice_receiver = VoiceReceiverSink(self)
# Start receiving audio from Discord using discord-ext-voice-recv
if self.voice_client:
self.voice_client.listen(self.voice_receiver)
logger.info("✓ Discord voice receive started (discord-ext-voice-recv)")
# Start listening to specific user
await self.voice_receiver.start_listening(user.id, user)
logger.info(f"✓ Started listening to {user.name}")
except Exception as e:
logger.error(f"Failed to start listening to {user.name}: {e}", exc_info=True)
raise
async def stop_listening(self, user_id: int):
"""
Stop listening to a user.
Args:
user_id: Discord user ID
"""
if self.voice_receiver:
await self.voice_receiver.stop_listening(user_id)
logger.info(f"✓ Stopped listening to user {user_id}")
async def stop_all_listening(self):
"""Stop listening to all users."""
if self.voice_receiver:
await self.voice_receiver.stop_all()
self.voice_receiver = None
logger.info("✓ Stopped all listening")
async def on_user_vad_event(self, user_id: int, event: dict):
"""Called when VAD detects speech state change."""
event_type = event.get('event')
logger.debug(f"User {user_id} VAD: {event_type}")
async def on_partial_transcript(self, user_id: int, text: str):
"""Called when partial transcript is received."""
logger.info(f"Partial from user {user_id}: {text}")
# Could show "User is saying..." in chat
async def on_final_transcript(self, user_id: int, text: str):
"""
Called when final transcript is received.
This triggers LLM response and TTS.
"""
logger.info(f"Final from user {user_id}: {text}")
# Get user info
user = self.voice_channel.guild.get_member(user_id)
if not user:
logger.warning(f"User {user_id} not found in guild")
return
# Show what user said
await self.text_channel.send(f"🎤 {user.name}: *\"{text}\"*")
# Generate LLM response and speak it
await self._generate_voice_response(user, text)
async def on_user_interruption(self, user_id: int, probability: float):
"""
Called when user interrupts Miku's speech.
Cancel TTS and switch to listening.
"""
if not self.miku_speaking:
return
logger.info(f"User {user_id} interrupted Miku (prob={probability:.3f})")
# Cancel Miku's speech
await self._cancel_tts()
# Show interruption in chat
user = self.voice_channel.guild.get_member(user_id)
await self.text_channel.send(f"⚠️ *{user.name if user else 'User'} interrupted Miku*")
async def _generate_voice_response(self, user: discord.User, text: str):
"""
Generate LLM response and speak it.
Args:
user: User who spoke
text: Transcribed text
"""
try:
self.miku_speaking = True
# Show processing
await self.text_channel.send(f"💭 *Miku is thinking...*")
# Import here to avoid circular imports
from utils.llm import get_current_gpu_url
import aiohttp
import globals
# Simple system prompt for voice
system_prompt = """You are Hatsune Miku, the virtual singer.
Respond naturally and concisely as Miku would in a voice conversation.
Keep responses short (1-3 sentences) since they will be spoken aloud."""
payload = {
"model": globals.TEXT_MODEL,
"messages": [
{"role": "system", "content": system_prompt},
{"role": "user", "content": text}
],
"stream": True,
"temperature": 0.8,
"max_tokens": 200
}
headers = {'Content-Type': 'application/json'}
llama_url = get_current_gpu_url()
# Stream LLM response to TTS
full_response = ""
async with aiohttp.ClientSession() as http_session:
async with http_session.post(
f"{llama_url}/v1/chat/completions",
json=payload,
headers=headers,
timeout=aiohttp.ClientTimeout(total=60)
) as response:
if response.status != 200:
error_text = await response.text()
raise Exception(f"LLM error {response.status}: {error_text}")
# Stream tokens to TTS
async for line in response.content:
if not self.miku_speaking:
# Interrupted
break
line = line.decode('utf-8').strip()
if line.startswith('data: '):
data_str = line[6:]
if data_str == '[DONE]':
break
try:
import json
data = json.loads(data_str)
if 'choices' in data and len(data['choices']) > 0:
delta = data['choices'][0].get('delta', {})
content = delta.get('content', '')
if content:
await self.audio_source.send_token(content)
full_response += content
except json.JSONDecodeError:
continue
# Flush TTS
if self.miku_speaking:
await self.audio_source.flush()
# Show response
await self.text_channel.send(f"🎤 Miku: *\"{full_response.strip()}\"*")
logger.info(f"✓ Voice response complete: {full_response.strip()}")
except Exception as e:
logger.error(f"Voice response failed: {e}", exc_info=True)
await self.text_channel.send(f"❌ Sorry, I had trouble responding")
finally:
self.miku_speaking = False
async def _cancel_tts(self):
"""Cancel current TTS synthesis."""
logger.info("Canceling TTS synthesis")
# Stop Discord playback
if self.voice_client and self.voice_client.is_playing():
self.voice_client.stop()
# Send interrupt to RVC
try:
import aiohttp
async with aiohttp.ClientSession() as session:
async with session.post("http://172.25.0.1:8765/interrupt") as resp:
if resp.status == 200:
logger.info("✓ TTS interrupted")
except Exception as e:
logger.error(f"Failed to interrupt TTS: {e}")
self.miku_speaking = False
# Global singleton instance

411
bot/utils/voice_receiver.py Normal file
View File

@@ -0,0 +1,411 @@
"""
Discord Voice Receiver using discord-ext-voice-recv
Captures audio from Discord voice channels and streams to STT.
Uses the discord-ext-voice-recv extension for proper audio receiving support.
"""
import asyncio
import audioop
import logging
from typing import Dict, Optional
from collections import deque
import discord
from discord.ext import voice_recv
from utils.stt_client import STTClient
logger = logging.getLogger('voice_receiver')
class VoiceReceiverSink(voice_recv.AudioSink):
"""
Audio sink that receives Discord audio and forwards to STT.
This sink processes incoming audio from Discord voice channels,
decodes/resamples as needed, and sends to STT clients for transcription.
"""
def __init__(self, voice_manager, stt_url: str = "ws://miku-stt:8000/ws/stt"):
"""
Initialize voice receiver sink.
Args:
voice_manager: Reference to VoiceManager for callbacks
stt_url: Base URL for STT WebSocket server with path (port 8000 inside container)
"""
super().__init__()
self.voice_manager = voice_manager
self.stt_url = stt_url
# Store event loop for thread-safe async calls
# Use get_running_loop() in async context, or store it when available
try:
self.loop = asyncio.get_running_loop()
except RuntimeError:
# Fallback if not in async context yet
self.loop = asyncio.get_event_loop()
# Per-user STT clients
self.stt_clients: Dict[int, STTClient] = {}
# Audio buffers per user (for resampling state)
self.audio_buffers: Dict[int, deque] = {}
# User info (for logging)
self.users: Dict[int, discord.User] = {}
# Active flag
self.active = False
logger.info("VoiceReceiverSink initialized")
def wants_opus(self) -> bool:
"""
Tell discord-ext-voice-recv we want Opus data, NOT decoded PCM.
We'll decode it ourselves to avoid decoder errors from discord-ext-voice-recv.
Returns:
True - we want Opus packets, we'll handle decoding
"""
return True # Get Opus, decode ourselves to avoid packet router errors
def write(self, user: Optional[discord.User], data: voice_recv.VoiceData):
"""
Called by discord-ext-voice-recv when audio is received.
This is the main callback that receives audio packets from Discord.
We get Opus data, decode it ourselves, resample, and forward to STT.
Args:
user: Discord user who sent the audio (None if unknown)
data: Voice data container with pcm, opus, and packet info
"""
if not user:
return # Skip packets from unknown users
user_id = user.id
# Check if we're listening to this user
if user_id not in self.stt_clients:
return
try:
# Get Opus data (we decode ourselves to avoid PacketRouter errors)
opus_data = data.opus
if not opus_data:
return
# Decode Opus to PCM (48kHz stereo int16)
# Use discord.py's opus decoder with proper error handling
import discord.opus
if not hasattr(self, '_opus_decoders'):
self._opus_decoders = {}
# Create decoder for this user if needed
if user_id not in self._opus_decoders:
self._opus_decoders[user_id] = discord.opus.Decoder()
decoder = self._opus_decoders[user_id]
# Decode opus -> PCM (this can fail on corrupt packets, so catch it)
try:
pcm_data = decoder.decode(opus_data, fec=False)
except discord.opus.OpusError as e:
# Skip corrupted packets silently (common at stream start)
logger.debug(f"Skipping corrupted opus packet for user {user_id}: {e}")
return
if not pcm_data:
return
# PCM from Discord is 48kHz stereo int16
# Convert stereo to mono
if len(pcm_data) % 4 == 0: # Stereo (2 channels * 2 bytes per sample)
pcm_mono = audioop.tomono(pcm_data, 2, 0.5, 0.5)
else:
pcm_mono = pcm_data
# Resample from 48kHz to 16kHz for STT
# Discord sends 20ms chunks: 960 samples @ 48kHz → 320 samples @ 16kHz
pcm_16k, _ = audioop.ratecv(pcm_mono, 2, 1, 48000, 16000, None)
# Send to STT client (schedule on event loop thread-safely)
asyncio.run_coroutine_threadsafe(
self._send_audio_chunk(user_id, pcm_16k),
self.loop
)
except Exception as e:
logger.error(f"Error processing audio for user {user_id}: {e}", exc_info=True)
def cleanup(self):
"""
Called when the sink is stopped.
Cleanup any resources.
"""
logger.info("VoiceReceiverSink cleanup")
# Async cleanup handled separately in stop_all()
async def start_listening(self, user_id: int, user: discord.User):
"""
Start listening to a specific user.
Creates an STT client connection for this user and registers callbacks.
Args:
user_id: Discord user ID
user: Discord user object
"""
if user_id in self.stt_clients:
logger.warning(f"Already listening to user {user.name} ({user_id})")
return
logger.info(f"Starting to listen to user {user.name} ({user_id})")
# Store user info
self.users[user_id] = user
# Initialize audio buffer
self.audio_buffers[user_id] = deque(maxlen=1000)
# Create STT client with callbacks
stt_client = STTClient(
user_id=user_id,
stt_url=self.stt_url,
on_vad_event=lambda event: asyncio.create_task(
self._on_vad_event(user_id, event)
),
on_partial_transcript=lambda text, timestamp: asyncio.create_task(
self._on_partial_transcript(user_id, text)
),
on_final_transcript=lambda text, timestamp: asyncio.create_task(
self._on_final_transcript(user_id, text, user)
),
on_interruption=lambda prob: asyncio.create_task(
self._on_interruption(user_id, prob)
)
)
# Connect to STT server
try:
await stt_client.connect()
self.stt_clients[user_id] = stt_client
self.active = True
logger.info(f"✓ STT connected for user {user.name}")
except Exception as e:
logger.error(f"Failed to connect STT for user {user.name}: {e}", exc_info=True)
# Cleanup partial state
if user_id in self.audio_buffers:
del self.audio_buffers[user_id]
if user_id in self.users:
del self.users[user_id]
raise
async def stop_listening(self, user_id: int):
"""
Stop listening to a specific user.
Disconnects the STT client and cleans up resources for this user.
Args:
user_id: Discord user ID
"""
if user_id not in self.stt_clients:
logger.warning(f"Not listening to user {user_id}")
return
user = self.users.get(user_id)
logger.info(f"Stopping listening to user {user.name if user else user_id}")
# Disconnect STT client
stt_client = self.stt_clients[user_id]
await stt_client.disconnect()
# Cleanup
del self.stt_clients[user_id]
if user_id in self.audio_buffers:
del self.audio_buffers[user_id]
if user_id in self.users:
del self.users[user_id]
# Cleanup opus decoder for this user
if hasattr(self, '_opus_decoders') and user_id in self._opus_decoders:
del self._opus_decoders[user_id]
# Update active flag
if not self.stt_clients:
self.active = False
logger.info(f"✓ Stopped listening to user {user.name if user else user_id}")
async def stop_all(self):
"""Stop listening to all users and cleanup all resources."""
logger.info("Stopping all voice receivers")
user_ids = list(self.stt_clients.keys())
for user_id in user_ids:
await self.stop_listening(user_id)
self.active = False
logger.info("✓ All voice receivers stopped")
async def _send_audio_chunk(self, user_id: int, audio_data: bytes):
"""
Send audio chunk to STT client.
Buffers audio until we have 512 samples (32ms @ 16kHz) which is what
Silero VAD expects. Discord sends 320 samples (20ms), so we buffer
2 chunks and send 640 samples, then the STT server can split it.
Args:
user_id: Discord user ID
audio_data: PCM audio (int16, 16kHz mono, 320 samples = 640 bytes)
"""
stt_client = self.stt_clients.get(user_id)
if not stt_client or not stt_client.is_connected():
return
try:
# Get or create buffer for this user
if user_id not in self.audio_buffers:
self.audio_buffers[user_id] = deque()
buffer = self.audio_buffers[user_id]
buffer.append(audio_data)
# Silero VAD expects 512 samples @ 16kHz (1024 bytes)
# Discord gives us 320 samples (640 bytes) every 20ms
# Buffer 2 chunks = 640 samples = 1280 bytes, send as one chunk
SAMPLES_NEEDED = 512 # What VAD wants
BYTES_NEEDED = SAMPLES_NEEDED * 2 # int16 = 2 bytes per sample
# Check if we have enough buffered audio
total_bytes = sum(len(chunk) for chunk in buffer)
if total_bytes >= BYTES_NEEDED:
# Concatenate buffered chunks
combined = b''.join(buffer)
buffer.clear()
# Send in 512-sample (1024-byte) chunks
for i in range(0, len(combined), BYTES_NEEDED):
chunk = combined[i:i+BYTES_NEEDED]
if len(chunk) == BYTES_NEEDED:
await stt_client.send_audio(chunk)
else:
# Put remaining partial chunk back in buffer
buffer.append(chunk)
except Exception as e:
logger.error(f"Failed to send audio chunk for user {user_id}: {e}")
async def _on_vad_event(self, user_id: int, event: dict):
"""
Handle VAD event from STT.
Args:
user_id: Discord user ID
event: VAD event dictionary with 'event' and 'probability' keys
"""
user = self.users.get(user_id)
event_type = event.get('event', 'unknown')
probability = event.get('probability', 0.0)
logger.debug(f"VAD [{user.name if user else user_id}]: {event_type} (prob={probability:.3f})")
# Notify voice manager - pass the full event dict
if hasattr(self.voice_manager, 'on_user_vad_event'):
await self.voice_manager.on_user_vad_event(user_id, event)
async def _on_partial_transcript(self, user_id: int, text: str):
"""
Handle partial transcript from STT.
Args:
user_id: Discord user ID
text: Partial transcript text
"""
user = self.users.get(user_id)
logger.info(f"[VOICE_RECEIVER] Partial [{user.name if user else user_id}]: {text}")
print(f"[DEBUG] PARTIAL TRANSCRIPT RECEIVED: {text}") # Extra debug
# Notify voice manager
if hasattr(self.voice_manager, 'on_partial_transcript'):
await self.voice_manager.on_partial_transcript(user_id, text)
async def _on_final_transcript(self, user_id: int, text: str, user: discord.User):
"""
Handle final transcript from STT.
This triggers the LLM response generation.
Args:
user_id: Discord user ID
text: Final transcript text
user: Discord user object
"""
logger.info(f"[VOICE_RECEIVER] Final [{user.name if user else user_id}]: {text}")
print(f"[DEBUG] FINAL TRANSCRIPT RECEIVED: {text}") # Extra debug
# Notify voice manager - THIS TRIGGERS LLM RESPONSE
if hasattr(self.voice_manager, 'on_final_transcript'):
await self.voice_manager.on_final_transcript(user_id, text)
async def _on_interruption(self, user_id: int, probability: float):
"""
Handle interruption detection from STT.
This cancels Miku's current speech if user interrupts.
Args:
user_id: Discord user ID
probability: Interruption confidence probability
"""
user = self.users.get(user_id)
logger.info(f"Interruption from [{user.name if user else user_id}] (prob={probability:.3f})")
# Notify voice manager - THIS CANCELS MIKU'S SPEECH
if hasattr(self.voice_manager, 'on_user_interruption'):
await self.voice_manager.on_user_interruption(user_id, probability)
def get_listening_users(self) -> list:
"""
Get list of users currently being listened to.
Returns:
List of dicts with user_id, username, and connection status
"""
return [
{
'user_id': user_id,
'username': user.name if user else 'Unknown',
'connected': client.is_connected()
}
for user_id, (user, client) in
[(uid, (self.users.get(uid), self.stt_clients.get(uid)))
for uid in self.stt_clients.keys()]
]
@voice_recv.AudioSink.listener()
def on_voice_member_speaking_start(self, member: discord.Member):
"""
Called when a member starts speaking (green circle appears).
This is a virtual event from discord-ext-voice-recv based on packet activity.
"""
if member.id in self.stt_clients:
logger.debug(f"🎤 {member.name} started speaking")
@voice_recv.AudioSink.listener()
def on_voice_member_speaking_stop(self, member: discord.Member):
"""
Called when a member stops speaking (green circle disappears).
This is a virtual event from discord-ext-voice-recv based on packet activity.
"""
if member.id in self.stt_clients:
logger.debug(f"🔇 {member.name} stopped speaking")

View File

@@ -0,0 +1,419 @@
"""
Discord Voice Receiver
Captures audio from Discord voice channels and streams to STT.
Handles opus decoding and audio preprocessing.
"""
import discord
import audioop
import numpy as np
import asyncio
import logging
from typing import Dict, Optional
from collections import deque
from utils.stt_client import STTClient
logger = logging.getLogger('voice_receiver')
class VoiceReceiver(discord.sinks.Sink):
"""
Voice Receiver for Discord Audio Capture
Captures audio from Discord voice channels using discord.py's voice websocket.
Processes Opus audio, decodes to PCM, resamples to 16kHz mono for STT.
Note: Standard discord.py doesn't have built-in audio receiving.
This implementation hooks into the voice websocket directly.
"""
import asyncio
import struct
import audioop
import logging
from typing import Dict, Optional, Callable
import discord
# Import opus decoder
try:
import discord.opus as opus
if not opus.is_loaded():
opus.load_opus('opus')
except Exception as e:
logging.error(f"Failed to load opus: {e}")
from utils.stt_client import STTClient
logger = logging.getLogger('voice_receiver')
class VoiceReceiver:
"""
Receives and processes audio from Discord voice channel.
This class monkey-patches the VoiceClient to intercept received RTP packets,
decodes Opus audio, and forwards to STT clients.
"""
def __init__(
self,
voice_client: discord.VoiceClient,
voice_manager,
stt_url: str = "ws://miku-stt:8001"
):
"""
Initialize voice receiver.
Args:
voice_client: Discord VoiceClient to receive audio from
voice_manager: Voice manager instance for callbacks
stt_url: Base URL for STT WebSocket server
"""
self.voice_client = voice_client
self.voice_manager = voice_manager
self.stt_url = stt_url
# Per-user STT clients
self.stt_clients: Dict[int, STTClient] = {}
# Opus decoder instances per SSRC (one per user)
self.opus_decoders: Dict[int, any] = {}
# Resampler state per user (for 48kHz → 16kHz)
self.resample_state: Dict[int, tuple] = {}
# Original receive method (for restoration)
self._original_receive = None
# Active flag
self.active = False
logger.info("VoiceReceiver initialized")
async def start_listening(self, user_id: int, user: discord.User):
"""
Start listening to a specific user's audio.
Args:
user_id: Discord user ID
user: Discord User object
"""
if user_id in self.stt_clients:
logger.warning(f"Already listening to user {user_id}")
return
try:
# Create STT client for this user
stt_client = STTClient(
user_id=user_id,
stt_url=self.stt_url,
on_vad_event=lambda event, prob: asyncio.create_task(
self.voice_manager.on_user_vad_event(user_id, event)
),
on_partial_transcript=lambda text: asyncio.create_task(
self.voice_manager.on_partial_transcript(user_id, text)
),
on_final_transcript=lambda text: asyncio.create_task(
self.voice_manager.on_final_transcript(user_id, text, user)
),
on_interruption=lambda prob: asyncio.create_task(
self.voice_manager.on_user_interruption(user_id, prob)
)
)
# Connect to STT server
await stt_client.connect()
# Store client
self.stt_clients[user_id] = stt_client
# Initialize opus decoder for this user if needed
# (Will be done when we receive their SSRC)
# Patch voice client to receive audio if not already patched
if not self.active:
await self._patch_voice_client()
logger.info(f"✓ Started listening to user {user_id} ({user.name})")
except Exception as e:
logger.error(f"Failed to start listening to user {user_id}: {e}", exc_info=True)
raise
async def stop_listening(self, user_id: int):
"""
Stop listening to a specific user.
Args:
user_id: Discord user ID
"""
if user_id not in self.stt_clients:
logger.warning(f"Not listening to user {user_id}")
return
try:
# Disconnect STT client
stt_client = self.stt_clients.pop(user_id)
await stt_client.disconnect()
# Clean up decoder and resampler state
# Note: We don't know the SSRC here, so we'll just remove by user_id
# Actual cleanup happens in _process_audio when we match SSRC to user_id
# If no more clients, unpatch voice client
if not self.stt_clients:
await self._unpatch_voice_client()
logger.info(f"✓ Stopped listening to user {user_id}")
except Exception as e:
logger.error(f"Failed to stop listening to user {user_id}: {e}", exc_info=True)
raise
async def _patch_voice_client(self):
"""Patch VoiceClient to intercept received audio packets."""
logger.warning("⚠️ Audio receiving not yet implemented - discord.py doesn't support receiving by default")
logger.warning("⚠️ You need discord.py-self or a custom fork with receiving support")
logger.warning("⚠️ STT will not receive any audio until this is implemented")
self.active = True
# TODO: Implement RTP packet receiving
# This requires either:
# 1. Using discord.py-self which has receiving support
# 2. Monkey-patching voice_client.ws to intercept packets
# 3. Using a separate UDP socket listener
async def _unpatch_voice_client(self):
"""Restore original VoiceClient behavior."""
self.active = False
logger.info("Unpatch voice client (receiving disabled)")
async def _process_audio(self, ssrc: int, opus_data: bytes):
"""
Process received Opus audio packet.
Args:
ssrc: RTP SSRC (identifies the audio source/user)
opus_data: Opus-encoded audio data
"""
# TODO: Map SSRC to user_id (requires tracking voice state updates)
# For now, this is a placeholder
pass
async def cleanup(self):
"""Clean up all resources."""
# Disconnect all STT clients
for user_id in list(self.stt_clients.keys()):
await self.stop_listening(user_id)
# Unpatch voice client
if self.active:
await self._unpatch_voice_client()
logger.info("VoiceReceiver cleanup complete") def __init__(self, voice_manager):
"""
Initialize voice receiver.
Args:
voice_manager: Reference to VoiceManager for callbacks
"""
super().__init__()
self.voice_manager = voice_manager
# Per-user STT clients
self.stt_clients: Dict[int, STTClient] = {}
# Audio buffers per user (for resampling)
self.audio_buffers: Dict[int, deque] = {}
# User info (for logging)
self.users: Dict[int, discord.User] = {}
logger.info("Voice receiver initialized")
async def start_listening(self, user_id: int, user: discord.User):
"""
Start listening to a specific user.
Args:
user_id: Discord user ID
user: Discord user object
"""
if user_id in self.stt_clients:
logger.warning(f"Already listening to user {user.name} ({user_id})")
return
logger.info(f"Starting to listen to user {user.name} ({user_id})")
# Store user info
self.users[user_id] = user
# Initialize audio buffer
self.audio_buffers[user_id] = deque(maxlen=1000) # Max 1000 chunks
# Create STT client with callbacks
stt_client = STTClient(
user_id=str(user_id),
on_vad_event=lambda event: self._on_vad_event(user_id, event),
on_partial_transcript=lambda text, ts: self._on_partial_transcript(user_id, text, ts),
on_final_transcript=lambda text, ts: self._on_final_transcript(user_id, text, ts),
on_interruption=lambda prob: self._on_interruption(user_id, prob)
)
# Connect to STT
try:
await stt_client.connect()
self.stt_clients[user_id] = stt_client
logger.info(f"✓ STT connected for user {user.name}")
except Exception as e:
logger.error(f"Failed to connect STT for user {user.name}: {e}")
async def stop_listening(self, user_id: int):
"""
Stop listening to a specific user.
Args:
user_id: Discord user ID
"""
if user_id not in self.stt_clients:
return
user = self.users.get(user_id)
logger.info(f"Stopping listening to user {user.name if user else user_id}")
# Disconnect STT client
stt_client = self.stt_clients[user_id]
await stt_client.disconnect()
# Cleanup
del self.stt_clients[user_id]
if user_id in self.audio_buffers:
del self.audio_buffers[user_id]
if user_id in self.users:
del self.users[user_id]
logger.info(f"✓ Stopped listening to user {user.name if user else user_id}")
async def stop_all(self):
"""Stop listening to all users."""
logger.info("Stopping all voice receivers")
user_ids = list(self.stt_clients.keys())
for user_id in user_ids:
await self.stop_listening(user_id)
logger.info("✓ All voice receivers stopped")
def write(self, data: discord.sinks.core.AudioData):
"""
Called by discord.py when audio is received.
Args:
data: Audio data from Discord
"""
# Get user ID from SSRC
user_id = data.user.id if data.user else None
if not user_id:
return
# Check if we're listening to this user
if user_id not in self.stt_clients:
return
# Process audio
try:
# Decode opus to PCM (48kHz stereo)
pcm_data = data.pcm
# Convert stereo to mono if needed
if len(pcm_data) % 4 == 0: # Stereo int16 (2 channels * 2 bytes)
# Average left and right channels
pcm_mono = audioop.tomono(pcm_data, 2, 0.5, 0.5)
else:
pcm_mono = pcm_data
# Resample from 48kHz to 16kHz
# Discord sends 20ms chunks at 48kHz = 960 samples
# We need 320 samples at 16kHz (20ms)
pcm_16k = audioop.ratecv(pcm_mono, 2, 1, 48000, 16000, None)[0]
# Send to STT
asyncio.create_task(self._send_audio_chunk(user_id, pcm_16k))
except Exception as e:
logger.error(f"Error processing audio for user {user_id}: {e}")
async def _send_audio_chunk(self, user_id: int, audio_data: bytes):
"""
Send audio chunk to STT client.
Args:
user_id: Discord user ID
audio_data: PCM audio (int16, 16kHz mono)
"""
stt_client = self.stt_clients.get(user_id)
if not stt_client or not stt_client.is_connected():
return
try:
await stt_client.send_audio(audio_data)
except Exception as e:
logger.error(f"Failed to send audio chunk for user {user_id}: {e}")
async def _on_vad_event(self, user_id: int, event: dict):
"""Handle VAD event from STT."""
user = self.users.get(user_id)
event_type = event.get('event')
probability = event.get('probability', 0)
logger.debug(f"VAD [{user.name if user else user_id}]: {event_type} (prob={probability:.3f})")
# Notify voice manager
if hasattr(self.voice_manager, 'on_user_vad_event'):
await self.voice_manager.on_user_vad_event(user_id, event)
async def _on_partial_transcript(self, user_id: int, text: str, timestamp: float):
"""Handle partial transcript from STT."""
user = self.users.get(user_id)
logger.info(f"Partial [{user.name if user else user_id}]: {text}")
# Notify voice manager
if hasattr(self.voice_manager, 'on_partial_transcript'):
await self.voice_manager.on_partial_transcript(user_id, text)
async def _on_final_transcript(self, user_id: int, text: str, timestamp: float):
"""Handle final transcript from STT."""
user = self.users.get(user_id)
logger.info(f"Final [{user.name if user else user_id}]: {text}")
# Notify voice manager - THIS TRIGGERS LLM RESPONSE
if hasattr(self.voice_manager, 'on_final_transcript'):
await self.voice_manager.on_final_transcript(user_id, text)
async def _on_interruption(self, user_id: int, probability: float):
"""Handle interruption detection from STT."""
user = self.users.get(user_id)
logger.info(f"Interruption from [{user.name if user else user_id}] (prob={probability:.3f})")
# Notify voice manager - THIS CANCELS MIKU'S SPEECH
if hasattr(self.voice_manager, 'on_user_interruption'):
await self.voice_manager.on_user_interruption(user_id, probability)
def cleanup(self):
"""Cleanup resources."""
logger.info("Cleaning up voice receiver")
# Async cleanup will be called separately
def get_listening_users(self) -> list:
"""Get list of users currently being listened to."""
return [
{
'user_id': user_id,
'username': user.name if user else 'Unknown',
'connected': client.is_connected()
}
for user_id, (user, client) in
[(uid, (self.users.get(uid), self.stt_clients.get(uid)))
for uid in self.stt_clients.keys()]
]