#!/usr/bin/env python3 """ ASR WebSocket Server with Live Transcription Display This version displays transcriptions in real-time on the server console while clients stream audio from remote machines. """ import asyncio import websockets import numpy as np import json import logging import sys from datetime import datetime from pathlib import Path # Add project root to path sys.path.insert(0, str(Path(__file__).parent.parent)) from asr.asr_pipeline import ASRPipeline # Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', handlers=[ logging.FileHandler('display_server.log'), logging.StreamHandler() ] ) logger = logging.getLogger(__name__) class DisplayServer: """ WebSocket server with live transcription display. """ def __init__( self, host: str = "0.0.0.0", port: int = 8766, model_path: str = "models/parakeet", sample_rate: int = 16000, ): """ Initialize server. Args: host: Host address to bind to port: Port to bind to model_path: Directory containing model files sample_rate: Audio sample rate """ self.host = host self.port = port self.sample_rate = sample_rate self.active_connections = set() # Terminal control codes self.CLEAR_LINE = '\033[2K' self.CURSOR_UP = '\033[1A' self.BOLD = '\033[1m' self.GREEN = '\033[92m' self.YELLOW = '\033[93m' self.BLUE = '\033[94m' self.RESET = '\033[0m' # Initialize ASR pipeline logger.info("Loading ASR model...") self.pipeline = ASRPipeline(model_path=model_path) logger.info("ASR Pipeline ready") # Client sessions self.sessions = {} def print_header(self): """Print server header.""" print("\n" + "=" * 80) print(f"{self.BOLD}{self.BLUE}ASR Server - Live Transcription Display{self.RESET}") print("=" * 80) print(f"Server: ws://{self.host}:{self.port}") print(f"Sample Rate: {self.sample_rate} Hz") print(f"Model: Parakeet TDT 0.6B V3") print("=" * 80 + "\n") def display_transcription(self, client_id: str, text: str, is_final: bool, is_progressive: bool = False): """ Display transcription in the terminal. Args: client_id: Client identifier text: Transcribed text is_final: Whether this is the final transcription is_progressive: Whether this is a progressive update """ timestamp = datetime.now().strftime("%H:%M:%S") if is_final: # Final transcription - bold green print(f"{self.GREEN}{self.BOLD}[{timestamp}] {client_id}{self.RESET}") print(f"{self.GREEN} ✓ FINAL: {text}{self.RESET}\n") elif is_progressive: # Progressive update - yellow print(f"{self.YELLOW}[{timestamp}] {client_id}{self.RESET}") print(f"{self.YELLOW} → {text}{self.RESET}\n") else: # Regular transcription print(f"{self.BLUE}[{timestamp}] {client_id}{self.RESET}") print(f" {text}\n") # Flush to ensure immediate display sys.stdout.flush() async def handle_client(self, websocket): """ Handle individual WebSocket client connection. Args: websocket: WebSocket connection """ client_id = f"{websocket.remote_address[0]}:{websocket.remote_address[1]}" logger.info(f"Client connected: {client_id}") self.active_connections.add(websocket) # Display connection print(f"\n{self.BOLD}{'='*80}{self.RESET}") print(f"{self.GREEN}✓ Client connected: {client_id}{self.RESET}") print(f"{self.BOLD}{'='*80}{self.RESET}\n") sys.stdout.flush() # Audio buffer for accumulating ALL audio all_audio = [] last_transcribed_samples = 0 # For progressive transcription min_chunk_duration = 2.0 # Minimum 2 seconds before transcribing min_chunk_samples = int(self.sample_rate * min_chunk_duration) try: # Send welcome message await websocket.send(json.dumps({ "type": "info", "message": "Connected to ASR server with live display", "sample_rate": self.sample_rate, })) async for message in websocket: try: if isinstance(message, bytes): # Binary audio data audio_data = np.frombuffer(message, dtype=np.int16) audio_data = audio_data.astype(np.float32) / 32768.0 # Accumulate all audio all_audio.append(audio_data) total_samples = sum(len(chunk) for chunk in all_audio) # Transcribe periodically when we have enough NEW audio samples_since_last = total_samples - last_transcribed_samples if samples_since_last >= min_chunk_samples: audio_chunk = np.concatenate(all_audio) last_transcribed_samples = total_samples # Transcribe the accumulated audio try: text = self.pipeline.transcribe( audio_chunk, sample_rate=self.sample_rate ) if text and text.strip(): # Display on server self.display_transcription(client_id, text, is_final=False, is_progressive=True) # Send to client response = { "type": "transcript", "text": text, "is_final": False, } await websocket.send(json.dumps(response)) except Exception as e: logger.error(f"Transcription error: {e}") await websocket.send(json.dumps({ "type": "error", "message": f"Transcription failed: {str(e)}" })) elif isinstance(message, str): # JSON command try: command = json.loads(message) if command.get("type") == "final": # Process all accumulated audio (final transcription) if all_audio: audio_chunk = np.concatenate(all_audio) text = self.pipeline.transcribe( audio_chunk, sample_rate=self.sample_rate ) if text and text.strip(): # Display on server self.display_transcription(client_id, text, is_final=True) # Send to client response = { "type": "transcript", "text": text, "is_final": True, } await websocket.send(json.dumps(response)) # Clear buffer after final transcription all_audio = [] last_transcribed_samples = 0 elif command.get("type") == "reset": # Reset buffer all_audio = [] last_transcribed_samples = 0 await websocket.send(json.dumps({ "type": "info", "message": "Buffer reset" })) print(f"{self.YELLOW}[{client_id}] Buffer reset{self.RESET}\n") sys.stdout.flush() except json.JSONDecodeError: logger.warning(f"Invalid JSON from {client_id}: {message}") except Exception as e: logger.error(f"Error processing message from {client_id}: {e}") break except websockets.exceptions.ConnectionClosed: logger.info(f"Connection closed: {client_id}") except Exception as e: logger.error(f"Unexpected error with {client_id}: {e}") finally: self.active_connections.discard(websocket) print(f"\n{self.BOLD}{'='*80}{self.RESET}") print(f"{self.YELLOW}✗ Client disconnected: {client_id}{self.RESET}") print(f"{self.BOLD}{'='*80}{self.RESET}\n") sys.stdout.flush() logger.info(f"Connection closed: {client_id}") async def start(self): """Start the WebSocket server.""" self.print_header() async with websockets.serve(self.handle_client, self.host, self.port): logger.info(f"Starting WebSocket server on {self.host}:{self.port}") print(f"{self.GREEN}{self.BOLD}Server is running and ready for connections!{self.RESET}") print(f"{self.BOLD}Waiting for clients...{self.RESET}\n") sys.stdout.flush() # Keep server running await asyncio.Future() def main(): """Main entry point.""" import argparse parser = argparse.ArgumentParser(description="ASR Server with Live Display") parser.add_argument("--host", default="0.0.0.0", help="Host address") parser.add_argument("--port", type=int, default=8766, help="Port number") parser.add_argument("--model-path", default="models/parakeet", help="Model directory") parser.add_argument("--sample-rate", type=int, default=16000, help="Sample rate") args = parser.parse_args() server = DisplayServer( host=args.host, port=args.port, model_path=args.model_path, sample_rate=args.sample_rate, ) try: asyncio.run(server.start()) except KeyboardInterrupt: print(f"\n\n{server.YELLOW}Server stopped by user{server.RESET}") logger.info("Server stopped by user") if __name__ == "__main__": main()