374 lines
15 KiB
Python
374 lines
15 KiB
Python
|
|
# voice_audio.py
|
||
|
|
"""
|
||
|
|
Audio streaming bridge between RVC TTS and Discord voice.
|
||
|
|
Uses aiohttp for WebSocket communication (compatible with FastAPI).
|
||
|
|
"""
|
||
|
|
|
||
|
|
import asyncio
|
||
|
|
import json
|
||
|
|
import numpy as np
|
||
|
|
from typing import Optional
|
||
|
|
import discord
|
||
|
|
import aiohttp
|
||
|
|
from utils.logger import get_logger
|
||
|
|
|
||
|
|
import asyncio
|
||
|
|
import struct
|
||
|
|
import json
|
||
|
|
import websockets
|
||
|
|
import discord
|
||
|
|
import numpy as np
|
||
|
|
from typing import Optional
|
||
|
|
from utils.logger import get_logger
|
||
|
|
|
||
|
|
logger = get_logger('voice_audio')
|
||
|
|
|
||
|
|
# Audio format constants
|
||
|
|
SAMPLE_RATE = 48000 # 48kHz
|
||
|
|
CHANNELS = 2 # Stereo for Discord
|
||
|
|
FRAME_LENGTH = 0.02 # 20ms frames
|
||
|
|
SAMPLES_PER_FRAME = int(SAMPLE_RATE * FRAME_LENGTH) # 960 samples
|
||
|
|
|
||
|
|
|
||
|
|
class MikuVoiceSource(discord.AudioSource):
|
||
|
|
"""
|
||
|
|
Audio source that receives audio from RVC TTS WebSocket and feeds it to Discord voice.
|
||
|
|
Single WebSocket connection handles both token sending and audio receiving.
|
||
|
|
Uses aiohttp for WebSocket communication (compatible with FastAPI).
|
||
|
|
"""
|
||
|
|
|
||
|
|
def __init__(self):
|
||
|
|
self.websocket_url = "ws://172.25.0.1:8765/ws/stream"
|
||
|
|
self.health_url = "http://172.25.0.1:8765/health"
|
||
|
|
self.session = None
|
||
|
|
self.websocket = None
|
||
|
|
self.audio_buffer = bytearray()
|
||
|
|
self.buffer_lock = asyncio.Lock()
|
||
|
|
self.running = False
|
||
|
|
self._receive_task = None
|
||
|
|
self.warmed_up = False # Track if TTS pipeline is warmed up
|
||
|
|
self.token_queue = [] # Queue tokens while warming up or connecting
|
||
|
|
|
||
|
|
async def _check_rvc_ready(self) -> bool:
|
||
|
|
"""Check if RVC is initialized and warmed up via health endpoint"""
|
||
|
|
try:
|
||
|
|
async with aiohttp.ClientSession() as session:
|
||
|
|
async with session.get(self.health_url, timeout=aiohttp.ClientTimeout(total=2)) as response:
|
||
|
|
if response.status == 200:
|
||
|
|
data = await response.json()
|
||
|
|
return data.get("warmed_up", False)
|
||
|
|
return False
|
||
|
|
except Exception as e:
|
||
|
|
logger.debug(f"Health check failed: {e}")
|
||
|
|
return False
|
||
|
|
|
||
|
|
async def connect(self):
|
||
|
|
"""Connect to RVC TTS WebSocket using aiohttp"""
|
||
|
|
try:
|
||
|
|
# First, check if RVC is warmed up
|
||
|
|
logger.info("Checking if RVC is ready...")
|
||
|
|
is_ready = await self._check_rvc_ready()
|
||
|
|
|
||
|
|
if not is_ready:
|
||
|
|
logger.warning("⏳ RVC is warming up, will queue tokens until ready")
|
||
|
|
self.warmed_up = False
|
||
|
|
# Don't connect yet - we'll connect later when ready
|
||
|
|
# Start a background task to poll and connect when ready
|
||
|
|
self._receive_task = asyncio.create_task(self._wait_for_warmup_and_connect())
|
||
|
|
return
|
||
|
|
|
||
|
|
# RVC is ready, connect immediately
|
||
|
|
logger.info("RVC is ready, connecting...")
|
||
|
|
await self._do_connect()
|
||
|
|
self.warmed_up = True
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
logger.error(f"Failed to initialize connection: {e}", exc_info=True)
|
||
|
|
raise
|
||
|
|
|
||
|
|
async def _do_connect(self):
|
||
|
|
"""Actually establish the WebSocket connection"""
|
||
|
|
self.session = aiohttp.ClientSession()
|
||
|
|
self.websocket = await self.session.ws_connect(self.websocket_url)
|
||
|
|
self.running = True
|
||
|
|
logger.info("✓ Connected to RVC TTS WebSocket")
|
||
|
|
|
||
|
|
# Always start background task to receive audio after connecting
|
||
|
|
# (Don't check if _receive_task exists - it might be the warmup polling task)
|
||
|
|
self._receive_task = asyncio.create_task(self._receive_audio())
|
||
|
|
|
||
|
|
async def _wait_for_warmup_and_connect(self):
|
||
|
|
"""Poll RVC health until warmed up, then connect and flush queue"""
|
||
|
|
try:
|
||
|
|
logger.info("Polling RVC for warmup completion...")
|
||
|
|
max_wait = 60 # 60 seconds max
|
||
|
|
poll_interval = 1.0 # Check every second
|
||
|
|
|
||
|
|
for _ in range(int(max_wait / poll_interval)):
|
||
|
|
if await self._check_rvc_ready():
|
||
|
|
logger.info("✅ RVC warmup complete! Connecting and flushing queue...")
|
||
|
|
await self._do_connect()
|
||
|
|
self.warmed_up = True
|
||
|
|
|
||
|
|
# Flush queued tokens
|
||
|
|
if self.token_queue and self.websocket:
|
||
|
|
logger.info(f"Sending {len(self.token_queue)} queued tokens")
|
||
|
|
for token, pitch_shift in self.token_queue:
|
||
|
|
await self.websocket.send_json({
|
||
|
|
"token": token,
|
||
|
|
"pitch_shift": pitch_shift
|
||
|
|
})
|
||
|
|
# Small delay to avoid overwhelming RVC
|
||
|
|
await asyncio.sleep(0.05)
|
||
|
|
|
||
|
|
# Send flush command to ensure any buffered text is synthesized
|
||
|
|
await self.websocket.send_json({"flush": True})
|
||
|
|
logger.info("✓ Queue flushed with explicit flush command")
|
||
|
|
self.token_queue.clear()
|
||
|
|
return
|
||
|
|
|
||
|
|
await asyncio.sleep(poll_interval)
|
||
|
|
|
||
|
|
# Timeout
|
||
|
|
logger.error("❌ RVC warmup timeout! Connecting anyway...")
|
||
|
|
await self._do_connect()
|
||
|
|
self.warmed_up = True
|
||
|
|
|
||
|
|
except asyncio.CancelledError:
|
||
|
|
logger.debug("Warmup polling cancelled")
|
||
|
|
except Exception as e:
|
||
|
|
logger.error(f"Error during warmup wait: {e}", exc_info=True)
|
||
|
|
|
||
|
|
async def _reconnect(self):
|
||
|
|
"""Attempt to reconnect after connection failure"""
|
||
|
|
try:
|
||
|
|
logger.info("Reconnection attempt starting...")
|
||
|
|
max_retries = 5
|
||
|
|
retry_delay = 3.0
|
||
|
|
|
||
|
|
for attempt in range(max_retries):
|
||
|
|
try:
|
||
|
|
# Clean up old connection
|
||
|
|
if self.websocket:
|
||
|
|
try:
|
||
|
|
await self.websocket.close()
|
||
|
|
except:
|
||
|
|
pass
|
||
|
|
self.websocket = None
|
||
|
|
|
||
|
|
if self.session:
|
||
|
|
try:
|
||
|
|
await self.session.close()
|
||
|
|
except:
|
||
|
|
pass
|
||
|
|
self.session = None
|
||
|
|
|
||
|
|
# Wait before retry
|
||
|
|
if attempt > 0:
|
||
|
|
logger.info(f"Retry {attempt + 1}/{max_retries} in {retry_delay}s...")
|
||
|
|
await asyncio.sleep(retry_delay)
|
||
|
|
|
||
|
|
# Check if RVC is ready
|
||
|
|
if not await self._check_rvc_ready():
|
||
|
|
logger.warning("RVC not ready, will retry...")
|
||
|
|
continue
|
||
|
|
|
||
|
|
# Try to connect
|
||
|
|
await self._do_connect()
|
||
|
|
self.warmed_up = True
|
||
|
|
|
||
|
|
# Flush queued tokens
|
||
|
|
if self.token_queue and self.websocket:
|
||
|
|
logger.info(f"✓ Reconnected! Flushing {len(self.token_queue)} queued tokens")
|
||
|
|
for token, pitch_shift in self.token_queue:
|
||
|
|
await self.websocket.send_json({
|
||
|
|
"token": token,
|
||
|
|
"pitch_shift": pitch_shift
|
||
|
|
})
|
||
|
|
await asyncio.sleep(0.05)
|
||
|
|
|
||
|
|
# Send flush command to ensure any buffered text is synthesized
|
||
|
|
await self.websocket.send_json({"flush": True})
|
||
|
|
self.token_queue.clear()
|
||
|
|
logger.info("✓ Queue flushed with explicit flush command")
|
||
|
|
logger.info("✓ Reconnection successful")
|
||
|
|
return
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
logger.error(f"Reconnection attempt {attempt + 1} failed: {e}")
|
||
|
|
|
||
|
|
logger.error(f"❌ Failed to reconnect after {max_retries} attempts")
|
||
|
|
|
||
|
|
except asyncio.CancelledError:
|
||
|
|
logger.debug("Reconnection cancelled")
|
||
|
|
except Exception as e:
|
||
|
|
logger.error(f"Error during reconnection: {e}", exc_info=True)
|
||
|
|
|
||
|
|
async def disconnect(self):
|
||
|
|
"""Disconnect from WebSocket"""
|
||
|
|
self.running = False
|
||
|
|
|
||
|
|
if self._receive_task:
|
||
|
|
self._receive_task.cancel()
|
||
|
|
try:
|
||
|
|
await self._receive_task
|
||
|
|
except asyncio.CancelledError:
|
||
|
|
pass
|
||
|
|
self._receive_task = None
|
||
|
|
|
||
|
|
if self.websocket:
|
||
|
|
await self.websocket.close()
|
||
|
|
self.websocket = None
|
||
|
|
|
||
|
|
if self.session:
|
||
|
|
await self.session.close()
|
||
|
|
self.session = None
|
||
|
|
|
||
|
|
logger.info("Disconnected from RVC TTS WebSocket")
|
||
|
|
|
||
|
|
async def send_token(self, token: str, pitch_shift: int = 0):
|
||
|
|
"""
|
||
|
|
Send a text token to TTS for voice generation.
|
||
|
|
Queues tokens if pipeline is still warming up or connection failed.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
token: Text token to synthesize
|
||
|
|
pitch_shift: Pitch adjustment (-12 to +12 semitones)
|
||
|
|
"""
|
||
|
|
# If not warmed up yet or no connection, queue the token
|
||
|
|
if not self.warmed_up or not self.websocket:
|
||
|
|
self.token_queue.append((token, pitch_shift))
|
||
|
|
if not self.warmed_up:
|
||
|
|
logger.debug(f"Queued token (warming up): '{token}' (queue size: {len(self.token_queue)})")
|
||
|
|
else:
|
||
|
|
logger.debug(f"Queued token (no connection): '{token}' (queue size: {len(self.token_queue)})")
|
||
|
|
# Try to reconnect in background if not already trying
|
||
|
|
if not self._receive_task or self._receive_task.done():
|
||
|
|
logger.info("Attempting to reconnect to RVC...")
|
||
|
|
self._receive_task = asyncio.create_task(self._reconnect())
|
||
|
|
return
|
||
|
|
|
||
|
|
try:
|
||
|
|
message = {
|
||
|
|
"token": token,
|
||
|
|
"pitch_shift": pitch_shift
|
||
|
|
}
|
||
|
|
await self.websocket.send_json(message)
|
||
|
|
logger.debug(f"Sent token to TTS: '{token}'")
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
logger.error(f"Failed to send token: {e}")
|
||
|
|
# Queue the failed token and mark as not warmed up to trigger reconnection
|
||
|
|
self.token_queue.append((token, pitch_shift))
|
||
|
|
self.warmed_up = False
|
||
|
|
if self.websocket:
|
||
|
|
try:
|
||
|
|
await self.websocket.close()
|
||
|
|
except:
|
||
|
|
pass
|
||
|
|
self.websocket = None
|
||
|
|
|
||
|
|
async def stream_text(self, text: str, pitch_shift: int = 0):
|
||
|
|
"""
|
||
|
|
Stream entire text to TTS word-by-word.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
text: Full text to synthesize
|
||
|
|
pitch_shift: Pitch adjustment
|
||
|
|
"""
|
||
|
|
words = text.split()
|
||
|
|
for word in words:
|
||
|
|
await self.send_token(word + " ", pitch_shift)
|
||
|
|
# Small delay to avoid overwhelming the TTS
|
||
|
|
await asyncio.sleep(0.05)
|
||
|
|
|
||
|
|
|
||
|
|
async def _receive_audio(self):
|
||
|
|
"""Background task to receive audio from WebSocket and buffer it."""
|
||
|
|
try:
|
||
|
|
while self.running and self.websocket:
|
||
|
|
try:
|
||
|
|
# Receive message from WebSocket
|
||
|
|
msg = await self.websocket.receive()
|
||
|
|
|
||
|
|
if msg.type == aiohttp.WSMsgType.BINARY:
|
||
|
|
# Convert float32 mono → int16 stereo
|
||
|
|
converted = self._convert_audio(msg.data)
|
||
|
|
self.audio_buffer.extend(converted)
|
||
|
|
|
||
|
|
logger.debug(f"Received {len(msg.data)} bytes, buffer: {len(self.audio_buffer)} bytes")
|
||
|
|
|
||
|
|
elif msg.type == aiohttp.WSMsgType.CLOSED:
|
||
|
|
logger.warning("TTS WebSocket connection closed")
|
||
|
|
break
|
||
|
|
elif msg.type == aiohttp.WSMsgType.ERROR:
|
||
|
|
logger.error(f"WebSocket error: {self.websocket.exception()}")
|
||
|
|
break
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
logger.error(f"Error receiving audio: {e}", exc_info=True)
|
||
|
|
break
|
||
|
|
|
||
|
|
except asyncio.CancelledError:
|
||
|
|
logger.debug("Audio receive task cancelled")
|
||
|
|
|
||
|
|
def _convert_audio(self, float32_mono: bytes) -> bytes:
|
||
|
|
"""
|
||
|
|
Convert float32 mono PCM to int16 stereo PCM.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
float32_mono: Raw PCM audio (float32 values, mono channel)
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
int16 stereo PCM bytes
|
||
|
|
"""
|
||
|
|
# Parse float32 values
|
||
|
|
num_samples = len(float32_mono) // 4 # 4 bytes per float32
|
||
|
|
float_array = struct.unpack(f'{num_samples}f', float32_mono)
|
||
|
|
|
||
|
|
# Convert to numpy for easier processing
|
||
|
|
audio_np = np.array(float_array, dtype=np.float32)
|
||
|
|
|
||
|
|
# Clamp to [-1.0, 1.0] range
|
||
|
|
audio_np = np.clip(audio_np, -1.0, 1.0)
|
||
|
|
|
||
|
|
# Convert to int16 range [-32768, 32767]
|
||
|
|
audio_int16 = (audio_np * 32767).astype(np.int16)
|
||
|
|
|
||
|
|
# Duplicate mono channel to stereo (L and R same)
|
||
|
|
stereo = np.repeat(audio_int16, 2)
|
||
|
|
|
||
|
|
# Convert to bytes
|
||
|
|
return stereo.tobytes()
|
||
|
|
|
||
|
|
def read(self) -> bytes:
|
||
|
|
"""
|
||
|
|
Read 20ms of audio (required by discord.py AudioSource interface).
|
||
|
|
|
||
|
|
Discord expects exactly 960 samples per channel (1920 samples total for stereo),
|
||
|
|
which equals 3840 bytes (1920 samples * 2 bytes per int16).
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
3840 bytes of int16 stereo PCM, or empty bytes if no audio available
|
||
|
|
"""
|
||
|
|
# Calculate required bytes for 20ms frame
|
||
|
|
bytes_needed = SAMPLES_PER_FRAME * CHANNELS * 2 # 960 * 2 * 2 = 3840 bytes
|
||
|
|
|
||
|
|
if len(self.audio_buffer) >= bytes_needed:
|
||
|
|
# Extract frame from buffer
|
||
|
|
frame = bytes(self.audio_buffer[:bytes_needed])
|
||
|
|
del self.audio_buffer[:bytes_needed]
|
||
|
|
return frame
|
||
|
|
else:
|
||
|
|
# Not enough audio yet, return silence
|
||
|
|
return b'\x00' * bytes_needed
|
||
|
|
|
||
|
|
def is_opus(self) -> bool:
|
||
|
|
"""Return False since we're providing raw PCM."""
|
||
|
|
return False
|
||
|
|
|
||
|
|
def cleanup(self):
|
||
|
|
"""Cleanup resources when AudioSource is done."""
|
||
|
|
logger.info("MikuVoiceSource cleanup called")
|
||
|
|
# Actual disconnect happens via disconnect() method
|