Files
miku-discord/bot/utils/cat_client.py
koko210Serve 34167eddae feat: Restore mood system and implement comprehensive memory editor UI
MOOD SYSTEM FIX:
- Mount bot/moods directory in docker-compose.yml for Cat container access
- Update miku_personality plugin to load mood descriptions from .txt files
- Add Cat logger for debugging mood loading (replaces print statements)
- Moods now dynamically loaded from working_memory instead of hardcoded neutral
2026-02-10 22:03:54 +02:00

640 lines
28 KiB
Python

# utils/cat_client.py
"""
Cheshire Cat AI Adapter for Miku Discord Bot (Phase 3)
Routes messages through the Cheshire Cat pipeline for:
- Memory-augmented responses (episodic + declarative recall)
- Fact extraction and consolidation
- Per-user conversation isolation
Uses WebSocket for chat (per-user isolation via /ws/{user_id}).
Uses HTTP for memory management endpoints.
Falls back to query_llama() on failure for zero-downtime resilience.
"""
import aiohttp
import asyncio
import json
import time
from typing import Optional, Dict, Any, List
import globals
from utils.logger import get_logger
logger = get_logger('llm') # Use existing 'llm' logger component
class CatAdapter:
"""
Async adapter for Cheshire Cat AI.
Uses WebSocket /ws/{user_id} for conversation (per-user memory isolation).
Uses HTTP REST for memory management endpoints.
Without API keys configured, HTTP POST /message defaults all users to
user_id="user" (no isolation). WebSocket path param gives true isolation.
"""
def __init__(self):
self._base_url = globals.CHESHIRE_CAT_URL.rstrip('/')
self._api_key = globals.CHESHIRE_CAT_API_KEY
self._timeout = globals.CHESHIRE_CAT_TIMEOUT
self._healthy = None # None = unknown, True/False = last check result
self._last_health_check = 0
self._health_check_interval = 30 # seconds between health checks
self._consecutive_failures = 0
self._max_failures_before_circuit_break = 3
self._circuit_broken_until = 0 # timestamp when circuit breaker resets
logger.info(f"CatAdapter initialized: {self._base_url} (timeout={self._timeout}s)")
def _get_headers(self) -> dict:
"""Build request headers with optional auth."""
headers = {'Content-Type': 'application/json'}
if self._api_key:
headers['Authorization'] = f'Bearer {self._api_key}'
return headers
def _user_id_for_discord(self, user_id: str) -> str:
"""
Format Discord user ID for Cat's user namespace.
Cat uses user_id to isolate working memory and episodic memories.
"""
return f"discord_{user_id}"
async def health_check(self) -> bool:
"""
Check if Cheshire Cat is reachable and healthy.
Caches result to avoid hammering the endpoint.
"""
now = time.time()
if now - self._last_health_check < self._health_check_interval and self._healthy is not None:
return self._healthy
try:
async with aiohttp.ClientSession() as session:
async with session.get(
f"{self._base_url}/",
headers=self._get_headers(),
timeout=aiohttp.ClientTimeout(total=10)
) as response:
self._healthy = response.status == 200
self._last_health_check = now
if self._healthy:
logger.debug("Cat health check: OK")
else:
logger.warning(f"Cat health check failed: status {response.status}")
return self._healthy
except Exception as e:
self._healthy = False
self._last_health_check = now
logger.warning(f"Cat health check error: {e}")
return False
def _is_circuit_broken(self) -> bool:
"""Check if circuit breaker is active (too many consecutive failures)."""
if self._consecutive_failures >= self._max_failures_before_circuit_break:
if time.time() < self._circuit_broken_until:
return True
# Circuit breaker expired, allow retry
logger.info("Circuit breaker reset, allowing Cat retry")
self._consecutive_failures = 0
return False
async def query(
self,
text: str,
user_id: str,
guild_id: Optional[str] = None,
author_name: Optional[str] = None,
mood: Optional[str] = None,
response_type: str = "dm_response",
) -> Optional[str]:
"""
Send a message through the Cat pipeline via WebSocket and get a response.
Uses WebSocket /ws/{user_id} for per-user memory isolation.
Without API keys, HTTP POST /message defaults all users to user_id="user"
(no isolation). The WebSocket path parameter provides true per-user isolation
because Cat's auth handler uses user_id from the path when no keys are set.
Args:
text: User's message text
user_id: Discord user ID (will be namespaced as discord_{user_id})
guild_id: Optional guild ID for server context
author_name: Display name of the user
mood: Current mood name (passed as metadata for Cat hooks)
response_type: Type of response context
Returns:
Cat's response text, or None if Cat is unavailable (caller should fallback)
"""
if not globals.USE_CHESHIRE_CAT:
return None
if self._is_circuit_broken():
logger.debug("Circuit breaker active, skipping Cat")
return None
cat_user_id = self._user_id_for_discord(user_id)
# Build message payload with Discord metadata for our plugin hooks.
# The discord_bridge plugin's before_cat_reads_message hook reads
# these custom keys from the message dict.
payload = {
"text": text,
}
if guild_id:
payload["discord_guild_id"] = str(guild_id)
if author_name:
payload["discord_author_name"] = author_name
if mood:
payload["discord_mood"] = mood
if response_type:
payload["discord_response_type"] = response_type
try:
# Build WebSocket URL from HTTP base URL
ws_base = self._base_url.replace("http://", "ws://").replace("https://", "wss://")
ws_url = f"{ws_base}/ws/{cat_user_id}"
logger.debug(f"Querying Cat via WS: user={cat_user_id}, text={text[:80]}...")
async with aiohttp.ClientSession() as session:
async with session.ws_connect(
ws_url,
timeout=self._timeout,
) as ws:
# Send the message
await ws.send_json(payload)
# Read responses until we get the final "chat" type message.
# Cat may send intermediate messages (chat_token for streaming,
# notification for status updates). We want the final "chat" one.
reply_text = None
deadline = asyncio.get_event_loop().time() + self._timeout
while True:
remaining = deadline - asyncio.get_event_loop().time()
if remaining <= 0:
logger.error(f"Cat WS timeout after {self._timeout}s")
break
try:
ws_msg = await asyncio.wait_for(
ws.receive(),
timeout=remaining
)
except asyncio.TimeoutError:
logger.error(f"Cat WS receive timeout after {self._timeout}s")
break
# Handle WebSocket close/error frames
if ws_msg.type in (aiohttp.WSMsgType.CLOSE, aiohttp.WSMsgType.CLOSING, aiohttp.WSMsgType.CLOSED):
logger.warning("Cat WS connection closed by server")
break
if ws_msg.type == aiohttp.WSMsgType.ERROR:
logger.error(f"Cat WS error frame: {ws.exception()}")
break
if ws_msg.type != aiohttp.WSMsgType.TEXT:
logger.debug(f"Cat WS non-text frame type: {ws_msg.type}")
continue
try:
msg = json.loads(ws_msg.data)
except (json.JSONDecodeError, TypeError) as e:
logger.warning(f"Cat WS non-JSON message: {e}")
continue
msg_type = msg.get("type", "")
if msg_type == "chat":
# Final response — extract text
reply_text = msg.get("content") or msg.get("text", "")
break
elif msg_type == "chat_token":
# Streaming token — skip, we wait for final
continue
elif msg_type == "error":
error_desc = msg.get("description", "Unknown Cat error")
logger.error(f"Cat WS error: {error_desc}")
break
elif msg_type == "notification":
logger.debug(f"Cat notification: {msg.get('content', '')}")
continue
else:
logger.debug(f"Cat WS unknown msg type: {msg_type}")
continue
if reply_text and reply_text.strip():
self._consecutive_failures = 0
logger.info(f"🐱 Cat response for {cat_user_id}: {reply_text[:100]}...")
return reply_text
else:
logger.warning("Cat returned empty response via WS")
self._consecutive_failures += 1
return None
except asyncio.TimeoutError:
logger.error(f"Cat WS connection timeout after {self._timeout}s")
self._consecutive_failures += 1
if self._consecutive_failures >= self._max_failures_before_circuit_break:
self._circuit_broken_until = time.time() + 60
logger.warning("Circuit breaker activated (WS timeout)")
return None
except Exception as e:
logger.error(f"Cat WS query error: {e}")
self._consecutive_failures += 1
if self._consecutive_failures >= self._max_failures_before_circuit_break:
self._circuit_broken_until = time.time() + 60
logger.warning(f"Circuit breaker activated: {e}")
return None
# ===================================================================
# MEMORY MANAGEMENT API (for Web UI)
# ===================================================================
async def get_memory_stats(self) -> Optional[Dict[str, Any]]:
"""
Get memory collection statistics with actual counts from Qdrant.
Returns dict with collection names and point counts.
"""
try:
# Query Qdrant directly for accurate counts
qdrant_host = self._base_url.replace("http://cheshire-cat:80", "http://cheshire-cat-vector-memory:6333")
collections_data = []
for collection_name in ["episodic", "declarative", "procedural"]:
async with aiohttp.ClientSession() as session:
async with session.get(
f"{qdrant_host}/collections/{collection_name}",
timeout=aiohttp.ClientTimeout(total=10)
) as response:
if response.status == 200:
data = await response.json()
count = data.get("result", {}).get("points_count", 0)
collections_data.append({
"name": collection_name,
"vectors_count": count
})
else:
collections_data.append({
"name": collection_name,
"vectors_count": 0
})
return {"collections": collections_data}
except Exception as e:
logger.error(f"Error getting memory stats from Qdrant: {e}")
return None
async def get_memory_points(
self,
collection: str = "declarative",
limit: int = 100,
offset: Optional[str] = None
) -> Optional[Dict[str, Any]]:
"""
Get all points from a memory collection via Qdrant.
Cat doesn't expose /memory/collections/{id}/points, so we query Qdrant directly.
Returns paginated list of memory points.
"""
try:
# Use Qdrant directly (Cat's vector memory backend)
# Qdrant is accessible at the same host, port 6333 internally
qdrant_host = self._base_url.replace("http://cheshire-cat:80", "http://cheshire-cat-vector-memory:6333")
payload = {"limit": limit, "with_payload": True, "with_vector": False}
if offset:
payload["offset"] = offset
async with aiohttp.ClientSession() as session:
async with session.post(
f"{qdrant_host}/collections/{collection}/points/scroll",
json=payload,
timeout=aiohttp.ClientTimeout(total=30)
) as response:
if response.status == 200:
data = await response.json()
return data.get("result", {})
else:
logger.error(f"Failed to get {collection} points from Qdrant: {response.status}")
return None
except Exception as e:
logger.error(f"Error getting memory points from Qdrant: {e}")
return None
async def get_all_facts(self) -> List[Dict[str, Any]]:
"""
Retrieve ALL declarative memory points (facts) with pagination.
Returns a flat list of all fact dicts.
"""
all_facts = []
offset = None
try:
while True:
result = await self.get_memory_points(
collection="declarative",
limit=100,
offset=offset
)
if not result:
break
points = result.get("points", [])
for point in points:
payload = point.get("payload", {})
fact = {
"id": point.get("id"),
"content": payload.get("page_content", ""),
"metadata": payload.get("metadata", {}),
}
all_facts.append(fact)
offset = result.get("next_offset")
if not offset:
break
logger.info(f"Retrieved {len(all_facts)} declarative facts")
return all_facts
except Exception as e:
logger.error(f"Error retrieving all facts: {e}")
return all_facts
async def delete_memory_point(self, collection: str, point_id: str) -> bool:
"""Delete a single memory point by ID via Qdrant."""
try:
qdrant_host = self._base_url.replace("http://cheshire-cat:80", "http://cheshire-cat-vector-memory:6333")
async with aiohttp.ClientSession() as session:
async with session.post(
f"{qdrant_host}/collections/{collection}/points/delete",
json={"points": [point_id]},
timeout=aiohttp.ClientTimeout(total=15)
) as response:
if response.status == 200:
logger.info(f"Deleted memory point {point_id} from {collection}")
return True
else:
logger.error(f"Failed to delete point: {response.status}")
return False
except Exception as e:
logger.error(f"Error deleting memory point: {e}")
return False
async def update_memory_point(self, collection: str, point_id: str, content: str, metadata: dict = None) -> bool:
"""Update an existing memory point's content and/or metadata."""
try:
# First, get the existing point to retrieve its vector
qdrant_host = self._base_url.replace("http://cheshire-cat:80", "http://cheshire-cat-vector-memory:6333")
async with aiohttp.ClientSession() as session:
# Get existing point
async with session.post(
f"{qdrant_host}/collections/{collection}/points",
json={"ids": [point_id], "with_vector": True, "with_payload": True},
timeout=aiohttp.ClientTimeout(total=15)
) as response:
if response.status != 200:
logger.error(f"Failed to fetch point {point_id}: {response.status}")
return False
data = await response.json()
points = data.get("result", [])
if not points:
logger.error(f"Point {point_id} not found")
return False
existing_point = points[0]
existing_vector = existing_point.get("vector")
existing_payload = existing_point.get("payload", {})
# If content changed, we need to re-embed it
if content != existing_payload.get("page_content"):
# Call Cat's embedder to get new vector
embed_response = await session.post(
f"{self._base_url}/embedder",
json={"text": content},
headers=self._get_headers(),
timeout=aiohttp.ClientTimeout(total=30)
)
if embed_response.status == 200:
embed_data = await embed_response.json()
new_vector = embed_data.get("embedding")
else:
logger.warning(f"Failed to re-embed content, keeping old vector")
new_vector = existing_vector
else:
new_vector = existing_vector
# Build updated payload
updated_payload = {
"page_content": content,
"metadata": metadata if metadata is not None else existing_payload.get("metadata", {})
}
# Update the point
async with session.put(
f"{qdrant_host}/collections/{collection}/points",
json={
"points": [{
"id": point_id,
"vector": new_vector,
"payload": updated_payload
}]
},
timeout=aiohttp.ClientTimeout(total=15)
) as update_response:
if update_response.status == 200:
logger.info(f"✏️ Updated memory point {point_id} in {collection}")
return True
else:
logger.error(f"Failed to update point: {update_response.status}")
return False
except Exception as e:
logger.error(f"Error updating memory point: {e}")
return False
async def create_memory_point(self, collection: str, content: str, user_id: str, source: str, metadata: dict = None) -> Optional[str]:
"""Create a new memory point manually."""
try:
import uuid
import time
# Generate a unique ID
point_id = str(uuid.uuid4())
# Get vector embedding from Cat
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self._base_url}/embedder",
json={"text": content},
headers=self._get_headers(),
timeout=aiohttp.ClientTimeout(total=30)
) as response:
if response.status != 200:
logger.error(f"Failed to embed content: {response.status}")
return None
data = await response.json()
vector = data.get("embedding")
if not vector:
logger.error("No embedding returned from Cat")
return None
# Build payload
payload = {
"page_content": content,
"metadata": metadata or {}
}
payload["metadata"]["source"] = source
payload["metadata"]["when"] = time.time()
# For declarative memories, add user_id to metadata
# For episodic, it's in the source field
if collection == "declarative":
payload["metadata"]["user_id"] = user_id
elif collection == "episodic":
payload["metadata"]["source"] = user_id
# Insert into Qdrant
qdrant_host = self._base_url.replace("http://cheshire-cat:80", "http://cheshire-cat-vector-memory:6333")
async with session.put(
f"{qdrant_host}/collections/{collection}/points",
json={
"points": [{
"id": point_id,
"vector": vector,
"payload": payload
}]
},
timeout=aiohttp.ClientTimeout(total=15)
) as insert_response:
if insert_response.status == 200:
logger.info(f"✨ Created new {collection} memory point: {point_id}")
return point_id
else:
logger.error(f"Failed to insert point: {insert_response.status}")
return None
except Exception as e:
logger.error(f"Error creating memory point: {e}")
return None
async def wipe_all_memories(self) -> bool:
"""
Delete ALL memory collections (episodic + declarative).
This is the nuclear option — requires multi-step confirmation in the UI.
"""
try:
async with aiohttp.ClientSession() as session:
async with session.delete(
f"{self._base_url}/memory/collections",
headers=self._get_headers(),
timeout=aiohttp.ClientTimeout(total=30)
) as response:
if response.status == 200:
logger.warning("🗑️ ALL memory collections wiped!")
return True
else:
error = await response.text()
logger.error(f"Failed to wipe memories: {response.status} - {error}")
return False
except Exception as e:
logger.error(f"Error wiping memories: {e}")
return False
async def wipe_conversation_history(self) -> bool:
"""Clear working memory / conversation history."""
try:
async with aiohttp.ClientSession() as session:
async with session.delete(
f"{self._base_url}/memory/conversation_history",
headers=self._get_headers(),
timeout=aiohttp.ClientTimeout(total=15)
) as response:
if response.status == 200:
logger.info("Conversation history cleared")
return True
else:
logger.error(f"Failed to clear conversation history: {response.status}")
return False
except Exception as e:
logger.error(f"Error clearing conversation history: {e}")
return False
async def trigger_consolidation(self) -> Optional[str]:
"""
Trigger memory consolidation by sending a special message via WebSocket.
The memory_consolidation plugin's tool 'consolidate_memories' is
triggered when it sees 'consolidate now' in the text.
Uses WebSocket with a system user ID for proper context.
"""
try:
ws_base = self._base_url.replace("http://", "ws://").replace("https://", "wss://")
ws_url = f"{ws_base}/ws/system_consolidation"
logger.info("🌙 Triggering memory consolidation via WS...")
async with aiohttp.ClientSession() as session:
async with session.ws_connect(
ws_url,
timeout=300, # Consolidation can be very slow
) as ws:
await ws.send_json({"text": "consolidate now"})
# Wait for the final chat response
deadline = asyncio.get_event_loop().time() + 300
while True:
remaining = deadline - asyncio.get_event_loop().time()
if remaining <= 0:
logger.error("Consolidation timed out (>300s)")
return "Consolidation timed out"
try:
ws_msg = await asyncio.wait_for(
ws.receive(),
timeout=remaining
)
except asyncio.TimeoutError:
logger.error("Consolidation WS receive timeout")
return "Consolidation timed out waiting for response"
if ws_msg.type in (aiohttp.WSMsgType.CLOSE, aiohttp.WSMsgType.CLOSING, aiohttp.WSMsgType.CLOSED):
logger.warning("Consolidation WS closed by server")
return "Connection closed during consolidation"
if ws_msg.type == aiohttp.WSMsgType.ERROR:
return f"WebSocket error: {ws.exception()}"
if ws_msg.type != aiohttp.WSMsgType.TEXT:
continue
try:
msg = json.loads(ws_msg.data)
except (json.JSONDecodeError, TypeError):
continue
msg_type = msg.get("type", "")
if msg_type == "chat":
reply = msg.get("content") or msg.get("text", "")
logger.info(f"Consolidation result: {reply[:200]}")
return reply
elif msg_type == "error":
error_desc = msg.get("description", "Unknown error")
logger.error(f"Consolidation error: {error_desc}")
return f"Consolidation error: {error_desc}"
else:
continue
except asyncio.TimeoutError:
logger.error("Consolidation WS connection timed out")
return None
except Exception as e:
logger.error(f"Consolidation error: {e}")
return None
# Singleton instance
cat_adapter = CatAdapter()