Implemented experimental real production ready voice chat, relegated old flow to voice debug mode. New Web UI panel for Voice Chat.
This commit is contained in:
@@ -49,6 +49,15 @@ class ParakeetTranscriber:
|
||||
|
||||
logger.info(f"Loading Parakeet model: {model_name} on {device}...")
|
||||
|
||||
# Set PyTorch memory allocator settings for better memory management
|
||||
if device == "cuda":
|
||||
# Enable expandable segments to reduce fragmentation
|
||||
import os
|
||||
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
|
||||
|
||||
# Clear cache before loading model
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Load model via NeMo from HuggingFace
|
||||
self.model = EncDecRNNTBPEModel.from_pretrained(
|
||||
model_name=model_name,
|
||||
@@ -58,6 +67,11 @@ class ParakeetTranscriber:
|
||||
self.model.eval()
|
||||
if device == "cuda":
|
||||
self.model = self.model.cuda()
|
||||
# Enable memory efficient attention if available
|
||||
try:
|
||||
self.model.encoder.use_memory_efficient_attention = True
|
||||
except:
|
||||
pass
|
||||
|
||||
# Thread pool for blocking transcription calls
|
||||
self.executor = ThreadPoolExecutor(max_workers=2)
|
||||
@@ -119,7 +133,7 @@ class ParakeetTranscriber:
|
||||
|
||||
# Transcribe using NeMo model
|
||||
with torch.no_grad():
|
||||
# Convert to tensor
|
||||
# Convert to tensor and keep on GPU to avoid CPU/GPU bouncing
|
||||
audio_signal = torch.from_numpy(audio).unsqueeze(0)
|
||||
audio_signal_len = torch.tensor([len(audio)])
|
||||
|
||||
@@ -127,12 +141,14 @@ class ParakeetTranscriber:
|
||||
audio_signal = audio_signal.cuda()
|
||||
audio_signal_len = audio_signal_len.cuda()
|
||||
|
||||
# Get transcription with timestamps
|
||||
# NeMo returns list of Hypothesis objects when timestamps=True
|
||||
# Get transcription
|
||||
# NeMo returns list of Hypothesis objects
|
||||
# Note: timestamps=True causes significant VRAM usage (~1-2GB extra)
|
||||
# Only enable for final transcriptions, not streaming partials
|
||||
transcriptions = self.model.transcribe(
|
||||
audio=[audio_signal.squeeze(0).cpu().numpy()],
|
||||
audio=[audio], # Pass NumPy array directly (NeMo handles it efficiently)
|
||||
batch_size=1,
|
||||
timestamps=True # Enable timestamps to get word-level data
|
||||
timestamps=return_timestamps # Only use timestamps when explicitly requested
|
||||
)
|
||||
|
||||
# Extract text from Hypothesis object
|
||||
@@ -144,9 +160,9 @@ class ParakeetTranscriber:
|
||||
# Hypothesis object has .text attribute
|
||||
text = hypothesis.text.strip() if hasattr(hypothesis, 'text') else str(hypothesis).strip()
|
||||
|
||||
# Extract word-level timestamps if available
|
||||
# Extract word-level timestamps if available and requested
|
||||
words = []
|
||||
if hasattr(hypothesis, 'timestamp') and hypothesis.timestamp:
|
||||
if return_timestamps and hasattr(hypothesis, 'timestamp') and hypothesis.timestamp:
|
||||
# timestamp is a dict with 'word' key containing list of word timestamps
|
||||
word_timestamps = hypothesis.timestamp.get('word', [])
|
||||
for word_info in word_timestamps:
|
||||
@@ -165,6 +181,10 @@ class ParakeetTranscriber:
|
||||
}
|
||||
else:
|
||||
return text
|
||||
|
||||
# Note: We do NOT call torch.cuda.empty_cache() here
|
||||
# That breaks PyTorch's memory allocator and causes fragmentation
|
||||
# Let PyTorch manage its own memory pool
|
||||
|
||||
async def transcribe_streaming(
|
||||
self,
|
||||
|
||||
@@ -22,6 +22,7 @@ silero-vad==5.1.2
|
||||
huggingface-hub>=0.30.0,<1.0
|
||||
nemo_toolkit[asr]==2.4.0
|
||||
omegaconf==2.3.0
|
||||
cuda-python>=12.3 # Enable CUDA graphs for faster decoding
|
||||
|
||||
# Utilities
|
||||
python-multipart==0.0.20
|
||||
|
||||
@@ -51,6 +51,9 @@ class UserSTTSession:
|
||||
self.timestamp_ms = 0.0
|
||||
self.transcript_buffer = []
|
||||
self.last_transcript = ""
|
||||
self.last_partial_duration = 0.0 # Track when we last sent a partial
|
||||
self.last_speech_timestamp = 0.0 # Track last time we detected speech
|
||||
self.speech_timeout_ms = 3000 # Force finalization after 3s of no new speech
|
||||
|
||||
logger.info(f"Created STT session for user {user_id}")
|
||||
|
||||
@@ -75,6 +78,8 @@ class UserSTTSession:
|
||||
event_type = vad_event["event"]
|
||||
probability = vad_event["probability"]
|
||||
|
||||
logger.debug(f"VAD event for user {self.user_id}: {event_type} (prob={probability:.3f})")
|
||||
|
||||
# Send VAD event to client
|
||||
await self.websocket.send_json({
|
||||
"type": "vad",
|
||||
@@ -88,63 +93,91 @@ class UserSTTSession:
|
||||
if event_type == "speech_start":
|
||||
self.is_speaking = True
|
||||
self.audio_buffer = [audio_np]
|
||||
logger.debug(f"User {self.user_id} started speaking")
|
||||
self.last_partial_duration = 0.0
|
||||
self.last_speech_timestamp = self.timestamp_ms
|
||||
logger.info(f"[STT] User {self.user_id} SPEECH START")
|
||||
|
||||
elif event_type == "speaking":
|
||||
if self.is_speaking:
|
||||
self.audio_buffer.append(audio_np)
|
||||
self.last_speech_timestamp = self.timestamp_ms # Update speech timestamp
|
||||
|
||||
# Transcribe partial every ~2 seconds for streaming
|
||||
# Transcribe partial every ~1 second for streaming (reduced from 2s)
|
||||
total_samples = sum(len(chunk) for chunk in self.audio_buffer)
|
||||
duration_s = total_samples / 16000
|
||||
|
||||
if duration_s >= 2.0:
|
||||
# More frequent partials for better responsiveness
|
||||
if duration_s >= 1.0:
|
||||
logger.debug(f"Triggering partial transcription at {duration_s:.1f}s")
|
||||
await self._transcribe_partial()
|
||||
# Keep buffer for final transcription, but mark progress
|
||||
self.last_partial_duration = duration_s
|
||||
|
||||
elif event_type == "speech_end":
|
||||
self.is_speaking = False
|
||||
|
||||
logger.info(f"[STT] User {self.user_id} SPEECH END (VAD detected) - transcribing final")
|
||||
|
||||
# Transcribe final
|
||||
await self._transcribe_final()
|
||||
|
||||
# Clear buffer
|
||||
self.audio_buffer = []
|
||||
self.last_partial_duration = 0.0
|
||||
logger.debug(f"User {self.user_id} stopped speaking")
|
||||
|
||||
else:
|
||||
# Still accumulate audio if speaking
|
||||
# No VAD event - still accumulate audio if speaking
|
||||
if self.is_speaking:
|
||||
self.audio_buffer.append(audio_np)
|
||||
|
||||
# Check for timeout
|
||||
time_since_speech = self.timestamp_ms - self.last_speech_timestamp
|
||||
|
||||
if time_since_speech >= self.speech_timeout_ms:
|
||||
# Timeout - user probably stopped but VAD didn't detect it
|
||||
logger.warning(f"[STT] User {self.user_id} SPEECH TIMEOUT after {time_since_speech:.0f}ms - forcing finalization")
|
||||
self.is_speaking = False
|
||||
|
||||
# Force final transcription
|
||||
await self._transcribe_final()
|
||||
|
||||
# Clear buffer
|
||||
self.audio_buffer = []
|
||||
self.last_partial_duration = 0.0
|
||||
|
||||
async def _transcribe_partial(self):
|
||||
"""Transcribe accumulated audio and send partial result with word tokens."""
|
||||
"""Transcribe accumulated audio and send partial result (no timestamps to save VRAM)."""
|
||||
if not self.audio_buffer:
|
||||
return
|
||||
|
||||
# Concatenate audio
|
||||
audio_full = np.concatenate(self.audio_buffer)
|
||||
|
||||
# Transcribe asynchronously with word-level timestamps
|
||||
# Transcribe asynchronously WITHOUT timestamps for partials (saves 1-2GB VRAM)
|
||||
try:
|
||||
result = await parakeet_transcriber.transcribe_async(
|
||||
audio_full,
|
||||
sample_rate=16000,
|
||||
return_timestamps=True
|
||||
return_timestamps=False # Disable timestamps for partials to reduce VRAM usage
|
||||
)
|
||||
|
||||
if result and result.get("text") and result["text"] != self.last_transcript:
|
||||
self.last_transcript = result["text"]
|
||||
# Result is just a string when timestamps=False
|
||||
text = result if isinstance(result, str) else result.get("text", "")
|
||||
|
||||
if text and text != self.last_transcript:
|
||||
self.last_transcript = text
|
||||
|
||||
# Send partial transcript with word tokens for LLM pre-computation
|
||||
# Send partial transcript without word tokens (saves memory)
|
||||
await self.websocket.send_json({
|
||||
"type": "partial",
|
||||
"text": result["text"],
|
||||
"words": result.get("words", []), # Word-level tokens
|
||||
"text": text,
|
||||
"words": [], # No word tokens for partials
|
||||
"user_id": self.user_id,
|
||||
"timestamp": self.timestamp_ms
|
||||
})
|
||||
|
||||
logger.info(f"Partial [{self.user_id}]: {result['text']}")
|
||||
logger.info(f"Partial [{self.user_id}]: {text}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Partial transcription failed: {e}", exc_info=True)
|
||||
@@ -220,8 +253,8 @@ async def startup_event():
|
||||
vad_processor = VADProcessor(
|
||||
sample_rate=16000,
|
||||
threshold=0.5,
|
||||
min_speech_duration_ms=250, # Conservative
|
||||
min_silence_duration_ms=500 # Conservative
|
||||
min_speech_duration_ms=250, # Conservative - wait 250ms before starting
|
||||
min_silence_duration_ms=300 # Reduced from 500ms - detect silence faster
|
||||
)
|
||||
logger.info("✓ VAD ready")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user