Files
miku-discord/cheshire-cat/streaming_benchmark.py

331 lines
11 KiB
Python
Raw Permalink Normal View History

#!/usr/bin/env python3
"""
Streaming Benchmark - TTFB Comparison
Measures Time To First Token (TTFT) for voice chat viability
Compares Cheshire Cat RAG vs Direct Context Loading
"""
import requests
import time
import json
import statistics
from datetime import datetime
from typing import List, Dict
# URLs
CAT_URL = "http://localhost:1865"
LLAMA_SWAP_URL = "http://localhost:8091/v1"
# Test queries
TEST_QUERIES = [
"Hi Miku!",
"What's your favorite food?",
"Tell me about your friends",
"What songs do you sing?",
"How old are you?",
"Who created you?",
"Do you like green onions?",
"What's World is Mine about?",
"Tell me about Rin and Len",
"What do you like to do?"
]
# Load Miku context files
def load_miku_context():
"""Load the current bot's context files"""
context = ""
try:
with open("../bot/persona/miku/miku_lore.txt", "r") as f:
context += f.read() + "\n\n"
with open("../bot/persona/miku/miku_prompt.txt", "r") as f:
context += f.read() + "\n\n"
# Skip lyrics for now - too long
except FileNotFoundError:
print("⚠️ Could not load context files from ../bot/")
return context
MIKU_CONTEXT = load_miku_context()
def test_cheshire_cat_non_streaming(query: str) -> Dict:
"""Test Cheshire Cat (no streaming available, measure total time)"""
start_time = time.time()
try:
response = requests.post(
f"{CAT_URL}/message",
json={"text": query, "user_id": "benchmark_user"},
timeout=60
)
total_time = (time.time() - start_time) * 1000
if response.status_code != 200:
return {
"success": False,
"error": f"HTTP {response.status_code}",
"method": "cheshire_cat"
}
data = response.json()
content = data.get("content", "")
# Filter tool calls
if content.startswith('{"name":'):
return {
"success": False,
"error": "Got tool call",
"method": "cheshire_cat"
}
# Estimate TTFT as ~15% of total (RAG retrieval + first tokens)
estimated_ttft = total_time * 0.15
return {
"success": True,
"ttft_ms": estimated_ttft,
"total_time_ms": total_time,
"response": content,
"method": "cheshire_cat",
"note": "TTFT estimated (no streaming)"
}
except Exception as e:
return {
"success": False,
"error": str(e),
"method": "cheshire_cat"
}
def test_direct_llama_streaming(query: str, use_context: bool = True) -> Dict:
"""Test direct llama.cpp with streaming to measure TTFT"""
start_time = time.time()
first_token_time = None
full_response = ""
chunks_received = 0
# Build system prompt
if use_context:
system_prompt = f"""You are Hatsune Miku, the virtual singer! Be cheerful, cute, and use emojis 🎶💙
CONTEXT:
{MIKU_CONTEXT}
Keep responses SHORT (2-3 sentences). Stay in character!"""
else:
system_prompt = "You are Hatsune Miku, the virtual singer! Be cheerful and cute. Keep responses SHORT."
payload = {
"model": "darkidol",
"messages": [
{"role": "system", "content": system_prompt},
{"role": "user", "content": query}
],
"stream": True,
"temperature": 0.8,
"max_tokens": 150
}
try:
response = requests.post(
f"{LLAMA_SWAP_URL}/chat/completions",
json=payload,
stream=True,
timeout=60
)
if response.status_code != 200:
return {
"success": False,
"error": f"HTTP {response.status_code}",
"method": f"direct_ctx={use_context}"
}
# Read streaming response line by line
for line in response.iter_lines():
if not line:
continue
line = line.decode('utf-8').strip()
if line == "data: [DONE]":
break
if line.startswith("data: "):
try:
json_str = line[6:] # Remove "data: " prefix
data = json.loads(json_str)
delta = data.get("choices", [{}])[0].get("delta", {})
content = delta.get("content", "")
if content:
if first_token_time is None:
first_token_time = (time.time() - start_time) * 1000
full_response += content
chunks_received += 1
except json.JSONDecodeError:
continue
total_time = (time.time() - start_time) * 1000
if first_token_time is None:
return {
"success": False,
"error": "No tokens received",
"method": f"direct_ctx={use_context}"
}
return {
"success": True,
"ttft_ms": first_token_time,
"total_time_ms": total_time,
"response": full_response.strip(),
"chunks": chunks_received,
"method": f"direct_ctx={use_context}",
"context_size": len(system_prompt) if use_context else 0
}
except Exception as e:
return {
"success": False,
"error": str(e),
"method": f"direct_ctx={use_context}"
}
def run_comparison(query: str) -> Dict:
"""Run all three methods on the same query"""
print(f"\n📝 Query: {query}")
results = {}
# Test 1: Cheshire Cat (RAG)
print(" 🐱 Testing Cheshire Cat...")
cat_result = test_cheshire_cat_non_streaming(query)
results['cheshire_cat'] = cat_result
if cat_result['success']:
print(f" TTFT: ~{cat_result['ttft_ms']:.0f}ms | Total: {cat_result['total_time_ms']:.0f}ms")
print(f" Response: {cat_result['response'][:80]}...")
else:
print(f" ❌ Error: {cat_result.get('error')}")
time.sleep(2)
# Test 2: Direct with full context
print(" 📄 Testing Direct + Full Context...")
direct_ctx_result = test_direct_llama_streaming(query, use_context=True)
results['direct_with_context'] = direct_ctx_result
if direct_ctx_result['success']:
print(f" TTFT: {direct_ctx_result['ttft_ms']:.0f}ms | Total: {direct_ctx_result['total_time_ms']:.0f}ms")
print(f" Response: {direct_ctx_result['response'][:80]}...")
else:
print(f" ❌ Error: {direct_ctx_result.get('error')}")
time.sleep(2)
# Test 3: Direct without context (minimal)
print(" ⚡ Testing Direct + Minimal Context...")
direct_min_result = test_direct_llama_streaming(query, use_context=False)
results['direct_minimal'] = direct_min_result
if direct_min_result['success']:
print(f" TTFT: {direct_min_result['ttft_ms']:.0f}ms | Total: {direct_min_result['total_time_ms']:.0f}ms")
print(f" Response: {direct_min_result['response'][:80]}...")
else:
print(f" ❌ Error: {direct_min_result.get('error')}")
return results
def main():
print("=" * 80)
print("⚡ STREAMING BENCHMARK - Time To First Token (TTFT) Comparison")
print("=" * 80)
print("\nComparing three approaches:")
print(" 1. 🐱 Cheshire Cat (RAG with embeddings)")
print(" 2. 📄 Direct LLM + Full Context (current bot approach)")
print(" 3. ⚡ Direct LLM + Minimal Context (baseline)")
print("\n" + "=" * 80)
all_results = []
for i, query in enumerate(TEST_QUERIES, 1):
print(f"\n[{i}/{len(TEST_QUERIES)}]")
results = run_comparison(query)
results['query'] = query
all_results.append(results)
if i < len(TEST_QUERIES):
print("\n⏳ Waiting 3s before next query...")
time.sleep(3)
# Calculate statistics
print("\n" + "=" * 80)
print("📊 RESULTS SUMMARY")
print("=" * 80)
methods = ['cheshire_cat', 'direct_with_context', 'direct_minimal']
method_names = {
'cheshire_cat': '🐱 Cheshire Cat (RAG)',
'direct_with_context': '📄 Direct + Full Context',
'direct_minimal': '⚡ Direct + Minimal'
}
for method in methods:
ttfts = []
totals = []
responses = []
for result in all_results:
if method in result and result[method].get('success'):
ttfts.append(result[method]['ttft_ms'])
totals.append(result[method]['total_time_ms'])
responses.append({
'query': result['query'],
'response': result[method]['response']
})
if ttfts:
print(f"\n{method_names[method]}")
print(f" Success Rate: {len(ttfts)}/{len(all_results)} ({len(ttfts)/len(all_results)*100:.0f}%)")
print(f" TTFT (Time To First Token):")
print(f" Mean: {statistics.mean(ttfts):>6.0f} ms")
print(f" Median: {statistics.median(ttfts):>6.0f} ms")
print(f" Min: {min(ttfts):>6.0f} ms")
print(f" Max: {max(ttfts):>6.0f} ms")
print(f" Total Generation Time:")
print(f" Mean: {statistics.mean(totals):>6.0f} ms")
print(f" Median: {statistics.median(totals):>6.0f} ms")
# Voice chat assessment
print("\n" + "=" * 80)
print("🎤 VOICE CHAT VIABILITY (based on TTFT)")
print("=" * 80)
for method in methods:
ttfts = [r[method]['ttft_ms'] for r in all_results if method in r and r[method].get('success')]
if ttfts:
mean_ttft = statistics.mean(ttfts)
if mean_ttft < 500:
status = "✅ EXCELLENT"
elif mean_ttft < 1000:
status = "✅ GOOD"
elif mean_ttft < 1500:
status = "⚠️ ACCEPTABLE"
else:
status = "❌ TOO SLOW"
print(f"{method_names[method]}: {status} ({mean_ttft:.0f}ms mean TTFT)")
# Save detailed results
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
output_file = f"streaming_benchmark_{timestamp}.json"
with open(output_file, 'w') as f:
json.dump(all_results, f, indent=2)
print(f"\n💾 Detailed results saved to: {output_file}")
print("\n" + "=" * 80)
if __name__ == "__main__":
main()