Compare commits

..

2 Commits

13 changed files with 683 additions and 43 deletions

267
bot/utils/error_handler.py Normal file
View File

@@ -0,0 +1,267 @@
# utils/error_handler.py
import aiohttp
import traceback
import datetime
import re
from utils.logger import get_logger
logger = get_logger('error_handler')
# Webhook URL for error notifications
ERROR_WEBHOOK_URL = "https://discord.com/api/webhooks/1462216811293708522/4kdGenpxZFsP0z3VBgebYENODKmcRrmEzoIwCN81jCirnAxuU2YvxGgwGCNBb6TInA9Z"
# User-friendly error message that Miku will say
MIKU_ERROR_MESSAGE = "Someone tell Koko-nii there is a problem with my AI."
def is_error_response(response_text: str) -> bool:
"""
Detect if a response text is an error message.
Args:
response_text: The response text to check
Returns:
bool: True if the response appears to be an error message
"""
if not response_text or not isinstance(response_text, str):
return False
response_lower = response_text.lower().strip()
# Common error patterns
error_patterns = [
r'^error:?\s*\d{3}', # "Error: 502" or "Error 502"
r'^error:?\s+', # "Error: " or "Error "
r'^\d{3}\s+error', # "502 Error"
r'^sorry,?\s+(there\s+was\s+)?an?\s+error', # "Sorry, an error" or "Sorry, there was an error"
r'^sorry,?\s+the\s+response\s+took\s+too\s+long', # Timeout error
r'connection\s+(refused|failed|error|timeout)',
r'timed?\s*out',
r'failed\s+to\s+(connect|respond|process)',
r'service\s+unavailable',
r'internal\s+server\s+error',
r'bad\s+gateway',
r'gateway\s+timeout',
]
# Check if response matches any error pattern
for pattern in error_patterns:
if re.search(pattern, response_lower):
return True
# Check for HTTP status codes indicating errors
if re.match(r'^\d{3}$', response_text.strip()):
status_code = int(response_text.strip())
if status_code >= 400: # HTTP error codes
return True
return False
async def send_error_webhook(error_message: str, context: dict = None):
"""
Send error notification to the webhook.
Args:
error_message: The error message or exception details
context: Optional dictionary with additional context (user, channel, etc.)
"""
try:
def truncate_field(text: str, max_length: int = 1000) -> str:
"""Truncate text to fit Discord's field value limit (1024 chars)."""
if not text:
return "N/A"
text = str(text)
if len(text) > max_length:
return text[:max_length - 20] + "\n...(truncated)"
return text
# Build embed for webhook
embed = {
"title": "🚨 Miku Bot Error",
"color": 0xFF0000, # Red color
"timestamp": datetime.datetime.utcnow().isoformat(),
"fields": []
}
# Add error message (limit to 1000 chars to leave room for code blocks)
error_value = f"```\n{truncate_field(error_message, 900)}\n```"
embed["fields"].append({
"name": "Error Message",
"value": error_value,
"inline": False
})
# Add context if provided
if context:
if 'user' in context and context['user']:
embed["fields"].append({
"name": "User",
"value": truncate_field(context['user'], 200),
"inline": True
})
if 'channel' in context and context['channel']:
embed["fields"].append({
"name": "Channel",
"value": truncate_field(context['channel'], 200),
"inline": True
})
if 'guild' in context and context['guild']:
embed["fields"].append({
"name": "Server",
"value": truncate_field(context['guild'], 200),
"inline": True
})
if 'prompt' in context and context['prompt']:
prompt_value = f"```\n{truncate_field(context['prompt'], 400)}\n```"
embed["fields"].append({
"name": "User Prompt",
"value": prompt_value,
"inline": False
})
if 'exception_type' in context and context['exception_type']:
embed["fields"].append({
"name": "Exception Type",
"value": f"`{truncate_field(context['exception_type'], 200)}`",
"inline": True
})
if 'traceback' in context and context['traceback']:
tb_value = f"```python\n{truncate_field(context['traceback'], 800)}\n```"
embed["fields"].append({
"name": "Traceback",
"value": tb_value,
"inline": False
})
# Ensure we have at least one field (Discord requirement)
if not embed["fields"]:
embed["fields"].append({
"name": "Status",
"value": "Error occurred with no additional context",
"inline": False
})
# Send webhook
payload = {
"content": "<@344584170839236608>", # Mention Koko-nii
"embeds": [embed]
}
async with aiohttp.ClientSession() as session:
async with session.post(ERROR_WEBHOOK_URL, json=payload) as response:
if response.status in [200, 204]:
logger.info(f"✅ Error webhook sent successfully")
else:
error_text = await response.text()
logger.error(f"❌ Failed to send error webhook: {response.status} - {error_text}")
except Exception as e:
logger.error(f"❌ Exception while sending error webhook: {e}")
logger.error(traceback.format_exc())
async def handle_llm_error(
error: Exception,
user_prompt: str = None,
user_id: str = None,
guild_id: str = None,
author_name: str = None
) -> str:
"""
Handle LLM errors by logging them and sending webhook notification.
Args:
error: The exception that occurred
user_prompt: The user's prompt (if available)
user_id: The user ID (if available)
guild_id: The guild ID (if available)
author_name: The user's display name (if available)
Returns:
str: User-friendly error message for Miku to say
"""
logger.error(f"🚨 LLM Error occurred: {type(error).__name__}: {str(error)}")
# Build context
context = {
"exception_type": type(error).__name__,
"traceback": traceback.format_exc()
}
if user_prompt:
context["prompt"] = user_prompt
if author_name:
context["user"] = author_name
elif user_id:
context["user"] = f"User ID: {user_id}"
if guild_id:
context["guild"] = f"Guild ID: {guild_id}"
# Get full error message
error_message = f"{type(error).__name__}: {str(error)}"
# Send webhook notification
await send_error_webhook(error_message, context)
return MIKU_ERROR_MESSAGE
async def handle_response_error(
response_text: str,
user_prompt: str = None,
user_id: str = None,
guild_id: str = None,
author_name: str = None,
channel_name: str = None
) -> str:
"""
Handle error responses from the LLM by checking if the response is an error message.
Args:
response_text: The response text from the LLM
user_prompt: The user's prompt (if available)
user_id: The user ID (if available)
guild_id: The guild ID (if available)
author_name: The user's display name (if available)
channel_name: The channel name (if available)
Returns:
str: Either the original response (if not an error) or user-friendly error message
"""
if not is_error_response(response_text):
return response_text
logger.error(f"🚨 Error response detected: {response_text}")
# Build context
context = {}
if user_prompt:
context["prompt"] = user_prompt
if author_name:
context["user"] = author_name
elif user_id:
context["user"] = f"User ID: {user_id}"
if channel_name:
context["channel"] = channel_name
elif guild_id:
context["channel"] = f"Guild ID: {guild_id}"
if guild_id:
context["guild"] = f"Guild ID: {guild_id}"
# Send webhook notification
await send_error_webhook(f"LLM returned error response: {response_text}", context)
return MIKU_ERROR_MESSAGE

