fix: Phase 2 integrity review - v2.0.0 rewrite & bugfixes
Memory Consolidation Plugin (828 -> 465 lines): - Replace SentenceTransformer with cat.embedder.embed_query() for vector consistency - Fix per-user fact isolation: source=user_id instead of global - Add duplicate fact detection (_is_duplicate_fact, score_threshold=0.85) - Remove ~350 lines of dead async run_consolidation() code - Remove duplicate declarative search in before_cat_sends_message - Unify trivial patterns into TRIVIAL_PATTERNS frozenset - Remove all sys.stderr.write debug logging - Remove sentence-transformers from requirements.txt (no external deps) Loguru Fix (cheshire-cat/cat/log.py): - Patch Cat v1.6.2 loguru format to provide default extra fields - Fixes KeyError: 'original_name' from third-party libs (fastembed) - Mounted via docker-compose volume Discord Bridge: - Copy discord_bridge.py to cat-plugins/ (was empty directory) Test Results (6/7 pass, 100% fact recall): - 11 facts extracted, per-user isolation working - Duplicate detection effective (+2 on 2nd run) - 5/5 natural language recall queries correct
This commit is contained in:
109
cat-plugins/discord_bridge/discord_bridge.py
Normal file
109
cat-plugins/discord_bridge/discord_bridge.py
Normal file
@@ -0,0 +1,109 @@
|
||||
"""
|
||||
Discord Bridge Plugin for Cheshire Cat
|
||||
|
||||
This plugin enriches Cat's memory system with Discord context:
|
||||
- Unified user identity across all servers and DMs
|
||||
- Guild/channel metadata for context tracking
|
||||
- Minimal filtering before storage (only skip obvious junk)
|
||||
- Marks memories as unconsolidated for nightly processing
|
||||
|
||||
Phase 1 Implementation
|
||||
"""
|
||||
|
||||
from cat.mad_hatter.decorators import hook
|
||||
from datetime import datetime
|
||||
import re
|
||||
|
||||
|
||||
@hook(priority=100)
|
||||
def before_cat_reads_message(user_message_json: dict, cat) -> dict:
|
||||
"""
|
||||
Enrich incoming message with Discord metadata.
|
||||
This runs BEFORE the message is processed.
|
||||
"""
|
||||
# Extract Discord context from working memory or metadata
|
||||
# These will be set by the Discord bot when calling the Cat API
|
||||
guild_id = cat.working_memory.get('guild_id')
|
||||
channel_id = cat.working_memory.get('channel_id')
|
||||
|
||||
# Add to message metadata for later use
|
||||
if 'metadata' not in user_message_json:
|
||||
user_message_json['metadata'] = {}
|
||||
|
||||
user_message_json['metadata']['guild_id'] = guild_id or 'dm'
|
||||
user_message_json['metadata']['channel_id'] = channel_id
|
||||
user_message_json['metadata']['timestamp'] = datetime.now().isoformat()
|
||||
|
||||
return user_message_json
|
||||
|
||||
|
||||
@hook(priority=100)
|
||||
def before_cat_stores_episodic_memory(doc, cat):
|
||||
"""
|
||||
Filter and enrich memories before storage.
|
||||
|
||||
Phase 1: Minimal filtering
|
||||
- Skip only obvious junk (1-2 char messages, pure reactions)
|
||||
- Store everything else temporarily
|
||||
- Mark as unconsolidated for nightly processing
|
||||
"""
|
||||
message = doc.page_content.strip()
|
||||
|
||||
# Skip only the most trivial messages
|
||||
skip_patterns = [
|
||||
r'^\w{1,2}$', # 1-2 character messages: "k", "ok"
|
||||
r'^(lol|lmao|haha|hehe|xd|rofl)$', # Pure reactions
|
||||
r'^:[\w_]+:$', # Discord emoji only: ":smile:"
|
||||
]
|
||||
|
||||
for pattern in skip_patterns:
|
||||
if re.match(pattern, message.lower()):
|
||||
print(f"🗑️ [Discord Bridge] Skipping trivial message: {message}")
|
||||
return None # Don't store at all
|
||||
|
||||
# Add Discord metadata to memory
|
||||
doc.metadata['consolidated'] = False # Needs nightly processing
|
||||
doc.metadata['stored_at'] = datetime.now().isoformat()
|
||||
|
||||
# Get Discord context from working memory
|
||||
guild_id = cat.working_memory.get('guild_id')
|
||||
channel_id = cat.working_memory.get('channel_id')
|
||||
|
||||
doc.metadata['guild_id'] = guild_id or 'dm'
|
||||
doc.metadata['channel_id'] = channel_id
|
||||
doc.metadata['source'] = cat.user_id # CRITICAL: Cat filters episodic by source=user_id!
|
||||
doc.metadata['discord_source'] = 'discord' # Keep original value as separate field
|
||||
|
||||
print(f"💾 [Discord Bridge] Storing memory (unconsolidated): {message[:50]}...")
|
||||
print(f" User: {cat.user_id}, Guild: {doc.metadata['guild_id']}, Channel: {channel_id}")
|
||||
|
||||
return doc
|
||||
|
||||
|
||||
@hook(priority=50)
|
||||
def after_cat_recalls_memories(cat):
|
||||
"""
|
||||
Log memory recall for debugging.
|
||||
Access recalled memories via cat.working_memory.
|
||||
"""
|
||||
import sys
|
||||
sys.stderr.write("🧠 [Discord Bridge] after_cat_recalls_memories HOOK CALLED!\n")
|
||||
sys.stderr.flush()
|
||||
|
||||
# Get recalled memories from working memory
|
||||
episodic_memories = cat.working_memory.get('episodic_memories', [])
|
||||
declarative_memories = cat.working_memory.get('declarative_memories', [])
|
||||
|
||||
if episodic_memories:
|
||||
print(f"🧠 [Discord Bridge] Recalled {len(episodic_memories)} episodic memories for user {cat.user_id}")
|
||||
# Show which guilds the memories are from
|
||||
guilds = set()
|
||||
for doc, score in episodic_memories:
|
||||
guild = doc.metadata.get('guild_id', 'unknown')
|
||||
guilds.add(guild)
|
||||
print(f" From guilds: {', '.join(guilds)}")
|
||||
|
||||
|
||||
# Plugin metadata
|
||||
__version__ = "1.0.0"
|
||||
__description__ = "Discord bridge with unified user identity and sleep consolidation support"
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1 +0,0 @@
|
||||
sentence-transformers>=2.2.0
|
||||
246
cheshire-cat/cat/log.py
Normal file
246
cheshire-cat/cat/log.py
Normal file
@@ -0,0 +1,246 @@
|
||||
"""The log engine."""
|
||||
|
||||
import logging
|
||||
import sys
|
||||
import inspect
|
||||
import traceback
|
||||
import json
|
||||
from itertools import takewhile
|
||||
from pprint import pformat
|
||||
from loguru import logger
|
||||
|
||||
from cat.env import get_env
|
||||
|
||||
def get_log_level():
|
||||
"""Return the global LOG level."""
|
||||
return get_env("CCAT_LOG_LEVEL")
|
||||
|
||||
|
||||
class CatLogEngine:
|
||||
"""The log engine.
|
||||
|
||||
Engine to filter the logs in the terminal according to the level of severity.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
LOG_LEVEL : str
|
||||
Level of logging set in the `.env` file.
|
||||
|
||||
Notes
|
||||
-----
|
||||
The logging level set in the `.env` file will print all the logs from that level to above.
|
||||
Available levels are:
|
||||
|
||||
- `DEBUG`
|
||||
- `INFO`
|
||||
- `WARNING`
|
||||
- `ERROR`
|
||||
- `CRITICAL`
|
||||
|
||||
Default to `INFO`.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.LOG_LEVEL = get_log_level()
|
||||
self.default_log()
|
||||
|
||||
# workaround for pdfminer logging
|
||||
# https://github.com/pdfminer/pdfminer.six/issues/347
|
||||
logging.getLogger("pdfminer").setLevel(logging.WARNING)
|
||||
|
||||
def show_log_level(self, record):
|
||||
"""Allows to show stuff in the log based on the global setting.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
record : dict
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
|
||||
"""
|
||||
return record["level"].no >= logger.level(self.LOG_LEVEL).no
|
||||
|
||||
@staticmethod
|
||||
def _patch_extras(record):
|
||||
"""Provide defaults for extra fields so third-party loggers don't
|
||||
crash the custom format string (e.g. fastembed deprecation warnings)."""
|
||||
record["extra"].setdefault("original_name", "(third-party)")
|
||||
record["extra"].setdefault("original_class", "")
|
||||
record["extra"].setdefault("original_caller", "")
|
||||
record["extra"].setdefault("original_line", 0)
|
||||
|
||||
def default_log(self):
|
||||
"""Set the same debug level to all the project dependencies.
|
||||
|
||||
Returns
|
||||
-------
|
||||
"""
|
||||
|
||||
time = "<green>[{time:YYYY-MM-DD HH:mm:ss.SSS}]</green>"
|
||||
level = "<level>{level: <6}</level>"
|
||||
origin = "<level>{extra[original_name]}.{extra[original_class]}.{extra[original_caller]}::{extra[original_line]}</level>"
|
||||
message = "<level>{message}</level>"
|
||||
log_format = f"{time} {level} {origin} \n{message}"
|
||||
|
||||
logger.remove()
|
||||
logger.configure(patcher=self._patch_extras)
|
||||
if self.LOG_LEVEL == "DEBUG":
|
||||
return logger.add(
|
||||
sys.stdout,
|
||||
colorize=True,
|
||||
format=log_format,
|
||||
backtrace=True,
|
||||
diagnose=True,
|
||||
filter=self.show_log_level
|
||||
)
|
||||
else:
|
||||
return logger.add(
|
||||
sys.stdout,
|
||||
colorize=True,
|
||||
format=log_format,
|
||||
filter=self.show_log_level,
|
||||
level=self.LOG_LEVEL
|
||||
)
|
||||
|
||||
def get_caller_info(self, skip=3):
|
||||
"""Get the name of a caller in the format module.class.method.
|
||||
|
||||
Copied from: https://gist.github.com/techtonik/2151727
|
||||
|
||||
Parameters
|
||||
----------
|
||||
skip : int
|
||||
Specifies how many levels of stack to skip while getting caller name.
|
||||
|
||||
Returns
|
||||
-------
|
||||
package : str
|
||||
Caller package.
|
||||
module : str
|
||||
Caller module.
|
||||
klass : str
|
||||
Caller classname if one otherwise None.
|
||||
caller : str
|
||||
Caller function or method (if a class exist).
|
||||
line : int
|
||||
The line of the call.
|
||||
|
||||
|
||||
Notes
|
||||
-----
|
||||
skip=1 means "who calls me",
|
||||
skip=2 "who calls my caller" etc.
|
||||
|
||||
An empty string is returned if skipped levels exceed stack height.
|
||||
"""
|
||||
stack = inspect.stack()
|
||||
start = 0 + skip
|
||||
if len(stack) < start + 1:
|
||||
return ""
|
||||
parentframe = stack[start][0]
|
||||
|
||||
# module and packagename.
|
||||
module_info = inspect.getmodule(parentframe)
|
||||
if module_info:
|
||||
mod = module_info.__name__.split(".")
|
||||
package = mod[0]
|
||||
module = ".".join(mod[1:])
|
||||
|
||||
# class name.
|
||||
klass = ""
|
||||
if "self" in parentframe.f_locals:
|
||||
klass = parentframe.f_locals["self"].__class__.__name__
|
||||
|
||||
# method or function name.
|
||||
caller = None
|
||||
if parentframe.f_code.co_name != "<module>": # top level usually
|
||||
caller = parentframe.f_code.co_name
|
||||
|
||||
# call line.
|
||||
line = parentframe.f_lineno
|
||||
|
||||
# Remove reference to frame
|
||||
# See: https://docs.python.org/3/library/inspect.html#the-interpreter-stack
|
||||
del parentframe
|
||||
|
||||
return package, module, klass, caller, line
|
||||
|
||||
def __call__(self, msg, level="DEBUG"):
|
||||
"""Alias of self.log()"""
|
||||
self.log(msg, level)
|
||||
|
||||
def debug(self, msg):
|
||||
"""Logs a DEBUG message"""
|
||||
self.log(msg, level="DEBUG")
|
||||
|
||||
def info(self, msg):
|
||||
"""Logs an INFO message"""
|
||||
self.log(msg, level="INFO")
|
||||
|
||||
def warning(self, msg):
|
||||
"""Logs a WARNING message"""
|
||||
self.log(msg, level="WARNING")
|
||||
|
||||
def error(self, msg):
|
||||
"""Logs an ERROR message"""
|
||||
self.log(msg, level="ERROR")
|
||||
|
||||
def critical(self, msg):
|
||||
"""Logs a CRITICAL message"""
|
||||
self.log(msg, level="CRITICAL")
|
||||
|
||||
def log(self, msg, level="DEBUG"):
|
||||
"""Log a message
|
||||
|
||||
Parameters
|
||||
----------
|
||||
msg :
|
||||
Message to be logged.
|
||||
level : str
|
||||
Logging level."""
|
||||
|
||||
(package, module, klass, caller, line) = self.get_caller_info()
|
||||
|
||||
custom_logger = logger.bind(
|
||||
original_name=f"{package}.{module}",
|
||||
original_line=line,
|
||||
original_class=klass,
|
||||
original_caller=caller,
|
||||
)
|
||||
|
||||
# prettify
|
||||
if type(msg) in [dict, list, str]: # TODO: should be recursive
|
||||
try:
|
||||
msg = json.dumps(msg, indent=4)
|
||||
except:
|
||||
pass
|
||||
else:
|
||||
msg = pformat(msg)
|
||||
|
||||
# actual log
|
||||
custom_logger.log(level, msg)
|
||||
|
||||
def welcome(self):
|
||||
"""Welcome message in the terminal."""
|
||||
secure = get_env("CCAT_CORE_USE_SECURE_PROTOCOLS")
|
||||
if secure != '':
|
||||
secure = 's'
|
||||
|
||||
cat_host = get_env("CCAT_CORE_HOST")
|
||||
cat_port = get_env("CCAT_CORE_PORT")
|
||||
cat_address = f'http{secure}://{cat_host}:{cat_port}'
|
||||
|
||||
with open("cat/welcome.txt", 'r') as f:
|
||||
print(f.read())
|
||||
|
||||
print('\n=============== ^._.^ ===============\n')
|
||||
print(f'Cat REST API: {cat_address}/docs')
|
||||
print(f'Cat PUBLIC: {cat_address}/public')
|
||||
print(f'Cat ADMIN: {cat_address}/admin\n')
|
||||
print('======================================')
|
||||
|
||||
# logger instance
|
||||
log = CatLogEngine()
|
||||
60
cheshire-cat/docker-compose.test.yml
Normal file
60
cheshire-cat/docker-compose.test.yml
Normal file
@@ -0,0 +1,60 @@
|
||||
services:
|
||||
cheshire-cat-core:
|
||||
image: ghcr.io/cheshire-cat-ai/core:1.6.2
|
||||
container_name: miku_cheshire_cat_test
|
||||
depends_on:
|
||||
- cheshire-cat-vector-memory
|
||||
environment:
|
||||
PYTHONUNBUFFERED: "1"
|
||||
WATCHFILES_FORCE_POLLING: "true"
|
||||
CORE_HOST: ${CORE_HOST:-localhost}
|
||||
CORE_PORT: ${CORE_PORT:-1865}
|
||||
QDRANT_HOST: ${QDRANT_HOST:-cheshire-cat-vector-memory}
|
||||
QDRANT_PORT: ${QDRANT_PORT:-6333}
|
||||
CORE_USE_SECURE_PROTOCOLS: ${CORE_USE_SECURE_PROTOCOLS:-false}
|
||||
API_KEY: ${API_KEY:-}
|
||||
LOG_LEVEL: ${LOG_LEVEL:-INFO}
|
||||
DEBUG: ${DEBUG:-true}
|
||||
SAVE_MEMORY_SNAPSHOTS: ${SAVE_MEMORY_SNAPSHOTS:-false}
|
||||
OPENAI_API_BASE: "http://host.docker.internal:8091/v1"
|
||||
ports:
|
||||
- "${CORE_PORT:-1865}:80"
|
||||
# Allow connection to host services (llama-swap)
|
||||
extra_hosts:
|
||||
- "host.docker.internal:host-gateway"
|
||||
volumes:
|
||||
- ./cat/static:/app/cat/static
|
||||
- ./cat/plugins:/app/cat/plugins
|
||||
- ./cat/data:/app/cat/data
|
||||
- ./cat/log.py:/app/cat/log.py # Patched: fix loguru KeyError for third-party libs
|
||||
restart: unless-stopped
|
||||
networks:
|
||||
- miku-test-network
|
||||
- miku-discord_default # Connect to existing miku bot network
|
||||
|
||||
cheshire-cat-vector-memory:
|
||||
image: qdrant/qdrant:v1.9.1
|
||||
container_name: miku_qdrant_test
|
||||
environment:
|
||||
LOG_LEVEL: ${LOG_LEVEL:-INFO}
|
||||
ports:
|
||||
- "6333:6333" # Expose for debugging
|
||||
ulimits:
|
||||
nofile:
|
||||
soft: 65536
|
||||
hard: 65536
|
||||
volumes:
|
||||
- ./cat/long_term_memory/vector:/qdrant/storage
|
||||
restart: unless-stopped
|
||||
networks:
|
||||
- miku-test-network
|
||||
|
||||
networks:
|
||||
miku-test-network:
|
||||
driver: bridge
|
||||
# Connect to main miku-discord network to access llama-swap
|
||||
default:
|
||||
external: true
|
||||
name: miku-discord_default
|
||||
miku-discord_default:
|
||||
external: true # Connect to your existing bot's network
|
||||
@@ -1,196 +1,254 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Full Pipeline Test for Memory Consolidation System
|
||||
Tests all phases: Storage → Consolidation → Fact Extraction → Recall
|
||||
Full Pipeline Test for Memory Consolidation System v2.0.0
|
||||
"""
|
||||
|
||||
import requests
|
||||
import time
|
||||
import json
|
||||
import sys
|
||||
|
||||
BASE_URL = "http://localhost:1865"
|
||||
CAT_URL = "http://localhost:1865"
|
||||
QDRANT_URL = "http://localhost:6333"
|
||||
CONSOLIDATION_TIMEOUT = 180
|
||||
|
||||
def send_message(text):
|
||||
"""Send a message to Miku and get response"""
|
||||
resp = requests.post(f"{BASE_URL}/message", json={"text": text})
|
||||
return resp.json()
|
||||
|
||||
def get_qdrant_count(collection):
|
||||
"""Get count of items in Qdrant collection"""
|
||||
resp = requests.post(
|
||||
f"http://localhost:6333/collections/{collection}/points/scroll",
|
||||
json={"limit": 1000, "with_payload": False, "with_vector": False}
|
||||
)
|
||||
return len(resp.json()["result"]["points"])
|
||||
def send_message(text, timeout=30):
|
||||
try:
|
||||
resp = requests.post(f"{CAT_URL}/message", json={"text": text}, timeout=timeout)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
except requests.exceptions.Timeout:
|
||||
return {"error": "timeout", "content": ""}
|
||||
except Exception as e:
|
||||
return {"error": str(e), "content": ""}
|
||||
|
||||
|
||||
def qdrant_scroll(collection, limit=200, filt=None):
|
||||
body = {"limit": limit, "with_payload": True, "with_vector": False}
|
||||
if filt:
|
||||
body["filter"] = filt
|
||||
resp = requests.post(f"{QDRANT_URL}/collections/{collection}/points/scroll", json=body)
|
||||
return resp.json()["result"]["points"]
|
||||
|
||||
|
||||
def qdrant_count(collection):
|
||||
return len(qdrant_scroll(collection))
|
||||
|
||||
|
||||
def section(title):
|
||||
print(f"\n{'=' * 70}")
|
||||
print(f" {title}")
|
||||
print(f"{'=' * 70}")
|
||||
|
||||
|
||||
print("=" * 70)
|
||||
print("🧪 FULL PIPELINE TEST - Memory Consolidation System")
|
||||
print(" FULL PIPELINE TEST - Memory Consolidation v2.0.0")
|
||||
print("=" * 70)
|
||||
|
||||
try:
|
||||
requests.get(f"{CAT_URL}/", timeout=5)
|
||||
except Exception:
|
||||
print("ERROR: Cat not reachable"); sys.exit(1)
|
||||
try:
|
||||
requests.get(f"{QDRANT_URL}/collections", timeout=5)
|
||||
except Exception:
|
||||
print("ERROR: Qdrant not reachable"); sys.exit(1)
|
||||
|
||||
episodic_start = qdrant_count("episodic")
|
||||
declarative_start = qdrant_count("declarative")
|
||||
print(f"\nStarting state: {episodic_start} episodic, {declarative_start} declarative")
|
||||
|
||||
results = {}
|
||||
|
||||
# TEST 1: Trivial Message Filtering
|
||||
print("\n📋 TEST 1: Trivial Message Filtering")
|
||||
print("-" * 70)
|
||||
section("TEST 1: Trivial Message Filtering")
|
||||
|
||||
trivial_messages = ["lol", "k", "ok", "haha", "xd"]
|
||||
important_message = "My name is Alex and I live in Seattle"
|
||||
|
||||
print("Sending trivial messages (should be filtered out)...")
|
||||
trivial_messages = ["lol", "k", "ok", "haha", "xd", "brb"]
|
||||
print(f"Sending {len(trivial_messages)} trivial messages...")
|
||||
for msg in trivial_messages:
|
||||
send_message(msg)
|
||||
time.sleep(0.5)
|
||||
time.sleep(0.3)
|
||||
|
||||
print("Sending important message...")
|
||||
send_message(important_message)
|
||||
time.sleep(1)
|
||||
# Count only USER episodic memories (exclude Miku's responses)
|
||||
user_episodic = qdrant_scroll("episodic", filt={
|
||||
"must_not": [{"key": "metadata.speaker", "match": {"value": "miku"}}]
|
||||
})
|
||||
trivial_user_stored = len(user_episodic) - episodic_start
|
||||
episodic_after_trivial = qdrant_count("episodic")
|
||||
|
||||
episodic_count = get_qdrant_count("episodic")
|
||||
print(f"\n✅ Episodic memories stored: {episodic_count}")
|
||||
if episodic_count < len(trivial_messages):
|
||||
print(" ✓ Trivial filtering working! (some messages were filtered)")
|
||||
# discord_bridge filters trivial user messages, but Miku still responds
|
||||
# so we only check user-side storage
|
||||
if trivial_user_stored < len(trivial_messages):
|
||||
print(f" PASS - Only {trivial_user_stored}/{len(trivial_messages)} user trivial messages stored")
|
||||
print(f" (Total episodic incl. Miku responses: {episodic_after_trivial})")
|
||||
results["trivial_filtering"] = True
|
||||
else:
|
||||
print(" ⚠️ Trivial filtering may not be active")
|
||||
print(f" WARN - All {trivial_user_stored} trivial messages stored")
|
||||
results["trivial_filtering"] = False
|
||||
|
||||
# TEST 2: Miku's Response Storage
|
||||
print("\n📋 TEST 2: Miku's Response Storage")
|
||||
print("-" * 70)
|
||||
# TEST 2: Important Message Storage
|
||||
section("TEST 2: Important Message Storage")
|
||||
|
||||
print("Sending message and checking if Miku's response is stored...")
|
||||
resp = send_message("Tell me a very short fact about music")
|
||||
miku_said = resp["content"]
|
||||
print(f"Miku said: {miku_said[:80]}...")
|
||||
time.sleep(2)
|
||||
|
||||
# Check for Miku's messages in episodic
|
||||
resp = requests.post(
|
||||
"http://localhost:6333/collections/episodic/points/scroll",
|
||||
json={
|
||||
"limit": 100,
|
||||
"with_payload": True,
|
||||
"with_vector": False,
|
||||
"filter": {"must": [{"key": "metadata.speaker", "match": {"value": "miku"}}]}
|
||||
}
|
||||
)
|
||||
miku_messages = resp.json()["result"]["points"]
|
||||
print(f"\n✅ Miku's messages in memory: {len(miku_messages)}")
|
||||
if miku_messages:
|
||||
print(f" Example: {miku_messages[0]['payload']['page_content'][:60]}...")
|
||||
print(" ✓ Bidirectional memory working!")
|
||||
else:
|
||||
print(" ⚠️ Miku's responses not being stored")
|
||||
|
||||
# TEST 3: Add Rich Personal Information
|
||||
print("\n📋 TEST 3: Adding Personal Information")
|
||||
print("-" * 70)
|
||||
|
||||
personal_info = [
|
||||
personal_facts = [
|
||||
"My name is Sarah Chen",
|
||||
"I'm 28 years old",
|
||||
"I work as a data scientist at Google",
|
||||
"My favorite color is blue",
|
||||
"I love playing piano",
|
||||
"I live in Seattle, Washington",
|
||||
"I work as a software engineer at Microsoft",
|
||||
"My favorite color is forest green",
|
||||
"I love playing piano and have practiced for 15 years",
|
||||
"I'm learning Japanese, currently at N3 level",
|
||||
"I have a cat named Luna",
|
||||
"I'm allergic to peanuts",
|
||||
"I live in Tokyo, Japan",
|
||||
"My hobbies include photography and hiking"
|
||||
"My birthday is March 15th",
|
||||
"I graduated from UW in 2018",
|
||||
"I enjoy hiking on weekends",
|
||||
]
|
||||
|
||||
print(f"Adding {len(personal_info)} messages with personal information...")
|
||||
for info in personal_info:
|
||||
send_message(info)
|
||||
print(f"Sending {len(personal_facts)} personal info messages...")
|
||||
for i, fact in enumerate(personal_facts, 1):
|
||||
resp = send_message(fact)
|
||||
status = "OK" if "error" not in resp else "ERR"
|
||||
print(f" [{i}/{len(personal_facts)}] {status} {fact[:50]}")
|
||||
time.sleep(0.5)
|
||||
|
||||
episodic_after = get_qdrant_count("episodic")
|
||||
print(f"\n✅ Total episodic memories: {episodic_after}")
|
||||
print(f" ({episodic_after - episodic_count} new memories added)")
|
||||
time.sleep(1)
|
||||
episodic_after_personal = qdrant_count("episodic")
|
||||
personal_stored = episodic_after_personal - episodic_after_trivial
|
||||
print(f"\n Episodic memories from personal info: {personal_stored}")
|
||||
results["important_storage"] = personal_stored >= len(personal_facts)
|
||||
print(f" {'PASS' if results['important_storage'] else 'FAIL'} - Expected >={len(personal_facts)}, got {personal_stored}")
|
||||
|
||||
# TEST 4: Memory Consolidation
|
||||
print("\n📋 TEST 4: Memory Consolidation & Fact Extraction")
|
||||
print("-" * 70)
|
||||
# TEST 3: Miku Response Storage
|
||||
section("TEST 3: Bidirectional Memory (Miku Response Storage)")
|
||||
|
||||
print("Triggering consolidation...")
|
||||
resp = send_message("consolidate now")
|
||||
consolidation_result = resp["content"]
|
||||
print(f"\n{consolidation_result}")
|
||||
miku_points = qdrant_scroll("episodic", filt={
|
||||
"must": [{"key": "metadata.speaker", "match": {"value": "miku"}}]
|
||||
})
|
||||
print(f" Miku's memories in episodic: {len(miku_points)}")
|
||||
if miku_points:
|
||||
print(f" Sample: \"{miku_points[0]['payload']['page_content'][:70]}\"")
|
||||
results["miku_storage"] = True
|
||||
print(" PASS")
|
||||
else:
|
||||
results["miku_storage"] = False
|
||||
print(" FAIL - No Miku responses in episodic memory")
|
||||
|
||||
time.sleep(2)
|
||||
# TEST 4: Per-User Source Tagging
|
||||
section("TEST 4: Per-User Source Tagging")
|
||||
|
||||
# Check declarative facts
|
||||
declarative_count = get_qdrant_count("declarative")
|
||||
print(f"\n✅ Declarative facts extracted: {declarative_count}")
|
||||
user_points = qdrant_scroll("episodic", filt={
|
||||
"must": [{"key": "metadata.source", "match": {"value": "user"}}]
|
||||
})
|
||||
print(f" Points with source='user': {len(user_points)}")
|
||||
|
||||
if declarative_count > 0:
|
||||
# Show sample facts
|
||||
resp = requests.post(
|
||||
"http://localhost:6333/collections/declarative/points/scroll",
|
||||
json={"limit": 5, "with_payload": True, "with_vector": False}
|
||||
)
|
||||
facts = resp.json()["result"]["points"]
|
||||
print("\nSample facts:")
|
||||
for i, fact in enumerate(facts[:5], 1):
|
||||
print(f" {i}. {fact['payload']['page_content']}")
|
||||
global_points = qdrant_scroll("episodic", filt={
|
||||
"must": [{"key": "metadata.source", "match": {"value": "global"}}]
|
||||
})
|
||||
print(f" Points with source='global' (old bug): {len(global_points)}")
|
||||
|
||||
# TEST 5: Fact Recall
|
||||
print("\n📋 TEST 5: Declarative Fact Recall")
|
||||
print("-" * 70)
|
||||
results["user_tagging"] = len(user_points) > 0 and len(global_points) == 0
|
||||
print(f" {'PASS' if results['user_tagging'] else 'FAIL'}")
|
||||
|
||||
queries = [
|
||||
"What is my name?",
|
||||
"How old am I?",
|
||||
"Where do I work?",
|
||||
"What's my favorite color?",
|
||||
"What am I allergic to?"
|
||||
]
|
||||
# TEST 5: Memory Consolidation
|
||||
section("TEST 5: Memory Consolidation & Fact Extraction")
|
||||
|
||||
print("Testing fact recall with queries...")
|
||||
correct_recalls = 0
|
||||
for query in queries:
|
||||
resp = send_message(query)
|
||||
answer = resp["content"]
|
||||
print(f"\n❓ {query}")
|
||||
print(f"💬 Miku: {answer[:150]}...")
|
||||
|
||||
# Basic heuristic: check if answer contains likely keywords
|
||||
keywords = {
|
||||
"What is my name?": ["Sarah", "Chen"],
|
||||
"How old am I?": ["28"],
|
||||
"Where do I work?": ["Google", "data scientist"],
|
||||
"What's my favorite color?": ["blue"],
|
||||
"What am I allergic to?": ["peanut"]
|
||||
}
|
||||
|
||||
if any(kw.lower() in answer.lower() for kw in keywords[query]):
|
||||
print(" ✓ Correct recall!")
|
||||
correct_recalls += 1
|
||||
else:
|
||||
print(" ⚠️ May not have recalled correctly")
|
||||
|
||||
print(f" Triggering consolidation (timeout={CONSOLIDATION_TIMEOUT}s)...")
|
||||
t0 = time.time()
|
||||
resp = send_message("consolidate now", timeout=CONSOLIDATION_TIMEOUT)
|
||||
elapsed = time.time() - t0
|
||||
|
||||
if "error" in resp:
|
||||
print(f" WARN - HTTP issue: {resp['error']} ({elapsed:.0f}s)")
|
||||
print(" Waiting 60s for background completion...")
|
||||
time.sleep(60)
|
||||
else:
|
||||
print(f" Completed in {elapsed:.1f}s")
|
||||
content = resp.get("content", "")
|
||||
print(f" Response: {content[:120]}...")
|
||||
|
||||
time.sleep(3)
|
||||
|
||||
declarative_after = qdrant_count("declarative")
|
||||
new_facts = declarative_after - declarative_start
|
||||
print(f"\n Declarative facts: {declarative_start} -> {declarative_after} (+{new_facts})")
|
||||
|
||||
results["consolidation"] = new_facts >= 5
|
||||
print(f" {'PASS' if results['consolidation'] else 'FAIL'} - {'>=5 facts' if results['consolidation'] else f'only {new_facts}'}")
|
||||
|
||||
all_facts = qdrant_scroll("declarative")
|
||||
print(f"\n All declarative facts ({len(all_facts)}):")
|
||||
for i, f in enumerate(all_facts, 1):
|
||||
content = f["payload"]["page_content"]
|
||||
meta = f["payload"].get("metadata", {})
|
||||
source = meta.get("source", "?")
|
||||
ftype = meta.get("fact_type", "?")
|
||||
print(f" {i}. [{source}|{ftype}] {content}")
|
||||
|
||||
# TEST 6: Duplicate Detection
|
||||
section("TEST 6: Duplicate Detection (2nd consolidation)")
|
||||
|
||||
facts_before_2nd = qdrant_count("declarative")
|
||||
print(f" Facts before: {facts_before_2nd}")
|
||||
print(f" Running consolidation again...")
|
||||
|
||||
resp = send_message("consolidate now", timeout=CONSOLIDATION_TIMEOUT)
|
||||
time.sleep(3)
|
||||
|
||||
facts_after_2nd = qdrant_count("declarative")
|
||||
new_dupes = facts_after_2nd - facts_before_2nd
|
||||
print(f" Facts after: {facts_after_2nd} (+{new_dupes})")
|
||||
|
||||
results["dedup"] = new_dupes <= 2
|
||||
print(f" {'PASS' if results['dedup'] else 'FAIL'} - {new_dupes} new facts (<=2 expected)")
|
||||
|
||||
# TEST 7: Fact Recall
|
||||
section("TEST 7: Fact Recall via Natural Language")
|
||||
|
||||
queries = {
|
||||
"What is my name?": ["sarah", "chen"],
|
||||
"How old am I?": ["28"],
|
||||
"Where do I live?": ["seattle"],
|
||||
"Where do I work?": ["microsoft", "software engineer"],
|
||||
"What am I allergic to?": ["peanut"],
|
||||
}
|
||||
|
||||
correct = 0
|
||||
for question, keywords in queries.items():
|
||||
resp = send_message(question)
|
||||
answer = resp.get("content", "")
|
||||
hit = any(kw.lower() in answer.lower() for kw in keywords)
|
||||
if hit:
|
||||
correct += 1
|
||||
icon = "OK" if hit else "??"
|
||||
print(f" {icon} Q: {question}")
|
||||
print(f" A: {answer[:150]}")
|
||||
time.sleep(1)
|
||||
|
||||
print(f"\n✅ Fact recall accuracy: {correct_recalls}/{len(queries)} ({correct_recalls/len(queries)*100:.0f}%)")
|
||||
accuracy = correct / len(queries) * 100
|
||||
results["recall"] = correct >= 3
|
||||
print(f"\n Recall: {correct}/{len(queries)} ({accuracy:.0f}%)")
|
||||
print(f" {'PASS' if results['recall'] else 'FAIL'} (threshold: >=3)")
|
||||
|
||||
# TEST 6: Conversation History Recall
|
||||
print("\n📋 TEST 6: Conversation History (Episodic) Recall")
|
||||
print("-" * 70)
|
||||
# FINAL SUMMARY
|
||||
section("FINAL SUMMARY")
|
||||
|
||||
print("Asking about conversation history...")
|
||||
resp = send_message("What have we talked about today?")
|
||||
summary = resp["content"]
|
||||
print(f"💬 Miku's summary:\n{summary}")
|
||||
total = len(results)
|
||||
passed = sum(1 for v in results.values() if v)
|
||||
print()
|
||||
for name, ok in results.items():
|
||||
print(f" [{'PASS' if ok else 'FAIL'}] {name}")
|
||||
|
||||
# Final Summary
|
||||
print("\n" + "=" * 70)
|
||||
print("📊 FINAL SUMMARY")
|
||||
print("=" * 70)
|
||||
print(f"✅ Episodic memories: {get_qdrant_count('episodic')}")
|
||||
print(f"✅ Declarative facts: {declarative_count}")
|
||||
print(f"✅ Miku's messages stored: {len(miku_messages)}")
|
||||
print(f"✅ Fact recall accuracy: {correct_recalls}/{len(queries)}")
|
||||
print(f"\n Score: {passed}/{total}")
|
||||
print(f" Episodic: {qdrant_count('episodic')}")
|
||||
print(f" Declarative: {qdrant_count('declarative')}")
|
||||
|
||||
# Overall verdict
|
||||
if declarative_count >= 5 and correct_recalls >= 3:
|
||||
print("\n🎉 PIPELINE TEST: PASS")
|
||||
print(" All major components working correctly!")
|
||||
if passed == total:
|
||||
print("\n ALL TESTS PASSED!")
|
||||
elif passed >= total - 1:
|
||||
print("\n MOSTLY PASSING - minor issues only")
|
||||
else:
|
||||
print("\n⚠️ PIPELINE TEST: PARTIAL PASS")
|
||||
print(" Some components may need adjustment")
|
||||
print("\n SOME TESTS FAILED - review above")
|
||||
|
||||
print("\n" + "=" * 70)
|
||||
|
||||
Reference in New Issue
Block a user