""" WebSocket server for streaming ASR using onnx-asr """ import asyncio import websockets import numpy as np import json import logging from asr.asr_pipeline import ASRPipeline from typing import Optional logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) logger = logging.getLogger(__name__) class ASRWebSocketServer: """ WebSocket server for real-time speech recognition. """ def __init__( self, host: str = "0.0.0.0", port: int = 8766, model_name: str = "nemo-parakeet-tdt-0.6b-v3", model_path: Optional[str] = None, use_vad: bool = False, sample_rate: int = 16000, ): """ Initialize WebSocket server. Args: host: Server host address port: Server port model_name: ASR model name model_path: Optional local model path use_vad: Whether to use VAD sample_rate: Expected audio sample rate """ self.host = host self.port = port self.sample_rate = sample_rate logger.info("Initializing ASR Pipeline...") self.pipeline = ASRPipeline( model_name=model_name, model_path=model_path, use_vad=use_vad, ) logger.info("ASR Pipeline ready") self.active_connections = set() async def handle_client(self, websocket): """ Handle individual WebSocket client connection. Args: websocket: WebSocket connection """ client_id = f"{websocket.remote_address[0]}:{websocket.remote_address[1]}" logger.info(f"Client connected: {client_id}") self.active_connections.add(websocket) # Audio buffer for accumulating ALL audio all_audio = [] last_transcribed_samples = 0 # Track what we've already transcribed # For progressive transcription, we'll accumulate and transcribe the full buffer # This gives better results than processing tiny chunks min_chunk_duration = 2.0 # Minimum 2 seconds before transcribing min_chunk_samples = int(self.sample_rate * min_chunk_duration) try: # Send welcome message await websocket.send(json.dumps({ "type": "info", "message": "Connected to ASR server", "sample_rate": self.sample_rate, })) async for message in websocket: try: if isinstance(message, bytes): # Binary audio data # Convert bytes to float32 numpy array # Assuming int16 PCM data audio_data = np.frombuffer(message, dtype=np.int16) audio_data = audio_data.astype(np.float32) / 32768.0 # Accumulate all audio all_audio.append(audio_data) total_samples = sum(len(chunk) for chunk in all_audio) # Transcribe periodically when we have enough NEW audio samples_since_last = total_samples - last_transcribed_samples if samples_since_last >= min_chunk_samples: audio_chunk = np.concatenate(all_audio) last_transcribed_samples = total_samples # Transcribe the accumulated audio try: text = self.pipeline.transcribe( audio_chunk, sample_rate=self.sample_rate ) if text and text.strip(): response = { "type": "transcript", "text": text, "is_final": False, } await websocket.send(json.dumps(response)) logger.info(f"Progressive transcription: {text}") except Exception as e: logger.error(f"Transcription error: {e}") await websocket.send(json.dumps({ "type": "error", "message": f"Transcription failed: {str(e)}" })) elif isinstance(message, str): # JSON command try: command = json.loads(message) if command.get("type") == "final": # Process all accumulated audio (final transcription) if all_audio: audio_chunk = np.concatenate(all_audio) text = self.pipeline.transcribe( audio_chunk, sample_rate=self.sample_rate ) if text and text.strip(): response = { "type": "transcript", "text": text, "is_final": True, } await websocket.send(json.dumps(response)) logger.info(f"Final transcription: {text}") # Clear buffer after final transcription all_audio = [] last_transcribed_samples = 0 elif command.get("type") == "reset": # Reset buffer all_audio = [] last_transcribed_samples = 0 await websocket.send(json.dumps({ "type": "info", "message": "Buffer reset" })) except json.JSONDecodeError: logger.warning(f"Invalid JSON command: {message}") except Exception as e: logger.error(f"Error processing message: {e}") await websocket.send(json.dumps({ "type": "error", "message": str(e) })) except websockets.exceptions.ConnectionClosed: logger.info(f"Client disconnected: {client_id}") finally: self.active_connections.discard(websocket) logger.info(f"Connection closed: {client_id}") async def start(self): """ Start the WebSocket server. """ logger.info(f"Starting WebSocket server on {self.host}:{self.port}") async with websockets.serve(self.handle_client, self.host, self.port): logger.info(f"Server running on ws://{self.host}:{self.port}") logger.info(f"Active connections: {len(self.active_connections)}") await asyncio.Future() # Run forever def run(self): """ Run the server (blocking). """ try: asyncio.run(self.start()) except KeyboardInterrupt: logger.info("Server stopped by user") def main(): """ Main entry point for the WebSocket server. """ import argparse parser = argparse.ArgumentParser(description="ASR WebSocket Server") parser.add_argument("--host", default="0.0.0.0", help="Server host") parser.add_argument("--port", type=int, default=8766, help="Server port") parser.add_argument("--model", default="nemo-parakeet-tdt-0.6b-v3", help="Model name") parser.add_argument("--model-path", default=None, help="Local model path") parser.add_argument("--use-vad", action="store_true", help="Enable VAD") parser.add_argument("--sample-rate", type=int, default=16000, help="Audio sample rate") args = parser.parse_args() server = ASRWebSocketServer( host=args.host, port=args.port, model_name=args.model, model_path=args.model_path, use_vad=args.use_vad, sample_rate=args.sample_rate, ) server.run() if __name__ == "__main__": main()