Files
LittleWhiteBox/modules/story-summary/vector/llm/atom-extraction.js

265 lines
8.6 KiB
JavaScript
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
// ============================================================================
// atom-extraction.js - 30并发 + 首批错开 + 取消支持 + 进度回调
// ============================================================================
import { callLLM, parseJson } from './llm-service.js';
import { xbLog } from '../../../../core/debug-core.js';
import { filterText } from '../utils/text-filter.js';
const MODULE_ID = 'atom-extraction';
const CONCURRENCY = 10;
const RETRY_COUNT = 2;
const RETRY_DELAY = 500;
const DEFAULT_TIMEOUT = 20000;
const STAGGER_DELAY = 80; // 首批错开延迟ms
let batchCancelled = false;
export function cancelBatchExtraction() {
batchCancelled = true;
}
export function isBatchCancelled() {
return batchCancelled;
}
const SYSTEM_PROMPT = `你是叙事锚点提取器。从一轮对话(用户发言+角色回复中提取4-8个关键锚点。
输入格式:
<round>
<user>...</user>
<assistant>...</assistant>
</round>
只输出严格JSON不要解释不要前后多余文字
{"atoms":[{"t":"类型","s":"主体","v":"值","f":"来源"}]}
类型t
- emo: 情绪状态需要s主体
- loc: 地点/场景
- act: 关键动作需要s主体
- rev: 揭示/发现
- ten: 冲突/张力
- dec: 决定/承诺
规则:
- s: 主体(谁)
- v: 简洁值10字内
- f: "u"=用户发言中, "a"=角色回复中
- 只提取对未来检索有价值的锚点
- 无明显锚点返回空数组`;
function buildSemantic(atom, userName, aiName) {
const speaker = atom.f === 'u' ? userName : aiName;
const s = atom.s || speaker;
switch (atom.t) {
case 'emo': return `${s}感到${atom.v}`;
case 'loc': return `场景:${atom.v}`;
case 'act': return `${s}${atom.v}`;
case 'rev': return `揭示:${atom.v}`;
case 'ten': return `冲突:${atom.v}`;
case 'dec': return `${s}决定${atom.v}`;
default: return `${s} ${atom.v}`;
}
}
const sleep = (ms) => new Promise(r => setTimeout(r, ms));
async function extractAtomsForRoundWithRetry(userMessage, aiMessage, aiFloor, options = {}) {
const { timeout = DEFAULT_TIMEOUT } = options;
if (!aiMessage?.mes?.trim()) return [];
const parts = [];
const userName = userMessage?.name || '用户';
const aiName = aiMessage.name || '角色';
if (userMessage?.mes?.trim()) {
const userText = filterText(userMessage.mes);
parts.push(`<user name="${userName}">\n${userText}\n</user>`);
}
const aiText = filterText(aiMessage.mes);
parts.push(`<assistant name="${aiName}">\n${aiText}\n</assistant>`);
const input = `<round>\n${parts.join('\n')}\n</round>`;
xbLog.info(MODULE_ID, `floor ${aiFloor} 发送输入 len=${input.length}`);
for (let attempt = 0; attempt <= RETRY_COUNT; attempt++) {
if (batchCancelled) return [];
try {
const response = await callLLM([
{ role: 'system', content: SYSTEM_PROMPT },
{ role: 'user', content: input },
{ role: 'assistant', content: '收到,开始提取并仅输出 JSON。' },
], {
temperature: 0.2,
max_tokens: 500,
timeout,
});
const rawText = String(response || '');
if (!rawText.trim()) {
xbLog.warn(MODULE_ID, `floor ${aiFloor} 解析失败:响应为空`);
if (attempt < RETRY_COUNT) {
await sleep(RETRY_DELAY);
continue;
}
return null;
}
let parsed;
try {
parsed = parseJson(rawText);
} catch (e) {
xbLog.warn(MODULE_ID, `floor ${aiFloor} 解析失败JSON 异常`);
if (attempt < RETRY_COUNT) {
await sleep(RETRY_DELAY);
continue;
}
return null;
}
if (!parsed?.atoms || !Array.isArray(parsed.atoms)) {
xbLog.warn(MODULE_ID, `floor ${aiFloor} atoms 缺失raw="${rawText.slice(0, 300)}"`);
xbLog.warn(MODULE_ID, `floor ${aiFloor} 解析失败atoms 缺失`);
if (attempt < RETRY_COUNT) {
await sleep(RETRY_DELAY);
continue;
}
return null;
}
const filtered = parsed.atoms
.filter(a => a?.t && a?.v)
.map((a, idx) => ({
atomId: `atom-${aiFloor}-${idx}`,
floor: aiFloor,
type: a.t,
subject: a.s || null,
value: String(a.v).slice(0, 30),
source: a.f === 'u' ? 'user' : 'ai',
semantic: buildSemantic(a, userName, aiName),
}));
if (!filtered.length) {
xbLog.warn(MODULE_ID, `floor ${aiFloor} atoms 为空raw="${rawText.slice(0, 300)}"`);
}
return filtered;
} catch (e) {
if (batchCancelled) return null;
if (attempt < RETRY_COUNT) {
xbLog.warn(MODULE_ID, `floor ${aiFloor}${attempt + 1}次失败,重试...`, e?.message);
await sleep(RETRY_DELAY * (attempt + 1));
continue;
}
xbLog.error(MODULE_ID, `floor ${aiFloor} 失败`, e);
return null;
}
}
return null;
}
/**
* 单轮配对提取(增量时使用)
*/
export async function extractAtomsForRound(userMessage, aiMessage, aiFloor, options = {}) {
return extractAtomsForRoundWithRetry(userMessage, aiMessage, aiFloor, options);
}
/**
* 批量提取(首批 staggered 启动)
* @param {Array} chat
* @param {Function} onProgress - (current, total, failed) => void
*/
export async function batchExtractAtoms(chat, onProgress) {
if (!chat?.length) return [];
batchCancelled = false;
const pairs = [];
for (let i = 0; i < chat.length; i++) {
if (!chat[i].is_user) {
const userMsg = (i > 0 && chat[i - 1]?.is_user) ? chat[i - 1] : null;
pairs.push({ userMsg, aiMsg: chat[i], aiFloor: i });
}
}
if (!pairs.length) return [];
const allAtoms = [];
let completed = 0;
let failed = 0;
for (let i = 0; i < pairs.length; i += CONCURRENCY) {
if (batchCancelled) {
xbLog.info(MODULE_ID, `批量提取已取消 (${completed}/${pairs.length})`);
break;
}
const batch = pairs.slice(i, i + CONCURRENCY);
// ★ 首批 staggered 启动:错开 80ms 发送
if (i === 0) {
const promises = batch.map((pair, idx) => (async () => {
await sleep(idx * STAGGER_DELAY);
if (batchCancelled) return;
try {
const atoms = await extractAtomsForRoundWithRetry(pair.userMsg, pair.aiMsg, pair.aiFloor, { timeout: DEFAULT_TIMEOUT });
if (atoms?.length) {
allAtoms.push(...atoms);
} else {
failed++;
}
} catch {
failed++;
}
completed++;
onProgress?.(completed, pairs.length, failed);
})());
await Promise.all(promises);
} else {
// 后续批次正常并行
const promises = batch.map(pair =>
extractAtomsForRoundWithRetry(pair.userMsg, pair.aiMsg, pair.aiFloor, { timeout: DEFAULT_TIMEOUT })
.then(atoms => {
if (batchCancelled) return;
if (atoms?.length) {
allAtoms.push(...atoms);
} else {
failed++;
}
completed++;
onProgress?.(completed, pairs.length, failed);
})
.catch(() => {
if (batchCancelled) return;
failed++;
completed++;
onProgress?.(completed, pairs.length, failed);
})
);
await Promise.all(promises);
}
// 批次间隔
if (i + CONCURRENCY < pairs.length && !batchCancelled) {
await sleep(30);
}
}
const status = batchCancelled ? '已取消' : '完成';
xbLog.info(MODULE_ID, `批量提取${status}: ${allAtoms.length} atoms, ${completed}/${pairs.length}, ${failed} 失败`);
return allAtoms;
}