Phase 4 STT pipeline implemented — Silero VAD + faster-whisper — still not working well at all
This commit is contained in:
214
bot/utils/stt_client.py
Normal file
214
bot/utils/stt_client.py
Normal file
@@ -0,0 +1,214 @@
|
||||
"""
|
||||
STT Client for Discord Bot
|
||||
|
||||
WebSocket client that connects to the STT server and handles:
|
||||
- Audio streaming to STT
|
||||
- Receiving VAD events
|
||||
- Receiving partial/final transcripts
|
||||
- Interruption detection
|
||||
"""
|
||||
|
||||
import aiohttp
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Optional, Callable
|
||||
import json
|
||||
|
||||
logger = logging.getLogger('stt_client')
|
||||
|
||||
|
||||
class STTClient:
|
||||
"""
|
||||
WebSocket client for STT server communication.
|
||||
|
||||
Handles audio streaming and receives transcription events.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
user_id: str,
|
||||
stt_url: str = "ws://miku-stt:8000/ws/stt",
|
||||
on_vad_event: Optional[Callable] = None,
|
||||
on_partial_transcript: Optional[Callable] = None,
|
||||
on_final_transcript: Optional[Callable] = None,
|
||||
on_interruption: Optional[Callable] = None
|
||||
):
|
||||
"""
|
||||
Initialize STT client.
|
||||
|
||||
Args:
|
||||
user_id: Discord user ID
|
||||
stt_url: Base WebSocket URL for STT server
|
||||
on_vad_event: Callback for VAD events (event_dict)
|
||||
on_partial_transcript: Callback for partial transcripts (text, timestamp)
|
||||
on_final_transcript: Callback for final transcripts (text, timestamp)
|
||||
on_interruption: Callback for interruption detection (probability)
|
||||
"""
|
||||
self.user_id = user_id
|
||||
self.stt_url = f"{stt_url}/{user_id}"
|
||||
|
||||
# Callbacks
|
||||
self.on_vad_event = on_vad_event
|
||||
self.on_partial_transcript = on_partial_transcript
|
||||
self.on_final_transcript = on_final_transcript
|
||||
self.on_interruption = on_interruption
|
||||
|
||||
# Connection state
|
||||
self.websocket: Optional[aiohttp.ClientWebSocket] = None
|
||||
self.session: Optional[aiohttp.ClientSession] = None
|
||||
self.connected = False
|
||||
self.running = False
|
||||
|
||||
# Receive task
|
||||
self._receive_task: Optional[asyncio.Task] = None
|
||||
|
||||
logger.info(f"STT client initialized for user {user_id}")
|
||||
|
||||
async def connect(self):
|
||||
"""Connect to STT WebSocket server."""
|
||||
if self.connected:
|
||||
logger.warning(f"Already connected for user {self.user_id}")
|
||||
return
|
||||
|
||||
try:
|
||||
self.session = aiohttp.ClientSession()
|
||||
self.websocket = await self.session.ws_connect(
|
||||
self.stt_url,
|
||||
heartbeat=30
|
||||
)
|
||||
|
||||
# Wait for ready message
|
||||
ready_msg = await self.websocket.receive_json()
|
||||
logger.info(f"STT connected for user {self.user_id}: {ready_msg}")
|
||||
|
||||
self.connected = True
|
||||
self.running = True
|
||||
|
||||
# Start receive task
|
||||
self._receive_task = asyncio.create_task(self._receive_events())
|
||||
|
||||
logger.info(f"✓ STT WebSocket connected for user {self.user_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect STT for user {self.user_id}: {e}", exc_info=True)
|
||||
await self.disconnect()
|
||||
raise
|
||||
|
||||
async def disconnect(self):
|
||||
"""Disconnect from STT WebSocket."""
|
||||
logger.info(f"Disconnecting STT for user {self.user_id}")
|
||||
|
||||
self.running = False
|
||||
self.connected = False
|
||||
|
||||
# Cancel receive task
|
||||
if self._receive_task and not self._receive_task.done():
|
||||
self._receive_task.cancel()
|
||||
try:
|
||||
await self._receive_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# Close WebSocket
|
||||
if self.websocket:
|
||||
await self.websocket.close()
|
||||
self.websocket = None
|
||||
|
||||
# Close session
|
||||
if self.session:
|
||||
await self.session.close()
|
||||
self.session = None
|
||||
|
||||
logger.info(f"✓ STT disconnected for user {self.user_id}")
|
||||
|
||||
async def send_audio(self, audio_data: bytes):
|
||||
"""
|
||||
Send audio chunk to STT server.
|
||||
|
||||
Args:
|
||||
audio_data: PCM audio (int16, 16kHz mono)
|
||||
"""
|
||||
if not self.connected or not self.websocket:
|
||||
logger.warning(f"Cannot send audio, not connected for user {self.user_id}")
|
||||
return
|
||||
|
||||
try:
|
||||
await self.websocket.send_bytes(audio_data)
|
||||
logger.debug(f"Sent {len(audio_data)} bytes to STT")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send audio to STT: {e}")
|
||||
self.connected = False
|
||||
|
||||
async def _receive_events(self):
|
||||
"""Background task to receive events from STT server."""
|
||||
try:
|
||||
while self.running and self.websocket:
|
||||
try:
|
||||
msg = await self.websocket.receive()
|
||||
|
||||
if msg.type == aiohttp.WSMsgType.TEXT:
|
||||
event = json.loads(msg.data)
|
||||
await self._handle_event(event)
|
||||
|
||||
elif msg.type == aiohttp.WSMsgType.CLOSED:
|
||||
logger.info(f"STT WebSocket closed for user {self.user_id}")
|
||||
break
|
||||
|
||||
elif msg.type == aiohttp.WSMsgType.ERROR:
|
||||
logger.error(f"STT WebSocket error for user {self.user_id}")
|
||||
break
|
||||
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Error receiving STT event: {e}", exc_info=True)
|
||||
|
||||
finally:
|
||||
self.connected = False
|
||||
logger.info(f"STT receive task ended for user {self.user_id}")
|
||||
|
||||
async def _handle_event(self, event: dict):
|
||||
"""
|
||||
Handle incoming STT event.
|
||||
|
||||
Args:
|
||||
event: Event dictionary from STT server
|
||||
"""
|
||||
event_type = event.get('type')
|
||||
|
||||
if event_type == 'vad':
|
||||
# VAD event: speech detection
|
||||
logger.debug(f"VAD event: {event}")
|
||||
if self.on_vad_event:
|
||||
await self.on_vad_event(event)
|
||||
|
||||
elif event_type == 'partial':
|
||||
# Partial transcript
|
||||
text = event.get('text', '')
|
||||
timestamp = event.get('timestamp', 0)
|
||||
logger.info(f"Partial transcript [{self.user_id}]: {text}")
|
||||
if self.on_partial_transcript:
|
||||
await self.on_partial_transcript(text, timestamp)
|
||||
|
||||
elif event_type == 'final':
|
||||
# Final transcript
|
||||
text = event.get('text', '')
|
||||
timestamp = event.get('timestamp', 0)
|
||||
logger.info(f"Final transcript [{self.user_id}]: {text}")
|
||||
if self.on_final_transcript:
|
||||
await self.on_final_transcript(text, timestamp)
|
||||
|
||||
elif event_type == 'interruption':
|
||||
# Interruption detected
|
||||
probability = event.get('probability', 0)
|
||||
logger.info(f"Interruption detected from user {self.user_id} (prob={probability:.3f})")
|
||||
if self.on_interruption:
|
||||
await self.on_interruption(probability)
|
||||
|
||||
else:
|
||||
logger.warning(f"Unknown STT event type: {event_type}")
|
||||
|
||||
def is_connected(self) -> bool:
|
||||
"""Check if STT client is connected."""
|
||||
return self.connected
|
||||
@@ -19,6 +19,7 @@ import json
|
||||
import os
|
||||
from typing import Optional
|
||||
import discord
|
||||
from discord.ext import voice_recv
|
||||
import globals
|
||||
from utils.logger import get_logger
|
||||
|
||||
@@ -97,12 +98,12 @@ class VoiceSessionManager:
|
||||
# 10. Create voice session
|
||||
self.active_session = VoiceSession(guild_id, voice_channel, text_channel)
|
||||
|
||||
# 11. Connect to Discord voice channel
|
||||
# 11. Connect to Discord voice channel with VoiceRecvClient
|
||||
try:
|
||||
voice_client = await voice_channel.connect()
|
||||
voice_client = await voice_channel.connect(cls=voice_recv.VoiceRecvClient)
|
||||
self.active_session.voice_client = voice_client
|
||||
self.active_session.active = True
|
||||
logger.info(f"✓ Connected to voice channel: {voice_channel.name}")
|
||||
logger.info(f"✓ Connected to voice channel: {voice_channel.name} (with audio receiving)")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect to voice channel: {e}", exc_info=True)
|
||||
raise
|
||||
@@ -387,7 +388,9 @@ class VoiceSession:
|
||||
self.voice_client: Optional[discord.VoiceClient] = None
|
||||
self.audio_source: Optional['MikuVoiceSource'] = None # Forward reference
|
||||
self.tts_streamer: Optional['TTSTokenStreamer'] = None # Forward reference
|
||||
self.voice_receiver: Optional['VoiceReceiver'] = None # STT receiver
|
||||
self.active = False
|
||||
self.miku_speaking = False # Track if Miku is currently speaking
|
||||
|
||||
logger.info(f"VoiceSession created for {voice_channel.name} in guild {guild_id}")
|
||||
|
||||
@@ -433,6 +436,207 @@ class VoiceSession:
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error stopping audio streaming: {e}", exc_info=True)
|
||||
|
||||
async def start_listening(self, user: discord.User):
|
||||
"""
|
||||
Start listening to a user's voice (STT).
|
||||
|
||||
Args:
|
||||
user: Discord user to listen to
|
||||
"""
|
||||
from utils.voice_receiver import VoiceReceiverSink
|
||||
|
||||
try:
|
||||
# Create receiver if not exists
|
||||
if not self.voice_receiver:
|
||||
self.voice_receiver = VoiceReceiverSink(self)
|
||||
|
||||
# Start receiving audio from Discord using discord-ext-voice-recv
|
||||
if self.voice_client:
|
||||
self.voice_client.listen(self.voice_receiver)
|
||||
logger.info("✓ Discord voice receive started (discord-ext-voice-recv)")
|
||||
|
||||
# Start listening to specific user
|
||||
await self.voice_receiver.start_listening(user.id, user)
|
||||
logger.info(f"✓ Started listening to {user.name}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start listening to {user.name}: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def stop_listening(self, user_id: int):
|
||||
"""
|
||||
Stop listening to a user.
|
||||
|
||||
Args:
|
||||
user_id: Discord user ID
|
||||
"""
|
||||
if self.voice_receiver:
|
||||
await self.voice_receiver.stop_listening(user_id)
|
||||
logger.info(f"✓ Stopped listening to user {user_id}")
|
||||
|
||||
async def stop_all_listening(self):
|
||||
"""Stop listening to all users."""
|
||||
if self.voice_receiver:
|
||||
await self.voice_receiver.stop_all()
|
||||
self.voice_receiver = None
|
||||
logger.info("✓ Stopped all listening")
|
||||
|
||||
async def on_user_vad_event(self, user_id: int, event: dict):
|
||||
"""Called when VAD detects speech state change."""
|
||||
event_type = event.get('event')
|
||||
logger.debug(f"User {user_id} VAD: {event_type}")
|
||||
|
||||
async def on_partial_transcript(self, user_id: int, text: str):
|
||||
"""Called when partial transcript is received."""
|
||||
logger.info(f"Partial from user {user_id}: {text}")
|
||||
# Could show "User is saying..." in chat
|
||||
|
||||
async def on_final_transcript(self, user_id: int, text: str):
|
||||
"""
|
||||
Called when final transcript is received.
|
||||
This triggers LLM response and TTS.
|
||||
"""
|
||||
logger.info(f"Final from user {user_id}: {text}")
|
||||
|
||||
# Get user info
|
||||
user = self.voice_channel.guild.get_member(user_id)
|
||||
if not user:
|
||||
logger.warning(f"User {user_id} not found in guild")
|
||||
return
|
||||
|
||||
# Show what user said
|
||||
await self.text_channel.send(f"🎤 {user.name}: *\"{text}\"*")
|
||||
|
||||
# Generate LLM response and speak it
|
||||
await self._generate_voice_response(user, text)
|
||||
|
||||
async def on_user_interruption(self, user_id: int, probability: float):
|
||||
"""
|
||||
Called when user interrupts Miku's speech.
|
||||
Cancel TTS and switch to listening.
|
||||
"""
|
||||
if not self.miku_speaking:
|
||||
return
|
||||
|
||||
logger.info(f"User {user_id} interrupted Miku (prob={probability:.3f})")
|
||||
|
||||
# Cancel Miku's speech
|
||||
await self._cancel_tts()
|
||||
|
||||
# Show interruption in chat
|
||||
user = self.voice_channel.guild.get_member(user_id)
|
||||
await self.text_channel.send(f"⚠️ *{user.name if user else 'User'} interrupted Miku*")
|
||||
|
||||
async def _generate_voice_response(self, user: discord.User, text: str):
|
||||
"""
|
||||
Generate LLM response and speak it.
|
||||
|
||||
Args:
|
||||
user: User who spoke
|
||||
text: Transcribed text
|
||||
"""
|
||||
try:
|
||||
self.miku_speaking = True
|
||||
|
||||
# Show processing
|
||||
await self.text_channel.send(f"💭 *Miku is thinking...*")
|
||||
|
||||
# Import here to avoid circular imports
|
||||
from utils.llm import get_current_gpu_url
|
||||
import aiohttp
|
||||
import globals
|
||||
|
||||
# Simple system prompt for voice
|
||||
system_prompt = """You are Hatsune Miku, the virtual singer.
|
||||
Respond naturally and concisely as Miku would in a voice conversation.
|
||||
Keep responses short (1-3 sentences) since they will be spoken aloud."""
|
||||
|
||||
payload = {
|
||||
"model": globals.TEXT_MODEL,
|
||||
"messages": [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": text}
|
||||
],
|
||||
"stream": True,
|
||||
"temperature": 0.8,
|
||||
"max_tokens": 200
|
||||
}
|
||||
|
||||
headers = {'Content-Type': 'application/json'}
|
||||
llama_url = get_current_gpu_url()
|
||||
|
||||
# Stream LLM response to TTS
|
||||
full_response = ""
|
||||
async with aiohttp.ClientSession() as http_session:
|
||||
async with http_session.post(
|
||||
f"{llama_url}/v1/chat/completions",
|
||||
json=payload,
|
||||
headers=headers,
|
||||
timeout=aiohttp.ClientTimeout(total=60)
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
raise Exception(f"LLM error {response.status}: {error_text}")
|
||||
|
||||
# Stream tokens to TTS
|
||||
async for line in response.content:
|
||||
if not self.miku_speaking:
|
||||
# Interrupted
|
||||
break
|
||||
|
||||
line = line.decode('utf-8').strip()
|
||||
if line.startswith('data: '):
|
||||
data_str = line[6:]
|
||||
if data_str == '[DONE]':
|
||||
break
|
||||
|
||||
try:
|
||||
import json
|
||||
data = json.loads(data_str)
|
||||
if 'choices' in data and len(data['choices']) > 0:
|
||||
delta = data['choices'][0].get('delta', {})
|
||||
content = delta.get('content', '')
|
||||
if content:
|
||||
await self.audio_source.send_token(content)
|
||||
full_response += content
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
# Flush TTS
|
||||
if self.miku_speaking:
|
||||
await self.audio_source.flush()
|
||||
|
||||
# Show response
|
||||
await self.text_channel.send(f"🎤 Miku: *\"{full_response.strip()}\"*")
|
||||
logger.info(f"✓ Voice response complete: {full_response.strip()}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Voice response failed: {e}", exc_info=True)
|
||||
await self.text_channel.send(f"❌ Sorry, I had trouble responding")
|
||||
|
||||
finally:
|
||||
self.miku_speaking = False
|
||||
|
||||
async def _cancel_tts(self):
|
||||
"""Cancel current TTS synthesis."""
|
||||
logger.info("Canceling TTS synthesis")
|
||||
|
||||
# Stop Discord playback
|
||||
if self.voice_client and self.voice_client.is_playing():
|
||||
self.voice_client.stop()
|
||||
|
||||
# Send interrupt to RVC
|
||||
try:
|
||||
import aiohttp
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post("http://172.25.0.1:8765/interrupt") as resp:
|
||||
if resp.status == 200:
|
||||
logger.info("✓ TTS interrupted")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to interrupt TTS: {e}")
|
||||
|
||||
self.miku_speaking = False
|
||||
|
||||
|
||||
# Global singleton instance
|
||||
|
||||
411
bot/utils/voice_receiver.py
Normal file
411
bot/utils/voice_receiver.py
Normal file
@@ -0,0 +1,411 @@
|
||||
"""
|
||||
Discord Voice Receiver using discord-ext-voice-recv
|
||||
|
||||
Captures audio from Discord voice channels and streams to STT.
|
||||
Uses the discord-ext-voice-recv extension for proper audio receiving support.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import audioop
|
||||
import logging
|
||||
from typing import Dict, Optional
|
||||
from collections import deque
|
||||
|
||||
import discord
|
||||
from discord.ext import voice_recv
|
||||
|
||||
from utils.stt_client import STTClient
|
||||
|
||||
logger = logging.getLogger('voice_receiver')
|
||||
|
||||
|
||||
class VoiceReceiverSink(voice_recv.AudioSink):
|
||||
"""
|
||||
Audio sink that receives Discord audio and forwards to STT.
|
||||
|
||||
This sink processes incoming audio from Discord voice channels,
|
||||
decodes/resamples as needed, and sends to STT clients for transcription.
|
||||
"""
|
||||
|
||||
def __init__(self, voice_manager, stt_url: str = "ws://miku-stt:8000/ws/stt"):
|
||||
"""
|
||||
Initialize voice receiver sink.
|
||||
|
||||
Args:
|
||||
voice_manager: Reference to VoiceManager for callbacks
|
||||
stt_url: Base URL for STT WebSocket server with path (port 8000 inside container)
|
||||
"""
|
||||
super().__init__()
|
||||
self.voice_manager = voice_manager
|
||||
self.stt_url = stt_url
|
||||
|
||||
# Store event loop for thread-safe async calls
|
||||
# Use get_running_loop() in async context, or store it when available
|
||||
try:
|
||||
self.loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
# Fallback if not in async context yet
|
||||
self.loop = asyncio.get_event_loop()
|
||||
|
||||
# Per-user STT clients
|
||||
self.stt_clients: Dict[int, STTClient] = {}
|
||||
|
||||
# Audio buffers per user (for resampling state)
|
||||
self.audio_buffers: Dict[int, deque] = {}
|
||||
|
||||
# User info (for logging)
|
||||
self.users: Dict[int, discord.User] = {}
|
||||
|
||||
# Active flag
|
||||
self.active = False
|
||||
|
||||
logger.info("VoiceReceiverSink initialized")
|
||||
|
||||
def wants_opus(self) -> bool:
|
||||
"""
|
||||
Tell discord-ext-voice-recv we want Opus data, NOT decoded PCM.
|
||||
|
||||
We'll decode it ourselves to avoid decoder errors from discord-ext-voice-recv.
|
||||
|
||||
Returns:
|
||||
True - we want Opus packets, we'll handle decoding
|
||||
"""
|
||||
return True # Get Opus, decode ourselves to avoid packet router errors
|
||||
|
||||
def write(self, user: Optional[discord.User], data: voice_recv.VoiceData):
|
||||
"""
|
||||
Called by discord-ext-voice-recv when audio is received.
|
||||
|
||||
This is the main callback that receives audio packets from Discord.
|
||||
We get Opus data, decode it ourselves, resample, and forward to STT.
|
||||
|
||||
Args:
|
||||
user: Discord user who sent the audio (None if unknown)
|
||||
data: Voice data container with pcm, opus, and packet info
|
||||
"""
|
||||
if not user:
|
||||
return # Skip packets from unknown users
|
||||
|
||||
user_id = user.id
|
||||
|
||||
# Check if we're listening to this user
|
||||
if user_id not in self.stt_clients:
|
||||
return
|
||||
|
||||
try:
|
||||
# Get Opus data (we decode ourselves to avoid PacketRouter errors)
|
||||
opus_data = data.opus
|
||||
|
||||
if not opus_data:
|
||||
return
|
||||
|
||||
# Decode Opus to PCM (48kHz stereo int16)
|
||||
# Use discord.py's opus decoder with proper error handling
|
||||
import discord.opus
|
||||
if not hasattr(self, '_opus_decoders'):
|
||||
self._opus_decoders = {}
|
||||
|
||||
# Create decoder for this user if needed
|
||||
if user_id not in self._opus_decoders:
|
||||
self._opus_decoders[user_id] = discord.opus.Decoder()
|
||||
|
||||
decoder = self._opus_decoders[user_id]
|
||||
|
||||
# Decode opus -> PCM (this can fail on corrupt packets, so catch it)
|
||||
try:
|
||||
pcm_data = decoder.decode(opus_data, fec=False)
|
||||
except discord.opus.OpusError as e:
|
||||
# Skip corrupted packets silently (common at stream start)
|
||||
logger.debug(f"Skipping corrupted opus packet for user {user_id}: {e}")
|
||||
return
|
||||
|
||||
if not pcm_data:
|
||||
return
|
||||
|
||||
# PCM from Discord is 48kHz stereo int16
|
||||
# Convert stereo to mono
|
||||
if len(pcm_data) % 4 == 0: # Stereo (2 channels * 2 bytes per sample)
|
||||
pcm_mono = audioop.tomono(pcm_data, 2, 0.5, 0.5)
|
||||
else:
|
||||
pcm_mono = pcm_data
|
||||
|
||||
# Resample from 48kHz to 16kHz for STT
|
||||
# Discord sends 20ms chunks: 960 samples @ 48kHz → 320 samples @ 16kHz
|
||||
pcm_16k, _ = audioop.ratecv(pcm_mono, 2, 1, 48000, 16000, None)
|
||||
|
||||
# Send to STT client (schedule on event loop thread-safely)
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self._send_audio_chunk(user_id, pcm_16k),
|
||||
self.loop
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing audio for user {user_id}: {e}", exc_info=True)
|
||||
|
||||
def cleanup(self):
|
||||
"""
|
||||
Called when the sink is stopped.
|
||||
Cleanup any resources.
|
||||
"""
|
||||
logger.info("VoiceReceiverSink cleanup")
|
||||
# Async cleanup handled separately in stop_all()
|
||||
|
||||
async def start_listening(self, user_id: int, user: discord.User):
|
||||
"""
|
||||
Start listening to a specific user.
|
||||
|
||||
Creates an STT client connection for this user and registers callbacks.
|
||||
|
||||
Args:
|
||||
user_id: Discord user ID
|
||||
user: Discord user object
|
||||
"""
|
||||
if user_id in self.stt_clients:
|
||||
logger.warning(f"Already listening to user {user.name} ({user_id})")
|
||||
return
|
||||
|
||||
logger.info(f"Starting to listen to user {user.name} ({user_id})")
|
||||
|
||||
# Store user info
|
||||
self.users[user_id] = user
|
||||
|
||||
# Initialize audio buffer
|
||||
self.audio_buffers[user_id] = deque(maxlen=1000)
|
||||
|
||||
# Create STT client with callbacks
|
||||
stt_client = STTClient(
|
||||
user_id=user_id,
|
||||
stt_url=self.stt_url,
|
||||
on_vad_event=lambda event: asyncio.create_task(
|
||||
self._on_vad_event(user_id, event)
|
||||
),
|
||||
on_partial_transcript=lambda text, timestamp: asyncio.create_task(
|
||||
self._on_partial_transcript(user_id, text)
|
||||
),
|
||||
on_final_transcript=lambda text, timestamp: asyncio.create_task(
|
||||
self._on_final_transcript(user_id, text, user)
|
||||
),
|
||||
on_interruption=lambda prob: asyncio.create_task(
|
||||
self._on_interruption(user_id, prob)
|
||||
)
|
||||
)
|
||||
|
||||
# Connect to STT server
|
||||
try:
|
||||
await stt_client.connect()
|
||||
self.stt_clients[user_id] = stt_client
|
||||
self.active = True
|
||||
logger.info(f"✓ STT connected for user {user.name}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect STT for user {user.name}: {e}", exc_info=True)
|
||||
# Cleanup partial state
|
||||
if user_id in self.audio_buffers:
|
||||
del self.audio_buffers[user_id]
|
||||
if user_id in self.users:
|
||||
del self.users[user_id]
|
||||
raise
|
||||
|
||||
async def stop_listening(self, user_id: int):
|
||||
"""
|
||||
Stop listening to a specific user.
|
||||
|
||||
Disconnects the STT client and cleans up resources for this user.
|
||||
|
||||
Args:
|
||||
user_id: Discord user ID
|
||||
"""
|
||||
if user_id not in self.stt_clients:
|
||||
logger.warning(f"Not listening to user {user_id}")
|
||||
return
|
||||
|
||||
user = self.users.get(user_id)
|
||||
logger.info(f"Stopping listening to user {user.name if user else user_id}")
|
||||
|
||||
# Disconnect STT client
|
||||
stt_client = self.stt_clients[user_id]
|
||||
await stt_client.disconnect()
|
||||
|
||||
# Cleanup
|
||||
del self.stt_clients[user_id]
|
||||
if user_id in self.audio_buffers:
|
||||
del self.audio_buffers[user_id]
|
||||
if user_id in self.users:
|
||||
del self.users[user_id]
|
||||
|
||||
# Cleanup opus decoder for this user
|
||||
if hasattr(self, '_opus_decoders') and user_id in self._opus_decoders:
|
||||
del self._opus_decoders[user_id]
|
||||
|
||||
# Update active flag
|
||||
if not self.stt_clients:
|
||||
self.active = False
|
||||
|
||||
logger.info(f"✓ Stopped listening to user {user.name if user else user_id}")
|
||||
|
||||
async def stop_all(self):
|
||||
"""Stop listening to all users and cleanup all resources."""
|
||||
logger.info("Stopping all voice receivers")
|
||||
|
||||
user_ids = list(self.stt_clients.keys())
|
||||
for user_id in user_ids:
|
||||
await self.stop_listening(user_id)
|
||||
|
||||
self.active = False
|
||||
logger.info("✓ All voice receivers stopped")
|
||||
|
||||
async def _send_audio_chunk(self, user_id: int, audio_data: bytes):
|
||||
"""
|
||||
Send audio chunk to STT client.
|
||||
|
||||
Buffers audio until we have 512 samples (32ms @ 16kHz) which is what
|
||||
Silero VAD expects. Discord sends 320 samples (20ms), so we buffer
|
||||
2 chunks and send 640 samples, then the STT server can split it.
|
||||
|
||||
Args:
|
||||
user_id: Discord user ID
|
||||
audio_data: PCM audio (int16, 16kHz mono, 320 samples = 640 bytes)
|
||||
"""
|
||||
stt_client = self.stt_clients.get(user_id)
|
||||
if not stt_client or not stt_client.is_connected():
|
||||
return
|
||||
|
||||
try:
|
||||
# Get or create buffer for this user
|
||||
if user_id not in self.audio_buffers:
|
||||
self.audio_buffers[user_id] = deque()
|
||||
|
||||
buffer = self.audio_buffers[user_id]
|
||||
buffer.append(audio_data)
|
||||
|
||||
# Silero VAD expects 512 samples @ 16kHz (1024 bytes)
|
||||
# Discord gives us 320 samples (640 bytes) every 20ms
|
||||
# Buffer 2 chunks = 640 samples = 1280 bytes, send as one chunk
|
||||
SAMPLES_NEEDED = 512 # What VAD wants
|
||||
BYTES_NEEDED = SAMPLES_NEEDED * 2 # int16 = 2 bytes per sample
|
||||
|
||||
# Check if we have enough buffered audio
|
||||
total_bytes = sum(len(chunk) for chunk in buffer)
|
||||
|
||||
if total_bytes >= BYTES_NEEDED:
|
||||
# Concatenate buffered chunks
|
||||
combined = b''.join(buffer)
|
||||
buffer.clear()
|
||||
|
||||
# Send in 512-sample (1024-byte) chunks
|
||||
for i in range(0, len(combined), BYTES_NEEDED):
|
||||
chunk = combined[i:i+BYTES_NEEDED]
|
||||
if len(chunk) == BYTES_NEEDED:
|
||||
await stt_client.send_audio(chunk)
|
||||
else:
|
||||
# Put remaining partial chunk back in buffer
|
||||
buffer.append(chunk)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send audio chunk for user {user_id}: {e}")
|
||||
|
||||
async def _on_vad_event(self, user_id: int, event: dict):
|
||||
"""
|
||||
Handle VAD event from STT.
|
||||
|
||||
Args:
|
||||
user_id: Discord user ID
|
||||
event: VAD event dictionary with 'event' and 'probability' keys
|
||||
"""
|
||||
user = self.users.get(user_id)
|
||||
event_type = event.get('event', 'unknown')
|
||||
probability = event.get('probability', 0.0)
|
||||
|
||||
logger.debug(f"VAD [{user.name if user else user_id}]: {event_type} (prob={probability:.3f})")
|
||||
|
||||
# Notify voice manager - pass the full event dict
|
||||
if hasattr(self.voice_manager, 'on_user_vad_event'):
|
||||
await self.voice_manager.on_user_vad_event(user_id, event)
|
||||
|
||||
async def _on_partial_transcript(self, user_id: int, text: str):
|
||||
"""
|
||||
Handle partial transcript from STT.
|
||||
|
||||
Args:
|
||||
user_id: Discord user ID
|
||||
text: Partial transcript text
|
||||
"""
|
||||
user = self.users.get(user_id)
|
||||
logger.info(f"[VOICE_RECEIVER] Partial [{user.name if user else user_id}]: {text}")
|
||||
print(f"[DEBUG] PARTIAL TRANSCRIPT RECEIVED: {text}") # Extra debug
|
||||
|
||||
# Notify voice manager
|
||||
if hasattr(self.voice_manager, 'on_partial_transcript'):
|
||||
await self.voice_manager.on_partial_transcript(user_id, text)
|
||||
|
||||
async def _on_final_transcript(self, user_id: int, text: str, user: discord.User):
|
||||
"""
|
||||
Handle final transcript from STT.
|
||||
|
||||
This triggers the LLM response generation.
|
||||
|
||||
Args:
|
||||
user_id: Discord user ID
|
||||
text: Final transcript text
|
||||
user: Discord user object
|
||||
"""
|
||||
logger.info(f"[VOICE_RECEIVER] Final [{user.name if user else user_id}]: {text}")
|
||||
print(f"[DEBUG] FINAL TRANSCRIPT RECEIVED: {text}") # Extra debug
|
||||
|
||||
# Notify voice manager - THIS TRIGGERS LLM RESPONSE
|
||||
if hasattr(self.voice_manager, 'on_final_transcript'):
|
||||
await self.voice_manager.on_final_transcript(user_id, text)
|
||||
|
||||
async def _on_interruption(self, user_id: int, probability: float):
|
||||
"""
|
||||
Handle interruption detection from STT.
|
||||
|
||||
This cancels Miku's current speech if user interrupts.
|
||||
|
||||
Args:
|
||||
user_id: Discord user ID
|
||||
probability: Interruption confidence probability
|
||||
"""
|
||||
user = self.users.get(user_id)
|
||||
logger.info(f"Interruption from [{user.name if user else user_id}] (prob={probability:.3f})")
|
||||
|
||||
# Notify voice manager - THIS CANCELS MIKU'S SPEECH
|
||||
if hasattr(self.voice_manager, 'on_user_interruption'):
|
||||
await self.voice_manager.on_user_interruption(user_id, probability)
|
||||
|
||||
def get_listening_users(self) -> list:
|
||||
"""
|
||||
Get list of users currently being listened to.
|
||||
|
||||
Returns:
|
||||
List of dicts with user_id, username, and connection status
|
||||
"""
|
||||
return [
|
||||
{
|
||||
'user_id': user_id,
|
||||
'username': user.name if user else 'Unknown',
|
||||
'connected': client.is_connected()
|
||||
}
|
||||
for user_id, (user, client) in
|
||||
[(uid, (self.users.get(uid), self.stt_clients.get(uid)))
|
||||
for uid in self.stt_clients.keys()]
|
||||
]
|
||||
|
||||
@voice_recv.AudioSink.listener()
|
||||
def on_voice_member_speaking_start(self, member: discord.Member):
|
||||
"""
|
||||
Called when a member starts speaking (green circle appears).
|
||||
|
||||
This is a virtual event from discord-ext-voice-recv based on packet activity.
|
||||
"""
|
||||
if member.id in self.stt_clients:
|
||||
logger.debug(f"🎤 {member.name} started speaking")
|
||||
|
||||
@voice_recv.AudioSink.listener()
|
||||
def on_voice_member_speaking_stop(self, member: discord.Member):
|
||||
"""
|
||||
Called when a member stops speaking (green circle disappears).
|
||||
|
||||
This is a virtual event from discord-ext-voice-recv based on packet activity.
|
||||
"""
|
||||
if member.id in self.stt_clients:
|
||||
logger.debug(f"🔇 {member.name} stopped speaking")
|
||||
419
bot/utils/voice_receiver.py.old
Normal file
419
bot/utils/voice_receiver.py.old
Normal file
@@ -0,0 +1,419 @@
|
||||
"""
|
||||
Discord Voice Receiver
|
||||
|
||||
Captures audio from Discord voice channels and streams to STT.
|
||||
Handles opus decoding and audio preprocessing.
|
||||
"""
|
||||
|
||||
import discord
|
||||
import audioop
|
||||
import numpy as np
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Dict, Optional
|
||||
from collections import deque
|
||||
|
||||
from utils.stt_client import STTClient
|
||||
|
||||
logger = logging.getLogger('voice_receiver')
|
||||
|
||||
|
||||
class VoiceReceiver(discord.sinks.Sink):
|
||||
"""
|
||||
Voice Receiver for Discord Audio Capture
|
||||
|
||||
Captures audio from Discord voice channels using discord.py's voice websocket.
|
||||
Processes Opus audio, decodes to PCM, resamples to 16kHz mono for STT.
|
||||
|
||||
Note: Standard discord.py doesn't have built-in audio receiving.
|
||||
This implementation hooks into the voice websocket directly.
|
||||
"""
|
||||
import asyncio
|
||||
import struct
|
||||
import audioop
|
||||
import logging
|
||||
from typing import Dict, Optional, Callable
|
||||
import discord
|
||||
|
||||
# Import opus decoder
|
||||
try:
|
||||
import discord.opus as opus
|
||||
if not opus.is_loaded():
|
||||
opus.load_opus('opus')
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to load opus: {e}")
|
||||
|
||||
from utils.stt_client import STTClient
|
||||
|
||||
logger = logging.getLogger('voice_receiver')
|
||||
|
||||
|
||||
class VoiceReceiver:
|
||||
"""
|
||||
Receives and processes audio from Discord voice channel.
|
||||
|
||||
This class monkey-patches the VoiceClient to intercept received RTP packets,
|
||||
decodes Opus audio, and forwards to STT clients.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
voice_client: discord.VoiceClient,
|
||||
voice_manager,
|
||||
stt_url: str = "ws://miku-stt:8001"
|
||||
):
|
||||
"""
|
||||
Initialize voice receiver.
|
||||
|
||||
Args:
|
||||
voice_client: Discord VoiceClient to receive audio from
|
||||
voice_manager: Voice manager instance for callbacks
|
||||
stt_url: Base URL for STT WebSocket server
|
||||
"""
|
||||
self.voice_client = voice_client
|
||||
self.voice_manager = voice_manager
|
||||
self.stt_url = stt_url
|
||||
|
||||
# Per-user STT clients
|
||||
self.stt_clients: Dict[int, STTClient] = {}
|
||||
|
||||
# Opus decoder instances per SSRC (one per user)
|
||||
self.opus_decoders: Dict[int, any] = {}
|
||||
|
||||
# Resampler state per user (for 48kHz → 16kHz)
|
||||
self.resample_state: Dict[int, tuple] = {}
|
||||
|
||||
# Original receive method (for restoration)
|
||||
self._original_receive = None
|
||||
|
||||
# Active flag
|
||||
self.active = False
|
||||
|
||||
logger.info("VoiceReceiver initialized")
|
||||
|
||||
async def start_listening(self, user_id: int, user: discord.User):
|
||||
"""
|
||||
Start listening to a specific user's audio.
|
||||
|
||||
Args:
|
||||
user_id: Discord user ID
|
||||
user: Discord User object
|
||||
"""
|
||||
if user_id in self.stt_clients:
|
||||
logger.warning(f"Already listening to user {user_id}")
|
||||
return
|
||||
|
||||
try:
|
||||
# Create STT client for this user
|
||||
stt_client = STTClient(
|
||||
user_id=user_id,
|
||||
stt_url=self.stt_url,
|
||||
on_vad_event=lambda event, prob: asyncio.create_task(
|
||||
self.voice_manager.on_user_vad_event(user_id, event)
|
||||
),
|
||||
on_partial_transcript=lambda text: asyncio.create_task(
|
||||
self.voice_manager.on_partial_transcript(user_id, text)
|
||||
),
|
||||
on_final_transcript=lambda text: asyncio.create_task(
|
||||
self.voice_manager.on_final_transcript(user_id, text, user)
|
||||
),
|
||||
on_interruption=lambda prob: asyncio.create_task(
|
||||
self.voice_manager.on_user_interruption(user_id, prob)
|
||||
)
|
||||
)
|
||||
|
||||
# Connect to STT server
|
||||
await stt_client.connect()
|
||||
|
||||
# Store client
|
||||
self.stt_clients[user_id] = stt_client
|
||||
|
||||
# Initialize opus decoder for this user if needed
|
||||
# (Will be done when we receive their SSRC)
|
||||
|
||||
# Patch voice client to receive audio if not already patched
|
||||
if not self.active:
|
||||
await self._patch_voice_client()
|
||||
|
||||
logger.info(f"✓ Started listening to user {user_id} ({user.name})")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start listening to user {user_id}: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def stop_listening(self, user_id: int):
|
||||
"""
|
||||
Stop listening to a specific user.
|
||||
|
||||
Args:
|
||||
user_id: Discord user ID
|
||||
"""
|
||||
if user_id not in self.stt_clients:
|
||||
logger.warning(f"Not listening to user {user_id}")
|
||||
return
|
||||
|
||||
try:
|
||||
# Disconnect STT client
|
||||
stt_client = self.stt_clients.pop(user_id)
|
||||
await stt_client.disconnect()
|
||||
|
||||
# Clean up decoder and resampler state
|
||||
# Note: We don't know the SSRC here, so we'll just remove by user_id
|
||||
# Actual cleanup happens in _process_audio when we match SSRC to user_id
|
||||
|
||||
# If no more clients, unpatch voice client
|
||||
if not self.stt_clients:
|
||||
await self._unpatch_voice_client()
|
||||
|
||||
logger.info(f"✓ Stopped listening to user {user_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to stop listening to user {user_id}: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def _patch_voice_client(self):
|
||||
"""Patch VoiceClient to intercept received audio packets."""
|
||||
logger.warning("⚠️ Audio receiving not yet implemented - discord.py doesn't support receiving by default")
|
||||
logger.warning("⚠️ You need discord.py-self or a custom fork with receiving support")
|
||||
logger.warning("⚠️ STT will not receive any audio until this is implemented")
|
||||
self.active = True
|
||||
# TODO: Implement RTP packet receiving
|
||||
# This requires either:
|
||||
# 1. Using discord.py-self which has receiving support
|
||||
# 2. Monkey-patching voice_client.ws to intercept packets
|
||||
# 3. Using a separate UDP socket listener
|
||||
|
||||
async def _unpatch_voice_client(self):
|
||||
"""Restore original VoiceClient behavior."""
|
||||
self.active = False
|
||||
logger.info("Unpatch voice client (receiving disabled)")
|
||||
|
||||
async def _process_audio(self, ssrc: int, opus_data: bytes):
|
||||
"""
|
||||
Process received Opus audio packet.
|
||||
|
||||
Args:
|
||||
ssrc: RTP SSRC (identifies the audio source/user)
|
||||
opus_data: Opus-encoded audio data
|
||||
"""
|
||||
# TODO: Map SSRC to user_id (requires tracking voice state updates)
|
||||
# For now, this is a placeholder
|
||||
pass
|
||||
|
||||
async def cleanup(self):
|
||||
"""Clean up all resources."""
|
||||
# Disconnect all STT clients
|
||||
for user_id in list(self.stt_clients.keys()):
|
||||
await self.stop_listening(user_id)
|
||||
|
||||
# Unpatch voice client
|
||||
if self.active:
|
||||
await self._unpatch_voice_client()
|
||||
|
||||
logger.info("VoiceReceiver cleanup complete") def __init__(self, voice_manager):
|
||||
"""
|
||||
Initialize voice receiver.
|
||||
|
||||
Args:
|
||||
voice_manager: Reference to VoiceManager for callbacks
|
||||
"""
|
||||
super().__init__()
|
||||
self.voice_manager = voice_manager
|
||||
|
||||
# Per-user STT clients
|
||||
self.stt_clients: Dict[int, STTClient] = {}
|
||||
|
||||
# Audio buffers per user (for resampling)
|
||||
self.audio_buffers: Dict[int, deque] = {}
|
||||
|
||||
# User info (for logging)
|
||||
self.users: Dict[int, discord.User] = {}
|
||||
|
||||
logger.info("Voice receiver initialized")
|
||||
|
||||
async def start_listening(self, user_id: int, user: discord.User):
|
||||
"""
|
||||
Start listening to a specific user.
|
||||
|
||||
Args:
|
||||
user_id: Discord user ID
|
||||
user: Discord user object
|
||||
"""
|
||||
if user_id in self.stt_clients:
|
||||
logger.warning(f"Already listening to user {user.name} ({user_id})")
|
||||
return
|
||||
|
||||
logger.info(f"Starting to listen to user {user.name} ({user_id})")
|
||||
|
||||
# Store user info
|
||||
self.users[user_id] = user
|
||||
|
||||
# Initialize audio buffer
|
||||
self.audio_buffers[user_id] = deque(maxlen=1000) # Max 1000 chunks
|
||||
|
||||
# Create STT client with callbacks
|
||||
stt_client = STTClient(
|
||||
user_id=str(user_id),
|
||||
on_vad_event=lambda event: self._on_vad_event(user_id, event),
|
||||
on_partial_transcript=lambda text, ts: self._on_partial_transcript(user_id, text, ts),
|
||||
on_final_transcript=lambda text, ts: self._on_final_transcript(user_id, text, ts),
|
||||
on_interruption=lambda prob: self._on_interruption(user_id, prob)
|
||||
)
|
||||
|
||||
# Connect to STT
|
||||
try:
|
||||
await stt_client.connect()
|
||||
self.stt_clients[user_id] = stt_client
|
||||
logger.info(f"✓ STT connected for user {user.name}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect STT for user {user.name}: {e}")
|
||||
|
||||
async def stop_listening(self, user_id: int):
|
||||
"""
|
||||
Stop listening to a specific user.
|
||||
|
||||
Args:
|
||||
user_id: Discord user ID
|
||||
"""
|
||||
if user_id not in self.stt_clients:
|
||||
return
|
||||
|
||||
user = self.users.get(user_id)
|
||||
logger.info(f"Stopping listening to user {user.name if user else user_id}")
|
||||
|
||||
# Disconnect STT client
|
||||
stt_client = self.stt_clients[user_id]
|
||||
await stt_client.disconnect()
|
||||
|
||||
# Cleanup
|
||||
del self.stt_clients[user_id]
|
||||
if user_id in self.audio_buffers:
|
||||
del self.audio_buffers[user_id]
|
||||
if user_id in self.users:
|
||||
del self.users[user_id]
|
||||
|
||||
logger.info(f"✓ Stopped listening to user {user.name if user else user_id}")
|
||||
|
||||
async def stop_all(self):
|
||||
"""Stop listening to all users."""
|
||||
logger.info("Stopping all voice receivers")
|
||||
|
||||
user_ids = list(self.stt_clients.keys())
|
||||
for user_id in user_ids:
|
||||
await self.stop_listening(user_id)
|
||||
|
||||
logger.info("✓ All voice receivers stopped")
|
||||
|
||||
def write(self, data: discord.sinks.core.AudioData):
|
||||
"""
|
||||
Called by discord.py when audio is received.
|
||||
|
||||
Args:
|
||||
data: Audio data from Discord
|
||||
"""
|
||||
# Get user ID from SSRC
|
||||
user_id = data.user.id if data.user else None
|
||||
|
||||
if not user_id:
|
||||
return
|
||||
|
||||
# Check if we're listening to this user
|
||||
if user_id not in self.stt_clients:
|
||||
return
|
||||
|
||||
# Process audio
|
||||
try:
|
||||
# Decode opus to PCM (48kHz stereo)
|
||||
pcm_data = data.pcm
|
||||
|
||||
# Convert stereo to mono if needed
|
||||
if len(pcm_data) % 4 == 0: # Stereo int16 (2 channels * 2 bytes)
|
||||
# Average left and right channels
|
||||
pcm_mono = audioop.tomono(pcm_data, 2, 0.5, 0.5)
|
||||
else:
|
||||
pcm_mono = pcm_data
|
||||
|
||||
# Resample from 48kHz to 16kHz
|
||||
# Discord sends 20ms chunks at 48kHz = 960 samples
|
||||
# We need 320 samples at 16kHz (20ms)
|
||||
pcm_16k = audioop.ratecv(pcm_mono, 2, 1, 48000, 16000, None)[0]
|
||||
|
||||
# Send to STT
|
||||
asyncio.create_task(self._send_audio_chunk(user_id, pcm_16k))
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing audio for user {user_id}: {e}")
|
||||
|
||||
async def _send_audio_chunk(self, user_id: int, audio_data: bytes):
|
||||
"""
|
||||
Send audio chunk to STT client.
|
||||
|
||||
Args:
|
||||
user_id: Discord user ID
|
||||
audio_data: PCM audio (int16, 16kHz mono)
|
||||
"""
|
||||
stt_client = self.stt_clients.get(user_id)
|
||||
if not stt_client or not stt_client.is_connected():
|
||||
return
|
||||
|
||||
try:
|
||||
await stt_client.send_audio(audio_data)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send audio chunk for user {user_id}: {e}")
|
||||
|
||||
async def _on_vad_event(self, user_id: int, event: dict):
|
||||
"""Handle VAD event from STT."""
|
||||
user = self.users.get(user_id)
|
||||
event_type = event.get('event')
|
||||
probability = event.get('probability', 0)
|
||||
|
||||
logger.debug(f"VAD [{user.name if user else user_id}]: {event_type} (prob={probability:.3f})")
|
||||
|
||||
# Notify voice manager
|
||||
if hasattr(self.voice_manager, 'on_user_vad_event'):
|
||||
await self.voice_manager.on_user_vad_event(user_id, event)
|
||||
|
||||
async def _on_partial_transcript(self, user_id: int, text: str, timestamp: float):
|
||||
"""Handle partial transcript from STT."""
|
||||
user = self.users.get(user_id)
|
||||
logger.info(f"Partial [{user.name if user else user_id}]: {text}")
|
||||
|
||||
# Notify voice manager
|
||||
if hasattr(self.voice_manager, 'on_partial_transcript'):
|
||||
await self.voice_manager.on_partial_transcript(user_id, text)
|
||||
|
||||
async def _on_final_transcript(self, user_id: int, text: str, timestamp: float):
|
||||
"""Handle final transcript from STT."""
|
||||
user = self.users.get(user_id)
|
||||
logger.info(f"Final [{user.name if user else user_id}]: {text}")
|
||||
|
||||
# Notify voice manager - THIS TRIGGERS LLM RESPONSE
|
||||
if hasattr(self.voice_manager, 'on_final_transcript'):
|
||||
await self.voice_manager.on_final_transcript(user_id, text)
|
||||
|
||||
async def _on_interruption(self, user_id: int, probability: float):
|
||||
"""Handle interruption detection from STT."""
|
||||
user = self.users.get(user_id)
|
||||
logger.info(f"Interruption from [{user.name if user else user_id}] (prob={probability:.3f})")
|
||||
|
||||
# Notify voice manager - THIS CANCELS MIKU'S SPEECH
|
||||
if hasattr(self.voice_manager, 'on_user_interruption'):
|
||||
await self.voice_manager.on_user_interruption(user_id, probability)
|
||||
|
||||
def cleanup(self):
|
||||
"""Cleanup resources."""
|
||||
logger.info("Cleaning up voice receiver")
|
||||
# Async cleanup will be called separately
|
||||
|
||||
def get_listening_users(self) -> list:
|
||||
"""Get list of users currently being listened to."""
|
||||
return [
|
||||
{
|
||||
'user_id': user_id,
|
||||
'username': user.name if user else 'Unknown',
|
||||
'connected': client.is_connected()
|
||||
}
|
||||
for user_id, (user, client) in
|
||||
[(uid, (self.users.get(uid), self.stt_clients.get(uid)))
|
||||
for uid in self.stt_clients.keys()]
|
||||
]
|
||||
Reference in New Issue
Block a user