View File

@@ -11,6 +11,7 @@ from utils.context_manager import get_context_for_response_type, get_complete_co
from utils.moods import load_mood_description
from utils.conversation_history import conversation_history
from utils.logger import get_logger
from utils.error_handler import handle_llm_error, handle_response_error
logger = get_logger('llm')
@@ -281,8 +282,18 @@ Please respond in a way that reflects this emotional tone.{pfp_context}"""
# Escape asterisks for actions (e.g., *adjusts hair* becomes \*adjusts hair\*)
reply = _escape_markdown_actions(reply)
# Check if the reply is an error response and handle it
reply = await handle_response_error(
reply,
user_prompt=user_prompt,
user_id=str(user_id),
guild_id=str(guild_id) if guild_id else None,
author_name=author_name
)
# Save to conversation history (only if both prompt and reply are non-empty)
if user_prompt and user_prompt.strip() and reply and reply.strip():
# Don't save error messages to history
if user_prompt and user_prompt.strip() and reply and reply.strip() and reply != "Someone tell Koko-nii there is a problem with my AI.":
# Add user message to history
conversation_history.add_message(
channel_id=channel_id,
@@ -298,21 +309,44 @@ Please respond in a way that reflects this emotional tone.{pfp_context}"""
is_bot=True
)
# Also save to legacy globals for backward compatibility
if user_prompt and user_prompt.strip() and reply and reply.strip():
# Also save to legacy globals for backward compatibility (skip error messages)
if user_prompt and user_prompt.strip() and reply and reply.strip() and reply != "Someone tell Koko-nii there is a problem with my AI.":
globals.conversation_history[user_id].append((user_prompt, reply))
return reply
else:
error_text = await response.text()
logger.error(f"Error from llama-swap: {response.status} - {error_text}")
# Send webhook notification for HTTP errors
await handle_response_error(
f"Error: {response.status}",
user_prompt=user_prompt,
user_id=str(user_id),
guild_id=str(guild_id) if guild_id else None,
author_name=author_name
)
# Don't save error responses to conversation history
return f"Error: {response.status}"
return "Someone tell Koko-nii there is a problem with my AI."
except asyncio.TimeoutError:
return "Sorry, the response took too long. Please try again."
logger.error("Timeout error in query_llama")
return await handle_llm_error(
asyncio.TimeoutError("Request timed out after 300 seconds"),
user_prompt=user_prompt,
user_id=str(user_id),
guild_id=str(guild_id) if guild_id else None,
author_name=author_name
)
except Exception as e:
logger.error(f"Error in query_llama: {e}")
return f"Sorry, there was an error: {str(e)}"
return await handle_llm_error(
e,
user_prompt=user_prompt,
user_id=str(user_id),
guild_id=str(guild_id) if guild_id else None,
author_name=author_name
)
# Backward compatibility alias for existing code
query_ollama = query_llama

