- discord_bridge before_agent_starts now checks evil_mode from
working_memory to load the correct personality files:
Normal: miku_lore/prompt/lyrics + /app/moods/{mood}.txt
Evil: evil_miku_lore/prompt/lyrics + /app/moods/evil/{mood}.txt
- Reads files directly instead of relying on cross-plugin working_memory
- cat_client.query() returns (response, full_prompt) tuple
- Full prompt includes system prefix + recalled memories + conversation
- API /prompt/cat returns full_prompt field
864 lines
37 KiB
Python
864 lines
37 KiB
Python
# utils/cat_client.py
|
||
"""
|
||
Cheshire Cat AI Adapter for Miku Discord Bot (Phase 3)
|
||
|
||
Routes messages through the Cheshire Cat pipeline for:
|
||
- Memory-augmented responses (episodic + declarative recall)
|
||
- Fact extraction and consolidation
|
||
- Per-user conversation isolation
|
||
|
||
Uses WebSocket for chat (per-user isolation via /ws/{user_id}).
|
||
Uses HTTP for memory management endpoints.
|
||
Falls back to query_llama() on failure for zero-downtime resilience.
|
||
"""
|
||
|
||
import aiohttp
|
||
import asyncio
|
||
import json
|
||
import time
|
||
from typing import Optional, Dict, Any, List
|
||
|
||
import globals
|
||
from utils.logger import get_logger
|
||
|
||
logger = get_logger('llm') # Use existing 'llm' logger component
|
||
|
||
|
||
class CatAdapter:
|
||
"""
|
||
Async adapter for Cheshire Cat AI.
|
||
|
||
Uses WebSocket /ws/{user_id} for conversation (per-user memory isolation).
|
||
Uses HTTP REST for memory management endpoints.
|
||
Without API keys configured, HTTP POST /message defaults all users to
|
||
user_id="user" (no isolation). WebSocket path param gives true isolation.
|
||
"""
|
||
|
||
def __init__(self):
|
||
self._base_url = globals.CHESHIRE_CAT_URL.rstrip('/')
|
||
self._api_key = globals.CHESHIRE_CAT_API_KEY
|
||
self._timeout = globals.CHESHIRE_CAT_TIMEOUT
|
||
self._healthy = None # None = unknown, True/False = last check result
|
||
self._last_health_check = 0
|
||
self._health_check_interval = 30 # seconds between health checks
|
||
self._consecutive_failures = 0
|
||
self._max_failures_before_circuit_break = 3
|
||
self._circuit_broken_until = 0 # timestamp when circuit breaker resets
|
||
logger.info(f"CatAdapter initialized: {self._base_url} (timeout={self._timeout}s)")
|
||
|
||
def _get_headers(self) -> dict:
|
||
"""Build request headers with optional auth."""
|
||
headers = {'Content-Type': 'application/json'}
|
||
if self._api_key:
|
||
headers['Authorization'] = f'Bearer {self._api_key}'
|
||
return headers
|
||
|
||
def _user_id_for_discord(self, user_id: str) -> str:
|
||
"""
|
||
Format Discord user ID for Cat's user namespace.
|
||
Cat uses user_id to isolate working memory and episodic memories.
|
||
"""
|
||
return f"discord_{user_id}"
|
||
|
||
async def health_check(self) -> bool:
|
||
"""
|
||
Check if Cheshire Cat is reachable and healthy.
|
||
Caches result to avoid hammering the endpoint.
|
||
"""
|
||
now = time.time()
|
||
if now - self._last_health_check < self._health_check_interval and self._healthy is not None:
|
||
return self._healthy
|
||
|
||
try:
|
||
async with aiohttp.ClientSession() as session:
|
||
async with session.get(
|
||
f"{self._base_url}/",
|
||
headers=self._get_headers(),
|
||
timeout=aiohttp.ClientTimeout(total=10)
|
||
) as response:
|
||
self._healthy = response.status == 200
|
||
self._last_health_check = now
|
||
if self._healthy:
|
||
logger.debug("Cat health check: OK")
|
||
else:
|
||
logger.warning(f"Cat health check failed: status {response.status}")
|
||
return self._healthy
|
||
except Exception as e:
|
||
self._healthy = False
|
||
self._last_health_check = now
|
||
logger.warning(f"Cat health check error: {e}")
|
||
return False
|
||
|
||
def _is_circuit_broken(self) -> bool:
|
||
"""Check if circuit breaker is active (too many consecutive failures)."""
|
||
if self._consecutive_failures >= self._max_failures_before_circuit_break:
|
||
if time.time() < self._circuit_broken_until:
|
||
return True
|
||
# Circuit breaker expired, allow retry
|
||
logger.info("Circuit breaker reset, allowing Cat retry")
|
||
self._consecutive_failures = 0
|
||
return False
|
||
|
||
async def query(
|
||
self,
|
||
text: str,
|
||
user_id: str,
|
||
guild_id: Optional[str] = None,
|
||
author_name: Optional[str] = None,
|
||
mood: Optional[str] = None,
|
||
response_type: str = "dm_response",
|
||
) -> Optional[tuple]:
|
||
"""
|
||
Send a message through the Cat pipeline via WebSocket and get a response.
|
||
|
||
Uses WebSocket /ws/{user_id} for per-user memory isolation.
|
||
Without API keys, HTTP POST /message defaults all users to user_id="user"
|
||
(no isolation). The WebSocket path parameter provides true per-user isolation
|
||
because Cat's auth handler uses user_id from the path when no keys are set.
|
||
|
||
Args:
|
||
text: User's message text
|
||
user_id: Discord user ID (will be namespaced as discord_{user_id})
|
||
guild_id: Optional guild ID for server context
|
||
author_name: Display name of the user
|
||
mood: Current mood name (passed as metadata for Cat hooks)
|
||
response_type: Type of response context
|
||
|
||
Returns:
|
||
Tuple of (response_text, full_prompt) on success, or None if Cat
|
||
is unavailable (caller should fallback to query_llama)
|
||
"""
|
||
if not globals.USE_CHESHIRE_CAT:
|
||
return None
|
||
|
||
if self._is_circuit_broken():
|
||
logger.debug("Circuit breaker active, skipping Cat")
|
||
return None
|
||
|
||
cat_user_id = self._user_id_for_discord(user_id)
|
||
|
||
# Build message payload with Discord metadata for our plugin hooks.
|
||
# The discord_bridge plugin's before_cat_reads_message hook reads
|
||
# these custom keys from the message dict.
|
||
payload = {
|
||
"text": text,
|
||
}
|
||
if guild_id:
|
||
payload["discord_guild_id"] = str(guild_id)
|
||
if author_name:
|
||
payload["discord_author_name"] = author_name
|
||
# When evil mode is active, send the evil mood name instead of the normal mood
|
||
if globals.EVIL_MODE:
|
||
payload["discord_mood"] = getattr(globals, 'EVIL_DM_MOOD', 'evil_neutral')
|
||
elif mood:
|
||
payload["discord_mood"] = mood
|
||
if response_type:
|
||
payload["discord_response_type"] = response_type
|
||
# Pass evil mode flag so discord_bridge stores it in working_memory
|
||
payload["discord_evil_mode"] = globals.EVIL_MODE
|
||
|
||
try:
|
||
# Build WebSocket URL from HTTP base URL
|
||
ws_base = self._base_url.replace("http://", "ws://").replace("https://", "wss://")
|
||
ws_url = f"{ws_base}/ws/{cat_user_id}"
|
||
|
||
logger.debug(f"Querying Cat via WS: user={cat_user_id}, text={text[:80]}...")
|
||
|
||
async with aiohttp.ClientSession() as session:
|
||
async with session.ws_connect(
|
||
ws_url,
|
||
timeout=self._timeout,
|
||
) as ws:
|
||
# Send the message
|
||
await ws.send_json(payload)
|
||
|
||
# Read responses until we get the final "chat" type message.
|
||
# Cat may send intermediate messages (chat_token for streaming,
|
||
# notification for status updates). We want the final "chat" one.
|
||
reply_text = None
|
||
full_prompt = ""
|
||
deadline = asyncio.get_event_loop().time() + self._timeout
|
||
|
||
while True:
|
||
remaining = deadline - asyncio.get_event_loop().time()
|
||
if remaining <= 0:
|
||
logger.error(f"Cat WS timeout after {self._timeout}s")
|
||
break
|
||
|
||
try:
|
||
ws_msg = await asyncio.wait_for(
|
||
ws.receive(),
|
||
timeout=remaining
|
||
)
|
||
except asyncio.TimeoutError:
|
||
logger.error(f"Cat WS receive timeout after {self._timeout}s")
|
||
break
|
||
|
||
# Handle WebSocket close/error frames
|
||
if ws_msg.type in (aiohttp.WSMsgType.CLOSE, aiohttp.WSMsgType.CLOSING, aiohttp.WSMsgType.CLOSED):
|
||
logger.warning("Cat WS connection closed by server")
|
||
break
|
||
if ws_msg.type == aiohttp.WSMsgType.ERROR:
|
||
logger.error(f"Cat WS error frame: {ws.exception()}")
|
||
break
|
||
if ws_msg.type != aiohttp.WSMsgType.TEXT:
|
||
logger.debug(f"Cat WS non-text frame type: {ws_msg.type}")
|
||
continue
|
||
|
||
try:
|
||
msg = json.loads(ws_msg.data)
|
||
except (json.JSONDecodeError, TypeError) as e:
|
||
logger.warning(f"Cat WS non-JSON message: {e}")
|
||
continue
|
||
|
||
msg_type = msg.get("type", "")
|
||
|
||
if msg_type == "chat":
|
||
# Final response — extract text and full prompt
|
||
reply_text = msg.get("content") or msg.get("text", "")
|
||
full_prompt = msg.get("full_prompt", "")
|
||
break
|
||
elif msg_type == "chat_token":
|
||
# Streaming token — skip, we wait for final
|
||
continue
|
||
elif msg_type == "error":
|
||
error_desc = msg.get("description", "Unknown Cat error")
|
||
logger.error(f"Cat WS error: {error_desc}")
|
||
break
|
||
elif msg_type == "notification":
|
||
logger.debug(f"Cat notification: {msg.get('content', '')}")
|
||
continue
|
||
else:
|
||
logger.debug(f"Cat WS unknown msg type: {msg_type}")
|
||
continue
|
||
|
||
if reply_text and reply_text.strip():
|
||
self._consecutive_failures = 0
|
||
logger.info(f"🐱 Cat response for {cat_user_id}: {reply_text[:100]}...")
|
||
return reply_text, full_prompt
|
||
else:
|
||
logger.warning("Cat returned empty response via WS")
|
||
self._consecutive_failures += 1
|
||
return None
|
||
|
||
except asyncio.TimeoutError:
|
||
logger.error(f"Cat WS connection timeout after {self._timeout}s")
|
||
self._consecutive_failures += 1
|
||
if self._consecutive_failures >= self._max_failures_before_circuit_break:
|
||
self._circuit_broken_until = time.time() + 60
|
||
logger.warning("Circuit breaker activated (WS timeout)")
|
||
return None
|
||
except Exception as e:
|
||
logger.error(f"Cat WS query error: {e}")
|
||
self._consecutive_failures += 1
|
||
if self._consecutive_failures >= self._max_failures_before_circuit_break:
|
||
self._circuit_broken_until = time.time() + 60
|
||
logger.warning(f"Circuit breaker activated: {e}")
|
||
return None
|
||
|
||
# ===================================================================
|
||
# MEMORY MANAGEMENT API (for Web UI)
|
||
# ===================================================================
|
||
|
||
async def get_memory_stats(self) -> Optional[Dict[str, Any]]:
|
||
"""
|
||
Get memory collection statistics with actual counts from Qdrant.
|
||
Returns dict with collection names and point counts.
|
||
"""
|
||
try:
|
||
# Query Qdrant directly for accurate counts
|
||
qdrant_host = self._base_url.replace("http://cheshire-cat:80", "http://cheshire-cat-vector-memory:6333")
|
||
|
||
collections_data = []
|
||
for collection_name in ["episodic", "declarative", "procedural"]:
|
||
async with aiohttp.ClientSession() as session:
|
||
async with session.get(
|
||
f"{qdrant_host}/collections/{collection_name}",
|
||
timeout=aiohttp.ClientTimeout(total=10)
|
||
) as response:
|
||
if response.status == 200:
|
||
data = await response.json()
|
||
count = data.get("result", {}).get("points_count", 0)
|
||
collections_data.append({
|
||
"name": collection_name,
|
||
"vectors_count": count
|
||
})
|
||
else:
|
||
collections_data.append({
|
||
"name": collection_name,
|
||
"vectors_count": 0
|
||
})
|
||
|
||
return {"collections": collections_data}
|
||
except Exception as e:
|
||
logger.error(f"Error getting memory stats from Qdrant: {e}")
|
||
return None
|
||
|
||
async def get_memory_points(
|
||
self,
|
||
collection: str = "declarative",
|
||
limit: int = 100,
|
||
offset: Optional[str] = None
|
||
) -> Optional[Dict[str, Any]]:
|
||
"""
|
||
Get all points from a memory collection via Qdrant.
|
||
Cat doesn't expose /memory/collections/{id}/points, so we query Qdrant directly.
|
||
Returns paginated list of memory points.
|
||
"""
|
||
try:
|
||
# Use Qdrant directly (Cat's vector memory backend)
|
||
# Qdrant is accessible at the same host, port 6333 internally
|
||
qdrant_host = self._base_url.replace("http://cheshire-cat:80", "http://cheshire-cat-vector-memory:6333")
|
||
|
||
payload = {"limit": limit, "with_payload": True, "with_vector": False}
|
||
if offset:
|
||
payload["offset"] = offset
|
||
|
||
async with aiohttp.ClientSession() as session:
|
||
async with session.post(
|
||
f"{qdrant_host}/collections/{collection}/points/scroll",
|
||
json=payload,
|
||
timeout=aiohttp.ClientTimeout(total=30)
|
||
) as response:
|
||
if response.status == 200:
|
||
data = await response.json()
|
||
return data.get("result", {})
|
||
else:
|
||
logger.error(f"Failed to get {collection} points from Qdrant: {response.status}")
|
||
return None
|
||
except Exception as e:
|
||
logger.error(f"Error getting memory points from Qdrant: {e}")
|
||
return None
|
||
|
||
async def get_all_facts(self) -> List[Dict[str, Any]]:
|
||
"""
|
||
Retrieve ALL declarative memory points (facts) with pagination.
|
||
Returns a flat list of all fact dicts.
|
||
"""
|
||
all_facts = []
|
||
offset = None
|
||
|
||
try:
|
||
while True:
|
||
result = await self.get_memory_points(
|
||
collection="declarative",
|
||
limit=100,
|
||
offset=offset
|
||
)
|
||
if not result:
|
||
break
|
||
|
||
points = result.get("points", [])
|
||
for point in points:
|
||
payload = point.get("payload", {})
|
||
fact = {
|
||
"id": point.get("id"),
|
||
"content": payload.get("page_content", ""),
|
||
"metadata": payload.get("metadata", {}),
|
||
}
|
||
all_facts.append(fact)
|
||
|
||
offset = result.get("next_offset")
|
||
if not offset:
|
||
break
|
||
|
||
logger.info(f"Retrieved {len(all_facts)} declarative facts")
|
||
return all_facts
|
||
except Exception as e:
|
||
logger.error(f"Error retrieving all facts: {e}")
|
||
return all_facts
|
||
|
||
async def delete_memory_point(self, collection: str, point_id: str) -> bool:
|
||
"""Delete a single memory point by ID via Qdrant."""
|
||
try:
|
||
qdrant_host = self._base_url.replace("http://cheshire-cat:80", "http://cheshire-cat-vector-memory:6333")
|
||
|
||
async with aiohttp.ClientSession() as session:
|
||
async with session.post(
|
||
f"{qdrant_host}/collections/{collection}/points/delete",
|
||
json={"points": [point_id]},
|
||
timeout=aiohttp.ClientTimeout(total=15)
|
||
) as response:
|
||
if response.status == 200:
|
||
logger.info(f"Deleted memory point {point_id} from {collection}")
|
||
return True
|
||
else:
|
||
logger.error(f"Failed to delete point: {response.status}")
|
||
return False
|
||
except Exception as e:
|
||
logger.error(f"Error deleting memory point: {e}")
|
||
return False
|
||
|
||
async def update_memory_point(self, collection: str, point_id: str, content: str, metadata: dict = None) -> bool:
|
||
"""Update an existing memory point's content and/or metadata."""
|
||
try:
|
||
# First, get the existing point to retrieve its vector
|
||
qdrant_host = self._base_url.replace("http://cheshire-cat:80", "http://cheshire-cat-vector-memory:6333")
|
||
|
||
async with aiohttp.ClientSession() as session:
|
||
# Get existing point
|
||
async with session.post(
|
||
f"{qdrant_host}/collections/{collection}/points",
|
||
json={"ids": [point_id], "with_vector": True, "with_payload": True},
|
||
timeout=aiohttp.ClientTimeout(total=15)
|
||
) as response:
|
||
if response.status != 200:
|
||
logger.error(f"Failed to fetch point {point_id}: {response.status}")
|
||
return False
|
||
|
||
data = await response.json()
|
||
points = data.get("result", [])
|
||
if not points:
|
||
logger.error(f"Point {point_id} not found")
|
||
return False
|
||
|
||
existing_point = points[0]
|
||
existing_vector = existing_point.get("vector")
|
||
existing_payload = existing_point.get("payload", {})
|
||
|
||
# If content changed, we need to re-embed it
|
||
if content != existing_payload.get("page_content"):
|
||
# Call Cat's embedder to get new vector
|
||
embed_response = await session.post(
|
||
f"{self._base_url}/embedder",
|
||
json={"text": content},
|
||
headers=self._get_headers(),
|
||
timeout=aiohttp.ClientTimeout(total=30)
|
||
)
|
||
if embed_response.status == 200:
|
||
embed_data = await embed_response.json()
|
||
new_vector = embed_data.get("embedding")
|
||
else:
|
||
logger.warning(f"Failed to re-embed content, keeping old vector")
|
||
new_vector = existing_vector
|
||
else:
|
||
new_vector = existing_vector
|
||
|
||
# Build updated payload
|
||
updated_payload = {
|
||
"page_content": content,
|
||
"metadata": metadata if metadata is not None else existing_payload.get("metadata", {})
|
||
}
|
||
|
||
# Update the point
|
||
async with session.put(
|
||
f"{qdrant_host}/collections/{collection}/points",
|
||
json={
|
||
"points": [{
|
||
"id": point_id,
|
||
"vector": new_vector,
|
||
"payload": updated_payload
|
||
}]
|
||
},
|
||
timeout=aiohttp.ClientTimeout(total=15)
|
||
) as update_response:
|
||
if update_response.status == 200:
|
||
logger.info(f"✏️ Updated memory point {point_id} in {collection}")
|
||
return True
|
||
else:
|
||
logger.error(f"Failed to update point: {update_response.status}")
|
||
return False
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error updating memory point: {e}")
|
||
return False
|
||
|
||
async def create_memory_point(self, collection: str, content: str, user_id: str, source: str, metadata: dict = None) -> Optional[str]:
|
||
"""Create a new memory point manually."""
|
||
try:
|
||
import uuid
|
||
import time
|
||
|
||
# Generate a unique ID
|
||
point_id = str(uuid.uuid4())
|
||
|
||
# Get vector embedding from Cat
|
||
async with aiohttp.ClientSession() as session:
|
||
async with session.post(
|
||
f"{self._base_url}/embedder",
|
||
json={"text": content},
|
||
headers=self._get_headers(),
|
||
timeout=aiohttp.ClientTimeout(total=30)
|
||
) as response:
|
||
if response.status != 200:
|
||
logger.error(f"Failed to embed content: {response.status}")
|
||
return None
|
||
|
||
data = await response.json()
|
||
vector = data.get("embedding")
|
||
if not vector:
|
||
logger.error("No embedding returned from Cat")
|
||
return None
|
||
|
||
# Build payload
|
||
payload = {
|
||
"page_content": content,
|
||
"metadata": metadata or {}
|
||
}
|
||
payload["metadata"]["source"] = source
|
||
payload["metadata"]["when"] = time.time()
|
||
|
||
# For declarative memories, add user_id to metadata
|
||
# For episodic, it's in the source field
|
||
if collection == "declarative":
|
||
payload["metadata"]["user_id"] = user_id
|
||
elif collection == "episodic":
|
||
payload["metadata"]["source"] = user_id
|
||
|
||
# Insert into Qdrant
|
||
qdrant_host = self._base_url.replace("http://cheshire-cat:80", "http://cheshire-cat-vector-memory:6333")
|
||
|
||
async with session.put(
|
||
f"{qdrant_host}/collections/{collection}/points",
|
||
json={
|
||
"points": [{
|
||
"id": point_id,
|
||
"vector": vector,
|
||
"payload": payload
|
||
}]
|
||
},
|
||
timeout=aiohttp.ClientTimeout(total=15)
|
||
) as insert_response:
|
||
if insert_response.status == 200:
|
||
logger.info(f"✨ Created new {collection} memory point: {point_id}")
|
||
return point_id
|
||
else:
|
||
logger.error(f"Failed to insert point: {insert_response.status}")
|
||
return None
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error creating memory point: {e}")
|
||
return None
|
||
|
||
async def wipe_all_memories(self) -> bool:
|
||
"""
|
||
Delete ALL memory collections (episodic + declarative).
|
||
This is the nuclear option — requires multi-step confirmation in the UI.
|
||
"""
|
||
try:
|
||
async with aiohttp.ClientSession() as session:
|
||
async with session.delete(
|
||
f"{self._base_url}/memory/collections",
|
||
headers=self._get_headers(),
|
||
timeout=aiohttp.ClientTimeout(total=30)
|
||
) as response:
|
||
if response.status == 200:
|
||
logger.warning("🗑️ ALL memory collections wiped!")
|
||
return True
|
||
else:
|
||
error = await response.text()
|
||
logger.error(f"Failed to wipe memories: {response.status} - {error}")
|
||
return False
|
||
except Exception as e:
|
||
logger.error(f"Error wiping memories: {e}")
|
||
return False
|
||
|
||
async def wipe_conversation_history(self) -> bool:
|
||
"""Clear working memory / conversation history."""
|
||
try:
|
||
async with aiohttp.ClientSession() as session:
|
||
async with session.delete(
|
||
f"{self._base_url}/memory/conversation_history",
|
||
headers=self._get_headers(),
|
||
timeout=aiohttp.ClientTimeout(total=15)
|
||
) as response:
|
||
if response.status == 200:
|
||
logger.info("Conversation history cleared")
|
||
return True
|
||
else:
|
||
logger.error(f"Failed to clear conversation history: {response.status}")
|
||
return False
|
||
except Exception as e:
|
||
logger.error(f"Error clearing conversation history: {e}")
|
||
return False
|
||
|
||
async def trigger_consolidation(self) -> Optional[str]:
|
||
"""
|
||
Trigger memory consolidation by sending a special message via WebSocket.
|
||
The memory_consolidation plugin's tool 'consolidate_memories' is
|
||
triggered when it sees 'consolidate now' in the text.
|
||
Uses WebSocket with a system user ID for proper context.
|
||
"""
|
||
try:
|
||
ws_base = self._base_url.replace("http://", "ws://").replace("https://", "wss://")
|
||
ws_url = f"{ws_base}/ws/system_consolidation"
|
||
|
||
logger.info("🌙 Triggering memory consolidation via WS...")
|
||
|
||
async with aiohttp.ClientSession() as session:
|
||
async with session.ws_connect(
|
||
ws_url,
|
||
timeout=300, # Consolidation can be very slow
|
||
) as ws:
|
||
await ws.send_json({"text": "consolidate now"})
|
||
|
||
# Wait for the final chat response
|
||
deadline = asyncio.get_event_loop().time() + 300
|
||
|
||
while True:
|
||
remaining = deadline - asyncio.get_event_loop().time()
|
||
if remaining <= 0:
|
||
logger.error("Consolidation timed out (>300s)")
|
||
return "Consolidation timed out"
|
||
|
||
try:
|
||
ws_msg = await asyncio.wait_for(
|
||
ws.receive(),
|
||
timeout=remaining
|
||
)
|
||
except asyncio.TimeoutError:
|
||
logger.error("Consolidation WS receive timeout")
|
||
return "Consolidation timed out waiting for response"
|
||
|
||
if ws_msg.type in (aiohttp.WSMsgType.CLOSE, aiohttp.WSMsgType.CLOSING, aiohttp.WSMsgType.CLOSED):
|
||
logger.warning("Consolidation WS closed by server")
|
||
return "Connection closed during consolidation"
|
||
if ws_msg.type == aiohttp.WSMsgType.ERROR:
|
||
return f"WebSocket error: {ws.exception()}"
|
||
if ws_msg.type != aiohttp.WSMsgType.TEXT:
|
||
continue
|
||
|
||
try:
|
||
msg = json.loads(ws_msg.data)
|
||
except (json.JSONDecodeError, TypeError):
|
||
continue
|
||
|
||
msg_type = msg.get("type", "")
|
||
if msg_type == "chat":
|
||
reply = msg.get("content") or msg.get("text", "")
|
||
logger.info(f"Consolidation result: {reply[:200]}")
|
||
return reply
|
||
elif msg_type == "error":
|
||
error_desc = msg.get("description", "Unknown error")
|
||
logger.error(f"Consolidation error: {error_desc}")
|
||
return f"Consolidation error: {error_desc}"
|
||
else:
|
||
continue
|
||
|
||
except asyncio.TimeoutError:
|
||
logger.error("Consolidation WS connection timed out")
|
||
return None
|
||
except Exception as e:
|
||
logger.error(f"Consolidation error: {e}")
|
||
return None
|
||
|
||
# ====================================================================
|
||
# Admin API helpers – plugin toggling & LLM model switching
|
||
# ====================================================================
|
||
|
||
async def wait_for_ready(self, max_wait: int = 120, interval: int = 5) -> bool:
|
||
"""Wait for Cat to become reachable, polling with interval.
|
||
|
||
Used on startup to avoid race conditions when bot starts before Cat.
|
||
Returns True once Cat responds, False if max_wait exceeded.
|
||
"""
|
||
start = time.time()
|
||
attempt = 0
|
||
while time.time() - start < max_wait:
|
||
attempt += 1
|
||
try:
|
||
async with aiohttp.ClientSession() as session:
|
||
async with session.get(
|
||
f"{self._base_url}/",
|
||
timeout=aiohttp.ClientTimeout(total=5),
|
||
) as resp:
|
||
if resp.status == 200:
|
||
elapsed = time.time() - start
|
||
logger.info(f"🐱 Cat is ready (took {elapsed:.1f}s, {attempt} attempts)")
|
||
self._healthy = True
|
||
self._last_health_check = time.time()
|
||
return True
|
||
except Exception:
|
||
pass
|
||
if attempt == 1:
|
||
logger.info(f"⏳ Waiting for Cat to become ready (up to {max_wait}s)...")
|
||
await asyncio.sleep(interval)
|
||
logger.error(f"Cat did not become ready within {max_wait}s ({attempt} attempts)")
|
||
return False
|
||
|
||
async def toggle_plugin(self, plugin_id: str) -> bool:
|
||
"""Toggle a Cat plugin on/off via the admin API.
|
||
|
||
PUT /plugins/toggle/{plugin_id}
|
||
Returns True on success, False on failure.
|
||
"""
|
||
url = f"{self._base_url}/plugins/toggle/{plugin_id}"
|
||
try:
|
||
async with aiohttp.ClientSession() as session:
|
||
async with session.put(
|
||
url,
|
||
headers=self._get_headers(),
|
||
timeout=aiohttp.ClientTimeout(total=15),
|
||
) as resp:
|
||
if resp.status == 200:
|
||
logger.info(f"🐱 Toggled Cat plugin: {plugin_id}")
|
||
return True
|
||
else:
|
||
body = await resp.text()
|
||
logger.error(f"Cat plugin toggle failed ({resp.status}): {body}")
|
||
return False
|
||
except Exception as e:
|
||
logger.error(f"Cat plugin toggle error for {plugin_id}: {e}")
|
||
return False
|
||
|
||
async def set_llm_model(self, model_name: str) -> bool:
|
||
"""Switch the Cheshire Cat's active LLM model via settings API.
|
||
|
||
The Cat settings API uses UUIDs: we must first GET /settings/ to find
|
||
the setting_id for LLMOpenAIChatConfig, then PUT /settings/{setting_id}.
|
||
llama-swap handles the actual model loading based on model_name.
|
||
Returns True on success, False on failure.
|
||
"""
|
||
try:
|
||
# Step 1: Find the setting_id for LLMOpenAIChatConfig
|
||
setting_id = None
|
||
async with aiohttp.ClientSession() as session:
|
||
async with session.get(
|
||
f"{self._base_url}/settings/",
|
||
headers=self._get_headers(),
|
||
timeout=aiohttp.ClientTimeout(total=10),
|
||
) as resp:
|
||
if resp.status != 200:
|
||
logger.error(f"Cat settings GET failed ({resp.status})")
|
||
return False
|
||
data = await resp.json()
|
||
for s in data.get("settings", []):
|
||
if s.get("name") == "LLMOpenAIChatConfig":
|
||
setting_id = s["setting_id"]
|
||
break
|
||
|
||
if not setting_id:
|
||
logger.error("Could not find LLMOpenAIChatConfig setting_id in Cat settings")
|
||
return False
|
||
|
||
# Step 2: PUT updated config to /settings/{setting_id}
|
||
payload = {
|
||
"name": "LLMOpenAIChatConfig",
|
||
"value": {
|
||
"openai_api_key": "sk-dummy",
|
||
"model_name": model_name,
|
||
"temperature": 0.8,
|
||
"streaming": False,
|
||
},
|
||
"category": "llm_factory",
|
||
}
|
||
async with aiohttp.ClientSession() as session:
|
||
async with session.put(
|
||
f"{self._base_url}/settings/{setting_id}",
|
||
json=payload,
|
||
headers=self._get_headers(),
|
||
timeout=aiohttp.ClientTimeout(total=15),
|
||
) as resp:
|
||
if resp.status == 200:
|
||
logger.info(f"🐱 Set Cat LLM model to: {model_name}")
|
||
return True
|
||
else:
|
||
body = await resp.text()
|
||
logger.error(f"Cat LLM model switch failed ({resp.status}): {body}")
|
||
return False
|
||
except Exception as e:
|
||
logger.error(f"Cat LLM model switch error: {e}")
|
||
return False
|
||
|
||
async def get_active_plugins(self) -> list:
|
||
"""Get list of active Cat plugin IDs.
|
||
|
||
GET /plugins → returns {\"installed\": [...], \"filters\": {...}}
|
||
Each plugin has \"id\" and \"active\" fields.
|
||
"""
|
||
url = f"{self._base_url}/plugins"
|
||
try:
|
||
async with aiohttp.ClientSession() as session:
|
||
async with session.get(
|
||
url,
|
||
headers=self._get_headers(),
|
||
timeout=aiohttp.ClientTimeout(total=10),
|
||
) as resp:
|
||
if resp.status == 200:
|
||
data = await resp.json()
|
||
installed = data.get("installed", [])
|
||
return [p["id"] for p in installed if p.get("active")]
|
||
else:
|
||
logger.error(f"Cat get_active_plugins failed ({resp.status})")
|
||
return []
|
||
except Exception as e:
|
||
logger.error(f"Cat get_active_plugins error: {e}")
|
||
return []
|
||
|
||
async def switch_to_evil_personality(self) -> bool:
|
||
"""Disable miku_personality, enable evil_miku_personality, switch LLM to darkidol.
|
||
|
||
Checks current plugin state first to avoid double-toggling
|
||
(the Cat API is a toggle, not enable/disable).
|
||
Returns True if all operations succeed, False if any fail.
|
||
"""
|
||
logger.info("🐱 Switching Cat to Evil Miku personality...")
|
||
success = True
|
||
|
||
# Check current plugin state
|
||
active = await self.get_active_plugins()
|
||
|
||
# Step 1: Disable normal personality (only if currently active)
|
||
if "miku_personality" in active:
|
||
if not await self.toggle_plugin("miku_personality"):
|
||
logger.error("Failed to disable miku_personality plugin")
|
||
success = False
|
||
await asyncio.sleep(1)
|
||
else:
|
||
logger.debug("miku_personality already disabled, skipping toggle")
|
||
|
||
# Step 2: Enable evil personality (only if currently inactive)
|
||
if "evil_miku_personality" not in active:
|
||
if not await self.toggle_plugin("evil_miku_personality"):
|
||
logger.error("Failed to enable evil_miku_personality plugin")
|
||
success = False
|
||
else:
|
||
logger.debug("evil_miku_personality already active, skipping toggle")
|
||
|
||
# Step 3: Switch LLM model to darkidol (the uncensored evil model)
|
||
if not await self.set_llm_model("darkidol"):
|
||
logger.error("Failed to switch Cat LLM to darkidol")
|
||
success = False
|
||
|
||
return success
|
||
|
||
async def switch_to_normal_personality(self) -> bool:
|
||
"""Disable evil_miku_personality, enable miku_personality, switch LLM to llama3.1.
|
||
|
||
Checks current plugin state first to avoid double-toggling.
|
||
Returns True if all operations succeed, False if any fail.
|
||
"""
|
||
logger.info("🐱 Switching Cat to normal Miku personality...")
|
||
success = True
|
||
|
||
# Check current plugin state
|
||
active = await self.get_active_plugins()
|
||
|
||
# Step 1: Disable evil personality (only if currently active)
|
||
if "evil_miku_personality" in active:
|
||
if not await self.toggle_plugin("evil_miku_personality"):
|
||
logger.error("Failed to disable evil_miku_personality plugin")
|
||
success = False
|
||
await asyncio.sleep(1)
|
||
else:
|
||
logger.debug("evil_miku_personality already disabled, skipping toggle")
|
||
|
||
# Step 2: Enable normal personality (only if currently inactive)
|
||
if "miku_personality" not in active:
|
||
if not await self.toggle_plugin("miku_personality"):
|
||
logger.error("Failed to enable miku_personality plugin")
|
||
success = False
|
||
else:
|
||
logger.debug("miku_personality already active, skipping toggle")
|
||
|
||
# Step 3: Switch LLM model back to llama3.1 (normal model)
|
||
if not await self.set_llm_model("llama3.1"):
|
||
logger.error("Failed to switch Cat LLM to llama3.1")
|
||
success = False
|
||
|
||
return success
|
||
|
||
|
||
# Singleton instance
|
||
cat_adapter = CatAdapter()
|