Phase 4 STT pipeline implemented — Silero VAD + faster-whisper — still not working well at all
This commit is contained in:
35
stt/Dockerfile.stt
Normal file
35
stt/Dockerfile.stt
Normal file
@@ -0,0 +1,35 @@
|
||||
FROM nvidia/cuda:12.1.0-base-ubuntu22.04
|
||||
|
||||
# Set working directory
|
||||
WORKDIR /app
|
||||
|
||||
# Install system dependencies
|
||||
RUN apt-get update && apt-get install -y \
|
||||
python3.11 \
|
||||
python3-pip \
|
||||
ffmpeg \
|
||||
libsndfile1 \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Copy requirements
|
||||
COPY requirements.txt .
|
||||
|
||||
# Install Python dependencies
|
||||
RUN pip3 install --no-cache-dir -r requirements.txt
|
||||
|
||||
# Copy application code
|
||||
COPY . .
|
||||
|
||||
# Create models directory
|
||||
RUN mkdir -p /models
|
||||
|
||||
# Expose port
|
||||
EXPOSE 8000
|
||||
|
||||
# Set environment variables
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
ENV CUDA_VISIBLE_DEVICES=0
|
||||
ENV LD_LIBRARY_PATH=/usr/local/lib/python3.11/dist-packages/nvidia/cudnn/lib:${LD_LIBRARY_PATH}
|
||||
|
||||
# Run the server
|
||||
CMD ["uvicorn", "stt_server:app", "--host", "0.0.0.0", "--port", "8000", "--log-level", "info"]
|
||||
152
stt/README.md
Normal file
152
stt/README.md
Normal file
@@ -0,0 +1,152 @@
|
||||
# Miku STT (Speech-to-Text) Server
|
||||
|
||||
Real-time speech-to-text service for Miku voice chat using Silero VAD (CPU) and Faster-Whisper (GPU).
|
||||
|
||||
## Architecture
|
||||
|
||||
- **Silero VAD** (CPU): Lightweight voice activity detection, runs continuously
|
||||
- **Faster-Whisper** (GPU GTX 1660): Efficient speech transcription using CTranslate2
|
||||
- **FastAPI WebSocket**: Real-time bidirectional communication
|
||||
|
||||
## Features
|
||||
|
||||
- ✅ Real-time voice activity detection with conservative settings
|
||||
- ✅ Streaming partial transcripts during speech
|
||||
- ✅ Final transcript on speech completion
|
||||
- ✅ Interruption detection (user speaking over Miku)
|
||||
- ✅ Multi-user support with isolated sessions
|
||||
- ✅ KV cache optimization ready (partial text for LLM precomputation)
|
||||
|
||||
## API Endpoints
|
||||
|
||||
### WebSocket: `/ws/stt/{user_id}`
|
||||
|
||||
Real-time STT session for a specific user.
|
||||
|
||||
**Client sends:** Raw PCM audio (int16, 16kHz mono, 20ms chunks = 320 samples)
|
||||
|
||||
**Server sends:** JSON events:
|
||||
```json
|
||||
// VAD events
|
||||
{"type": "vad", "event": "speech_start", "speaking": true, "probability": 0.85, "timestamp": 1250.5}
|
||||
{"type": "vad", "event": "speaking", "speaking": true, "probability": 0.92, "timestamp": 1270.5}
|
||||
{"type": "vad", "event": "speech_end", "speaking": false, "probability": 0.35, "timestamp": 3500.0}
|
||||
|
||||
// Transcription events
|
||||
{"type": "partial", "text": "Hello how are", "user_id": "123", "timestamp": 2000.0}
|
||||
{"type": "final", "text": "Hello how are you?", "user_id": "123", "timestamp": 3500.0}
|
||||
|
||||
// Interruption detection
|
||||
{"type": "interruption", "probability": 0.92, "timestamp": 1500.0}
|
||||
```
|
||||
|
||||
### HTTP GET: `/health`
|
||||
|
||||
Health check with model status.
|
||||
|
||||
**Response:**
|
||||
```json
|
||||
{
|
||||
"status": "healthy",
|
||||
"models": {
|
||||
"vad": {"loaded": true, "device": "cpu"},
|
||||
"whisper": {"loaded": true, "model": "small", "device": "cuda"}
|
||||
},
|
||||
"sessions": {
|
||||
"active": 2,
|
||||
"users": ["user123", "user456"]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Configuration
|
||||
|
||||
### VAD Parameters (Conservative)
|
||||
|
||||
- **Threshold**: 0.5 (speech probability)
|
||||
- **Min speech duration**: 250ms (avoid false triggers)
|
||||
- **Min silence duration**: 500ms (don't cut off mid-sentence)
|
||||
- **Speech padding**: 30ms (context around speech)
|
||||
|
||||
### Whisper Parameters
|
||||
|
||||
- **Model**: small (balanced speed/quality, ~500MB VRAM)
|
||||
- **Compute**: float16 (GPU optimization)
|
||||
- **Language**: en (English)
|
||||
- **Beam size**: 5 (quality/speed balance)
|
||||
|
||||
## Usage Example
|
||||
|
||||
```python
|
||||
import asyncio
|
||||
import websockets
|
||||
import numpy as np
|
||||
|
||||
async def stream_audio():
|
||||
uri = "ws://localhost:8001/ws/stt/user123"
|
||||
|
||||
async with websockets.connect(uri) as websocket:
|
||||
# Wait for ready
|
||||
ready = await websocket.recv()
|
||||
print(ready)
|
||||
|
||||
# Stream audio chunks (16kHz, 20ms chunks)
|
||||
for audio_chunk in audio_stream:
|
||||
# Convert to bytes (int16)
|
||||
audio_bytes = audio_chunk.astype(np.int16).tobytes()
|
||||
await websocket.send(audio_bytes)
|
||||
|
||||
# Receive events
|
||||
event = await websocket.recv()
|
||||
print(event)
|
||||
|
||||
asyncio.run(stream_audio())
|
||||
```
|
||||
|
||||
## Docker Setup
|
||||
|
||||
### Build
|
||||
```bash
|
||||
docker-compose build miku-stt
|
||||
```
|
||||
|
||||
### Run
|
||||
```bash
|
||||
docker-compose up -d miku-stt
|
||||
```
|
||||
|
||||
### Logs
|
||||
```bash
|
||||
docker-compose logs -f miku-stt
|
||||
```
|
||||
|
||||
### Test
|
||||
```bash
|
||||
curl http://localhost:8001/health
|
||||
```
|
||||
|
||||
## GPU Sharing with Soprano
|
||||
|
||||
Both STT (Whisper) and TTS (Soprano) run on GTX 1660 but at different times:
|
||||
|
||||
1. **User speaking** → Whisper active, Soprano idle
|
||||
2. **LLM processing** → Both idle
|
||||
3. **Miku speaking** → Soprano active, Whisper idle (VAD monitoring only)
|
||||
|
||||
Interruption detection runs VAD continuously but doesn't use GPU.
|
||||
|
||||
## Performance
|
||||
|
||||
- **VAD latency**: 10-20ms per chunk (CPU)
|
||||
- **Whisper latency**: ~1-2s for 2s audio (GPU)
|
||||
- **Memory usage**:
|
||||
- Silero VAD: ~100MB (CPU)
|
||||
- Faster-Whisper small: ~500MB (GPU VRAM)
|
||||
|
||||
## Future Improvements
|
||||
|
||||
- [ ] Multi-language support (auto-detect)
|
||||
- [ ] Word-level timestamps for better sync
|
||||
- [ ] Custom vocabulary/prompt tuning
|
||||
- [ ] Speaker diarization (multiple speakers)
|
||||
- [ ] Noise suppression preprocessing
|
||||
Binary file not shown.
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,239 @@
|
||||
{
|
||||
"alignment_heads": [
|
||||
[
|
||||
5,
|
||||
3
|
||||
],
|
||||
[
|
||||
5,
|
||||
9
|
||||
],
|
||||
[
|
||||
8,
|
||||
0
|
||||
],
|
||||
[
|
||||
8,
|
||||
4
|
||||
],
|
||||
[
|
||||
8,
|
||||
7
|
||||
],
|
||||
[
|
||||
8,
|
||||
8
|
||||
],
|
||||
[
|
||||
9,
|
||||
0
|
||||
],
|
||||
[
|
||||
9,
|
||||
7
|
||||
],
|
||||
[
|
||||
9,
|
||||
9
|
||||
],
|
||||
[
|
||||
10,
|
||||
5
|
||||
]
|
||||
],
|
||||
"lang_ids": [
|
||||
50259,
|
||||
50260,
|
||||
50261,
|
||||
50262,
|
||||
50263,
|
||||
50264,
|
||||
50265,
|
||||
50266,
|
||||
50267,
|
||||
50268,
|
||||
50269,
|
||||
50270,
|
||||
50271,
|
||||
50272,
|
||||
50273,
|
||||
50274,
|
||||
50275,
|
||||
50276,
|
||||
50277,
|
||||
50278,
|
||||
50279,
|
||||
50280,
|
||||
50281,
|
||||
50282,
|
||||
50283,
|
||||
50284,
|
||||
50285,
|
||||
50286,
|
||||
50287,
|
||||
50288,
|
||||
50289,
|
||||
50290,
|
||||
50291,
|
||||
50292,
|
||||
50293,
|
||||
50294,
|
||||
50295,
|
||||
50296,
|
||||
50297,
|
||||
50298,
|
||||
50299,
|
||||
50300,
|
||||
50301,
|
||||
50302,
|
||||
50303,
|
||||
50304,
|
||||
50305,
|
||||
50306,
|
||||
50307,
|
||||
50308,
|
||||
50309,
|
||||
50310,
|
||||
50311,
|
||||
50312,
|
||||
50313,
|
||||
50314,
|
||||
50315,
|
||||
50316,
|
||||
50317,
|
||||
50318,
|
||||
50319,
|
||||
50320,
|
||||
50321,
|
||||
50322,
|
||||
50323,
|
||||
50324,
|
||||
50325,
|
||||
50326,
|
||||
50327,
|
||||
50328,
|
||||
50329,
|
||||
50330,
|
||||
50331,
|
||||
50332,
|
||||
50333,
|
||||
50334,
|
||||
50335,
|
||||
50336,
|
||||
50337,
|
||||
50338,
|
||||
50339,
|
||||
50340,
|
||||
50341,
|
||||
50342,
|
||||
50343,
|
||||
50344,
|
||||
50345,
|
||||
50346,
|
||||
50347,
|
||||
50348,
|
||||
50349,
|
||||
50350,
|
||||
50351,
|
||||
50352,
|
||||
50353,
|
||||
50354,
|
||||
50355,
|
||||
50356,
|
||||
50357
|
||||
],
|
||||
"suppress_ids": [
|
||||
1,
|
||||
2,
|
||||
7,
|
||||
8,
|
||||
9,
|
||||
10,
|
||||
14,
|
||||
25,
|
||||
26,
|
||||
27,
|
||||
28,
|
||||
29,
|
||||
31,
|
||||
58,
|
||||
59,
|
||||
60,
|
||||
61,
|
||||
62,
|
||||
63,
|
||||
90,
|
||||
91,
|
||||
92,
|
||||
93,
|
||||
359,
|
||||
503,
|
||||
522,
|
||||
542,
|
||||
873,
|
||||
893,
|
||||
902,
|
||||
918,
|
||||
922,
|
||||
931,
|
||||
1350,
|
||||
1853,
|
||||
1982,
|
||||
2460,
|
||||
2627,
|
||||
3246,
|
||||
3253,
|
||||
3268,
|
||||
3536,
|
||||
3846,
|
||||
3961,
|
||||
4183,
|
||||
4667,
|
||||
6585,
|
||||
6647,
|
||||
7273,
|
||||
9061,
|
||||
9383,
|
||||
10428,
|
||||
10929,
|
||||
11938,
|
||||
12033,
|
||||
12331,
|
||||
12562,
|
||||
13793,
|
||||
14157,
|
||||
14635,
|
||||
15265,
|
||||
15618,
|
||||
16553,
|
||||
16604,
|
||||
18362,
|
||||
18956,
|
||||
20075,
|
||||
21675,
|
||||
22520,
|
||||
26130,
|
||||
26161,
|
||||
26435,
|
||||
28279,
|
||||
29464,
|
||||
31650,
|
||||
32302,
|
||||
32470,
|
||||
36865,
|
||||
42863,
|
||||
47425,
|
||||
49870,
|
||||
50254,
|
||||
50258,
|
||||
50358,
|
||||
50359,
|
||||
50360,
|
||||
50361,
|
||||
50362
|
||||
],
|
||||
"suppress_ids_begin": [
|
||||
220,
|
||||
50257
|
||||
]
|
||||
}
|
||||
@@ -0,0 +1 @@
|
||||
536b0662742c02347bc0e980a01041f333bce120
|
||||
@@ -0,0 +1 @@
|
||||
../../blobs/e5047537059bd8f182d9ca64c470201585015187
|
||||
@@ -0,0 +1 @@
|
||||
../../blobs/3e305921506d8872816023e4c273e75d2419fb89b24da97b4fe7bce14170d671
|
||||
@@ -0,0 +1 @@
|
||||
../../blobs/7818adb6de9fa3064d3ff81226fdd675be1f6344
|
||||
@@ -0,0 +1 @@
|
||||
../../blobs/c9074644d9d1205686f16d411564729461324b75
|
||||
25
stt/requirements.txt
Normal file
25
stt/requirements.txt
Normal file
@@ -0,0 +1,25 @@
|
||||
# STT Container Requirements
|
||||
|
||||
# Core dependencies
|
||||
fastapi==0.115.6
|
||||
uvicorn[standard]==0.32.1
|
||||
websockets==14.1
|
||||
aiohttp==3.11.11
|
||||
|
||||
# Audio processing
|
||||
numpy==2.2.2
|
||||
soundfile==0.12.1
|
||||
librosa==0.10.2.post1
|
||||
|
||||
# VAD (CPU)
|
||||
torch==2.9.1 # Latest PyTorch
|
||||
torchaudio==2.9.1
|
||||
silero-vad==5.1.2
|
||||
|
||||
# STT (GPU)
|
||||
faster-whisper==1.2.1 # Latest version (Oct 31, 2025)
|
||||
ctranslate2==4.5.0 # Required by faster-whisper
|
||||
|
||||
# Utilities
|
||||
python-multipart==0.0.20
|
||||
pydantic==2.10.4
|
||||
361
stt/stt_server.py
Normal file
361
stt/stt_server.py
Normal file
@@ -0,0 +1,361 @@
|
||||
"""
|
||||
STT Server
|
||||
|
||||
FastAPI WebSocket server for real-time speech-to-text.
|
||||
Combines Silero VAD (CPU) and Faster-Whisper (GPU) for efficient transcription.
|
||||
|
||||
Architecture:
|
||||
- VAD runs continuously on every audio chunk (CPU)
|
||||
- Whisper transcribes only when VAD detects speech (GPU)
|
||||
- Supports multiple concurrent users
|
||||
- Sends partial and final transcripts via WebSocket
|
||||
"""
|
||||
|
||||
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
import numpy as np
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Dict, Optional
|
||||
from datetime import datetime
|
||||
|
||||
from vad_processor import VADProcessor
|
||||
from whisper_transcriber import WhisperTranscriber
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='[%(levelname)s] [%(name)s] %(message)s'
|
||||
)
|
||||
logger = logging.getLogger('stt_server')
|
||||
|
||||
# Initialize FastAPI app
|
||||
app = FastAPI(title="Miku STT Server", version="1.0.0")
|
||||
|
||||
# Global instances (initialized on startup)
|
||||
vad_processor: Optional[VADProcessor] = None
|
||||
whisper_transcriber: Optional[WhisperTranscriber] = None
|
||||
|
||||
# User session tracking
|
||||
user_sessions: Dict[str, dict] = {}
|
||||
|
||||
|
||||
class UserSTTSession:
|
||||
"""Manages STT state for a single user."""
|
||||
|
||||
def __init__(self, user_id: str, websocket: WebSocket):
|
||||
self.user_id = user_id
|
||||
self.websocket = websocket
|
||||
self.audio_buffer = []
|
||||
self.is_speaking = False
|
||||
self.timestamp_ms = 0.0
|
||||
self.transcript_buffer = []
|
||||
self.last_transcript = ""
|
||||
|
||||
logger.info(f"Created STT session for user {user_id}")
|
||||
|
||||
async def process_audio_chunk(self, audio_data: bytes):
|
||||
"""
|
||||
Process incoming audio chunk.
|
||||
|
||||
Args:
|
||||
audio_data: Raw PCM audio (int16, 16kHz mono)
|
||||
"""
|
||||
# Convert bytes to numpy array (int16)
|
||||
audio_np = np.frombuffer(audio_data, dtype=np.int16)
|
||||
|
||||
# Calculate timestamp (assuming 16kHz, 20ms chunks = 320 samples)
|
||||
chunk_duration_ms = (len(audio_np) / 16000) * 1000
|
||||
self.timestamp_ms += chunk_duration_ms
|
||||
|
||||
# Run VAD on chunk
|
||||
vad_event = vad_processor.detect_speech_segment(audio_np, self.timestamp_ms)
|
||||
|
||||
if vad_event:
|
||||
event_type = vad_event["event"]
|
||||
probability = vad_event["probability"]
|
||||
|
||||
# Send VAD event to client
|
||||
await self.websocket.send_json({
|
||||
"type": "vad",
|
||||
"event": event_type,
|
||||
"speaking": event_type in ["speech_start", "speaking"],
|
||||
"probability": probability,
|
||||
"timestamp": self.timestamp_ms
|
||||
})
|
||||
|
||||
# Handle speech events
|
||||
if event_type == "speech_start":
|
||||
self.is_speaking = True
|
||||
self.audio_buffer = [audio_np]
|
||||
logger.debug(f"User {self.user_id} started speaking")
|
||||
|
||||
elif event_type == "speaking":
|
||||
if self.is_speaking:
|
||||
self.audio_buffer.append(audio_np)
|
||||
|
||||
# Transcribe partial every ~2 seconds for streaming
|
||||
total_samples = sum(len(chunk) for chunk in self.audio_buffer)
|
||||
duration_s = total_samples / 16000
|
||||
|
||||
if duration_s >= 2.0:
|
||||
await self._transcribe_partial()
|
||||
|
||||
elif event_type == "speech_end":
|
||||
self.is_speaking = False
|
||||
|
||||
# Transcribe final
|
||||
await self._transcribe_final()
|
||||
|
||||
# Clear buffer
|
||||
self.audio_buffer = []
|
||||
logger.debug(f"User {self.user_id} stopped speaking")
|
||||
|
||||
else:
|
||||
# Still accumulate audio if speaking
|
||||
if self.is_speaking:
|
||||
self.audio_buffer.append(audio_np)
|
||||
|
||||
async def _transcribe_partial(self):
|
||||
"""Transcribe accumulated audio and send partial result."""
|
||||
if not self.audio_buffer:
|
||||
return
|
||||
|
||||
# Concatenate audio
|
||||
audio_full = np.concatenate(self.audio_buffer)
|
||||
|
||||
# Transcribe asynchronously
|
||||
try:
|
||||
text = await whisper_transcriber.transcribe_async(
|
||||
audio_full,
|
||||
sample_rate=16000,
|
||||
initial_prompt=self.last_transcript # Use previous for context
|
||||
)
|
||||
|
||||
if text and text != self.last_transcript:
|
||||
self.last_transcript = text
|
||||
|
||||
# Send partial transcript
|
||||
await self.websocket.send_json({
|
||||
"type": "partial",
|
||||
"text": text,
|
||||
"user_id": self.user_id,
|
||||
"timestamp": self.timestamp_ms
|
||||
})
|
||||
|
||||
logger.info(f"Partial [{self.user_id}]: {text}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Partial transcription failed: {e}", exc_info=True)
|
||||
|
||||
async def _transcribe_final(self):
|
||||
"""Transcribe final accumulated audio."""
|
||||
if not self.audio_buffer:
|
||||
return
|
||||
|
||||
# Concatenate all audio
|
||||
audio_full = np.concatenate(self.audio_buffer)
|
||||
|
||||
try:
|
||||
text = await whisper_transcriber.transcribe_async(
|
||||
audio_full,
|
||||
sample_rate=16000
|
||||
)
|
||||
|
||||
if text:
|
||||
self.last_transcript = text
|
||||
|
||||
# Send final transcript
|
||||
await self.websocket.send_json({
|
||||
"type": "final",
|
||||
"text": text,
|
||||
"user_id": self.user_id,
|
||||
"timestamp": self.timestamp_ms
|
||||
})
|
||||
|
||||
logger.info(f"Final [{self.user_id}]: {text}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Final transcription failed: {e}", exc_info=True)
|
||||
|
||||
async def check_interruption(self, audio_data: bytes) -> bool:
|
||||
"""
|
||||
Check if user is interrupting (for use during Miku's speech).
|
||||
|
||||
Args:
|
||||
audio_data: Raw PCM audio chunk
|
||||
|
||||
Returns:
|
||||
True if interruption detected
|
||||
"""
|
||||
audio_np = np.frombuffer(audio_data, dtype=np.int16)
|
||||
speech_prob, is_speaking = vad_processor.process_chunk(audio_np)
|
||||
|
||||
# Interruption: high probability sustained for threshold duration
|
||||
if speech_prob > 0.7: # Higher threshold for interruption
|
||||
await self.websocket.send_json({
|
||||
"type": "interruption",
|
||||
"probability": speech_prob,
|
||||
"timestamp": self.timestamp_ms
|
||||
})
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
"""Initialize models on server startup."""
|
||||
global vad_processor, whisper_transcriber
|
||||
|
||||
logger.info("=" * 50)
|
||||
logger.info("Initializing Miku STT Server")
|
||||
logger.info("=" * 50)
|
||||
|
||||
# Initialize VAD (CPU)
|
||||
logger.info("Loading Silero VAD model (CPU)...")
|
||||
vad_processor = VADProcessor(
|
||||
sample_rate=16000,
|
||||
threshold=0.5,
|
||||
min_speech_duration_ms=250, # Conservative
|
||||
min_silence_duration_ms=500 # Conservative
|
||||
)
|
||||
logger.info("✓ VAD ready")
|
||||
|
||||
# Initialize Whisper (GPU with cuDNN)
|
||||
logger.info("Loading Faster-Whisper model (GPU)...")
|
||||
whisper_transcriber = WhisperTranscriber(
|
||||
model_size="small",
|
||||
device="cuda",
|
||||
compute_type="float16",
|
||||
language="en"
|
||||
)
|
||||
logger.info("✓ Whisper ready")
|
||||
|
||||
logger.info("=" * 50)
|
||||
logger.info("STT Server ready to accept connections")
|
||||
logger.info("=" * 50)
|
||||
|
||||
|
||||
@app.on_event("shutdown")
|
||||
async def shutdown_event():
|
||||
"""Cleanup on server shutdown."""
|
||||
logger.info("Shutting down STT server...")
|
||||
|
||||
if whisper_transcriber:
|
||||
whisper_transcriber.cleanup()
|
||||
|
||||
logger.info("STT server shutdown complete")
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
"""Health check endpoint."""
|
||||
return {
|
||||
"service": "Miku STT Server",
|
||||
"status": "running",
|
||||
"vad_ready": vad_processor is not None,
|
||||
"whisper_ready": whisper_transcriber is not None,
|
||||
"active_sessions": len(user_sessions)
|
||||
}
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
"""Detailed health check."""
|
||||
return {
|
||||
"status": "healthy",
|
||||
"models": {
|
||||
"vad": {
|
||||
"loaded": vad_processor is not None,
|
||||
"device": "cpu"
|
||||
},
|
||||
"whisper": {
|
||||
"loaded": whisper_transcriber is not None,
|
||||
"model": "small",
|
||||
"device": "cuda"
|
||||
}
|
||||
},
|
||||
"sessions": {
|
||||
"active": len(user_sessions),
|
||||
"users": list(user_sessions.keys())
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@app.websocket("/ws/stt/{user_id}")
|
||||
async def websocket_stt(websocket: WebSocket, user_id: str):
|
||||
"""
|
||||
WebSocket endpoint for real-time STT.
|
||||
|
||||
Client sends: Raw PCM audio (int16, 16kHz mono, 20ms chunks)
|
||||
Server sends: JSON events:
|
||||
- {"type": "vad", "event": "speech_start|speaking|speech_end", ...}
|
||||
- {"type": "partial", "text": "...", ...}
|
||||
- {"type": "final", "text": "...", ...}
|
||||
- {"type": "interruption", "probability": 0.xx}
|
||||
"""
|
||||
await websocket.accept()
|
||||
logger.info(f"STT WebSocket connected: user {user_id}")
|
||||
|
||||
# Create session
|
||||
session = UserSTTSession(user_id, websocket)
|
||||
user_sessions[user_id] = session
|
||||
|
||||
try:
|
||||
# Send ready message
|
||||
await websocket.send_json({
|
||||
"type": "ready",
|
||||
"user_id": user_id,
|
||||
"message": "STT session started"
|
||||
})
|
||||
|
||||
# Main loop: receive audio chunks
|
||||
while True:
|
||||
# Receive binary audio data
|
||||
data = await websocket.receive_bytes()
|
||||
|
||||
# Process audio chunk
|
||||
await session.process_audio_chunk(data)
|
||||
|
||||
except WebSocketDisconnect:
|
||||
logger.info(f"User {user_id} disconnected")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in STT WebSocket for user {user_id}: {e}", exc_info=True)
|
||||
|
||||
finally:
|
||||
# Cleanup session
|
||||
if user_id in user_sessions:
|
||||
del user_sessions[user_id]
|
||||
logger.info(f"STT session ended for user {user_id}")
|
||||
|
||||
|
||||
@app.post("/interrupt/check")
|
||||
async def check_interruption(user_id: str):
|
||||
"""
|
||||
Check if user is interrupting (for use during Miku's speech).
|
||||
|
||||
Query param:
|
||||
user_id: Discord user ID
|
||||
|
||||
Returns:
|
||||
{"interrupting": bool, "probability": float}
|
||||
"""
|
||||
session = user_sessions.get(user_id)
|
||||
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail="User session not found")
|
||||
|
||||
# Get current VAD state
|
||||
vad_state = vad_processor.get_state()
|
||||
|
||||
return {
|
||||
"interrupting": vad_state["speaking"],
|
||||
"user_id": user_id
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000, log_level="info")
|
||||
206
stt/test_stt.py
Normal file
206
stt/test_stt.py
Normal file
@@ -0,0 +1,206 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test script for STT WebSocket server.
|
||||
Sends test audio and receives VAD/transcription events.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import websockets
|
||||
import numpy as np
|
||||
import json
|
||||
import wave
|
||||
|
||||
|
||||
async def test_websocket():
|
||||
"""Test STT WebSocket with generated audio."""
|
||||
|
||||
uri = "ws://localhost:8001/ws/stt/test_user"
|
||||
|
||||
print("🔌 Connecting to STT WebSocket...")
|
||||
|
||||
async with websockets.connect(uri) as websocket:
|
||||
# Wait for ready message
|
||||
ready_msg = await websocket.recv()
|
||||
ready = json.loads(ready_msg)
|
||||
print(f"✅ {ready}")
|
||||
|
||||
# Generate test audio: 2 seconds of 440Hz tone (A note)
|
||||
# This simulates speech-like audio
|
||||
print("\n🎵 Generating test audio (2 seconds, 440Hz tone)...")
|
||||
sample_rate = 16000
|
||||
duration = 2.0
|
||||
frequency = 440 # A4 note
|
||||
|
||||
t = np.linspace(0, duration, int(sample_rate * duration), False)
|
||||
audio = np.sin(frequency * 2 * np.pi * t)
|
||||
|
||||
# Convert to int16
|
||||
audio_int16 = (audio * 32767).astype(np.int16)
|
||||
|
||||
# Send in 20ms chunks (320 samples at 16kHz)
|
||||
chunk_size = 320 # 20ms chunks
|
||||
total_chunks = len(audio_int16) // chunk_size
|
||||
|
||||
print(f"📤 Sending {total_chunks} audio chunks (20ms each)...\n")
|
||||
|
||||
# Send chunks and receive events
|
||||
for i in range(0, len(audio_int16), chunk_size):
|
||||
chunk = audio_int16[i:i+chunk_size]
|
||||
|
||||
# Send audio chunk
|
||||
await websocket.send(chunk.tobytes())
|
||||
|
||||
# Try to receive events (non-blocking)
|
||||
try:
|
||||
response = await asyncio.wait_for(
|
||||
websocket.recv(),
|
||||
timeout=0.01
|
||||
)
|
||||
event = json.loads(response)
|
||||
|
||||
# Print VAD events
|
||||
if event['type'] == 'vad':
|
||||
emoji = "🟢" if event['speaking'] else "⚪"
|
||||
print(f"{emoji} VAD: {event['event']} "
|
||||
f"(prob={event['probability']:.3f}, "
|
||||
f"t={event['timestamp']:.1f}ms)")
|
||||
|
||||
# Print transcription events
|
||||
elif event['type'] == 'partial':
|
||||
print(f"📝 Partial: \"{event['text']}\"")
|
||||
|
||||
elif event['type'] == 'final':
|
||||
print(f"✅ Final: \"{event['text']}\"")
|
||||
|
||||
elif event['type'] == 'interruption':
|
||||
print(f"⚠️ Interruption detected! (prob={event['probability']:.3f})")
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
pass # No event yet
|
||||
|
||||
# Small delay between chunks
|
||||
await asyncio.sleep(0.02)
|
||||
|
||||
print("\n✅ Test audio sent successfully!")
|
||||
|
||||
# Wait a bit for final transcription
|
||||
print("⏳ Waiting for final transcription...")
|
||||
|
||||
for _ in range(50): # Wait up to 1 second
|
||||
try:
|
||||
response = await asyncio.wait_for(
|
||||
websocket.recv(),
|
||||
timeout=0.02
|
||||
)
|
||||
event = json.loads(response)
|
||||
|
||||
if event['type'] == 'final':
|
||||
print(f"\n✅ FINAL TRANSCRIPT: \"{event['text']}\"")
|
||||
break
|
||||
elif event['type'] == 'vad':
|
||||
emoji = "🟢" if event['speaking'] else "⚪"
|
||||
print(f"{emoji} VAD: {event['event']} (prob={event['probability']:.3f})")
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
|
||||
print("\n✅ WebSocket test complete!")
|
||||
|
||||
|
||||
async def test_with_sample_audio():
|
||||
"""Test with actual speech audio file (if available)."""
|
||||
|
||||
import sys
|
||||
import os
|
||||
|
||||
if len(sys.argv) > 1 and os.path.exists(sys.argv[1]):
|
||||
audio_file = sys.argv[1]
|
||||
print(f"📂 Loading audio from: {audio_file}")
|
||||
|
||||
# Load WAV file
|
||||
with wave.open(audio_file, 'rb') as wav:
|
||||
sample_rate = wav.getframerate()
|
||||
n_channels = wav.getnchannels()
|
||||
audio_data = wav.readframes(wav.getnframes())
|
||||
|
||||
# Convert to numpy array
|
||||
audio_np = np.frombuffer(audio_data, dtype=np.int16)
|
||||
|
||||
# If stereo, convert to mono
|
||||
if n_channels == 2:
|
||||
audio_np = audio_np.reshape(-1, 2).mean(axis=1).astype(np.int16)
|
||||
|
||||
# Resample to 16kHz if needed
|
||||
if sample_rate != 16000:
|
||||
print(f"⚠️ Resampling from {sample_rate}Hz to 16000Hz...")
|
||||
import librosa
|
||||
audio_float = audio_np.astype(np.float32) / 32768.0
|
||||
audio_resampled = librosa.resample(
|
||||
audio_float,
|
||||
orig_sr=sample_rate,
|
||||
target_sr=16000
|
||||
)
|
||||
audio_np = (audio_resampled * 32767).astype(np.int16)
|
||||
|
||||
print(f"✅ Audio loaded: {len(audio_np)/16000:.2f} seconds")
|
||||
|
||||
# Send to STT
|
||||
uri = "ws://localhost:8001/ws/stt/test_user"
|
||||
|
||||
async with websockets.connect(uri) as websocket:
|
||||
ready_msg = await websocket.recv()
|
||||
print(f"✅ {json.loads(ready_msg)}")
|
||||
|
||||
# Send in chunks
|
||||
chunk_size = 320 # 20ms at 16kHz
|
||||
|
||||
for i in range(0, len(audio_np), chunk_size):
|
||||
chunk = audio_np[i:i+chunk_size]
|
||||
await websocket.send(chunk.tobytes())
|
||||
|
||||
# Receive events
|
||||
try:
|
||||
response = await asyncio.wait_for(
|
||||
websocket.recv(),
|
||||
timeout=0.01
|
||||
)
|
||||
event = json.loads(response)
|
||||
|
||||
if event['type'] == 'vad':
|
||||
emoji = "🟢" if event['speaking'] else "⚪"
|
||||
print(f"{emoji} VAD: {event['event']} (prob={event['probability']:.3f})")
|
||||
elif event['type'] in ['partial', 'final']:
|
||||
print(f"📝 {event['type'].title()}: \"{event['text']}\"")
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
|
||||
await asyncio.sleep(0.02)
|
||||
|
||||
# Wait for final
|
||||
for _ in range(100):
|
||||
try:
|
||||
response = await asyncio.wait_for(websocket.recv(), timeout=0.02)
|
||||
event = json.loads(response)
|
||||
if event['type'] == 'final':
|
||||
print(f"\n✅ FINAL: \"{event['text']}\"")
|
||||
break
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
print("=" * 60)
|
||||
print(" Miku STT WebSocket Test")
|
||||
print("=" * 60)
|
||||
print()
|
||||
|
||||
if len(sys.argv) > 1:
|
||||
print("📁 Testing with audio file...")
|
||||
asyncio.run(test_with_sample_audio())
|
||||
else:
|
||||
print("🎵 Testing with generated tone...")
|
||||
print(" (To test with audio file: python test_stt.py audio.wav)")
|
||||
print()
|
||||
asyncio.run(test_websocket())
|
||||
204
stt/vad_processor.py
Normal file
204
stt/vad_processor.py
Normal file
@@ -0,0 +1,204 @@
|
||||
"""
|
||||
Silero VAD Processor
|
||||
|
||||
Lightweight CPU-based Voice Activity Detection for real-time speech detection.
|
||||
Runs continuously on audio chunks to determine when users are speaking.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from typing import Tuple, Optional
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger('vad')
|
||||
|
||||
|
||||
class VADProcessor:
|
||||
"""
|
||||
Voice Activity Detection using Silero VAD model.
|
||||
|
||||
Processes audio chunks and returns speech probability.
|
||||
Conservative settings to avoid cutting off speech.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
sample_rate: int = 16000,
|
||||
threshold: float = 0.5,
|
||||
min_speech_duration_ms: int = 250,
|
||||
min_silence_duration_ms: int = 500,
|
||||
speech_pad_ms: int = 30
|
||||
):
|
||||
"""
|
||||
Initialize VAD processor.
|
||||
|
||||
Args:
|
||||
sample_rate: Audio sample rate (must be 8000 or 16000)
|
||||
threshold: Speech probability threshold (0.0-1.0)
|
||||
min_speech_duration_ms: Minimum speech duration to trigger (conservative)
|
||||
min_silence_duration_ms: Minimum silence to end speech (conservative)
|
||||
speech_pad_ms: Padding around speech segments
|
||||
"""
|
||||
self.sample_rate = sample_rate
|
||||
self.threshold = threshold
|
||||
self.min_speech_duration_ms = min_speech_duration_ms
|
||||
self.min_silence_duration_ms = min_silence_duration_ms
|
||||
self.speech_pad_ms = speech_pad_ms
|
||||
|
||||
# Load Silero VAD model (CPU only)
|
||||
logger.info("Loading Silero VAD model (CPU)...")
|
||||
self.model, utils = torch.hub.load(
|
||||
repo_or_dir='snakers4/silero-vad',
|
||||
model='silero_vad',
|
||||
force_reload=False,
|
||||
onnx=False # Use PyTorch model
|
||||
)
|
||||
|
||||
# Extract utility functions
|
||||
(self.get_speech_timestamps,
|
||||
self.save_audio,
|
||||
self.read_audio,
|
||||
self.VADIterator,
|
||||
self.collect_chunks) = utils
|
||||
|
||||
# State tracking
|
||||
self.speaking = False
|
||||
self.speech_start_time = None
|
||||
self.silence_start_time = None
|
||||
self.audio_buffer = []
|
||||
|
||||
# Chunk buffer for VAD (Silero needs at least 512 samples)
|
||||
self.vad_buffer = []
|
||||
self.min_vad_samples = 512 # Minimum samples for VAD processing
|
||||
|
||||
logger.info(f"VAD initialized: threshold={threshold}, "
|
||||
f"min_speech={min_speech_duration_ms}ms, "
|
||||
f"min_silence={min_silence_duration_ms}ms")
|
||||
|
||||
def process_chunk(self, audio_chunk: np.ndarray) -> Tuple[float, bool]:
|
||||
"""
|
||||
Process single audio chunk and return speech probability.
|
||||
Buffers small chunks to meet VAD minimum size requirement.
|
||||
|
||||
Args:
|
||||
audio_chunk: Audio data as numpy array (int16 or float32)
|
||||
|
||||
Returns:
|
||||
(speech_probability, is_speaking): Probability and current speaking state
|
||||
"""
|
||||
# Convert to float32 if needed
|
||||
if audio_chunk.dtype == np.int16:
|
||||
audio_chunk = audio_chunk.astype(np.float32) / 32768.0
|
||||
|
||||
# Add to buffer
|
||||
self.vad_buffer.append(audio_chunk)
|
||||
|
||||
# Check if we have enough samples
|
||||
total_samples = sum(len(chunk) for chunk in self.vad_buffer)
|
||||
|
||||
if total_samples < self.min_vad_samples:
|
||||
# Not enough samples yet, return neutral probability
|
||||
return 0.0, False
|
||||
|
||||
# Concatenate buffer
|
||||
audio_full = np.concatenate(self.vad_buffer)
|
||||
|
||||
# Process with VAD
|
||||
audio_tensor = torch.from_numpy(audio_full)
|
||||
|
||||
with torch.no_grad():
|
||||
speech_prob = self.model(audio_tensor, self.sample_rate).item()
|
||||
|
||||
# Clear buffer after processing
|
||||
self.vad_buffer = []
|
||||
|
||||
# Update speaking state based on probability
|
||||
is_speaking = speech_prob > self.threshold
|
||||
|
||||
return speech_prob, is_speaking
|
||||
|
||||
def detect_speech_segment(
|
||||
self,
|
||||
audio_chunk: np.ndarray,
|
||||
timestamp_ms: float
|
||||
) -> Optional[dict]:
|
||||
"""
|
||||
Process chunk and detect speech start/end events.
|
||||
|
||||
Args:
|
||||
audio_chunk: Audio data
|
||||
timestamp_ms: Current timestamp in milliseconds
|
||||
|
||||
Returns:
|
||||
Event dict or None:
|
||||
- {"event": "speech_start", "timestamp": float, "probability": float}
|
||||
- {"event": "speech_end", "timestamp": float, "probability": float}
|
||||
- {"event": "speaking", "probability": float} # Ongoing speech
|
||||
"""
|
||||
speech_prob, is_speaking = self.process_chunk(audio_chunk)
|
||||
|
||||
# Speech started
|
||||
if is_speaking and not self.speaking:
|
||||
if self.speech_start_time is None:
|
||||
self.speech_start_time = timestamp_ms
|
||||
|
||||
# Check if speech duration exceeds minimum
|
||||
speech_duration = timestamp_ms - self.speech_start_time
|
||||
if speech_duration >= self.min_speech_duration_ms:
|
||||
self.speaking = True
|
||||
self.silence_start_time = None
|
||||
logger.debug(f"Speech started at {timestamp_ms}ms, prob={speech_prob:.3f}")
|
||||
return {
|
||||
"event": "speech_start",
|
||||
"timestamp": timestamp_ms,
|
||||
"probability": speech_prob
|
||||
}
|
||||
|
||||
# Speech ongoing
|
||||
elif is_speaking and self.speaking:
|
||||
self.silence_start_time = None # Reset silence timer
|
||||
return {
|
||||
"event": "speaking",
|
||||
"probability": speech_prob,
|
||||
"timestamp": timestamp_ms
|
||||
}
|
||||
|
||||
# Silence detected during speech
|
||||
elif not is_speaking and self.speaking:
|
||||
if self.silence_start_time is None:
|
||||
self.silence_start_time = timestamp_ms
|
||||
|
||||
# Check if silence duration exceeds minimum
|
||||
silence_duration = timestamp_ms - self.silence_start_time
|
||||
if silence_duration >= self.min_silence_duration_ms:
|
||||
self.speaking = False
|
||||
self.speech_start_time = None
|
||||
logger.debug(f"Speech ended at {timestamp_ms}ms, prob={speech_prob:.3f}")
|
||||
return {
|
||||
"event": "speech_end",
|
||||
"timestamp": timestamp_ms,
|
||||
"probability": speech_prob
|
||||
}
|
||||
|
||||
# No speech or insufficient duration
|
||||
else:
|
||||
if not is_speaking:
|
||||
self.speech_start_time = None
|
||||
|
||||
return None
|
||||
|
||||
def reset(self):
|
||||
"""Reset VAD state."""
|
||||
self.speaking = False
|
||||
self.speech_start_time = None
|
||||
self.silence_start_time = None
|
||||
self.audio_buffer.clear()
|
||||
logger.debug("VAD state reset")
|
||||
|
||||
def get_state(self) -> dict:
|
||||
"""Get current VAD state."""
|
||||
return {
|
||||
"speaking": self.speaking,
|
||||
"speech_start_time": self.speech_start_time,
|
||||
"silence_start_time": self.silence_start_time
|
||||
}
|
||||
193
stt/whisper_transcriber.py
Normal file
193
stt/whisper_transcriber.py
Normal file
@@ -0,0 +1,193 @@
|
||||
"""
|
||||
Faster-Whisper Transcriber
|
||||
|
||||
GPU-accelerated speech-to-text using faster-whisper (CTranslate2).
|
||||
Supports streaming transcription with partial results.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
from faster_whisper import WhisperModel
|
||||
from typing import Iterator, Optional, List
|
||||
import logging
|
||||
import asyncio
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
logger = logging.getLogger('whisper')
|
||||
|
||||
|
||||
class WhisperTranscriber:
|
||||
"""
|
||||
Faster-Whisper based transcription with streaming support.
|
||||
|
||||
Runs on GPU (GTX 1660) with small model for balance of speed/quality.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_size: str = "small",
|
||||
device: str = "cuda",
|
||||
compute_type: str = "float16",
|
||||
language: str = "en",
|
||||
beam_size: int = 5
|
||||
):
|
||||
"""
|
||||
Initialize Whisper transcriber.
|
||||
|
||||
Args:
|
||||
model_size: Model size (tiny, base, small, medium, large)
|
||||
device: Device to run on (cuda or cpu)
|
||||
compute_type: Compute precision (float16, int8, int8_float16)
|
||||
language: Language code for transcription
|
||||
beam_size: Beam search size (higher = better quality, slower)
|
||||
"""
|
||||
self.model_size = model_size
|
||||
self.device = device
|
||||
self.compute_type = compute_type
|
||||
self.language = language
|
||||
self.beam_size = beam_size
|
||||
|
||||
logger.info(f"Loading Faster-Whisper model: {model_size} on {device}...")
|
||||
|
||||
# Load model
|
||||
self.model = WhisperModel(
|
||||
model_size,
|
||||
device=device,
|
||||
compute_type=compute_type,
|
||||
download_root="/models"
|
||||
)
|
||||
|
||||
# Thread pool for blocking transcription calls
|
||||
self.executor = ThreadPoolExecutor(max_workers=2)
|
||||
|
||||
logger.info(f"Whisper model loaded: {model_size} ({compute_type})")
|
||||
|
||||
async def transcribe_async(
|
||||
self,
|
||||
audio: np.ndarray,
|
||||
sample_rate: int = 16000,
|
||||
initial_prompt: Optional[str] = None
|
||||
) -> str:
|
||||
"""
|
||||
Transcribe audio asynchronously (non-blocking).
|
||||
|
||||
Args:
|
||||
audio: Audio data as numpy array (float32)
|
||||
sample_rate: Audio sample rate
|
||||
initial_prompt: Optional prompt to guide transcription
|
||||
|
||||
Returns:
|
||||
Transcribed text
|
||||
"""
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
# Run transcription in thread pool to avoid blocking
|
||||
result = await loop.run_in_executor(
|
||||
self.executor,
|
||||
self._transcribe_blocking,
|
||||
audio,
|
||||
sample_rate,
|
||||
initial_prompt
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def _transcribe_blocking(
|
||||
self,
|
||||
audio: np.ndarray,
|
||||
sample_rate: int,
|
||||
initial_prompt: Optional[str]
|
||||
) -> str:
|
||||
"""
|
||||
Blocking transcription call (runs in thread pool).
|
||||
"""
|
||||
# Convert to float32 if needed
|
||||
if audio.dtype != np.float32:
|
||||
audio = audio.astype(np.float32) / 32768.0
|
||||
|
||||
# Transcribe
|
||||
segments, info = self.model.transcribe(
|
||||
audio,
|
||||
language=self.language,
|
||||
beam_size=self.beam_size,
|
||||
initial_prompt=initial_prompt,
|
||||
vad_filter=False, # We handle VAD separately
|
||||
word_timestamps=False # Can enable for word-level timing
|
||||
)
|
||||
|
||||
# Collect all segments
|
||||
text_parts = []
|
||||
for segment in segments:
|
||||
text_parts.append(segment.text.strip())
|
||||
|
||||
full_text = " ".join(text_parts).strip()
|
||||
|
||||
logger.debug(f"Transcribed: '{full_text}' (language: {info.language}, "
|
||||
f"probability: {info.language_probability:.2f})")
|
||||
|
||||
return full_text
|
||||
|
||||
async def transcribe_streaming(
|
||||
self,
|
||||
audio_stream: Iterator[np.ndarray],
|
||||
sample_rate: int = 16000,
|
||||
chunk_duration_s: float = 2.0
|
||||
) -> Iterator[dict]:
|
||||
"""
|
||||
Transcribe audio stream with partial results.
|
||||
|
||||
Args:
|
||||
audio_stream: Iterator yielding audio chunks
|
||||
sample_rate: Audio sample rate
|
||||
chunk_duration_s: Duration of each chunk to transcribe
|
||||
|
||||
Yields:
|
||||
{"type": "partial", "text": "partial transcript"}
|
||||
{"type": "final", "text": "complete transcript"}
|
||||
"""
|
||||
accumulated_audio = []
|
||||
chunk_samples = int(chunk_duration_s * sample_rate)
|
||||
|
||||
async for audio_chunk in audio_stream:
|
||||
accumulated_audio.append(audio_chunk)
|
||||
|
||||
# Check if we have enough audio for transcription
|
||||
total_samples = sum(len(chunk) for chunk in accumulated_audio)
|
||||
|
||||
if total_samples >= chunk_samples:
|
||||
# Concatenate accumulated audio
|
||||
audio_data = np.concatenate(accumulated_audio)
|
||||
|
||||
# Transcribe current accumulated audio
|
||||
text = await self.transcribe_async(audio_data, sample_rate)
|
||||
|
||||
if text:
|
||||
yield {
|
||||
"type": "partial",
|
||||
"text": text,
|
||||
"duration": total_samples / sample_rate
|
||||
}
|
||||
|
||||
# Final transcription of remaining audio
|
||||
if accumulated_audio:
|
||||
audio_data = np.concatenate(accumulated_audio)
|
||||
text = await self.transcribe_async(audio_data, sample_rate)
|
||||
|
||||
if text:
|
||||
yield {
|
||||
"type": "final",
|
||||
"text": text,
|
||||
"duration": len(audio_data) / sample_rate
|
||||
}
|
||||
|
||||
def get_supported_languages(self) -> List[str]:
|
||||
"""Get list of supported language codes."""
|
||||
return [
|
||||
"en", "zh", "de", "es", "ru", "ko", "fr", "ja", "pt", "tr",
|
||||
"pl", "ca", "nl", "ar", "sv", "it", "id", "hi", "fi", "vi",
|
||||
"he", "uk", "el", "ms", "cs", "ro", "da", "hu", "ta", "no"
|
||||
]
|
||||
|
||||
def cleanup(self):
|
||||
"""Cleanup resources."""
|
||||
self.executor.shutdown(wait=True)
|
||||
logger.info("Whisper transcriber cleaned up")
|
||||
Reference in New Issue
Block a user