Phase 3: Unified Cheshire Cat integration with WebSocket-based per-user isolation

Key changes:
- CatAdapter (bot/utils/cat_client.py): WebSocket /ws/{user_id} for chat
  queries instead of HTTP POST (fixes per-user memory isolation when no
  API keys are configured — HTTP defaults all users to user_id='user')
- Memory management API: 8 endpoints for status, stats, facts, episodic
  memories, consolidation trigger, multi-step delete with confirmation
- Web UI: Memory tab (tab9) with collection stats, fact/episodic browser,
  manual consolidation trigger, and 3-step delete flow requiring exact
  confirmation string
- Bot integration: Cat-first response path with query_llama fallback for
  both text and embed responses, server mood detection
- Discord bridge plugin: fixed .pop() to .get() (UserMessage is a Pydantic
  BaseModelDict, not a raw dict), metadata extraction via extra attributes
- Unified docker-compose: Cat + Qdrant services merged into main compose,
  bot depends_on Cat healthcheck
- All plugins (discord_bridge, memory_consolidation, miku_personality)
  consolidated into cat-plugins/ for volume mount
- query_llama deprecated but functional for compatibility
This commit is contained in:
2026-02-07 20:22:03 +02:00
parent edb88e9ede
commit 14e1a8df51
14 changed files with 1382 additions and 70 deletions

479
bot/utils/cat_client.py Normal file
View File

