# utils/cat_client.py """ Cheshire Cat AI Adapter for Miku Discord Bot (Phase 3) Routes messages through the Cheshire Cat pipeline for: - Memory-augmented responses (episodic + declarative recall) - Fact extraction and consolidation - Per-user conversation isolation Uses WebSocket for chat (per-user isolation via /ws/{user_id}). Uses HTTP for memory management endpoints. Falls back to query_llama() on failure for zero-downtime resilience. """ import aiohttp import asyncio import json import time from typing import Optional, Dict, Any, List import globals from utils.logger import get_logger logger = get_logger('cat_client') class CatAdapter: """ Async adapter for Cheshire Cat AI. Uses WebSocket /ws/{user_id} for conversation (per-user memory isolation). Uses HTTP REST for memory management endpoints. Without API keys configured, HTTP POST /message defaults all users to user_id="user" (no isolation). WebSocket path param gives true isolation. """ def __init__(self): self._base_url = globals.CHESHIRE_CAT_URL.rstrip('/') self._api_key = globals.CHESHIRE_CAT_API_KEY self._timeout = globals.CHESHIRE_CAT_TIMEOUT self._healthy = None # None = unknown, True/False = last check result self._last_health_check = 0 self._health_check_interval = 30 # seconds between health checks self._consecutive_failures = 0 self._max_failures_before_circuit_break = 3 self._circuit_broken_until = 0 # timestamp when circuit breaker resets logger.info(f"CatAdapter initialized: {self._base_url} (timeout={self._timeout}s)") def _get_headers(self) -> dict: """Build request headers with optional auth.""" headers = {'Content-Type': 'application/json'} if self._api_key: headers['Authorization'] = f'Bearer {self._api_key}' return headers def _user_id_for_discord(self, user_id: str) -> str: """ Format Discord user ID for Cat's user namespace. Cat uses user_id to isolate working memory and episodic memories. """ return f"discord_{user_id}" async def health_check(self) -> bool: """ Check if Cheshire Cat is reachable and healthy. Caches result to avoid hammering the endpoint. """ now = time.time() if now - self._last_health_check < self._health_check_interval and self._healthy is not None: return self._healthy try: async with aiohttp.ClientSession() as session: async with session.get( f"{self._base_url}/", headers=self._get_headers(), timeout=aiohttp.ClientTimeout(total=10) ) as response: self._healthy = response.status == 200 self._last_health_check = now if self._healthy: logger.debug("Cat health check: OK") else: logger.warning(f"Cat health check failed: status {response.status}") return self._healthy except Exception as e: self._healthy = False self._last_health_check = now logger.warning(f"Cat health check error: {e}") return False def _is_circuit_broken(self) -> bool: """Check if circuit breaker is active (too many consecutive failures).""" if self._consecutive_failures >= self._max_failures_before_circuit_break: if time.time() < self._circuit_broken_until: return True # Circuit breaker expired, allow retry logger.info("Circuit breaker reset, allowing Cat retry") self._consecutive_failures = 0 return False async def query( self, text: str, user_id: str, guild_id: Optional[str] = None, author_name: Optional[str] = None, mood: Optional[str] = None, response_type: str = "dm_response", ) -> Optional[str]: """ Send a message through the Cat pipeline via WebSocket and get a response. Uses WebSocket /ws/{user_id} for per-user memory isolation. Without API keys, HTTP POST /message defaults all users to user_id="user" (no isolation). The WebSocket path parameter provides true per-user isolation because Cat's auth handler uses user_id from the path when no keys are set. Args: text: User's message text user_id: Discord user ID (will be namespaced as discord_{user_id}) guild_id: Optional guild ID for server context author_name: Display name of the user mood: Current mood name (passed as metadata for Cat hooks) response_type: Type of response context Returns: Cat's response text, or None if Cat is unavailable (caller should fallback) """ if not globals.USE_CHESHIRE_CAT: return None if self._is_circuit_broken(): logger.debug("Circuit breaker active, skipping Cat") return None cat_user_id = self._user_id_for_discord(user_id) # Build message payload with Discord metadata for our plugin hooks. # The discord_bridge plugin's before_cat_reads_message hook reads # these custom keys from the message dict. payload = { "text": text, } if guild_id: payload["discord_guild_id"] = str(guild_id) if author_name: payload["discord_author_name"] = author_name if mood: payload["discord_mood"] = mood if response_type: payload["discord_response_type"] = response_type try: # Build WebSocket URL from HTTP base URL ws_base = self._base_url.replace("http://", "ws://").replace("https://", "wss://") ws_url = f"{ws_base}/ws/{cat_user_id}" logger.debug(f"Querying Cat via WS: user={cat_user_id}, text={text[:80]}...") async with aiohttp.ClientSession() as session: async with session.ws_connect( ws_url, timeout=self._timeout, ) as ws: # Send the message await ws.send_json(payload) # Read responses until we get the final "chat" type message. # Cat may send intermediate messages (chat_token for streaming, # notification for status updates). We want the final "chat" one. reply_text = None deadline = asyncio.get_event_loop().time() + self._timeout while True: remaining = deadline - asyncio.get_event_loop().time() if remaining <= 0: logger.error(f"Cat WS timeout after {self._timeout}s") break try: ws_msg = await asyncio.wait_for( ws.receive(), timeout=remaining ) except asyncio.TimeoutError: logger.error(f"Cat WS receive timeout after {self._timeout}s") break # Handle WebSocket close/error frames if ws_msg.type in (aiohttp.WSMsgType.CLOSE, aiohttp.WSMsgType.CLOSING, aiohttp.WSMsgType.CLOSED): logger.warning("Cat WS connection closed by server") break if ws_msg.type == aiohttp.WSMsgType.ERROR: logger.error(f"Cat WS error frame: {ws.exception()}") break if ws_msg.type != aiohttp.WSMsgType.TEXT: logger.debug(f"Cat WS non-text frame type: {ws_msg.type}") continue try: msg = json.loads(ws_msg.data) except (json.JSONDecodeError, TypeError) as e: logger.warning(f"Cat WS non-JSON message: {e}") continue msg_type = msg.get("type", "") if msg_type == "chat": # Final response — extract text reply_text = msg.get("content") or msg.get("text", "") break elif msg_type == "chat_token": # Streaming token — skip, we wait for final continue elif msg_type == "error": error_desc = msg.get("description", "Unknown Cat error") logger.error(f"Cat WS error: {error_desc}") break elif msg_type == "notification": logger.debug(f"Cat notification: {msg.get('content', '')}") continue else: logger.debug(f"Cat WS unknown msg type: {msg_type}") continue if reply_text and reply_text.strip(): self._consecutive_failures = 0 logger.info(f"🐱 Cat response for {cat_user_id}: {reply_text[:100]}...") return reply_text else: logger.warning("Cat returned empty response via WS") self._consecutive_failures += 1 return None except asyncio.TimeoutError: logger.error(f"Cat WS connection timeout after {self._timeout}s") self._consecutive_failures += 1 if self._consecutive_failures >= self._max_failures_before_circuit_break: self._circuit_broken_until = time.time() + 60 logger.warning("Circuit breaker activated (WS timeout)") return None except Exception as e: logger.error(f"Cat WS query error: {e}") self._consecutive_failures += 1 if self._consecutive_failures >= self._max_failures_before_circuit_break: self._circuit_broken_until = time.time() + 60 logger.warning(f"Circuit breaker activated: {e}") return None # =================================================================== # MEMORY MANAGEMENT API (for Web UI) # =================================================================== async def get_memory_stats(self) -> Optional[Dict[str, Any]]: """ Get memory collection statistics from Cat. Returns dict with collection names and point counts. """ try: async with aiohttp.ClientSession() as session: async with session.get( f"{self._base_url}/memory/collections", headers=self._get_headers(), timeout=aiohttp.ClientTimeout(total=15) ) as response: if response.status == 200: data = await response.json() return data else: logger.error(f"Failed to get memory stats: {response.status}") return None except Exception as e: logger.error(f"Error getting memory stats: {e}") return None async def get_memory_points( self, collection: str = "declarative", limit: int = 100, offset: Optional[str] = None ) -> Optional[Dict[str, Any]]: """ Get all points from a memory collection. Returns paginated list of memory points. """ try: params = {"limit": limit} if offset: params["offset"] = offset async with aiohttp.ClientSession() as session: async with session.get( f"{self._base_url}/memory/collections/{collection}/points", headers=self._get_headers(), params=params, timeout=aiohttp.ClientTimeout(total=30) ) as response: if response.status == 200: return await response.json() else: logger.error(f"Failed to get {collection} points: {response.status}") return None except Exception as e: logger.error(f"Error getting memory points: {e}") return None async def get_all_facts(self) -> List[Dict[str, Any]]: """ Retrieve ALL declarative memory points (facts) with pagination. Returns a flat list of all fact dicts. """ all_facts = [] offset = None try: while True: result = await self.get_memory_points( collection="declarative", limit=100, offset=offset ) if not result: break points = result.get("points", []) for point in points: payload = point.get("payload", {}) fact = { "id": point.get("id"), "content": payload.get("page_content", ""), "metadata": payload.get("metadata", {}), } all_facts.append(fact) offset = result.get("next_offset") if not offset: break logger.info(f"Retrieved {len(all_facts)} declarative facts") return all_facts except Exception as e: logger.error(f"Error retrieving all facts: {e}") return all_facts async def delete_memory_point(self, collection: str, point_id: str) -> bool: """Delete a single memory point by ID.""" try: async with aiohttp.ClientSession() as session: async with session.delete( f"{self._base_url}/memory/collections/{collection}/points/{point_id}", headers=self._get_headers(), timeout=aiohttp.ClientTimeout(total=15) ) as response: if response.status == 200: logger.info(f"Deleted point {point_id} from {collection}") return True else: logger.error(f"Failed to delete point: {response.status}") return False except Exception as e: logger.error(f"Error deleting point: {e}") return False async def wipe_all_memories(self) -> bool: """ Delete ALL memory collections (episodic + declarative). This is the nuclear option — requires multi-step confirmation in the UI. """ try: async with aiohttp.ClientSession() as session: async with session.delete( f"{self._base_url}/memory/collections", headers=self._get_headers(), timeout=aiohttp.ClientTimeout(total=30) ) as response: if response.status == 200: logger.warning("🗑️ ALL memory collections wiped!") return True else: error = await response.text() logger.error(f"Failed to wipe memories: {response.status} - {error}") return False except Exception as e: logger.error(f"Error wiping memories: {e}") return False async def wipe_conversation_history(self) -> bool: """Clear working memory / conversation history.""" try: async with aiohttp.ClientSession() as session: async with session.delete( f"{self._base_url}/memory/conversation_history", headers=self._get_headers(), timeout=aiohttp.ClientTimeout(total=15) ) as response: if response.status == 200: logger.info("Conversation history cleared") return True else: logger.error(f"Failed to clear conversation history: {response.status}") return False except Exception as e: logger.error(f"Error clearing conversation history: {e}") return False async def trigger_consolidation(self) -> Optional[str]: """ Trigger memory consolidation by sending a special message via WebSocket. The memory_consolidation plugin's tool 'consolidate_memories' is triggered when it sees 'consolidate now' in the text. Uses WebSocket with a system user ID for proper context. """ try: ws_base = self._base_url.replace("http://", "ws://").replace("https://", "wss://") ws_url = f"{ws_base}/ws/system_consolidation" logger.info("🌙 Triggering memory consolidation via WS...") async with aiohttp.ClientSession() as session: async with session.ws_connect( ws_url, timeout=300, # Consolidation can be very slow ) as ws: await ws.send_json({"text": "consolidate now"}) # Wait for the final chat response deadline = asyncio.get_event_loop().time() + 300 while True: remaining = deadline - asyncio.get_event_loop().time() if remaining <= 0: logger.error("Consolidation timed out (>300s)") return "Consolidation timed out" try: ws_msg = await asyncio.wait_for( ws.receive(), timeout=remaining ) except asyncio.TimeoutError: logger.error("Consolidation WS receive timeout") return "Consolidation timed out waiting for response" if ws_msg.type in (aiohttp.WSMsgType.CLOSE, aiohttp.WSMsgType.CLOSING, aiohttp.WSMsgType.CLOSED): logger.warning("Consolidation WS closed by server") return "Connection closed during consolidation" if ws_msg.type == aiohttp.WSMsgType.ERROR: return f"WebSocket error: {ws.exception()}" if ws_msg.type != aiohttp.WSMsgType.TEXT: continue try: msg = json.loads(ws_msg.data) except (json.JSONDecodeError, TypeError): continue msg_type = msg.get("type", "") if msg_type == "chat": reply = msg.get("content") or msg.get("text", "") logger.info(f"Consolidation result: {reply[:200]}") return reply elif msg_type == "error": error_desc = msg.get("description", "Unknown error") logger.error(f"Consolidation error: {error_desc}") return f"Consolidation error: {error_desc}" else: continue except asyncio.TimeoutError: logger.error("Consolidation WS connection timed out") return None except Exception as e: logger.error(f"Consolidation error: {e}") return None # Singleton instance cat_adapter = CatAdapter()