Phase 4 STT pipeline implemented — Silero VAD + faster-whisper — still not working well at all
This commit is contained in:
266
STT_VOICE_TESTING.md
Normal file
266
STT_VOICE_TESTING.md
Normal file
@@ -0,0 +1,266 @@
|
|||||||
|
# STT Voice Testing Guide
|
||||||
|
|
||||||
|
## Phase 4B: Bot-Side STT Integration - COMPLETE ✅
|
||||||
|
|
||||||
|
All code has been deployed to containers. Ready for testing!
|
||||||
|
|
||||||
|
## Architecture Overview
|
||||||
|
|
||||||
|
```
|
||||||
|
Discord Voice (User) → Opus 48kHz stereo
|
||||||
|
↓
|
||||||
|
VoiceReceiver.write()
|
||||||
|
↓
|
||||||
|
Opus decode → Stereo-to-mono → Resample to 16kHz
|
||||||
|
↓
|
||||||
|
STTClient.send_audio() → WebSocket
|
||||||
|
↓
|
||||||
|
miku-stt:8001 (Silero VAD + Faster-Whisper)
|
||||||
|
↓
|
||||||
|
JSON events (vad, partial, final, interruption)
|
||||||
|
↓
|
||||||
|
VoiceReceiver callbacks → voice_manager
|
||||||
|
↓
|
||||||
|
on_final_transcript() → _generate_voice_response()
|
||||||
|
↓
|
||||||
|
LLM streaming → TTS tokens → Audio playback
|
||||||
|
```
|
||||||
|
|
||||||
|
## New Voice Commands
|
||||||
|
|
||||||
|
### 1. Start Listening
|
||||||
|
```
|
||||||
|
!miku listen
|
||||||
|
```
|
||||||
|
- Starts listening to **your** voice in the current voice channel
|
||||||
|
- You must be in the same channel as Miku
|
||||||
|
- Miku will transcribe your speech and respond with voice
|
||||||
|
|
||||||
|
```
|
||||||
|
!miku listen @username
|
||||||
|
```
|
||||||
|
- Start listening to a specific user's voice
|
||||||
|
- Useful for moderators or testing with multiple users
|
||||||
|
|
||||||
|
### 2. Stop Listening
|
||||||
|
```
|
||||||
|
!miku stop-listening
|
||||||
|
```
|
||||||
|
- Stop listening to your voice
|
||||||
|
- Miku will no longer transcribe or respond to your speech
|
||||||
|
|
||||||
|
```
|
||||||
|
!miku stop-listening @username
|
||||||
|
```
|
||||||
|
- Stop listening to a specific user
|
||||||
|
|
||||||
|
## Testing Procedure
|
||||||
|
|
||||||
|
### Test 1: Basic STT Connection
|
||||||
|
1. Join a voice channel
|
||||||
|
2. `!miku join` - Miku joins your channel
|
||||||
|
3. `!miku listen` - Start listening to your voice
|
||||||
|
4. Check bot logs for "Started listening to user"
|
||||||
|
5. Check STT logs: `docker logs miku-stt --tail 50`
|
||||||
|
- Should show: "WebSocket connection from user {user_id}"
|
||||||
|
- Should show: "Session started for user {user_id}"
|
||||||
|
|
||||||
|
### Test 2: VAD Detection
|
||||||
|
1. After `!miku listen`, speak into your microphone
|
||||||
|
2. Say something like: "Hello Miku, can you hear me?"
|
||||||
|
3. Check STT logs for VAD events:
|
||||||
|
```
|
||||||
|
[DEBUG] VAD: speech_start probability=0.85
|
||||||
|
[DEBUG] VAD: speaking probability=0.92
|
||||||
|
[DEBUG] VAD: speech_end probability=0.15
|
||||||
|
```
|
||||||
|
4. Bot logs should show: "VAD event for user {id}: speech_start/speaking/speech_end"
|
||||||
|
|
||||||
|
### Test 3: Transcription
|
||||||
|
1. Speak clearly into microphone: "Hey Miku, tell me a joke"
|
||||||
|
2. Watch bot logs for:
|
||||||
|
- "Partial transcript from user {id}: Hey Miku..."
|
||||||
|
- "Final transcript from user {id}: Hey Miku, tell me a joke"
|
||||||
|
3. Miku should respond with LLM-generated speech
|
||||||
|
4. Check channel for: "🎤 Miku: *[her response]*"
|
||||||
|
|
||||||
|
### Test 4: Interruption Detection
|
||||||
|
1. `!miku listen`
|
||||||
|
2. `!miku say Tell me a very long story about your favorite song`
|
||||||
|
3. While Miku is speaking, start talking yourself
|
||||||
|
4. Speak loudly enough to trigger VAD (probability > 0.7)
|
||||||
|
5. Expected behavior:
|
||||||
|
- Miku's audio should stop immediately
|
||||||
|
- Bot logs: "User {id} interrupted Miku (probability={prob})"
|
||||||
|
- STT logs: "Interruption detected during TTS playback"
|
||||||
|
- RVC logs: "Interrupted: Flushed {N} ZMQ chunks"
|
||||||
|
|
||||||
|
### Test 5: Multi-User (if available)
|
||||||
|
1. Have two users join voice channel
|
||||||
|
2. `!miku listen @user1` - Listen to first user
|
||||||
|
3. `!miku listen @user2` - Listen to second user
|
||||||
|
4. Both users speak separately
|
||||||
|
5. Verify Miku responds to each user individually
|
||||||
|
6. Check STT logs for multiple active sessions
|
||||||
|
|
||||||
|
## Logs to Monitor
|
||||||
|
|
||||||
|
### Bot Logs
|
||||||
|
```bash
|
||||||
|
docker logs -f miku-bot | grep -E "(listen|STT|transcript|interrupt)"
|
||||||
|
```
|
||||||
|
Expected output:
|
||||||
|
```
|
||||||
|
[INFO] Started listening to user 123456789 (username)
|
||||||
|
[DEBUG] VAD event for user 123456789: speech_start
|
||||||
|
[DEBUG] Partial transcript from user 123456789: Hello Miku...
|
||||||
|
[INFO] Final transcript from user 123456789: Hello Miku, how are you?
|
||||||
|
[INFO] User 123456789 interrupted Miku (probability=0.82)
|
||||||
|
```
|
||||||
|
|
||||||
|
### STT Logs
|
||||||
|
```bash
|
||||||
|
docker logs -f miku-stt
|
||||||
|
```
|
||||||
|
Expected output:
|
||||||
|
```
|
||||||
|
[INFO] WebSocket connection from user_123456789
|
||||||
|
[INFO] Session started for user 123456789
|
||||||
|
[DEBUG] Received 320 audio samples from user_123456789
|
||||||
|
[DEBUG] VAD speech_start: probability=0.87
|
||||||
|
[INFO] Transcribing audio segment (duration=2.5s)
|
||||||
|
[INFO] Final transcript: "Hello Miku, how are you?"
|
||||||
|
```
|
||||||
|
|
||||||
|
### RVC Logs (for interruption)
|
||||||
|
```bash
|
||||||
|
docker logs -f miku-rvc-api | grep -i interrupt
|
||||||
|
```
|
||||||
|
Expected output:
|
||||||
|
```
|
||||||
|
[INFO] Interrupted: Flushed 15 ZMQ chunks, cleared 48000 RVC buffer samples
|
||||||
|
```
|
||||||
|
|
||||||
|
## Component Status
|
||||||
|
|
||||||
|
### ✅ Completed
|
||||||
|
- [x] STT container running (miku-stt:8001)
|
||||||
|
- [x] Silero VAD on CPU with chunk buffering
|
||||||
|
- [x] Faster-Whisper on GTX 1660 (1.3GB VRAM)
|
||||||
|
- [x] STTClient WebSocket client
|
||||||
|
- [x] VoiceReceiver Discord audio sink
|
||||||
|
- [x] VoiceSession STT integration
|
||||||
|
- [x] listen/stop-listening commands
|
||||||
|
- [x] /interrupt endpoint in RVC API
|
||||||
|
- [x] LLM response generation from transcripts
|
||||||
|
- [x] Interruption detection and cancellation
|
||||||
|
|
||||||
|
### ⏳ Pending Testing
|
||||||
|
- [ ] Basic STT connection test
|
||||||
|
- [ ] VAD speech detection test
|
||||||
|
- [ ] End-to-end transcription test
|
||||||
|
- [ ] LLM voice response test
|
||||||
|
- [ ] Interruption cancellation test
|
||||||
|
- [ ] Multi-user testing (if available)
|
||||||
|
|
||||||
|
### 🔧 Configuration Tuning (after testing)
|
||||||
|
- VAD sensitivity (currently threshold=0.5)
|
||||||
|
- VAD timing (min_speech=250ms, min_silence=500ms)
|
||||||
|
- Interruption threshold (currently 0.7)
|
||||||
|
- Whisper beam size and patience
|
||||||
|
- LLM streaming chunk size
|
||||||
|
|
||||||
|
## API Endpoints
|
||||||
|
|
||||||
|
### STT Container (port 8001)
|
||||||
|
- WebSocket: `ws://localhost:8001/ws/stt/{user_id}`
|
||||||
|
- Health: `http://localhost:8001/health`
|
||||||
|
|
||||||
|
### RVC Container (port 8765)
|
||||||
|
- WebSocket: `ws://localhost:8765/ws/stream`
|
||||||
|
- Interrupt: `http://localhost:8765/interrupt` (POST)
|
||||||
|
- Health: `http://localhost:8765/health`
|
||||||
|
|
||||||
|
## Troubleshooting
|
||||||
|
|
||||||
|
### No audio received from Discord
|
||||||
|
- Check bot logs for "write() called with data"
|
||||||
|
- Verify user is in same voice channel as Miku
|
||||||
|
- Check Discord permissions (View Channel, Connect, Speak)
|
||||||
|
|
||||||
|
### VAD not detecting speech
|
||||||
|
- Check chunk buffer accumulation in STT logs
|
||||||
|
- Verify audio format: PCM int16, 16kHz mono
|
||||||
|
- Try speaking louder or more clearly
|
||||||
|
- Check VAD threshold (may need adjustment)
|
||||||
|
|
||||||
|
### Transcription empty or gibberish
|
||||||
|
- Verify Whisper model loaded (check STT startup logs)
|
||||||
|
- Check GPU VRAM usage: `nvidia-smi`
|
||||||
|
- Ensure audio segments are at least 1-2 seconds long
|
||||||
|
- Try speaking more clearly with less background noise
|
||||||
|
|
||||||
|
### Interruption not working
|
||||||
|
- Verify Miku is actually speaking (check miku_speaking flag)
|
||||||
|
- Check VAD probability in logs (must be > 0.7)
|
||||||
|
- Verify /interrupt endpoint returns success
|
||||||
|
- Check RVC logs for flushed chunks
|
||||||
|
|
||||||
|
### Multiple users causing issues
|
||||||
|
- Check STT logs for per-user session management
|
||||||
|
- Verify each user has separate STTClient instance
|
||||||
|
- Check for resource contention on GTX 1660
|
||||||
|
|
||||||
|
## Next Steps After Testing
|
||||||
|
|
||||||
|
### Phase 4C: LLM KV Cache Precomputation
|
||||||
|
- Use partial transcripts to start LLM generation early
|
||||||
|
- Precompute KV cache for common phrases
|
||||||
|
- Reduce latency between speech end and response start
|
||||||
|
|
||||||
|
### Phase 4D: Multi-User Refinement
|
||||||
|
- Queue management for multiple simultaneous speakers
|
||||||
|
- Priority system for interruptions
|
||||||
|
- Resource allocation for multiple Whisper requests
|
||||||
|
|
||||||
|
### Phase 4E: Latency Optimization
|
||||||
|
- Profile each stage of the pipeline
|
||||||
|
- Optimize audio chunk sizes
|
||||||
|
- Reduce WebSocket message overhead
|
||||||
|
- Tune Whisper beam search parameters
|
||||||
|
- Implement VAD lookahead for quicker detection
|
||||||
|
|
||||||
|
## Hardware Utilization
|
||||||
|
|
||||||
|
### Current Allocation
|
||||||
|
- **AMD RX 6800**: LLaMA text models (idle during listen/speak)
|
||||||
|
- **GTX 1660**:
|
||||||
|
- Listen phase: Faster-Whisper (1.3GB VRAM)
|
||||||
|
- Speak phase: Soprano TTS + RVC (time-multiplexed)
|
||||||
|
- **CPU**: Silero VAD, audio preprocessing
|
||||||
|
|
||||||
|
### Expected Performance
|
||||||
|
- VAD latency: <50ms (CPU processing)
|
||||||
|
- Transcription latency: 200-500ms (Whisper inference)
|
||||||
|
- LLM streaming: 20-30 tokens/sec (RX 6800)
|
||||||
|
- TTS synthesis: Real-time (GTX 1660)
|
||||||
|
- Total latency (speech → response): 1-2 seconds
|
||||||
|
|
||||||
|
## Testing Checklist
|
||||||
|
|
||||||
|
Before marking Phase 4B as complete:
|
||||||
|
|
||||||
|
- [ ] Test basic STT connection with `!miku listen`
|
||||||
|
- [ ] Verify VAD detects speech start/end correctly
|
||||||
|
- [ ] Confirm transcripts are accurate and complete
|
||||||
|
- [ ] Test LLM voice response generation works
|
||||||
|
- [ ] Verify interruption cancels TTS playback
|
||||||
|
- [ ] Check multi-user handling (if possible)
|
||||||
|
- [ ] Verify resource cleanup on `!miku stop-listening`
|
||||||
|
- [ ] Test edge cases (silence, background noise, overlapping speech)
|
||||||
|
- [ ] Profile latencies at each stage
|
||||||
|
- [ ] Document any configuration tuning needed
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
**Status**: Code deployed, ready for user testing! 🎤🤖
|
||||||
323
VOICE_TO_VOICE_REFERENCE.md
Normal file
323
VOICE_TO_VOICE_REFERENCE.md
Normal file
@@ -0,0 +1,323 @@
|
|||||||
|
# Voice-to-Voice Quick Reference
|
||||||
|
|
||||||
|
## Complete Pipeline Status ✅
|
||||||
|
|
||||||
|
All phases complete and deployed!
|
||||||
|
|
||||||
|
## Phase Completion Status
|
||||||
|
|
||||||
|
### ✅ Phase 1: Voice Connection (COMPLETE)
|
||||||
|
- Discord voice channel connection
|
||||||
|
- Audio playback via discord.py
|
||||||
|
- Resource management and cleanup
|
||||||
|
|
||||||
|
### ✅ Phase 2: Audio Streaming (COMPLETE)
|
||||||
|
- Soprano TTS server (GTX 1660)
|
||||||
|
- RVC voice conversion
|
||||||
|
- Real-time streaming via WebSocket
|
||||||
|
- Token-by-token synthesis
|
||||||
|
|
||||||
|
### ✅ Phase 3: Text-to-Voice (COMPLETE)
|
||||||
|
- LLaMA text generation (AMD RX 6800)
|
||||||
|
- Streaming token pipeline
|
||||||
|
- TTS integration with `!miku say`
|
||||||
|
- Natural conversation flow
|
||||||
|
|
||||||
|
### ✅ Phase 4A: STT Container (COMPLETE)
|
||||||
|
- Silero VAD on CPU
|
||||||
|
- Faster-Whisper on GTX 1660
|
||||||
|
- WebSocket server at port 8001
|
||||||
|
- Per-user session management
|
||||||
|
- Chunk buffering for VAD
|
||||||
|
|
||||||
|
### ✅ Phase 4B: Bot STT Integration (COMPLETE - READY FOR TESTING)
|
||||||
|
- Discord audio capture
|
||||||
|
- Opus decode + resampling
|
||||||
|
- STT client WebSocket integration
|
||||||
|
- Voice commands: `!miku listen`, `!miku stop-listening`
|
||||||
|
- LLM voice response generation
|
||||||
|
- Interruption detection and cancellation
|
||||||
|
- `/interrupt` endpoint in RVC API
|
||||||
|
|
||||||
|
## Quick Start Commands
|
||||||
|
|
||||||
|
### Setup
|
||||||
|
```bash
|
||||||
|
!miku join # Join your voice channel
|
||||||
|
!miku listen # Start listening to your voice
|
||||||
|
```
|
||||||
|
|
||||||
|
### Usage
|
||||||
|
- **Speak** into your microphone
|
||||||
|
- Miku will **transcribe** your speech
|
||||||
|
- Miku will **respond** with voice
|
||||||
|
- **Interrupt** her by speaking while she's talking
|
||||||
|
|
||||||
|
### Teardown
|
||||||
|
```bash
|
||||||
|
!miku stop-listening # Stop listening to your voice
|
||||||
|
!miku leave # Leave voice channel
|
||||||
|
```
|
||||||
|
|
||||||
|
## Architecture Diagram
|
||||||
|
|
||||||
|
```
|
||||||
|
┌─────────────────────────────────────────────────────────────────┐
|
||||||
|
│ USER INPUT │
|
||||||
|
└─────────────────────────────────────────────────────────────────┘
|
||||||
|
│
|
||||||
|
│ Discord Voice (Opus 48kHz)
|
||||||
|
▼
|
||||||
|
┌─────────────────────────────────────────────────────────────────┐
|
||||||
|
│ miku-bot Container │
|
||||||
|
│ ┌───────────────────────────────────────────────────────────┐ │
|
||||||
|
│ │ VoiceReceiver (discord.sinks.Sink) │ │
|
||||||
|
│ │ - Opus decode → PCM │ │
|
||||||
|
│ │ - Stereo → Mono │ │
|
||||||
|
│ │ - Resample 48kHz → 16kHz │ │
|
||||||
|
│ └─────────────────┬─────────────────────────────────────────┘ │
|
||||||
|
│ │ PCM int16, 16kHz, 20ms chunks │
|
||||||
|
│ ┌─────────────────▼─────────────────────────────────────────┐ │
|
||||||
|
│ │ STTClient (WebSocket) │ │
|
||||||
|
│ │ - Sends audio to miku-stt │ │
|
||||||
|
│ │ - Receives VAD events, transcripts │ │
|
||||||
|
│ └─────────────────┬─────────────────────────────────────────┘ │
|
||||||
|
└────────────────────┼───────────────────────────────────────────┘
|
||||||
|
│ ws://miku-stt:8001/ws/stt/{user_id}
|
||||||
|
▼
|
||||||
|
┌─────────────────────────────────────────────────────────────────┐
|
||||||
|
│ miku-stt Container │
|
||||||
|
│ ┌───────────────────────────────────────────────────────────┐ │
|
||||||
|
│ │ VADProcessor (Silero VAD 5.1.2) [CPU] │ │
|
||||||
|
│ │ - Chunk buffering (512 samples min) │ │
|
||||||
|
│ │ - Speech detection (threshold=0.5) │ │
|
||||||
|
│ │ - Events: speech_start, speaking, speech_end │ │
|
||||||
|
│ └─────────────────┬─────────────────────────────────────────┘ │
|
||||||
|
│ │ Audio segments │
|
||||||
|
│ ┌─────────────────▼─────────────────────────────────────────┐ │
|
||||||
|
│ │ WhisperTranscriber (Faster-Whisper 1.2.1) [GTX 1660] │ │
|
||||||
|
│ │ - Model: small (1.3GB VRAM) │ │
|
||||||
|
│ │ - Transcribes speech segments │ │
|
||||||
|
│ │ - Returns: partial & final transcripts │ │
|
||||||
|
│ └─────────────────┬─────────────────────────────────────────┘ │
|
||||||
|
└────────────────────┼───────────────────────────────────────────┘
|
||||||
|
│ JSON events via WebSocket
|
||||||
|
▼
|
||||||
|
┌─────────────────────────────────────────────────────────────────┐
|
||||||
|
│ miku-bot Container │
|
||||||
|
│ ┌───────────────────────────────────────────────────────────┐ │
|
||||||
|
│ │ voice_manager.py Callbacks │ │
|
||||||
|
│ │ - on_vad_event() → Log VAD states │ │
|
||||||
|
│ │ - on_partial_transcript() → Show typing indicator │ │
|
||||||
|
│ │ - on_final_transcript() → Generate LLM response │ │
|
||||||
|
│ │ - on_interruption() → Cancel TTS playback │ │
|
||||||
|
│ └─────────────────┬─────────────────────────────────────────┘ │
|
||||||
|
│ │ Final transcript text │
|
||||||
|
│ ┌─────────────────▼─────────────────────────────────────────┐ │
|
||||||
|
│ │ _generate_voice_response() │ │
|
||||||
|
│ │ - Build LLM prompt with conversation history │ │
|
||||||
|
│ │ - Stream LLM response │ │
|
||||||
|
│ │ - Send tokens to TTS │ │
|
||||||
|
│ └─────────────────┬─────────────────────────────────────────┘ │
|
||||||
|
└────────────────────┼───────────────────────────────────────────┘
|
||||||
|
│ HTTP streaming to LLaMA server
|
||||||
|
▼
|
||||||
|
┌─────────────────────────────────────────────────────────────────┐
|
||||||
|
│ llama-cpp-server (AMD RX 6800) │
|
||||||
|
│ - Streaming text generation │
|
||||||
|
│ - 20-30 tokens/sec │
|
||||||
|
│ - Returns: {"delta": {"content": "token"}} │
|
||||||
|
└─────────────────┬───────────────────────────────────────────────┘
|
||||||
|
│ Token stream
|
||||||
|
▼
|
||||||
|
┌─────────────────────────────────────────────────────────────────┐
|
||||||
|
│ miku-bot Container │
|
||||||
|
│ ┌───────────────────────────────────────────────────────────┐ │
|
||||||
|
│ │ audio_source.send_token() │ │
|
||||||
|
│ │ - Buffers tokens │ │
|
||||||
|
│ │ - Sends to RVC WebSocket │ │
|
||||||
|
│ └─────────────────┬─────────────────────────────────────────┘ │
|
||||||
|
└────────────────────┼───────────────────────────────────────────┘
|
||||||
|
│ ws://miku-rvc-api:8765/ws/stream
|
||||||
|
▼
|
||||||
|
┌─────────────────────────────────────────────────────────────────┐
|
||||||
|
│ miku-rvc-api Container │
|
||||||
|
│ ┌───────────────────────────────────────────────────────────┐ │
|
||||||
|
│ │ Soprano TTS Server (miku-soprano-tts) [GTX 1660] │ │
|
||||||
|
│ │ - Text → Audio synthesis │ │
|
||||||
|
│ │ - 32kHz output │ │
|
||||||
|
│ └─────────────────┬─────────────────────────────────────────┘ │
|
||||||
|
│ │ Raw audio via ZMQ │
|
||||||
|
│ ┌─────────────────▼─────────────────────────────────────────┐ │
|
||||||
|
│ │ RVC Voice Conversion [GTX 1660] │ │
|
||||||
|
│ │ - Voice cloning & pitch shifting │ │
|
||||||
|
│ │ - 48kHz output │ │
|
||||||
|
│ └─────────────────┬─────────────────────────────────────────┘ │
|
||||||
|
└────────────────────┼───────────────────────────────────────────┘
|
||||||
|
│ PCM float32, 48kHz
|
||||||
|
▼
|
||||||
|
┌─────────────────────────────────────────────────────────────────┐
|
||||||
|
│ miku-bot Container │
|
||||||
|
│ ┌───────────────────────────────────────────────────────────┐ │
|
||||||
|
│ │ discord.VoiceClient │ │
|
||||||
|
│ │ - Plays audio in voice channel │ │
|
||||||
|
│ │ - Can be interrupted by user speech │ │
|
||||||
|
│ └───────────────────────────────────────────────────────────┘ │
|
||||||
|
└─────────────────────────────────────────────────────────────────┘
|
||||||
|
│
|
||||||
|
▼
|
||||||
|
┌─────────────────────────────────────────────────────────────────┐
|
||||||
|
│ USER OUTPUT │
|
||||||
|
│ (Miku's voice response) │
|
||||||
|
└─────────────────────────────────────────────────────────────────┘
|
||||||
|
```
|
||||||
|
|
||||||
|
## Interruption Flow
|
||||||
|
|
||||||
|
```
|
||||||
|
User speaks during Miku's TTS
|
||||||
|
│
|
||||||
|
▼
|
||||||
|
VAD detects speech (probability > 0.7)
|
||||||
|
│
|
||||||
|
▼
|
||||||
|
STT sends interruption event
|
||||||
|
│
|
||||||
|
▼
|
||||||
|
on_user_interruption() callback
|
||||||
|
│
|
||||||
|
▼
|
||||||
|
_cancel_tts() → voice_client.stop()
|
||||||
|
│
|
||||||
|
▼
|
||||||
|
POST http://miku-rvc-api:8765/interrupt
|
||||||
|
│
|
||||||
|
▼
|
||||||
|
Flush ZMQ socket + clear RVC buffers
|
||||||
|
│
|
||||||
|
▼
|
||||||
|
Miku stops speaking, ready for new input
|
||||||
|
```
|
||||||
|
|
||||||
|
## Hardware Utilization
|
||||||
|
|
||||||
|
### Listen Phase (User Speaking)
|
||||||
|
- **CPU**: Silero VAD processing
|
||||||
|
- **GTX 1660**: Faster-Whisper transcription (1.3GB VRAM)
|
||||||
|
- **AMD RX 6800**: Idle
|
||||||
|
|
||||||
|
### Think Phase (LLM Generation)
|
||||||
|
- **CPU**: Idle
|
||||||
|
- **GTX 1660**: Idle
|
||||||
|
- **AMD RX 6800**: LLaMA inference (20-30 tokens/sec)
|
||||||
|
|
||||||
|
### Speak Phase (Miku Responding)
|
||||||
|
- **CPU**: Silero VAD monitoring for interruption
|
||||||
|
- **GTX 1660**: Soprano TTS + RVC synthesis
|
||||||
|
- **AMD RX 6800**: Idle
|
||||||
|
|
||||||
|
## Performance Metrics
|
||||||
|
|
||||||
|
### Expected Latencies
|
||||||
|
| Stage | Latency |
|
||||||
|
|--------------------------|--------------|
|
||||||
|
| Discord audio capture | ~20ms |
|
||||||
|
| Opus decode + resample | <10ms |
|
||||||
|
| VAD processing | <50ms |
|
||||||
|
| Whisper transcription | 200-500ms |
|
||||||
|
| LLM token generation | 33-50ms/tok |
|
||||||
|
| TTS synthesis | Real-time |
|
||||||
|
| **Total (speech → response)** | **1-2s** |
|
||||||
|
|
||||||
|
### VRAM Usage
|
||||||
|
| GPU | Component | VRAM |
|
||||||
|
|-------------|----------------|-----------|
|
||||||
|
| AMD RX 6800 | LLaMA 8B Q4 | ~5.5GB |
|
||||||
|
| GTX 1660 | Whisper small | 1.3GB |
|
||||||
|
| GTX 1660 | Soprano + RVC | ~3GB |
|
||||||
|
|
||||||
|
## Key Files
|
||||||
|
|
||||||
|
### Bot Container
|
||||||
|
- `bot/utils/stt_client.py` - WebSocket client for STT
|
||||||
|
- `bot/utils/voice_receiver.py` - Discord audio sink
|
||||||
|
- `bot/utils/voice_manager.py` - Voice session with STT integration
|
||||||
|
- `bot/commands/voice.py` - Voice commands including listen/stop-listening
|
||||||
|
|
||||||
|
### STT Container
|
||||||
|
- `stt/vad_processor.py` - Silero VAD with chunk buffering
|
||||||
|
- `stt/whisper_transcriber.py` - Faster-Whisper transcription
|
||||||
|
- `stt/stt_server.py` - FastAPI WebSocket server
|
||||||
|
|
||||||
|
### RVC Container
|
||||||
|
- `soprano_to_rvc/soprano_rvc_api.py` - TTS + RVC pipeline with /interrupt endpoint
|
||||||
|
|
||||||
|
## Configuration Files
|
||||||
|
|
||||||
|
### docker-compose.yml
|
||||||
|
- Network: `miku-network` (all containers)
|
||||||
|
- Ports:
|
||||||
|
- miku-bot: 8081 (API)
|
||||||
|
- miku-rvc-api: 8765 (TTS)
|
||||||
|
- miku-stt: 8001 (STT)
|
||||||
|
- llama-cpp-server: 8080 (LLM)
|
||||||
|
|
||||||
|
### VAD Settings (stt/vad_processor.py)
|
||||||
|
```python
|
||||||
|
threshold = 0.5 # Speech detection sensitivity
|
||||||
|
min_speech = 250 # Minimum speech duration (ms)
|
||||||
|
min_silence = 500 # Silence before speech_end (ms)
|
||||||
|
interruption_threshold = 0.7 # Probability for interruption
|
||||||
|
```
|
||||||
|
|
||||||
|
### Whisper Settings (stt/whisper_transcriber.py)
|
||||||
|
```python
|
||||||
|
model = "small" # 1.3GB VRAM
|
||||||
|
device = "cuda"
|
||||||
|
compute_type = "float16"
|
||||||
|
beam_size = 5
|
||||||
|
patience = 1.0
|
||||||
|
```
|
||||||
|
|
||||||
|
## Testing Commands
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Check all container health
|
||||||
|
curl http://localhost:8001/health # STT
|
||||||
|
curl http://localhost:8765/health # RVC
|
||||||
|
curl http://localhost:8080/health # LLM
|
||||||
|
|
||||||
|
# Monitor logs
|
||||||
|
docker logs -f miku-bot | grep -E "(listen|transcript|interrupt)"
|
||||||
|
docker logs -f miku-stt
|
||||||
|
docker logs -f miku-rvc-api | grep interrupt
|
||||||
|
|
||||||
|
# Test interrupt endpoint
|
||||||
|
curl -X POST http://localhost:8765/interrupt
|
||||||
|
|
||||||
|
# Check GPU usage
|
||||||
|
nvidia-smi
|
||||||
|
```
|
||||||
|
|
||||||
|
## Troubleshooting
|
||||||
|
|
||||||
|
| Issue | Solution |
|
||||||
|
|-------|----------|
|
||||||
|
| No audio from Discord | Check bot has Connect and Speak permissions |
|
||||||
|
| VAD not detecting | Speak louder, check microphone, lower threshold |
|
||||||
|
| Empty transcripts | Speak for at least 1-2 seconds, check Whisper model |
|
||||||
|
| Interruption not working | Verify `miku_speaking=true`, check VAD probability |
|
||||||
|
| High latency | Profile each stage, check GPU utilization |
|
||||||
|
|
||||||
|
## Next Features (Phase 4C+)
|
||||||
|
|
||||||
|
- [ ] KV cache precomputation from partial transcripts
|
||||||
|
- [ ] Multi-user simultaneous conversation
|
||||||
|
- [ ] Latency optimization (<1s total)
|
||||||
|
- [ ] Voice activity history and analytics
|
||||||
|
- [ ] Emotion detection from speech patterns
|
||||||
|
- [ ] Context-aware interruption handling
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
**Ready to test!** Use `!miku join` → `!miku listen` → speak to Miku 🎤
|
||||||
@@ -125,7 +125,7 @@ async def on_message(message):
|
|||||||
if message.author == globals.client.user:
|
if message.author == globals.client.user:
|
||||||
return
|
return
|
||||||
|
|
||||||
# Check for voice commands first (!miku join, !miku leave, !miku voice-status, !miku test, !miku say)
|
# Check for voice commands first (!miku join, !miku leave, !miku voice-status, !miku test, !miku say, !miku listen, !miku stop-listening)
|
||||||
if not isinstance(message.channel, discord.DMChannel) and message.content.strip().lower().startswith('!miku '):
|
if not isinstance(message.channel, discord.DMChannel) and message.content.strip().lower().startswith('!miku '):
|
||||||
from commands.voice import handle_voice_command
|
from commands.voice import handle_voice_command
|
||||||
|
|
||||||
@@ -134,7 +134,7 @@ async def on_message(message):
|
|||||||
cmd = parts[1].lower()
|
cmd = parts[1].lower()
|
||||||
args = parts[2:] if len(parts) > 2 else []
|
args = parts[2:] if len(parts) > 2 else []
|
||||||
|
|
||||||
if cmd in ['join', 'leave', 'voice-status', 'test', 'say']:
|
if cmd in ['join', 'leave', 'voice-status', 'test', 'say', 'listen', 'stop-listening']:
|
||||||
await handle_voice_command(message, cmd, args)
|
await handle_voice_command(message, cmd, args)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|||||||
@@ -39,6 +39,12 @@ async def handle_voice_command(message, cmd, args):
|
|||||||
elif cmd == 'say':
|
elif cmd == 'say':
|
||||||
await _handle_say(message, args)
|
await _handle_say(message, args)
|
||||||
|
|
||||||
|
elif cmd == 'listen':
|
||||||
|
await _handle_listen(message, args)
|
||||||
|
|
||||||
|
elif cmd == 'stop-listening':
|
||||||
|
await _handle_stop_listening(message, args)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
await message.channel.send(f"❌ Unknown voice command: `{cmd}`")
|
await message.channel.send(f"❌ Unknown voice command: `{cmd}`")
|
||||||
|
|
||||||
@@ -366,8 +372,97 @@ Keep responses short (1-3 sentences) since they will be spoken aloud."""
|
|||||||
await message.channel.send(f"🎤 Miku: *\"{full_response.strip()}\"*")
|
await message.channel.send(f"🎤 Miku: *\"{full_response.strip()}\"*")
|
||||||
logger.info(f"✓ Voice say complete: {full_response.strip()}")
|
logger.info(f"✓ Voice say complete: {full_response.strip()}")
|
||||||
await message.add_reaction("✅")
|
await message.add_reaction("✅")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Voice say failed: {e}", exc_info=True)
|
logger.error(f"Failed to generate voice response: {e}", exc_info=True)
|
||||||
await message.channel.send(f"❌ Voice say failed: {str(e)}")
|
await message.channel.send(f"❌ Error generating voice response: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
async def _handle_listen(message, args):
|
||||||
|
"""
|
||||||
|
Handle !miku listen command.
|
||||||
|
Start listening to a user's voice for STT.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
!miku listen - Start listening to command author
|
||||||
|
!miku listen @user - Start listening to mentioned user
|
||||||
|
"""
|
||||||
|
# Check if Miku is in voice channel
|
||||||
|
session = voice_manager.active_session
|
||||||
|
|
||||||
|
if not session or not session.voice_client or not session.voice_client.is_connected():
|
||||||
|
await message.channel.send("❌ I'm not in a voice channel! Use `!miku join` first.")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Determine target user
|
||||||
|
target_user = None
|
||||||
|
if args and len(message.mentions) > 0:
|
||||||
|
# Listen to mentioned user
|
||||||
|
target_user = message.mentions[0]
|
||||||
|
else:
|
||||||
|
# Listen to command author
|
||||||
|
target_user = message.author
|
||||||
|
|
||||||
|
# Check if user is in voice channel
|
||||||
|
if not target_user.voice or not target_user.voice.channel:
|
||||||
|
await message.channel.send(f"❌ {target_user.mention} is not in a voice channel!")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Check if user is in same channel as Miku
|
||||||
|
if target_user.voice.channel.id != session.voice_client.channel.id:
|
||||||
|
await message.channel.send(
|
||||||
|
f"❌ {target_user.mention} must be in the same voice channel as me!"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Start listening to user
|
||||||
|
await session.start_listening(target_user)
|
||||||
|
await message.channel.send(
|
||||||
|
f"👂 Now listening to {target_user.mention}'s voice! "
|
||||||
|
f"Speak to me and I'll respond. Use `!miku stop-listening` to stop."
|
||||||
|
)
|
||||||
|
await message.add_reaction("👂")
|
||||||
|
logger.info(f"Started listening to user {target_user.id} ({target_user.name})")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to start listening: {e}", exc_info=True)
|
||||||
|
await message.channel.send(f"❌ Failed to start listening: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
async def _handle_stop_listening(message, args):
|
||||||
|
"""
|
||||||
|
Handle !miku stop-listening command.
|
||||||
|
Stop listening to a user's voice.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
!miku stop-listening - Stop listening to command author
|
||||||
|
!miku stop-listening @user - Stop listening to mentioned user
|
||||||
|
"""
|
||||||
|
# Check if Miku is in voice channel
|
||||||
|
session = voice_manager.active_session
|
||||||
|
|
||||||
|
if not session:
|
||||||
|
await message.channel.send("❌ I'm not in a voice channel!")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Determine target user
|
||||||
|
target_user = None
|
||||||
|
if args and len(message.mentions) > 0:
|
||||||
|
# Stop listening to mentioned user
|
||||||
|
target_user = message.mentions[0]
|
||||||
|
else:
|
||||||
|
# Stop listening to command author
|
||||||
|
target_user = message.author
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Stop listening to user
|
||||||
|
await session.stop_listening(target_user.id)
|
||||||
|
await message.channel.send(f"🔇 Stopped listening to {target_user.mention}.")
|
||||||
|
await message.add_reaction("🔇")
|
||||||
|
logger.info(f"Stopped listening to user {target_user.id} ({target_user.name})")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to stop listening: {e}", exc_info=True)
|
||||||
|
await message.channel.send(f"❌ Failed to stop listening: {str(e)}")
|
||||||
|
|
||||||
|
|||||||
@@ -22,3 +22,4 @@ transformers
|
|||||||
torch
|
torch
|
||||||
PyNaCl>=1.5.0
|
PyNaCl>=1.5.0
|
||||||
websockets>=12.0
|
websockets>=12.0
|
||||||
|
discord-ext-voice-recv
|
||||||
|
|||||||
214
bot/utils/stt_client.py
Normal file
214
bot/utils/stt_client.py
Normal file
@@ -0,0 +1,214 @@
|
|||||||
|
"""
|
||||||
|
STT Client for Discord Bot
|
||||||
|
|
||||||
|
WebSocket client that connects to the STT server and handles:
|
||||||
|
- Audio streaming to STT
|
||||||
|
- Receiving VAD events
|
||||||
|
- Receiving partial/final transcripts
|
||||||
|
- Interruption detection
|
||||||
|
"""
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from typing import Optional, Callable
|
||||||
|
import json
|
||||||
|
|
||||||
|
logger = logging.getLogger('stt_client')
|
||||||
|
|
||||||
|
|
||||||
|
class STTClient:
|
||||||
|
"""
|
||||||
|
WebSocket client for STT server communication.
|
||||||
|
|
||||||
|
Handles audio streaming and receives transcription events.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
stt_url: str = "ws://miku-stt:8000/ws/stt",
|
||||||
|
on_vad_event: Optional[Callable] = None,
|
||||||
|
on_partial_transcript: Optional[Callable] = None,
|
||||||
|
on_final_transcript: Optional[Callable] = None,
|
||||||
|
on_interruption: Optional[Callable] = None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize STT client.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: Discord user ID
|
||||||
|
stt_url: Base WebSocket URL for STT server
|
||||||
|
on_vad_event: Callback for VAD events (event_dict)
|
||||||
|
on_partial_transcript: Callback for partial transcripts (text, timestamp)
|
||||||
|
on_final_transcript: Callback for final transcripts (text, timestamp)
|
||||||
|
on_interruption: Callback for interruption detection (probability)
|
||||||
|
"""
|
||||||
|
self.user_id = user_id
|
||||||
|
self.stt_url = f"{stt_url}/{user_id}"
|
||||||
|
|
||||||
|
# Callbacks
|
||||||
|
self.on_vad_event = on_vad_event
|
||||||
|
self.on_partial_transcript = on_partial_transcript
|
||||||
|
self.on_final_transcript = on_final_transcript
|
||||||
|
self.on_interruption = on_interruption
|
||||||
|
|
||||||
|
# Connection state
|
||||||
|
self.websocket: Optional[aiohttp.ClientWebSocket] = None
|
||||||
|
self.session: Optional[aiohttp.ClientSession] = None
|
||||||
|
self.connected = False
|
||||||
|
self.running = False
|
||||||
|
|
||||||
|
# Receive task
|
||||||
|
self._receive_task: Optional[asyncio.Task] = None
|
||||||
|
|
||||||
|
logger.info(f"STT client initialized for user {user_id}")
|
||||||
|
|
||||||
|
async def connect(self):
|
||||||
|
"""Connect to STT WebSocket server."""
|
||||||
|
if self.connected:
|
||||||
|
logger.warning(f"Already connected for user {self.user_id}")
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.session = aiohttp.ClientSession()
|
||||||
|
self.websocket = await self.session.ws_connect(
|
||||||
|
self.stt_url,
|
||||||
|
heartbeat=30
|
||||||
|
)
|
||||||
|
|
||||||
|
# Wait for ready message
|
||||||
|
ready_msg = await self.websocket.receive_json()
|
||||||
|
logger.info(f"STT connected for user {self.user_id}: {ready_msg}")
|
||||||
|
|
||||||
|
self.connected = True
|
||||||
|
self.running = True
|
||||||
|
|
||||||
|
# Start receive task
|
||||||
|
self._receive_task = asyncio.create_task(self._receive_events())
|
||||||
|
|
||||||
|
logger.info(f"✓ STT WebSocket connected for user {self.user_id}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to connect STT for user {self.user_id}: {e}", exc_info=True)
|
||||||
|
await self.disconnect()
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def disconnect(self):
|
||||||
|
"""Disconnect from STT WebSocket."""
|
||||||
|
logger.info(f"Disconnecting STT for user {self.user_id}")
|
||||||
|
|
||||||
|
self.running = False
|
||||||
|
self.connected = False
|
||||||
|
|
||||||
|
# Cancel receive task
|
||||||
|
if self._receive_task and not self._receive_task.done():
|
||||||
|
self._receive_task.cancel()
|
||||||
|
try:
|
||||||
|
await self._receive_task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Close WebSocket
|
||||||
|
if self.websocket:
|
||||||
|
await self.websocket.close()
|
||||||
|
self.websocket = None
|
||||||
|
|
||||||
|
# Close session
|
||||||
|
if self.session:
|
||||||
|
await self.session.close()
|
||||||
|
self.session = None
|
||||||
|
|
||||||
|
logger.info(f"✓ STT disconnected for user {self.user_id}")
|
||||||
|
|
||||||
|
async def send_audio(self, audio_data: bytes):
|
||||||
|
"""
|
||||||
|
Send audio chunk to STT server.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
audio_data: PCM audio (int16, 16kHz mono)
|
||||||
|
"""
|
||||||
|
if not self.connected or not self.websocket:
|
||||||
|
logger.warning(f"Cannot send audio, not connected for user {self.user_id}")
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
await self.websocket.send_bytes(audio_data)
|
||||||
|
logger.debug(f"Sent {len(audio_data)} bytes to STT")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to send audio to STT: {e}")
|
||||||
|
self.connected = False
|
||||||
|
|
||||||
|
async def _receive_events(self):
|
||||||
|
"""Background task to receive events from STT server."""
|
||||||
|
try:
|
||||||
|
while self.running and self.websocket:
|
||||||
|
try:
|
||||||
|
msg = await self.websocket.receive()
|
||||||
|
|
||||||
|
if msg.type == aiohttp.WSMsgType.TEXT:
|
||||||
|
event = json.loads(msg.data)
|
||||||
|
await self._handle_event(event)
|
||||||
|
|
||||||
|
elif msg.type == aiohttp.WSMsgType.CLOSED:
|
||||||
|
logger.info(f"STT WebSocket closed for user {self.user_id}")
|
||||||
|
break
|
||||||
|
|
||||||
|
elif msg.type == aiohttp.WSMsgType.ERROR:
|
||||||
|
logger.error(f"STT WebSocket error for user {self.user_id}")
|
||||||
|
break
|
||||||
|
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error receiving STT event: {e}", exc_info=True)
|
||||||
|
|
||||||
|
finally:
|
||||||
|
self.connected = False
|
||||||
|
logger.info(f"STT receive task ended for user {self.user_id}")
|
||||||
|
|
||||||
|
async def _handle_event(self, event: dict):
|
||||||
|
"""
|
||||||
|
Handle incoming STT event.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
event: Event dictionary from STT server
|
||||||
|
"""
|
||||||
|
event_type = event.get('type')
|
||||||
|
|
||||||
|
if event_type == 'vad':
|
||||||
|
# VAD event: speech detection
|
||||||
|
logger.debug(f"VAD event: {event}")
|
||||||
|
if self.on_vad_event:
|
||||||
|
await self.on_vad_event(event)
|
||||||
|
|
||||||
|
elif event_type == 'partial':
|
||||||
|
# Partial transcript
|
||||||
|
text = event.get('text', '')
|
||||||
|
timestamp = event.get('timestamp', 0)
|
||||||
|
logger.info(f"Partial transcript [{self.user_id}]: {text}")
|
||||||
|
if self.on_partial_transcript:
|
||||||
|
await self.on_partial_transcript(text, timestamp)
|
||||||
|
|
||||||
|
elif event_type == 'final':
|
||||||
|
# Final transcript
|
||||||
|
text = event.get('text', '')
|
||||||
|
timestamp = event.get('timestamp', 0)
|
||||||
|
logger.info(f"Final transcript [{self.user_id}]: {text}")
|
||||||
|
if self.on_final_transcript:
|
||||||
|
await self.on_final_transcript(text, timestamp)
|
||||||
|
|
||||||
|
elif event_type == 'interruption':
|
||||||
|
# Interruption detected
|
||||||
|
probability = event.get('probability', 0)
|
||||||
|
logger.info(f"Interruption detected from user {self.user_id} (prob={probability:.3f})")
|
||||||
|
if self.on_interruption:
|
||||||
|
await self.on_interruption(probability)
|
||||||
|
|
||||||
|
else:
|
||||||
|
logger.warning(f"Unknown STT event type: {event_type}")
|
||||||
|
|
||||||
|
def is_connected(self) -> bool:
|
||||||
|
"""Check if STT client is connected."""
|
||||||
|
return self.connected
|
||||||
@@ -19,6 +19,7 @@ import json
|
|||||||
import os
|
import os
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
import discord
|
import discord
|
||||||
|
from discord.ext import voice_recv
|
||||||
import globals
|
import globals
|
||||||
from utils.logger import get_logger
|
from utils.logger import get_logger
|
||||||
|
|
||||||
@@ -97,12 +98,12 @@ class VoiceSessionManager:
|
|||||||
# 10. Create voice session
|
# 10. Create voice session
|
||||||
self.active_session = VoiceSession(guild_id, voice_channel, text_channel)
|
self.active_session = VoiceSession(guild_id, voice_channel, text_channel)
|
||||||
|
|
||||||
# 11. Connect to Discord voice channel
|
# 11. Connect to Discord voice channel with VoiceRecvClient
|
||||||
try:
|
try:
|
||||||
voice_client = await voice_channel.connect()
|
voice_client = await voice_channel.connect(cls=voice_recv.VoiceRecvClient)
|
||||||
self.active_session.voice_client = voice_client
|
self.active_session.voice_client = voice_client
|
||||||
self.active_session.active = True
|
self.active_session.active = True
|
||||||
logger.info(f"✓ Connected to voice channel: {voice_channel.name}")
|
logger.info(f"✓ Connected to voice channel: {voice_channel.name} (with audio receiving)")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to connect to voice channel: {e}", exc_info=True)
|
logger.error(f"Failed to connect to voice channel: {e}", exc_info=True)
|
||||||
raise
|
raise
|
||||||
@@ -387,7 +388,9 @@ class VoiceSession:
|
|||||||
self.voice_client: Optional[discord.VoiceClient] = None
|
self.voice_client: Optional[discord.VoiceClient] = None
|
||||||
self.audio_source: Optional['MikuVoiceSource'] = None # Forward reference
|
self.audio_source: Optional['MikuVoiceSource'] = None # Forward reference
|
||||||
self.tts_streamer: Optional['TTSTokenStreamer'] = None # Forward reference
|
self.tts_streamer: Optional['TTSTokenStreamer'] = None # Forward reference
|
||||||
|
self.voice_receiver: Optional['VoiceReceiver'] = None # STT receiver
|
||||||
self.active = False
|
self.active = False
|
||||||
|
self.miku_speaking = False # Track if Miku is currently speaking
|
||||||
|
|
||||||
logger.info(f"VoiceSession created for {voice_channel.name} in guild {guild_id}")
|
logger.info(f"VoiceSession created for {voice_channel.name} in guild {guild_id}")
|
||||||
|
|
||||||
@@ -433,6 +436,207 @@ class VoiceSession:
|
|||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error stopping audio streaming: {e}", exc_info=True)
|
logger.error(f"Error stopping audio streaming: {e}", exc_info=True)
|
||||||
|
|
||||||
|
async def start_listening(self, user: discord.User):
|
||||||
|
"""
|
||||||
|
Start listening to a user's voice (STT).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user: Discord user to listen to
|
||||||
|
"""
|
||||||
|
from utils.voice_receiver import VoiceReceiverSink
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Create receiver if not exists
|
||||||
|
if not self.voice_receiver:
|
||||||
|
self.voice_receiver = VoiceReceiverSink(self)
|
||||||
|
|
||||||
|
# Start receiving audio from Discord using discord-ext-voice-recv
|
||||||
|
if self.voice_client:
|
||||||
|
self.voice_client.listen(self.voice_receiver)
|
||||||
|
logger.info("✓ Discord voice receive started (discord-ext-voice-recv)")
|
||||||
|
|
||||||
|
# Start listening to specific user
|
||||||
|
await self.voice_receiver.start_listening(user.id, user)
|
||||||
|
logger.info(f"✓ Started listening to {user.name}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to start listening to {user.name}: {e}", exc_info=True)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def stop_listening(self, user_id: int):
|
||||||
|
"""
|
||||||
|
Stop listening to a user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: Discord user ID
|
||||||
|
"""
|
||||||
|
if self.voice_receiver:
|
||||||
|
await self.voice_receiver.stop_listening(user_id)
|
||||||
|
logger.info(f"✓ Stopped listening to user {user_id}")
|
||||||
|
|
||||||
|
async def stop_all_listening(self):
|
||||||
|
"""Stop listening to all users."""
|
||||||
|
if self.voice_receiver:
|
||||||
|
await self.voice_receiver.stop_all()
|
||||||
|
self.voice_receiver = None
|
||||||
|
logger.info("✓ Stopped all listening")
|
||||||
|
|
||||||
|
async def on_user_vad_event(self, user_id: int, event: dict):
|
||||||
|
"""Called when VAD detects speech state change."""
|
||||||
|
event_type = event.get('event')
|
||||||
|
logger.debug(f"User {user_id} VAD: {event_type}")
|
||||||
|
|
||||||
|
async def on_partial_transcript(self, user_id: int, text: str):
|
||||||
|
"""Called when partial transcript is received."""
|
||||||
|
logger.info(f"Partial from user {user_id}: {text}")
|
||||||
|
# Could show "User is saying..." in chat
|
||||||
|
|
||||||
|
async def on_final_transcript(self, user_id: int, text: str):
|
||||||
|
"""
|
||||||
|
Called when final transcript is received.
|
||||||
|
This triggers LLM response and TTS.
|
||||||
|
"""
|
||||||
|
logger.info(f"Final from user {user_id}: {text}")
|
||||||
|
|
||||||
|
# Get user info
|
||||||
|
user = self.voice_channel.guild.get_member(user_id)
|
||||||
|
if not user:
|
||||||
|
logger.warning(f"User {user_id} not found in guild")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Show what user said
|
||||||
|
await self.text_channel.send(f"🎤 {user.name}: *\"{text}\"*")
|
||||||
|
|
||||||
|
# Generate LLM response and speak it
|
||||||
|
await self._generate_voice_response(user, text)
|
||||||
|
|
||||||
|
async def on_user_interruption(self, user_id: int, probability: float):
|
||||||
|
"""
|
||||||
|
Called when user interrupts Miku's speech.
|
||||||
|
Cancel TTS and switch to listening.
|
||||||
|
"""
|
||||||
|
if not self.miku_speaking:
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info(f"User {user_id} interrupted Miku (prob={probability:.3f})")
|
||||||
|
|
||||||
|
# Cancel Miku's speech
|
||||||
|
await self._cancel_tts()
|
||||||
|
|
||||||
|
# Show interruption in chat
|
||||||
|
user = self.voice_channel.guild.get_member(user_id)
|
||||||
|
await self.text_channel.send(f"⚠️ *{user.name if user else 'User'} interrupted Miku*")
|
||||||
|
|
||||||
|
async def _generate_voice_response(self, user: discord.User, text: str):
|
||||||
|
"""
|
||||||
|
Generate LLM response and speak it.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user: User who spoke
|
||||||
|
text: Transcribed text
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
self.miku_speaking = True
|
||||||
|
|
||||||
|
# Show processing
|
||||||
|
await self.text_channel.send(f"💭 *Miku is thinking...*")
|
||||||
|
|
||||||
|
# Import here to avoid circular imports
|
||||||
|
from utils.llm import get_current_gpu_url
|
||||||
|
import aiohttp
|
||||||
|
import globals
|
||||||
|
|
||||||
|
# Simple system prompt for voice
|
||||||
|
system_prompt = """You are Hatsune Miku, the virtual singer.
|
||||||
|
Respond naturally and concisely as Miku would in a voice conversation.
|
||||||
|
Keep responses short (1-3 sentences) since they will be spoken aloud."""
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"model": globals.TEXT_MODEL,
|
||||||
|
"messages": [
|
||||||
|
{"role": "system", "content": system_prompt},
|
||||||
|
{"role": "user", "content": text}
|
||||||
|
],
|
||||||
|
"stream": True,
|
||||||
|
"temperature": 0.8,
|
||||||
|
"max_tokens": 200
|
||||||
|
}
|
||||||
|
|
||||||
|
headers = {'Content-Type': 'application/json'}
|
||||||
|
llama_url = get_current_gpu_url()
|
||||||
|
|
||||||
|
# Stream LLM response to TTS
|
||||||
|
full_response = ""
|
||||||
|
async with aiohttp.ClientSession() as http_session:
|
||||||
|
async with http_session.post(
|
||||||
|
f"{llama_url}/v1/chat/completions",
|
||||||
|
json=payload,
|
||||||
|
headers=headers,
|
||||||
|
timeout=aiohttp.ClientTimeout(total=60)
|
||||||
|
) as response:
|
||||||
|
if response.status != 200:
|
||||||
|
error_text = await response.text()
|
||||||
|
raise Exception(f"LLM error {response.status}: {error_text}")
|
||||||
|
|
||||||
|
# Stream tokens to TTS
|
||||||
|
async for line in response.content:
|
||||||
|
if not self.miku_speaking:
|
||||||
|
# Interrupted
|
||||||
|
break
|
||||||
|
|
||||||
|
line = line.decode('utf-8').strip()
|
||||||
|
if line.startswith('data: '):
|
||||||
|
data_str = line[6:]
|
||||||
|
if data_str == '[DONE]':
|
||||||
|
break
|
||||||
|
|
||||||
|
try:
|
||||||
|
import json
|
||||||
|
data = json.loads(data_str)
|
||||||
|
if 'choices' in data and len(data['choices']) > 0:
|
||||||
|
delta = data['choices'][0].get('delta', {})
|
||||||
|
content = delta.get('content', '')
|
||||||
|
if content:
|
||||||
|
await self.audio_source.send_token(content)
|
||||||
|
full_response += content
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Flush TTS
|
||||||
|
if self.miku_speaking:
|
||||||
|
await self.audio_source.flush()
|
||||||
|
|
||||||
|
# Show response
|
||||||
|
await self.text_channel.send(f"🎤 Miku: *\"{full_response.strip()}\"*")
|
||||||
|
logger.info(f"✓ Voice response complete: {full_response.strip()}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Voice response failed: {e}", exc_info=True)
|
||||||
|
await self.text_channel.send(f"❌ Sorry, I had trouble responding")
|
||||||
|
|
||||||
|
finally:
|
||||||
|
self.miku_speaking = False
|
||||||
|
|
||||||
|
async def _cancel_tts(self):
|
||||||
|
"""Cancel current TTS synthesis."""
|
||||||
|
logger.info("Canceling TTS synthesis")
|
||||||
|
|
||||||
|
# Stop Discord playback
|
||||||
|
if self.voice_client and self.voice_client.is_playing():
|
||||||
|
self.voice_client.stop()
|
||||||
|
|
||||||
|
# Send interrupt to RVC
|
||||||
|
try:
|
||||||
|
import aiohttp
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post("http://172.25.0.1:8765/interrupt") as resp:
|
||||||
|
if resp.status == 200:
|
||||||
|
logger.info("✓ TTS interrupted")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to interrupt TTS: {e}")
|
||||||
|
|
||||||
|
self.miku_speaking = False
|
||||||
|
|
||||||
|
|
||||||
# Global singleton instance
|
# Global singleton instance
|
||||||
|
|||||||
411
bot/utils/voice_receiver.py
Normal file
411
bot/utils/voice_receiver.py
Normal file
@@ -0,0 +1,411 @@
|
|||||||
|
"""
|
||||||
|
Discord Voice Receiver using discord-ext-voice-recv
|
||||||
|
|
||||||
|
Captures audio from Discord voice channels and streams to STT.
|
||||||
|
Uses the discord-ext-voice-recv extension for proper audio receiving support.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import audioop
|
||||||
|
import logging
|
||||||
|
from typing import Dict, Optional
|
||||||
|
from collections import deque
|
||||||
|
|
||||||
|
import discord
|
||||||
|
from discord.ext import voice_recv
|
||||||
|
|
||||||
|
from utils.stt_client import STTClient
|
||||||
|
|
||||||
|
logger = logging.getLogger('voice_receiver')
|
||||||
|
|
||||||
|
|
||||||
|
class VoiceReceiverSink(voice_recv.AudioSink):
|
||||||
|
"""
|
||||||
|
Audio sink that receives Discord audio and forwards to STT.
|
||||||
|
|
||||||
|
This sink processes incoming audio from Discord voice channels,
|
||||||
|
decodes/resamples as needed, and sends to STT clients for transcription.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, voice_manager, stt_url: str = "ws://miku-stt:8000/ws/stt"):
|
||||||
|
"""
|
||||||
|
Initialize voice receiver sink.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
voice_manager: Reference to VoiceManager for callbacks
|
||||||
|
stt_url: Base URL for STT WebSocket server with path (port 8000 inside container)
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.voice_manager = voice_manager
|
||||||
|
self.stt_url = stt_url
|
||||||
|
|
||||||
|
# Store event loop for thread-safe async calls
|
||||||
|
# Use get_running_loop() in async context, or store it when available
|
||||||
|
try:
|
||||||
|
self.loop = asyncio.get_running_loop()
|
||||||
|
except RuntimeError:
|
||||||
|
# Fallback if not in async context yet
|
||||||
|
self.loop = asyncio.get_event_loop()
|
||||||
|
|
||||||
|
# Per-user STT clients
|
||||||
|
self.stt_clients: Dict[int, STTClient] = {}
|
||||||
|
|
||||||
|
# Audio buffers per user (for resampling state)
|
||||||
|
self.audio_buffers: Dict[int, deque] = {}
|
||||||
|
|
||||||
|
# User info (for logging)
|
||||||
|
self.users: Dict[int, discord.User] = {}
|
||||||
|
|
||||||
|
# Active flag
|
||||||
|
self.active = False
|
||||||
|
|
||||||
|
logger.info("VoiceReceiverSink initialized")
|
||||||
|
|
||||||
|
def wants_opus(self) -> bool:
|
||||||
|
"""
|
||||||
|
Tell discord-ext-voice-recv we want Opus data, NOT decoded PCM.
|
||||||
|
|
||||||
|
We'll decode it ourselves to avoid decoder errors from discord-ext-voice-recv.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True - we want Opus packets, we'll handle decoding
|
||||||
|
"""
|
||||||
|
return True # Get Opus, decode ourselves to avoid packet router errors
|
||||||
|
|
||||||
|
def write(self, user: Optional[discord.User], data: voice_recv.VoiceData):
|
||||||
|
"""
|
||||||
|
Called by discord-ext-voice-recv when audio is received.
|
||||||
|
|
||||||
|
This is the main callback that receives audio packets from Discord.
|
||||||
|
We get Opus data, decode it ourselves, resample, and forward to STT.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user: Discord user who sent the audio (None if unknown)
|
||||||
|
data: Voice data container with pcm, opus, and packet info
|
||||||
|
"""
|
||||||
|
if not user:
|
||||||
|
return # Skip packets from unknown users
|
||||||
|
|
||||||
|
user_id = user.id
|
||||||
|
|
||||||
|
# Check if we're listening to this user
|
||||||
|
if user_id not in self.stt_clients:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Get Opus data (we decode ourselves to avoid PacketRouter errors)
|
||||||
|
opus_data = data.opus
|
||||||
|
|
||||||
|
if not opus_data:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Decode Opus to PCM (48kHz stereo int16)
|
||||||
|
# Use discord.py's opus decoder with proper error handling
|
||||||
|
import discord.opus
|
||||||
|
if not hasattr(self, '_opus_decoders'):
|
||||||
|
self._opus_decoders = {}
|
||||||
|
|
||||||
|
# Create decoder for this user if needed
|
||||||
|
if user_id not in self._opus_decoders:
|
||||||
|
self._opus_decoders[user_id] = discord.opus.Decoder()
|
||||||
|
|
||||||
|
decoder = self._opus_decoders[user_id]
|
||||||
|
|
||||||
|
# Decode opus -> PCM (this can fail on corrupt packets, so catch it)
|
||||||
|
try:
|
||||||
|
pcm_data = decoder.decode(opus_data, fec=False)
|
||||||
|
except discord.opus.OpusError as e:
|
||||||
|
# Skip corrupted packets silently (common at stream start)
|
||||||
|
logger.debug(f"Skipping corrupted opus packet for user {user_id}: {e}")
|
||||||
|
return
|
||||||
|
|
||||||
|
if not pcm_data:
|
||||||
|
return
|
||||||
|
|
||||||
|
# PCM from Discord is 48kHz stereo int16
|
||||||
|
# Convert stereo to mono
|
||||||
|
if len(pcm_data) % 4 == 0: # Stereo (2 channels * 2 bytes per sample)
|
||||||
|
pcm_mono = audioop.tomono(pcm_data, 2, 0.5, 0.5)
|
||||||
|
else:
|
||||||
|
pcm_mono = pcm_data
|
||||||
|
|
||||||
|
# Resample from 48kHz to 16kHz for STT
|
||||||
|
# Discord sends 20ms chunks: 960 samples @ 48kHz → 320 samples @ 16kHz
|
||||||
|
pcm_16k, _ = audioop.ratecv(pcm_mono, 2, 1, 48000, 16000, None)
|
||||||
|
|
||||||
|
# Send to STT client (schedule on event loop thread-safely)
|
||||||
|
asyncio.run_coroutine_threadsafe(
|
||||||
|
self._send_audio_chunk(user_id, pcm_16k),
|
||||||
|
self.loop
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error processing audio for user {user_id}: {e}", exc_info=True)
|
||||||
|
|
||||||
|
def cleanup(self):
|
||||||
|
"""
|
||||||
|
Called when the sink is stopped.
|
||||||
|
Cleanup any resources.
|
||||||
|
"""
|
||||||
|
logger.info("VoiceReceiverSink cleanup")
|
||||||
|
# Async cleanup handled separately in stop_all()
|
||||||
|
|
||||||
|
async def start_listening(self, user_id: int, user: discord.User):
|
||||||
|
"""
|
||||||
|
Start listening to a specific user.
|
||||||
|
|
||||||
|
Creates an STT client connection for this user and registers callbacks.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: Discord user ID
|
||||||
|
user: Discord user object
|
||||||
|
"""
|
||||||
|
if user_id in self.stt_clients:
|
||||||
|
logger.warning(f"Already listening to user {user.name} ({user_id})")
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info(f"Starting to listen to user {user.name} ({user_id})")
|
||||||
|
|
||||||
|
# Store user info
|
||||||
|
self.users[user_id] = user
|
||||||
|
|
||||||
|
# Initialize audio buffer
|
||||||
|
self.audio_buffers[user_id] = deque(maxlen=1000)
|
||||||
|
|
||||||
|
# Create STT client with callbacks
|
||||||
|
stt_client = STTClient(
|
||||||
|
user_id=user_id,
|
||||||
|
stt_url=self.stt_url,
|
||||||
|
on_vad_event=lambda event: asyncio.create_task(
|
||||||
|
self._on_vad_event(user_id, event)
|
||||||
|
),
|
||||||
|
on_partial_transcript=lambda text, timestamp: asyncio.create_task(
|
||||||
|
self._on_partial_transcript(user_id, text)
|
||||||
|
),
|
||||||
|
on_final_transcript=lambda text, timestamp: asyncio.create_task(
|
||||||
|
self._on_final_transcript(user_id, text, user)
|
||||||
|
),
|
||||||
|
on_interruption=lambda prob: asyncio.create_task(
|
||||||
|
self._on_interruption(user_id, prob)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Connect to STT server
|
||||||
|
try:
|
||||||
|
await stt_client.connect()
|
||||||
|
self.stt_clients[user_id] = stt_client
|
||||||
|
self.active = True
|
||||||
|
logger.info(f"✓ STT connected for user {user.name}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to connect STT for user {user.name}: {e}", exc_info=True)
|
||||||
|
# Cleanup partial state
|
||||||
|
if user_id in self.audio_buffers:
|
||||||
|
del self.audio_buffers[user_id]
|
||||||
|
if user_id in self.users:
|
||||||
|
del self.users[user_id]
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def stop_listening(self, user_id: int):
|
||||||
|
"""
|
||||||
|
Stop listening to a specific user.
|
||||||
|
|
||||||
|
Disconnects the STT client and cleans up resources for this user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: Discord user ID
|
||||||
|
"""
|
||||||
|
if user_id not in self.stt_clients:
|
||||||
|
logger.warning(f"Not listening to user {user_id}")
|
||||||
|
return
|
||||||
|
|
||||||
|
user = self.users.get(user_id)
|
||||||
|
logger.info(f"Stopping listening to user {user.name if user else user_id}")
|
||||||
|
|
||||||
|
# Disconnect STT client
|
||||||
|
stt_client = self.stt_clients[user_id]
|
||||||
|
await stt_client.disconnect()
|
||||||
|
|
||||||
|
# Cleanup
|
||||||
|
del self.stt_clients[user_id]
|
||||||
|
if user_id in self.audio_buffers:
|
||||||
|
del self.audio_buffers[user_id]
|
||||||
|
if user_id in self.users:
|
||||||
|
del self.users[user_id]
|
||||||
|
|
||||||
|
# Cleanup opus decoder for this user
|
||||||
|
if hasattr(self, '_opus_decoders') and user_id in self._opus_decoders:
|
||||||
|
del self._opus_decoders[user_id]
|
||||||
|
|
||||||
|
# Update active flag
|
||||||
|
if not self.stt_clients:
|
||||||
|
self.active = False
|
||||||
|
|
||||||
|
logger.info(f"✓ Stopped listening to user {user.name if user else user_id}")
|
||||||
|
|
||||||
|
async def stop_all(self):
|
||||||
|
"""Stop listening to all users and cleanup all resources."""
|
||||||
|
logger.info("Stopping all voice receivers")
|
||||||
|
|
||||||
|
user_ids = list(self.stt_clients.keys())
|
||||||
|
for user_id in user_ids:
|
||||||
|
await self.stop_listening(user_id)
|
||||||
|
|
||||||
|
self.active = False
|
||||||
|
logger.info("✓ All voice receivers stopped")
|
||||||
|
|
||||||
|
async def _send_audio_chunk(self, user_id: int, audio_data: bytes):
|
||||||
|
"""
|
||||||
|
Send audio chunk to STT client.
|
||||||
|
|
||||||
|
Buffers audio until we have 512 samples (32ms @ 16kHz) which is what
|
||||||
|
Silero VAD expects. Discord sends 320 samples (20ms), so we buffer
|
||||||
|
2 chunks and send 640 samples, then the STT server can split it.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: Discord user ID
|
||||||
|
audio_data: PCM audio (int16, 16kHz mono, 320 samples = 640 bytes)
|
||||||
|
"""
|
||||||
|
stt_client = self.stt_clients.get(user_id)
|
||||||
|
if not stt_client or not stt_client.is_connected():
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Get or create buffer for this user
|
||||||
|
if user_id not in self.audio_buffers:
|
||||||
|
self.audio_buffers[user_id] = deque()
|
||||||
|
|
||||||
|
buffer = self.audio_buffers[user_id]
|
||||||
|
buffer.append(audio_data)
|
||||||
|
|
||||||
|
# Silero VAD expects 512 samples @ 16kHz (1024 bytes)
|
||||||
|
# Discord gives us 320 samples (640 bytes) every 20ms
|
||||||
|
# Buffer 2 chunks = 640 samples = 1280 bytes, send as one chunk
|
||||||
|
SAMPLES_NEEDED = 512 # What VAD wants
|
||||||
|
BYTES_NEEDED = SAMPLES_NEEDED * 2 # int16 = 2 bytes per sample
|
||||||
|
|
||||||
|
# Check if we have enough buffered audio
|
||||||
|
total_bytes = sum(len(chunk) for chunk in buffer)
|
||||||
|
|
||||||
|
if total_bytes >= BYTES_NEEDED:
|
||||||
|
# Concatenate buffered chunks
|
||||||
|
combined = b''.join(buffer)
|
||||||
|
buffer.clear()
|
||||||
|
|
||||||
|
# Send in 512-sample (1024-byte) chunks
|
||||||
|
for i in range(0, len(combined), BYTES_NEEDED):
|
||||||
|
chunk = combined[i:i+BYTES_NEEDED]
|
||||||
|
if len(chunk) == BYTES_NEEDED:
|
||||||
|
await stt_client.send_audio(chunk)
|
||||||
|
else:
|
||||||
|
# Put remaining partial chunk back in buffer
|
||||||
|
buffer.append(chunk)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to send audio chunk for user {user_id}: {e}")
|
||||||
|
|
||||||
|
async def _on_vad_event(self, user_id: int, event: dict):
|
||||||
|
"""
|
||||||
|
Handle VAD event from STT.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: Discord user ID
|
||||||
|
event: VAD event dictionary with 'event' and 'probability' keys
|
||||||
|
"""
|
||||||
|
user = self.users.get(user_id)
|
||||||
|
event_type = event.get('event', 'unknown')
|
||||||
|
probability = event.get('probability', 0.0)
|
||||||
|
|
||||||
|
logger.debug(f"VAD [{user.name if user else user_id}]: {event_type} (prob={probability:.3f})")
|
||||||
|
|
||||||
|
# Notify voice manager - pass the full event dict
|
||||||
|
if hasattr(self.voice_manager, 'on_user_vad_event'):
|
||||||
|
await self.voice_manager.on_user_vad_event(user_id, event)
|
||||||
|
|
||||||
|
async def _on_partial_transcript(self, user_id: int, text: str):
|
||||||
|
"""
|
||||||
|
Handle partial transcript from STT.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: Discord user ID
|
||||||
|
text: Partial transcript text
|
||||||
|
"""
|
||||||
|
user = self.users.get(user_id)
|
||||||
|
logger.info(f"[VOICE_RECEIVER] Partial [{user.name if user else user_id}]: {text}")
|
||||||
|
print(f"[DEBUG] PARTIAL TRANSCRIPT RECEIVED: {text}") # Extra debug
|
||||||
|
|
||||||
|
# Notify voice manager
|
||||||
|
if hasattr(self.voice_manager, 'on_partial_transcript'):
|
||||||
|
await self.voice_manager.on_partial_transcript(user_id, text)
|
||||||
|
|
||||||
|
async def _on_final_transcript(self, user_id: int, text: str, user: discord.User):
|
||||||
|
"""
|
||||||
|
Handle final transcript from STT.
|
||||||
|
|
||||||
|
This triggers the LLM response generation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: Discord user ID
|
||||||
|
text: Final transcript text
|
||||||
|
user: Discord user object
|
||||||
|
"""
|
||||||
|
logger.info(f"[VOICE_RECEIVER] Final [{user.name if user else user_id}]: {text}")
|
||||||
|
print(f"[DEBUG] FINAL TRANSCRIPT RECEIVED: {text}") # Extra debug
|
||||||
|
|
||||||
|
# Notify voice manager - THIS TRIGGERS LLM RESPONSE
|
||||||
|
if hasattr(self.voice_manager, 'on_final_transcript'):
|
||||||
|
await self.voice_manager.on_final_transcript(user_id, text)
|
||||||
|
|
||||||
|
async def _on_interruption(self, user_id: int, probability: float):
|
||||||
|
"""
|
||||||
|
Handle interruption detection from STT.
|
||||||
|
|
||||||
|
This cancels Miku's current speech if user interrupts.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: Discord user ID
|
||||||
|
probability: Interruption confidence probability
|
||||||
|
"""
|
||||||
|
user = self.users.get(user_id)
|
||||||
|
logger.info(f"Interruption from [{user.name if user else user_id}] (prob={probability:.3f})")
|
||||||
|
|
||||||
|
# Notify voice manager - THIS CANCELS MIKU'S SPEECH
|
||||||
|
if hasattr(self.voice_manager, 'on_user_interruption'):
|
||||||
|
await self.voice_manager.on_user_interruption(user_id, probability)
|
||||||
|
|
||||||
|
def get_listening_users(self) -> list:
|
||||||
|
"""
|
||||||
|
Get list of users currently being listened to.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of dicts with user_id, username, and connection status
|
||||||
|
"""
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
'user_id': user_id,
|
||||||
|
'username': user.name if user else 'Unknown',
|
||||||
|
'connected': client.is_connected()
|
||||||
|
}
|
||||||
|
for user_id, (user, client) in
|
||||||
|
[(uid, (self.users.get(uid), self.stt_clients.get(uid)))
|
||||||
|
for uid in self.stt_clients.keys()]
|
||||||
|
]
|
||||||
|
|
||||||
|
@voice_recv.AudioSink.listener()
|
||||||
|
def on_voice_member_speaking_start(self, member: discord.Member):
|
||||||
|
"""
|
||||||
|
Called when a member starts speaking (green circle appears).
|
||||||
|
|
||||||
|
This is a virtual event from discord-ext-voice-recv based on packet activity.
|
||||||
|
"""
|
||||||
|
if member.id in self.stt_clients:
|
||||||
|
logger.debug(f"🎤 {member.name} started speaking")
|
||||||
|
|
||||||
|
@voice_recv.AudioSink.listener()
|
||||||
|
def on_voice_member_speaking_stop(self, member: discord.Member):
|
||||||
|
"""
|
||||||
|
Called when a member stops speaking (green circle disappears).
|
||||||
|
|
||||||
|
This is a virtual event from discord-ext-voice-recv based on packet activity.
|
||||||
|
"""
|
||||||
|
if member.id in self.stt_clients:
|
||||||
|
logger.debug(f"🔇 {member.name} stopped speaking")
|
||||||
419
bot/utils/voice_receiver.py.old
Normal file
419
bot/utils/voice_receiver.py.old
Normal file
@@ -0,0 +1,419 @@
|
|||||||
|
"""
|
||||||
|
Discord Voice Receiver
|
||||||
|
|
||||||
|
Captures audio from Discord voice channels and streams to STT.
|
||||||
|
Handles opus decoding and audio preprocessing.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import discord
|
||||||
|
import audioop
|
||||||
|
import numpy as np
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from typing import Dict, Optional
|
||||||
|
from collections import deque
|
||||||
|
|
||||||
|
from utils.stt_client import STTClient
|
||||||
|
|
||||||
|
logger = logging.getLogger('voice_receiver')
|
||||||
|
|
||||||
|
|
||||||
|
class VoiceReceiver(discord.sinks.Sink):
|
||||||
|
"""
|
||||||
|
Voice Receiver for Discord Audio Capture
|
||||||
|
|
||||||
|
Captures audio from Discord voice channels using discord.py's voice websocket.
|
||||||
|
Processes Opus audio, decodes to PCM, resamples to 16kHz mono for STT.
|
||||||
|
|
||||||
|
Note: Standard discord.py doesn't have built-in audio receiving.
|
||||||
|
This implementation hooks into the voice websocket directly.
|
||||||
|
"""
|
||||||
|
import asyncio
|
||||||
|
import struct
|
||||||
|
import audioop
|
||||||
|
import logging
|
||||||
|
from typing import Dict, Optional, Callable
|
||||||
|
import discord
|
||||||
|
|
||||||
|
# Import opus decoder
|
||||||
|
try:
|
||||||
|
import discord.opus as opus
|
||||||
|
if not opus.is_loaded():
|
||||||
|
opus.load_opus('opus')
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Failed to load opus: {e}")
|
||||||
|
|
||||||
|
from utils.stt_client import STTClient
|
||||||
|
|
||||||
|
logger = logging.getLogger('voice_receiver')
|
||||||
|
|
||||||
|
|
||||||
|
class VoiceReceiver:
|
||||||
|
"""
|
||||||
|
Receives and processes audio from Discord voice channel.
|
||||||
|
|
||||||
|
This class monkey-patches the VoiceClient to intercept received RTP packets,
|
||||||
|
decodes Opus audio, and forwards to STT clients.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
voice_client: discord.VoiceClient,
|
||||||
|
voice_manager,
|
||||||
|
stt_url: str = "ws://miku-stt:8001"
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize voice receiver.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
voice_client: Discord VoiceClient to receive audio from
|
||||||
|
voice_manager: Voice manager instance for callbacks
|
||||||
|
stt_url: Base URL for STT WebSocket server
|
||||||
|
"""
|
||||||
|
self.voice_client = voice_client
|
||||||
|
self.voice_manager = voice_manager
|
||||||
|
self.stt_url = stt_url
|
||||||
|
|
||||||
|
# Per-user STT clients
|
||||||
|
self.stt_clients: Dict[int, STTClient] = {}
|
||||||
|
|
||||||
|
# Opus decoder instances per SSRC (one per user)
|
||||||
|
self.opus_decoders: Dict[int, any] = {}
|
||||||
|
|
||||||
|
# Resampler state per user (for 48kHz → 16kHz)
|
||||||
|
self.resample_state: Dict[int, tuple] = {}
|
||||||
|
|
||||||
|
# Original receive method (for restoration)
|
||||||
|
self._original_receive = None
|
||||||
|
|
||||||
|
# Active flag
|
||||||
|
self.active = False
|
||||||
|
|
||||||
|
logger.info("VoiceReceiver initialized")
|
||||||
|
|
||||||
|
async def start_listening(self, user_id: int, user: discord.User):
|
||||||
|
"""
|
||||||
|
Start listening to a specific user's audio.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: Discord user ID
|
||||||
|
user: Discord User object
|
||||||
|
"""
|
||||||
|
if user_id in self.stt_clients:
|
||||||
|
logger.warning(f"Already listening to user {user_id}")
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Create STT client for this user
|
||||||
|
stt_client = STTClient(
|
||||||
|
user_id=user_id,
|
||||||
|
stt_url=self.stt_url,
|
||||||
|
on_vad_event=lambda event, prob: asyncio.create_task(
|
||||||
|
self.voice_manager.on_user_vad_event(user_id, event)
|
||||||
|
),
|
||||||
|
on_partial_transcript=lambda text: asyncio.create_task(
|
||||||
|
self.voice_manager.on_partial_transcript(user_id, text)
|
||||||
|
),
|
||||||
|
on_final_transcript=lambda text: asyncio.create_task(
|
||||||
|
self.voice_manager.on_final_transcript(user_id, text, user)
|
||||||
|
),
|
||||||
|
on_interruption=lambda prob: asyncio.create_task(
|
||||||
|
self.voice_manager.on_user_interruption(user_id, prob)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Connect to STT server
|
||||||
|
await stt_client.connect()
|
||||||
|
|
||||||
|
# Store client
|
||||||
|
self.stt_clients[user_id] = stt_client
|
||||||
|
|
||||||
|
# Initialize opus decoder for this user if needed
|
||||||
|
# (Will be done when we receive their SSRC)
|
||||||
|
|
||||||
|
# Patch voice client to receive audio if not already patched
|
||||||
|
if not self.active:
|
||||||
|
await self._patch_voice_client()
|
||||||
|
|
||||||
|
logger.info(f"✓ Started listening to user {user_id} ({user.name})")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to start listening to user {user_id}: {e}", exc_info=True)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def stop_listening(self, user_id: int):
|
||||||
|
"""
|
||||||
|
Stop listening to a specific user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: Discord user ID
|
||||||
|
"""
|
||||||
|
if user_id not in self.stt_clients:
|
||||||
|
logger.warning(f"Not listening to user {user_id}")
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Disconnect STT client
|
||||||
|
stt_client = self.stt_clients.pop(user_id)
|
||||||
|
await stt_client.disconnect()
|
||||||
|
|
||||||
|
# Clean up decoder and resampler state
|
||||||
|
# Note: We don't know the SSRC here, so we'll just remove by user_id
|
||||||
|
# Actual cleanup happens in _process_audio when we match SSRC to user_id
|
||||||
|
|
||||||
|
# If no more clients, unpatch voice client
|
||||||
|
if not self.stt_clients:
|
||||||
|
await self._unpatch_voice_client()
|
||||||
|
|
||||||
|
logger.info(f"✓ Stopped listening to user {user_id}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to stop listening to user {user_id}: {e}", exc_info=True)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def _patch_voice_client(self):
|
||||||
|
"""Patch VoiceClient to intercept received audio packets."""
|
||||||
|
logger.warning("⚠️ Audio receiving not yet implemented - discord.py doesn't support receiving by default")
|
||||||
|
logger.warning("⚠️ You need discord.py-self or a custom fork with receiving support")
|
||||||
|
logger.warning("⚠️ STT will not receive any audio until this is implemented")
|
||||||
|
self.active = True
|
||||||
|
# TODO: Implement RTP packet receiving
|
||||||
|
# This requires either:
|
||||||
|
# 1. Using discord.py-self which has receiving support
|
||||||
|
# 2. Monkey-patching voice_client.ws to intercept packets
|
||||||
|
# 3. Using a separate UDP socket listener
|
||||||
|
|
||||||
|
async def _unpatch_voice_client(self):
|
||||||
|
"""Restore original VoiceClient behavior."""
|
||||||
|
self.active = False
|
||||||
|
logger.info("Unpatch voice client (receiving disabled)")
|
||||||
|
|
||||||
|
async def _process_audio(self, ssrc: int, opus_data: bytes):
|
||||||
|
"""
|
||||||
|
Process received Opus audio packet.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ssrc: RTP SSRC (identifies the audio source/user)
|
||||||
|
opus_data: Opus-encoded audio data
|
||||||
|
"""
|
||||||
|
# TODO: Map SSRC to user_id (requires tracking voice state updates)
|
||||||
|
# For now, this is a placeholder
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def cleanup(self):
|
||||||
|
"""Clean up all resources."""
|
||||||
|
# Disconnect all STT clients
|
||||||
|
for user_id in list(self.stt_clients.keys()):
|
||||||
|
await self.stop_listening(user_id)
|
||||||
|
|
||||||
|
# Unpatch voice client
|
||||||
|
if self.active:
|
||||||
|
await self._unpatch_voice_client()
|
||||||
|
|
||||||
|
logger.info("VoiceReceiver cleanup complete") def __init__(self, voice_manager):
|
||||||
|
"""
|
||||||
|
Initialize voice receiver.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
voice_manager: Reference to VoiceManager for callbacks
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.voice_manager = voice_manager
|
||||||
|
|
||||||
|
# Per-user STT clients
|
||||||
|
self.stt_clients: Dict[int, STTClient] = {}
|
||||||
|
|
||||||
|
# Audio buffers per user (for resampling)
|
||||||
|
self.audio_buffers: Dict[int, deque] = {}
|
||||||
|
|
||||||
|
# User info (for logging)
|
||||||
|
self.users: Dict[int, discord.User] = {}
|
||||||
|
|
||||||
|
logger.info("Voice receiver initialized")
|
||||||
|
|
||||||
|
async def start_listening(self, user_id: int, user: discord.User):
|
||||||
|
"""
|
||||||
|
Start listening to a specific user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: Discord user ID
|
||||||
|
user: Discord user object
|
||||||
|
"""
|
||||||
|
if user_id in self.stt_clients:
|
||||||
|
logger.warning(f"Already listening to user {user.name} ({user_id})")
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info(f"Starting to listen to user {user.name} ({user_id})")
|
||||||
|
|
||||||
|
# Store user info
|
||||||
|
self.users[user_id] = user
|
||||||
|
|
||||||
|
# Initialize audio buffer
|
||||||
|
self.audio_buffers[user_id] = deque(maxlen=1000) # Max 1000 chunks
|
||||||
|
|
||||||
|
# Create STT client with callbacks
|
||||||
|
stt_client = STTClient(
|
||||||
|
user_id=str(user_id),
|
||||||
|
on_vad_event=lambda event: self._on_vad_event(user_id, event),
|
||||||
|
on_partial_transcript=lambda text, ts: self._on_partial_transcript(user_id, text, ts),
|
||||||
|
on_final_transcript=lambda text, ts: self._on_final_transcript(user_id, text, ts),
|
||||||
|
on_interruption=lambda prob: self._on_interruption(user_id, prob)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Connect to STT
|
||||||
|
try:
|
||||||
|
await stt_client.connect()
|
||||||
|
self.stt_clients[user_id] = stt_client
|
||||||
|
logger.info(f"✓ STT connected for user {user.name}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to connect STT for user {user.name}: {e}")
|
||||||
|
|
||||||
|
async def stop_listening(self, user_id: int):
|
||||||
|
"""
|
||||||
|
Stop listening to a specific user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: Discord user ID
|
||||||
|
"""
|
||||||
|
if user_id not in self.stt_clients:
|
||||||
|
return
|
||||||
|
|
||||||
|
user = self.users.get(user_id)
|
||||||
|
logger.info(f"Stopping listening to user {user.name if user else user_id}")
|
||||||
|
|
||||||
|
# Disconnect STT client
|
||||||
|
stt_client = self.stt_clients[user_id]
|
||||||
|
await stt_client.disconnect()
|
||||||
|
|
||||||
|
# Cleanup
|
||||||
|
del self.stt_clients[user_id]
|
||||||
|
if user_id in self.audio_buffers:
|
||||||
|
del self.audio_buffers[user_id]
|
||||||
|
if user_id in self.users:
|
||||||
|
del self.users[user_id]
|
||||||
|
|
||||||
|
logger.info(f"✓ Stopped listening to user {user.name if user else user_id}")
|
||||||
|
|
||||||
|
async def stop_all(self):
|
||||||
|
"""Stop listening to all users."""
|
||||||
|
logger.info("Stopping all voice receivers")
|
||||||
|
|
||||||
|
user_ids = list(self.stt_clients.keys())
|
||||||
|
for user_id in user_ids:
|
||||||
|
await self.stop_listening(user_id)
|
||||||
|
|
||||||
|
logger.info("✓ All voice receivers stopped")
|
||||||
|
|
||||||
|
def write(self, data: discord.sinks.core.AudioData):
|
||||||
|
"""
|
||||||
|
Called by discord.py when audio is received.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: Audio data from Discord
|
||||||
|
"""
|
||||||
|
# Get user ID from SSRC
|
||||||
|
user_id = data.user.id if data.user else None
|
||||||
|
|
||||||
|
if not user_id:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Check if we're listening to this user
|
||||||
|
if user_id not in self.stt_clients:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Process audio
|
||||||
|
try:
|
||||||
|
# Decode opus to PCM (48kHz stereo)
|
||||||
|
pcm_data = data.pcm
|
||||||
|
|
||||||
|
# Convert stereo to mono if needed
|
||||||
|
if len(pcm_data) % 4 == 0: # Stereo int16 (2 channels * 2 bytes)
|
||||||
|
# Average left and right channels
|
||||||
|
pcm_mono = audioop.tomono(pcm_data, 2, 0.5, 0.5)
|
||||||
|
else:
|
||||||
|
pcm_mono = pcm_data
|
||||||
|
|
||||||
|
# Resample from 48kHz to 16kHz
|
||||||
|
# Discord sends 20ms chunks at 48kHz = 960 samples
|
||||||
|
# We need 320 samples at 16kHz (20ms)
|
||||||
|
pcm_16k = audioop.ratecv(pcm_mono, 2, 1, 48000, 16000, None)[0]
|
||||||
|
|
||||||
|
# Send to STT
|
||||||
|
asyncio.create_task(self._send_audio_chunk(user_id, pcm_16k))
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error processing audio for user {user_id}: {e}")
|
||||||
|
|
||||||
|
async def _send_audio_chunk(self, user_id: int, audio_data: bytes):
|
||||||
|
"""
|
||||||
|
Send audio chunk to STT client.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: Discord user ID
|
||||||
|
audio_data: PCM audio (int16, 16kHz mono)
|
||||||
|
"""
|
||||||
|
stt_client = self.stt_clients.get(user_id)
|
||||||
|
if not stt_client or not stt_client.is_connected():
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
await stt_client.send_audio(audio_data)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to send audio chunk for user {user_id}: {e}")
|
||||||
|
|
||||||
|
async def _on_vad_event(self, user_id: int, event: dict):
|
||||||
|
"""Handle VAD event from STT."""
|
||||||
|
user = self.users.get(user_id)
|
||||||
|
event_type = event.get('event')
|
||||||
|
probability = event.get('probability', 0)
|
||||||
|
|
||||||
|
logger.debug(f"VAD [{user.name if user else user_id}]: {event_type} (prob={probability:.3f})")
|
||||||
|
|
||||||
|
# Notify voice manager
|
||||||
|
if hasattr(self.voice_manager, 'on_user_vad_event'):
|
||||||
|
await self.voice_manager.on_user_vad_event(user_id, event)
|
||||||
|
|
||||||
|
async def _on_partial_transcript(self, user_id: int, text: str, timestamp: float):
|
||||||
|
"""Handle partial transcript from STT."""
|
||||||
|
user = self.users.get(user_id)
|
||||||
|
logger.info(f"Partial [{user.name if user else user_id}]: {text}")
|
||||||
|
|
||||||
|
# Notify voice manager
|
||||||
|
if hasattr(self.voice_manager, 'on_partial_transcript'):
|
||||||
|
await self.voice_manager.on_partial_transcript(user_id, text)
|
||||||
|
|
||||||
|
async def _on_final_transcript(self, user_id: int, text: str, timestamp: float):
|
||||||
|
"""Handle final transcript from STT."""
|
||||||
|
user = self.users.get(user_id)
|
||||||
|
logger.info(f"Final [{user.name if user else user_id}]: {text}")
|
||||||
|
|
||||||
|
# Notify voice manager - THIS TRIGGERS LLM RESPONSE
|
||||||
|
if hasattr(self.voice_manager, 'on_final_transcript'):
|
||||||
|
await self.voice_manager.on_final_transcript(user_id, text)
|
||||||
|
|
||||||
|
async def _on_interruption(self, user_id: int, probability: float):
|
||||||
|
"""Handle interruption detection from STT."""
|
||||||
|
user = self.users.get(user_id)
|
||||||
|
logger.info(f"Interruption from [{user.name if user else user_id}] (prob={probability:.3f})")
|
||||||
|
|
||||||
|
# Notify voice manager - THIS CANCELS MIKU'S SPEECH
|
||||||
|
if hasattr(self.voice_manager, 'on_user_interruption'):
|
||||||
|
await self.voice_manager.on_user_interruption(user_id, probability)
|
||||||
|
|
||||||
|
def cleanup(self):
|
||||||
|
"""Cleanup resources."""
|
||||||
|
logger.info("Cleaning up voice receiver")
|
||||||
|
# Async cleanup will be called separately
|
||||||
|
|
||||||
|
def get_listening_users(self) -> list:
|
||||||
|
"""Get list of users currently being listened to."""
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
'user_id': user_id,
|
||||||
|
'username': user.name if user else 'Unknown',
|
||||||
|
'connected': client.is_connected()
|
||||||
|
}
|
||||||
|
for user_id, (user, client) in
|
||||||
|
[(uid, (self.users.get(uid), self.stt_clients.get(uid)))
|
||||||
|
for uid in self.stt_clients.keys()]
|
||||||
|
]
|
||||||
@@ -76,6 +76,33 @@ services:
|
|||||||
- miku-voice # Connect to voice network for RVC/TTS
|
- miku-voice # Connect to voice network for RVC/TTS
|
||||||
restart: unless-stopped
|
restart: unless-stopped
|
||||||
|
|
||||||
|
miku-stt:
|
||||||
|
build:
|
||||||
|
context: ./stt
|
||||||
|
dockerfile: Dockerfile.stt
|
||||||
|
container_name: miku-stt
|
||||||
|
runtime: nvidia
|
||||||
|
environment:
|
||||||
|
- NVIDIA_VISIBLE_DEVICES=0 # GTX 1660 (same as Soprano)
|
||||||
|
- CUDA_VISIBLE_DEVICES=0
|
||||||
|
- NVIDIA_DRIVER_CAPABILITIES=compute,utility
|
||||||
|
- LD_LIBRARY_PATH=/usr/local/lib/python3.10/dist-packages/nvidia/cudnn/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64
|
||||||
|
volumes:
|
||||||
|
- ./stt:/app
|
||||||
|
- ./stt/models:/models
|
||||||
|
ports:
|
||||||
|
- "8001:8000"
|
||||||
|
networks:
|
||||||
|
- miku-voice
|
||||||
|
deploy:
|
||||||
|
resources:
|
||||||
|
reservations:
|
||||||
|
devices:
|
||||||
|
- driver: nvidia
|
||||||
|
device_ids: ['0'] # GTX 1660
|
||||||
|
capabilities: [gpu]
|
||||||
|
restart: unless-stopped
|
||||||
|
|
||||||
anime-face-detector:
|
anime-face-detector:
|
||||||
build: ./face-detector
|
build: ./face-detector
|
||||||
container_name: anime-face-detector
|
container_name: anime-face-detector
|
||||||
|
|||||||
35
stt/Dockerfile.stt
Normal file
35
stt/Dockerfile.stt
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
FROM nvidia/cuda:12.1.0-base-ubuntu22.04
|
||||||
|
|
||||||
|
# Set working directory
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
# Install system dependencies
|
||||||
|
RUN apt-get update && apt-get install -y \
|
||||||
|
python3.11 \
|
||||||
|
python3-pip \
|
||||||
|
ffmpeg \
|
||||||
|
libsndfile1 \
|
||||||
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
# Copy requirements
|
||||||
|
COPY requirements.txt .
|
||||||
|
|
||||||
|
# Install Python dependencies
|
||||||
|
RUN pip3 install --no-cache-dir -r requirements.txt
|
||||||
|
|
||||||
|
# Copy application code
|
||||||
|
COPY . .
|
||||||
|
|
||||||
|
# Create models directory
|
||||||
|
RUN mkdir -p /models
|
||||||
|
|
||||||
|
# Expose port
|
||||||
|
EXPOSE 8000
|
||||||
|
|
||||||
|
# Set environment variables
|
||||||
|
ENV PYTHONUNBUFFERED=1
|
||||||
|
ENV CUDA_VISIBLE_DEVICES=0
|
||||||
|
ENV LD_LIBRARY_PATH=/usr/local/lib/python3.11/dist-packages/nvidia/cudnn/lib:${LD_LIBRARY_PATH}
|
||||||
|
|
||||||
|
# Run the server
|
||||||
|
CMD ["uvicorn", "stt_server:app", "--host", "0.0.0.0", "--port", "8000", "--log-level", "info"]
|
||||||
152
stt/README.md
Normal file
152
stt/README.md
Normal file
@@ -0,0 +1,152 @@
|
|||||||
|
# Miku STT (Speech-to-Text) Server
|
||||||
|
|
||||||
|
Real-time speech-to-text service for Miku voice chat using Silero VAD (CPU) and Faster-Whisper (GPU).
|
||||||
|
|
||||||
|
## Architecture
|
||||||
|
|
||||||
|
- **Silero VAD** (CPU): Lightweight voice activity detection, runs continuously
|
||||||
|
- **Faster-Whisper** (GPU GTX 1660): Efficient speech transcription using CTranslate2
|
||||||
|
- **FastAPI WebSocket**: Real-time bidirectional communication
|
||||||
|
|
||||||
|
## Features
|
||||||
|
|
||||||
|
- ✅ Real-time voice activity detection with conservative settings
|
||||||
|
- ✅ Streaming partial transcripts during speech
|
||||||
|
- ✅ Final transcript on speech completion
|
||||||
|
- ✅ Interruption detection (user speaking over Miku)
|
||||||
|
- ✅ Multi-user support with isolated sessions
|
||||||
|
- ✅ KV cache optimization ready (partial text for LLM precomputation)
|
||||||
|
|
||||||
|
## API Endpoints
|
||||||
|
|
||||||
|
### WebSocket: `/ws/stt/{user_id}`
|
||||||
|
|
||||||
|
Real-time STT session for a specific user.
|
||||||
|
|
||||||
|
**Client sends:** Raw PCM audio (int16, 16kHz mono, 20ms chunks = 320 samples)
|
||||||
|
|
||||||
|
**Server sends:** JSON events:
|
||||||
|
```json
|
||||||
|
// VAD events
|
||||||
|
{"type": "vad", "event": "speech_start", "speaking": true, "probability": 0.85, "timestamp": 1250.5}
|
||||||
|
{"type": "vad", "event": "speaking", "speaking": true, "probability": 0.92, "timestamp": 1270.5}
|
||||||
|
{"type": "vad", "event": "speech_end", "speaking": false, "probability": 0.35, "timestamp": 3500.0}
|
||||||
|
|
||||||
|
// Transcription events
|
||||||
|
{"type": "partial", "text": "Hello how are", "user_id": "123", "timestamp": 2000.0}
|
||||||
|
{"type": "final", "text": "Hello how are you?", "user_id": "123", "timestamp": 3500.0}
|
||||||
|
|
||||||
|
// Interruption detection
|
||||||
|
{"type": "interruption", "probability": 0.92, "timestamp": 1500.0}
|
||||||
|
```
|
||||||
|
|
||||||
|
### HTTP GET: `/health`
|
||||||
|
|
||||||
|
Health check with model status.
|
||||||
|
|
||||||
|
**Response:**
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"status": "healthy",
|
||||||
|
"models": {
|
||||||
|
"vad": {"loaded": true, "device": "cpu"},
|
||||||
|
"whisper": {"loaded": true, "model": "small", "device": "cuda"}
|
||||||
|
},
|
||||||
|
"sessions": {
|
||||||
|
"active": 2,
|
||||||
|
"users": ["user123", "user456"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Configuration
|
||||||
|
|
||||||
|
### VAD Parameters (Conservative)
|
||||||
|
|
||||||
|
- **Threshold**: 0.5 (speech probability)
|
||||||
|
- **Min speech duration**: 250ms (avoid false triggers)
|
||||||
|
- **Min silence duration**: 500ms (don't cut off mid-sentence)
|
||||||
|
- **Speech padding**: 30ms (context around speech)
|
||||||
|
|
||||||
|
### Whisper Parameters
|
||||||
|
|
||||||
|
- **Model**: small (balanced speed/quality, ~500MB VRAM)
|
||||||
|
- **Compute**: float16 (GPU optimization)
|
||||||
|
- **Language**: en (English)
|
||||||
|
- **Beam size**: 5 (quality/speed balance)
|
||||||
|
|
||||||
|
## Usage Example
|
||||||
|
|
||||||
|
```python
|
||||||
|
import asyncio
|
||||||
|
import websockets
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
async def stream_audio():
|
||||||
|
uri = "ws://localhost:8001/ws/stt/user123"
|
||||||
|
|
||||||
|
async with websockets.connect(uri) as websocket:
|
||||||
|
# Wait for ready
|
||||||
|
ready = await websocket.recv()
|
||||||
|
print(ready)
|
||||||
|
|
||||||
|
# Stream audio chunks (16kHz, 20ms chunks)
|
||||||
|
for audio_chunk in audio_stream:
|
||||||
|
# Convert to bytes (int16)
|
||||||
|
audio_bytes = audio_chunk.astype(np.int16).tobytes()
|
||||||
|
await websocket.send(audio_bytes)
|
||||||
|
|
||||||
|
# Receive events
|
||||||
|
event = await websocket.recv()
|
||||||
|
print(event)
|
||||||
|
|
||||||
|
asyncio.run(stream_audio())
|
||||||
|
```
|
||||||
|
|
||||||
|
## Docker Setup
|
||||||
|
|
||||||
|
### Build
|
||||||
|
```bash
|
||||||
|
docker-compose build miku-stt
|
||||||
|
```
|
||||||
|
|
||||||
|
### Run
|
||||||
|
```bash
|
||||||
|
docker-compose up -d miku-stt
|
||||||
|
```
|
||||||
|
|
||||||
|
### Logs
|
||||||
|
```bash
|
||||||
|
docker-compose logs -f miku-stt
|
||||||
|
```
|
||||||
|
|
||||||
|
### Test
|
||||||
|
```bash
|
||||||
|
curl http://localhost:8001/health
|
||||||
|
```
|
||||||
|
|
||||||
|
## GPU Sharing with Soprano
|
||||||
|
|
||||||
|
Both STT (Whisper) and TTS (Soprano) run on GTX 1660 but at different times:
|
||||||
|
|
||||||
|
1. **User speaking** → Whisper active, Soprano idle
|
||||||
|
2. **LLM processing** → Both idle
|
||||||
|
3. **Miku speaking** → Soprano active, Whisper idle (VAD monitoring only)
|
||||||
|
|
||||||
|
Interruption detection runs VAD continuously but doesn't use GPU.
|
||||||
|
|
||||||
|
## Performance
|
||||||
|
|
||||||
|
- **VAD latency**: 10-20ms per chunk (CPU)
|
||||||
|
- **Whisper latency**: ~1-2s for 2s audio (GPU)
|
||||||
|
- **Memory usage**:
|
||||||
|
- Silero VAD: ~100MB (CPU)
|
||||||
|
- Faster-Whisper small: ~500MB (GPU VRAM)
|
||||||
|
|
||||||
|
## Future Improvements
|
||||||
|
|
||||||
|
- [ ] Multi-language support (auto-detect)
|
||||||
|
- [ ] Word-level timestamps for better sync
|
||||||
|
- [ ] Custom vocabulary/prompt tuning
|
||||||
|
- [ ] Speaker diarization (multiple speakers)
|
||||||
|
- [ ] Noise suppression preprocessing
|
||||||
Binary file not shown.
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,239 @@
|
|||||||
|
{
|
||||||
|
"alignment_heads": [
|
||||||
|
[
|
||||||
|
5,
|
||||||
|
3
|
||||||
|
],
|
||||||
|
[
|
||||||
|
5,
|
||||||
|
9
|
||||||
|
],
|
||||||
|
[
|
||||||
|
8,
|
||||||
|
0
|
||||||
|
],
|
||||||
|
[
|
||||||
|
8,
|
||||||
|
4
|
||||||
|
],
|
||||||
|
[
|
||||||
|
8,
|
||||||
|
7
|
||||||
|
],
|
||||||
|
[
|
||||||
|
8,
|
||||||
|
8
|
||||||
|
],
|
||||||
|
[
|
||||||
|
9,
|
||||||
|
0
|
||||||
|
],
|
||||||
|
[
|
||||||
|
9,
|
||||||
|
7
|
||||||
|
],
|
||||||
|
[
|
||||||
|
9,
|
||||||
|
9
|
||||||
|
],
|
||||||
|
[
|
||||||
|
10,
|
||||||
|
5
|
||||||
|
]
|
||||||
|
],
|
||||||
|
"lang_ids": [
|
||||||
|
50259,
|
||||||
|
50260,
|
||||||
|
50261,
|
||||||
|
50262,
|
||||||
|
50263,
|
||||||
|
50264,
|
||||||
|
50265,
|
||||||
|
50266,
|
||||||
|
50267,
|
||||||
|
50268,
|
||||||
|
50269,
|
||||||
|
50270,
|
||||||
|
50271,
|
||||||
|
50272,
|
||||||
|
50273,
|
||||||
|
50274,
|
||||||
|
50275,
|
||||||
|
50276,
|
||||||
|
50277,
|
||||||
|
50278,
|
||||||
|
50279,
|
||||||
|
50280,
|
||||||
|
50281,
|
||||||
|
50282,
|
||||||
|
50283,
|
||||||
|
50284,
|
||||||
|
50285,
|
||||||
|
50286,
|
||||||
|
50287,
|
||||||
|
50288,
|
||||||
|
50289,
|
||||||
|
50290,
|
||||||
|
50291,
|
||||||
|
50292,
|
||||||
|
50293,
|
||||||
|
50294,
|
||||||
|
50295,
|
||||||
|
50296,
|
||||||
|
50297,
|
||||||
|
50298,
|
||||||
|
50299,
|
||||||
|
50300,
|
||||||
|
50301,
|
||||||
|
50302,
|
||||||
|
50303,
|
||||||
|
50304,
|
||||||
|
50305,
|
||||||
|
50306,
|
||||||
|
50307,
|
||||||
|
50308,
|
||||||
|
50309,
|
||||||
|
50310,
|
||||||
|
50311,
|
||||||
|
50312,
|
||||||
|
50313,
|
||||||
|
50314,
|
||||||
|
50315,
|
||||||
|
50316,
|
||||||
|
50317,
|
||||||
|
50318,
|
||||||
|
50319,
|
||||||
|
50320,
|
||||||
|
50321,
|
||||||
|
50322,
|
||||||
|
50323,
|
||||||
|
50324,
|
||||||
|
50325,
|
||||||
|
50326,
|
||||||
|
50327,
|
||||||
|
50328,
|
||||||
|
50329,
|
||||||
|
50330,
|
||||||
|
50331,
|
||||||
|
50332,
|
||||||
|
50333,
|
||||||
|
50334,
|
||||||
|
50335,
|
||||||
|
50336,
|
||||||
|
50337,
|
||||||
|
50338,
|
||||||
|
50339,
|
||||||
|
50340,
|
||||||
|
50341,
|
||||||
|
50342,
|
||||||
|
50343,
|
||||||
|
50344,
|
||||||
|
50345,
|
||||||
|
50346,
|
||||||
|
50347,
|
||||||
|
50348,
|
||||||
|
50349,
|
||||||
|
50350,
|
||||||
|
50351,
|
||||||
|
50352,
|
||||||
|
50353,
|
||||||
|
50354,
|
||||||
|
50355,
|
||||||
|
50356,
|
||||||
|
50357
|
||||||
|
],
|
||||||
|
"suppress_ids": [
|
||||||
|
1,
|
||||||
|
2,
|
||||||
|
7,
|
||||||
|
8,
|
||||||
|
9,
|
||||||
|
10,
|
||||||
|
14,
|
||||||
|
25,
|
||||||
|
26,
|
||||||
|
27,
|
||||||
|
28,
|
||||||
|
29,
|
||||||
|
31,
|
||||||
|
58,
|
||||||
|
59,
|
||||||
|
60,
|
||||||
|
61,
|
||||||
|
62,
|
||||||
|
63,
|
||||||
|
90,
|
||||||
|
91,
|
||||||
|
92,
|
||||||
|
93,
|
||||||
|
359,
|
||||||
|
503,
|
||||||
|
522,
|
||||||
|
542,
|
||||||
|
873,
|
||||||
|
893,
|
||||||
|
902,
|
||||||
|
918,
|
||||||
|
922,
|
||||||
|
931,
|
||||||
|
1350,
|
||||||
|
1853,
|
||||||
|
1982,
|
||||||
|
2460,
|
||||||
|
2627,
|
||||||
|
3246,
|
||||||
|
3253,
|
||||||
|
3268,
|
||||||
|
3536,
|
||||||
|
3846,
|
||||||
|
3961,
|
||||||
|
4183,
|
||||||
|
4667,
|
||||||
|
6585,
|
||||||
|
6647,
|
||||||
|
7273,
|
||||||
|
9061,
|
||||||
|
9383,
|
||||||
|
10428,
|
||||||
|
10929,
|
||||||
|
11938,
|
||||||
|
12033,
|
||||||
|
12331,
|
||||||
|
12562,
|
||||||
|
13793,
|
||||||
|
14157,
|
||||||
|
14635,
|
||||||
|
15265,
|
||||||
|
15618,
|
||||||
|
16553,
|
||||||
|
16604,
|
||||||
|
18362,
|
||||||
|
18956,
|
||||||
|
20075,
|
||||||
|
21675,
|
||||||
|
22520,
|
||||||
|
26130,
|
||||||
|
26161,
|
||||||
|
26435,
|
||||||
|
28279,
|
||||||
|
29464,
|
||||||
|
31650,
|
||||||
|
32302,
|
||||||
|
32470,
|
||||||
|
36865,
|
||||||
|
42863,
|
||||||
|
47425,
|
||||||
|
49870,
|
||||||
|
50254,
|
||||||
|
50258,
|
||||||
|
50358,
|
||||||
|
50359,
|
||||||
|
50360,
|
||||||
|
50361,
|
||||||
|
50362
|
||||||
|
],
|
||||||
|
"suppress_ids_begin": [
|
||||||
|
220,
|
||||||
|
50257
|
||||||
|
]
|
||||||
|
}
|
||||||
@@ -0,0 +1 @@
|
|||||||
|
536b0662742c02347bc0e980a01041f333bce120
|
||||||
@@ -0,0 +1 @@
|
|||||||
|
../../blobs/e5047537059bd8f182d9ca64c470201585015187
|
||||||
@@ -0,0 +1 @@
|
|||||||
|
../../blobs/3e305921506d8872816023e4c273e75d2419fb89b24da97b4fe7bce14170d671
|
||||||
@@ -0,0 +1 @@
|
|||||||
|
../../blobs/7818adb6de9fa3064d3ff81226fdd675be1f6344
|
||||||
@@ -0,0 +1 @@
|
|||||||
|
../../blobs/c9074644d9d1205686f16d411564729461324b75
|
||||||
25
stt/requirements.txt
Normal file
25
stt/requirements.txt
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
# STT Container Requirements
|
||||||
|
|
||||||
|
# Core dependencies
|
||||||
|
fastapi==0.115.6
|
||||||
|
uvicorn[standard]==0.32.1
|
||||||
|
websockets==14.1
|
||||||
|
aiohttp==3.11.11
|
||||||
|
|
||||||
|
# Audio processing
|
||||||
|
numpy==2.2.2
|
||||||
|
soundfile==0.12.1
|
||||||
|
librosa==0.10.2.post1
|
||||||
|
|
||||||
|
# VAD (CPU)
|
||||||
|
torch==2.9.1 # Latest PyTorch
|
||||||
|
torchaudio==2.9.1
|
||||||
|
silero-vad==5.1.2
|
||||||
|
|
||||||
|
# STT (GPU)
|
||||||
|
faster-whisper==1.2.1 # Latest version (Oct 31, 2025)
|
||||||
|
ctranslate2==4.5.0 # Required by faster-whisper
|
||||||
|
|
||||||
|
# Utilities
|
||||||
|
python-multipart==0.0.20
|
||||||
|
pydantic==2.10.4
|
||||||
361
stt/stt_server.py
Normal file
361
stt/stt_server.py
Normal file
@@ -0,0 +1,361 @@
|
|||||||
|
"""
|
||||||
|
STT Server
|
||||||
|
|
||||||
|
FastAPI WebSocket server for real-time speech-to-text.
|
||||||
|
Combines Silero VAD (CPU) and Faster-Whisper (GPU) for efficient transcription.
|
||||||
|
|
||||||
|
Architecture:
|
||||||
|
- VAD runs continuously on every audio chunk (CPU)
|
||||||
|
- Whisper transcribes only when VAD detects speech (GPU)
|
||||||
|
- Supports multiple concurrent users
|
||||||
|
- Sends partial and final transcripts via WebSocket
|
||||||
|
"""
|
||||||
|
|
||||||
|
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
import numpy as np
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from typing import Dict, Optional
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from vad_processor import VADProcessor
|
||||||
|
from whisper_transcriber import WhisperTranscriber
|
||||||
|
|
||||||
|
# Configure logging
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format='[%(levelname)s] [%(name)s] %(message)s'
|
||||||
|
)
|
||||||
|
logger = logging.getLogger('stt_server')
|
||||||
|
|
||||||
|
# Initialize FastAPI app
|
||||||
|
app = FastAPI(title="Miku STT Server", version="1.0.0")
|
||||||
|
|
||||||
|
# Global instances (initialized on startup)
|
||||||
|
vad_processor: Optional[VADProcessor] = None
|
||||||
|
whisper_transcriber: Optional[WhisperTranscriber] = None
|
||||||
|
|
||||||
|
# User session tracking
|
||||||
|
user_sessions: Dict[str, dict] = {}
|
||||||
|
|
||||||
|
|
||||||
|
class UserSTTSession:
|
||||||
|
"""Manages STT state for a single user."""
|
||||||
|
|
||||||
|
def __init__(self, user_id: str, websocket: WebSocket):
|
||||||
|
self.user_id = user_id
|
||||||
|
self.websocket = websocket
|
||||||
|
self.audio_buffer = []
|
||||||
|
self.is_speaking = False
|
||||||
|
self.timestamp_ms = 0.0
|
||||||
|
self.transcript_buffer = []
|
||||||
|
self.last_transcript = ""
|
||||||
|
|
||||||
|
logger.info(f"Created STT session for user {user_id}")
|
||||||
|
|
||||||
|
async def process_audio_chunk(self, audio_data: bytes):
|
||||||
|
"""
|
||||||
|
Process incoming audio chunk.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
audio_data: Raw PCM audio (int16, 16kHz mono)
|
||||||
|
"""
|
||||||
|
# Convert bytes to numpy array (int16)
|
||||||
|
audio_np = np.frombuffer(audio_data, dtype=np.int16)
|
||||||
|
|
||||||
|
# Calculate timestamp (assuming 16kHz, 20ms chunks = 320 samples)
|
||||||
|
chunk_duration_ms = (len(audio_np) / 16000) * 1000
|
||||||
|
self.timestamp_ms += chunk_duration_ms
|
||||||
|
|
||||||
|
# Run VAD on chunk
|
||||||
|
vad_event = vad_processor.detect_speech_segment(audio_np, self.timestamp_ms)
|
||||||
|
|
||||||
|
if vad_event:
|
||||||
|
event_type = vad_event["event"]
|
||||||
|
probability = vad_event["probability"]
|
||||||
|
|
||||||
|
# Send VAD event to client
|
||||||
|
await self.websocket.send_json({
|
||||||
|
"type": "vad",
|
||||||
|
"event": event_type,
|
||||||
|
"speaking": event_type in ["speech_start", "speaking"],
|
||||||
|
"probability": probability,
|
||||||
|
"timestamp": self.timestamp_ms
|
||||||
|
})
|
||||||
|
|
||||||
|
# Handle speech events
|
||||||
|
if event_type == "speech_start":
|
||||||
|
self.is_speaking = True
|
||||||
|
self.audio_buffer = [audio_np]
|
||||||
|
logger.debug(f"User {self.user_id} started speaking")
|
||||||
|
|
||||||
|
elif event_type == "speaking":
|
||||||
|
if self.is_speaking:
|
||||||
|
self.audio_buffer.append(audio_np)
|
||||||
|
|
||||||
|
# Transcribe partial every ~2 seconds for streaming
|
||||||
|
total_samples = sum(len(chunk) for chunk in self.audio_buffer)
|
||||||
|
duration_s = total_samples / 16000
|
||||||
|
|
||||||
|
if duration_s >= 2.0:
|
||||||
|
await self._transcribe_partial()
|
||||||
|
|
||||||
|
elif event_type == "speech_end":
|
||||||
|
self.is_speaking = False
|
||||||
|
|
||||||
|
# Transcribe final
|
||||||
|
await self._transcribe_final()
|
||||||
|
|
||||||
|
# Clear buffer
|
||||||
|
self.audio_buffer = []
|
||||||
|
logger.debug(f"User {self.user_id} stopped speaking")
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Still accumulate audio if speaking
|
||||||
|
if self.is_speaking:
|
||||||
|
self.audio_buffer.append(audio_np)
|
||||||
|
|
||||||
|
async def _transcribe_partial(self):
|
||||||
|
"""Transcribe accumulated audio and send partial result."""
|
||||||
|
if not self.audio_buffer:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Concatenate audio
|
||||||
|
audio_full = np.concatenate(self.audio_buffer)
|
||||||
|
|
||||||
|
# Transcribe asynchronously
|
||||||
|
try:
|
||||||
|
text = await whisper_transcriber.transcribe_async(
|
||||||
|
audio_full,
|
||||||
|
sample_rate=16000,
|
||||||
|
initial_prompt=self.last_transcript # Use previous for context
|
||||||
|
)
|
||||||
|
|
||||||
|
if text and text != self.last_transcript:
|
||||||
|
self.last_transcript = text
|
||||||
|
|
||||||
|
# Send partial transcript
|
||||||
|
await self.websocket.send_json({
|
||||||
|
"type": "partial",
|
||||||
|
"text": text,
|
||||||
|
"user_id": self.user_id,
|
||||||
|
"timestamp": self.timestamp_ms
|
||||||
|
})
|
||||||
|
|
||||||
|
logger.info(f"Partial [{self.user_id}]: {text}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Partial transcription failed: {e}", exc_info=True)
|
||||||
|
|
||||||
|
async def _transcribe_final(self):
|
||||||
|
"""Transcribe final accumulated audio."""
|
||||||
|
if not self.audio_buffer:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Concatenate all audio
|
||||||
|
audio_full = np.concatenate(self.audio_buffer)
|
||||||
|
|
||||||
|
try:
|
||||||
|
text = await whisper_transcriber.transcribe_async(
|
||||||
|
audio_full,
|
||||||
|
sample_rate=16000
|
||||||
|
)
|
||||||
|
|
||||||
|
if text:
|
||||||
|
self.last_transcript = text
|
||||||
|
|
||||||
|
# Send final transcript
|
||||||
|
await self.websocket.send_json({
|
||||||
|
"type": "final",
|
||||||
|
"text": text,
|
||||||
|
"user_id": self.user_id,
|
||||||
|
"timestamp": self.timestamp_ms
|
||||||
|
})
|
||||||
|
|
||||||
|
logger.info(f"Final [{self.user_id}]: {text}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Final transcription failed: {e}", exc_info=True)
|
||||||
|
|
||||||
|
async def check_interruption(self, audio_data: bytes) -> bool:
|
||||||
|
"""
|
||||||
|
Check if user is interrupting (for use during Miku's speech).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
audio_data: Raw PCM audio chunk
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if interruption detected
|
||||||
|
"""
|
||||||
|
audio_np = np.frombuffer(audio_data, dtype=np.int16)
|
||||||
|
speech_prob, is_speaking = vad_processor.process_chunk(audio_np)
|
||||||
|
|
||||||
|
# Interruption: high probability sustained for threshold duration
|
||||||
|
if speech_prob > 0.7: # Higher threshold for interruption
|
||||||
|
await self.websocket.send_json({
|
||||||
|
"type": "interruption",
|
||||||
|
"probability": speech_prob,
|
||||||
|
"timestamp": self.timestamp_ms
|
||||||
|
})
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
@app.on_event("startup")
|
||||||
|
async def startup_event():
|
||||||
|
"""Initialize models on server startup."""
|
||||||
|
global vad_processor, whisper_transcriber
|
||||||
|
|
||||||
|
logger.info("=" * 50)
|
||||||
|
logger.info("Initializing Miku STT Server")
|
||||||
|
logger.info("=" * 50)
|
||||||
|
|
||||||
|
# Initialize VAD (CPU)
|
||||||
|
logger.info("Loading Silero VAD model (CPU)...")
|
||||||
|
vad_processor = VADProcessor(
|
||||||
|
sample_rate=16000,
|
||||||
|
threshold=0.5,
|
||||||
|
min_speech_duration_ms=250, # Conservative
|
||||||
|
min_silence_duration_ms=500 # Conservative
|
||||||
|
)
|
||||||
|
logger.info("✓ VAD ready")
|
||||||
|
|
||||||
|
# Initialize Whisper (GPU with cuDNN)
|
||||||
|
logger.info("Loading Faster-Whisper model (GPU)...")
|
||||||
|
whisper_transcriber = WhisperTranscriber(
|
||||||
|
model_size="small",
|
||||||
|
device="cuda",
|
||||||
|
compute_type="float16",
|
||||||
|
language="en"
|
||||||
|
)
|
||||||
|
logger.info("✓ Whisper ready")
|
||||||
|
|
||||||
|
logger.info("=" * 50)
|
||||||
|
logger.info("STT Server ready to accept connections")
|
||||||
|
logger.info("=" * 50)
|
||||||
|
|
||||||
|
|
||||||
|
@app.on_event("shutdown")
|
||||||
|
async def shutdown_event():
|
||||||
|
"""Cleanup on server shutdown."""
|
||||||
|
logger.info("Shutting down STT server...")
|
||||||
|
|
||||||
|
if whisper_transcriber:
|
||||||
|
whisper_transcriber.cleanup()
|
||||||
|
|
||||||
|
logger.info("STT server shutdown complete")
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/")
|
||||||
|
async def root():
|
||||||
|
"""Health check endpoint."""
|
||||||
|
return {
|
||||||
|
"service": "Miku STT Server",
|
||||||
|
"status": "running",
|
||||||
|
"vad_ready": vad_processor is not None,
|
||||||
|
"whisper_ready": whisper_transcriber is not None,
|
||||||
|
"active_sessions": len(user_sessions)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/health")
|
||||||
|
async def health():
|
||||||
|
"""Detailed health check."""
|
||||||
|
return {
|
||||||
|
"status": "healthy",
|
||||||
|
"models": {
|
||||||
|
"vad": {
|
||||||
|
"loaded": vad_processor is not None,
|
||||||
|
"device": "cpu"
|
||||||
|
},
|
||||||
|
"whisper": {
|
||||||
|
"loaded": whisper_transcriber is not None,
|
||||||
|
"model": "small",
|
||||||
|
"device": "cuda"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"sessions": {
|
||||||
|
"active": len(user_sessions),
|
||||||
|
"users": list(user_sessions.keys())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@app.websocket("/ws/stt/{user_id}")
|
||||||
|
async def websocket_stt(websocket: WebSocket, user_id: str):
|
||||||
|
"""
|
||||||
|
WebSocket endpoint for real-time STT.
|
||||||
|
|
||||||
|
Client sends: Raw PCM audio (int16, 16kHz mono, 20ms chunks)
|
||||||
|
Server sends: JSON events:
|
||||||
|
- {"type": "vad", "event": "speech_start|speaking|speech_end", ...}
|
||||||
|
- {"type": "partial", "text": "...", ...}
|
||||||
|
- {"type": "final", "text": "...", ...}
|
||||||
|
- {"type": "interruption", "probability": 0.xx}
|
||||||
|
"""
|
||||||
|
await websocket.accept()
|
||||||
|
logger.info(f"STT WebSocket connected: user {user_id}")
|
||||||
|
|
||||||
|
# Create session
|
||||||
|
session = UserSTTSession(user_id, websocket)
|
||||||
|
user_sessions[user_id] = session
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Send ready message
|
||||||
|
await websocket.send_json({
|
||||||
|
"type": "ready",
|
||||||
|
"user_id": user_id,
|
||||||
|
"message": "STT session started"
|
||||||
|
})
|
||||||
|
|
||||||
|
# Main loop: receive audio chunks
|
||||||
|
while True:
|
||||||
|
# Receive binary audio data
|
||||||
|
data = await websocket.receive_bytes()
|
||||||
|
|
||||||
|
# Process audio chunk
|
||||||
|
await session.process_audio_chunk(data)
|
||||||
|
|
||||||
|
except WebSocketDisconnect:
|
||||||
|
logger.info(f"User {user_id} disconnected")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in STT WebSocket for user {user_id}: {e}", exc_info=True)
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Cleanup session
|
||||||
|
if user_id in user_sessions:
|
||||||
|
del user_sessions[user_id]
|
||||||
|
logger.info(f"STT session ended for user {user_id}")
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/interrupt/check")
|
||||||
|
async def check_interruption(user_id: str):
|
||||||
|
"""
|
||||||
|
Check if user is interrupting (for use during Miku's speech).
|
||||||
|
|
||||||
|
Query param:
|
||||||
|
user_id: Discord user ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
{"interrupting": bool, "probability": float}
|
||||||
|
"""
|
||||||
|
session = user_sessions.get(user_id)
|
||||||
|
|
||||||
|
if not session:
|
||||||
|
raise HTTPException(status_code=404, detail="User session not found")
|
||||||
|
|
||||||
|
# Get current VAD state
|
||||||
|
vad_state = vad_processor.get_state()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"interrupting": vad_state["speaking"],
|
||||||
|
"user_id": user_id
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import uvicorn
|
||||||
|
uvicorn.run(app, host="0.0.0.0", port=8000, log_level="info")
|
||||||
206
stt/test_stt.py
Normal file
206
stt/test_stt.py
Normal file
@@ -0,0 +1,206 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Test script for STT WebSocket server.
|
||||||
|
Sends test audio and receives VAD/transcription events.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import websockets
|
||||||
|
import numpy as np
|
||||||
|
import json
|
||||||
|
import wave
|
||||||
|
|
||||||
|
|
||||||
|
async def test_websocket():
|
||||||
|
"""Test STT WebSocket with generated audio."""
|
||||||
|
|
||||||
|
uri = "ws://localhost:8001/ws/stt/test_user"
|
||||||
|
|
||||||
|
print("🔌 Connecting to STT WebSocket...")
|
||||||
|
|
||||||
|
async with websockets.connect(uri) as websocket:
|
||||||
|
# Wait for ready message
|
||||||
|
ready_msg = await websocket.recv()
|
||||||
|
ready = json.loads(ready_msg)
|
||||||
|
print(f"✅ {ready}")
|
||||||
|
|
||||||
|
# Generate test audio: 2 seconds of 440Hz tone (A note)
|
||||||
|
# This simulates speech-like audio
|
||||||
|
print("\n🎵 Generating test audio (2 seconds, 440Hz tone)...")
|
||||||
|
sample_rate = 16000
|
||||||
|
duration = 2.0
|
||||||
|
frequency = 440 # A4 note
|
||||||
|
|
||||||
|
t = np.linspace(0, duration, int(sample_rate * duration), False)
|
||||||
|
audio = np.sin(frequency * 2 * np.pi * t)
|
||||||
|
|
||||||
|
# Convert to int16
|
||||||
|
audio_int16 = (audio * 32767).astype(np.int16)
|
||||||
|
|
||||||
|
# Send in 20ms chunks (320 samples at 16kHz)
|
||||||
|
chunk_size = 320 # 20ms chunks
|
||||||
|
total_chunks = len(audio_int16) // chunk_size
|
||||||
|
|
||||||
|
print(f"📤 Sending {total_chunks} audio chunks (20ms each)...\n")
|
||||||
|
|
||||||
|
# Send chunks and receive events
|
||||||
|
for i in range(0, len(audio_int16), chunk_size):
|
||||||
|
chunk = audio_int16[i:i+chunk_size]
|
||||||
|
|
||||||
|
# Send audio chunk
|
||||||
|
await websocket.send(chunk.tobytes())
|
||||||
|
|
||||||
|
# Try to receive events (non-blocking)
|
||||||
|
try:
|
||||||
|
response = await asyncio.wait_for(
|
||||||
|
websocket.recv(),
|
||||||
|
timeout=0.01
|
||||||
|
)
|
||||||
|
event = json.loads(response)
|
||||||
|
|
||||||
|
# Print VAD events
|
||||||
|
if event['type'] == 'vad':
|
||||||
|
emoji = "🟢" if event['speaking'] else "⚪"
|
||||||
|
print(f"{emoji} VAD: {event['event']} "
|
||||||
|
f"(prob={event['probability']:.3f}, "
|
||||||
|
f"t={event['timestamp']:.1f}ms)")
|
||||||
|
|
||||||
|
# Print transcription events
|
||||||
|
elif event['type'] == 'partial':
|
||||||
|
print(f"📝 Partial: \"{event['text']}\"")
|
||||||
|
|
||||||
|
elif event['type'] == 'final':
|
||||||
|
print(f"✅ Final: \"{event['text']}\"")
|
||||||
|
|
||||||
|
elif event['type'] == 'interruption':
|
||||||
|
print(f"⚠️ Interruption detected! (prob={event['probability']:.3f})")
|
||||||
|
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
pass # No event yet
|
||||||
|
|
||||||
|
# Small delay between chunks
|
||||||
|
await asyncio.sleep(0.02)
|
||||||
|
|
||||||
|
print("\n✅ Test audio sent successfully!")
|
||||||
|
|
||||||
|
# Wait a bit for final transcription
|
||||||
|
print("⏳ Waiting for final transcription...")
|
||||||
|
|
||||||
|
for _ in range(50): # Wait up to 1 second
|
||||||
|
try:
|
||||||
|
response = await asyncio.wait_for(
|
||||||
|
websocket.recv(),
|
||||||
|
timeout=0.02
|
||||||
|
)
|
||||||
|
event = json.loads(response)
|
||||||
|
|
||||||
|
if event['type'] == 'final':
|
||||||
|
print(f"\n✅ FINAL TRANSCRIPT: \"{event['text']}\"")
|
||||||
|
break
|
||||||
|
elif event['type'] == 'vad':
|
||||||
|
emoji = "🟢" if event['speaking'] else "⚪"
|
||||||
|
print(f"{emoji} VAD: {event['event']} (prob={event['probability']:.3f})")
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
print("\n✅ WebSocket test complete!")
|
||||||
|
|
||||||
|
|
||||||
|
async def test_with_sample_audio():
|
||||||
|
"""Test with actual speech audio file (if available)."""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
|
||||||
|
if len(sys.argv) > 1 and os.path.exists(sys.argv[1]):
|
||||||
|
audio_file = sys.argv[1]
|
||||||
|
print(f"📂 Loading audio from: {audio_file}")
|
||||||
|
|
||||||
|
# Load WAV file
|
||||||
|
with wave.open(audio_file, 'rb') as wav:
|
||||||
|
sample_rate = wav.getframerate()
|
||||||
|
n_channels = wav.getnchannels()
|
||||||
|
audio_data = wav.readframes(wav.getnframes())
|
||||||
|
|
||||||
|
# Convert to numpy array
|
||||||
|
audio_np = np.frombuffer(audio_data, dtype=np.int16)
|
||||||
|
|
||||||
|
# If stereo, convert to mono
|
||||||
|
if n_channels == 2:
|
||||||
|
audio_np = audio_np.reshape(-1, 2).mean(axis=1).astype(np.int16)
|
||||||
|
|
||||||
|
# Resample to 16kHz if needed
|
||||||
|
if sample_rate != 16000:
|
||||||
|
print(f"⚠️ Resampling from {sample_rate}Hz to 16000Hz...")
|
||||||
|
import librosa
|
||||||
|
audio_float = audio_np.astype(np.float32) / 32768.0
|
||||||
|
audio_resampled = librosa.resample(
|
||||||
|
audio_float,
|
||||||
|
orig_sr=sample_rate,
|
||||||
|
target_sr=16000
|
||||||
|
)
|
||||||
|
audio_np = (audio_resampled * 32767).astype(np.int16)
|
||||||
|
|
||||||
|
print(f"✅ Audio loaded: {len(audio_np)/16000:.2f} seconds")
|
||||||
|
|
||||||
|
# Send to STT
|
||||||
|
uri = "ws://localhost:8001/ws/stt/test_user"
|
||||||
|
|
||||||
|
async with websockets.connect(uri) as websocket:
|
||||||
|
ready_msg = await websocket.recv()
|
||||||
|
print(f"✅ {json.loads(ready_msg)}")
|
||||||
|
|
||||||
|
# Send in chunks
|
||||||
|
chunk_size = 320 # 20ms at 16kHz
|
||||||
|
|
||||||
|
for i in range(0, len(audio_np), chunk_size):
|
||||||
|
chunk = audio_np[i:i+chunk_size]
|
||||||
|
await websocket.send(chunk.tobytes())
|
||||||
|
|
||||||
|
# Receive events
|
||||||
|
try:
|
||||||
|
response = await asyncio.wait_for(
|
||||||
|
websocket.recv(),
|
||||||
|
timeout=0.01
|
||||||
|
)
|
||||||
|
event = json.loads(response)
|
||||||
|
|
||||||
|
if event['type'] == 'vad':
|
||||||
|
emoji = "🟢" if event['speaking'] else "⚪"
|
||||||
|
print(f"{emoji} VAD: {event['event']} (prob={event['probability']:.3f})")
|
||||||
|
elif event['type'] in ['partial', 'final']:
|
||||||
|
print(f"📝 {event['type'].title()}: \"{event['text']}\"")
|
||||||
|
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
await asyncio.sleep(0.02)
|
||||||
|
|
||||||
|
# Wait for final
|
||||||
|
for _ in range(100):
|
||||||
|
try:
|
||||||
|
response = await asyncio.wait_for(websocket.recv(), timeout=0.02)
|
||||||
|
event = json.loads(response)
|
||||||
|
if event['type'] == 'final':
|
||||||
|
print(f"\n✅ FINAL: \"{event['text']}\"")
|
||||||
|
break
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import sys
|
||||||
|
|
||||||
|
print("=" * 60)
|
||||||
|
print(" Miku STT WebSocket Test")
|
||||||
|
print("=" * 60)
|
||||||
|
print()
|
||||||
|
|
||||||
|
if len(sys.argv) > 1:
|
||||||
|
print("📁 Testing with audio file...")
|
||||||
|
asyncio.run(test_with_sample_audio())
|
||||||
|
else:
|
||||||
|
print("🎵 Testing with generated tone...")
|
||||||
|
print(" (To test with audio file: python test_stt.py audio.wav)")
|
||||||
|
print()
|
||||||
|
asyncio.run(test_websocket())
|
||||||
204
stt/vad_processor.py
Normal file
204
stt/vad_processor.py
Normal file
@@ -0,0 +1,204 @@
|
|||||||
|
"""
|
||||||
|
Silero VAD Processor
|
||||||
|
|
||||||
|
Lightweight CPU-based Voice Activity Detection for real-time speech detection.
|
||||||
|
Runs continuously on audio chunks to determine when users are speaking.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
from typing import Tuple, Optional
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger('vad')
|
||||||
|
|
||||||
|
|
||||||
|
class VADProcessor:
|
||||||
|
"""
|
||||||
|
Voice Activity Detection using Silero VAD model.
|
||||||
|
|
||||||
|
Processes audio chunks and returns speech probability.
|
||||||
|
Conservative settings to avoid cutting off speech.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
sample_rate: int = 16000,
|
||||||
|
threshold: float = 0.5,
|
||||||
|
min_speech_duration_ms: int = 250,
|
||||||
|
min_silence_duration_ms: int = 500,
|
||||||
|
speech_pad_ms: int = 30
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize VAD processor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sample_rate: Audio sample rate (must be 8000 or 16000)
|
||||||
|
threshold: Speech probability threshold (0.0-1.0)
|
||||||
|
min_speech_duration_ms: Minimum speech duration to trigger (conservative)
|
||||||
|
min_silence_duration_ms: Minimum silence to end speech (conservative)
|
||||||
|
speech_pad_ms: Padding around speech segments
|
||||||
|
"""
|
||||||
|
self.sample_rate = sample_rate
|
||||||
|
self.threshold = threshold
|
||||||
|
self.min_speech_duration_ms = min_speech_duration_ms
|
||||||
|
self.min_silence_duration_ms = min_silence_duration_ms
|
||||||
|
self.speech_pad_ms = speech_pad_ms
|
||||||
|
|
||||||
|
# Load Silero VAD model (CPU only)
|
||||||
|
logger.info("Loading Silero VAD model (CPU)...")
|
||||||
|
self.model, utils = torch.hub.load(
|
||||||
|
repo_or_dir='snakers4/silero-vad',
|
||||||
|
model='silero_vad',
|
||||||
|
force_reload=False,
|
||||||
|
onnx=False # Use PyTorch model
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract utility functions
|
||||||
|
(self.get_speech_timestamps,
|
||||||
|
self.save_audio,
|
||||||
|
self.read_audio,
|
||||||
|
self.VADIterator,
|
||||||
|
self.collect_chunks) = utils
|
||||||
|
|
||||||
|
# State tracking
|
||||||
|
self.speaking = False
|
||||||
|
self.speech_start_time = None
|
||||||
|
self.silence_start_time = None
|
||||||
|
self.audio_buffer = []
|
||||||
|
|
||||||
|
# Chunk buffer for VAD (Silero needs at least 512 samples)
|
||||||
|
self.vad_buffer = []
|
||||||
|
self.min_vad_samples = 512 # Minimum samples for VAD processing
|
||||||
|
|
||||||
|
logger.info(f"VAD initialized: threshold={threshold}, "
|
||||||
|
f"min_speech={min_speech_duration_ms}ms, "
|
||||||
|
f"min_silence={min_silence_duration_ms}ms")
|
||||||
|
|
||||||
|
def process_chunk(self, audio_chunk: np.ndarray) -> Tuple[float, bool]:
|
||||||
|
"""
|
||||||
|
Process single audio chunk and return speech probability.
|
||||||
|
Buffers small chunks to meet VAD minimum size requirement.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
audio_chunk: Audio data as numpy array (int16 or float32)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(speech_probability, is_speaking): Probability and current speaking state
|
||||||
|
"""
|
||||||
|
# Convert to float32 if needed
|
||||||
|
if audio_chunk.dtype == np.int16:
|
||||||
|
audio_chunk = audio_chunk.astype(np.float32) / 32768.0
|
||||||
|
|
||||||
|
# Add to buffer
|
||||||
|
self.vad_buffer.append(audio_chunk)
|
||||||
|
|
||||||
|
# Check if we have enough samples
|
||||||
|
total_samples = sum(len(chunk) for chunk in self.vad_buffer)
|
||||||
|
|
||||||
|
if total_samples < self.min_vad_samples:
|
||||||
|
# Not enough samples yet, return neutral probability
|
||||||
|
return 0.0, False
|
||||||
|
|
||||||
|
# Concatenate buffer
|
||||||
|
audio_full = np.concatenate(self.vad_buffer)
|
||||||
|
|
||||||
|
# Process with VAD
|
||||||
|
audio_tensor = torch.from_numpy(audio_full)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
speech_prob = self.model(audio_tensor, self.sample_rate).item()
|
||||||
|
|
||||||
|
# Clear buffer after processing
|
||||||
|
self.vad_buffer = []
|
||||||
|
|
||||||
|
# Update speaking state based on probability
|
||||||
|
is_speaking = speech_prob > self.threshold
|
||||||
|
|
||||||
|
return speech_prob, is_speaking
|
||||||
|
|
||||||
|
def detect_speech_segment(
|
||||||
|
self,
|
||||||
|
audio_chunk: np.ndarray,
|
||||||
|
timestamp_ms: float
|
||||||
|
) -> Optional[dict]:
|
||||||
|
"""
|
||||||
|
Process chunk and detect speech start/end events.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
audio_chunk: Audio data
|
||||||
|
timestamp_ms: Current timestamp in milliseconds
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Event dict or None:
|
||||||
|
- {"event": "speech_start", "timestamp": float, "probability": float}
|
||||||
|
- {"event": "speech_end", "timestamp": float, "probability": float}
|
||||||
|
- {"event": "speaking", "probability": float} # Ongoing speech
|
||||||
|
"""
|
||||||
|
speech_prob, is_speaking = self.process_chunk(audio_chunk)
|
||||||
|
|
||||||
|
# Speech started
|
||||||
|
if is_speaking and not self.speaking:
|
||||||
|
if self.speech_start_time is None:
|
||||||
|
self.speech_start_time = timestamp_ms
|
||||||
|
|
||||||
|
# Check if speech duration exceeds minimum
|
||||||
|
speech_duration = timestamp_ms - self.speech_start_time
|
||||||
|
if speech_duration >= self.min_speech_duration_ms:
|
||||||
|
self.speaking = True
|
||||||
|
self.silence_start_time = None
|
||||||
|
logger.debug(f"Speech started at {timestamp_ms}ms, prob={speech_prob:.3f}")
|
||||||
|
return {
|
||||||
|
"event": "speech_start",
|
||||||
|
"timestamp": timestamp_ms,
|
||||||
|
"probability": speech_prob
|
||||||
|
}
|
||||||
|
|
||||||
|
# Speech ongoing
|
||||||
|
elif is_speaking and self.speaking:
|
||||||
|
self.silence_start_time = None # Reset silence timer
|
||||||
|
return {
|
||||||
|
"event": "speaking",
|
||||||
|
"probability": speech_prob,
|
||||||
|
"timestamp": timestamp_ms
|
||||||
|
}
|
||||||
|
|
||||||
|
# Silence detected during speech
|
||||||
|
elif not is_speaking and self.speaking:
|
||||||
|
if self.silence_start_time is None:
|
||||||
|
self.silence_start_time = timestamp_ms
|
||||||
|
|
||||||
|
# Check if silence duration exceeds minimum
|
||||||
|
silence_duration = timestamp_ms - self.silence_start_time
|
||||||
|
if silence_duration >= self.min_silence_duration_ms:
|
||||||
|
self.speaking = False
|
||||||
|
self.speech_start_time = None
|
||||||
|
logger.debug(f"Speech ended at {timestamp_ms}ms, prob={speech_prob:.3f}")
|
||||||
|
return {
|
||||||
|
"event": "speech_end",
|
||||||
|
"timestamp": timestamp_ms,
|
||||||
|
"probability": speech_prob
|
||||||
|
}
|
||||||
|
|
||||||
|
# No speech or insufficient duration
|
||||||
|
else:
|
||||||
|
if not is_speaking:
|
||||||
|
self.speech_start_time = None
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
"""Reset VAD state."""
|
||||||
|
self.speaking = False
|
||||||
|
self.speech_start_time = None
|
||||||
|
self.silence_start_time = None
|
||||||
|
self.audio_buffer.clear()
|
||||||
|
logger.debug("VAD state reset")
|
||||||
|
|
||||||
|
def get_state(self) -> dict:
|
||||||
|
"""Get current VAD state."""
|
||||||
|
return {
|
||||||
|
"speaking": self.speaking,
|
||||||
|
"speech_start_time": self.speech_start_time,
|
||||||
|
"silence_start_time": self.silence_start_time
|
||||||
|
}
|
||||||
193
stt/whisper_transcriber.py
Normal file
193
stt/whisper_transcriber.py
Normal file
@@ -0,0 +1,193 @@
|
|||||||
|
"""
|
||||||
|
Faster-Whisper Transcriber
|
||||||
|
|
||||||
|
GPU-accelerated speech-to-text using faster-whisper (CTranslate2).
|
||||||
|
Supports streaming transcription with partial results.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from faster_whisper import WhisperModel
|
||||||
|
from typing import Iterator, Optional, List
|
||||||
|
import logging
|
||||||
|
import asyncio
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
|
||||||
|
logger = logging.getLogger('whisper')
|
||||||
|
|
||||||
|
|
||||||
|
class WhisperTranscriber:
|
||||||
|
"""
|
||||||
|
Faster-Whisper based transcription with streaming support.
|
||||||
|
|
||||||
|
Runs on GPU (GTX 1660) with small model for balance of speed/quality.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_size: str = "small",
|
||||||
|
device: str = "cuda",
|
||||||
|
compute_type: str = "float16",
|
||||||
|
language: str = "en",
|
||||||
|
beam_size: int = 5
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize Whisper transcriber.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_size: Model size (tiny, base, small, medium, large)
|
||||||
|
device: Device to run on (cuda or cpu)
|
||||||
|
compute_type: Compute precision (float16, int8, int8_float16)
|
||||||
|
language: Language code for transcription
|
||||||
|
beam_size: Beam search size (higher = better quality, slower)
|
||||||
|
"""
|
||||||
|
self.model_size = model_size
|
||||||
|
self.device = device
|
||||||
|
self.compute_type = compute_type
|
||||||
|
self.language = language
|
||||||
|
self.beam_size = beam_size
|
||||||
|
|
||||||
|
logger.info(f"Loading Faster-Whisper model: {model_size} on {device}...")
|
||||||
|
|
||||||
|
# Load model
|
||||||
|
self.model = WhisperModel(
|
||||||
|
model_size,
|
||||||
|
device=device,
|
||||||
|
compute_type=compute_type,
|
||||||
|
download_root="/models"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Thread pool for blocking transcription calls
|
||||||
|
self.executor = ThreadPoolExecutor(max_workers=2)
|
||||||
|
|
||||||
|
logger.info(f"Whisper model loaded: {model_size} ({compute_type})")
|
||||||
|
|
||||||
|
async def transcribe_async(
|
||||||
|
self,
|
||||||
|
audio: np.ndarray,
|
||||||
|
sample_rate: int = 16000,
|
||||||
|
initial_prompt: Optional[str] = None
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Transcribe audio asynchronously (non-blocking).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
audio: Audio data as numpy array (float32)
|
||||||
|
sample_rate: Audio sample rate
|
||||||
|
initial_prompt: Optional prompt to guide transcription
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Transcribed text
|
||||||
|
"""
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
|
||||||
|
# Run transcription in thread pool to avoid blocking
|
||||||
|
result = await loop.run_in_executor(
|
||||||
|
self.executor,
|
||||||
|
self._transcribe_blocking,
|
||||||
|
audio,
|
||||||
|
sample_rate,
|
||||||
|
initial_prompt
|
||||||
|
)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def _transcribe_blocking(
|
||||||
|
self,
|
||||||
|
audio: np.ndarray,
|
||||||
|
sample_rate: int,
|
||||||
|
initial_prompt: Optional[str]
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Blocking transcription call (runs in thread pool).
|
||||||
|
"""
|
||||||
|
# Convert to float32 if needed
|
||||||
|
if audio.dtype != np.float32:
|
||||||
|
audio = audio.astype(np.float32) / 32768.0
|
||||||
|
|
||||||
|
# Transcribe
|
||||||
|
segments, info = self.model.transcribe(
|
||||||
|
audio,
|
||||||
|
language=self.language,
|
||||||
|
beam_size=self.beam_size,
|
||||||
|
initial_prompt=initial_prompt,
|
||||||
|
vad_filter=False, # We handle VAD separately
|
||||||
|
word_timestamps=False # Can enable for word-level timing
|
||||||
|
)
|
||||||
|
|
||||||
|
# Collect all segments
|
||||||
|
text_parts = []
|
||||||
|
for segment in segments:
|
||||||
|
text_parts.append(segment.text.strip())
|
||||||
|
|
||||||
|
full_text = " ".join(text_parts).strip()
|
||||||
|
|
||||||
|
logger.debug(f"Transcribed: '{full_text}' (language: {info.language}, "
|
||||||
|
f"probability: {info.language_probability:.2f})")
|
||||||
|
|
||||||
|
return full_text
|
||||||
|
|
||||||
|
async def transcribe_streaming(
|
||||||
|
self,
|
||||||
|
audio_stream: Iterator[np.ndarray],
|
||||||
|
sample_rate: int = 16000,
|
||||||
|
chunk_duration_s: float = 2.0
|
||||||
|
) -> Iterator[dict]:
|
||||||
|
"""
|
||||||
|
Transcribe audio stream with partial results.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
audio_stream: Iterator yielding audio chunks
|
||||||
|
sample_rate: Audio sample rate
|
||||||
|
chunk_duration_s: Duration of each chunk to transcribe
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
{"type": "partial", "text": "partial transcript"}
|
||||||
|
{"type": "final", "text": "complete transcript"}
|
||||||
|
"""
|
||||||
|
accumulated_audio = []
|
||||||
|
chunk_samples = int(chunk_duration_s * sample_rate)
|
||||||
|
|
||||||
|
async for audio_chunk in audio_stream:
|
||||||
|
accumulated_audio.append(audio_chunk)
|
||||||
|
|
||||||
|
# Check if we have enough audio for transcription
|
||||||
|
total_samples = sum(len(chunk) for chunk in accumulated_audio)
|
||||||
|
|
||||||
|
if total_samples >= chunk_samples:
|
||||||
|
# Concatenate accumulated audio
|
||||||
|
audio_data = np.concatenate(accumulated_audio)
|
||||||
|
|
||||||
|
# Transcribe current accumulated audio
|
||||||
|
text = await self.transcribe_async(audio_data, sample_rate)
|
||||||
|
|
||||||
|
if text:
|
||||||
|
yield {
|
||||||
|
"type": "partial",
|
||||||
|
"text": text,
|
||||||
|
"duration": total_samples / sample_rate
|
||||||
|
}
|
||||||
|
|
||||||
|
# Final transcription of remaining audio
|
||||||
|
if accumulated_audio:
|
||||||
|
audio_data = np.concatenate(accumulated_audio)
|
||||||
|
text = await self.transcribe_async(audio_data, sample_rate)
|
||||||
|
|
||||||
|
if text:
|
||||||
|
yield {
|
||||||
|
"type": "final",
|
||||||
|
"text": text,
|
||||||
|
"duration": len(audio_data) / sample_rate
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_supported_languages(self) -> List[str]:
|
||||||
|
"""Get list of supported language codes."""
|
||||||
|
return [
|
||||||
|
"en", "zh", "de", "es", "ru", "ko", "fr", "ja", "pt", "tr",
|
||||||
|
"pl", "ca", "nl", "ar", "sv", "it", "id", "hi", "fi", "vi",
|
||||||
|
"he", "uk", "el", "ms", "cs", "ro", "da", "hu", "ta", "no"
|
||||||
|
]
|
||||||
|
|
||||||
|
def cleanup(self):
|
||||||
|
"""Cleanup resources."""
|
||||||
|
self.executor.shutdown(wait=True)
|
||||||
|
logger.info("Whisper transcriber cleaned up")
|
||||||
Reference in New Issue
Block a user