Implemented experimental real production ready voice chat, relegated old flow to voice debug mode. New Web UI panel for Voice Chat.
This commit is contained in:
205
bot/utils/container_manager.py
Normal file
205
bot/utils/container_manager.py
Normal file
@@ -0,0 +1,205 @@
|
||||
# container_manager.py
|
||||
"""
|
||||
Manages Docker containers for STT and TTS services.
|
||||
Handles startup, shutdown, and warmup detection.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import subprocess
|
||||
import aiohttp
|
||||
from utils.logger import get_logger
|
||||
|
||||
logger = get_logger('container_manager')
|
||||
|
||||
class ContainerManager:
|
||||
"""Manages STT and TTS Docker containers."""
|
||||
|
||||
# Container names from docker-compose.yml
|
||||
STT_CONTAINER = "miku-stt"
|
||||
TTS_CONTAINER = "miku-rvc-api"
|
||||
|
||||
# Warmup check endpoints
|
||||
STT_HEALTH_URL = "http://miku-stt:8767/health" # HTTP health check endpoint
|
||||
TTS_HEALTH_URL = "http://miku-rvc-api:8765/health"
|
||||
|
||||
# Warmup timeouts
|
||||
STT_WARMUP_TIMEOUT = 30 # seconds
|
||||
TTS_WARMUP_TIMEOUT = 60 # seconds (RVC takes longer)
|
||||
|
||||
@classmethod
|
||||
async def start_voice_containers(cls) -> bool:
|
||||
"""
|
||||
Start STT and TTS containers and wait for them to warm up.
|
||||
|
||||
Returns:
|
||||
bool: True if both containers started and warmed up successfully
|
||||
"""
|
||||
logger.info("🚀 Starting voice chat containers...")
|
||||
|
||||
try:
|
||||
# Start STT container using docker start (assumes container exists)
|
||||
logger.info(f"Starting {cls.STT_CONTAINER}...")
|
||||
result = subprocess.run(
|
||||
["docker", "start", cls.STT_CONTAINER],
|
||||
capture_output=True,
|
||||
text=True
|
||||
)
|
||||
|
||||
if result.returncode != 0:
|
||||
logger.error(f"Failed to start {cls.STT_CONTAINER}: {result.stderr}")
|
||||
return False
|
||||
|
||||
logger.info(f"✓ {cls.STT_CONTAINER} started")
|
||||
|
||||
# Start TTS container
|
||||
logger.info(f"Starting {cls.TTS_CONTAINER}...")
|
||||
result = subprocess.run(
|
||||
["docker", "start", cls.TTS_CONTAINER],
|
||||
capture_output=True,
|
||||
text=True
|
||||
)
|
||||
|
||||
if result.returncode != 0:
|
||||
logger.error(f"Failed to start {cls.TTS_CONTAINER}: {result.stderr}")
|
||||
return False
|
||||
|
||||
logger.info(f"✓ {cls.TTS_CONTAINER} started")
|
||||
|
||||
# Wait for warmup
|
||||
logger.info("⏳ Waiting for containers to warm up...")
|
||||
|
||||
stt_ready = await cls._wait_for_stt_warmup()
|
||||
if not stt_ready:
|
||||
logger.error("STT failed to warm up")
|
||||
return False
|
||||
|
||||
tts_ready = await cls._wait_for_tts_warmup()
|
||||
if not tts_ready:
|
||||
logger.error("TTS failed to warm up")
|
||||
return False
|
||||
|
||||
logger.info("✅ All voice containers ready!")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error starting voice containers: {e}")
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
async def stop_voice_containers(cls) -> bool:
|
||||
"""
|
||||
Stop STT and TTS containers.
|
||||
|
||||
Returns:
|
||||
bool: True if containers stopped successfully
|
||||
"""
|
||||
logger.info("🛑 Stopping voice chat containers...")
|
||||
|
||||
try:
|
||||
# Stop both containers
|
||||
result = subprocess.run(
|
||||
["docker", "stop", cls.STT_CONTAINER, cls.TTS_CONTAINER],
|
||||
capture_output=True,
|
||||
text=True
|
||||
)
|
||||
|
||||
if result.returncode != 0:
|
||||
logger.error(f"Failed to stop containers: {result.stderr}")
|
||||
return False
|
||||
|
||||
logger.info("✓ Voice containers stopped")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error stopping voice containers: {e}")
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
async def _wait_for_stt_warmup(cls) -> bool:
|
||||
"""
|
||||
Wait for STT container to be ready by checking health endpoint.
|
||||
|
||||
Returns:
|
||||
bool: True if STT is ready within timeout
|
||||
"""
|
||||
start_time = asyncio.get_event_loop().time()
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
while (asyncio.get_event_loop().time() - start_time) < cls.STT_WARMUP_TIMEOUT:
|
||||
try:
|
||||
async with session.get(cls.STT_HEALTH_URL, timeout=aiohttp.ClientTimeout(total=2)) as resp:
|
||||
if resp.status == 200:
|
||||
data = await resp.json()
|
||||
if data.get("status") == "ready" and data.get("warmed_up"):
|
||||
logger.info("✓ STT is ready")
|
||||
return True
|
||||
except Exception:
|
||||
# Not ready yet, wait and retry
|
||||
pass
|
||||
|
||||
await asyncio.sleep(2)
|
||||
|
||||
logger.error(f"STT warmup timeout ({cls.STT_WARMUP_TIMEOUT}s)")
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
async def _wait_for_tts_warmup(cls) -> bool:
|
||||
"""
|
||||
Wait for TTS container to be ready by checking health endpoint.
|
||||
|
||||
Returns:
|
||||
bool: True if TTS is ready within timeout
|
||||
"""
|
||||
start_time = asyncio.get_event_loop().time()
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
while (asyncio.get_event_loop().time() - start_time) < cls.TTS_WARMUP_TIMEOUT:
|
||||
try:
|
||||
async with session.get(cls.TTS_HEALTH_URL, timeout=aiohttp.ClientTimeout(total=2)) as resp:
|
||||
if resp.status == 200:
|
||||
data = await resp.json()
|
||||
# RVC API returns "status": "healthy", not "ready"
|
||||
status_ok = data.get("status") in ["ready", "healthy"]
|
||||
if status_ok and data.get("warmed_up"):
|
||||
logger.info("✓ TTS is ready")
|
||||
return True
|
||||
except Exception:
|
||||
# Not ready yet, wait and retry
|
||||
pass
|
||||
|
||||
await asyncio.sleep(2)
|
||||
|
||||
logger.error(f"TTS warmup timeout ({cls.TTS_WARMUP_TIMEOUT}s)")
|
||||
return False
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
async def are_containers_running(cls) -> tuple[bool, bool]:
|
||||
"""
|
||||
Check if STT and TTS containers are currently running.
|
||||
|
||||
Returns:
|
||||
tuple[bool, bool]: (stt_running, tts_running)
|
||||
"""
|
||||
try:
|
||||
# Check STT
|
||||
result = subprocess.run(
|
||||
["docker", "inspect", "-f", "{{.State.Running}}", cls.STT_CONTAINER],
|
||||
capture_output=True,
|
||||
text=True
|
||||
)
|
||||
stt_running = result.returncode == 0 and result.stdout.strip() == "true"
|
||||
|
||||
# Check TTS
|
||||
result = subprocess.run(
|
||||
["docker", "inspect", "-f", "{{.State.Running}}", cls.TTS_CONTAINER],
|
||||
capture_output=True,
|
||||
text=True
|
||||
)
|
||||
tts_running = result.returncode == 0 and result.stdout.strip() == "true"
|
||||
|
||||
return (stt_running, tts_running)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking container status: {e}")
|
||||
return (False, False)
|
||||
@@ -62,6 +62,7 @@ COMPONENTS = {
|
||||
'voice_manager': 'Voice channel session management',
|
||||
'voice_commands': 'Voice channel commands',
|
||||
'voice_audio': 'Voice audio streaming and TTS',
|
||||
'container_manager': 'Docker container lifecycle management',
|
||||
'error_handler': 'Error detection and webhook notifications',
|
||||
}
|
||||
|
||||
|
||||
@@ -1,11 +1,15 @@
|
||||
"""
|
||||
STT Client for Discord Bot
|
||||
STT Client for Discord Bot (RealtimeSTT Version)
|
||||
|
||||
WebSocket client that connects to the STT server and handles:
|
||||
WebSocket client that connects to the RealtimeSTT server and handles:
|
||||
- Audio streaming to STT
|
||||
- Receiving VAD events
|
||||
- Receiving partial/final transcripts
|
||||
- Interruption detection
|
||||
|
||||
Protocol:
|
||||
- Client sends: binary audio data (16kHz, 16-bit mono PCM)
|
||||
- Client sends: JSON {"command": "reset"} to reset state
|
||||
- Server sends: JSON {"type": "partial", "text": "...", "timestamp": float}
|
||||
- Server sends: JSON {"type": "final", "text": "...", "timestamp": float}
|
||||
"""
|
||||
|
||||
import aiohttp
|
||||
@@ -19,7 +23,7 @@ logger = logging.getLogger('stt_client')
|
||||
|
||||
class STTClient:
|
||||
"""
|
||||
WebSocket client for STT server communication.
|
||||
WebSocket client for RealtimeSTT server communication.
|
||||
|
||||
Handles audio streaming and receives transcription events.
|
||||
"""
|
||||
@@ -27,34 +31,28 @@ class STTClient:
|
||||
def __init__(
|
||||
self,
|
||||
user_id: str,
|
||||
stt_url: str = "ws://miku-stt:8766/ws/stt",
|
||||
on_vad_event: Optional[Callable] = None,
|
||||
stt_url: str = "ws://miku-stt:8766",
|
||||
on_partial_transcript: Optional[Callable] = None,
|
||||
on_final_transcript: Optional[Callable] = None,
|
||||
on_interruption: Optional[Callable] = None
|
||||
):
|
||||
"""
|
||||
Initialize STT client.
|
||||
|
||||
Args:
|
||||
user_id: Discord user ID
|
||||
stt_url: Base WebSocket URL for STT server
|
||||
on_vad_event: Callback for VAD events (event_dict)
|
||||
user_id: Discord user ID (for logging purposes)
|
||||
stt_url: WebSocket URL for STT server
|
||||
on_partial_transcript: Callback for partial transcripts (text, timestamp)
|
||||
on_final_transcript: Callback for final transcripts (text, timestamp)
|
||||
on_interruption: Callback for interruption detection (probability)
|
||||
"""
|
||||
self.user_id = user_id
|
||||
self.stt_url = f"{stt_url}/{user_id}"
|
||||
self.stt_url = stt_url
|
||||
|
||||
# Callbacks
|
||||
self.on_vad_event = on_vad_event
|
||||
self.on_partial_transcript = on_partial_transcript
|
||||
self.on_final_transcript = on_final_transcript
|
||||
self.on_interruption = on_interruption
|
||||
|
||||
# Connection state
|
||||
self.websocket: Optional[aiohttp.ClientWebSocket] = None
|
||||
self.websocket: Optional[aiohttp.ClientWebSocketResponse] = None
|
||||
self.session: Optional[aiohttp.ClientSession] = None
|
||||
self.connected = False
|
||||
self.running = False
|
||||
@@ -65,7 +63,7 @@ class STTClient:
|
||||
logger.info(f"STT client initialized for user {user_id}")
|
||||
|
||||
async def connect(self):
|
||||
"""Connect to STT WebSocket server."""
|
||||
"""Connect to RealtimeSTT WebSocket server."""
|
||||
if self.connected:
|
||||
logger.warning(f"Already connected for user {self.user_id}")
|
||||
return
|
||||
@@ -74,202 +72,156 @@ class STTClient:
|
||||
self.session = aiohttp.ClientSession()
|
||||
self.websocket = await self.session.ws_connect(
|
||||
self.stt_url,
|
||||
heartbeat=30
|
||||
heartbeat=30,
|
||||
receive_timeout=60
|
||||
)
|
||||
|
||||
# Wait for ready message
|
||||
ready_msg = await self.websocket.receive_json()
|
||||
logger.info(f"STT connected for user {self.user_id}: {ready_msg}")
|
||||
|
||||
self.connected = True
|
||||
self.running = True
|
||||
|
||||
# Start receive task
|
||||
self._receive_task = asyncio.create_task(self._receive_events())
|
||||
# Start background task to receive messages
|
||||
self._receive_task = asyncio.create_task(self._receive_loop())
|
||||
|
||||
logger.info(f"✓ STT WebSocket connected for user {self.user_id}")
|
||||
|
||||
logger.info(f"Connected to STT server at {self.stt_url} for user {self.user_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect STT for user {self.user_id}: {e}", exc_info=True)
|
||||
await self.disconnect()
|
||||
logger.error(f"Failed to connect to STT server: {e}")
|
||||
await self._cleanup()
|
||||
raise
|
||||
|
||||
async def disconnect(self):
|
||||
"""Disconnect from STT WebSocket."""
|
||||
logger.info(f"Disconnecting STT for user {self.user_id}")
|
||||
|
||||
"""Disconnect from STT server."""
|
||||
self.running = False
|
||||
self.connected = False
|
||||
|
||||
# Cancel receive task
|
||||
if self._receive_task and not self._receive_task.done():
|
||||
if self._receive_task:
|
||||
self._receive_task.cancel()
|
||||
try:
|
||||
await self._receive_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._receive_task = None
|
||||
|
||||
# Close WebSocket
|
||||
await self._cleanup()
|
||||
logger.info(f"Disconnected from STT server for user {self.user_id}")
|
||||
|
||||
async def _cleanup(self):
|
||||
"""Clean up WebSocket and session."""
|
||||
if self.websocket:
|
||||
await self.websocket.close()
|
||||
try:
|
||||
await self.websocket.close()
|
||||
except Exception:
|
||||
pass
|
||||
self.websocket = None
|
||||
|
||||
# Close session
|
||||
if self.session:
|
||||
await self.session.close()
|
||||
try:
|
||||
await self.session.close()
|
||||
except Exception:
|
||||
pass
|
||||
self.session = None
|
||||
|
||||
logger.info(f"✓ STT disconnected for user {self.user_id}")
|
||||
self.connected = False
|
||||
|
||||
async def send_audio(self, audio_data: bytes):
|
||||
"""
|
||||
Send audio chunk to STT server.
|
||||
Send raw audio data to STT server.
|
||||
|
||||
Args:
|
||||
audio_data: PCM audio (int16, 16kHz mono)
|
||||
audio_data: Raw PCM audio (16kHz, 16-bit mono, little-endian)
|
||||
"""
|
||||
if not self.connected or not self.websocket:
|
||||
logger.warning(f"Cannot send audio, not connected for user {self.user_id}")
|
||||
return
|
||||
|
||||
try:
|
||||
await self.websocket.send_bytes(audio_data)
|
||||
logger.debug(f"Sent {len(audio_data)} bytes to STT")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send audio to STT: {e}")
|
||||
self.connected = False
|
||||
logger.error(f"Failed to send audio: {e}")
|
||||
await self._cleanup()
|
||||
|
||||
async def send_final(self):
|
||||
"""
|
||||
Request final transcription from STT server.
|
||||
|
||||
Call this when the user stops speaking to get the final transcript.
|
||||
"""
|
||||
async def reset(self):
|
||||
"""Reset STT state (clear any pending transcription)."""
|
||||
if not self.connected or not self.websocket:
|
||||
logger.warning(f"Cannot send final command, not connected for user {self.user_id}")
|
||||
return
|
||||
|
||||
try:
|
||||
command = json.dumps({"type": "final"})
|
||||
await self.websocket.send_str(command)
|
||||
logger.debug(f"Sent final command to STT")
|
||||
|
||||
await self.websocket.send_json({"command": "reset"})
|
||||
logger.debug(f"Sent reset command for user {self.user_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send final command to STT: {e}")
|
||||
self.connected = False
|
||||
logger.error(f"Failed to send reset: {e}")
|
||||
|
||||
async def send_reset(self):
|
||||
"""
|
||||
Reset the STT server's audio buffer.
|
||||
|
||||
Call this to clear any buffered audio.
|
||||
"""
|
||||
if not self.connected or not self.websocket:
|
||||
logger.warning(f"Cannot send reset command, not connected for user {self.user_id}")
|
||||
return
|
||||
|
||||
try:
|
||||
command = json.dumps({"type": "reset"})
|
||||
await self.websocket.send_str(command)
|
||||
logger.debug(f"Sent reset command to STT")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send reset command to STT: {e}")
|
||||
self.connected = False
|
||||
def is_connected(self) -> bool:
|
||||
"""Check if connected to STT server."""
|
||||
return self.connected and self.websocket is not None
|
||||
|
||||
async def _receive_events(self):
|
||||
"""Background task to receive events from STT server."""
|
||||
async def _receive_loop(self):
|
||||
"""Background task to receive messages from STT server."""
|
||||
try:
|
||||
while self.running and self.websocket:
|
||||
try:
|
||||
msg = await self.websocket.receive()
|
||||
msg = await asyncio.wait_for(
|
||||
self.websocket.receive(),
|
||||
timeout=5.0
|
||||
)
|
||||
|
||||
if msg.type == aiohttp.WSMsgType.TEXT:
|
||||
event = json.loads(msg.data)
|
||||
await self._handle_event(event)
|
||||
|
||||
await self._handle_message(msg.data)
|
||||
elif msg.type == aiohttp.WSMsgType.CLOSED:
|
||||
logger.info(f"STT WebSocket closed for user {self.user_id}")
|
||||
logger.warning(f"STT WebSocket closed for user {self.user_id}")
|
||||
break
|
||||
|
||||
elif msg.type == aiohttp.WSMsgType.ERROR:
|
||||
logger.error(f"STT WebSocket error for user {self.user_id}")
|
||||
break
|
||||
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Error receiving STT event: {e}", exc_info=True)
|
||||
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
# Timeout is fine, just continue
|
||||
continue
|
||||
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f"Error in STT receive loop: {e}")
|
||||
finally:
|
||||
self.connected = False
|
||||
logger.info(f"STT receive task ended for user {self.user_id}")
|
||||
|
||||
async def _handle_event(self, event: dict):
|
||||
"""
|
||||
Handle incoming STT event.
|
||||
|
||||
Args:
|
||||
event: Event dictionary from STT server
|
||||
"""
|
||||
event_type = event.get('type')
|
||||
|
||||
if event_type == 'transcript':
|
||||
# New ONNX server protocol: single transcript type with is_final flag
|
||||
text = event.get('text', '')
|
||||
is_final = event.get('is_final', False)
|
||||
timestamp = event.get('timestamp', 0)
|
||||
async def _handle_message(self, data: str):
|
||||
"""Handle a message from the STT server."""
|
||||
try:
|
||||
message = json.loads(data)
|
||||
msg_type = message.get("type")
|
||||
text = message.get("text", "")
|
||||
timestamp = message.get("timestamp", 0)
|
||||
|
||||
if is_final:
|
||||
logger.info(f"Final transcript [{self.user_id}]: {text}")
|
||||
if self.on_final_transcript:
|
||||
await self.on_final_transcript(text, timestamp)
|
||||
else:
|
||||
logger.info(f"Partial transcript [{self.user_id}]: {text}")
|
||||
if self.on_partial_transcript:
|
||||
await self.on_partial_transcript(text, timestamp)
|
||||
|
||||
elif event_type == 'vad':
|
||||
# VAD event: speech detection (legacy support)
|
||||
logger.debug(f"VAD event: {event}")
|
||||
if self.on_vad_event:
|
||||
await self.on_vad_event(event)
|
||||
|
||||
elif event_type == 'partial':
|
||||
# Legacy protocol support: partial transcript
|
||||
text = event.get('text', '')
|
||||
timestamp = event.get('timestamp', 0)
|
||||
logger.info(f"Partial transcript [{self.user_id}]: {text}")
|
||||
if self.on_partial_transcript:
|
||||
await self.on_partial_transcript(text, timestamp)
|
||||
|
||||
elif event_type == 'final':
|
||||
# Legacy protocol support: final transcript
|
||||
text = event.get('text', '')
|
||||
timestamp = event.get('timestamp', 0)
|
||||
logger.info(f"Final transcript [{self.user_id}]: {text}")
|
||||
if self.on_final_transcript:
|
||||
await self.on_final_transcript(text, timestamp)
|
||||
|
||||
elif event_type == 'interruption':
|
||||
# Interruption detected (legacy support)
|
||||
probability = event.get('probability', 0)
|
||||
logger.info(f"Interruption detected from user {self.user_id} (prob={probability:.3f})")
|
||||
if self.on_interruption:
|
||||
await self.on_interruption(probability)
|
||||
|
||||
elif event_type == 'info':
|
||||
# Info message
|
||||
logger.info(f"STT info: {event.get('message', '')}")
|
||||
|
||||
elif event_type == 'error':
|
||||
# Error message
|
||||
logger.error(f"STT error: {event.get('message', '')}")
|
||||
|
||||
else:
|
||||
logger.warning(f"Unknown STT event type: {event_type}")
|
||||
if msg_type == "partial":
|
||||
if self.on_partial_transcript and text:
|
||||
await self._call_callback(
|
||||
self.on_partial_transcript,
|
||||
text,
|
||||
timestamp
|
||||
)
|
||||
|
||||
elif msg_type == "final":
|
||||
if self.on_final_transcript and text:
|
||||
await self._call_callback(
|
||||
self.on_final_transcript,
|
||||
text,
|
||||
timestamp
|
||||
)
|
||||
|
||||
elif msg_type == "connected":
|
||||
logger.info(f"STT server confirmed connection for user {self.user_id}")
|
||||
|
||||
elif msg_type == "error":
|
||||
error_msg = message.get("error", "Unknown error")
|
||||
logger.error(f"STT server error: {error_msg}")
|
||||
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Invalid JSON from STT server: {data[:100]}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling STT message: {e}")
|
||||
|
||||
def is_connected(self) -> bool:
|
||||
"""Check if STT client is connected."""
|
||||
return self.connected
|
||||
async def _call_callback(self, callback, *args):
|
||||
"""Safely call a callback, handling both sync and async functions."""
|
||||
try:
|
||||
result = callback(*args)
|
||||
if asyncio.iscoroutine(result):
|
||||
await result
|
||||
except Exception as e:
|
||||
logger.error(f"Error in STT callback: {e}")
|
||||
|
||||
@@ -6,6 +6,7 @@ Uses aiohttp for WebSocket communication (compatible with FastAPI).
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import re
|
||||
import numpy as np
|
||||
from typing import Optional
|
||||
import discord
|
||||
@@ -29,6 +30,25 @@ CHANNELS = 2 # Stereo for Discord
|
||||
FRAME_LENGTH = 0.02 # 20ms frames
|
||||
SAMPLES_PER_FRAME = int(SAMPLE_RATE * FRAME_LENGTH) # 960 samples
|
||||
|
||||
# Emoji pattern for filtering
|
||||
# Covers most emoji ranges including emoticons, symbols, pictographs, etc.
|
||||
EMOJI_PATTERN = re.compile(
|
||||
"["
|
||||
"\U0001F600-\U0001F64F" # emoticons
|
||||
"\U0001F300-\U0001F5FF" # symbols & pictographs
|
||||
"\U0001F680-\U0001F6FF" # transport & map symbols
|
||||
"\U0001F1E0-\U0001F1FF" # flags (iOS)
|
||||
"\U00002702-\U000027B0" # dingbats
|
||||
"\U000024C2-\U0001F251" # enclosed characters
|
||||
"\U0001F900-\U0001F9FF" # supplemental symbols and pictographs
|
||||
"\U0001FA00-\U0001FA6F" # chess symbols
|
||||
"\U0001FA70-\U0001FAFF" # symbols and pictographs extended-A
|
||||
"\U00002600-\U000026FF" # miscellaneous symbols
|
||||
"\U00002700-\U000027BF" # dingbats
|
||||
"]+",
|
||||
flags=re.UNICODE
|
||||
)
|
||||
|
||||
|
||||
class MikuVoiceSource(discord.AudioSource):
|
||||
"""
|
||||
@@ -38,8 +58,9 @@ class MikuVoiceSource(discord.AudioSource):
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.websocket_url = "ws://172.25.0.1:8765/ws/stream"
|
||||
self.health_url = "http://172.25.0.1:8765/health"
|
||||
# Use Docker hostname for RVC service (miku-rvc-api is on miku-voice-network)
|
||||
self.websocket_url = "ws://miku-rvc-api:8765/ws/stream"
|
||||
self.health_url = "http://miku-rvc-api:8765/health"
|
||||
self.session = None
|
||||
self.websocket = None
|
||||
self.audio_buffer = bytearray()
|
||||
@@ -230,11 +251,26 @@ class MikuVoiceSource(discord.AudioSource):
|
||||
"""
|
||||
Send a text token to TTS for voice generation.
|
||||
Queues tokens if pipeline is still warming up or connection failed.
|
||||
Filters out emojis to prevent TTS hallucinations.
|
||||
|
||||
Args:
|
||||
token: Text token to synthesize
|
||||
pitch_shift: Pitch adjustment (-12 to +12 semitones)
|
||||
"""
|
||||
# Filter out emojis from the token (preserve whitespace!)
|
||||
original_token = token
|
||||
token = EMOJI_PATTERN.sub('', token)
|
||||
|
||||
# If token is now empty or only whitespace after emoji removal, skip it
|
||||
if not token or not token.strip():
|
||||
if original_token != token:
|
||||
logger.debug(f"Skipped token (only emojis): '{original_token}'")
|
||||
return
|
||||
|
||||
# Log if we filtered out emojis
|
||||
if original_token != token:
|
||||
logger.debug(f"Filtered emojis from token: '{original_token}' -> '{token}'")
|
||||
|
||||
# If not warmed up yet or no connection, queue the token
|
||||
if not self.warmed_up or not self.websocket:
|
||||
self.token_queue.append((token, pitch_shift))
|
||||
|
||||
@@ -398,6 +398,13 @@ class VoiceSession:
|
||||
# Voice chat conversation history (last 8 exchanges)
|
||||
self.conversation_history = [] # List of {"role": "user"/"assistant", "content": str}
|
||||
|
||||
# Voice call management (for automated calls from web UI)
|
||||
self.call_user_id: Optional[int] = None # User ID that was called
|
||||
self.call_timeout_task: Optional[asyncio.Task] = None # 30min timeout task
|
||||
self.user_has_joined = False # Track if user joined the call
|
||||
self.auto_leave_task: Optional[asyncio.Task] = None # 45s auto-leave task
|
||||
self.user_leave_time: Optional[float] = None # When user left the channel
|
||||
|
||||
logger.info(f"VoiceSession created for {voice_channel.name} in guild {guild_id}")
|
||||
|
||||
async def start_audio_streaming(self):
|
||||
@@ -488,6 +495,57 @@ class VoiceSession:
|
||||
self.voice_receiver = None
|
||||
logger.info("✓ Stopped all listening")
|
||||
|
||||
async def on_user_join(self, user_id: int):
|
||||
"""Called when a user joins the voice channel."""
|
||||
# If this is a voice call and the expected user joined
|
||||
if self.call_user_id and user_id == self.call_user_id:
|
||||
self.user_has_joined = True
|
||||
logger.info(f"✓ Call user {user_id} joined the channel")
|
||||
|
||||
# Cancel timeout task since user joined
|
||||
if self.call_timeout_task:
|
||||
self.call_timeout_task.cancel()
|
||||
self.call_timeout_task = None
|
||||
|
||||
# Cancel auto-leave task if it was running
|
||||
if self.auto_leave_task:
|
||||
self.auto_leave_task.cancel()
|
||||
self.auto_leave_task = None
|
||||
self.user_leave_time = None
|
||||
|
||||
async def on_user_leave(self, user_id: int):
|
||||
"""Called when a user leaves the voice channel."""
|
||||
# If this is the call user leaving
|
||||
if self.call_user_id and user_id == self.call_user_id and self.user_has_joined:
|
||||
import time
|
||||
self.user_leave_time = time.time()
|
||||
logger.info(f"📴 Call user {user_id} left - starting 45s auto-leave timer")
|
||||
|
||||
# Start 45s auto-leave timer
|
||||
self.auto_leave_task = asyncio.create_task(self._auto_leave_after_user_disconnect())
|
||||
|
||||
async def _auto_leave_after_user_disconnect(self):
|
||||
"""Auto-leave 45s after user disconnects."""
|
||||
try:
|
||||
await asyncio.sleep(45)
|
||||
|
||||
logger.info("⏰ 45s timeout reached - auto-leaving voice channel")
|
||||
|
||||
# End the session (will trigger cleanup)
|
||||
from utils.voice_manager import VoiceSessionManager
|
||||
session_manager = VoiceSessionManager()
|
||||
await session_manager.end_session()
|
||||
|
||||
# Stop containers
|
||||
from utils.container_manager import ContainerManager
|
||||
await ContainerManager.stop_voice_containers()
|
||||
|
||||
logger.info("✓ Auto-leave complete")
|
||||
|
||||
except asyncio.CancelledError:
|
||||
# User rejoined, normal operation
|
||||
logger.info("Auto-leave cancelled - user rejoined")
|
||||
|
||||
async def on_user_vad_event(self, user_id: int, event: dict):
|
||||
"""Called when VAD detects speech state change."""
|
||||
event_type = event.get('event')
|
||||
@@ -515,7 +573,10 @@ class VoiceSession:
|
||||
# Get user info for notification
|
||||
user = self.voice_channel.guild.get_member(user_id)
|
||||
user_name = user.name if user else f"User {user_id}"
|
||||
await self.text_channel.send(f"💬 *{user_name} said: \"{text}\" (interrupted but too brief - talk longer to interrupt)*")
|
||||
|
||||
# Only send message if debug mode is on
|
||||
if globals.VOICE_DEBUG_MODE:
|
||||
await self.text_channel.send(f"💬 *{user_name} said: \"{text}\" (interrupted but too brief - talk longer to interrupt)*")
|
||||
return
|
||||
|
||||
logger.info(f"✓ Processing final transcript (miku_speaking={self.miku_speaking})")
|
||||
@@ -530,12 +591,14 @@ class VoiceSession:
|
||||
stop_phrases = ["stop talking", "be quiet", "shut up", "stop speaking", "silence"]
|
||||
if any(phrase in text.lower() for phrase in stop_phrases):
|
||||
logger.info(f"🤫 Stop command detected: {text}")
|
||||
await self.text_channel.send(f"🎤 {user.name}: *\"{text}\"*")
|
||||
await self.text_channel.send(f"🤫 *Miku goes quiet*")
|
||||
if globals.VOICE_DEBUG_MODE:
|
||||
await self.text_channel.send(f"🎤 {user.name}: *\"{text}\"*")
|
||||
await self.text_channel.send(f"🤫 *Miku goes quiet*")
|
||||
return
|
||||
|
||||
# Show what user said
|
||||
await self.text_channel.send(f"🎤 {user.name}: *\"{text}\"*")
|
||||
# Show what user said (only in debug mode)
|
||||
if globals.VOICE_DEBUG_MODE:
|
||||
await self.text_channel.send(f"🎤 {user.name}: *\"{text}\"*")
|
||||
|
||||
# Generate LLM response and speak it
|
||||
await self._generate_voice_response(user, text)
|
||||
@@ -582,14 +645,15 @@ class VoiceSession:
|
||||
logger.info(f"⏸️ Pausing for {self.interruption_silence_duration}s after interruption")
|
||||
await asyncio.sleep(self.interruption_silence_duration)
|
||||
|
||||
# 5. Add interruption marker to conversation history
|
||||
# Add interruption marker to conversation history
|
||||
self.conversation_history.append({
|
||||
"role": "assistant",
|
||||
"content": "[INTERRUPTED - user started speaking]"
|
||||
})
|
||||
|
||||
# Show interruption in chat
|
||||
await self.text_channel.send(f"⚠️ *{user_name} interrupted Miku*")
|
||||
# Show interruption in chat (only in debug mode)
|
||||
if globals.VOICE_DEBUG_MODE:
|
||||
await self.text_channel.send(f"⚠️ *{user_name} interrupted Miku*")
|
||||
|
||||
logger.info(f"✓ Interruption handled, ready for next input")
|
||||
|
||||
@@ -599,8 +663,10 @@ class VoiceSession:
|
||||
Called when VAD-based interruption detection is used.
|
||||
"""
|
||||
await self.on_user_interruption(user_id)
|
||||
user = self.voice_channel.guild.get_member(user_id)
|
||||
await self.text_channel.send(f"⚠️ *{user.name if user else 'User'} interrupted Miku*")
|
||||
# Only show interruption message in debug mode
|
||||
if globals.VOICE_DEBUG_MODE:
|
||||
user = self.voice_channel.guild.get_member(user_id)
|
||||
await self.text_channel.send(f"⚠️ *{user.name if user else 'User'} interrupted Miku*")
|
||||
|
||||
async def _generate_voice_response(self, user: discord.User, text: str):
|
||||
"""
|
||||
@@ -624,13 +690,13 @@ class VoiceSession:
|
||||
self.miku_speaking = True
|
||||
logger.info(f" → miku_speaking is now: {self.miku_speaking}")
|
||||
|
||||
# Show processing
|
||||
await self.text_channel.send(f"💭 *Miku is thinking...*")
|
||||
# Show processing (only in debug mode)
|
||||
if globals.VOICE_DEBUG_MODE:
|
||||
await self.text_channel.send(f"💭 *Miku is thinking...*")
|
||||
|
||||
# Import here to avoid circular imports
|
||||
from utils.llm import get_current_gpu_url
|
||||
import aiohttp
|
||||
import globals
|
||||
|
||||
# Load personality and lore
|
||||
miku_lore = ""
|
||||
@@ -657,8 +723,11 @@ VOICE CHAT CONTEXT:
|
||||
* Stories/explanations: 4-6 sentences when asked for details
|
||||
- Match the user's energy and conversation style
|
||||
- IMPORTANT: Only respond in ENGLISH! The TTS system cannot handle Japanese or other languages well.
|
||||
- IMPORTANT: Do not include emojis in your response! The TTS system cannot handle them well.
|
||||
- IMPORTANT: Do NOT prefix your response with your name (like "Miku:" or "Hatsune Miku:")! Just speak naturally - you're already known to be speaking.
|
||||
- Be expressive and use casual language, but stay in character as Miku
|
||||
- If user says "stop talking" or "be quiet", acknowledge briefly and stop
|
||||
- NOTE: You will automatically disconnect 45 seconds after {user.name} leaves the voice channel, so you can mention this if asked about leaving
|
||||
|
||||
Remember: This is a live voice conversation - be natural, not formulaic!"""
|
||||
|
||||
@@ -742,15 +811,19 @@ Remember: This is a live voice conversation - be natural, not formulaic!"""
|
||||
if self.miku_speaking:
|
||||
await self.audio_source.flush()
|
||||
|
||||
# Add Miku's complete response to history
|
||||
# Filter out self-referential prefixes from response
|
||||
filtered_response = self._filter_name_prefixes(full_response.strip())
|
||||
|
||||
# Add Miku's complete response to history (use filtered version)
|
||||
self.conversation_history.append({
|
||||
"role": "assistant",
|
||||
"content": full_response.strip()
|
||||
"content": filtered_response
|
||||
})
|
||||
|
||||
# Show response
|
||||
await self.text_channel.send(f"🎤 Miku: *\"{full_response.strip()}\"*")
|
||||
logger.info(f"✓ Voice response complete: {full_response.strip()}")
|
||||
# Show response (only in debug mode)
|
||||
if globals.VOICE_DEBUG_MODE:
|
||||
await self.text_channel.send(f"🎤 Miku: *\"{filtered_response}\"*")
|
||||
logger.info(f"✓ Voice response complete: {filtered_response}")
|
||||
else:
|
||||
# Interrupted - don't add incomplete response to history
|
||||
# (interruption marker already added by on_user_interruption)
|
||||
@@ -763,6 +836,35 @@ Remember: This is a live voice conversation - be natural, not formulaic!"""
|
||||
finally:
|
||||
self.miku_speaking = False
|
||||
|
||||
def _filter_name_prefixes(self, text: str) -> str:
|
||||
"""
|
||||
Filter out self-referential name prefixes from Miku's responses.
|
||||
|
||||
Removes patterns like:
|
||||
- "Miku: rest of text"
|
||||
- "Hatsune Miku: rest of text"
|
||||
- "miku: rest of text" (case insensitive)
|
||||
|
||||
Args:
|
||||
text: Raw response text
|
||||
|
||||
Returns:
|
||||
Filtered text without name prefixes
|
||||
"""
|
||||
import re
|
||||
|
||||
# Pattern matches "Miku:" or "Hatsune Miku:" at the start of the text (case insensitive)
|
||||
# Captures any amount of whitespace after the colon
|
||||
pattern = r'^(?:Hatsune\s+)?Miku:\s*'
|
||||
|
||||
filtered = re.sub(pattern, '', text, flags=re.IGNORECASE)
|
||||
|
||||
# Log if we filtered something
|
||||
if filtered != text:
|
||||
logger.info(f"Filtered name prefix: '{text[:30]}...' -> '{filtered[:30]}...'")
|
||||
|
||||
return filtered
|
||||
|
||||
async def _cancel_tts(self):
|
||||
"""
|
||||
Immediately cancel TTS synthesis and clear all audio buffers.
|
||||
|
||||
@@ -8,6 +8,8 @@ Uses the discord-ext-voice-recv extension for proper audio receiving support.
|
||||
import asyncio
|
||||
import audioop
|
||||
import logging
|
||||
import struct
|
||||
import array
|
||||
from typing import Dict, Optional
|
||||
from collections import deque
|
||||
|
||||
@@ -27,13 +29,13 @@ class VoiceReceiverSink(voice_recv.AudioSink):
|
||||
decodes/resamples as needed, and sends to STT clients for transcription.
|
||||
"""
|
||||
|
||||
def __init__(self, voice_manager, stt_url: str = "ws://miku-stt:8766/ws/stt"):
|
||||
def __init__(self, voice_manager, stt_url: str = "ws://miku-stt:8766"):
|
||||
"""
|
||||
Initialize Voice Receiver.
|
||||
|
||||
Args:
|
||||
voice_manager: The voice manager instance
|
||||
stt_url: Base URL for STT WebSocket server with path (port 8766 inside container)
|
||||
stt_url: WebSocket URL for RealtimeSTT server (port 8766 inside container)
|
||||
"""
|
||||
super().__init__()
|
||||
self.voice_manager = voice_manager
|
||||
@@ -72,6 +74,68 @@ class VoiceReceiverSink(voice_recv.AudioSink):
|
||||
|
||||
logger.info("VoiceReceiverSink initialized")
|
||||
|
||||
@staticmethod
|
||||
def _preprocess_audio(pcm_data: bytes) -> bytes:
|
||||
"""
|
||||
Preprocess audio for better STT accuracy.
|
||||
|
||||
Applies:
|
||||
1. DC offset removal
|
||||
2. High-pass filter (80Hz) to remove rumble
|
||||
3. RMS normalization
|
||||
|
||||
Args:
|
||||
pcm_data: Raw PCM audio (16-bit mono, 16kHz)
|
||||
|
||||
Returns:
|
||||
Preprocessed PCM audio
|
||||
"""
|
||||
try:
|
||||
# Convert bytes to array of int16 samples
|
||||
samples = array.array('h', pcm_data)
|
||||
|
||||
# 1. Remove DC offset (mean)
|
||||
mean = sum(samples) / len(samples) if samples else 0
|
||||
samples = array.array('h', [int(s - mean) for s in samples])
|
||||
|
||||
# 2. Simple high-pass filter (80Hz @ 16kHz)
|
||||
# Using a simple first-order HPF: y[n] = x[n] - x[n-1] + 0.95 * y[n-1]
|
||||
alpha = 0.95 # Filter coefficient (roughly 80Hz cutoff at 16kHz)
|
||||
filtered = array.array('h')
|
||||
prev_input = 0
|
||||
prev_output = 0
|
||||
|
||||
for sample in samples:
|
||||
output = sample - prev_input + alpha * prev_output
|
||||
filtered.append(int(max(-32768, min(32767, output)))) # Clamp to int16 range
|
||||
prev_input = sample
|
||||
prev_output = output
|
||||
|
||||
# 3. RMS normalization to target level
|
||||
# Calculate RMS
|
||||
sum_squares = sum(s * s for s in filtered)
|
||||
rms = (sum_squares / len(filtered)) ** 0.5 if filtered else 1.0
|
||||
|
||||
# Target RMS (roughly -20dB)
|
||||
target_rms = 3276.8 # 10% of max int16 range
|
||||
|
||||
# Normalize if RMS is too low or too high
|
||||
if rms > 100: # Only normalize if there's actual signal
|
||||
gain = target_rms / rms
|
||||
# Limit gain to prevent over-amplification of noise
|
||||
gain = min(gain, 4.0) # Max 12dB boost
|
||||
normalized = array.array('h', [
|
||||
int(max(-32768, min(32767, s * gain))) for s in filtered
|
||||
])
|
||||
return normalized.tobytes()
|
||||
else:
|
||||
# Signal too weak, return filtered without normalization
|
||||
return filtered.tobytes()
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Audio preprocessing failed, using raw audio: {e}")
|
||||
return pcm_data
|
||||
|
||||
def wants_opus(self) -> bool:
|
||||
"""
|
||||
Tell discord-ext-voice-recv we want Opus data, NOT decoded PCM.
|
||||
@@ -144,6 +208,10 @@ class VoiceReceiverSink(voice_recv.AudioSink):
|
||||
# Discord sends 20ms chunks: 960 samples @ 48kHz → 320 samples @ 16kHz
|
||||
pcm_16k, _ = audioop.ratecv(pcm_mono, 2, 1, 48000, 16000, None)
|
||||
|
||||
# Preprocess audio for better STT accuracy
|
||||
# (DC offset removal, high-pass filter, RMS normalization)
|
||||
pcm_16k = self._preprocess_audio(pcm_16k)
|
||||
|
||||
# Send to STT client (schedule on event loop thread-safely)
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self._send_audio_chunk(user_id, pcm_16k),
|
||||
@@ -184,21 +252,16 @@ class VoiceReceiverSink(voice_recv.AudioSink):
|
||||
self.audio_buffers[user_id] = deque(maxlen=1000)
|
||||
|
||||
# Create STT client with callbacks
|
||||
# RealtimeSTT handles VAD internally, so we only need partial/final callbacks
|
||||
stt_client = STTClient(
|
||||
user_id=user_id,
|
||||
stt_url=self.stt_url,
|
||||
on_vad_event=lambda event: asyncio.create_task(
|
||||
self._on_vad_event(user_id, event)
|
||||
),
|
||||
on_partial_transcript=lambda text, timestamp: asyncio.create_task(
|
||||
self._on_partial_transcript(user_id, text)
|
||||
),
|
||||
on_final_transcript=lambda text, timestamp: asyncio.create_task(
|
||||
self._on_final_transcript(user_id, text, user)
|
||||
),
|
||||
on_interruption=lambda prob: asyncio.create_task(
|
||||
self._on_interruption(user_id, prob)
|
||||
)
|
||||
)
|
||||
|
||||
# Connect to STT server
|
||||
@@ -279,16 +342,16 @@ class VoiceReceiverSink(voice_recv.AudioSink):
|
||||
"""
|
||||
Send audio chunk to STT client.
|
||||
|
||||
Buffers audio until we have 512 samples (32ms @ 16kHz) which is what
|
||||
Silero VAD expects. Discord sends 320 samples (20ms), so we buffer
|
||||
2 chunks and send 640 samples, then the STT server can split it.
|
||||
RealtimeSTT expects 16kHz mono 16-bit PCM audio.
|
||||
We buffer audio to send larger chunks for efficiency.
|
||||
VAD and silence detection is handled by RealtimeSTT.
|
||||
|
||||
Args:
|
||||
user_id: Discord user ID
|
||||
audio_data: PCM audio (int16, 16kHz mono, 320 samples = 640 bytes)
|
||||
audio_data: PCM audio (int16, 16kHz mono)
|
||||
"""
|
||||
stt_client = self.stt_clients.get(user_id)
|
||||
if not stt_client or not stt_client.is_connected():
|
||||
if not stt_client or not stt_client.connected:
|
||||
return
|
||||
|
||||
try:
|
||||
@@ -299,11 +362,9 @@ class VoiceReceiverSink(voice_recv.AudioSink):
|
||||
buffer = self.audio_buffers[user_id]
|
||||
buffer.append(audio_data)
|
||||
|
||||
# Silero VAD expects 512 samples @ 16kHz (1024 bytes)
|
||||
# Discord gives us 320 samples (640 bytes) every 20ms
|
||||
# Buffer 2 chunks = 640 samples = 1280 bytes, send as one chunk
|
||||
SAMPLES_NEEDED = 512 # What VAD wants
|
||||
BYTES_NEEDED = SAMPLES_NEEDED * 2 # int16 = 2 bytes per sample
|
||||
# Buffer and send in larger chunks for efficiency
|
||||
# RealtimeSTT will handle VAD internally
|
||||
BYTES_NEEDED = 1024 # 512 samples * 2 bytes
|
||||
|
||||
# Check if we have enough buffered audio
|
||||
total_bytes = sum(len(chunk) for chunk in buffer)
|
||||
@@ -313,16 +374,10 @@ class VoiceReceiverSink(voice_recv.AudioSink):
|
||||
combined = b''.join(buffer)
|
||||
buffer.clear()
|
||||
|
||||
# Send in 512-sample (1024-byte) chunks
|
||||
for i in range(0, len(combined), BYTES_NEEDED):
|
||||
chunk = combined[i:i+BYTES_NEEDED]
|
||||
if len(chunk) == BYTES_NEEDED:
|
||||
await stt_client.send_audio(chunk)
|
||||
else:
|
||||
# Put remaining partial chunk back in buffer
|
||||
buffer.append(chunk)
|
||||
# Send all audio to STT (RealtimeSTT handles VAD internally)
|
||||
await stt_client.send_audio(combined)
|
||||
|
||||
# Track audio time for silence detection
|
||||
# Track audio time for interruption detection
|
||||
import time
|
||||
current_time = time.time()
|
||||
self.last_audio_time[user_id] = current_time
|
||||
@@ -331,103 +386,57 @@ class VoiceReceiverSink(voice_recv.AudioSink):
|
||||
# Check if Miku is speaking and user is interrupting
|
||||
# Note: self.voice_manager IS the VoiceSession, not the VoiceManager singleton
|
||||
miku_speaking = self.voice_manager.miku_speaking
|
||||
logger.debug(f"[INTERRUPTION CHECK] user={user_id}, miku_speaking={miku_speaking}")
|
||||
|
||||
if miku_speaking:
|
||||
# Track interruption
|
||||
if user_id not in self.interruption_start_time:
|
||||
# First chunk during Miku's speech
|
||||
self.interruption_start_time[user_id] = current_time
|
||||
self.interruption_audio_count[user_id] = 1
|
||||
# Calculate RMS to detect if user is actually speaking
|
||||
# (not just silence/background noise)
|
||||
rms = audioop.rms(combined, 2)
|
||||
RMS_THRESHOLD = 500 # Adjust threshold - higher = less sensitive
|
||||
|
||||
if rms > RMS_THRESHOLD:
|
||||
# User is actually speaking - track as potential interruption
|
||||
if user_id not in self.interruption_start_time:
|
||||
# First chunk during Miku's speech with actual audio
|
||||
self.interruption_start_time[user_id] = current_time
|
||||
self.interruption_audio_count[user_id] = 1
|
||||
logger.debug(f"Potential interruption start (rms={rms})")
|
||||
else:
|
||||
# Increment chunk count
|
||||
self.interruption_audio_count[user_id] += 1
|
||||
|
||||
# Calculate interruption duration
|
||||
interruption_duration = current_time - self.interruption_start_time[user_id]
|
||||
chunk_count = self.interruption_audio_count[user_id]
|
||||
|
||||
# Check if interruption threshold is met
|
||||
if (interruption_duration >= self.interruption_threshold_time and
|
||||
chunk_count >= self.interruption_threshold_chunks):
|
||||
|
||||
# Trigger interruption!
|
||||
logger.info(f"🛑 User {user_id} interrupted Miku (duration={interruption_duration:.2f}s, chunks={chunk_count}, rms={rms})")
|
||||
logger.info(f" → Stopping Miku's TTS and LLM, will process user's speech when finished")
|
||||
|
||||
# Reset interruption tracking
|
||||
self.interruption_start_time.pop(user_id, None)
|
||||
self.interruption_audio_count.pop(user_id, None)
|
||||
|
||||
# Call interruption handler (this sets miku_speaking=False)
|
||||
asyncio.create_task(
|
||||
self.voice_manager.on_user_interruption(user_id)
|
||||
)
|
||||
else:
|
||||
# Increment chunk count
|
||||
self.interruption_audio_count[user_id] += 1
|
||||
|
||||
# Calculate interruption duration
|
||||
interruption_duration = current_time - self.interruption_start_time[user_id]
|
||||
chunk_count = self.interruption_audio_count[user_id]
|
||||
|
||||
# Check if interruption threshold is met
|
||||
if (interruption_duration >= self.interruption_threshold_time and
|
||||
chunk_count >= self.interruption_threshold_chunks):
|
||||
|
||||
# Trigger interruption!
|
||||
logger.info(f"🛑 User {user_id} interrupted Miku (duration={interruption_duration:.2f}s, chunks={chunk_count})")
|
||||
logger.info(f" → Stopping Miku's TTS and LLM, will process user's speech when finished")
|
||||
|
||||
# Reset interruption tracking
|
||||
# Audio below RMS threshold (silence) - reset interruption tracking
|
||||
# This ensures brief pauses in speech reset the counter
|
||||
self.interruption_start_time.pop(user_id, None)
|
||||
self.interruption_audio_count.pop(user_id, None)
|
||||
|
||||
# Call interruption handler (this sets miku_speaking=False)
|
||||
asyncio.create_task(
|
||||
self.voice_manager.on_user_interruption(user_id)
|
||||
)
|
||||
else:
|
||||
# Miku not speaking, clear interruption tracking
|
||||
self.interruption_start_time.pop(user_id, None)
|
||||
self.interruption_audio_count.pop(user_id, None)
|
||||
|
||||
# Cancel existing silence task if any
|
||||
if user_id in self.silence_tasks and not self.silence_tasks[user_id].done():
|
||||
self.silence_tasks[user_id].cancel()
|
||||
|
||||
# Start new silence detection task
|
||||
self.silence_tasks[user_id] = asyncio.create_task(
|
||||
self._detect_silence(user_id)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send audio chunk for user {user_id}: {e}")
|
||||
|
||||
async def _detect_silence(self, user_id: int):
|
||||
"""
|
||||
Wait for silence timeout and send 'final' command to STT.
|
||||
|
||||
This is called after each audio chunk. If no more audio arrives within
|
||||
the silence_timeout period, we send the 'final' command to get the
|
||||
complete transcription.
|
||||
|
||||
Args:
|
||||
user_id: Discord user ID
|
||||
"""
|
||||
try:
|
||||
# Wait for silence timeout
|
||||
await asyncio.sleep(self.silence_timeout)
|
||||
|
||||
# Check if we still have an active STT client
|
||||
stt_client = self.stt_clients.get(user_id)
|
||||
if not stt_client or not stt_client.is_connected():
|
||||
return
|
||||
|
||||
# Send final command to get complete transcription
|
||||
logger.debug(f"Silence detected for user {user_id}, requesting final transcript")
|
||||
await stt_client.send_final()
|
||||
|
||||
except asyncio.CancelledError:
|
||||
# Task was cancelled because new audio arrived
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f"Error in silence detection for user {user_id}: {e}")
|
||||
|
||||
async def _on_vad_event(self, user_id: int, event: dict):
|
||||
"""
|
||||
Handle VAD event from STT.
|
||||
|
||||
Args:
|
||||
user_id: Discord user ID
|
||||
event: VAD event dictionary with 'event' and 'probability' keys
|
||||
"""
|
||||
user = self.users.get(user_id)
|
||||
event_type = event.get('event', 'unknown')
|
||||
probability = event.get('probability', 0.0)
|
||||
|
||||
logger.debug(f"VAD [{user.name if user else user_id}]: {event_type} (prob={probability:.3f})")
|
||||
|
||||
# Notify voice manager - pass the full event dict
|
||||
if hasattr(self.voice_manager, 'on_user_vad_event'):
|
||||
await self.voice_manager.on_user_vad_event(user_id, event)
|
||||
|
||||
async def _on_partial_transcript(self, user_id: int, text: str):
|
||||
"""
|
||||
Handle partial transcript from STT.
|
||||
@@ -438,7 +447,6 @@ class VoiceReceiverSink(voice_recv.AudioSink):
|
||||
"""
|
||||
user = self.users.get(user_id)
|
||||
logger.info(f"[VOICE_RECEIVER] Partial [{user.name if user else user_id}]: {text}")
|
||||
print(f"[DEBUG] PARTIAL TRANSCRIPT RECEIVED: {text}") # Extra debug
|
||||
|
||||
# Notify voice manager
|
||||
if hasattr(self.voice_manager, 'on_partial_transcript'):
|
||||
@@ -456,29 +464,11 @@ class VoiceReceiverSink(voice_recv.AudioSink):
|
||||
user: Discord user object
|
||||
"""
|
||||
logger.info(f"[VOICE_RECEIVER] Final [{user.name if user else user_id}]: {text}")
|
||||
print(f"[DEBUG] FINAL TRANSCRIPT RECEIVED: {text}") # Extra debug
|
||||
|
||||
# Notify voice manager - THIS TRIGGERS LLM RESPONSE
|
||||
if hasattr(self.voice_manager, 'on_final_transcript'):
|
||||
await self.voice_manager.on_final_transcript(user_id, text)
|
||||
|
||||
async def _on_interruption(self, user_id: int, probability: float):
|
||||
"""
|
||||
Handle interruption detection from STT.
|
||||
|
||||
This cancels Miku's current speech if user interrupts.
|
||||
|
||||
Args:
|
||||
user_id: Discord user ID
|
||||
probability: Interruption confidence probability
|
||||
"""
|
||||
user = self.users.get(user_id)
|
||||
logger.info(f"Interruption from [{user.name if user else user_id}] (prob={probability:.3f})")
|
||||
|
||||
# Notify voice manager - THIS CANCELS MIKU'S SPEECH
|
||||
if hasattr(self.voice_manager, 'on_user_interruption'):
|
||||
await self.voice_manager.on_user_interruption(user_id, probability)
|
||||
|
||||
def get_listening_users(self) -> list:
|
||||
"""
|
||||
Get list of users currently being listened to.
|
||||
@@ -489,30 +479,10 @@ class VoiceReceiverSink(voice_recv.AudioSink):
|
||||
return [
|
||||
{
|
||||
'user_id': user_id,
|
||||
'username': user.name if user else 'Unknown',
|
||||
'connected': client.is_connected()
|
||||
'username': self.users.get(user_id, {}).name if self.users.get(user_id) else 'Unknown',
|
||||
'connected': self.stt_clients.get(user_id, {}).connected if self.stt_clients.get(user_id) else False
|
||||
}
|
||||
for user_id, (user, client) in
|
||||
[(uid, (self.users.get(uid), self.stt_clients.get(uid)))
|
||||
for uid in self.stt_clients.keys()]
|
||||
for user_id in self.stt_clients.keys()
|
||||
]
|
||||
|
||||
@voice_recv.AudioSink.listener()
|
||||
def on_voice_member_speaking_start(self, member: discord.Member):
|
||||
"""
|
||||
Called when a member starts speaking (green circle appears).
|
||||
|
||||
This is a virtual event from discord-ext-voice-recv based on packet activity.
|
||||
"""
|
||||
if member.id in self.stt_clients:
|
||||
logger.debug(f"🎤 {member.name} started speaking")
|
||||
|
||||
@voice_recv.AudioSink.listener()
|
||||
def on_voice_member_speaking_stop(self, member: discord.Member):
|
||||
"""
|
||||
Called when a member stops speaking (green circle disappears).
|
||||
|
||||
This is a virtual event from discord-ext-voice-recv based on packet activity.
|
||||
"""
|
||||
if member.id in self.stt_clients:
|
||||
logger.debug(f"🔇 {member.name} stopped speaking")
|
||||
# Discord VAD events removed - we rely entirely on RealtimeSTT's VAD for speech detection
|
||||
|
||||
Reference in New Issue
Block a user