# utils/cat_client.py """ Cheshire Cat AI Adapter for Miku Discord Bot (Phase 3) Routes messages through the Cheshire Cat pipeline for: - Memory-augmented responses (episodic + declarative recall) - Fact extraction and consolidation - Per-user conversation isolation Uses WebSocket for chat (per-user isolation via /ws/{user_id}). Uses HTTP for memory management endpoints. Falls back to query_llama() on failure for zero-downtime resilience. """ import aiohttp import asyncio import json import time from typing import Optional, Dict, Any, List import globals from utils.logger import get_logger logger = get_logger('llm') # Use existing 'llm' logger component class CatAdapter: """ Async adapter for Cheshire Cat AI. Uses WebSocket /ws/{user_id} for conversation (per-user memory isolation). Uses HTTP REST for memory management endpoints. Without API keys configured, HTTP POST /message defaults all users to user_id="user" (no isolation). WebSocket path param gives true isolation. """ def __init__(self): self._base_url = globals.CHESHIRE_CAT_URL.rstrip('/') self._api_key = globals.CHESHIRE_CAT_API_KEY self._timeout = globals.CHESHIRE_CAT_TIMEOUT self._healthy = None # None = unknown, True/False = last check result self._last_health_check = 0 self._health_check_interval = 30 # seconds between health checks self._consecutive_failures = 0 self._max_failures_before_circuit_break = 3 self._circuit_broken_until = 0 # timestamp when circuit breaker resets logger.info(f"CatAdapter initialized: {self._base_url} (timeout={self._timeout}s)") def _get_headers(self) -> dict: """Build request headers with optional auth.""" headers = {'Content-Type': 'application/json'} if self._api_key: headers['Authorization'] = f'Bearer {self._api_key}' return headers def _user_id_for_discord(self, user_id: str) -> str: """ Format Discord user ID for Cat's user namespace. Cat uses user_id to isolate working memory and episodic memories. """ return f"discord_{user_id}" async def health_check(self) -> bool: """ Check if Cheshire Cat is reachable and healthy. Caches result to avoid hammering the endpoint. """ now = time.time() if now - self._last_health_check < self._health_check_interval and self._healthy is not None: return self._healthy try: async with aiohttp.ClientSession() as session: async with session.get( f"{self._base_url}/", headers=self._get_headers(), timeout=aiohttp.ClientTimeout(total=10) ) as response: self._healthy = response.status == 200 self._last_health_check = now if self._healthy: logger.debug("Cat health check: OK") else: logger.warning(f"Cat health check failed: status {response.status}") return self._healthy except Exception as e: self._healthy = False self._last_health_check = now logger.warning(f"Cat health check error: {e}") return False def _is_circuit_broken(self) -> bool: """Check if circuit breaker is active (too many consecutive failures).""" if self._consecutive_failures >= self._max_failures_before_circuit_break: if time.time() < self._circuit_broken_until: return True # Circuit breaker expired, allow retry logger.info("Circuit breaker reset, allowing Cat retry") self._consecutive_failures = 0 return False async def query( self, text: str, user_id: str, guild_id: Optional[str] = None, author_name: Optional[str] = None, mood: Optional[str] = None, response_type: str = "dm_response", media_type: Optional[str] = None, ) -> Optional[tuple]: """ Send a message through the Cat pipeline via WebSocket and get a response. Uses WebSocket /ws/{user_id} for per-user memory isolation. Without API keys, HTTP POST /message defaults all users to user_id="user" (no isolation). The WebSocket path parameter provides true per-user isolation because Cat's auth handler uses user_id from the path when no keys are set. Args: text: User's message text user_id: Discord user ID (will be namespaced as discord_{user_id}) guild_id: Optional guild ID for server context author_name: Display name of the user mood: Current mood name (passed as metadata for Cat hooks) response_type: Type of response context media_type: Type of media attachment ("image", "video", "gif", "tenor_gif") Returns: Tuple of (response_text, full_prompt) on success, or None if Cat is unavailable (caller should fallback to query_llama) """ if not globals.USE_CHESHIRE_CAT: return None if self._is_circuit_broken(): logger.debug("Circuit breaker active, skipping Cat") return None cat_user_id = self._user_id_for_discord(user_id) # Build message payload with Discord metadata for our plugin hooks. # The discord_bridge plugin's before_cat_reads_message hook reads # these custom keys from the message dict. payload = { "text": text, } if guild_id: payload["discord_guild_id"] = str(guild_id) if author_name: payload["discord_author_name"] = author_name # When evil mode is active, send the evil mood name instead of the normal mood if globals.EVIL_MODE: payload["discord_mood"] = getattr(globals, 'EVIL_DM_MOOD', 'evil_neutral') elif mood: payload["discord_mood"] = mood if response_type: payload["discord_response_type"] = response_type # Pass evil mode flag so discord_bridge stores it in working_memory payload["discord_evil_mode"] = globals.EVIL_MODE # Pass media type so discord_bridge can add MEDIA NOTE to the prompt if media_type: payload["discord_media_type"] = media_type try: # Build WebSocket URL from HTTP base URL ws_base = self._base_url.replace("http://", "ws://").replace("https://", "wss://") ws_url = f"{ws_base}/ws/{cat_user_id}" logger.debug(f"Querying Cat via WS: user={cat_user_id}, text={text[:80]}...") async with aiohttp.ClientSession() as session: async with session.ws_connect( ws_url, timeout=self._timeout, ) as ws: # Send the message await ws.send_json(payload) # Read responses until we get the final "chat" type message. # Cat may send intermediate messages (chat_token for streaming, # notification for status updates). We want the final "chat" one. reply_text = None full_prompt = "" deadline = asyncio.get_event_loop().time() + self._timeout while True: remaining = deadline - asyncio.get_event_loop().time() if remaining <= 0: logger.error(f"Cat WS timeout after {self._timeout}s") break try: ws_msg = await asyncio.wait_for( ws.receive(), timeout=remaining ) except asyncio.TimeoutError: logger.error(f"Cat WS receive timeout after {self._timeout}s") break # Handle WebSocket close/error frames if ws_msg.type in (aiohttp.WSMsgType.CLOSE, aiohttp.WSMsgType.CLOSING, aiohttp.WSMsgType.CLOSED): logger.warning("Cat WS connection closed by server") break if ws_msg.type == aiohttp.WSMsgType.ERROR: logger.error(f"Cat WS error frame: {ws.exception()}") break if ws_msg.type != aiohttp.WSMsgType.TEXT: logger.debug(f"Cat WS non-text frame type: {ws_msg.type}") continue try: msg = json.loads(ws_msg.data) except (json.JSONDecodeError, TypeError) as e: logger.warning(f"Cat WS non-JSON message: {e}") continue msg_type = msg.get("type", "") if msg_type == "chat": # Final response — extract text and full prompt reply_text = msg.get("content") or msg.get("text", "") full_prompt = msg.get("full_prompt", "") break elif msg_type == "chat_token": # Streaming token — skip, we wait for final continue elif msg_type == "error": error_desc = msg.get("description", "Unknown Cat error") logger.error(f"Cat WS error: {error_desc}") break elif msg_type == "notification": logger.debug(f"Cat notification: {msg.get('content', '')}") continue else: logger.debug(f"Cat WS unknown msg type: {msg_type}") continue if reply_text and reply_text.strip(): self._consecutive_failures = 0 logger.info(f"🐱 Cat response for {cat_user_id}: {reply_text[:100]}...") return reply_text, full_prompt else: logger.warning("Cat returned empty response via WS") self._consecutive_failures += 1 return None except asyncio.TimeoutError: logger.error(f"Cat WS connection timeout after {self._timeout}s") self._consecutive_failures += 1 if self._consecutive_failures >= self._max_failures_before_circuit_break: self._circuit_broken_until = time.time() + 60 logger.warning("Circuit breaker activated (WS timeout)") return None except Exception as e: logger.error(f"Cat WS query error: {e}") self._consecutive_failures += 1 if self._consecutive_failures >= self._max_failures_before_circuit_break: self._circuit_broken_until = time.time() + 60 logger.warning(f"Circuit breaker activated: {e}") return None # =================================================================== # MEMORY MANAGEMENT API (for Web UI) # =================================================================== async def get_memory_stats(self) -> Optional[Dict[str, Any]]: """ Get memory collection statistics with actual counts from Qdrant. Returns dict with collection names and point counts. """ try: # Query Qdrant directly for accurate counts qdrant_host = self._base_url.replace("http://cheshire-cat:80", "http://cheshire-cat-vector-memory:6333") collections_data = [] for collection_name in ["episodic", "declarative", "procedural"]: async with aiohttp.ClientSession() as session: async with session.get( f"{qdrant_host}/collections/{collection_name}", timeout=aiohttp.ClientTimeout(total=10) ) as response: if response.status == 200: data = await response.json() count = data.get("result", {}).get("points_count", 0) collections_data.append({ "name": collection_name, "vectors_count": count }) else: collections_data.append({ "name": collection_name, "vectors_count": 0 }) return {"collections": collections_data} except Exception as e: logger.error(f"Error getting memory stats from Qdrant: {e}") return None async def get_memory_points( self, collection: str = "declarative", limit: int = 100, offset: Optional[str] = None ) -> Optional[Dict[str, Any]]: """ Get all points from a memory collection via Qdrant. Cat doesn't expose /memory/collections/{id}/points, so we query Qdrant directly. Returns paginated list of memory points. """ try: # Use Qdrant directly (Cat's vector memory backend) # Qdrant is accessible at the same host, port 6333 internally qdrant_host = self._base_url.replace("http://cheshire-cat:80", "http://cheshire-cat-vector-memory:6333") payload = {"limit": limit, "with_payload": True, "with_vector": False} if offset: payload["offset"] = offset async with aiohttp.ClientSession() as session: async with session.post( f"{qdrant_host}/collections/{collection}/points/scroll", json=payload, timeout=aiohttp.ClientTimeout(total=30) ) as response: if response.status == 200: data = await response.json() return data.get("result", {}) else: logger.error(f"Failed to get {collection} points from Qdrant: {response.status}") return None except Exception as e: logger.error(f"Error getting memory points from Qdrant: {e}") return None async def get_all_facts(self) -> List[Dict[str, Any]]: """ Retrieve ALL declarative memory points (facts) with pagination. Returns a flat list of all fact dicts. """ all_facts = [] offset = None try: while True: result = await self.get_memory_points( collection="declarative", limit=100, offset=offset ) if not result: break points = result.get("points", []) for point in points: payload = point.get("payload", {}) fact = { "id": point.get("id"), "content": payload.get("page_content", ""), "metadata": payload.get("metadata", {}), } all_facts.append(fact) offset = result.get("next_offset") if not offset: break logger.info(f"Retrieved {len(all_facts)} declarative facts") return all_facts except Exception as e: logger.error(f"Error retrieving all facts: {e}") return all_facts async def delete_memory_point(self, collection: str, point_id: str) -> bool: """Delete a single memory point by ID via Qdrant.""" try: qdrant_host = self._base_url.replace("http://cheshire-cat:80", "http://cheshire-cat-vector-memory:6333") async with aiohttp.ClientSession() as session: async with session.post( f"{qdrant_host}/collections/{collection}/points/delete", json={"points": [point_id]}, timeout=aiohttp.ClientTimeout(total=15) ) as response: if response.status == 200: logger.info(f"Deleted memory point {point_id} from {collection}") return True else: logger.error(f"Failed to delete point: {response.status}") return False except Exception as e: logger.error(f"Error deleting memory point: {e}") return False async def update_memory_point(self, collection: str, point_id: str, content: str, metadata: dict = None) -> bool: """Update an existing memory point's content and/or metadata.""" try: # First, get the existing point to retrieve its vector qdrant_host = self._base_url.replace("http://cheshire-cat:80", "http://cheshire-cat-vector-memory:6333") async with aiohttp.ClientSession() as session: # Get existing point async with session.post( f"{qdrant_host}/collections/{collection}/points", json={"ids": [point_id], "with_vector": True, "with_payload": True}, timeout=aiohttp.ClientTimeout(total=15) ) as response: if response.status != 200: logger.error(f"Failed to fetch point {point_id}: {response.status}") return False data = await response.json() points = data.get("result", []) if not points: logger.error(f"Point {point_id} not found") return False existing_point = points[0] existing_vector = existing_point.get("vector") existing_payload = existing_point.get("payload", {}) # If content changed, we need to re-embed it if content != existing_payload.get("page_content"): # Call Cat's embedder to get new vector embed_response = await session.post( f"{self._base_url}/embedder", json={"text": content}, headers=self._get_headers(), timeout=aiohttp.ClientTimeout(total=30) ) if embed_response.status == 200: embed_data = await embed_response.json() new_vector = embed_data.get("embedding") else: logger.warning(f"Failed to re-embed content, keeping old vector") new_vector = existing_vector else: new_vector = existing_vector # Build updated payload updated_payload = { "page_content": content, "metadata": metadata if metadata is not None else existing_payload.get("metadata", {}) } # Update the point async with session.put( f"{qdrant_host}/collections/{collection}/points", json={ "points": [{ "id": point_id, "vector": new_vector, "payload": updated_payload }] }, timeout=aiohttp.ClientTimeout(total=15) ) as update_response: if update_response.status == 200: logger.info(f"✏️ Updated memory point {point_id} in {collection}") return True else: logger.error(f"Failed to update point: {update_response.status}") return False except Exception as e: logger.error(f"Error updating memory point: {e}") return False async def create_memory_point(self, collection: str, content: str, user_id: str, source: str, metadata: dict = None) -> Optional[str]: """Create a new memory point manually.""" try: import uuid import time # Generate a unique ID point_id = str(uuid.uuid4()) # Get vector embedding from Cat async with aiohttp.ClientSession() as session: async with session.post( f"{self._base_url}/embedder", json={"text": content}, headers=self._get_headers(), timeout=aiohttp.ClientTimeout(total=30) ) as response: if response.status != 200: logger.error(f"Failed to embed content: {response.status}") return None data = await response.json() vector = data.get("embedding") if not vector: logger.error("No embedding returned from Cat") return None # Build payload payload = { "page_content": content, "metadata": metadata or {} } payload["metadata"]["source"] = source payload["metadata"]["when"] = time.time() # For declarative memories, add user_id to metadata # For episodic, it's in the source field if collection == "declarative": payload["metadata"]["user_id"] = user_id elif collection == "episodic": payload["metadata"]["source"] = user_id # Insert into Qdrant qdrant_host = self._base_url.replace("http://cheshire-cat:80", "http://cheshire-cat-vector-memory:6333") async with session.put( f"{qdrant_host}/collections/{collection}/points", json={ "points": [{ "id": point_id, "vector": vector, "payload": payload }] }, timeout=aiohttp.ClientTimeout(total=15) ) as insert_response: if insert_response.status == 200: logger.info(f"✨ Created new {collection} memory point: {point_id}") return point_id else: logger.error(f"Failed to insert point: {insert_response.status}") return None except Exception as e: logger.error(f"Error creating memory point: {e}") return None async def wipe_all_memories(self) -> bool: """ Delete ALL memory collections (episodic + declarative). This is the nuclear option — requires multi-step confirmation in the UI. """ try: async with aiohttp.ClientSession() as session: async with session.delete( f"{self._base_url}/memory/collections", headers=self._get_headers(), timeout=aiohttp.ClientTimeout(total=30) ) as response: if response.status == 200: logger.warning("🗑️ ALL memory collections wiped!") return True else: error = await response.text() logger.error(f"Failed to wipe memories: {response.status} - {error}") return False except Exception as e: logger.error(f"Error wiping memories: {e}") return False async def wipe_conversation_history(self) -> bool: """Clear working memory / conversation history.""" try: async with aiohttp.ClientSession() as session: async with session.delete( f"{self._base_url}/memory/conversation_history", headers=self._get_headers(), timeout=aiohttp.ClientTimeout(total=15) ) as response: if response.status == 200: logger.info("Conversation history cleared") return True else: logger.error(f"Failed to clear conversation history: {response.status}") return False except Exception as e: logger.error(f"Error clearing conversation history: {e}") return False async def trigger_consolidation(self) -> Optional[str]: """ Trigger memory consolidation by sending a special message via WebSocket. The memory_consolidation plugin's tool 'consolidate_memories' is triggered when it sees 'consolidate now' in the text. Uses WebSocket with a system user ID for proper context. """ try: ws_base = self._base_url.replace("http://", "ws://").replace("https://", "wss://") ws_url = f"{ws_base}/ws/system_consolidation" logger.info("🌙 Triggering memory consolidation via WS...") async with aiohttp.ClientSession() as session: async with session.ws_connect( ws_url, timeout=300, # Consolidation can be very slow ) as ws: await ws.send_json({"text": "consolidate now"}) # Wait for the final chat response deadline = asyncio.get_event_loop().time() + 300 while True: remaining = deadline - asyncio.get_event_loop().time() if remaining <= 0: logger.error("Consolidation timed out (>300s)") return "Consolidation timed out" try: ws_msg = await asyncio.wait_for( ws.receive(), timeout=remaining ) except asyncio.TimeoutError: logger.error("Consolidation WS receive timeout") return "Consolidation timed out waiting for response" if ws_msg.type in (aiohttp.WSMsgType.CLOSE, aiohttp.WSMsgType.CLOSING, aiohttp.WSMsgType.CLOSED): logger.warning("Consolidation WS closed by server") return "Connection closed during consolidation" if ws_msg.type == aiohttp.WSMsgType.ERROR: return f"WebSocket error: {ws.exception()}" if ws_msg.type != aiohttp.WSMsgType.TEXT: continue try: msg = json.loads(ws_msg.data) except (json.JSONDecodeError, TypeError): continue msg_type = msg.get("type", "") if msg_type == "chat": reply = msg.get("content") or msg.get("text", "") logger.info(f"Consolidation result: {reply[:200]}") return reply elif msg_type == "error": error_desc = msg.get("description", "Unknown error") logger.error(f"Consolidation error: {error_desc}") return f"Consolidation error: {error_desc}" else: continue except asyncio.TimeoutError: logger.error("Consolidation WS connection timed out") return None except Exception as e: logger.error(f"Consolidation error: {e}") return None # ==================================================================== # Admin API helpers – plugin toggling & LLM model switching # ==================================================================== async def wait_for_ready(self, max_wait: int = 120, interval: int = 5) -> bool: """Wait for Cat to become reachable, polling with interval. Used on startup to avoid race conditions when bot starts before Cat. Returns True once Cat responds, False if max_wait exceeded. """ start = time.time() attempt = 0 while time.time() - start < max_wait: attempt += 1 try: async with aiohttp.ClientSession() as session: async with session.get( f"{self._base_url}/", timeout=aiohttp.ClientTimeout(total=5), ) as resp: if resp.status == 200: elapsed = time.time() - start logger.info(f"🐱 Cat is ready (took {elapsed:.1f}s, {attempt} attempts)") self._healthy = True self._last_health_check = time.time() return True except Exception: pass if attempt == 1: logger.info(f"⏳ Waiting for Cat to become ready (up to {max_wait}s)...") await asyncio.sleep(interval) logger.error(f"Cat did not become ready within {max_wait}s ({attempt} attempts)") return False async def toggle_plugin(self, plugin_id: str) -> bool: """Toggle a Cat plugin on/off via the admin API. PUT /plugins/toggle/{plugin_id} Returns True on success, False on failure. """ url = f"{self._base_url}/plugins/toggle/{plugin_id}" try: async with aiohttp.ClientSession() as session: async with session.put( url, headers=self._get_headers(), timeout=aiohttp.ClientTimeout(total=15), ) as resp: if resp.status == 200: logger.info(f"🐱 Toggled Cat plugin: {plugin_id}") return True else: body = await resp.text() logger.error(f"Cat plugin toggle failed ({resp.status}): {body}") return False except Exception as e: logger.error(f"Cat plugin toggle error for {plugin_id}: {e}") return False async def set_llm_model(self, model_name: str) -> bool: """Switch the Cheshire Cat's active LLM model via settings API. The Cat settings API uses UUIDs: we must first GET /settings/ to find the setting_id for LLMOpenAIChatConfig, then PUT /settings/{setting_id}. llama-swap handles the actual model loading based on model_name. Returns True on success, False on failure. """ try: # Step 1: Find the setting_id for LLMOpenAIChatConfig setting_id = None async with aiohttp.ClientSession() as session: async with session.get( f"{self._base_url}/settings/", headers=self._get_headers(), timeout=aiohttp.ClientTimeout(total=10), ) as resp: if resp.status != 200: logger.error(f"Cat settings GET failed ({resp.status})") return False data = await resp.json() for s in data.get("settings", []): if s.get("name") == "LLMOpenAIChatConfig": setting_id = s["setting_id"] break if not setting_id: logger.error("Could not find LLMOpenAIChatConfig setting_id in Cat settings") return False # Step 2: PUT updated config to /settings/{setting_id} payload = { "name": "LLMOpenAIChatConfig", "value": { "openai_api_key": "sk-dummy", "model_name": model_name, "temperature": 0.8, "streaming": False, }, "category": "llm_factory", } async with aiohttp.ClientSession() as session: async with session.put( f"{self._base_url}/settings/{setting_id}", json=payload, headers=self._get_headers(), timeout=aiohttp.ClientTimeout(total=15), ) as resp: if resp.status == 200: logger.info(f"🐱 Set Cat LLM model to: {model_name}") return True else: body = await resp.text() logger.error(f"Cat LLM model switch failed ({resp.status}): {body}") return False except Exception as e: logger.error(f"Cat LLM model switch error: {e}") return False async def get_active_plugins(self) -> list: """Get list of active Cat plugin IDs. GET /plugins → returns {\"installed\": [...], \"filters\": {...}} Each plugin has \"id\" and \"active\" fields. """ url = f"{self._base_url}/plugins" try: async with aiohttp.ClientSession() as session: async with session.get( url, headers=self._get_headers(), timeout=aiohttp.ClientTimeout(total=10), ) as resp: if resp.status == 200: data = await resp.json() installed = data.get("installed", []) return [p["id"] for p in installed if p.get("active")] else: logger.error(f"Cat get_active_plugins failed ({resp.status})") return [] except Exception as e: logger.error(f"Cat get_active_plugins error: {e}") return [] async def switch_to_evil_personality(self) -> bool: """Disable miku_personality, enable evil_miku_personality, switch LLM to darkidol. Checks current plugin state first to avoid double-toggling (the Cat API is a toggle, not enable/disable). Returns True if all operations succeed, False if any fail. """ logger.info("🐱 Switching Cat to Evil Miku personality...") success = True # Check current plugin state active = await self.get_active_plugins() # Step 1: Disable normal personality (only if currently active) if "miku_personality" in active: if not await self.toggle_plugin("miku_personality"): logger.error("Failed to disable miku_personality plugin") success = False await asyncio.sleep(1) else: logger.debug("miku_personality already disabled, skipping toggle") # Step 2: Enable evil personality (only if currently inactive) if "evil_miku_personality" not in active: if not await self.toggle_plugin("evil_miku_personality"): logger.error("Failed to enable evil_miku_personality plugin") success = False else: logger.debug("evil_miku_personality already active, skipping toggle") # Step 3: Switch LLM model to darkidol (the uncensored evil model) if not await self.set_llm_model("darkidol"): logger.error("Failed to switch Cat LLM to darkidol") success = False return success async def switch_to_normal_personality(self) -> bool: """Disable evil_miku_personality, enable miku_personality, switch LLM to llama3.1. Checks current plugin state first to avoid double-toggling. Returns True if all operations succeed, False if any fail. """ logger.info("🐱 Switching Cat to normal Miku personality...") success = True # Check current plugin state active = await self.get_active_plugins() # Step 1: Disable evil personality (only if currently active) if "evil_miku_personality" in active: if not await self.toggle_plugin("evil_miku_personality"): logger.error("Failed to disable evil_miku_personality plugin") success = False await asyncio.sleep(1) else: logger.debug("evil_miku_personality already disabled, skipping toggle") # Step 2: Enable normal personality (only if currently inactive) if "miku_personality" not in active: if not await self.toggle_plugin("miku_personality"): logger.error("Failed to enable miku_personality plugin") success = False else: logger.debug("miku_personality already active, skipping toggle") # Step 3: Switch LLM model back to llama3.1 (normal model) if not await self.set_llm_model("llama3.1"): logger.error("Failed to switch Cat LLM to llama3.1") success = False return success # Singleton instance cat_adapter = CatAdapter()