Phase 3: Unified Cheshire Cat integration with WebSocket-based per-user isolation
Key changes:
- CatAdapter (bot/utils/cat_client.py): WebSocket /ws/{user_id} for chat
queries instead of HTTP POST (fixes per-user memory isolation when no
API keys are configured — HTTP defaults all users to user_id='user')
- Memory management API: 8 endpoints for status, stats, facts, episodic
memories, consolidation trigger, multi-step delete with confirmation
- Web UI: Memory tab (tab9) with collection stats, fact/episodic browser,
manual consolidation trigger, and 3-step delete flow requiring exact
confirmation string
- Bot integration: Cat-first response path with query_llama fallback for
both text and embed responses, server mood detection
- Discord bridge plugin: fixed .pop() to .get() (UserMessage is a Pydantic
BaseModelDict, not a raw dict), metadata extraction via extra attributes
- Unified docker-compose: Cat + Qdrant services merged into main compose,
bot depends_on Cat healthcheck
- All plugins (discord_bridge, memory_consolidation, miku_personality)
consolidated into cat-plugins/ for volume mount
- query_llama deprecated but functional for compatibility
This commit is contained in:
479
bot/utils/cat_client.py
Normal file
479
bot/utils/cat_client.py
Normal file
@@ -0,0 +1,479 @@
|
||||
# utils/cat_client.py
|
||||
"""
|
||||
Cheshire Cat AI Adapter for Miku Discord Bot (Phase 3)
|
||||
|
||||
Routes messages through the Cheshire Cat pipeline for:
|
||||
- Memory-augmented responses (episodic + declarative recall)
|
||||
- Fact extraction and consolidation
|
||||
- Per-user conversation isolation
|
||||
|
||||
Uses WebSocket for chat (per-user isolation via /ws/{user_id}).
|
||||
Uses HTTP for memory management endpoints.
|
||||
Falls back to query_llama() on failure for zero-downtime resilience.
|
||||
"""
|
||||
|
||||
import aiohttp
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from typing import Optional, Dict, Any, List
|
||||
|
||||
import globals
|
||||
from utils.logger import get_logger
|
||||
|
||||
logger = get_logger('cat_client')
|
||||
|
||||
|
||||
class CatAdapter:
|
||||
"""
|
||||
Async adapter for Cheshire Cat AI.
|
||||
|
||||
Uses WebSocket /ws/{user_id} for conversation (per-user memory isolation).
|
||||
Uses HTTP REST for memory management endpoints.
|
||||
Without API keys configured, HTTP POST /message defaults all users to
|
||||
user_id="user" (no isolation). WebSocket path param gives true isolation.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._base_url = globals.CHESHIRE_CAT_URL.rstrip('/')
|
||||
self._api_key = globals.CHESHIRE_CAT_API_KEY
|
||||
self._timeout = globals.CHESHIRE_CAT_TIMEOUT
|
||||
self._healthy = None # None = unknown, True/False = last check result
|
||||
self._last_health_check = 0
|
||||
self._health_check_interval = 30 # seconds between health checks
|
||||
self._consecutive_failures = 0
|
||||
self._max_failures_before_circuit_break = 3
|
||||
self._circuit_broken_until = 0 # timestamp when circuit breaker resets
|
||||
logger.info(f"CatAdapter initialized: {self._base_url} (timeout={self._timeout}s)")
|
||||
|
||||
def _get_headers(self) -> dict:
|
||||
"""Build request headers with optional auth."""
|
||||
headers = {'Content-Type': 'application/json'}
|
||||
if self._api_key:
|
||||
headers['Authorization'] = f'Bearer {self._api_key}'
|
||||
return headers
|
||||
|
||||
def _user_id_for_discord(self, user_id: str) -> str:
|
||||
"""
|
||||
Format Discord user ID for Cat's user namespace.
|
||||
Cat uses user_id to isolate working memory and episodic memories.
|
||||
"""
|
||||
return f"discord_{user_id}"
|
||||
|
||||
async def health_check(self) -> bool:
|
||||
"""
|
||||
Check if Cheshire Cat is reachable and healthy.
|
||||
Caches result to avoid hammering the endpoint.
|
||||
"""
|
||||
now = time.time()
|
||||
if now - self._last_health_check < self._health_check_interval and self._healthy is not None:
|
||||
return self._healthy
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
f"{self._base_url}/",
|
||||
headers=self._get_headers(),
|
||||
timeout=aiohttp.ClientTimeout(total=10)
|
||||
) as response:
|
||||
self._healthy = response.status == 200
|
||||
self._last_health_check = now
|
||||
if self._healthy:
|
||||
logger.debug("Cat health check: OK")
|
||||
else:
|
||||
logger.warning(f"Cat health check failed: status {response.status}")
|
||||
return self._healthy
|
||||
except Exception as e:
|
||||
self._healthy = False
|
||||
self._last_health_check = now
|
||||
logger.warning(f"Cat health check error: {e}")
|
||||
return False
|
||||
|
||||
def _is_circuit_broken(self) -> bool:
|
||||
"""Check if circuit breaker is active (too many consecutive failures)."""
|
||||
if self._consecutive_failures >= self._max_failures_before_circuit_break:
|
||||
if time.time() < self._circuit_broken_until:
|
||||
return True
|
||||
# Circuit breaker expired, allow retry
|
||||
logger.info("Circuit breaker reset, allowing Cat retry")
|
||||
self._consecutive_failures = 0
|
||||
return False
|
||||
|
||||
async def query(
|
||||
self,
|
||||
text: str,
|
||||
user_id: str,
|
||||
guild_id: Optional[str] = None,
|
||||
author_name: Optional[str] = None,
|
||||
mood: Optional[str] = None,
|
||||
response_type: str = "dm_response",
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Send a message through the Cat pipeline via WebSocket and get a response.
|
||||
|
||||
Uses WebSocket /ws/{user_id} for per-user memory isolation.
|
||||
Without API keys, HTTP POST /message defaults all users to user_id="user"
|
||||
(no isolation). The WebSocket path parameter provides true per-user isolation
|
||||
because Cat's auth handler uses user_id from the path when no keys are set.
|
||||
|
||||
Args:
|
||||
text: User's message text
|
||||
user_id: Discord user ID (will be namespaced as discord_{user_id})
|
||||
guild_id: Optional guild ID for server context
|
||||
author_name: Display name of the user
|
||||
mood: Current mood name (passed as metadata for Cat hooks)
|
||||
response_type: Type of response context
|
||||
|
||||
Returns:
|
||||
Cat's response text, or None if Cat is unavailable (caller should fallback)
|
||||
"""
|
||||
if not globals.USE_CHESHIRE_CAT:
|
||||
return None
|
||||
|
||||
if self._is_circuit_broken():
|
||||
logger.debug("Circuit breaker active, skipping Cat")
|
||||
return None
|
||||
|
||||
cat_user_id = self._user_id_for_discord(user_id)
|
||||
|
||||
# Build message payload with Discord metadata for our plugin hooks.
|
||||
# The discord_bridge plugin's before_cat_reads_message hook reads
|
||||
# these custom keys from the message dict.
|
||||
payload = {
|
||||
"text": text,
|
||||
}
|
||||
if guild_id:
|
||||
payload["discord_guild_id"] = str(guild_id)
|
||||
if author_name:
|
||||
payload["discord_author_name"] = author_name
|
||||
if mood:
|
||||
payload["discord_mood"] = mood
|
||||
if response_type:
|
||||
payload["discord_response_type"] = response_type
|
||||
|
||||
try:
|
||||
# Build WebSocket URL from HTTP base URL
|
||||
ws_base = self._base_url.replace("http://", "ws://").replace("https://", "wss://")
|
||||
ws_url = f"{ws_base}/ws/{cat_user_id}"
|
||||
|
||||
logger.debug(f"Querying Cat via WS: user={cat_user_id}, text={text[:80]}...")
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.ws_connect(
|
||||
ws_url,
|
||||
timeout=self._timeout,
|
||||
) as ws:
|
||||
# Send the message
|
||||
await ws.send_json(payload)
|
||||
|
||||
# Read responses until we get the final "chat" type message.
|
||||
# Cat may send intermediate messages (chat_token for streaming,
|
||||
# notification for status updates). We want the final "chat" one.
|
||||
reply_text = None
|
||||
deadline = asyncio.get_event_loop().time() + self._timeout
|
||||
|
||||
while True:
|
||||
remaining = deadline - asyncio.get_event_loop().time()
|
||||
if remaining <= 0:
|
||||
logger.error(f"Cat WS timeout after {self._timeout}s")
|
||||
break
|
||||
|
||||
try:
|
||||
ws_msg = await asyncio.wait_for(
|
||||
ws.receive(),
|
||||
timeout=remaining
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"Cat WS receive timeout after {self._timeout}s")
|
||||
break
|
||||
|
||||
# Handle WebSocket close/error frames
|
||||
if ws_msg.type in (aiohttp.WSMsgType.CLOSE, aiohttp.WSMsgType.CLOSING, aiohttp.WSMsgType.CLOSED):
|
||||
logger.warning("Cat WS connection closed by server")
|
||||
break
|
||||
if ws_msg.type == aiohttp.WSMsgType.ERROR:
|
||||
logger.error(f"Cat WS error frame: {ws.exception()}")
|
||||
break
|
||||
if ws_msg.type != aiohttp.WSMsgType.TEXT:
|
||||
logger.debug(f"Cat WS non-text frame type: {ws_msg.type}")
|
||||
continue
|
||||
|
||||
try:
|
||||
msg = json.loads(ws_msg.data)
|
||||
except (json.JSONDecodeError, TypeError) as e:
|
||||
logger.warning(f"Cat WS non-JSON message: {e}")
|
||||
continue
|
||||
|
||||
msg_type = msg.get("type", "")
|
||||
|
||||
if msg_type == "chat":
|
||||
# Final response — extract text
|
||||
reply_text = msg.get("content") or msg.get("text", "")
|
||||
break
|
||||
elif msg_type == "chat_token":
|
||||
# Streaming token — skip, we wait for final
|
||||
continue
|
||||
elif msg_type == "error":
|
||||
error_desc = msg.get("description", "Unknown Cat error")
|
||||
logger.error(f"Cat WS error: {error_desc}")
|
||||
break
|
||||
elif msg_type == "notification":
|
||||
logger.debug(f"Cat notification: {msg.get('content', '')}")
|
||||
continue
|
||||
else:
|
||||
logger.debug(f"Cat WS unknown msg type: {msg_type}")
|
||||
continue
|
||||
|
||||
if reply_text and reply_text.strip():
|
||||
self._consecutive_failures = 0
|
||||
logger.info(f"🐱 Cat response for {cat_user_id}: {reply_text[:100]}...")
|
||||
return reply_text
|
||||
else:
|
||||
logger.warning("Cat returned empty response via WS")
|
||||
self._consecutive_failures += 1
|
||||
return None
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"Cat WS connection timeout after {self._timeout}s")
|
||||
self._consecutive_failures += 1
|
||||
if self._consecutive_failures >= self._max_failures_before_circuit_break:
|
||||
self._circuit_broken_until = time.time() + 60
|
||||
logger.warning("Circuit breaker activated (WS timeout)")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Cat WS query error: {e}")
|
||||
self._consecutive_failures += 1
|
||||
if self._consecutive_failures >= self._max_failures_before_circuit_break:
|
||||
self._circuit_broken_until = time.time() + 60
|
||||
logger.warning(f"Circuit breaker activated: {e}")
|
||||
return None
|
||||
|
||||
# ===================================================================
|
||||
# MEMORY MANAGEMENT API (for Web UI)
|
||||
# ===================================================================
|
||||
|
||||
async def get_memory_stats(self) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Get memory collection statistics from Cat.
|
||||
Returns dict with collection names and point counts.
|
||||
"""
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
f"{self._base_url}/memory/collections",
|
||||
headers=self._get_headers(),
|
||||
timeout=aiohttp.ClientTimeout(total=15)
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
data = await response.json()
|
||||
return data
|
||||
else:
|
||||
logger.error(f"Failed to get memory stats: {response.status}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting memory stats: {e}")
|
||||
return None
|
||||
|
||||
async def get_memory_points(
|
||||
self,
|
||||
collection: str = "declarative",
|
||||
limit: int = 100,
|
||||
offset: Optional[str] = None
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Get all points from a memory collection.
|
||||
Returns paginated list of memory points.
|
||||
"""
|
||||
try:
|
||||
params = {"limit": limit}
|
||||
if offset:
|
||||
params["offset"] = offset
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
f"{self._base_url}/memory/collections/{collection}/points",
|
||||
headers=self._get_headers(),
|
||||
params=params,
|
||||
timeout=aiohttp.ClientTimeout(total=30)
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
return await response.json()
|
||||
else:
|
||||
logger.error(f"Failed to get {collection} points: {response.status}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting memory points: {e}")
|
||||
return None
|
||||
|
||||
async def get_all_facts(self) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Retrieve ALL declarative memory points (facts) with pagination.
|
||||
Returns a flat list of all fact dicts.
|
||||
"""
|
||||
all_facts = []
|
||||
offset = None
|
||||
|
||||
try:
|
||||
while True:
|
||||
result = await self.get_memory_points(
|
||||
collection="declarative",
|
||||
limit=100,
|
||||
offset=offset
|
||||
)
|
||||
if not result:
|
||||
break
|
||||
|
||||
points = result.get("points", [])
|
||||
for point in points:
|
||||
payload = point.get("payload", {})
|
||||
fact = {
|
||||
"id": point.get("id"),
|
||||
"content": payload.get("page_content", ""),
|
||||
"metadata": payload.get("metadata", {}),
|
||||
}
|
||||
all_facts.append(fact)
|
||||
|
||||
offset = result.get("next_offset")
|
||||
if not offset:
|
||||
break
|
||||
|
||||
logger.info(f"Retrieved {len(all_facts)} declarative facts")
|
||||
return all_facts
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving all facts: {e}")
|
||||
return all_facts
|
||||
|
||||
async def delete_memory_point(self, collection: str, point_id: str) -> bool:
|
||||
"""Delete a single memory point by ID."""
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.delete(
|
||||
f"{self._base_url}/memory/collections/{collection}/points/{point_id}",
|
||||
headers=self._get_headers(),
|
||||
timeout=aiohttp.ClientTimeout(total=15)
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
logger.info(f"Deleted point {point_id} from {collection}")
|
||||
return True
|
||||
else:
|
||||
logger.error(f"Failed to delete point: {response.status}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting point: {e}")
|
||||
return False
|
||||
|
||||
async def wipe_all_memories(self) -> bool:
|
||||
"""
|
||||
Delete ALL memory collections (episodic + declarative).
|
||||
This is the nuclear option — requires multi-step confirmation in the UI.
|
||||
"""
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.delete(
|
||||
f"{self._base_url}/memory/collections",
|
||||
headers=self._get_headers(),
|
||||
timeout=aiohttp.ClientTimeout(total=30)
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
logger.warning("🗑️ ALL memory collections wiped!")
|
||||
return True
|
||||
else:
|
||||
error = await response.text()
|
||||
logger.error(f"Failed to wipe memories: {response.status} - {error}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Error wiping memories: {e}")
|
||||
return False
|
||||
|
||||
async def wipe_conversation_history(self) -> bool:
|
||||
"""Clear working memory / conversation history."""
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.delete(
|
||||
f"{self._base_url}/memory/conversation_history",
|
||||
headers=self._get_headers(),
|
||||
timeout=aiohttp.ClientTimeout(total=15)
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
logger.info("Conversation history cleared")
|
||||
return True
|
||||
else:
|
||||
logger.error(f"Failed to clear conversation history: {response.status}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Error clearing conversation history: {e}")
|
||||
return False
|
||||
|
||||
async def trigger_consolidation(self) -> Optional[str]:
|
||||
"""
|
||||
Trigger memory consolidation by sending a special message via WebSocket.
|
||||
The memory_consolidation plugin's tool 'consolidate_memories' is
|
||||
triggered when it sees 'consolidate now' in the text.
|
||||
Uses WebSocket with a system user ID for proper context.
|
||||
"""
|
||||
try:
|
||||
ws_base = self._base_url.replace("http://", "ws://").replace("https://", "wss://")
|
||||
ws_url = f"{ws_base}/ws/system_consolidation"
|
||||
|
||||
logger.info("🌙 Triggering memory consolidation via WS...")
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.ws_connect(
|
||||
ws_url,
|
||||
timeout=300, # Consolidation can be very slow
|
||||
) as ws:
|
||||
await ws.send_json({"text": "consolidate now"})
|
||||
|
||||
# Wait for the final chat response
|
||||
deadline = asyncio.get_event_loop().time() + 300
|
||||
|
||||
while True:
|
||||
remaining = deadline - asyncio.get_event_loop().time()
|
||||
if remaining <= 0:
|
||||
logger.error("Consolidation timed out (>300s)")
|
||||
return "Consolidation timed out"
|
||||
|
||||
try:
|
||||
ws_msg = await asyncio.wait_for(
|
||||
ws.receive(),
|
||||
timeout=remaining
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error("Consolidation WS receive timeout")
|
||||
return "Consolidation timed out waiting for response"
|
||||
|
||||
if ws_msg.type in (aiohttp.WSMsgType.CLOSE, aiohttp.WSMsgType.CLOSING, aiohttp.WSMsgType.CLOSED):
|
||||
logger.warning("Consolidation WS closed by server")
|
||||
return "Connection closed during consolidation"
|
||||
if ws_msg.type == aiohttp.WSMsgType.ERROR:
|
||||
return f"WebSocket error: {ws.exception()}"
|
||||
if ws_msg.type != aiohttp.WSMsgType.TEXT:
|
||||
continue
|
||||
|
||||
try:
|
||||
msg = json.loads(ws_msg.data)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
continue
|
||||
|
||||
msg_type = msg.get("type", "")
|
||||
if msg_type == "chat":
|
||||
reply = msg.get("content") or msg.get("text", "")
|
||||
logger.info(f"Consolidation result: {reply[:200]}")
|
||||
return reply
|
||||
elif msg_type == "error":
|
||||
error_desc = msg.get("description", "Unknown error")
|
||||
logger.error(f"Consolidation error: {error_desc}")
|
||||
return f"Consolidation error: {error_desc}"
|
||||
else:
|
||||
continue
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
logger.error("Consolidation WS connection timed out")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Consolidation error: {e}")
|
||||
return None
|
||||
|
||||
|
||||
# Singleton instance
|
||||
cat_adapter = CatAdapter()
|
||||
@@ -152,6 +152,13 @@ async def query_llama(user_prompt, user_id, guild_id=None, response_type="dm_res
|
||||
"""
|
||||
Query llama.cpp server via llama-swap with OpenAI-compatible API.
|
||||
|
||||
.. deprecated:: Phase 3
|
||||
For main conversation flow, prefer routing through the Cheshire Cat pipeline
|
||||
(via cat_client.CatAdapter.query) which provides memory-augmented responses.
|
||||
This function remains available for specialized use cases (vision, bipolar mode,
|
||||
image generation, autonomous, sentiment analysis) and as a fallback when Cat
|
||||
is unavailable.
|
||||
|
||||
Args:
|
||||
user_prompt: The user's input
|
||||
user_id: User identifier (used for DM history)
|
||||
|
||||
Reference in New Issue
Block a user