480 lines
20 KiB
Python
480 lines
20 KiB
Python
|
|
# utils/cat_client.py
|
||
|
|
"""
|
||
|
|
Cheshire Cat AI Adapter for Miku Discord Bot (Phase 3)
|
||
|
|
|
||
|
|
Routes messages through the Cheshire Cat pipeline for:
|
||
|
|
- Memory-augmented responses (episodic + declarative recall)
|
||
|
|
- Fact extraction and consolidation
|
||
|
|
- Per-user conversation isolation
|
||
|
|
|
||
|
|
Uses WebSocket for chat (per-user isolation via /ws/{user_id}).
|
||
|
|
Uses HTTP for memory management endpoints.
|
||
|
|
Falls back to query_llama() on failure for zero-downtime resilience.
|
||
|
|
"""
|
||
|
|
|
||
|
|
import aiohttp
|
||
|
|
import asyncio
|
||
|
|
import json
|
||
|
|
import time
|
||
|
|
from typing import Optional, Dict, Any, List
|
||
|
|
|
||
|
|
import globals
|
||
|
|
from utils.logger import get_logger
|
||
|
|
|
||
|
|
logger = get_logger('cat_client')
|
||
|
|
|
||
|
|
|
||
|
|
class CatAdapter:
|
||
|
|
"""
|
||
|
|
Async adapter for Cheshire Cat AI.
|
||
|
|
|
||
|
|
Uses WebSocket /ws/{user_id} for conversation (per-user memory isolation).
|
||
|
|
Uses HTTP REST for memory management endpoints.
|
||
|
|
Without API keys configured, HTTP POST /message defaults all users to
|
||
|
|
user_id="user" (no isolation). WebSocket path param gives true isolation.
|
||
|
|
"""
|
||
|
|
|
||
|
|
def __init__(self):
|
||
|
|
self._base_url = globals.CHESHIRE_CAT_URL.rstrip('/')
|
||
|
|
self._api_key = globals.CHESHIRE_CAT_API_KEY
|
||
|
|
self._timeout = globals.CHESHIRE_CAT_TIMEOUT
|
||
|
|
self._healthy = None # None = unknown, True/False = last check result
|
||
|
|
self._last_health_check = 0
|
||
|
|
self._health_check_interval = 30 # seconds between health checks
|
||
|
|
self._consecutive_failures = 0
|
||
|
|
self._max_failures_before_circuit_break = 3
|
||
|
|
self._circuit_broken_until = 0 # timestamp when circuit breaker resets
|
||
|
|
logger.info(f"CatAdapter initialized: {self._base_url} (timeout={self._timeout}s)")
|
||
|
|
|
||
|
|
def _get_headers(self) -> dict:
|
||
|
|
"""Build request headers with optional auth."""
|
||
|
|
headers = {'Content-Type': 'application/json'}
|
||
|
|
if self._api_key:
|
||
|
|
headers['Authorization'] = f'Bearer {self._api_key}'
|
||
|
|
return headers
|
||
|
|
|
||
|
|
def _user_id_for_discord(self, user_id: str) -> str:
|
||
|
|
"""
|
||
|
|
Format Discord user ID for Cat's user namespace.
|
||
|
|
Cat uses user_id to isolate working memory and episodic memories.
|
||
|
|
"""
|
||
|
|
return f"discord_{user_id}"
|
||
|
|
|
||
|
|
async def health_check(self) -> bool:
|
||
|
|
"""
|
||
|
|
Check if Cheshire Cat is reachable and healthy.
|
||
|
|
Caches result to avoid hammering the endpoint.
|
||
|
|
"""
|
||
|
|
now = time.time()
|
||
|
|
if now - self._last_health_check < self._health_check_interval and self._healthy is not None:
|
||
|
|
return self._healthy
|
||
|
|
|
||
|
|
try:
|
||
|
|
async with aiohttp.ClientSession() as session:
|
||
|
|
async with session.get(
|
||
|
|
f"{self._base_url}/",
|
||
|
|
headers=self._get_headers(),
|
||
|
|
timeout=aiohttp.ClientTimeout(total=10)
|
||
|
|
) as response:
|
||
|
|
self._healthy = response.status == 200
|
||
|
|
self._last_health_check = now
|
||
|
|
if self._healthy:
|
||
|
|
logger.debug("Cat health check: OK")
|
||
|
|
else:
|
||
|
|
logger.warning(f"Cat health check failed: status {response.status}")
|
||
|
|
return self._healthy
|
||
|
|
except Exception as e:
|
||
|
|
self._healthy = False
|
||
|
|
self._last_health_check = now
|
||
|
|
logger.warning(f"Cat health check error: {e}")
|
||
|
|
return False
|
||
|
|
|
||
|
|
def _is_circuit_broken(self) -> bool:
|
||
|
|
"""Check if circuit breaker is active (too many consecutive failures)."""
|
||
|
|
if self._consecutive_failures >= self._max_failures_before_circuit_break:
|
||
|
|
if time.time() < self._circuit_broken_until:
|
||
|
|
return True
|
||
|
|
# Circuit breaker expired, allow retry
|
||
|
|
logger.info("Circuit breaker reset, allowing Cat retry")
|
||
|
|
self._consecutive_failures = 0
|
||
|
|
return False
|
||
|
|
|
||
|
|
async def query(
|
||
|
|
self,
|
||
|
|
text: str,
|
||
|
|
user_id: str,
|
||
|
|
guild_id: Optional[str] = None,
|
||
|
|
author_name: Optional[str] = None,
|
||
|
|
mood: Optional[str] = None,
|
||
|
|
response_type: str = "dm_response",
|
||
|
|
) -> Optional[str]:
|
||
|
|
"""
|
||
|
|
Send a message through the Cat pipeline via WebSocket and get a response.
|
||
|
|
|
||
|
|
Uses WebSocket /ws/{user_id} for per-user memory isolation.
|
||
|
|
Without API keys, HTTP POST /message defaults all users to user_id="user"
|
||
|
|
(no isolation). The WebSocket path parameter provides true per-user isolation
|
||
|
|
because Cat's auth handler uses user_id from the path when no keys are set.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
text: User's message text
|
||
|
|
user_id: Discord user ID (will be namespaced as discord_{user_id})
|
||
|
|
guild_id: Optional guild ID for server context
|
||
|
|
author_name: Display name of the user
|
||
|
|
mood: Current mood name (passed as metadata for Cat hooks)
|
||
|
|
response_type: Type of response context
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Cat's response text, or None if Cat is unavailable (caller should fallback)
|
||
|
|
"""
|
||
|
|
if not globals.USE_CHESHIRE_CAT:
|
||
|
|
return None
|
||
|
|
|
||
|
|
if self._is_circuit_broken():
|
||
|
|
logger.debug("Circuit breaker active, skipping Cat")
|
||
|
|
return None
|
||
|
|
|
||
|
|
cat_user_id = self._user_id_for_discord(user_id)
|
||
|
|
|
||
|
|
# Build message payload with Discord metadata for our plugin hooks.
|
||
|
|
# The discord_bridge plugin's before_cat_reads_message hook reads
|
||
|
|
# these custom keys from the message dict.
|
||
|
|
payload = {
|
||
|
|
"text": text,
|
||
|
|
}
|
||
|
|
if guild_id:
|
||
|
|
payload["discord_guild_id"] = str(guild_id)
|
||
|
|
if author_name:
|
||
|
|
payload["discord_author_name"] = author_name
|
||
|
|
if mood:
|
||
|
|
payload["discord_mood"] = mood
|
||
|
|
if response_type:
|
||
|
|
payload["discord_response_type"] = response_type
|
||
|
|
|
||
|
|
try:
|
||
|
|
# Build WebSocket URL from HTTP base URL
|
||
|
|
ws_base = self._base_url.replace("http://", "ws://").replace("https://", "wss://")
|
||
|
|
ws_url = f"{ws_base}/ws/{cat_user_id}"
|
||
|
|
|
||
|
|
logger.debug(f"Querying Cat via WS: user={cat_user_id}, text={text[:80]}...")
|
||
|
|
|
||
|
|
async with aiohttp.ClientSession() as session:
|
||
|
|
async with session.ws_connect(
|
||
|
|
ws_url,
|
||
|
|
timeout=self._timeout,
|
||
|
|
) as ws:
|
||
|
|
# Send the message
|
||
|
|
await ws.send_json(payload)
|
||
|
|
|
||
|
|
# Read responses until we get the final "chat" type message.
|
||
|
|
# Cat may send intermediate messages (chat_token for streaming,
|
||
|
|
# notification for status updates). We want the final "chat" one.
|
||
|
|
reply_text = None
|
||
|
|
deadline = asyncio.get_event_loop().time() + self._timeout
|
||
|
|
|
||
|
|
while True:
|
||
|
|
remaining = deadline - asyncio.get_event_loop().time()
|
||
|
|
if remaining <= 0:
|
||
|
|
logger.error(f"Cat WS timeout after {self._timeout}s")
|
||
|
|
break
|
||
|
|
|
||
|
|
try:
|
||
|
|
ws_msg = await asyncio.wait_for(
|
||
|
|
ws.receive(),
|
||
|
|
timeout=remaining
|
||
|
|
)
|
||
|
|
except asyncio.TimeoutError:
|
||
|
|
logger.error(f"Cat WS receive timeout after {self._timeout}s")
|
||
|
|
break
|
||
|
|
|
||
|
|
# Handle WebSocket close/error frames
|
||
|
|
if ws_msg.type in (aiohttp.WSMsgType.CLOSE, aiohttp.WSMsgType.CLOSING, aiohttp.WSMsgType.CLOSED):
|
||
|
|
logger.warning("Cat WS connection closed by server")
|
||
|
|
break
|
||
|
|
if ws_msg.type == aiohttp.WSMsgType.ERROR:
|
||
|
|
logger.error(f"Cat WS error frame: {ws.exception()}")
|
||
|
|
break
|
||
|
|
if ws_msg.type != aiohttp.WSMsgType.TEXT:
|
||
|
|
logger.debug(f"Cat WS non-text frame type: {ws_msg.type}")
|
||
|
|
continue
|
||
|
|
|
||
|
|
try:
|
||
|
|
msg = json.loads(ws_msg.data)
|
||
|
|
except (json.JSONDecodeError, TypeError) as e:
|
||
|
|
logger.warning(f"Cat WS non-JSON message: {e}")
|
||
|
|
continue
|
||
|
|
|
||
|
|
msg_type = msg.get("type", "")
|
||
|
|
|
||
|
|
if msg_type == "chat":
|
||
|
|
# Final response — extract text
|
||
|
|
reply_text = msg.get("content") or msg.get("text", "")
|
||
|
|
break
|
||
|
|
elif msg_type == "chat_token":
|
||
|
|
# Streaming token — skip, we wait for final
|
||
|
|
continue
|
||
|
|
elif msg_type == "error":
|
||
|
|
error_desc = msg.get("description", "Unknown Cat error")
|
||
|
|
logger.error(f"Cat WS error: {error_desc}")
|
||
|
|
break
|
||
|
|
elif msg_type == "notification":
|
||
|
|
logger.debug(f"Cat notification: {msg.get('content', '')}")
|
||
|
|
continue
|
||
|
|
else:
|
||
|
|
logger.debug(f"Cat WS unknown msg type: {msg_type}")
|
||
|
|
continue
|
||
|
|
|
||
|
|
if reply_text and reply_text.strip():
|
||
|
|
self._consecutive_failures = 0
|
||
|
|
logger.info(f"🐱 Cat response for {cat_user_id}: {reply_text[:100]}...")
|
||
|
|
return reply_text
|
||
|
|
else:
|
||
|
|
logger.warning("Cat returned empty response via WS")
|
||
|
|
self._consecutive_failures += 1
|
||
|
|
return None
|
||
|
|
|
||
|
|
except asyncio.TimeoutError:
|
||
|
|
logger.error(f"Cat WS connection timeout after {self._timeout}s")
|
||
|
|
self._consecutive_failures += 1
|
||
|
|
if self._consecutive_failures >= self._max_failures_before_circuit_break:
|
||
|
|
self._circuit_broken_until = time.time() + 60
|
||
|
|
logger.warning("Circuit breaker activated (WS timeout)")
|
||
|
|
return None
|
||
|
|
except Exception as e:
|
||
|
|
logger.error(f"Cat WS query error: {e}")
|
||
|
|
self._consecutive_failures += 1
|
||
|
|
if self._consecutive_failures >= self._max_failures_before_circuit_break:
|
||
|
|
self._circuit_broken_until = time.time() + 60
|
||
|
|
logger.warning(f"Circuit breaker activated: {e}")
|
||
|
|
return None
|
||
|
|
|
||
|
|
# ===================================================================
|
||
|
|
# MEMORY MANAGEMENT API (for Web UI)
|
||
|
|
# ===================================================================
|
||
|
|
|
||
|
|
async def get_memory_stats(self) -> Optional[Dict[str, Any]]:
|
||
|
|
"""
|
||
|
|
Get memory collection statistics from Cat.
|
||
|
|
Returns dict with collection names and point counts.
|
||
|
|
"""
|
||
|
|
try:
|
||
|
|
async with aiohttp.ClientSession() as session:
|
||
|
|
async with session.get(
|
||
|
|
f"{self._base_url}/memory/collections",
|
||
|
|
headers=self._get_headers(),
|
||
|
|
timeout=aiohttp.ClientTimeout(total=15)
|
||
|
|
) as response:
|
||
|
|
if response.status == 200:
|
||
|
|
data = await response.json()
|
||
|
|
return data
|
||
|
|
else:
|
||
|
|
logger.error(f"Failed to get memory stats: {response.status}")
|
||
|
|
return None
|
||
|
|
except Exception as e:
|
||
|
|
logger.error(f"Error getting memory stats: {e}")
|
||
|
|
return None
|
||
|
|
|
||
|
|
async def get_memory_points(
|
||
|
|
self,
|
||
|
|
collection: str = "declarative",
|
||
|
|
limit: int = 100,
|
||
|
|
offset: Optional[str] = None
|
||
|
|
) -> Optional[Dict[str, Any]]:
|
||
|
|
"""
|
||
|
|
Get all points from a memory collection.
|
||
|
|
Returns paginated list of memory points.
|
||
|
|
"""
|
||
|
|
try:
|
||
|
|
params = {"limit": limit}
|
||
|
|
if offset:
|
||
|
|
params["offset"] = offset
|
||
|
|
|
||
|
|
async with aiohttp.ClientSession() as session:
|
||
|
|
async with session.get(
|
||
|
|
f"{self._base_url}/memory/collections/{collection}/points",
|
||
|
|
headers=self._get_headers(),
|
||
|
|
params=params,
|
||
|
|
timeout=aiohttp.ClientTimeout(total=30)
|
||
|
|
) as response:
|
||
|
|
if response.status == 200:
|
||
|
|
return await response.json()
|
||
|
|
else:
|
||
|
|
logger.error(f"Failed to get {collection} points: {response.status}")
|
||
|
|
return None
|
||
|
|
except Exception as e:
|
||
|
|
logger.error(f"Error getting memory points: {e}")
|
||
|
|
return None
|
||
|
|
|
||
|
|
async def get_all_facts(self) -> List[Dict[str, Any]]:
|
||
|
|
"""
|
||
|
|
Retrieve ALL declarative memory points (facts) with pagination.
|
||
|
|
Returns a flat list of all fact dicts.
|
||
|
|
"""
|
||
|
|
all_facts = []
|
||
|
|
offset = None
|
||
|
|
|
||
|
|
try:
|
||
|
|
while True:
|
||
|
|
result = await self.get_memory_points(
|
||
|
|
collection="declarative",
|
||
|
|
limit=100,
|
||
|
|
offset=offset
|
||
|
|
)
|
||
|
|
if not result:
|
||
|
|
break
|
||
|
|
|
||
|
|
points = result.get("points", [])
|
||
|
|
for point in points:
|
||
|
|
payload = point.get("payload", {})
|
||
|
|
fact = {
|
||
|
|
"id": point.get("id"),
|
||
|
|
"content": payload.get("page_content", ""),
|
||
|
|
"metadata": payload.get("metadata", {}),
|
||
|
|
}
|
||
|
|
all_facts.append(fact)
|
||
|
|
|
||
|
|
offset = result.get("next_offset")
|
||
|
|
if not offset:
|
||
|
|
break
|
||
|
|
|
||
|
|
logger.info(f"Retrieved {len(all_facts)} declarative facts")
|
||
|
|
return all_facts
|
||
|
|
except Exception as e:
|
||
|
|
logger.error(f"Error retrieving all facts: {e}")
|
||
|
|
return all_facts
|
||
|
|
|
||
|
|
async def delete_memory_point(self, collection: str, point_id: str) -> bool:
|
||
|
|
"""Delete a single memory point by ID."""
|
||
|
|
try:
|
||
|
|
async with aiohttp.ClientSession() as session:
|
||
|
|
async with session.delete(
|
||
|
|
f"{self._base_url}/memory/collections/{collection}/points/{point_id}",
|
||
|
|
headers=self._get_headers(),
|
||
|
|
timeout=aiohttp.ClientTimeout(total=15)
|
||
|
|
) as response:
|
||
|
|
if response.status == 200:
|
||
|
|
logger.info(f"Deleted point {point_id} from {collection}")
|
||
|
|
return True
|
||
|
|
else:
|
||
|
|
logger.error(f"Failed to delete point: {response.status}")
|
||
|
|
return False
|
||
|
|
except Exception as e:
|
||
|
|
logger.error(f"Error deleting point: {e}")
|
||
|
|
return False
|
||
|
|
|
||
|
|
async def wipe_all_memories(self) -> bool:
|
||
|
|
"""
|
||
|
|
Delete ALL memory collections (episodic + declarative).
|
||
|
|
This is the nuclear option — requires multi-step confirmation in the UI.
|
||
|
|
"""
|
||
|
|
try:
|
||
|
|
async with aiohttp.ClientSession() as session:
|
||
|
|
async with session.delete(
|
||
|
|
f"{self._base_url}/memory/collections",
|
||
|
|
headers=self._get_headers(),
|
||
|
|
timeout=aiohttp.ClientTimeout(total=30)
|
||
|
|
) as response:
|
||
|
|
if response.status == 200:
|
||
|
|
logger.warning("🗑️ ALL memory collections wiped!")
|
||
|
|
return True
|
||
|
|
else:
|
||
|
|
error = await response.text()
|
||
|
|
logger.error(f"Failed to wipe memories: {response.status} - {error}")
|
||
|
|
return False
|
||
|
|
except Exception as e:
|
||
|
|
logger.error(f"Error wiping memories: {e}")
|
||
|
|
return False
|
||
|
|
|
||
|
|
async def wipe_conversation_history(self) -> bool:
|
||
|
|
"""Clear working memory / conversation history."""
|
||
|
|
try:
|
||
|
|
async with aiohttp.ClientSession() as session:
|
||
|
|
async with session.delete(
|
||
|
|
f"{self._base_url}/memory/conversation_history",
|
||
|
|
headers=self._get_headers(),
|
||
|
|
timeout=aiohttp.ClientTimeout(total=15)
|
||
|
|
) as response:
|
||
|
|
if response.status == 200:
|
||
|
|
logger.info("Conversation history cleared")
|
||
|
|
return True
|
||
|
|
else:
|
||
|
|
logger.error(f"Failed to clear conversation history: {response.status}")
|
||
|
|
return False
|
||
|
|
except Exception as e:
|
||
|
|
logger.error(f"Error clearing conversation history: {e}")
|
||
|
|
return False
|
||
|
|
|
||
|
|
async def trigger_consolidation(self) -> Optional[str]:
|
||
|
|
"""
|
||
|
|
Trigger memory consolidation by sending a special message via WebSocket.
|
||
|
|
The memory_consolidation plugin's tool 'consolidate_memories' is
|
||
|
|
triggered when it sees 'consolidate now' in the text.
|
||
|
|
Uses WebSocket with a system user ID for proper context.
|
||
|
|
"""
|
||
|
|
try:
|
||
|
|
ws_base = self._base_url.replace("http://", "ws://").replace("https://", "wss://")
|
||
|
|
ws_url = f"{ws_base}/ws/system_consolidation"
|
||
|
|
|
||
|
|
logger.info("🌙 Triggering memory consolidation via WS...")
|
||
|
|
|
||
|
|
async with aiohttp.ClientSession() as session:
|
||
|
|
async with session.ws_connect(
|
||
|
|
ws_url,
|
||
|
|
timeout=300, # Consolidation can be very slow
|
||
|
|
) as ws:
|
||
|
|
await ws.send_json({"text": "consolidate now"})
|
||
|
|
|
||
|
|
# Wait for the final chat response
|
||
|
|
deadline = asyncio.get_event_loop().time() + 300
|
||
|
|
|
||
|
|
while True:
|
||
|
|
remaining = deadline - asyncio.get_event_loop().time()
|
||
|
|
if remaining <= 0:
|
||
|
|
logger.error("Consolidation timed out (>300s)")
|
||
|
|
return "Consolidation timed out"
|
||
|
|
|
||
|
|
try:
|
||
|
|
ws_msg = await asyncio.wait_for(
|
||
|
|
ws.receive(),
|
||
|
|
timeout=remaining
|
||
|
|
)
|
||
|
|
except asyncio.TimeoutError:
|
||
|
|
logger.error("Consolidation WS receive timeout")
|
||
|
|
return "Consolidation timed out waiting for response"
|
||
|
|
|
||
|
|
if ws_msg.type in (aiohttp.WSMsgType.CLOSE, aiohttp.WSMsgType.CLOSING, aiohttp.WSMsgType.CLOSED):
|
||
|
|
logger.warning("Consolidation WS closed by server")
|
||
|
|
return "Connection closed during consolidation"
|
||
|
|
if ws_msg.type == aiohttp.WSMsgType.ERROR:
|
||
|
|
return f"WebSocket error: {ws.exception()}"
|
||
|
|
if ws_msg.type != aiohttp.WSMsgType.TEXT:
|
||
|
|
continue
|
||
|
|
|
||
|
|
try:
|
||
|
|
msg = json.loads(ws_msg.data)
|
||
|
|
except (json.JSONDecodeError, TypeError):
|
||
|
|
continue
|
||
|
|
|
||
|
|
msg_type = msg.get("type", "")
|
||
|
|
if msg_type == "chat":
|
||
|
|
reply = msg.get("content") or msg.get("text", "")
|
||
|
|
logger.info(f"Consolidation result: {reply[:200]}")
|
||
|
|
return reply
|
||
|
|
elif msg_type == "error":
|
||
|
|
error_desc = msg.get("description", "Unknown error")
|
||
|
|
logger.error(f"Consolidation error: {error_desc}")
|
||
|
|
return f"Consolidation error: {error_desc}"
|
||
|
|
else:
|
||
|
|
continue
|
||
|
|
|
||
|
|
except asyncio.TimeoutError:
|
||
|
|
logger.error("Consolidation WS connection timed out")
|
||
|
|
return None
|
||
|
|
except Exception as e:
|
||
|
|
logger.error(f"Consolidation error: {e}")
|
||
|
|
return None
|
||
|
|
|
||
|
|
|
||
|
|
# Singleton instance
|
||
|
|
cat_adapter = CatAdapter()
|