Decided on Parakeet ONNX Runtime. Works pretty great. Realtime voice chat possible now. UX lacking.

This commit is contained in:
2026-01-19 00:29:44 +02:00
parent 0a8910fff8
commit 362108f4b0
34 changed files with 4593 additions and 73 deletions

View File

@@ -0,0 +1,6 @@
"""
WebSocket server module for streaming ASR
"""
from .ws_server import ASRWebSocketServer
__all__ = ["ASRWebSocketServer"]

View File

@@ -0,0 +1,292 @@
#!/usr/bin/env python3
"""
ASR WebSocket Server with Live Transcription Display
This version displays transcriptions in real-time on the server console
while clients stream audio from remote machines.
"""
import asyncio
import websockets
import numpy as np
import json
import logging
import sys
from datetime import datetime
from pathlib import Path
# Add project root to path
sys.path.insert(0, str(Path(__file__).parent.parent))
from asr.asr_pipeline import ASRPipeline
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler('display_server.log'),
logging.StreamHandler()
]
)
logger = logging.getLogger(__name__)
class DisplayServer:
"""
WebSocket server with live transcription display.
"""
def __init__(
self,
host: str = "0.0.0.0",
port: int = 8766,
model_path: str = "models/parakeet",
sample_rate: int = 16000,
):
"""
Initialize server.
Args:
host: Host address to bind to
port: Port to bind to
model_path: Directory containing model files
sample_rate: Audio sample rate
"""
self.host = host
self.port = port
self.sample_rate = sample_rate
self.active_connections = set()
# Terminal control codes
self.CLEAR_LINE = '\033[2K'
self.CURSOR_UP = '\033[1A'
self.BOLD = '\033[1m'
self.GREEN = '\033[92m'
self.YELLOW = '\033[93m'
self.BLUE = '\033[94m'
self.RESET = '\033[0m'
# Initialize ASR pipeline
logger.info("Loading ASR model...")
self.pipeline = ASRPipeline(model_path=model_path)
logger.info("ASR Pipeline ready")
# Client sessions
self.sessions = {}
def print_header(self):
"""Print server header."""
print("\n" + "=" * 80)
print(f"{self.BOLD}{self.BLUE}ASR Server - Live Transcription Display{self.RESET}")
print("=" * 80)
print(f"Server: ws://{self.host}:{self.port}")
print(f"Sample Rate: {self.sample_rate} Hz")
print(f"Model: Parakeet TDT 0.6B V3")
print("=" * 80 + "\n")
def display_transcription(self, client_id: str, text: str, is_final: bool, is_progressive: bool = False):
"""
Display transcription in the terminal.
Args:
client_id: Client identifier
text: Transcribed text
is_final: Whether this is the final transcription
is_progressive: Whether this is a progressive update
"""
timestamp = datetime.now().strftime("%H:%M:%S")
if is_final:
# Final transcription - bold green
print(f"{self.GREEN}{self.BOLD}[{timestamp}] {client_id}{self.RESET}")
print(f"{self.GREEN} ✓ FINAL: {text}{self.RESET}\n")
elif is_progressive:
# Progressive update - yellow
print(f"{self.YELLOW}[{timestamp}] {client_id}{self.RESET}")
print(f"{self.YELLOW}{text}{self.RESET}\n")
else:
# Regular transcription
print(f"{self.BLUE}[{timestamp}] {client_id}{self.RESET}")
print(f" {text}\n")
# Flush to ensure immediate display
sys.stdout.flush()
async def handle_client(self, websocket):
"""
Handle individual WebSocket client connection.
Args:
websocket: WebSocket connection
"""
client_id = f"{websocket.remote_address[0]}:{websocket.remote_address[1]}"
logger.info(f"Client connected: {client_id}")
self.active_connections.add(websocket)
# Display connection
print(f"\n{self.BOLD}{'='*80}{self.RESET}")
print(f"{self.GREEN}✓ Client connected: {client_id}{self.RESET}")
print(f"{self.BOLD}{'='*80}{self.RESET}\n")
sys.stdout.flush()
# Audio buffer for accumulating ALL audio
all_audio = []
last_transcribed_samples = 0
# For progressive transcription
min_chunk_duration = 2.0 # Minimum 2 seconds before transcribing
min_chunk_samples = int(self.sample_rate * min_chunk_duration)
try:
# Send welcome message
await websocket.send(json.dumps({
"type": "info",
"message": "Connected to ASR server with live display",
"sample_rate": self.sample_rate,
}))
async for message in websocket:
try:
if isinstance(message, bytes):
# Binary audio data
audio_data = np.frombuffer(message, dtype=np.int16)
audio_data = audio_data.astype(np.float32) / 32768.0
# Accumulate all audio
all_audio.append(audio_data)
total_samples = sum(len(chunk) for chunk in all_audio)
# Transcribe periodically when we have enough NEW audio
samples_since_last = total_samples - last_transcribed_samples
if samples_since_last >= min_chunk_samples:
audio_chunk = np.concatenate(all_audio)
last_transcribed_samples = total_samples
# Transcribe the accumulated audio
try:
text = self.pipeline.transcribe(
audio_chunk,
sample_rate=self.sample_rate
)
if text and text.strip():
# Display on server
self.display_transcription(client_id, text, is_final=False, is_progressive=True)
# Send to client
response = {
"type": "transcript",
"text": text,
"is_final": False,
}
await websocket.send(json.dumps(response))
except Exception as e:
logger.error(f"Transcription error: {e}")
await websocket.send(json.dumps({
"type": "error",
"message": f"Transcription failed: {str(e)}"
}))
elif isinstance(message, str):
# JSON command
try:
command = json.loads(message)
if command.get("type") == "final":
# Process all accumulated audio (final transcription)
if all_audio:
audio_chunk = np.concatenate(all_audio)
text = self.pipeline.transcribe(
audio_chunk,
sample_rate=self.sample_rate
)
if text and text.strip():
# Display on server
self.display_transcription(client_id, text, is_final=True)
# Send to client
response = {
"type": "transcript",
"text": text,
"is_final": True,
}
await websocket.send(json.dumps(response))
# Clear buffer after final transcription
all_audio = []
last_transcribed_samples = 0
elif command.get("type") == "reset":
# Reset buffer
all_audio = []
last_transcribed_samples = 0
await websocket.send(json.dumps({
"type": "info",
"message": "Buffer reset"
}))
print(f"{self.YELLOW}[{client_id}] Buffer reset{self.RESET}\n")
sys.stdout.flush()
except json.JSONDecodeError:
logger.warning(f"Invalid JSON from {client_id}: {message}")
except Exception as e:
logger.error(f"Error processing message from {client_id}: {e}")
break
except websockets.exceptions.ConnectionClosed:
logger.info(f"Connection closed: {client_id}")
except Exception as e:
logger.error(f"Unexpected error with {client_id}: {e}")
finally:
self.active_connections.discard(websocket)
print(f"\n{self.BOLD}{'='*80}{self.RESET}")
print(f"{self.YELLOW}✗ Client disconnected: {client_id}{self.RESET}")
print(f"{self.BOLD}{'='*80}{self.RESET}\n")
sys.stdout.flush()
logger.info(f"Connection closed: {client_id}")
async def start(self):
"""Start the WebSocket server."""
self.print_header()
async with websockets.serve(self.handle_client, self.host, self.port):
logger.info(f"Starting WebSocket server on {self.host}:{self.port}")
print(f"{self.GREEN}{self.BOLD}Server is running and ready for connections!{self.RESET}")
print(f"{self.BOLD}Waiting for clients...{self.RESET}\n")
sys.stdout.flush()
# Keep server running
await asyncio.Future()
def main():
"""Main entry point."""
import argparse
parser = argparse.ArgumentParser(description="ASR Server with Live Display")
parser.add_argument("--host", default="0.0.0.0", help="Host address")
parser.add_argument("--port", type=int, default=8766, help="Port number")
parser.add_argument("--model-path", default="models/parakeet", help="Model directory")
parser.add_argument("--sample-rate", type=int, default=16000, help="Sample rate")
args = parser.parse_args()
server = DisplayServer(
host=args.host,
port=args.port,
model_path=args.model_path,
sample_rate=args.sample_rate,
)
try:
asyncio.run(server.start())
except KeyboardInterrupt:
print(f"\n\n{server.YELLOW}Server stopped by user{server.RESET}")
logger.info("Server stopped by user")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,416 @@
#!/usr/bin/env python3
"""
ASR WebSocket Server with VAD - Optimized for Discord Bots
This server uses Voice Activity Detection (VAD) to:
- Detect speech start and end automatically
- Only transcribe speech segments (ignore silence)
- Provide clean boundaries for Discord message formatting
- Minimize processing of silence/noise
"""
import asyncio
import websockets
import numpy as np
import json
import logging
import sys
from datetime import datetime
from pathlib import Path
from collections import deque
from dataclasses import dataclass
from typing import Optional
# Add project root to path
sys.path.insert(0, str(Path(__file__).parent.parent))
from asr.asr_pipeline import ASRPipeline
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler('vad_server.log'),
logging.StreamHandler()
]
)
logger = logging.getLogger(__name__)
@dataclass
class SpeechSegment:
"""Represents a segment of detected speech."""
audio: np.ndarray
start_time: float
end_time: Optional[float] = None
is_complete: bool = False
class VADState:
"""Manages VAD state for speech detection."""
def __init__(self, sample_rate: int = 16000, speech_threshold: float = 0.5):
self.sample_rate = sample_rate
# Simple energy-based VAD parameters
self.energy_threshold = 0.005 # Lower threshold for better detection
self.speech_frames = 0
self.silence_frames = 0
self.min_speech_frames = 3 # 3 frames minimum (300ms with 100ms chunks)
self.min_silence_frames = 5 # 5 frames of silence (500ms)
self.is_speech = False
self.speech_buffer = []
# Pre-buffer to capture audio BEFORE speech detection
# This prevents cutting off the start of speech
self.pre_buffer_frames = 5 # Keep 5 frames (500ms) of pre-speech audio
self.pre_buffer = deque(maxlen=self.pre_buffer_frames)
# Progressive transcription tracking
self.last_partial_samples = 0 # Track when we last sent a partial
self.partial_interval_samples = int(sample_rate * 0.3) # Partial every 0.3 seconds (near real-time)
logger.info(f"VAD initialized: energy_threshold={self.energy_threshold}, pre_buffer={self.pre_buffer_frames} frames")
def calculate_energy(self, audio_chunk: np.ndarray) -> float:
"""Calculate RMS energy of audio chunk."""
return np.sqrt(np.mean(audio_chunk ** 2))
def process_audio(self, audio_chunk: np.ndarray) -> tuple[bool, Optional[np.ndarray], Optional[np.ndarray]]:
"""
Process audio chunk and detect speech boundaries.
Returns:
(speech_detected, complete_segment, partial_segment)
- speech_detected: True if currently in speech
- complete_segment: Audio segment if speech ended, None otherwise
- partial_segment: Audio for partial transcription, None otherwise
"""
energy = self.calculate_energy(audio_chunk)
chunk_is_speech = energy > self.energy_threshold
logger.debug(f"Energy: {energy:.6f}, Is speech: {chunk_is_speech}")
partial_segment = None
if chunk_is_speech:
self.speech_frames += 1
self.silence_frames = 0
if not self.is_speech and self.speech_frames >= self.min_speech_frames:
# Speech started - add pre-buffer to capture the beginning!
self.is_speech = True
logger.info("🎤 Speech started (including pre-buffer)")
# Add pre-buffered audio to speech buffer
if self.pre_buffer:
logger.debug(f"Adding {len(self.pre_buffer)} pre-buffered frames")
self.speech_buffer.extend(list(self.pre_buffer))
self.pre_buffer.clear()
if self.is_speech:
self.speech_buffer.append(audio_chunk)
else:
# Not in speech yet, keep in pre-buffer
self.pre_buffer.append(audio_chunk)
# Check if we should send a partial transcription
current_samples = sum(len(chunk) for chunk in self.speech_buffer)
samples_since_last_partial = current_samples - self.last_partial_samples
# Send partial if enough NEW audio accumulated AND we have minimum duration
min_duration_for_partial = int(self.sample_rate * 0.8) # At least 0.8s of audio
if samples_since_last_partial >= self.partial_interval_samples and current_samples >= min_duration_for_partial:
# Time for a partial update
partial_segment = np.concatenate(self.speech_buffer)
self.last_partial_samples = current_samples
logger.debug(f"📝 Partial update: {current_samples/self.sample_rate:.2f}s")
else:
if self.is_speech:
self.silence_frames += 1
# Add some trailing silence (up to limit)
if self.silence_frames < self.min_silence_frames:
self.speech_buffer.append(audio_chunk)
else:
# Speech ended
logger.info(f"🛑 Speech ended after {self.silence_frames} silence frames")
self.is_speech = False
self.speech_frames = 0
self.silence_frames = 0
self.last_partial_samples = 0 # Reset partial counter
if self.speech_buffer:
complete_segment = np.concatenate(self.speech_buffer)
segment_duration = len(complete_segment) / self.sample_rate
self.speech_buffer = []
self.pre_buffer.clear() # Clear pre-buffer after speech ends
logger.info(f"✅ Complete segment: {segment_duration:.2f}s")
return False, complete_segment, None
else:
self.speech_frames = 0
# Keep adding to pre-buffer when not in speech
self.pre_buffer.append(audio_chunk)
return self.is_speech, None, partial_segment
class VADServer:
"""
WebSocket server with VAD for Discord bot integration.
"""
def __init__(
self,
host: str = "0.0.0.0",
port: int = 8766,
model_path: str = "models/parakeet",
sample_rate: int = 16000,
):
"""Initialize server."""
self.host = host
self.port = port
self.sample_rate = sample_rate
self.active_connections = set()
# Terminal control codes
self.BOLD = '\033[1m'
self.GREEN = '\033[92m'
self.YELLOW = '\033[93m'
self.BLUE = '\033[94m'
self.RED = '\033[91m'
self.RESET = '\033[0m'
# Initialize ASR pipeline
logger.info("Loading ASR model...")
self.pipeline = ASRPipeline(model_path=model_path)
logger.info("ASR Pipeline ready")
def print_header(self):
"""Print server header."""
print("\n" + "=" * 80)
print(f"{self.BOLD}{self.BLUE}ASR Server with VAD - Discord Bot Ready{self.RESET}")
print("=" * 80)
print(f"Server: ws://{self.host}:{self.port}")
print(f"Sample Rate: {self.sample_rate} Hz")
print(f"Model: Parakeet TDT 0.6B V3")
print(f"VAD: Energy-based speech detection")
print("=" * 80 + "\n")
def display_transcription(self, client_id: str, text: str, duration: float):
"""Display transcription in the terminal."""
timestamp = datetime.now().strftime("%H:%M:%S")
print(f"{self.GREEN}{self.BOLD}[{timestamp}] {client_id}{self.RESET}")
print(f"{self.GREEN} 📝 {text}{self.RESET}")
print(f"{self.BLUE} ⏱️ Duration: {duration:.2f}s{self.RESET}\n")
sys.stdout.flush()
async def handle_client(self, websocket):
"""Handle WebSocket client connection."""
client_id = f"{websocket.remote_address[0]}:{websocket.remote_address[1]}"
logger.info(f"Client connected: {client_id}")
self.active_connections.add(websocket)
print(f"\n{self.BOLD}{'='*80}{self.RESET}")
print(f"{self.GREEN}✓ Client connected: {client_id}{self.RESET}")
print(f"{self.BOLD}{'='*80}{self.RESET}\n")
sys.stdout.flush()
# Initialize VAD state for this client
vad_state = VADState(sample_rate=self.sample_rate)
try:
# Send welcome message
await websocket.send(json.dumps({
"type": "info",
"message": "Connected to ASR server with VAD",
"sample_rate": self.sample_rate,
"vad_enabled": True,
}))
async for message in websocket:
try:
if isinstance(message, bytes):
# Binary audio data
audio_data = np.frombuffer(message, dtype=np.int16)
audio_data = audio_data.astype(np.float32) / 32768.0
# Process through VAD
is_speech, complete_segment, partial_segment = vad_state.process_audio(audio_data)
# Send VAD status to client (only on state change)
prev_speech_state = getattr(vad_state, '_prev_speech_state', False)
if is_speech != prev_speech_state:
vad_state._prev_speech_state = is_speech
await websocket.send(json.dumps({
"type": "vad_status",
"is_speech": is_speech,
}))
# Handle partial transcription (progressive updates while speaking)
if partial_segment is not None:
try:
text = self.pipeline.transcribe(
partial_segment,
sample_rate=self.sample_rate
)
if text and text.strip():
duration = len(partial_segment) / self.sample_rate
# Display on server
timestamp = datetime.now().strftime("%H:%M:%S")
print(f"{self.YELLOW}[{timestamp}] {client_id}{self.RESET}")
print(f"{self.YELLOW} → PARTIAL: {text}{self.RESET}\n")
sys.stdout.flush()
# Send to client
response = {
"type": "transcript",
"text": text,
"is_final": False,
"duration": duration,
}
await websocket.send(json.dumps(response))
except Exception as e:
logger.error(f"Partial transcription error: {e}")
# If we have a complete speech segment, transcribe it
if complete_segment is not None:
try:
text = self.pipeline.transcribe(
complete_segment,
sample_rate=self.sample_rate
)
if text and text.strip():
duration = len(complete_segment) / self.sample_rate
# Display on server
self.display_transcription(client_id, text, duration)
# Send to client
response = {
"type": "transcript",
"text": text,
"is_final": True,
"duration": duration,
}
await websocket.send(json.dumps(response))
except Exception as e:
logger.error(f"Transcription error: {e}")
await websocket.send(json.dumps({
"type": "error",
"message": f"Transcription failed: {str(e)}"
}))
elif isinstance(message, str):
# JSON command
try:
command = json.loads(message)
if command.get("type") == "force_transcribe":
# Force transcribe current buffer
if vad_state.speech_buffer:
audio_chunk = np.concatenate(vad_state.speech_buffer)
vad_state.speech_buffer = []
vad_state.is_speech = False
text = self.pipeline.transcribe(
audio_chunk,
sample_rate=self.sample_rate
)
if text and text.strip():
duration = len(audio_chunk) / self.sample_rate
self.display_transcription(client_id, text, duration)
response = {
"type": "transcript",
"text": text,
"is_final": True,
"duration": duration,
}
await websocket.send(json.dumps(response))
elif command.get("type") == "reset":
# Reset VAD state
vad_state = VADState(sample_rate=self.sample_rate)
await websocket.send(json.dumps({
"type": "info",
"message": "VAD state reset"
}))
print(f"{self.YELLOW}[{client_id}] VAD reset{self.RESET}\n")
sys.stdout.flush()
elif command.get("type") == "set_threshold":
# Adjust VAD threshold
threshold = command.get("threshold", 0.01)
vad_state.energy_threshold = threshold
await websocket.send(json.dumps({
"type": "info",
"message": f"VAD threshold set to {threshold}"
}))
except json.JSONDecodeError:
logger.warning(f"Invalid JSON from {client_id}: {message}")
except Exception as e:
logger.error(f"Error processing message from {client_id}: {e}")
break
except websockets.exceptions.ConnectionClosed:
logger.info(f"Connection closed: {client_id}")
except Exception as e:
logger.error(f"Unexpected error with {client_id}: {e}")
finally:
self.active_connections.discard(websocket)
print(f"\n{self.BOLD}{'='*80}{self.RESET}")
print(f"{self.YELLOW}✗ Client disconnected: {client_id}{self.RESET}")
print(f"{self.BOLD}{'='*80}{self.RESET}\n")
sys.stdout.flush()
logger.info(f"Connection closed: {client_id}")
async def start(self):
"""Start the WebSocket server."""
self.print_header()
async with websockets.serve(self.handle_client, self.host, self.port):
logger.info(f"Starting WebSocket server on {self.host}:{self.port}")
print(f"{self.GREEN}{self.BOLD}Server is running with VAD enabled!{self.RESET}")
print(f"{self.BOLD}Ready for Discord bot connections...{self.RESET}\n")
sys.stdout.flush()
# Keep server running
await asyncio.Future()
def main():
"""Main entry point."""
import argparse
parser = argparse.ArgumentParser(description="ASR Server with VAD for Discord")
parser.add_argument("--host", default="0.0.0.0", help="Host address")
parser.add_argument("--port", type=int, default=8766, help="Port number")
parser.add_argument("--model-path", default="models/parakeet", help="Model directory")
parser.add_argument("--sample-rate", type=int, default=16000, help="Sample rate")
args = parser.parse_args()
server = VADServer(
host=args.host,
port=args.port,
model_path=args.model_path,
sample_rate=args.sample_rate,
)
try:
asyncio.run(server.start())
except KeyboardInterrupt:
print(f"\n\n{server.YELLOW}Server stopped by user{server.RESET}")
logger.info("Server stopped by user")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,231 @@
"""
WebSocket server for streaming ASR using onnx-asr
"""
import asyncio
import websockets
import numpy as np
import json
import logging
from asr.asr_pipeline import ASRPipeline
from typing import Optional
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
class ASRWebSocketServer:
"""
WebSocket server for real-time speech recognition.
"""
def __init__(
self,
host: str = "0.0.0.0",
port: int = 8766,
model_name: str = "nemo-parakeet-tdt-0.6b-v3",
model_path: Optional[str] = None,
use_vad: bool = False,
sample_rate: int = 16000,
):
"""
Initialize WebSocket server.
Args:
host: Server host address
port: Server port
model_name: ASR model name
model_path: Optional local model path
use_vad: Whether to use VAD
sample_rate: Expected audio sample rate
"""
self.host = host
self.port = port
self.sample_rate = sample_rate
logger.info("Initializing ASR Pipeline...")
self.pipeline = ASRPipeline(
model_name=model_name,
model_path=model_path,
use_vad=use_vad,
)
logger.info("ASR Pipeline ready")
self.active_connections = set()
async def handle_client(self, websocket):
"""
Handle individual WebSocket client connection.
Args:
websocket: WebSocket connection
"""
client_id = f"{websocket.remote_address[0]}:{websocket.remote_address[1]}"
logger.info(f"Client connected: {client_id}")
self.active_connections.add(websocket)
# Audio buffer for accumulating ALL audio
all_audio = []
last_transcribed_samples = 0 # Track what we've already transcribed
# For progressive transcription, we'll accumulate and transcribe the full buffer
# This gives better results than processing tiny chunks
min_chunk_duration = 2.0 # Minimum 2 seconds before transcribing
min_chunk_samples = int(self.sample_rate * min_chunk_duration)
try:
# Send welcome message
await websocket.send(json.dumps({
"type": "info",
"message": "Connected to ASR server",
"sample_rate": self.sample_rate,
}))
async for message in websocket:
try:
if isinstance(message, bytes):
# Binary audio data
# Convert bytes to float32 numpy array
# Assuming int16 PCM data
audio_data = np.frombuffer(message, dtype=np.int16)
audio_data = audio_data.astype(np.float32) / 32768.0
# Accumulate all audio
all_audio.append(audio_data)
total_samples = sum(len(chunk) for chunk in all_audio)
# Transcribe periodically when we have enough NEW audio
samples_since_last = total_samples - last_transcribed_samples
if samples_since_last >= min_chunk_samples:
audio_chunk = np.concatenate(all_audio)
last_transcribed_samples = total_samples
# Transcribe the accumulated audio
try:
text = self.pipeline.transcribe(
audio_chunk,
sample_rate=self.sample_rate
)
if text and text.strip():
response = {
"type": "transcript",
"text": text,
"is_final": False,
}
await websocket.send(json.dumps(response))
logger.info(f"Progressive transcription: {text}")
except Exception as e:
logger.error(f"Transcription error: {e}")
await websocket.send(json.dumps({
"type": "error",
"message": f"Transcription failed: {str(e)}"
}))
elif isinstance(message, str):
# JSON command
try:
command = json.loads(message)
if command.get("type") == "final":
# Process all accumulated audio (final transcription)
if all_audio:
audio_chunk = np.concatenate(all_audio)
text = self.pipeline.transcribe(
audio_chunk,
sample_rate=self.sample_rate
)
if text and text.strip():
response = {
"type": "transcript",
"text": text,
"is_final": True,
}
await websocket.send(json.dumps(response))
logger.info(f"Final transcription: {text}")
# Clear buffer after final transcription
all_audio = []
last_transcribed_samples = 0
elif command.get("type") == "reset":
# Reset buffer
all_audio = []
last_transcribed_samples = 0
await websocket.send(json.dumps({
"type": "info",
"message": "Buffer reset"
}))
except json.JSONDecodeError:
logger.warning(f"Invalid JSON command: {message}")
except Exception as e:
logger.error(f"Error processing message: {e}")
await websocket.send(json.dumps({
"type": "error",
"message": str(e)
}))
except websockets.exceptions.ConnectionClosed:
logger.info(f"Client disconnected: {client_id}")
finally:
self.active_connections.discard(websocket)
logger.info(f"Connection closed: {client_id}")
async def start(self):
"""
Start the WebSocket server.
"""
logger.info(f"Starting WebSocket server on {self.host}:{self.port}")
async with websockets.serve(self.handle_client, self.host, self.port):
logger.info(f"Server running on ws://{self.host}:{self.port}")
logger.info(f"Active connections: {len(self.active_connections)}")
await asyncio.Future() # Run forever
def run(self):
"""
Run the server (blocking).
"""
try:
asyncio.run(self.start())
except KeyboardInterrupt:
logger.info("Server stopped by user")
def main():
"""
Main entry point for the WebSocket server.
"""
import argparse
parser = argparse.ArgumentParser(description="ASR WebSocket Server")
parser.add_argument("--host", default="0.0.0.0", help="Server host")
parser.add_argument("--port", type=int, default=8766, help="Server port")
parser.add_argument("--model", default="nemo-parakeet-tdt-0.6b-v3", help="Model name")
parser.add_argument("--model-path", default=None, help="Local model path")
parser.add_argument("--use-vad", action="store_true", help="Enable VAD")
parser.add_argument("--sample-rate", type=int, default=16000, help="Audio sample rate")
args = parser.parse_args()
server = ASRWebSocketServer(
host=args.host,
port=args.port,
model_name=args.model,
model_path=args.model_path,
use_vad=args.use_vad,
sample_rate=args.sample_rate,
)
server.run()
if __name__ == "__main__":
main()