215 lines
7.2 KiB
Python
215 lines
7.2 KiB
Python
|
|
"""
|
||
|
|
STT Client for Discord Bot
|
||
|
|
|
||
|
|
WebSocket client that connects to the STT server and handles:
|
||
|
|
- Audio streaming to STT
|
||
|
|
- Receiving VAD events
|
||
|
|
- Receiving partial/final transcripts
|
||
|
|
- Interruption detection
|
||
|
|
"""
|
||
|
|
|
||
|
|
import aiohttp
|
||
|
|
import asyncio
|
||
|
|
import logging
|
||
|
|
from typing import Optional, Callable
|
||
|
|
import json
|
||
|
|
|
||
|
|
logger = logging.getLogger('stt_client')
|
||
|
|
|
||
|
|
|
||
|
|
class STTClient:
|
||
|
|
"""
|
||
|
|
WebSocket client for STT server communication.
|
||
|
|
|
||
|
|
Handles audio streaming and receives transcription events.
|
||
|
|
"""
|
||
|
|
|
||
|
|
def __init__(
|
||
|
|
self,
|
||
|
|
user_id: str,
|
||
|
|
stt_url: str = "ws://miku-stt:8000/ws/stt",
|
||
|
|
on_vad_event: Optional[Callable] = None,
|
||
|
|
on_partial_transcript: Optional[Callable] = None,
|
||
|
|
on_final_transcript: Optional[Callable] = None,
|
||
|
|
on_interruption: Optional[Callable] = None
|
||
|
|
):
|
||
|
|
"""
|
||
|
|
Initialize STT client.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
user_id: Discord user ID
|
||
|
|
stt_url: Base WebSocket URL for STT server
|
||
|
|
on_vad_event: Callback for VAD events (event_dict)
|
||
|
|
on_partial_transcript: Callback for partial transcripts (text, timestamp)
|
||
|
|
on_final_transcript: Callback for final transcripts (text, timestamp)
|
||
|
|
on_interruption: Callback for interruption detection (probability)
|
||
|
|
"""
|
||
|
|
self.user_id = user_id
|
||
|
|
self.stt_url = f"{stt_url}/{user_id}"
|
||
|
|
|
||
|
|
# Callbacks
|
||
|
|
self.on_vad_event = on_vad_event
|
||
|
|
self.on_partial_transcript = on_partial_transcript
|
||
|
|
self.on_final_transcript = on_final_transcript
|
||
|
|
self.on_interruption = on_interruption
|
||
|
|
|
||
|
|
# Connection state
|
||
|
|
self.websocket: Optional[aiohttp.ClientWebSocket] = None
|
||
|
|
self.session: Optional[aiohttp.ClientSession] = None
|
||
|
|
self.connected = False
|
||
|
|
self.running = False
|
||
|
|
|
||
|
|
# Receive task
|
||
|
|
self._receive_task: Optional[asyncio.Task] = None
|
||
|
|
|
||
|
|
logger.info(f"STT client initialized for user {user_id}")
|
||
|
|
|
||
|
|
async def connect(self):
|
||
|
|
"""Connect to STT WebSocket server."""
|
||
|
|
if self.connected:
|
||
|
|
logger.warning(f"Already connected for user {self.user_id}")
|
||
|
|
return
|
||
|
|
|
||
|
|
try:
|
||
|
|
self.session = aiohttp.ClientSession()
|
||
|
|
self.websocket = await self.session.ws_connect(
|
||
|
|
self.stt_url,
|
||
|
|
heartbeat=30
|
||
|
|
)
|
||
|
|
|
||
|
|
# Wait for ready message
|
||
|
|
ready_msg = await self.websocket.receive_json()
|
||
|
|
logger.info(f"STT connected for user {self.user_id}: {ready_msg}")
|
||
|
|
|
||
|
|
self.connected = True
|
||
|
|
self.running = True
|
||
|
|
|
||
|
|
# Start receive task
|
||
|
|
self._receive_task = asyncio.create_task(self._receive_events())
|
||
|
|
|
||
|
|
logger.info(f"✓ STT WebSocket connected for user {self.user_id}")
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
logger.error(f"Failed to connect STT for user {self.user_id}: {e}", exc_info=True)
|
||
|
|
await self.disconnect()
|
||
|
|
raise
|
||
|
|
|
||
|
|
async def disconnect(self):
|
||
|
|
"""Disconnect from STT WebSocket."""
|
||
|
|
logger.info(f"Disconnecting STT for user {self.user_id}")
|
||
|
|
|
||
|
|
self.running = False
|
||
|
|
self.connected = False
|
||
|
|
|
||
|
|
# Cancel receive task
|
||
|
|
if self._receive_task and not self._receive_task.done():
|
||
|
|
self._receive_task.cancel()
|
||
|
|
try:
|
||
|
|
await self._receive_task
|
||
|
|
except asyncio.CancelledError:
|
||
|
|
pass
|
||
|
|
|
||
|
|
# Close WebSocket
|
||
|
|
if self.websocket:
|
||
|
|
await self.websocket.close()
|
||
|
|
self.websocket = None
|
||
|
|
|
||
|
|
# Close session
|
||
|
|
if self.session:
|
||
|
|
await self.session.close()
|
||
|
|
self.session = None
|
||
|
|
|
||
|
|
logger.info(f"✓ STT disconnected for user {self.user_id}")
|
||
|
|
|
||
|
|
async def send_audio(self, audio_data: bytes):
|
||
|
|
"""
|
||
|
|
Send audio chunk to STT server.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
audio_data: PCM audio (int16, 16kHz mono)
|
||
|
|
"""
|
||
|
|
if not self.connected or not self.websocket:
|
||
|
|
logger.warning(f"Cannot send audio, not connected for user {self.user_id}")
|
||
|
|
return
|
||
|
|
|
||
|
|
try:
|
||
|
|
await self.websocket.send_bytes(audio_data)
|
||
|
|
logger.debug(f"Sent {len(audio_data)} bytes to STT")
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
logger.error(f"Failed to send audio to STT: {e}")
|
||
|
|
self.connected = False
|
||
|
|
|
||
|
|
async def _receive_events(self):
|
||
|
|
"""Background task to receive events from STT server."""
|
||
|
|
try:
|
||
|
|
while self.running and self.websocket:
|
||
|
|
try:
|
||
|
|
msg = await self.websocket.receive()
|
||
|
|
|
||
|
|
if msg.type == aiohttp.WSMsgType.TEXT:
|
||
|
|
event = json.loads(msg.data)
|
||
|
|
await self._handle_event(event)
|
||
|
|
|
||
|
|
elif msg.type == aiohttp.WSMsgType.CLOSED:
|
||
|
|
logger.info(f"STT WebSocket closed for user {self.user_id}")
|
||
|
|
break
|
||
|
|
|
||
|
|
elif msg.type == aiohttp.WSMsgType.ERROR:
|
||
|
|
logger.error(f"STT WebSocket error for user {self.user_id}")
|
||
|
|
break
|
||
|
|
|
||
|
|
except asyncio.CancelledError:
|
||
|
|
break
|
||
|
|
except Exception as e:
|
||
|
|
logger.error(f"Error receiving STT event: {e}", exc_info=True)
|
||
|
|
|
||
|
|
finally:
|
||
|
|
self.connected = False
|
||
|
|
logger.info(f"STT receive task ended for user {self.user_id}")
|
||
|
|
|
||
|
|
async def _handle_event(self, event: dict):
|
||
|
|
"""
|
||
|
|
Handle incoming STT event.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
event: Event dictionary from STT server
|
||
|
|
"""
|
||
|
|
event_type = event.get('type')
|
||
|
|
|
||
|
|
if event_type == 'vad':
|
||
|
|
# VAD event: speech detection
|
||
|
|
logger.debug(f"VAD event: {event}")
|
||
|
|
if self.on_vad_event:
|
||
|
|
await self.on_vad_event(event)
|
||
|
|
|
||
|
|
elif event_type == 'partial':
|
||
|
|
# Partial transcript
|
||
|
|
text = event.get('text', '')
|
||
|
|
timestamp = event.get('timestamp', 0)
|
||
|
|
logger.info(f"Partial transcript [{self.user_id}]: {text}")
|
||
|
|
if self.on_partial_transcript:
|
||
|
|
await self.on_partial_transcript(text, timestamp)
|
||
|
|
|
||
|
|
elif event_type == 'final':
|
||
|
|
# Final transcript
|
||
|
|
text = event.get('text', '')
|
||
|
|
timestamp = event.get('timestamp', 0)
|
||
|
|
logger.info(f"Final transcript [{self.user_id}]: {text}")
|
||
|
|
if self.on_final_transcript:
|
||
|
|
await self.on_final_transcript(text, timestamp)
|
||
|
|
|
||
|
|
elif event_type == 'interruption':
|
||
|
|
# Interruption detected
|
||
|
|
probability = event.get('probability', 0)
|
||
|
|
logger.info(f"Interruption detected from user {self.user_id} (prob={probability:.3f})")
|
||
|
|
if self.on_interruption:
|
||
|
|
await self.on_interruption(probability)
|
||
|
|
|
||
|
|
else:
|
||
|
|
logger.warning(f"Unknown STT event type: {event_type}")
|
||
|
|
|
||
|
|
def is_connected(self) -> bool:
|
||
|
|
"""Check if STT client is connected."""
|
||
|
|
return self.connected
|