View File

@@ -62,6 +62,7 @@ COMPONENTS = {
'voice_manager': 'Voice channel session management',
'voice_commands': 'Voice channel commands',
'voice_audio': 'Voice audio streaming and TTS',
'error_handler': 'Error detection and webhook notifications',
}
# Global configuration

View File

@@ -9,13 +9,22 @@ RUN apt-get update && apt-get install -y \
python3-pip \
ffmpeg \
libsndfile1 \
sox \
libsox-dev \
libsox-fmt-all \
&& rm -rf /var/lib/apt/lists/*
# Copy requirements
COPY requirements.txt .
# Install Python dependencies
RUN pip3 install --no-cache-dir -r requirements.txt
# Upgrade pip to avoid dependency resolution issues
RUN pip3 install --upgrade pip
# Install dependencies for sox package (required by NeMo) in correct order
RUN pip3 install --no-cache-dir numpy==2.2.2 typing-extensions
# Install Python dependencies with legacy resolver (NeMo has complex dependencies)
RUN pip3 install --no-cache-dir --use-deprecated=legacy-resolver -r requirements.txt
# Copy application code
COPY . .

114
stt/PARAKEET_MIGRATION.md Normal file
View File

@@ -0,0 +1,114 @@
# NVIDIA Parakeet Migration
## Summary
Replaced Faster-Whisper with NVIDIA Parakeet TDT (Token-and-Duration Transducer) for real-time speech transcription.
## Changes Made
### 1. New Transcriber: `parakeet_transcriber.py`
- **Model**: `nvidia/parakeet-tdt-0.6b-v3` (600M parameters)
- **Features**:
- Real-time streaming transcription
- Word-level timestamps for LLM pre-computation
- GPU-accelerated (CUDA)
- Lower latency than Faster-Whisper
- Native PyTorch (no CTranslate2 dependency)
### 2. Requirements Updated
**Removed**:
- `faster-whisper==1.2.1`
- `ctranslate2==4.5.0`
**Added**:
- `transformers==4.47.1` - HuggingFace model loading
- `accelerate==1.2.1` - GPU optimization
- `sentencepiece==0.2.0` - Tokenization
**Kept**:
- `torch==2.9.1` & `torchaudio==2.9.1` - Core ML framework
- `silero-vad==5.1.2` - VAD still uses Silero (CPU)
### 3. Server Updates: `stt_server.py`
**Changes**:
- Import `ParakeetTranscriber` instead of `WhisperTranscriber`
- Partial transcripts now include `words` array with timestamps
- Final transcripts include `words` array for LLM pre-computation
- Startup logs show "Loading NVIDIA Parakeet TDT model"
**Word-level Token Format**:
```json
{
"type": "partial",
"text": "hello world",
"words": [
{"word": "hello", "start_time": 0.0, "end_time": 0.5},
{"word": "world", "start_time": 0.5, "end_time": 1.0}
],
"user_id": "123",
"timestamp": 1234.56
}
```
## Advantages Over Faster-Whisper
1. **Real-time Performance**: TDT architecture designed for streaming
2. **No cuDNN Issues**: Native PyTorch, no CTranslate2 library loading problems
3. **Word-level Tokens**: Enables LLM prompt pre-computation during speech
4. **Lower Latency**: Optimized for real-time use cases
5. **Better GPU Utilization**: Uses standard PyTorch CUDA
6. **Simpler Dependencies**: No external compiled libraries
## Deployment
1. **Build Container**:
```bash
docker-compose build miku-stt
```
2. **First Run** (downloads model ~600MB):
```bash
docker-compose up miku-stt
```
Model will be cached in `/models` volume for subsequent runs.
3. **Verify GPU Usage**:
```bash
docker exec miku-stt nvidia-smi
```
You should see `python3` process using VRAM (~1.5GB for model + inference).
## Testing
Same test procedure as before:
1. Join voice channel
2. `!miku listen`
3. Speak clearly
4. Check logs for "Parakeet model loaded"
5. Verify transcripts appear faster than before
## Bot-Side Compatibility
No changes needed to bot code - STT WebSocket protocol is identical. The bot will automatically receive word-level tokens in partial/final transcript messages.
### Future Enhancement: LLM Pre-computation
The `words` array can be used to start LLM inference before full transcript completes:
- Send partial words to LLM as they arrive
- LLM begins processing prompt tokens
- Faster response time when user finishes speaking
## Rollback (if needed)
To revert to Faster-Whisper:
1. Restore `requirements.txt` from git
2. Restore `stt_server.py` from git
3. Delete `parakeet_transcriber.py`
4. Rebuild container
## Performance Expectations
- **Model Load Time**: ~5-10 seconds (first time downloads from HuggingFace)
- **VRAM Usage**: ~1.5GB (vs ~800MB for Whisper small)
- **Latency**: ~200-500ms for 2-second audio chunks
- **GPU Utilization**: 30-60% during active transcription
- **Accuracy**: Similar to Whisper small (designed for English)

View File

@@ -0,0 +1 @@
6d590f77001d318fb17a0b5bf7ee329a91b52598

209
stt/parakeet_transcriber.py Normal file
View File

@@ -0,0 +1,209 @@
"""
NVIDIA Parakeet TDT Transcriber
Real-time streaming ASR using NVIDIA's Parakeet TDT (Token-and-Duration Transducer) model.
Supports streaming transcription with word-level timestamps for LLM pre-computation.
Model: nvidia/parakeet-tdt-0.6b-v3
- 600M parameters
- Real-time capable on GPU
- Word-level timestamps
- Streaming support via NVIDIA NeMo
"""
import numpy as np
import torch
from nemo.collections.asr.models import EncDecRNNTBPEModel
from typing import Optional, List, Dict
import logging
import asyncio
from concurrent.futures import ThreadPoolExecutor
logger = logging.getLogger('parakeet')
class ParakeetTranscriber:
"""
NVIDIA Parakeet-based streaming transcription with word-level tokens.
Uses NVIDIA NeMo for proper model loading and inference.
"""
def __init__(
self,
model_name: str = "nvidia/parakeet-tdt-0.6b-v3",
device: str = "cuda",
language: str = "en"
):
"""
Initialize Parakeet transcriber.
Args:
model_name: HuggingFace model identifier
device: Device to run on (cuda or cpu)
language: Language code (Parakeet primarily supports English)
"""
self.model_name = model_name
self.device = device
self.language = language
logger.info(f"Loading Parakeet model: {model_name} on {device}...")
# Load model via NeMo from HuggingFace
self.model = EncDecRNNTBPEModel.from_pretrained(
model_name=model_name,
map_location=device
)
self.model.eval()
if device == "cuda":
self.model = self.model.cuda()
# Thread pool for blocking transcription calls
self.executor = ThreadPoolExecutor(max_workers=2)
logger.info(f"Parakeet model loaded on {device}")
async def transcribe_async(
self,
audio: np.ndarray,
sample_rate: int = 16000,
return_timestamps: bool = False
) -> str:
"""
Transcribe audio asynchronously (non-blocking).
Args:
audio: Audio data as numpy array (float32)
sample_rate: Audio sample rate (Parakeet expects 16kHz)
return_timestamps: Whether to return word-level timestamps
Returns:
Transcribed text (or dict with timestamps if return_timestamps=True)
"""
loop = asyncio.get_event_loop()
# Run transcription in thread pool to avoid blocking
result = await loop.run_in_executor(
self.executor,
self._transcribe_blocking,
audio,
sample_rate,
return_timestamps
)
return result
def _transcribe_blocking(
self,
audio: np.ndarray,
sample_rate: int,
return_timestamps: bool
):
"""
Blocking transcription call (runs in thread pool).
"""
# Convert to float32 if needed
if audio.dtype != np.float32:
audio = audio.astype(np.float32) / 32768.0
# Ensure correct sample rate (Parakeet expects 16kHz)
if sample_rate != 16000:
logger.warning(f"Audio sample rate is {sample_rate}Hz, Parakeet expects 16kHz. Resampling...")
import torchaudio
audio_tensor = torch.from_numpy(audio).unsqueeze(0)
resampler = torchaudio.transforms.Resample(sample_rate, 16000)
audio_tensor = resampler(audio_tensor)
audio = audio_tensor.squeeze(0).numpy()
sample_rate = 16000
# Transcribe using NeMo model
with torch.no_grad():
# Convert to tensor
audio_signal = torch.from_numpy(audio).unsqueeze(0)
audio_signal_len = torch.tensor([len(audio)])
if self.device == "cuda":
audio_signal = audio_signal.cuda()
audio_signal_len = audio_signal_len.cuda()
# Get transcription with timestamps
# NeMo returns list of Hypothesis objects when timestamps=True
transcriptions = self.model.transcribe(
audio=[audio_signal.squeeze(0).cpu().numpy()],
batch_size=1,
timestamps=True # Enable timestamps to get word-level data
)
# Extract text from Hypothesis object
hypothesis = transcriptions[0] if transcriptions else None
if hypothesis is None:
text = ""
words = []
else:
# Hypothesis object has .text attribute
text = hypothesis.text.strip() if hasattr(hypothesis, 'text') else str(hypothesis).strip()
# Extract word-level timestamps if available
words = []
if hasattr(hypothesis, 'timestamp') and hypothesis.timestamp:
# timestamp is a dict with 'word' key containing list of word timestamps
word_timestamps = hypothesis.timestamp.get('word', [])
for word_info in word_timestamps:
words.append({
"word": word_info.get('word', ''),
"start_time": word_info.get('start', 0.0),
"end_time": word_info.get('end', 0.0)
})
logger.debug(f"Transcribed: '{text}' with {len(words)} words")
if return_timestamps:
return {
"text": text,
"words": words
}
else:
return text
async def transcribe_streaming(
self,
audio_chunks: List[np.ndarray],
sample_rate: int = 16000,
chunk_size_ms: int = 500
) -> Dict[str, any]:
"""
Transcribe audio chunks with streaming support.
Args:
audio_chunks: List of audio chunks to process
sample_rate: Audio sample rate
chunk_size_ms: Size of each chunk in milliseconds
Returns:
Dict with partial and word-level results
"""
if not audio_chunks:
return {"text": "", "words": []}
# Concatenate all chunks
audio_data = np.concatenate(audio_chunks)
# Transcribe with timestamps for streaming
result = await self.transcribe_async(
audio_data,
sample_rate,
return_timestamps=True
)
return result
def get_supported_languages(self) -> List[str]:
"""Get list of supported language codes."""
# Parakeet TDT v3 primarily supports English
return ["en"]
def cleanup(self):
"""Cleanup resources."""
self.executor.shutdown(wait=True)
logger.info("Parakeet transcriber cleaned up")

View File

@@ -6,7 +6,7 @@ uvicorn[standard]==0.32.1
websockets==14.1
aiohttp==3.11.11
# Audio processing
# Audio processing (install numpy first for sox dependency)
numpy==2.2.2
soundfile==0.12.1
librosa==0.10.2.post1
@@ -16,9 +16,12 @@ torch==2.9.1 # Latest PyTorch
torchaudio==2.9.1
silero-vad==5.1.2
# STT (GPU)
faster-whisper==1.2.1 # Latest version (Oct 31, 2025)
ctranslate2==4.5.0 # Required by faster-whisper
# STT (GPU) - NVIDIA NeMo for Parakeet
# Parakeet TDT 0.6b-v3 requires NeMo 2.4
# Fix huggingface-hub version conflict with transformers
huggingface-hub>=0.30.0,<1.0
nemo_toolkit[asr]==2.4.0
omegaconf==2.3.0
# Utilities
python-multipart==0.0.20

View File

@@ -2,13 +2,13 @@
STT Server
FastAPI WebSocket server for real-time speech-to-text.
Combines Silero VAD (CPU) and Faster-Whisper (GPU) for efficient transcription.
Combines Silero VAD (CPU) and NVIDIA Parakeet (GPU) for efficient transcription.
Architecture:
- VAD runs continuously on every audio chunk (CPU)
- Whisper transcribes only when VAD detects speech (GPU)
- Parakeet transcribes only when VAD detects speech (GPU)
- Supports multiple concurrent users
- Sends partial and final transcripts via WebSocket
- Sends partial and final transcripts via WebSocket with word-level tokens
"""
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException
@@ -20,7 +20,7 @@ from typing import Dict, Optional
from datetime import datetime
from vad_processor import VADProcessor
from whisper_transcriber import WhisperTranscriber
from parakeet_transcriber import ParakeetTranscriber
# Configure logging
logging.basicConfig(
@@ -34,7 +34,7 @@ app = FastAPI(title="Miku STT Server", version="1.0.0")
# Global instances (initialized on startup)
vad_processor: Optional[VADProcessor] = None
whisper_transcriber: Optional[WhisperTranscriber] = None
parakeet_transcriber: Optional[ParakeetTranscriber] = None
# User session tracking
user_sessions: Dict[str, dict] = {}
@@ -117,39 +117,40 @@ class UserSTTSession:
self.audio_buffer.append(audio_np)
async def _transcribe_partial(self):
"""Transcribe accumulated audio and send partial result."""
"""Transcribe accumulated audio and send partial result with word tokens."""
if not self.audio_buffer:
return
# Concatenate audio
audio_full = np.concatenate(self.audio_buffer)
# Transcribe asynchronously
# Transcribe asynchronously with word-level timestamps
try:
text = await whisper_transcriber.transcribe_async(
result = await parakeet_transcriber.transcribe_async(
audio_full,
sample_rate=16000,
initial_prompt=self.last_transcript # Use previous for context
return_timestamps=True
)
if text and text != self.last_transcript:
self.last_transcript = text
if result and result.get("text") and result["text"] != self.last_transcript:
self.last_transcript = result["text"]
# Send partial transcript
# Send partial transcript with word tokens for LLM pre-computation
await self.websocket.send_json({
"type": "partial",
"text": text,
"text": result["text"],
"words": result.get("words", []), # Word-level tokens
"user_id": self.user_id,
"timestamp": self.timestamp_ms
})
logger.info(f"Partial [{self.user_id}]: {text}")
logger.info(f"Partial [{self.user_id}]: {result['text']}")
except Exception as e:
logger.error(f"Partial transcription failed: {e}", exc_info=True)
async def _transcribe_final(self):
"""Transcribe final accumulated audio."""
"""Transcribe final accumulated audio with word tokens."""
if not self.audio_buffer:
return
@@ -157,23 +158,25 @@ class UserSTTSession:
audio_full = np.concatenate(self.audio_buffer)
try:
text = await whisper_transcriber.transcribe_async(
result = await parakeet_transcriber.transcribe_async(
audio_full,
sample_rate=16000
sample_rate=16000,
return_timestamps=True
)
if text:
self.last_transcript = text
if result and result.get("text"):
self.last_transcript = result["text"]
# Send final transcript
# Send final transcript with word tokens
await self.websocket.send_json({
"type": "final",
"text": text,
"text": result["text"],
"words": result.get("words", []), # Word-level tokens for LLM
"user_id": self.user_id,
"timestamp": self.timestamp_ms
})
logger.info(f"Final [{self.user_id}]: {text}")
logger.info(f"Final [{self.user_id}]: {result['text']}")
except Exception as e:
logger.error(f"Final transcription failed: {e}", exc_info=True)
@@ -206,7 +209,7 @@ class UserSTTSession:
@app.on_event("startup")
async def startup_event():
"""Initialize models on server startup."""
global vad_processor, whisper_transcriber
global vad_processor, parakeet_transcriber
logger.info("=" * 50)
logger.info("Initializing Miku STT Server")
@@ -222,15 +225,14 @@ async def startup_event():
)
logger.info("✓ VAD ready")
# Initialize Whisper (GPU with cuDNN)
logger.info("Loading Faster-Whisper model (GPU)...")
whisper_transcriber = WhisperTranscriber(
model_size="small",
# Initialize Parakeet (GPU)
logger.info("Loading NVIDIA Parakeet TDT model (GPU)...")
parakeet_transcriber = ParakeetTranscriber(
model_name="nvidia/parakeet-tdt-0.6b-v3",
device="cuda",
compute_type="float16",
language="en"
)
logger.info("Whisper ready")
logger.info("Parakeet ready")
logger.info("=" * 50)
logger.info("STT Server ready to accept connections")
@@ -242,8 +244,8 @@ async def shutdown_event():
"""Cleanup on server shutdown."""
logger.info("Shutting down STT server...")
if whisper_transcriber:
whisper_transcriber.cleanup()
if parakeet_transcriber:
parakeet_transcriber.cleanup()
logger.info("STT server shutdown complete")