@@ -0,0 +1,479 @@
# utils/cat_client.py
"""
Cheshire Cat AI Adapter for Miku Discord Bot (Phase 3)
Routes messages through the Cheshire Cat pipeline for:
- Memory-augmented responses (episodic + declarative recall)
- Fact extraction and consolidation
- Per-user conversation isolation
Uses WebSocket for chat (per-user isolation via /ws/{user_id}).
Uses HTTP for memory management endpoints.
Falls back to query_llama() on failure for zero-downtime resilience.
"""
import aiohttp
import asyncio
import json
import time
from typing import Optional, Dict, Any, List
import globals
from utils.logger import get_logger
logger = get_logger('cat_client')
class CatAdapter:
"""
Async adapter for Cheshire Cat AI.
Uses WebSocket /ws/{user_id} for conversation (per-user memory isolation).
Uses HTTP REST for memory management endpoints.
Without API keys configured, HTTP POST /message defaults all users to
user_id="user" (no isolation). WebSocket path param gives true isolation.
"""
def __init__(self):
self._base_url = globals.CHESHIRE_CAT_URL.rstrip('/')
self._api_key = globals.CHESHIRE_CAT_API_KEY
self._timeout = globals.CHESHIRE_CAT_TIMEOUT
self._healthy = None # None = unknown, True/False = last check result
self._last_health_check = 0
self._health_check_interval = 30 # seconds between health checks
self._consecutive_failures = 0
self._max_failures_before_circuit_break = 3
self._circuit_broken_until = 0 # timestamp when circuit breaker resets
logger.info(f"CatAdapter initialized: {self._base_url} (timeout={self._timeout}s)")
def _get_headers(self) -> dict:
"""Build request headers with optional auth."""
headers = {'Content-Type': 'application/json'}
if self._api_key:
headers['Authorization'] = f'Bearer {self._api_key}'
return headers
def _user_id_for_discord(self, user_id: str) -> str:
"""
Format Discord user ID for Cat's user namespace.
Cat uses user_id to isolate working memory and episodic memories.
"""
return f"discord_{user_id}"
async def health_check(self) -> bool:
"""
Check if Cheshire Cat is reachable and healthy.
Caches result to avoid hammering the endpoint.
"""
now = time.time()
if now - self._last_health_check < self._health_check_interval and self._healthy is not None:
return self._healthy
try:
async with aiohttp.ClientSession() as session:
async with session.get(
f"{self._base_url}/",
headers=self._get_headers(),
timeout=aiohttp.ClientTimeout(total=10)
) as response:
self._healthy = response.status == 200
self._last_health_check = now
if self._healthy:
logger.debug("Cat health check: OK")
else:
logger.warning(f"Cat health check failed: status {response.status}")
return self._healthy
except Exception as e:
self._healthy = False
self._last_health_check = now
logger.warning(f"Cat health check error: {e}")
return False
def _is_circuit_broken(self) -> bool:
"""Check if circuit breaker is active (too many consecutive failures)."""
if self._consecutive_failures >= self._max_failures_before_circuit_break:
if time.time() < self._circuit_broken_until:
return True
# Circuit breaker expired, allow retry
logger.info("Circuit breaker reset, allowing Cat retry")
self._consecutive_failures = 0
return False
async def query(
self,
text: str,
user_id: str,
guild_id: Optional[str] = None,
author_name: Optional[str] = None,
mood: Optional[str] = None,
response_type: str = "dm_response",
) -> Optional[str]:
"""
Send a message through the Cat pipeline via WebSocket and get a response.
Uses WebSocket /ws/{user_id} for per-user memory isolation.
Without API keys, HTTP POST /message defaults all users to user_id="user"
(no isolation). The WebSocket path parameter provides true per-user isolation
because Cat's auth handler uses user_id from the path when no keys are set.
Args:
text: User's message text
user_id: Discord user ID (will be namespaced as discord_{user_id})
guild_id: Optional guild ID for server context
author_name: Display name of the user
mood: Current mood name (passed as metadata for Cat hooks)
response_type: Type of response context
Returns:
Cat's response text, or None if Cat is unavailable (caller should fallback)
"""
if not globals.USE_CHESHIRE_CAT:
return None
if self._is_circuit_broken():
logger.debug("Circuit breaker active, skipping Cat")
return None
cat_user_id = self._user_id_for_discord(user_id)
# Build message payload with Discord metadata for our plugin hooks.
# The discord_bridge plugin's before_cat_reads_message hook reads
# these custom keys from the message dict.
payload = {
"text": text,
}
if guild_id:
payload["discord_guild_id"] = str(guild_id)
if author_name:
payload["discord_author_name"] = author_name
if mood:
payload["discord_mood"] = mood
if response_type:
payload["discord_response_type"] = response_type
try:
# Build WebSocket URL from HTTP base URL
ws_base = self._base_url.replace("http://", "ws://").replace("https://", "wss://")
ws_url = f"{ws_base}/ws/{cat_user_id}"
logger.debug(f"Querying Cat via WS: user={cat_user_id}, text={text[:80]}...")
async with aiohttp.ClientSession() as session:
async with session.ws_connect(
ws_url,
timeout=self._timeout,
) as ws:
# Send the message
await ws.send_json(payload)
# Read responses until we get the final "chat" type message.
# Cat may send intermediate messages (chat_token for streaming,
# notification for status updates). We want the final "chat" one.
reply_text = None
deadline = asyncio.get_event_loop().time() + self._timeout
while True:
remaining = deadline - asyncio.get_event_loop().time()
if remaining <= 0:
logger.error(f"Cat WS timeout after {self._timeout}s")
break
try:
ws_msg = await asyncio.wait_for(
ws.receive(),
timeout=remaining
)
except asyncio.TimeoutError:
logger.error(f"Cat WS receive timeout after {self._timeout}s")
break
# Handle WebSocket close/error frames
if ws_msg.type in (aiohttp.WSMsgType.CLOSE, aiohttp.WSMsgType.CLOSING, aiohttp.WSMsgType.CLOSED):
logger.warning("Cat WS connection closed by server")
break
if ws_msg.type == aiohttp.WSMsgType.ERROR:
logger.error(f"Cat WS error frame: {ws.exception()}")
break
if ws_msg.type != aiohttp.WSMsgType.TEXT:
logger.debug(f"Cat WS non-text frame type: {ws_msg.type}")
continue
try:
msg = json.loads(ws_msg.data)
except (json.JSONDecodeError, TypeError) as e:
logger.warning(f"Cat WS non-JSON message: {e}")
continue
msg_type = msg.get("type", "")
if msg_type == "chat":
# Final response — extract text
reply_text = msg.get("content") or msg.get("text", "")
break
elif msg_type == "chat_token":
# Streaming token — skip, we wait for final
continue
elif msg_type == "error":
error_desc = msg.get("description", "Unknown Cat error")
logger.error(f"Cat WS error: {error_desc}")
break
elif msg_type == "notification":
logger.debug(f"Cat notification: {msg.get('content', '')}")
continue
else:
logger.debug(f"Cat WS unknown msg type: {msg_type}")
continue
if reply_text and reply_text.strip():
self._consecutive_failures = 0
logger.info(f"🐱 Cat response for {cat_user_id}: {reply_text[:100]}...")
return reply_text
else:
logger.warning("Cat returned empty response via WS")
self._consecutive_failures += 1
return None
except asyncio.TimeoutError:
logger.error(f"Cat WS connection timeout after {self._timeout}s")
self._consecutive_failures += 1
if self._consecutive_failures >= self._max_failures_before_circuit_break:
self._circuit_broken_until = time.time() + 60
logger.warning("Circuit breaker activated (WS timeout)")
return None
except Exception as e:
logger.error(f"Cat WS query error: {e}")
self._consecutive_failures += 1
if self._consecutive_failures >= self._max_failures_before_circuit_break:
self._circuit_broken_until = time.time() + 60
logger.warning(f"Circuit breaker activated: {e}")
return None
# ===================================================================
# MEMORY MANAGEMENT API (for Web UI)
# ===================================================================
async def get_memory_stats(self) -> Optional[Dict[str, Any]]:
"""
Get memory collection statistics from Cat.
Returns dict with collection names and point counts.
"""
try:
async with aiohttp.ClientSession() as session:
async with session.get(
f"{self._base_url}/memory/collections",
headers=self._get_headers(),
timeout=aiohttp.ClientTimeout(total=15)
) as response:
if response.status == 200:
data = await response.json()
return data
else:
logger.error(f"Failed to get memory stats: {response.status}")
return None
except Exception as e:
logger.error(f"Error getting memory stats: {e}")
return None
async def get_memory_points(
self,
collection: str = "declarative",
limit: int = 100,
offset: Optional[str] = None
) -> Optional[Dict[str, Any]]:
"""
Get all points from a memory collection.
Returns paginated list of memory points.
"""
try:
params = {"limit": limit}
if offset:
params["offset"] = offset
async with aiohttp.ClientSession() as session:
async with session.get(
f"{self._base_url}/memory/collections/{collection}/points",
headers=self._get_headers(),
params=params,
timeout=aiohttp.ClientTimeout(total=30)
) as response:
if response.status == 200:
return await response.json()
else:
logger.error(f"Failed to get {collection} points: {response.status}")
return None
except Exception as e:
logger.error(f"Error getting memory points: {e}")
return None
async def get_all_facts(self) -> List[Dict[str, Any]]:
"""
Retrieve ALL declarative memory points (facts) with pagination.
Returns a flat list of all fact dicts.
"""
all_facts = []
offset = None
try:
while True:
result = await self.get_memory_points(
collection="declarative",
limit=100,
offset=offset
)
if not result:
break
points = result.get("points", [])
for point in points:
payload = point.get("payload", {})
fact = {
"id": point.get("id"),
"content": payload.get("page_content", ""),
"metadata": payload.get("metadata", {}),
}
all_facts.append(fact)
offset = result.get("next_offset")
if not offset:
break
logger.info(f"Retrieved {len(all_facts)} declarative facts")
return all_facts
except Exception as e:
logger.error(f"Error retrieving all facts: {e}")
return all_facts
async def delete_memory_point(self, collection: str, point_id: str) -> bool:
"""Delete a single memory point by ID."""
try:
async with aiohttp.ClientSession() as session:
async with session.delete(
f"{self._base_url}/memory/collections/{collection}/points/{point_id}",
headers=self._get_headers(),
timeout=aiohttp.ClientTimeout(total=15)
) as response:
if response.status == 200:
logger.info(f"Deleted point {point_id} from {collection}")
return True
else:
logger.error(f"Failed to delete point: {response.status}")
return False
except Exception as e:
logger.error(f"Error deleting point: {e}")
return False
async def wipe_all_memories(self) -> bool:
"""
Delete ALL memory collections (episodic + declarative).
This is the nuclear option — requires multi-step confirmation in the UI.
"""
try:
async with aiohttp.ClientSession() as session:
async with session.delete(
f"{self._base_url}/memory/collections",
headers=self._get_headers(),
timeout=aiohttp.ClientTimeout(total=30)
) as response:
if response.status == 200:
logger.warning("🗑️ ALL memory collections wiped!")
return True
else:
error = await response.text()
logger.error(f"Failed to wipe memories: {response.status} - {error}")
return False
except Exception as e:
logger.error(f"Error wiping memories: {e}")
return False
async def wipe_conversation_history(self) -> bool:
"""Clear working memory / conversation history."""
try:
async with aiohttp.ClientSession() as session:
async with session.delete(
f"{self._base_url}/memory/conversation_history",
headers=self._get_headers(),
timeout=aiohttp.ClientTimeout(total=15)
) as response:
if response.status == 200:
logger.info("Conversation history cleared")
return True
else:
logger.error(f"Failed to clear conversation history: {response.status}")
return False
except Exception as e:
logger.error(f"Error clearing conversation history: {e}")
return False
async def trigger_consolidation(self) -> Optional[str]:
"""
Trigger memory consolidation by sending a special message via WebSocket.
The memory_consolidation plugin's tool 'consolidate_memories' is
triggered when it sees 'consolidate now' in the text.
Uses WebSocket with a system user ID for proper context.
"""
try:
ws_base = self._base_url.replace("http://", "ws://").replace("https://", "wss://")
ws_url = f"{ws_base}/ws/system_consolidation"
logger.info("🌙 Triggering memory consolidation via WS...")
async with aiohttp.ClientSession() as session:
async with session.ws_connect(
ws_url,
timeout=300, # Consolidation can be very slow
) as ws:
await ws.send_json({"text": "consolidate now"})
# Wait for the final chat response
deadline = asyncio.get_event_loop().time() + 300
while True:
remaining = deadline - asyncio.get_event_loop().time()
if remaining <= 0:
logger.error("Consolidation timed out (>300s)")
return "Consolidation timed out"
try:
ws_msg = await asyncio.wait_for(
ws.receive(),
timeout=remaining
)
except asyncio.TimeoutError:
logger.error("Consolidation WS receive timeout")
return "Consolidation timed out waiting for response"
if ws_msg.type in (aiohttp.WSMsgType.CLOSE, aiohttp.WSMsgType.CLOSING, aiohttp.WSMsgType.CLOSED):
logger.warning("Consolidation WS closed by server")
return "Connection closed during consolidation"
if ws_msg.type == aiohttp.WSMsgType.ERROR:
return f"WebSocket error: {ws.exception()}"
if ws_msg.type != aiohttp.WSMsgType.TEXT:
continue
try:
msg = json.loads(ws_msg.data)
except (json.JSONDecodeError, TypeError):
continue
msg_type = msg.get("type", "")
if msg_type == "chat":
reply = msg.get("content") or msg.get("text", "")
logger.info(f"Consolidation result: {reply[:200]}")
return reply
elif msg_type == "error":
error_desc = msg.get("description", "Unknown error")
logger.error(f"Consolidation error: {error_desc}")
return f"Consolidation error: {error_desc}"
else:
continue
except asyncio.TimeoutError:
logger.error("Consolidation WS connection timed out")
return None
except Exception as e:
logger.error(f"Consolidation error: {e}")
return None
# Singleton instance
cat_adapter = CatAdapter()