Changed stt to parakeet — still experiemntal, though performance seems to be better
This commit is contained in:
@@ -2,13 +2,13 @@
|
||||
STT Server
|
||||
|
||||
FastAPI WebSocket server for real-time speech-to-text.
|
||||
Combines Silero VAD (CPU) and Faster-Whisper (GPU) for efficient transcription.
|
||||
Combines Silero VAD (CPU) and NVIDIA Parakeet (GPU) for efficient transcription.
|
||||
|
||||
Architecture:
|
||||
- VAD runs continuously on every audio chunk (CPU)
|
||||
- Whisper transcribes only when VAD detects speech (GPU)
|
||||
- Parakeet transcribes only when VAD detects speech (GPU)
|
||||
- Supports multiple concurrent users
|
||||
- Sends partial and final transcripts via WebSocket
|
||||
- Sends partial and final transcripts via WebSocket with word-level tokens
|
||||
"""
|
||||
|
||||
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException
|
||||
@@ -20,7 +20,7 @@ from typing import Dict, Optional
|
||||
from datetime import datetime
|
||||
|
||||
from vad_processor import VADProcessor
|
||||
from whisper_transcriber import WhisperTranscriber
|
||||
from parakeet_transcriber import ParakeetTranscriber
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
@@ -34,7 +34,7 @@ app = FastAPI(title="Miku STT Server", version="1.0.0")
|
||||
|
||||
# Global instances (initialized on startup)
|
||||
vad_processor: Optional[VADProcessor] = None
|
||||
whisper_transcriber: Optional[WhisperTranscriber] = None
|
||||
parakeet_transcriber: Optional[ParakeetTranscriber] = None
|
||||
|
||||
# User session tracking
|
||||
user_sessions: Dict[str, dict] = {}
|
||||
@@ -117,39 +117,40 @@ class UserSTTSession:
|
||||
self.audio_buffer.append(audio_np)
|
||||
|
||||
async def _transcribe_partial(self):
|
||||
"""Transcribe accumulated audio and send partial result."""
|
||||
"""Transcribe accumulated audio and send partial result with word tokens."""
|
||||
if not self.audio_buffer:
|
||||
return
|
||||
|
||||
# Concatenate audio
|
||||
audio_full = np.concatenate(self.audio_buffer)
|
||||
|
||||
# Transcribe asynchronously
|
||||
# Transcribe asynchronously with word-level timestamps
|
||||
try:
|
||||
text = await whisper_transcriber.transcribe_async(
|
||||
result = await parakeet_transcriber.transcribe_async(
|
||||
audio_full,
|
||||
sample_rate=16000,
|
||||
initial_prompt=self.last_transcript # Use previous for context
|
||||
return_timestamps=True
|
||||
)
|
||||
|
||||
if text and text != self.last_transcript:
|
||||
self.last_transcript = text
|
||||
if result and result.get("text") and result["text"] != self.last_transcript:
|
||||
self.last_transcript = result["text"]
|
||||
|
||||
# Send partial transcript
|
||||
# Send partial transcript with word tokens for LLM pre-computation
|
||||
await self.websocket.send_json({
|
||||
"type": "partial",
|
||||
"text": text,
|
||||
"text": result["text"],
|
||||
"words": result.get("words", []), # Word-level tokens
|
||||
"user_id": self.user_id,
|
||||
"timestamp": self.timestamp_ms
|
||||
})
|
||||
|
||||
logger.info(f"Partial [{self.user_id}]: {text}")
|
||||
logger.info(f"Partial [{self.user_id}]: {result['text']}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Partial transcription failed: {e}", exc_info=True)
|
||||
|
||||
async def _transcribe_final(self):
|
||||
"""Transcribe final accumulated audio."""
|
||||
"""Transcribe final accumulated audio with word tokens."""
|
||||
if not self.audio_buffer:
|
||||
return
|
||||
|
||||
@@ -157,23 +158,25 @@ class UserSTTSession:
|
||||
audio_full = np.concatenate(self.audio_buffer)
|
||||
|
||||
try:
|
||||
text = await whisper_transcriber.transcribe_async(
|
||||
result = await parakeet_transcriber.transcribe_async(
|
||||
audio_full,
|
||||
sample_rate=16000
|
||||
sample_rate=16000,
|
||||
return_timestamps=True
|
||||
)
|
||||
|
||||
if text:
|
||||
self.last_transcript = text
|
||||
if result and result.get("text"):
|
||||
self.last_transcript = result["text"]
|
||||
|
||||
# Send final transcript
|
||||
# Send final transcript with word tokens
|
||||
await self.websocket.send_json({
|
||||
"type": "final",
|
||||
"text": text,
|
||||
"text": result["text"],
|
||||
"words": result.get("words", []), # Word-level tokens for LLM
|
||||
"user_id": self.user_id,
|
||||
"timestamp": self.timestamp_ms
|
||||
})
|
||||
|
||||
logger.info(f"Final [{self.user_id}]: {text}")
|
||||
logger.info(f"Final [{self.user_id}]: {result['text']}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Final transcription failed: {e}", exc_info=True)
|
||||
@@ -206,7 +209,7 @@ class UserSTTSession:
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
"""Initialize models on server startup."""
|
||||
global vad_processor, whisper_transcriber
|
||||
global vad_processor, parakeet_transcriber
|
||||
|
||||
logger.info("=" * 50)
|
||||
logger.info("Initializing Miku STT Server")
|
||||
@@ -222,15 +225,14 @@ async def startup_event():
|
||||
)
|
||||
logger.info("✓ VAD ready")
|
||||
|
||||
# Initialize Whisper (GPU with cuDNN)
|
||||
logger.info("Loading Faster-Whisper model (GPU)...")
|
||||
whisper_transcriber = WhisperTranscriber(
|
||||
model_size="small",
|
||||
# Initialize Parakeet (GPU)
|
||||
logger.info("Loading NVIDIA Parakeet TDT model (GPU)...")
|
||||
parakeet_transcriber = ParakeetTranscriber(
|
||||
model_name="nvidia/parakeet-tdt-0.6b-v3",
|
||||
device="cuda",
|
||||
compute_type="float16",
|
||||
language="en"
|
||||
)
|
||||
logger.info("✓ Whisper ready")
|
||||
logger.info("✓ Parakeet ready")
|
||||
|
||||
logger.info("=" * 50)
|
||||
logger.info("STT Server ready to accept connections")
|
||||
@@ -242,8 +244,8 @@ async def shutdown_event():
|
||||
"""Cleanup on server shutdown."""
|
||||
logger.info("Shutting down STT server...")
|
||||
|
||||
if whisper_transcriber:
|
||||
whisper_transcriber.cleanup()
|
||||
if parakeet_transcriber:
|
||||
parakeet_transcriber.cleanup()
|
||||
|
||||
logger.info("STT server shutdown complete")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user