- Fix silent None return in analyze_image_with_vision exception handler - Add None/empty guards after vision analysis in bot.py (image, video, GIF, Tenor) - Route all image/video/GIF responses through Cheshire Cat pipeline (was calling query_llama directly), enabling episodic memory storage for media interactions and correct Last Prompt display in Web UI - Add media_type parameter to cat_adapter.query() and forward as discord_media_type in WebSocket payload - Update discord_bridge plugin to read media_type from payload and inject MEDIA NOTE into system prefix in before_agent_starts hook - Add _extract_vision_question() helper to strip Discord mentions and bot-name triggers from user message; pass cleaned question to vision model so specific questions (e.g. 'what is the person wearing?') go directly to the vision model instead of the generic 'Describe this image in detail.' fallback - Pass user_prompt to all analyze_image_with_qwen / analyze_video_with_vision call sites in bot.py (image, video, GIF, Tenor, embed paths) - Fix autonomous reaction loops skipping messages that @mention the bot or have media attachments in DMs, preventing duplicate vision model calls for images already being processed by the main message handler - Increase vision max_tokens: images 300->800, video/GIF 400->1000 (no VRAM impact; KV cache is pre-allocated at model load time)
869 lines
38 KiB
Python
869 lines
38 KiB
Python
# utils/cat_client.py
|
||
"""
|
||
Cheshire Cat AI Adapter for Miku Discord Bot (Phase 3)
|
||
|
||
Routes messages through the Cheshire Cat pipeline for:
|
||
- Memory-augmented responses (episodic + declarative recall)
|
||
- Fact extraction and consolidation
|
||
- Per-user conversation isolation
|
||
|
||
Uses WebSocket for chat (per-user isolation via /ws/{user_id}).
|
||
Uses HTTP for memory management endpoints.
|
||
Falls back to query_llama() on failure for zero-downtime resilience.
|
||
"""
|
||
|
||
import aiohttp
|
||
import asyncio
|
||
import json
|
||
import time
|
||
from typing import Optional, Dict, Any, List
|
||
|
||
import globals
|
||
from utils.logger import get_logger
|
||
|
||
logger = get_logger('llm') # Use existing 'llm' logger component
|
||
|
||
|
||
class CatAdapter:
|
||
"""
|
||
Async adapter for Cheshire Cat AI.
|
||
|
||
Uses WebSocket /ws/{user_id} for conversation (per-user memory isolation).
|
||
Uses HTTP REST for memory management endpoints.
|
||
Without API keys configured, HTTP POST /message defaults all users to
|
||
user_id="user" (no isolation). WebSocket path param gives true isolation.
|
||
"""
|
||
|
||
def __init__(self):
|
||
self._base_url = globals.CHESHIRE_CAT_URL.rstrip('/')
|
||
self._api_key = globals.CHESHIRE_CAT_API_KEY
|
||
self._timeout = globals.CHESHIRE_CAT_TIMEOUT
|
||
self._healthy = None # None = unknown, True/False = last check result
|
||
self._last_health_check = 0
|
||
self._health_check_interval = 30 # seconds between health checks
|
||
self._consecutive_failures = 0
|
||
self._max_failures_before_circuit_break = 3
|
||
self._circuit_broken_until = 0 # timestamp when circuit breaker resets
|
||
logger.info(f"CatAdapter initialized: {self._base_url} (timeout={self._timeout}s)")
|
||
|
||
def _get_headers(self) -> dict:
|
||
"""Build request headers with optional auth."""
|
||
headers = {'Content-Type': 'application/json'}
|
||
if self._api_key:
|
||
headers['Authorization'] = f'Bearer {self._api_key}'
|
||
return headers
|
||
|
||
def _user_id_for_discord(self, user_id: str) -> str:
|
||
"""
|
||
Format Discord user ID for Cat's user namespace.
|
||
Cat uses user_id to isolate working memory and episodic memories.
|
||
"""
|
||
return f"discord_{user_id}"
|
||
|
||
async def health_check(self) -> bool:
|
||
"""
|
||
Check if Cheshire Cat is reachable and healthy.
|
||
Caches result to avoid hammering the endpoint.
|
||
"""
|
||
now = time.time()
|
||
if now - self._last_health_check < self._health_check_interval and self._healthy is not None:
|
||
return self._healthy
|
||
|
||
try:
|
||
async with aiohttp.ClientSession() as session:
|
||
async with session.get(
|
||
f"{self._base_url}/",
|
||
headers=self._get_headers(),
|
||
timeout=aiohttp.ClientTimeout(total=10)
|
||
) as response:
|
||
self._healthy = response.status == 200
|
||
self._last_health_check = now
|
||
if self._healthy:
|
||
logger.debug("Cat health check: OK")
|
||
else:
|
||
logger.warning(f"Cat health check failed: status {response.status}")
|
||
return self._healthy
|
||
except Exception as e:
|
||
self._healthy = False
|
||
self._last_health_check = now
|
||
logger.warning(f"Cat health check error: {e}")
|
||
return False
|
||
|
||
def _is_circuit_broken(self) -> bool:
|
||
"""Check if circuit breaker is active (too many consecutive failures)."""
|
||
if self._consecutive_failures >= self._max_failures_before_circuit_break:
|
||
if time.time() < self._circuit_broken_until:
|
||
return True
|
||
# Circuit breaker expired, allow retry
|
||
logger.info("Circuit breaker reset, allowing Cat retry")
|
||
self._consecutive_failures = 0
|
||
return False
|
||
|
||
async def query(
|
||
self,
|
||
text: str,
|
||
user_id: str,
|
||
guild_id: Optional[str] = None,
|
||
author_name: Optional[str] = None,
|
||
mood: Optional[str] = None,
|
||
response_type: str = "dm_response",
|
||
media_type: Optional[str] = None,
|
||
) -> Optional[tuple]:
|
||
"""
|
||
Send a message through the Cat pipeline via WebSocket and get a response.
|
||
|
||
Uses WebSocket /ws/{user_id} for per-user memory isolation.
|
||
Without API keys, HTTP POST /message defaults all users to user_id="user"
|
||
(no isolation). The WebSocket path parameter provides true per-user isolation
|
||
because Cat's auth handler uses user_id from the path when no keys are set.
|
||
|
||
Args:
|
||
text: User's message text
|
||
user_id: Discord user ID (will be namespaced as discord_{user_id})
|
||
guild_id: Optional guild ID for server context
|
||
author_name: Display name of the user
|
||
mood: Current mood name (passed as metadata for Cat hooks)
|
||
response_type: Type of response context
|
||
media_type: Type of media attachment ("image", "video", "gif", "tenor_gif")
|
||
|
||
Returns:
|
||
Tuple of (response_text, full_prompt) on success, or None if Cat
|
||
is unavailable (caller should fallback to query_llama)
|
||
"""
|
||
if not globals.USE_CHESHIRE_CAT:
|
||
return None
|
||
|
||
if self._is_circuit_broken():
|
||
logger.debug("Circuit breaker active, skipping Cat")
|
||
return None
|
||
|
||
cat_user_id = self._user_id_for_discord(user_id)
|
||
|
||
# Build message payload with Discord metadata for our plugin hooks.
|
||
# The discord_bridge plugin's before_cat_reads_message hook reads
|
||
# these custom keys from the message dict.
|
||
payload = {
|
||
"text": text,
|
||
}
|
||
if guild_id:
|
||
payload["discord_guild_id"] = str(guild_id)
|
||
if author_name:
|
||
payload["discord_author_name"] = author_name
|
||
# When evil mode is active, send the evil mood name instead of the normal mood
|
||
if globals.EVIL_MODE:
|
||
payload["discord_mood"] = getattr(globals, 'EVIL_DM_MOOD', 'evil_neutral')
|
||
elif mood:
|
||
payload["discord_mood"] = mood
|
||
if response_type:
|
||
payload["discord_response_type"] = response_type
|
||
# Pass evil mode flag so discord_bridge stores it in working_memory
|
||
payload["discord_evil_mode"] = globals.EVIL_MODE
|
||
# Pass media type so discord_bridge can add MEDIA NOTE to the prompt
|
||
if media_type:
|
||
payload["discord_media_type"] = media_type
|
||
|
||
try:
|
||
# Build WebSocket URL from HTTP base URL
|
||
ws_base = self._base_url.replace("http://", "ws://").replace("https://", "wss://")
|
||
ws_url = f"{ws_base}/ws/{cat_user_id}"
|
||
|
||
logger.debug(f"Querying Cat via WS: user={cat_user_id}, text={text[:80]}...")
|
||
|
||
async with aiohttp.ClientSession() as session:
|
||
async with session.ws_connect(
|
||
ws_url,
|
||
timeout=self._timeout,
|
||
) as ws:
|
||
# Send the message
|
||
await ws.send_json(payload)
|
||
|
||
# Read responses until we get the final "chat" type message.
|
||
# Cat may send intermediate messages (chat_token for streaming,
|
||
# notification for status updates). We want the final "chat" one.
|
||
reply_text = None
|
||
full_prompt = ""
|
||
deadline = asyncio.get_event_loop().time() + self._timeout
|
||
|
||
while True:
|
||
remaining = deadline - asyncio.get_event_loop().time()
|
||
if remaining <= 0:
|
||
logger.error(f"Cat WS timeout after {self._timeout}s")
|
||
break
|
||
|
||
try:
|
||
ws_msg = await asyncio.wait_for(
|
||
ws.receive(),
|
||
timeout=remaining
|
||
)
|
||
except asyncio.TimeoutError:
|
||
logger.error(f"Cat WS receive timeout after {self._timeout}s")
|
||
break
|
||
|
||
# Handle WebSocket close/error frames
|
||
if ws_msg.type in (aiohttp.WSMsgType.CLOSE, aiohttp.WSMsgType.CLOSING, aiohttp.WSMsgType.CLOSED):
|
||
logger.warning("Cat WS connection closed by server")
|
||
break
|
||
if ws_msg.type == aiohttp.WSMsgType.ERROR:
|
||
logger.error(f"Cat WS error frame: {ws.exception()}")
|
||
break
|
||
if ws_msg.type != aiohttp.WSMsgType.TEXT:
|
||
logger.debug(f"Cat WS non-text frame type: {ws_msg.type}")
|
||
continue
|
||
|
||
try:
|
||
msg = json.loads(ws_msg.data)
|
||
except (json.JSONDecodeError, TypeError) as e:
|
||
logger.warning(f"Cat WS non-JSON message: {e}")
|
||
continue
|
||
|
||
msg_type = msg.get("type", "")
|
||
|
||
if msg_type == "chat":
|
||
# Final response — extract text and full prompt
|
||
reply_text = msg.get("content") or msg.get("text", "")
|
||
full_prompt = msg.get("full_prompt", "")
|
||
break
|
||
elif msg_type == "chat_token":
|
||
# Streaming token — skip, we wait for final
|
||
continue
|
||
elif msg_type == "error":
|
||
error_desc = msg.get("description", "Unknown Cat error")
|
||
logger.error(f"Cat WS error: {error_desc}")
|
||
break
|
||
elif msg_type == "notification":
|
||
logger.debug(f"Cat notification: {msg.get('content', '')}")
|
||
continue
|
||
else:
|
||
logger.debug(f"Cat WS unknown msg type: {msg_type}")
|
||
continue
|
||
|
||
if reply_text and reply_text.strip():
|
||
self._consecutive_failures = 0
|
||
logger.info(f"🐱 Cat response for {cat_user_id}: {reply_text[:100]}...")
|
||
return reply_text, full_prompt
|
||
else:
|
||
logger.warning("Cat returned empty response via WS")
|
||
self._consecutive_failures += 1
|
||
return None
|
||
|
||
except asyncio.TimeoutError:
|
||
logger.error(f"Cat WS connection timeout after {self._timeout}s")
|
||
self._consecutive_failures += 1
|
||
if self._consecutive_failures >= self._max_failures_before_circuit_break:
|
||
self._circuit_broken_until = time.time() + 60
|
||
logger.warning("Circuit breaker activated (WS timeout)")
|
||
return None
|
||
except Exception as e:
|
||
logger.error(f"Cat WS query error: {e}")
|
||
self._consecutive_failures += 1
|
||
if self._consecutive_failures >= self._max_failures_before_circuit_break:
|
||
self._circuit_broken_until = time.time() + 60
|
||
logger.warning(f"Circuit breaker activated: {e}")
|
||
return None
|
||
|
||
# ===================================================================
|
||
# MEMORY MANAGEMENT API (for Web UI)
|
||
# ===================================================================
|
||
|
||
async def get_memory_stats(self) -> Optional[Dict[str, Any]]:
|
||
"""
|
||
Get memory collection statistics with actual counts from Qdrant.
|
||
Returns dict with collection names and point counts.
|
||
"""
|
||
try:
|
||
# Query Qdrant directly for accurate counts
|
||
qdrant_host = self._base_url.replace("http://cheshire-cat:80", "http://cheshire-cat-vector-memory:6333")
|
||
|
||
collections_data = []
|
||
for collection_name in ["episodic", "declarative", "procedural"]:
|
||
async with aiohttp.ClientSession() as session:
|
||
async with session.get(
|
||
f"{qdrant_host}/collections/{collection_name}",
|
||
timeout=aiohttp.ClientTimeout(total=10)
|
||
) as response:
|
||
if response.status == 200:
|
||
data = await response.json()
|
||
count = data.get("result", {}).get("points_count", 0)
|
||
collections_data.append({
|
||
"name": collection_name,
|
||
"vectors_count": count
|
||
})
|
||
else:
|
||
collections_data.append({
|
||
"name": collection_name,
|
||
"vectors_count": 0
|
||
})
|
||
|
||
return {"collections": collections_data}
|
||
except Exception as e:
|
||
logger.error(f"Error getting memory stats from Qdrant: {e}")
|
||
return None
|
||
|
||
async def get_memory_points(
|
||
self,
|
||
collection: str = "declarative",
|
||
limit: int = 100,
|
||
offset: Optional[str] = None
|
||
) -> Optional[Dict[str, Any]]:
|
||
"""
|
||
Get all points from a memory collection via Qdrant.
|
||
Cat doesn't expose /memory/collections/{id}/points, so we query Qdrant directly.
|
||
Returns paginated list of memory points.
|
||
"""
|
||
try:
|
||
# Use Qdrant directly (Cat's vector memory backend)
|
||
# Qdrant is accessible at the same host, port 6333 internally
|
||
qdrant_host = self._base_url.replace("http://cheshire-cat:80", "http://cheshire-cat-vector-memory:6333")
|
||
|
||
payload = {"limit": limit, "with_payload": True, "with_vector": False}
|
||
if offset:
|
||
payload["offset"] = offset
|
||
|
||
async with aiohttp.ClientSession() as session:
|
||
async with session.post(
|
||
f"{qdrant_host}/collections/{collection}/points/scroll",
|
||
json=payload,
|
||
timeout=aiohttp.ClientTimeout(total=30)
|
||
) as response:
|
||
if response.status == 200:
|
||
data = await response.json()
|
||
return data.get("result", {})
|
||
else:
|
||
logger.error(f"Failed to get {collection} points from Qdrant: {response.status}")
|
||
return None
|
||
except Exception as e:
|
||
logger.error(f"Error getting memory points from Qdrant: {e}")
|
||
return None
|
||
|
||
async def get_all_facts(self) -> List[Dict[str, Any]]:
|
||
"""
|
||
Retrieve ALL declarative memory points (facts) with pagination.
|
||
Returns a flat list of all fact dicts.
|
||
"""
|
||
all_facts = []
|
||
offset = None
|
||
|
||
try:
|
||
while True:
|
||
result = await self.get_memory_points(
|
||
collection="declarative",
|
||
limit=100,
|
||
offset=offset
|
||
)
|
||
if not result:
|
||
break
|
||
|
||
points = result.get("points", [])
|
||
for point in points:
|
||
payload = point.get("payload", {})
|
||
fact = {
|
||
"id": point.get("id"),
|
||
"content": payload.get("page_content", ""),
|
||
"metadata": payload.get("metadata", {}),
|
||
}
|
||
all_facts.append(fact)
|
||
|
||
offset = result.get("next_offset")
|
||
if not offset:
|
||
break
|
||
|
||
logger.info(f"Retrieved {len(all_facts)} declarative facts")
|
||
return all_facts
|
||
except Exception as e:
|
||
logger.error(f"Error retrieving all facts: {e}")
|
||
return all_facts
|
||
|
||
async def delete_memory_point(self, collection: str, point_id: str) -> bool:
|
||
"""Delete a single memory point by ID via Qdrant."""
|
||
try:
|
||
qdrant_host = self._base_url.replace("http://cheshire-cat:80", "http://cheshire-cat-vector-memory:6333")
|
||
|
||
async with aiohttp.ClientSession() as session:
|
||
async with session.post(
|
||
f"{qdrant_host}/collections/{collection}/points/delete",
|
||
json={"points": [point_id]},
|
||
timeout=aiohttp.ClientTimeout(total=15)
|
||
) as response:
|
||
if response.status == 200:
|
||
logger.info(f"Deleted memory point {point_id} from {collection}")
|
||
return True
|
||
else:
|
||
logger.error(f"Failed to delete point: {response.status}")
|
||
return False
|
||
except Exception as e:
|
||
logger.error(f"Error deleting memory point: {e}")
|
||
return False
|
||
|
||
async def update_memory_point(self, collection: str, point_id: str, content: str, metadata: dict = None) -> bool:
|
||
"""Update an existing memory point's content and/or metadata."""
|
||
try:
|
||
# First, get the existing point to retrieve its vector
|
||
qdrant_host = self._base_url.replace("http://cheshire-cat:80", "http://cheshire-cat-vector-memory:6333")
|
||
|
||
async with aiohttp.ClientSession() as session:
|
||
# Get existing point
|
||
async with session.post(
|
||
f"{qdrant_host}/collections/{collection}/points",
|
||
json={"ids": [point_id], "with_vector": True, "with_payload": True},
|
||
timeout=aiohttp.ClientTimeout(total=15)
|
||
) as response:
|
||
if response.status != 200:
|
||
logger.error(f"Failed to fetch point {point_id}: {response.status}")
|
||
return False
|
||
|
||
data = await response.json()
|
||
points = data.get("result", [])
|
||
if not points:
|
||
logger.error(f"Point {point_id} not found")
|
||
return False
|
||
|
||
existing_point = points[0]
|
||
existing_vector = existing_point.get("vector")
|
||
existing_payload = existing_point.get("payload", {})
|
||
|
||
# If content changed, we need to re-embed it
|
||
if content != existing_payload.get("page_content"):
|
||
# Call Cat's embedder to get new vector
|
||
embed_response = await session.post(
|
||
f"{self._base_url}/embedder",
|
||
json={"text": content},
|
||
headers=self._get_headers(),
|
||
timeout=aiohttp.ClientTimeout(total=30)
|
||
)
|
||
if embed_response.status == 200:
|
||
embed_data = await embed_response.json()
|
||
new_vector = embed_data.get("embedding")
|
||
else:
|
||
logger.warning(f"Failed to re-embed content, keeping old vector")
|
||
new_vector = existing_vector
|
||
else:
|
||
new_vector = existing_vector
|
||
|
||
# Build updated payload
|
||
updated_payload = {
|
||
"page_content": content,
|
||
"metadata": metadata if metadata is not None else existing_payload.get("metadata", {})
|
||
}
|
||
|
||
# Update the point
|
||
async with session.put(
|
||
f"{qdrant_host}/collections/{collection}/points",
|
||
json={
|
||
"points": [{
|
||
"id": point_id,
|
||
"vector": new_vector,
|
||
"payload": updated_payload
|
||
}]
|
||
},
|
||
timeout=aiohttp.ClientTimeout(total=15)
|
||
) as update_response:
|
||
if update_response.status == 200:
|
||
logger.info(f"✏️ Updated memory point {point_id} in {collection}")
|
||
return True
|
||
else:
|
||
logger.error(f"Failed to update point: {update_response.status}")
|
||
return False
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error updating memory point: {e}")
|
||
return False
|
||
|
||
async def create_memory_point(self, collection: str, content: str, user_id: str, source: str, metadata: dict = None) -> Optional[str]:
|
||
"""Create a new memory point manually."""
|
||
try:
|
||
import uuid
|
||
import time
|
||
|
||
# Generate a unique ID
|
||
point_id = str(uuid.uuid4())
|
||
|
||
# Get vector embedding from Cat
|
||
async with aiohttp.ClientSession() as session:
|
||
async with session.post(
|
||
f"{self._base_url}/embedder",
|
||
json={"text": content},
|
||
headers=self._get_headers(),
|
||
timeout=aiohttp.ClientTimeout(total=30)
|
||
) as response:
|
||
if response.status != 200:
|
||
logger.error(f"Failed to embed content: {response.status}")
|
||
return None
|
||
|
||
data = await response.json()
|
||
vector = data.get("embedding")
|
||
if not vector:
|
||
logger.error("No embedding returned from Cat")
|
||
return None
|
||
|
||
# Build payload
|
||
payload = {
|
||
"page_content": content,
|
||
"metadata": metadata or {}
|
||
}
|
||
payload["metadata"]["source"] = source
|
||
payload["metadata"]["when"] = time.time()
|
||
|
||
# For declarative memories, add user_id to metadata
|
||
# For episodic, it's in the source field
|
||
if collection == "declarative":
|
||
payload["metadata"]["user_id"] = user_id
|
||
elif collection == "episodic":
|
||
payload["metadata"]["source"] = user_id
|
||
|
||
# Insert into Qdrant
|
||
qdrant_host = self._base_url.replace("http://cheshire-cat:80", "http://cheshire-cat-vector-memory:6333")
|
||
|
||
async with session.put(
|
||
f"{qdrant_host}/collections/{collection}/points",
|
||
json={
|
||
"points": [{
|
||
"id": point_id,
|
||
"vector": vector,
|
||
"payload": payload
|
||
}]
|
||
},
|
||
timeout=aiohttp.ClientTimeout(total=15)
|
||
) as insert_response:
|
||
if insert_response.status == 200:
|
||
logger.info(f"✨ Created new {collection} memory point: {point_id}")
|
||
return point_id
|
||
else:
|
||
logger.error(f"Failed to insert point: {insert_response.status}")
|
||
return None
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error creating memory point: {e}")
|
||
return None
|
||
|
||
async def wipe_all_memories(self) -> bool:
|
||
"""
|
||
Delete ALL memory collections (episodic + declarative).
|
||
This is the nuclear option — requires multi-step confirmation in the UI.
|
||
"""
|
||
try:
|
||
async with aiohttp.ClientSession() as session:
|
||
async with session.delete(
|
||
f"{self._base_url}/memory/collections",
|
||
headers=self._get_headers(),
|
||
timeout=aiohttp.ClientTimeout(total=30)
|
||
) as response:
|
||
if response.status == 200:
|
||
logger.warning("🗑️ ALL memory collections wiped!")
|
||
return True
|
||
else:
|
||
error = await response.text()
|
||
logger.error(f"Failed to wipe memories: {response.status} - {error}")
|
||
return False
|
||
except Exception as e:
|
||
logger.error(f"Error wiping memories: {e}")
|
||
return False
|
||
|
||
async def wipe_conversation_history(self) -> bool:
|
||
"""Clear working memory / conversation history."""
|
||
try:
|
||
async with aiohttp.ClientSession() as session:
|
||
async with session.delete(
|
||
f"{self._base_url}/memory/conversation_history",
|
||
headers=self._get_headers(),
|
||
timeout=aiohttp.ClientTimeout(total=15)
|
||
) as response:
|
||
if response.status == 200:
|
||
logger.info("Conversation history cleared")
|
||
return True
|
||
else:
|
||
logger.error(f"Failed to clear conversation history: {response.status}")
|
||
return False
|
||
except Exception as e:
|
||
logger.error(f"Error clearing conversation history: {e}")
|
||
return False
|
||
|
||
async def trigger_consolidation(self) -> Optional[str]:
|
||
"""
|
||
Trigger memory consolidation by sending a special message via WebSocket.
|
||
The memory_consolidation plugin's tool 'consolidate_memories' is
|
||
triggered when it sees 'consolidate now' in the text.
|
||
Uses WebSocket with a system user ID for proper context.
|
||
"""
|
||
try:
|
||
ws_base = self._base_url.replace("http://", "ws://").replace("https://", "wss://")
|
||
ws_url = f"{ws_base}/ws/system_consolidation"
|
||
|
||
logger.info("🌙 Triggering memory consolidation via WS...")
|
||
|
||
async with aiohttp.ClientSession() as session:
|
||
async with session.ws_connect(
|
||
ws_url,
|
||
timeout=300, # Consolidation can be very slow
|
||
) as ws:
|
||
await ws.send_json({"text": "consolidate now"})
|
||
|
||
# Wait for the final chat response
|
||
deadline = asyncio.get_event_loop().time() + 300
|
||
|
||
while True:
|
||
remaining = deadline - asyncio.get_event_loop().time()
|
||
if remaining <= 0:
|
||
logger.error("Consolidation timed out (>300s)")
|
||
return "Consolidation timed out"
|
||
|
||
try:
|
||
ws_msg = await asyncio.wait_for(
|
||
ws.receive(),
|
||
timeout=remaining
|
||
)
|
||
except asyncio.TimeoutError:
|
||
logger.error("Consolidation WS receive timeout")
|
||
return "Consolidation timed out waiting for response"
|
||
|
||
if ws_msg.type in (aiohttp.WSMsgType.CLOSE, aiohttp.WSMsgType.CLOSING, aiohttp.WSMsgType.CLOSED):
|
||
logger.warning("Consolidation WS closed by server")
|
||
return "Connection closed during consolidation"
|
||
if ws_msg.type == aiohttp.WSMsgType.ERROR:
|
||
return f"WebSocket error: {ws.exception()}"
|
||
if ws_msg.type != aiohttp.WSMsgType.TEXT:
|
||
continue
|
||
|
||
try:
|
||
msg = json.loads(ws_msg.data)
|
||
except (json.JSONDecodeError, TypeError):
|
||
continue
|
||
|
||
msg_type = msg.get("type", "")
|
||
if msg_type == "chat":
|
||
reply = msg.get("content") or msg.get("text", "")
|
||
logger.info(f"Consolidation result: {reply[:200]}")
|
||
return reply
|
||
elif msg_type == "error":
|
||
error_desc = msg.get("description", "Unknown error")
|
||
logger.error(f"Consolidation error: {error_desc}")
|
||
return f"Consolidation error: {error_desc}"
|
||
else:
|
||
continue
|
||
|
||
except asyncio.TimeoutError:
|
||
logger.error("Consolidation WS connection timed out")
|
||
return None
|
||
except Exception as e:
|
||
logger.error(f"Consolidation error: {e}")
|
||
return None
|
||
|
||
# ====================================================================
|
||
# Admin API helpers – plugin toggling & LLM model switching
|
||
# ====================================================================
|
||
|
||
async def wait_for_ready(self, max_wait: int = 120, interval: int = 5) -> bool:
|
||
"""Wait for Cat to become reachable, polling with interval.
|
||
|
||
Used on startup to avoid race conditions when bot starts before Cat.
|
||
Returns True once Cat responds, False if max_wait exceeded.
|
||
"""
|
||
start = time.time()
|
||
attempt = 0
|
||
while time.time() - start < max_wait:
|
||
attempt += 1
|
||
try:
|
||
async with aiohttp.ClientSession() as session:
|
||
async with session.get(
|
||
f"{self._base_url}/",
|
||
timeout=aiohttp.ClientTimeout(total=5),
|
||
) as resp:
|
||
if resp.status == 200:
|
||
elapsed = time.time() - start
|
||
logger.info(f"🐱 Cat is ready (took {elapsed:.1f}s, {attempt} attempts)")
|
||
self._healthy = True
|
||
self._last_health_check = time.time()
|
||
return True
|
||
except Exception:
|
||
pass
|
||
if attempt == 1:
|
||
logger.info(f"⏳ Waiting for Cat to become ready (up to {max_wait}s)...")
|
||
await asyncio.sleep(interval)
|
||
logger.error(f"Cat did not become ready within {max_wait}s ({attempt} attempts)")
|
||
return False
|
||
|
||
async def toggle_plugin(self, plugin_id: str) -> bool:
|
||
"""Toggle a Cat plugin on/off via the admin API.
|
||
|
||
PUT /plugins/toggle/{plugin_id}
|
||
Returns True on success, False on failure.
|
||
"""
|
||
url = f"{self._base_url}/plugins/toggle/{plugin_id}"
|
||
try:
|
||
async with aiohttp.ClientSession() as session:
|
||
async with session.put(
|
||
url,
|
||
headers=self._get_headers(),
|
||
timeout=aiohttp.ClientTimeout(total=15),
|
||
) as resp:
|
||
if resp.status == 200:
|
||
logger.info(f"🐱 Toggled Cat plugin: {plugin_id}")
|
||
return True
|
||
else:
|
||
body = await resp.text()
|
||
logger.error(f"Cat plugin toggle failed ({resp.status}): {body}")
|
||
return False
|
||
except Exception as e:
|
||
logger.error(f"Cat plugin toggle error for {plugin_id}: {e}")
|
||
return False
|
||
|
||
async def set_llm_model(self, model_name: str) -> bool:
|
||
"""Switch the Cheshire Cat's active LLM model via settings API.
|
||
|
||
The Cat settings API uses UUIDs: we must first GET /settings/ to find
|
||
the setting_id for LLMOpenAIChatConfig, then PUT /settings/{setting_id}.
|
||
llama-swap handles the actual model loading based on model_name.
|
||
Returns True on success, False on failure.
|
||
"""
|
||
try:
|
||
# Step 1: Find the setting_id for LLMOpenAIChatConfig
|
||
setting_id = None
|
||
async with aiohttp.ClientSession() as session:
|
||
async with session.get(
|
||
f"{self._base_url}/settings/",
|
||
headers=self._get_headers(),
|
||
timeout=aiohttp.ClientTimeout(total=10),
|
||
) as resp:
|
||
if resp.status != 200:
|
||
logger.error(f"Cat settings GET failed ({resp.status})")
|
||
return False
|
||
data = await resp.json()
|
||
for s in data.get("settings", []):
|
||
if s.get("name") == "LLMOpenAIChatConfig":
|
||
setting_id = s["setting_id"]
|
||
break
|
||
|
||
if not setting_id:
|
||
logger.error("Could not find LLMOpenAIChatConfig setting_id in Cat settings")
|
||
return False
|
||
|
||
# Step 2: PUT updated config to /settings/{setting_id}
|
||
payload = {
|
||
"name": "LLMOpenAIChatConfig",
|
||
"value": {
|
||
"openai_api_key": "sk-dummy",
|
||
"model_name": model_name,
|
||
"temperature": 0.8,
|
||
"streaming": False,
|
||
},
|
||
"category": "llm_factory",
|
||
}
|
||
async with aiohttp.ClientSession() as session:
|
||
async with session.put(
|
||
f"{self._base_url}/settings/{setting_id}",
|
||
json=payload,
|
||
headers=self._get_headers(),
|
||
timeout=aiohttp.ClientTimeout(total=15),
|
||
) as resp:
|
||
if resp.status == 200:
|
||
logger.info(f"🐱 Set Cat LLM model to: {model_name}")
|
||
return True
|
||
else:
|
||
body = await resp.text()
|
||
logger.error(f"Cat LLM model switch failed ({resp.status}): {body}")
|
||
return False
|
||
except Exception as e:
|
||
logger.error(f"Cat LLM model switch error: {e}")
|
||
return False
|
||
|
||
async def get_active_plugins(self) -> list:
|
||
"""Get list of active Cat plugin IDs.
|
||
|
||
GET /plugins → returns {\"installed\": [...], \"filters\": {...}}
|
||
Each plugin has \"id\" and \"active\" fields.
|
||
"""
|
||
url = f"{self._base_url}/plugins"
|
||
try:
|
||
async with aiohttp.ClientSession() as session:
|
||
async with session.get(
|
||
url,
|
||
headers=self._get_headers(),
|
||
timeout=aiohttp.ClientTimeout(total=10),
|
||
) as resp:
|
||
if resp.status == 200:
|
||
data = await resp.json()
|
||
installed = data.get("installed", [])
|
||
return [p["id"] for p in installed if p.get("active")]
|
||
else:
|
||
logger.error(f"Cat get_active_plugins failed ({resp.status})")
|
||
return []
|
||
except Exception as e:
|
||
logger.error(f"Cat get_active_plugins error: {e}")
|
||
return []
|
||
|
||
async def switch_to_evil_personality(self) -> bool:
|
||
"""Disable miku_personality, enable evil_miku_personality, switch LLM to darkidol.
|
||
|
||
Checks current plugin state first to avoid double-toggling
|
||
(the Cat API is a toggle, not enable/disable).
|
||
Returns True if all operations succeed, False if any fail.
|
||
"""
|
||
logger.info("🐱 Switching Cat to Evil Miku personality...")
|
||
success = True
|
||
|
||
# Check current plugin state
|
||
active = await self.get_active_plugins()
|
||
|
||
# Step 1: Disable normal personality (only if currently active)
|
||
if "miku_personality" in active:
|
||
if not await self.toggle_plugin("miku_personality"):
|
||
logger.error("Failed to disable miku_personality plugin")
|
||
success = False
|
||
await asyncio.sleep(1)
|
||
else:
|
||
logger.debug("miku_personality already disabled, skipping toggle")
|
||
|
||
# Step 2: Enable evil personality (only if currently inactive)
|
||
if "evil_miku_personality" not in active:
|
||
if not await self.toggle_plugin("evil_miku_personality"):
|
||
logger.error("Failed to enable evil_miku_personality plugin")
|
||
success = False
|
||
else:
|
||
logger.debug("evil_miku_personality already active, skipping toggle")
|
||
|
||
# Step 3: Switch LLM model to darkidol (the uncensored evil model)
|
||
if not await self.set_llm_model("darkidol"):
|
||
logger.error("Failed to switch Cat LLM to darkidol")
|
||
success = False
|
||
|
||
return success
|
||
|
||
async def switch_to_normal_personality(self) -> bool:
|
||
"""Disable evil_miku_personality, enable miku_personality, switch LLM to llama3.1.
|
||
|
||
Checks current plugin state first to avoid double-toggling.
|
||
Returns True if all operations succeed, False if any fail.
|
||
"""
|
||
logger.info("🐱 Switching Cat to normal Miku personality...")
|
||
success = True
|
||
|
||
# Check current plugin state
|
||
active = await self.get_active_plugins()
|
||
|
||
# Step 1: Disable evil personality (only if currently active)
|
||
if "evil_miku_personality" in active:
|
||
if not await self.toggle_plugin("evil_miku_personality"):
|
||
logger.error("Failed to disable evil_miku_personality plugin")
|
||
success = False
|
||
await asyncio.sleep(1)
|
||
else:
|
||
logger.debug("evil_miku_personality already disabled, skipping toggle")
|
||
|
||
# Step 2: Enable normal personality (only if currently inactive)
|
||
if "miku_personality" not in active:
|
||
if not await self.toggle_plugin("miku_personality"):
|
||
logger.error("Failed to enable miku_personality plugin")
|
||
success = False
|
||
else:
|
||
logger.debug("miku_personality already active, skipping toggle")
|
||
|
||
# Step 3: Switch LLM model back to llama3.1 (normal model)
|
||
if not await self.set_llm_model("llama3.1"):
|
||
logger.error("Failed to switch Cat LLM to llama3.1")
|
||
success = False
|
||
|
||
return success
|
||
|
||
|
||
# Singleton instance
|
||
cat_adapter = CatAdapter()
|