331 lines
11 KiB
Python
331 lines
11 KiB
Python
|
|
#!/usr/bin/env python3
|
||
|
|
"""
|
||
|
|
Streaming Benchmark - TTFB Comparison
|
||
|
|
Measures Time To First Token (TTFT) for voice chat viability
|
||
|
|
Compares Cheshire Cat RAG vs Direct Context Loading
|
||
|
|
"""
|
||
|
|
|
||
|
|
import requests
|
||
|
|
import time
|
||
|
|
import json
|
||
|
|
import statistics
|
||
|
|
from datetime import datetime
|
||
|
|
from typing import List, Dict
|
||
|
|
|
||
|
|
# URLs
|
||
|
|
CAT_URL = "http://localhost:1865"
|
||
|
|
LLAMA_SWAP_URL = "http://localhost:8091/v1"
|
||
|
|
|
||
|
|
# Test queries
|
||
|
|
TEST_QUERIES = [
|
||
|
|
"Hi Miku!",
|
||
|
|
"What's your favorite food?",
|
||
|
|
"Tell me about your friends",
|
||
|
|
"What songs do you sing?",
|
||
|
|
"How old are you?",
|
||
|
|
"Who created you?",
|
||
|
|
"Do you like green onions?",
|
||
|
|
"What's World is Mine about?",
|
||
|
|
"Tell me about Rin and Len",
|
||
|
|
"What do you like to do?"
|
||
|
|
]
|
||
|
|
|
||
|
|
# Load Miku context files
|
||
|
|
def load_miku_context():
|
||
|
|
"""Load the current bot's context files"""
|
||
|
|
context = ""
|
||
|
|
try:
|
||
|
|
with open("../bot/persona/miku/miku_lore.txt", "r") as f:
|
||
|
|
context += f.read() + "\n\n"
|
||
|
|
with open("../bot/persona/miku/miku_prompt.txt", "r") as f:
|
||
|
|
context += f.read() + "\n\n"
|
||
|
|
# Skip lyrics for now - too long
|
||
|
|
except FileNotFoundError:
|
||
|
|
print("⚠️ Could not load context files from ../bot/")
|
||
|
|
return context
|
||
|
|
|
||
|
|
MIKU_CONTEXT = load_miku_context()
|
||
|
|
|
||
|
|
def test_cheshire_cat_non_streaming(query: str) -> Dict:
|
||
|
|
"""Test Cheshire Cat (no streaming available, measure total time)"""
|
||
|
|
start_time = time.time()
|
||
|
|
|
||
|
|
try:
|
||
|
|
response = requests.post(
|
||
|
|
f"{CAT_URL}/message",
|
||
|
|
json={"text": query, "user_id": "benchmark_user"},
|
||
|
|
timeout=60
|
||
|
|
)
|
||
|
|
|
||
|
|
total_time = (time.time() - start_time) * 1000
|
||
|
|
|
||
|
|
if response.status_code != 200:
|
||
|
|
return {
|
||
|
|
"success": False,
|
||
|
|
"error": f"HTTP {response.status_code}",
|
||
|
|
"method": "cheshire_cat"
|
||
|
|
}
|
||
|
|
|
||
|
|
data = response.json()
|
||
|
|
content = data.get("content", "")
|
||
|
|
|
||
|
|
# Filter tool calls
|
||
|
|
if content.startswith('{"name":'):
|
||
|
|
return {
|
||
|
|
"success": False,
|
||
|
|
"error": "Got tool call",
|
||
|
|
"method": "cheshire_cat"
|
||
|
|
}
|
||
|
|
|
||
|
|
# Estimate TTFT as ~15% of total (RAG retrieval + first tokens)
|
||
|
|
estimated_ttft = total_time * 0.15
|
||
|
|
|
||
|
|
return {
|
||
|
|
"success": True,
|
||
|
|
"ttft_ms": estimated_ttft,
|
||
|
|
"total_time_ms": total_time,
|
||
|
|
"response": content,
|
||
|
|
"method": "cheshire_cat",
|
||
|
|
"note": "TTFT estimated (no streaming)"
|
||
|
|
}
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
return {
|
||
|
|
"success": False,
|
||
|
|
"error": str(e),
|
||
|
|
"method": "cheshire_cat"
|
||
|
|
}
|
||
|
|
|
||
|
|
def test_direct_llama_streaming(query: str, use_context: bool = True) -> Dict:
|
||
|
|
"""Test direct llama.cpp with streaming to measure TTFT"""
|
||
|
|
start_time = time.time()
|
||
|
|
first_token_time = None
|
||
|
|
full_response = ""
|
||
|
|
chunks_received = 0
|
||
|
|
|
||
|
|
# Build system prompt
|
||
|
|
if use_context:
|
||
|
|
system_prompt = f"""You are Hatsune Miku, the virtual singer! Be cheerful, cute, and use emojis 🎶💙
|
||
|
|
|
||
|
|
CONTEXT:
|
||
|
|
{MIKU_CONTEXT}
|
||
|
|
|
||
|
|
Keep responses SHORT (2-3 sentences). Stay in character!"""
|
||
|
|
else:
|
||
|
|
system_prompt = "You are Hatsune Miku, the virtual singer! Be cheerful and cute. Keep responses SHORT."
|
||
|
|
|
||
|
|
payload = {
|
||
|
|
"model": "darkidol",
|
||
|
|
"messages": [
|
||
|
|
{"role": "system", "content": system_prompt},
|
||
|
|
{"role": "user", "content": query}
|
||
|
|
],
|
||
|
|
"stream": True,
|
||
|
|
"temperature": 0.8,
|
||
|
|
"max_tokens": 150
|
||
|
|
}
|
||
|
|
|
||
|
|
try:
|
||
|
|
response = requests.post(
|
||
|
|
f"{LLAMA_SWAP_URL}/chat/completions",
|
||
|
|
json=payload,
|
||
|
|
stream=True,
|
||
|
|
timeout=60
|
||
|
|
)
|
||
|
|
|
||
|
|
if response.status_code != 200:
|
||
|
|
return {
|
||
|
|
"success": False,
|
||
|
|
"error": f"HTTP {response.status_code}",
|
||
|
|
"method": f"direct_ctx={use_context}"
|
||
|
|
}
|
||
|
|
|
||
|
|
# Read streaming response line by line
|
||
|
|
for line in response.iter_lines():
|
||
|
|
if not line:
|
||
|
|
continue
|
||
|
|
|
||
|
|
line = line.decode('utf-8').strip()
|
||
|
|
|
||
|
|
if line == "data: [DONE]":
|
||
|
|
break
|
||
|
|
|
||
|
|
if line.startswith("data: "):
|
||
|
|
try:
|
||
|
|
json_str = line[6:] # Remove "data: " prefix
|
||
|
|
data = json.loads(json_str)
|
||
|
|
|
||
|
|
delta = data.get("choices", [{}])[0].get("delta", {})
|
||
|
|
content = delta.get("content", "")
|
||
|
|
|
||
|
|
if content:
|
||
|
|
if first_token_time is None:
|
||
|
|
first_token_time = (time.time() - start_time) * 1000
|
||
|
|
|
||
|
|
full_response += content
|
||
|
|
chunks_received += 1
|
||
|
|
|
||
|
|
except json.JSONDecodeError:
|
||
|
|
continue
|
||
|
|
|
||
|
|
total_time = (time.time() - start_time) * 1000
|
||
|
|
|
||
|
|
if first_token_time is None:
|
||
|
|
return {
|
||
|
|
"success": False,
|
||
|
|
"error": "No tokens received",
|
||
|
|
"method": f"direct_ctx={use_context}"
|
||
|
|
}
|
||
|
|
|
||
|
|
return {
|
||
|
|
"success": True,
|
||
|
|
"ttft_ms": first_token_time,
|
||
|
|
"total_time_ms": total_time,
|
||
|
|
"response": full_response.strip(),
|
||
|
|
"chunks": chunks_received,
|
||
|
|
"method": f"direct_ctx={use_context}",
|
||
|
|
"context_size": len(system_prompt) if use_context else 0
|
||
|
|
}
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
return {
|
||
|
|
"success": False,
|
||
|
|
"error": str(e),
|
||
|
|
"method": f"direct_ctx={use_context}"
|
||
|
|
}
|
||
|
|
|
||
|
|
def run_comparison(query: str) -> Dict:
|
||
|
|
"""Run all three methods on the same query"""
|
||
|
|
print(f"\n📝 Query: {query}")
|
||
|
|
|
||
|
|
results = {}
|
||
|
|
|
||
|
|
# Test 1: Cheshire Cat (RAG)
|
||
|
|
print(" 🐱 Testing Cheshire Cat...")
|
||
|
|
cat_result = test_cheshire_cat_non_streaming(query)
|
||
|
|
results['cheshire_cat'] = cat_result
|
||
|
|
if cat_result['success']:
|
||
|
|
print(f" TTFT: ~{cat_result['ttft_ms']:.0f}ms | Total: {cat_result['total_time_ms']:.0f}ms")
|
||
|
|
print(f" Response: {cat_result['response'][:80]}...")
|
||
|
|
else:
|
||
|
|
print(f" ❌ Error: {cat_result.get('error')}")
|
||
|
|
|
||
|
|
time.sleep(2)
|
||
|
|
|
||
|
|
# Test 2: Direct with full context
|
||
|
|
print(" 📄 Testing Direct + Full Context...")
|
||
|
|
direct_ctx_result = test_direct_llama_streaming(query, use_context=True)
|
||
|
|
results['direct_with_context'] = direct_ctx_result
|
||
|
|
if direct_ctx_result['success']:
|
||
|
|
print(f" TTFT: {direct_ctx_result['ttft_ms']:.0f}ms | Total: {direct_ctx_result['total_time_ms']:.0f}ms")
|
||
|
|
print(f" Response: {direct_ctx_result['response'][:80]}...")
|
||
|
|
else:
|
||
|
|
print(f" ❌ Error: {direct_ctx_result.get('error')}")
|
||
|
|
|
||
|
|
time.sleep(2)
|
||
|
|
|
||
|
|
# Test 3: Direct without context (minimal)
|
||
|
|
print(" ⚡ Testing Direct + Minimal Context...")
|
||
|
|
direct_min_result = test_direct_llama_streaming(query, use_context=False)
|
||
|
|
results['direct_minimal'] = direct_min_result
|
||
|
|
if direct_min_result['success']:
|
||
|
|
print(f" TTFT: {direct_min_result['ttft_ms']:.0f}ms | Total: {direct_min_result['total_time_ms']:.0f}ms")
|
||
|
|
print(f" Response: {direct_min_result['response'][:80]}...")
|
||
|
|
else:
|
||
|
|
print(f" ❌ Error: {direct_min_result.get('error')}")
|
||
|
|
|
||
|
|
return results
|
||
|
|
|
||
|
|
def main():
|
||
|
|
print("=" * 80)
|
||
|
|
print("⚡ STREAMING BENCHMARK - Time To First Token (TTFT) Comparison")
|
||
|
|
print("=" * 80)
|
||
|
|
print("\nComparing three approaches:")
|
||
|
|
print(" 1. 🐱 Cheshire Cat (RAG with embeddings)")
|
||
|
|
print(" 2. 📄 Direct LLM + Full Context (current bot approach)")
|
||
|
|
print(" 3. ⚡ Direct LLM + Minimal Context (baseline)")
|
||
|
|
print("\n" + "=" * 80)
|
||
|
|
|
||
|
|
all_results = []
|
||
|
|
|
||
|
|
for i, query in enumerate(TEST_QUERIES, 1):
|
||
|
|
print(f"\n[{i}/{len(TEST_QUERIES)}]")
|
||
|
|
results = run_comparison(query)
|
||
|
|
results['query'] = query
|
||
|
|
all_results.append(results)
|
||
|
|
|
||
|
|
if i < len(TEST_QUERIES):
|
||
|
|
print("\n⏳ Waiting 3s before next query...")
|
||
|
|
time.sleep(3)
|
||
|
|
|
||
|
|
# Calculate statistics
|
||
|
|
print("\n" + "=" * 80)
|
||
|
|
print("📊 RESULTS SUMMARY")
|
||
|
|
print("=" * 80)
|
||
|
|
|
||
|
|
methods = ['cheshire_cat', 'direct_with_context', 'direct_minimal']
|
||
|
|
method_names = {
|
||
|
|
'cheshire_cat': '🐱 Cheshire Cat (RAG)',
|
||
|
|
'direct_with_context': '📄 Direct + Full Context',
|
||
|
|
'direct_minimal': '⚡ Direct + Minimal'
|
||
|
|
}
|
||
|
|
|
||
|
|
for method in methods:
|
||
|
|
ttfts = []
|
||
|
|
totals = []
|
||
|
|
responses = []
|
||
|
|
|
||
|
|
for result in all_results:
|
||
|
|
if method in result and result[method].get('success'):
|
||
|
|
ttfts.append(result[method]['ttft_ms'])
|
||
|
|
totals.append(result[method]['total_time_ms'])
|
||
|
|
responses.append({
|
||
|
|
'query': result['query'],
|
||
|
|
'response': result[method]['response']
|
||
|
|
})
|
||
|
|
|
||
|
|
if ttfts:
|
||
|
|
print(f"\n{method_names[method]}")
|
||
|
|
print(f" Success Rate: {len(ttfts)}/{len(all_results)} ({len(ttfts)/len(all_results)*100:.0f}%)")
|
||
|
|
print(f" TTFT (Time To First Token):")
|
||
|
|
print(f" Mean: {statistics.mean(ttfts):>6.0f} ms")
|
||
|
|
print(f" Median: {statistics.median(ttfts):>6.0f} ms")
|
||
|
|
print(f" Min: {min(ttfts):>6.0f} ms")
|
||
|
|
print(f" Max: {max(ttfts):>6.0f} ms")
|
||
|
|
print(f" Total Generation Time:")
|
||
|
|
print(f" Mean: {statistics.mean(totals):>6.0f} ms")
|
||
|
|
print(f" Median: {statistics.median(totals):>6.0f} ms")
|
||
|
|
|
||
|
|
# Voice chat assessment
|
||
|
|
print("\n" + "=" * 80)
|
||
|
|
print("🎤 VOICE CHAT VIABILITY (based on TTFT)")
|
||
|
|
print("=" * 80)
|
||
|
|
|
||
|
|
for method in methods:
|
||
|
|
ttfts = [r[method]['ttft_ms'] for r in all_results if method in r and r[method].get('success')]
|
||
|
|
if ttfts:
|
||
|
|
mean_ttft = statistics.mean(ttfts)
|
||
|
|
if mean_ttft < 500:
|
||
|
|
status = "✅ EXCELLENT"
|
||
|
|
elif mean_ttft < 1000:
|
||
|
|
status = "✅ GOOD"
|
||
|
|
elif mean_ttft < 1500:
|
||
|
|
status = "⚠️ ACCEPTABLE"
|
||
|
|
else:
|
||
|
|
status = "❌ TOO SLOW"
|
||
|
|
|
||
|
|
print(f"{method_names[method]}: {status} ({mean_ttft:.0f}ms mean TTFT)")
|
||
|
|
|
||
|
|
# Save detailed results
|
||
|
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||
|
|
output_file = f"streaming_benchmark_{timestamp}.json"
|
||
|
|
|
||
|
|
with open(output_file, 'w') as f:
|
||
|
|
json.dump(all_results, f, indent=2)
|
||
|
|
|
||
|
|
print(f"\n💾 Detailed results saved to: {output_file}")
|
||
|
|
print("\n" + "=" * 80)
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
main()
|