Files
LittleWhiteBox/modules/story-summary/vector/retrieval/recall.js

590 lines
26 KiB
JavaScript
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
// ═══════════════════════════════════════════════════════════════════════════
// Story Summary - Recall Engine (v2 - LLM Augmented)
// 纯向量路召回LLM Query Expansion 替代 BM25
// ═══════════════════════════════════════════════════════════════════════════
import { getAllChunkVectors, getAllEventVectors, getChunksByFloors, getMeta } from '../storage/chunk-store.js';
import { getEngineFingerprint } from '../utils/embedder.js';
import { xbLog } from '../../../../core/debug-core.js';
import { getContext } from '../../../../../../../extensions.js';
import { getSummaryStore } from '../../data/store.js';
import { filterText } from '../utils/text-filter.js';
import {
searchStateAtoms,
buildL0FloorBonus,
stateToVirtualChunks,
mergeAndSparsify,
} from '../pipeline/state-recall.js';
// 新增LLM 模块
import { expandQueryCached, buildSearchText } from '../llm/query-expansion.js';
import { embed } from '../llm/siliconflow.js';
const MODULE_ID = 'recall';
// ═══════════════════════════════════════════════════════════════════════════
// 配置
// ═══════════════════════════════════════════════════════════════════════════
const CONFIG = {
// Query
QUERY_MSG_COUNT: 2,
QUERY_MAX_CHARS: 100,
QUERY_EXPANSION_TIMEOUT: 3000,
// 因果链
CAUSAL_CHAIN_MAX_DEPTH: 10,
CAUSAL_INJECT_MAX: 30,
// 候选数量
CANDIDATE_CHUNKS: 150,
CANDIDATE_EVENTS: 100,
// 最终输出
MAX_CHUNKS: 40,
MAX_EVENTS: 80,
// 相似度阈值
MIN_SIMILARITY_CHUNK: 0.55,
MIN_SIMILARITY_CHUNK_RECENT: 0.45,
MIN_SIMILARITY_EVENT: 0.60,
// MMR
MMR_LAMBDA: 0.72,
// L0 加权
L0_FLOOR_BONUS_FACTOR: 0.10,
FLOOR_MAX_CHUNKS: 2,
};
// ═══════════════════════════════════════════════════════════════════════════
// 工具函数
// ═══════════════════════════════════════════════════════════════════════════
function cosineSimilarity(a, b) {
if (!a?.length || !b?.length || a.length !== b.length) return 0;
let dot = 0, nA = 0, nB = 0;
for (let i = 0; i < a.length; i++) {
dot += a[i] * b[i];
nA += a[i] * a[i];
nB += b[i] * b[i];
}
return nA && nB ? dot / (Math.sqrt(nA) * Math.sqrt(nB)) : 0;
}
function normalize(s) {
return String(s || '').normalize('NFKC').replace(/[\u200B-\u200D\uFEFF]/g, '').trim().toLowerCase();
}
function parseFloorRange(summary) {
if (!summary) return null;
const match = String(summary).match(/\(#(\d+)(?:-(\d+))?\)/);
if (!match) return null;
const start = Math.max(0, parseInt(match[1], 10) - 1);
const end = Math.max(0, (match[2] ? parseInt(match[2], 10) : parseInt(match[1], 10)) - 1);
return { start, end };
}
function cleanForRecall(text) {
return filterText(text).replace(/\[tts:[^\]]*\]/gi, '').trim();
}
// ═══════════════════════════════════════════════════════════════════════════
// MMR 选择
// ═══════════════════════════════════════════════════════════════════════════
function mmrSelect(candidates, k, lambda, getVector, getScore) {
const selected = [];
const ids = new Set();
while (selected.length < k && candidates.length) {
let best = null, bestScore = -Infinity;
for (const c of candidates) {
if (ids.has(c._id)) continue;
const rel = getScore(c);
let div = 0;
if (selected.length) {
const vC = getVector(c);
if (vC?.length) {
for (const s of selected) {
const sim = cosineSimilarity(vC, getVector(s));
if (sim > div) div = sim;
}
}
}
const score = lambda * rel - (1 - lambda) * div;
if (score > bestScore) {
bestScore = score;
best = c;
}
}
if (!best) break;
selected.push(best);
ids.add(best._id);
}
return selected;
}
// ═══════════════════════════════════════════════════════════════════════════
// 因果链追溯
// ═══════════════════════════════════════════════════════════════════════════
function buildEventIndex(allEvents) {
const map = new Map();
for (const e of allEvents || []) {
if (e?.id) map.set(e.id, e);
}
return map;
}
function traceCausalAncestors(recalledEvents, eventIndex, maxDepth = CONFIG.CAUSAL_CHAIN_MAX_DEPTH) {
const out = new Map();
const idRe = /^evt-\d+$/;
function visit(parentId, depth, chainFrom) {
if (depth > maxDepth) return;
if (!idRe.test(parentId)) return;
const ev = eventIndex.get(parentId);
if (!ev) return;
const existed = out.get(parentId);
if (!existed) {
out.set(parentId, { event: ev, depth, chainFrom: [chainFrom] });
} else {
if (depth < existed.depth) existed.depth = depth;
if (!existed.chainFrom.includes(chainFrom)) existed.chainFrom.push(chainFrom);
}
for (const next of (ev.causedBy || [])) {
visit(String(next || '').trim(), depth + 1, chainFrom);
}
}
for (const r of recalledEvents || []) {
const rid = r?.event?.id;
if (!rid) continue;
for (const cid of (r.event?.causedBy || [])) {
visit(String(cid || '').trim(), 1, rid);
}
}
return Array.from(out.values())
.sort((a, b) => {
const refDiff = b.chainFrom.length - a.chainFrom.length;
if (refDiff !== 0) return refDiff;
return a.depth - b.depth;
})
.slice(0, CONFIG.CAUSAL_INJECT_MAX);
}
// ═══════════════════════════════════════════════════════════════════════════
// Query 构建
// ═══════════════════════════════════════════════════════════════════════════
function getLastRounds(chat, roundCount = 3, excludeLastAi = false) {
if (!chat?.length) return [];
let messages = [...chat];
if (excludeLastAi && messages.length > 0 && !messages[messages.length - 1]?.is_user) {
messages = messages.slice(0, -1);
}
const result = [];
let rounds = 0;
for (let i = messages.length - 1; i >= 0 && rounds < roundCount; i--) {
result.unshift(messages[i]);
if (messages[i]?.is_user) rounds++;
}
return result;
}
// ═══════════════════════════════════════════════════════════════════════════
// L2 Events 检索(纯向量)
// ═══════════════════════════════════════════════════════════════════════════
async function searchEvents(queryVector, allEvents, vectorConfig, entitySet, l0FloorBonus) {
const { chatId } = getContext();
if (!chatId || !queryVector?.length) return [];
const meta = await getMeta(chatId);
const fp = getEngineFingerprint(vectorConfig);
if (meta.fingerprint && meta.fingerprint !== fp) return [];
const eventVectors = await getAllEventVectors(chatId);
const vectorMap = new Map(eventVectors.map(v => [v.eventId, v.vector]));
if (!vectorMap.size) return [];
// 向量检索
const scored = (allEvents || []).map(event => {
const v = vectorMap.get(event.id);
const rawSim = v ? cosineSimilarity(queryVector, v) : 0;
// L0 加权
let bonus = 0;
const range = parseFloorRange(event.summary);
if (range) {
for (let f = range.start; f <= range.end; f++) {
if (l0FloorBonus.has(f)) {
bonus += l0FloorBonus.get(f);
break;
}
}
}
// 实体命中加分
const participants = (event.participants || []).map(p => normalize(p));
const hasEntity = participants.some(p => entitySet.has(p));
if (hasEntity) bonus += 0.05;
return {
_id: event.id,
event,
similarity: rawSim + bonus,
_rawSim: rawSim,
_hasEntity: hasEntity,
vector: v,
};
});
// 过滤 + 排序
const candidates = scored
.filter(s => s.similarity >= CONFIG.MIN_SIMILARITY_EVENT)
.sort((a, b) => b.similarity - a.similarity)
.slice(0, CONFIG.CANDIDATE_EVENTS);
// MMR 去重
const selected = mmrSelect(
candidates,
CONFIG.MAX_EVENTS,
CONFIG.MMR_LAMBDA,
c => c.vector,
c => c.similarity
);
return selected.map(s => ({
event: s.event,
similarity: s.similarity,
_recallType: s._hasEntity ? 'DIRECT' : 'SIMILAR',
_rawSim: s._rawSim,
}));
}
// ═══════════════════════════════════════════════════════════════════════════
// L1 Chunks 检索(纯向量)
// ═══════════════════════════════════════════════════════════════════════════
async function searchChunks(queryVector, vectorConfig, l0FloorBonus, lastSummarizedFloor) {
const { chatId } = getContext();
if (!chatId || !queryVector?.length) return [];
const meta = await getMeta(chatId);
const fp = getEngineFingerprint(vectorConfig);
if (meta.fingerprint && meta.fingerprint !== fp) return [];
const chunkVectors = await getAllChunkVectors(chatId);
if (!chunkVectors.length) return [];
// 向量检索
const scored = chunkVectors.map(cv => {
const match = String(cv.chunkId).match(/c-(\d+)-(\d+)/);
const floor = match ? parseInt(match[1], 10) : 0;
const baseSim = cosineSimilarity(queryVector, cv.vector);
const l0Bonus = l0FloorBonus.get(floor) || 0;
return {
_id: cv.chunkId,
chunkId: cv.chunkId,
floor,
chunkIdx: match ? parseInt(match[2], 10) : 0,
similarity: baseSim + l0Bonus,
_baseSim: baseSim,
vector: cv.vector,
};
});
// 过滤(近期区域用更低阈值)
const candidates = scored
.filter(s => {
const threshold = s.floor > lastSummarizedFloor
? CONFIG.MIN_SIMILARITY_CHUNK_RECENT
: CONFIG.MIN_SIMILARITY_CHUNK;
return s.similarity >= threshold;
})
.sort((a, b) => b.similarity - a.similarity)
.slice(0, CONFIG.CANDIDATE_CHUNKS);
// MMR 去重
const selected = mmrSelect(
candidates,
CONFIG.MAX_CHUNKS,
CONFIG.MMR_LAMBDA,
c => c.vector,
c => c.similarity
);
// 每楼层稀疏
const bestByFloor = new Map();
for (const s of selected) {
const prev = bestByFloor.get(s.floor);
if (!prev || s.similarity > prev.similarity) {
bestByFloor.set(s.floor, s);
}
}
const sparse = Array.from(bestByFloor.values()).sort((a, b) => b.similarity - a.similarity);
// 获取完整 chunk 数据
const floors = [...new Set(sparse.map(c => c.floor))];
const chunks = await getChunksByFloors(chatId, floors);
const chunkMap = new Map(chunks.map(c => [c.chunkId, c]));
return sparse.map(item => {
const chunk = chunkMap.get(item.chunkId);
if (!chunk) return null;
return {
chunkId: item.chunkId,
floor: item.floor,
chunkIdx: item.chunkIdx,
speaker: chunk.speaker,
isUser: chunk.isUser,
text: chunk.text,
similarity: item.similarity,
};
}).filter(Boolean);
}
// ═══════════════════════════════════════════════════════════════════════════
// 日志格式化
// ═══════════════════════════════════════════════════════════════════════════
function formatRecallLog({ elapsed, expansion, l0Results, chunkResults, eventResults, causalEvents }) {
const lines = [
'╔══════════════════════════════════════════════════════════════╗',
'║ 记忆召回报告 (v2) ║',
'╠══════════════════════════════════════════════════════════════╣',
`║ 总耗时: ${elapsed}ms `,
'╚══════════════════════════════════════════════════════════════╝',
'',
'┌─────────────────────────────────────────────────────────────┐',
'│ 【Query Expansion】LLM 语义翻译 │',
'└─────────────────────────────────────────────────────────────┘',
];
if (expansion) {
if (expansion.entities?.length) {
lines.push(` 实体: ${expansion.entities.join(' | ')}`);
}
if (expansion.implicit?.length) {
lines.push(` 隐含: ${expansion.implicit.join(' | ')}`);
}
if (expansion.queries?.length) {
lines.push(` 短句: ${expansion.queries.join(' | ')}`);
}
} else {
lines.push(' (未启用或失败)');
}
lines.push('');
lines.push('┌─────────────────────────────────────────────────────────────┐');
lines.push('│ 【召回统计】 │');
lines.push('└─────────────────────────────────────────────────────────────┘');
// L0
const l0Floors = [...new Set((l0Results || []).map(r => r.floor))].sort((a, b) => a - b);
lines.push(` L0 Atoms: ${l0Results?.length || 0}`);
if (l0Floors.length) {
lines.push(` 影响楼层: ${l0Floors.slice(0, 10).join(', ')}${l0Floors.length > 10 ? '...' : ''}`);
}
// L1
lines.push(` L1 Chunks: ${chunkResults?.length || 0}`);
// L2
const directCount = (eventResults || []).filter(e => e._recallType === 'DIRECT').length;
const similarCount = (eventResults || []).filter(e => e._recallType === 'SIMILAR').length;
lines.push(` L2 Events: ${eventResults?.length || 0} 条 (实体命中: ${directCount}, 相似: ${similarCount})`);
// 因果链
if (causalEvents?.length) {
lines.push(` 因果链: ${causalEvents.length}`);
}
// Top Events
if (eventResults?.length) {
lines.push('');
lines.push('┌─────────────────────────────────────────────────────────────┐');
lines.push('│ 【Top 5 Events】 │');
lines.push('└─────────────────────────────────────────────────────────────┘');
eventResults.slice(0, 5).forEach((e, i) => {
const ev = e.event || {};
const title = (ev.title || '').slice(0, 20).padEnd(20);
const sim = (e.similarity || 0).toFixed(2);
const type = e._recallType === 'DIRECT' ? '⭐' : '○';
lines.push(` ${i + 1}. ${type} ${title} sim=${sim}`);
});
}
lines.push('');
return lines.join('\n');
}
// ═══════════════════════════════════════════════════════════════════════════
// 主函数
// ═══════════════════════════════════════════════════════════════════════════
export async function recallMemory(queryText, allEvents, vectorConfig, options = {}) {
const T0 = performance.now();
const { chat } = getContext();
const store = getSummaryStore();
const lastSummarizedFloor = store?.lastSummarizedMesId ?? -1;
const { pendingUserMessage = null, excludeLastAi = false } = options;
if (!allEvents?.length) {
return { events: [], chunks: [], elapsed: 0, logText: 'No events.' };
}
// ═══════════════════════════════════════════════════════════════════════
// Step 1: Query ExpansionLLM 语义翻译)
// ═══════════════════════════════════════════════════════════════════════
const lastRounds = getLastRounds(chat, 3, excludeLastAi);
if (pendingUserMessage) {
lastRounds.push({ is_user: true, mes: pendingUserMessage });
}
let expansion = { entities: [], implicit: [], queries: [] };
try {
expansion = await expandQueryCached(lastRounds, { timeout: 6000 });
xbLog.info(MODULE_ID, `Query Expansion: e=${expansion.entities.length} i=${expansion.implicit.length} q=${expansion.queries.length}`);
} catch (e) {
xbLog.warn(MODULE_ID, 'Query Expansion 失败,降级使用原始文本', e);
}
const searchText = buildSearchText(expansion);
const finalSearchText = searchText || lastRounds.map(m => filterText(m.mes || '').slice(0, 200)).join(' ');
// ═══════════════════════════════════════════════════════════════════════
// Step 2: 向量化
// ═══════════════════════════════════════════════════════════════════════
let queryVector;
try {
const [vec] = await embed([finalSearchText], { timeout: 10000 });
queryVector = vec;
} catch (e) {
xbLog.error(MODULE_ID, '向量化失败', e);
return { events: [], chunks: [], elapsed: Math.round(performance.now() - T0), logText: 'Embedding failed.' };
}
if (!queryVector?.length) {
return { events: [], chunks: [], elapsed: Math.round(performance.now() - T0), logText: 'Empty query vector.' };
}
// ═══════════════════════════════════════════════════════════════════════
// Step 3: L0 召回
// ═══════════════════════════════════════════════════════════════════════
let l0Results = [];
let l0FloorBonus = new Map();
let l0VirtualChunks = [];
try {
l0Results = await searchStateAtoms(queryVector, vectorConfig);
l0FloorBonus = buildL0FloorBonus(l0Results, CONFIG.L0_FLOOR_BONUS_FACTOR);
l0VirtualChunks = stateToVirtualChunks(l0Results);
} catch (e) {
xbLog.warn(MODULE_ID, 'L0 召回失败', e);
}
// ═══════════════════════════════════════════════════════════════════════
// Step 4: L1 + L2 召回(并行)
// ═══════════════════════════════════════════════════════════════════════
const entitySet = new Set((expansion.entities || []).map(normalize));
const [chunkResults, eventResults] = await Promise.all([
searchChunks(queryVector, vectorConfig, l0FloorBonus, lastSummarizedFloor),
searchEvents(queryVector, allEvents, vectorConfig, entitySet, l0FloorBonus),
]);
// 合并 L0 虚拟 chunks 和 L1 chunks
const mergedChunks = mergeAndSparsify(l0VirtualChunks, chunkResults, CONFIG.FLOOR_MAX_CHUNKS);
// ═══════════════════════════════════════════════════════════════════════
// Step 5: 因果链追溯
// ═══════════════════════════════════════════════════════════════════════
const eventIndex = buildEventIndex(allEvents);
const causalMap = traceCausalAncestors(eventResults, eventIndex);
const recalledIdSet = new Set(eventResults.map(x => x?.event?.id).filter(Boolean));
const causalEvents = causalMap
.filter(x => x?.event?.id && !recalledIdSet.has(x.event.id))
.map(x => ({
event: x.event,
similarity: 0,
_recallType: 'CAUSAL',
_causalDepth: x.depth,
chainFrom: x.chainFrom,
}));
// ═══════════════════════════════════════════════════════════════════════
// 返回
// ═══════════════════════════════════════════════════════════════════════
const elapsed = Math.round(performance.now() - T0);
const logText = formatRecallLog({
elapsed,
expansion,
l0Results,
chunkResults: mergedChunks,
eventResults,
causalEvents,
});
console.group('%c[Recall v2]', 'color: #7c3aed; font-weight: bold');
console.log(`Elapsed: ${elapsed}ms`);
console.log(`Expansion: ${expansion.entities.join(', ')} | ${expansion.implicit.join(', ')}`);
console.log(`L0: ${l0Results.length} | L1: ${mergedChunks.length} | L2: ${eventResults.length} | Causal: ${causalEvents.length}`);
console.groupEnd();
return {
events: eventResults,
causalEvents,
chunks: mergedChunks,
expansion,
queryEntities: expansion.entities,
elapsed,
logText,
};
}
// ═══════════════════════════════════════════════════════════════════════════
// 辅助导出
// ═══════════════════════════════════════════════════════════════════════════
export function buildQueryText(chat, count = 2, excludeLastAi = false) {
if (!chat?.length) return '';
let messages = chat;
if (excludeLastAi && messages.length > 0 && !messages[messages.length - 1]?.is_user) {
messages = messages.slice(0, -1);
}
return messages.slice(-count).map(m => {
const text = cleanForRecall(m.mes);
const speaker = m.name || (m.is_user ? '用户' : '角色');
return `${speaker}: ${text.slice(0, 500)}`;
}).filter(Boolean).join('\n');
}