Files
LittleWhiteBox/modules/story-summary/vector/recall.js

520 lines
22 KiB
JavaScript
Raw Normal View History

2026-01-26 01:16:35 +08:00
// Story Summary - Recall Engine
// L1 chunk + L2 event 召回
// - 全量向量打分
// - 指数衰减加权 Query Embedding
// - 实体/参与者加分
// - MMR 去重
// - floor 稀疏去重
import { getAllEventVectors, getAllChunkVectors, getChunksByFloors, getMeta } from './chunk-store.js';
import { embed, getEngineFingerprint } from './embedder.js';
import { xbLog } from '../../../core/debug-core.js';
import { getContext } from '../../../../../../extensions.js';
import { getSummaryStore } from '../data/store.js';
const MODULE_ID = 'recall';
const CONFIG = {
QUERY_MSG_COUNT: 5,
QUERY_DECAY_BETA: 0.7,
QUERY_MAX_CHARS: 600,
QUERY_CONTEXT_CHARS: 240,
CANDIDATE_CHUNKS: 120,
CANDIDATE_EVENTS: 100,
TOP_K_CHUNKS: 40,
TOP_K_EVENTS: 35,
MIN_SIMILARITY: 0.35,
MMR_LAMBDA: 0.72,
BONUS_PARTICIPANT_HIT: 0.08,
BONUS_TEXT_HIT: 0.05,
BONUS_WORLD_TOPIC_HIT: 0.06,
FLOOR_LIMIT: 1,
};
// ═══════════════════════════════════════════════════════════════════════════
// 工具函数
// ═══════════════════════════════════════════════════════════════════════════
function cosineSimilarity(a, b) {
if (!a?.length || !b?.length || a.length !== b.length) return 0;
let dot = 0, nA = 0, nB = 0;
for (let i = 0; i < a.length; i++) {
dot += a[i] * b[i];
nA += a[i] * a[i];
nB += b[i] * b[i];
}
return nA && nB ? dot / (Math.sqrt(nA) * Math.sqrt(nB)) : 0;
}
function normalizeVec(v) {
let s = 0;
for (let i = 0; i < v.length; i++) s += v[i] * v[i];
s = Math.sqrt(s) || 1;
return v.map(x => x / s);
}
function normalize(s) {
return String(s || '').normalize('NFKC').replace(/[\u200B-\u200D\uFEFF]/g, '').trim();
}
function stripNoise(text) {
return String(text || '')
.replace(/<think>[\s\S]*?<\/think>/gi, '')
.replace(/<thinking>[\s\S]*?<\/thinking>/gi, '')
.replace(/\[tts:[^\]]*\]/gi, '')
.trim();
}
function buildExpDecayWeights(n, beta) {
const last = n - 1;
const w = Array.from({ length: n }, (_, i) => Math.exp(beta * (i - last)));
const sum = w.reduce((a, b) => a + b, 0) || 1;
return w.map(x => x / sum);
}
// ═══════════════════════════════════════════════════════════════════════════
// Query 构建
// ═══════════════════════════════════════════════════════════════════════════
function buildQuerySegments(chat, count, excludeLastAi) {
if (!chat?.length) return [];
let messages = chat;
if (excludeLastAi && messages.length > 0 && !messages[messages.length - 1]?.is_user) {
messages = messages.slice(0, -1);
}
return messages.slice(-count).map((m, idx, arr) => {
const speaker = m.name || (m.is_user ? '用户' : '角色');
const clean = stripNoise(m.mes);
if (!clean) return '';
const limit = idx === arr.length - 1 ? CONFIG.QUERY_MAX_CHARS : CONFIG.QUERY_CONTEXT_CHARS;
return `${speaker}: ${clean.slice(0, limit)}`;
}).filter(Boolean);
}
async function embedWeightedQuery(segments, vectorConfig) {
if (!segments?.length) return null;
const weights = buildExpDecayWeights(segments.length, CONFIG.QUERY_DECAY_BETA);
const vecs = await embed(segments, vectorConfig);
const dims = vecs?.[0]?.length || 0;
if (!dims) return null;
const out = new Array(dims).fill(0);
for (let i = 0; i < vecs.length; i++) {
for (let j = 0; j < dims; j++) out[j] += (vecs[i][j] || 0) * weights[i];
}
return { vector: normalizeVec(out), weights };
}
// ═══════════════════════════════════════════════════════════════════════════
// 实体抽取
// ═══════════════════════════════════════════════════════════════════════════
function buildEntityLexicon(store, allEvents) {
const { name1 } = getContext();
const userName = normalize(name1);
const set = new Set();
for (const e of allEvents || []) {
for (const p of e.participants || []) {
const s = normalize(p);
if (s) set.add(s);
}
}
const json = store?.json || {};
for (const m of json.characters?.main || []) {
const s = normalize(typeof m === 'string' ? m : m?.name);
if (s) set.add(s);
}
for (const a of json.arcs || []) {
const s = normalize(a?.name);
if (s) set.add(s);
}
for (const w of json.world || []) {
const t = normalize(w?.topic);
if (t && !t.includes('::')) set.add(t);
}
for (const r of json.characters?.relationships || []) {
const from = normalize(r?.from);
const to = normalize(r?.to);
if (from) set.add(from);
if (to) set.add(to);
}
const stop = new Set([userName, '我', '你', '他', '她', '它', '用户', '角色', 'assistant'].map(normalize).filter(Boolean));
return Array.from(set)
.filter(s => s.length >= 2 && !stop.has(s) && !/^[\s\p{P}\p{S}]+$/u.test(s) && !/<[^>]+>/.test(s))
.slice(0, 5000);
}
function extractEntities(text, lexicon) {
const t = normalize(text);
if (!t || !lexicon?.length) return [];
const sorted = [...lexicon].sort((a, b) => b.length - a.length);
const hits = [];
for (const e of sorted) {
if (t.includes(e)) hits.push(e);
if (hits.length >= 20) break;
}
return hits;
}
// ═══════════════════════════════════════════════════════════════════════════
// MMR
// ═══════════════════════════════════════════════════════════════════════════
function mmrSelect(candidates, k, lambda, getVector, getScore) {
const selected = [];
const ids = new Set();
while (selected.length < k && candidates.length) {
let best = null, bestScore = -Infinity;
for (const c of candidates) {
if (ids.has(c._id)) continue;
const rel = getScore(c);
let div = 0;
if (selected.length) {
const vC = getVector(c);
if (vC?.length) {
for (const s of selected) {
const sim = cosineSimilarity(vC, getVector(s));
if (sim > div) div = sim;
}
}
}
const score = lambda * rel - (1 - lambda) * div;
if (score > bestScore) {
bestScore = score;
best = c;
}
}
if (!best) break;
selected.push(best);
ids.add(best._id);
}
return selected;
}
// ═══════════════════════════════════════════════════════════════════════════
// L1 Chunks 检索
// ═══════════════════════════════════════════════════════════════════════════
async function searchChunks(queryVector, vectorConfig) {
const { chatId } = getContext();
if (!chatId || !queryVector?.length) return [];
const meta = await getMeta(chatId);
const fp = getEngineFingerprint(vectorConfig);
if (meta.fingerprint && meta.fingerprint !== fp) return [];
const chunkVectors = await getAllChunkVectors(chatId);
if (!chunkVectors.length) return [];
const scored = chunkVectors.map(cv => {
const match = String(cv.chunkId).match(/c-(\d+)-(\d+)/);
return {
_id: cv.chunkId,
chunkId: cv.chunkId,
floor: match ? parseInt(match[1], 10) : 0,
chunkIdx: match ? parseInt(match[2], 10) : 0,
similarity: cosineSimilarity(queryVector, cv.vector),
vector: cv.vector,
};
});
const candidates = scored
.filter(s => s.similarity >= CONFIG.MIN_SIMILARITY)
.sort((a, b) => b.similarity - a.similarity)
.slice(0, CONFIG.CANDIDATE_CHUNKS);
const selected = mmrSelect(
candidates,
CONFIG.TOP_K_CHUNKS,
CONFIG.MMR_LAMBDA,
c => c.vector,
c => c.similarity
);
// floor 稀疏去重
const floorCount = new Map();
const sparse = [];
for (const s of selected.sort((a, b) => b.similarity - a.similarity)) {
const cnt = floorCount.get(s.floor) || 0;
if (cnt >= CONFIG.FLOOR_LIMIT) continue;
floorCount.set(s.floor, cnt + 1);
sparse.push(s);
}
const floors = [...new Set(sparse.map(c => c.floor))];
const chunks = await getChunksByFloors(chatId, floors);
const chunkMap = new Map(chunks.map(c => [c.chunkId, c]));
return sparse.map(item => {
const chunk = chunkMap.get(item.chunkId);
if (!chunk) return null;
return {
chunkId: item.chunkId,
floor: item.floor,
chunkIdx: item.chunkIdx,
speaker: chunk.speaker,
isUser: chunk.isUser,
text: chunk.text,
similarity: item.similarity,
};
}).filter(Boolean);
}
// ═══════════════════════════════════════════════════════════════════════════
// L2 Events 检索
// ═══════════════════════════════════════════════════════════════════════════
async function searchEvents(queryVector, allEvents, vectorConfig, store, queryEntities) {
const { chatId, name1 } = getContext();
if (!chatId || !queryVector?.length) return [];
const meta = await getMeta(chatId);
const fp = getEngineFingerprint(vectorConfig);
if (meta.fingerprint && meta.fingerprint !== fp) return [];
const eventVectors = await getAllEventVectors(chatId);
const vectorMap = new Map(eventVectors.map(v => [v.eventId, v.vector]));
if (!vectorMap.size) return [];
const userName = normalize(name1);
const querySet = new Set((queryEntities || []).map(normalize));
// 只取硬约束类的 world topic
const worldTopics = (store?.json?.world || [])
.filter(w => ['inventory', 'rule', 'knowledge'].includes(String(w.category).toLowerCase()))
.map(w => normalize(w.topic))
.filter(Boolean);
const scored = (allEvents || []).map((event, idx) => {
const v = vectorMap.get(event.id);
const sim = v ? cosineSimilarity(queryVector, v) : 0;
let bonus = 0;
const reasons = [];
// participants 命中
const participants = (event.participants || []).map(normalize).filter(Boolean);
if (participants.some(p => p !== userName && querySet.has(p))) {
bonus += CONFIG.BONUS_PARTICIPANT_HIT;
reasons.push('participant');
}
// text 命中
const text = normalize(`${event.title || ''} ${event.summary || ''}`);
if ((queryEntities || []).some(e => text.includes(normalize(e)))) {
bonus += CONFIG.BONUS_TEXT_HIT;
reasons.push('text');
}
// world topic 命中
if (worldTopics.some(topic => querySet.has(topic) && text.includes(topic))) {
bonus += CONFIG.BONUS_WORLD_TOPIC_HIT;
reasons.push('world');
}
return {
_id: event.id,
_idx: idx,
event,
similarity: sim,
bonus,
finalScore: sim + bonus,
reasons,
isDirect: reasons.includes('participant'),
vector: v,
};
});
const candidates = scored
.filter(s => s.similarity >= CONFIG.MIN_SIMILARITY)
.sort((a, b) => b.finalScore - a.finalScore)
.slice(0, CONFIG.CANDIDATE_EVENTS);
const selected = mmrSelect(
candidates,
CONFIG.TOP_K_EVENTS,
CONFIG.MMR_LAMBDA,
c => c.vector,
c => c.finalScore
);
return selected
.sort((a, b) => b.finalScore - a.finalScore)
.map(s => ({
event: s.event,
similarity: s.finalScore,
_recallType: s.isDirect ? 'DIRECT' : 'SIMILAR',
_recallReason: s.reasons.length ? s.reasons.join('+') : '相似',
}));
}
// ═══════════════════════════════════════════════════════════════════════════
// 日志
// ═══════════════════════════════════════════════════════════════════════════
function formatRecallLog({ elapsed, segments, weights, chunkResults, eventResults, allEvents, queryEntities }) {
const lines = [
'╔══════════════════════════════════════════════════════════════╗',
'║ 记忆召回报告 ║',
'╠══════════════════════════════════════════════════════════════╣',
`║ 耗时: ${elapsed}ms`,
'╚══════════════════════════════════════════════════════════════╝',
'',
'┌─────────────────────────────────────────────────────────────┐',
'│ 【查询构建】最近 5 条消息,指数衰减加权 (β=0.7) │',
'│ 权重越高 = 对召回方向影响越大 │',
'└─────────────────────────────────────────────────────────────┘',
];
// 按权重从高到低排序显示
const segmentsSorted = segments.map((s, i) => ({
idx: i + 1,
weight: weights?.[i] ?? 0,
text: s,
})).sort((a, b) => b.weight - a.weight);
segmentsSorted.forEach((s, rank) => {
const bar = '█'.repeat(Math.round(s.weight * 20));
const preview = s.text.length > 60 ? s.text.slice(0, 60) + '...' : s.text;
const marker = rank === 0 ? ' ◀ 主导' : '';
lines.push(` ${(s.weight * 100).toFixed(1).padStart(5)}% ${bar.padEnd(12)} ${preview}${marker}`);
});
lines.push('');
lines.push('┌─────────────────────────────────────────────────────────────┐');
lines.push('│ 【提取实体】用于判断"亲身经历"(DIRECT) │');
lines.push('└─────────────────────────────────────────────────────────────┘');
lines.push(` ${queryEntities?.length ? queryEntities.join('、') : '(无)'}`);
lines.push('');
lines.push('┌─────────────────────────────────────────────────────────────┐');
lines.push(`│ 【L1 原文片段】召回 ${chunkResults.length}`);
lines.push('└─────────────────────────────────────────────────────────────┘');
chunkResults.slice(0, 15).forEach((c, i) => {
const preview = c.text.length > 50 ? c.text.slice(0, 50) + '...' : c.text;
lines.push(` ${String(i + 1).padStart(2)}. #${String(c.floor).padStart(3)} [${c.speaker}] ${preview}`);
lines.push(` 相似度: ${c.similarity.toFixed(3)}`);
});
if (chunkResults.length > 15) {
lines.push(` ... 还有 ${chunkResults.length - 15}`);
}
lines.push('');
lines.push('┌─────────────────────────────────────────────────────────────┐');
lines.push(`│ 【L2 事件记忆】召回 ${eventResults.length} / ${allEvents.length}`);
lines.push('│ DIRECT=亲身经历 SIMILAR=相关背景 │');
lines.push('└─────────────────────────────────────────────────────────────┘');
eventResults.forEach((e, i) => {
const type = e._recallType === 'DIRECT' ? '★ DIRECT ' : ' SIMILAR';
const title = e.event.title || '(无标题)';
lines.push(` ${String(i + 1).padStart(2)}. ${type} ${title}`);
lines.push(` 相似度: ${e.similarity.toFixed(3)} | 原因: ${e._recallReason}`);
});
// 统计
const directCount = eventResults.filter(e => e._recallType === 'DIRECT').length;
const similarCount = eventResults.filter(e => e._recallType === 'SIMILAR').length;
lines.push('');
lines.push('┌─────────────────────────────────────────────────────────────┐');
lines.push('│ 【统计】 │');
lines.push('└─────────────────────────────────────────────────────────────┘');
lines.push(` L1 片段: ${chunkResults.length}`);
lines.push(` L2 事件: ${eventResults.length} 条 (DIRECT: ${directCount}, SIMILAR: ${similarCount})`);
lines.push(` 实体命中: ${queryEntities?.length || 0}`);
lines.push('');
return lines.join('\n');
}
// ═══════════════════════════════════════════════════════════════════════════
// 主入口
// ═══════════════════════════════════════════════════════════════════════════
export async function recallMemory(queryText, allEvents, vectorConfig, options = {}) {
const T0 = performance.now();
const { chat } = getContext();
const store = getSummaryStore();
if (!allEvents?.length) {
return { events: [], chunks: [], elapsed: 0, logText: 'No events.' };
}
const segments = buildQuerySegments(chat, CONFIG.QUERY_MSG_COUNT, !!options.excludeLastAi);
let queryVector, weights;
try {
const result = await embedWeightedQuery(segments, vectorConfig);
queryVector = result?.vector;
weights = result?.weights;
} catch (e) {
xbLog.error(MODULE_ID, '查询向量生成失败', e);
return { events: [], chunks: [], elapsed: Math.round(performance.now() - T0), logText: 'Query embedding failed.' };
}
if (!queryVector?.length) {
return { events: [], chunks: [], elapsed: Math.round(performance.now() - T0), logText: 'Empty query vector.' };
}
const lexicon = buildEntityLexicon(store, allEvents);
const queryEntities = extractEntities([queryText, ...segments].join('\n'), lexicon);
const [chunkResults, eventResults] = await Promise.all([
searchChunks(queryVector, vectorConfig),
searchEvents(queryVector, allEvents, vectorConfig, store, queryEntities),
]);
const elapsed = Math.round(performance.now() - T0);
const logText = formatRecallLog({ elapsed, queryText, segments, weights, chunkResults, eventResults, allEvents, queryEntities });
console.group('%c[Recall]', 'color: #7c3aed; font-weight: bold');
console.log(`Elapsed: ${elapsed}ms | Entities: ${queryEntities.join(', ') || '(none)'}`);
console.log(`L1: ${chunkResults.length} | L2: ${eventResults.length}/${allEvents.length}`);
console.groupEnd();
return { events: eventResults, chunks: chunkResults, elapsed, logText };
}
export function buildQueryText(chat, count = 2, excludeLastAi = false) {
if (!chat?.length) return '';
let messages = chat;
if (excludeLastAi && messages.length > 0 && !messages[messages.length - 1]?.is_user) {
messages = messages.slice(0, -1);
}
return messages.slice(-count).map(m => {
const text = stripNoise(m.mes);
const speaker = m.name || (m.is_user ? '用户' : '角色');
return `${speaker}: ${text.slice(0, 500)}`;
}).filter(Boolean).join('\n');
}