526 lines
20 KiB
Python
526 lines
20 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
RealtimeSTT WebSocket Server
|
|
|
|
Provides real-time speech-to-text transcription using Faster-Whisper.
|
|
Receives audio chunks via WebSocket and streams back partial/final transcripts.
|
|
|
|
Protocol:
|
|
- Client sends: binary audio data (16kHz, 16-bit mono PCM)
|
|
- Client sends: JSON {"command": "reset"} to reset state
|
|
- Server sends: JSON {"type": "partial", "text": "...", "timestamp": float}
|
|
- Server sends: JSON {"type": "final", "text": "...", "timestamp": float}
|
|
"""
|
|
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
import time
|
|
import threading
|
|
import queue
|
|
from typing import Optional, Dict, Any
|
|
import numpy as np
|
|
import websockets
|
|
from websockets.server import serve
|
|
from aiohttp import web
|
|
|
|
# Configure logging
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format='%(asctime)s %(levelname)s [%(name)s] %(message)s',
|
|
datefmt='%Y-%m-%d %H:%M:%S'
|
|
)
|
|
logger = logging.getLogger('stt-realtime')
|
|
|
|
# Import RealtimeSTT
|
|
from RealtimeSTT import AudioToTextRecorder
|
|
|
|
# Global warmup state
|
|
warmup_complete = False
|
|
warmup_lock = threading.Lock()
|
|
warmup_recorder = None
|
|
|
|
|
|
class STTSession:
|
|
"""
|
|
Manages a single STT session for a WebSocket client.
|
|
Uses RealtimeSTT's AudioToTextRecorder with feed_audio() method.
|
|
"""
|
|
|
|
def __init__(self, websocket, session_id: str, config: Dict[str, Any]):
|
|
self.websocket = websocket
|
|
self.session_id = session_id
|
|
self.config = config
|
|
self.recorder: Optional[AudioToTextRecorder] = None
|
|
self.running = False
|
|
self.audio_queue = queue.Queue()
|
|
self.feed_thread: Optional[threading.Thread] = None
|
|
self.last_partial = ""
|
|
self.last_stabilized = "" # Track last stabilized partial
|
|
self.last_text_was_stabilized = False # Track which came last
|
|
self.recording_active = False # Track if currently recording
|
|
|
|
logger.info(f"[{session_id}] Session created")
|
|
|
|
def _on_realtime_transcription(self, text: str):
|
|
"""Called when partial transcription is available."""
|
|
if text and text != self.last_partial:
|
|
self.last_partial = text
|
|
self.last_text_was_stabilized = False # Partial came after stabilized
|
|
logger.info(f"[{self.session_id}] 📝 Partial: {text}")
|
|
asyncio.run_coroutine_threadsafe(
|
|
self._send_transcript("partial", text),
|
|
self.loop
|
|
)
|
|
|
|
def _on_realtime_stabilized(self, text: str):
|
|
"""Called when a stabilized partial is available (high confidence)."""
|
|
if text and text.strip():
|
|
self.last_stabilized = text
|
|
self.last_text_was_stabilized = True # Stabilized came after partial
|
|
logger.info(f"[{self.session_id}] 🔒 Stabilized: {text}")
|
|
asyncio.run_coroutine_threadsafe(
|
|
self._send_transcript("partial", text),
|
|
self.loop
|
|
)
|
|
|
|
def _on_recording_stop(self):
|
|
"""Called when recording stops (silence detected)."""
|
|
logger.info(f"[{self.session_id}] ⏹️ Recording stopped")
|
|
self.recording_active = False
|
|
|
|
# Use the most recent text: prioritize whichever came last
|
|
if self.last_text_was_stabilized:
|
|
final_text = self.last_stabilized or self.last_partial
|
|
source = "stabilized" if self.last_stabilized else "partial"
|
|
else:
|
|
final_text = self.last_partial or self.last_stabilized
|
|
source = "partial" if self.last_partial else "stabilized"
|
|
|
|
if final_text:
|
|
logger.info(f"[{self.session_id}] ✅ Final (from {source}): {final_text}")
|
|
asyncio.run_coroutine_threadsafe(
|
|
self._send_transcript("final", final_text),
|
|
self.loop
|
|
)
|
|
else:
|
|
# No transcript means VAD false positive (detected "speech" in pure noise)
|
|
logger.warning(f"[{self.session_id}] ⚠️ Recording stopped but no transcript available (VAD false positive)")
|
|
logger.info(f"[{self.session_id}] 🔄 Clearing audio buffer to recover")
|
|
|
|
# Clear the audio queue to prevent stale data
|
|
try:
|
|
while not self.audio_queue.empty():
|
|
self.audio_queue.get_nowait()
|
|
except Exception:
|
|
pass
|
|
|
|
# Reset state
|
|
self.last_stabilized = ""
|
|
self.last_partial = ""
|
|
self.last_text_was_stabilized = False
|
|
|
|
def _on_recording_start(self):
|
|
"""Called when recording starts (speech detected)."""
|
|
logger.info(f"[{self.session_id}] 🎙️ Recording started")
|
|
self.recording_active = True
|
|
self.last_stabilized = ""
|
|
self.last_partial = ""
|
|
|
|
def _on_transcription(self, text: str):
|
|
"""Not used - we use stabilized partials as finals."""
|
|
pass
|
|
|
|
async def _send_transcript(self, transcript_type: str, text: str):
|
|
"""Send transcript to client via WebSocket."""
|
|
try:
|
|
message = {
|
|
"type": transcript_type,
|
|
"text": text,
|
|
"timestamp": time.time()
|
|
}
|
|
await self.websocket.send(json.dumps(message))
|
|
except Exception as e:
|
|
logger.error(f"[{self.session_id}] Failed to send transcript: {e}")
|
|
|
|
def _feed_audio_thread(self):
|
|
"""Thread that feeds audio to the recorder."""
|
|
logger.info(f"[{self.session_id}] Audio feed thread started")
|
|
while self.running:
|
|
try:
|
|
# Get audio chunk with timeout
|
|
audio_chunk = self.audio_queue.get(timeout=0.1)
|
|
if audio_chunk is not None and self.recorder:
|
|
self.recorder.feed_audio(audio_chunk)
|
|
except queue.Empty:
|
|
continue
|
|
except Exception as e:
|
|
logger.error(f"[{self.session_id}] Error feeding audio: {e}")
|
|
logger.info(f"[{self.session_id}] Audio feed thread stopped")
|
|
|
|
async def start(self, loop: asyncio.AbstractEventLoop):
|
|
"""Start the STT session."""
|
|
self.loop = loop
|
|
self.running = True
|
|
|
|
logger.info(f"[{self.session_id}] Starting RealtimeSTT recorder...")
|
|
logger.info(f"[{self.session_id}] Model: {self.config['model']}")
|
|
logger.info(f"[{self.session_id}] Device: {self.config['device']}")
|
|
|
|
try:
|
|
# Create recorder in a thread to avoid blocking
|
|
def init_recorder():
|
|
self.recorder = AudioToTextRecorder(
|
|
# Model settings - using same model for both partial and final
|
|
model=self.config['model'],
|
|
language=self.config['language'],
|
|
compute_type=self.config['compute_type'],
|
|
device=self.config['device'],
|
|
|
|
# Disable microphone - we feed audio manually
|
|
use_microphone=False,
|
|
|
|
# Real-time transcription - use same model for everything
|
|
enable_realtime_transcription=True,
|
|
realtime_model_type=self.config['model'], # Use same model
|
|
realtime_processing_pause=0.05, # 50ms between updates
|
|
on_realtime_transcription_update=self._on_realtime_transcription,
|
|
on_realtime_transcription_stabilized=self._on_realtime_stabilized,
|
|
|
|
# VAD settings - very permissive, rely on Discord's VAD for speech detection
|
|
# Our VAD is only for silence detection, not filtering audio content
|
|
silero_sensitivity=0.05, # Very low = barely filters anything
|
|
silero_use_onnx=True, # Faster
|
|
webrtc_sensitivity=3,
|
|
post_speech_silence_duration=self.config['silence_duration'],
|
|
min_length_of_recording=self.config['min_recording_length'],
|
|
min_gap_between_recordings=self.config['min_gap'],
|
|
pre_recording_buffer_duration=1.0, # Capture more audio before/after speech
|
|
|
|
# Callbacks
|
|
on_recording_start=self._on_recording_start,
|
|
on_recording_stop=self._on_recording_stop,
|
|
on_vad_detect_start=lambda: logger.debug(f"[{self.session_id}] VAD listening"),
|
|
on_vad_detect_stop=lambda: logger.debug(f"[{self.session_id}] VAD stopped"),
|
|
|
|
# Other settings
|
|
spinner=False, # No spinner in container
|
|
level=logging.WARNING, # Reduce internal logging
|
|
|
|
# Beam search settings
|
|
beam_size=5, # Higher beam = better accuracy (used for final processing)
|
|
beam_size_realtime=5, # Increased from 3 for better real-time accuracy
|
|
|
|
# Batch sizes
|
|
batch_size=16,
|
|
realtime_batch_size=8,
|
|
|
|
initial_prompt="", # Can add context here if needed
|
|
)
|
|
logger.info(f"[{self.session_id}] ✅ Recorder initialized")
|
|
|
|
# Run initialization in thread pool
|
|
await asyncio.get_event_loop().run_in_executor(None, init_recorder)
|
|
|
|
# Start audio feed thread
|
|
self.feed_thread = threading.Thread(target=self._feed_audio_thread, daemon=True)
|
|
self.feed_thread.start()
|
|
|
|
# Start the recorder's text processing loop in a thread
|
|
def run_text_loop():
|
|
while self.running:
|
|
try:
|
|
# This blocks until speech is detected and transcribed
|
|
text = self.recorder.text(self._on_transcription)
|
|
except Exception as e:
|
|
if self.running:
|
|
logger.error(f"[{self.session_id}] Text loop error: {e}")
|
|
break
|
|
|
|
self.text_thread = threading.Thread(target=run_text_loop, daemon=True)
|
|
self.text_thread.start()
|
|
|
|
logger.info(f"[{self.session_id}] ✅ Session started successfully")
|
|
|
|
except Exception as e:
|
|
logger.error(f"[{self.session_id}] Failed to start session: {e}", exc_info=True)
|
|
raise
|
|
|
|
def feed_audio(self, audio_data: bytes):
|
|
"""Feed audio data to the recorder."""
|
|
if self.running:
|
|
# Convert bytes to numpy array (16-bit PCM)
|
|
audio_np = np.frombuffer(audio_data, dtype=np.int16)
|
|
self.audio_queue.put(audio_np)
|
|
|
|
def reset(self):
|
|
"""Reset the session state."""
|
|
logger.info(f"[{self.session_id}] Resetting session")
|
|
self.last_partial = ""
|
|
# Clear audio queue
|
|
while not self.audio_queue.empty():
|
|
try:
|
|
self.audio_queue.get_nowait()
|
|
except queue.Empty:
|
|
break
|
|
|
|
async def stop(self):
|
|
"""Stop the session and cleanup."""
|
|
logger.info(f"[{self.session_id}] Stopping session...")
|
|
self.running = False
|
|
|
|
# Wait for threads to finish
|
|
if self.feed_thread and self.feed_thread.is_alive():
|
|
self.feed_thread.join(timeout=2)
|
|
|
|
# Shutdown recorder
|
|
if self.recorder:
|
|
try:
|
|
self.recorder.shutdown()
|
|
except Exception as e:
|
|
logger.error(f"[{self.session_id}] Error shutting down recorder: {e}")
|
|
|
|
logger.info(f"[{self.session_id}] Session stopped")
|
|
|
|
|
|
class STTServer:
|
|
"""
|
|
WebSocket server for RealtimeSTT.
|
|
Handles multiple concurrent clients (one per Discord user).
|
|
"""
|
|
|
|
def __init__(self, host: str = "0.0.0.0", port: int = 8766):
|
|
self.host = host
|
|
self.port = port
|
|
self.sessions: Dict[str, STTSession] = {}
|
|
self.session_counter = 0
|
|
|
|
# Default configuration
|
|
self.config = {
|
|
# Model - using small.en (English-only, more accurate than multilingual small)
|
|
'model': 'small.en',
|
|
'language': 'en',
|
|
'compute_type': 'float16', # FP16 for GPU efficiency
|
|
'device': 'cuda',
|
|
|
|
# VAD settings
|
|
'silero_sensitivity': 0.6,
|
|
'webrtc_sensitivity': 3,
|
|
'silence_duration': 0.8, # Shorter to improve responsiveness
|
|
'min_recording_length': 0.5,
|
|
'min_gap': 0.3,
|
|
}
|
|
|
|
logger.info("=" * 60)
|
|
logger.info("RealtimeSTT Server Configuration:")
|
|
logger.info(f" Host: {host}:{port}")
|
|
logger.info(f" Model: {self.config['model']} (English-only, optimized)")
|
|
logger.info(f" Beam size: 5 (higher accuracy)")
|
|
logger.info(f" Strategy: Use last partial as final (instant response)")
|
|
logger.info(f" Language: {self.config['language']}")
|
|
logger.info(f" Device: {self.config['device']}")
|
|
logger.info(f" Compute Type: {self.config['compute_type']}")
|
|
logger.info(f" Silence Duration: {self.config['silence_duration']}s")
|
|
logger.info("=" * 60)
|
|
|
|
async def handle_client(self, websocket):
|
|
"""Handle a WebSocket client connection."""
|
|
self.session_counter += 1
|
|
session_id = f"session_{self.session_counter}"
|
|
session = None
|
|
|
|
try:
|
|
logger.info(f"[{session_id}] Client connected from {websocket.remote_address}")
|
|
|
|
# Create session
|
|
session = STTSession(websocket, session_id, self.config)
|
|
self.sessions[session_id] = session
|
|
|
|
# Start session
|
|
await session.start(asyncio.get_event_loop())
|
|
|
|
# Process messages
|
|
async for message in websocket:
|
|
try:
|
|
if isinstance(message, bytes):
|
|
# Binary audio data
|
|
session.feed_audio(message)
|
|
else:
|
|
# JSON command
|
|
data = json.loads(message)
|
|
command = data.get('command', '')
|
|
|
|
if command == 'reset':
|
|
session.reset()
|
|
elif command == 'ping':
|
|
await websocket.send(json.dumps({
|
|
'type': 'pong',
|
|
'timestamp': time.time()
|
|
}))
|
|
else:
|
|
logger.warning(f"[{session_id}] Unknown command: {command}")
|
|
|
|
except json.JSONDecodeError:
|
|
logger.warning(f"[{session_id}] Invalid JSON message")
|
|
except Exception as e:
|
|
logger.error(f"[{session_id}] Error processing message: {e}")
|
|
|
|
except websockets.exceptions.ConnectionClosed:
|
|
logger.info(f"[{session_id}] Client disconnected")
|
|
except Exception as e:
|
|
logger.error(f"[{session_id}] Error: {e}", exc_info=True)
|
|
finally:
|
|
# Cleanup
|
|
if session:
|
|
await session.stop()
|
|
del self.sessions[session_id]
|
|
|
|
async def run(self):
|
|
"""Run the WebSocket server."""
|
|
logger.info(f"Starting RealtimeSTT server on ws://{self.host}:{self.port}")
|
|
|
|
async with serve(
|
|
self.handle_client,
|
|
self.host,
|
|
self.port,
|
|
ping_interval=30,
|
|
ping_timeout=10,
|
|
max_size=10 * 1024 * 1024, # 10MB max message size
|
|
):
|
|
logger.info("✅ Server ready and listening for connections")
|
|
await asyncio.Future() # Run forever
|
|
|
|
|
|
async def warmup_model(config: Dict[str, Any]):
|
|
"""
|
|
Warm up the STT model by loading it and processing test audio.
|
|
This ensures the model is cached in memory before handling real requests.
|
|
"""
|
|
global warmup_complete, warmup_recorder
|
|
|
|
with warmup_lock:
|
|
if warmup_complete:
|
|
logger.info("Model already warmed up")
|
|
return
|
|
|
|
logger.info("🔥 Starting model warmup...")
|
|
try:
|
|
# Generate silent test audio (1 second of silence, 16kHz)
|
|
test_audio = np.zeros(16000, dtype=np.int16)
|
|
|
|
# Initialize a temporary recorder to load the model
|
|
logger.info("Loading Faster-Whisper model...")
|
|
|
|
def dummy_callback(text):
|
|
pass
|
|
|
|
# This will trigger model loading and compilation
|
|
warmup_recorder = AudioToTextRecorder(
|
|
model=config['model'],
|
|
language=config['language'],
|
|
compute_type=config['compute_type'],
|
|
device=config['device'],
|
|
silero_sensitivity=config['silero_sensitivity'],
|
|
webrtc_sensitivity=config['webrtc_sensitivity'],
|
|
post_speech_silence_duration=config['silence_duration'],
|
|
min_length_of_recording=config['min_recording_length'],
|
|
min_gap_between_recordings=config['min_gap'],
|
|
enable_realtime_transcription=True,
|
|
realtime_processing_pause=0.1,
|
|
on_realtime_transcription_update=dummy_callback,
|
|
on_realtime_transcription_stabilized=dummy_callback,
|
|
spinner=False,
|
|
level=logging.WARNING,
|
|
beam_size=5,
|
|
beam_size_realtime=5,
|
|
batch_size=16,
|
|
realtime_batch_size=8,
|
|
initial_prompt="",
|
|
)
|
|
|
|
logger.info("✅ Model loaded and warmed up successfully")
|
|
warmup_complete = True
|
|
|
|
except Exception as e:
|
|
logger.error(f"❌ Warmup failed: {e}", exc_info=True)
|
|
warmup_complete = False
|
|
|
|
|
|
async def health_handler(request):
|
|
"""HTTP health check endpoint"""
|
|
if warmup_complete:
|
|
return web.json_response({
|
|
"status": "ready",
|
|
"warmed_up": True,
|
|
"model": "small.en",
|
|
"device": "cuda"
|
|
})
|
|
else:
|
|
return web.json_response({
|
|
"status": "warming_up",
|
|
"warmed_up": False,
|
|
"model": "small.en",
|
|
"device": "cuda"
|
|
}, status=503)
|
|
|
|
|
|
async def start_http_server(host: str, http_port: int):
|
|
"""Start HTTP server for health checks"""
|
|
app = web.Application()
|
|
app.router.add_get('/health', health_handler)
|
|
|
|
runner = web.AppRunner(app)
|
|
await runner.setup()
|
|
site = web.TCPSite(runner, host, http_port)
|
|
await site.start()
|
|
|
|
logger.info(f"✅ HTTP health server listening on http://{host}:{http_port}")
|
|
|
|
|
|
def main():
|
|
"""Main entry point."""
|
|
import os
|
|
|
|
# Get configuration from environment
|
|
host = os.environ.get('STT_HOST', '0.0.0.0')
|
|
port = int(os.environ.get('STT_PORT', '8766'))
|
|
http_port = int(os.environ.get('STT_HTTP_PORT', '8767')) # HTTP health check port
|
|
|
|
# Configuration
|
|
config = {
|
|
'model': 'small.en',
|
|
'language': 'en',
|
|
'compute_type': 'float16',
|
|
'device': 'cuda',
|
|
'silero_sensitivity': 0.6,
|
|
'webrtc_sensitivity': 3,
|
|
'silence_duration': 0.8,
|
|
'min_recording_length': 0.5,
|
|
'min_gap': 0.3,
|
|
}
|
|
|
|
# Create and run server
|
|
server = STTServer(host=host, port=port)
|
|
|
|
async def run_all():
|
|
# Start warmup in background
|
|
asyncio.create_task(warmup_model(config))
|
|
|
|
# Start HTTP health server
|
|
asyncio.create_task(start_http_server(host, http_port))
|
|
|
|
# Start WebSocket server
|
|
await server.run()
|
|
|
|
try:
|
|
asyncio.run(run_all())
|
|
except KeyboardInterrupt:
|
|
logger.info("Server shutdown requested")
|
|
except Exception as e:
|
|
logger.error(f"Server error: {e}", exc_info=True)
|
|
raise
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|