Files
LittleWhiteBox/modules/story-summary/vector/llm/atom-extraction.js

252 lines
8.0 KiB
JavaScript
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
// ============================================================================
// atom-extraction.js - 30并发 + 首批错开 + 取消支持 + 进度回调
// ============================================================================
import { callLLM, parseJson } from './llm-service.js';
import { xbLog } from '../../../../core/debug-core.js';
import { filterText } from '../utils/text-filter.js';
const MODULE_ID = 'atom-extraction';
const CONCURRENCY = 10;
const RETRY_COUNT = 2;
const RETRY_DELAY = 500;
const DEFAULT_TIMEOUT = 20000;
const STAGGER_DELAY = 80; // 首批错开延迟ms
let batchCancelled = false;
export function cancelBatchExtraction() {
batchCancelled = true;
}
export function isBatchCancelled() {
return batchCancelled;
}
const SYSTEM_PROMPT = `你是叙事锚点提取器。从一轮对话(用户发言+角色回复中提取4-8个关键锚点。
只输出JSON
{"atoms":[{"t":"类型","s":"主体","v":"值","f":"来源"}]}
类型t
- emo: 情绪状态需要s主体
- loc: 地点/场景
- act: 关键动作需要s主体
- rev: 揭示/发现
- ten: 冲突/张力
- dec: 决定/承诺
规则:
- s: 主体(谁)
- v: 简洁值10字内
- f: "u"=用户发言中, "a"=角色回复中
- 只提取对未来检索有价值的锚点
- 无明显锚点返回空数组`;
function buildSemantic(atom, userName, aiName) {
const speaker = atom.f === 'u' ? userName : aiName;
const s = atom.s || speaker;
switch (atom.t) {
case 'emo': return `${s}感到${atom.v}`;
case 'loc': return `场景:${atom.v}`;
case 'act': return `${s}${atom.v}`;
case 'rev': return `揭示:${atom.v}`;
case 'ten': return `冲突:${atom.v}`;
case 'dec': return `${s}决定${atom.v}`;
default: return `${s} ${atom.v}`;
}
}
const sleep = (ms) => new Promise(r => setTimeout(r, ms));
async function extractAtomsForRoundWithRetry(userMessage, aiMessage, aiFloor, options = {}) {
const { timeout = DEFAULT_TIMEOUT } = options;
if (!aiMessage?.mes?.trim()) return [];
const parts = [];
const userName = userMessage?.name || '用户';
const aiName = aiMessage.name || '角色';
if (userMessage?.mes?.trim()) {
const userText = filterText(userMessage.mes);
parts.push(`【用户:${userName}\n${userText}`);
}
const aiText = filterText(aiMessage.mes);
parts.push(`【角色:${aiName}\n${aiText}`);
const input = parts.join('\n\n---\n\n');
xbLog.info(MODULE_ID, `floor ${aiFloor} 发送输入 len=${input.length}`);
for (let attempt = 0; attempt <= RETRY_COUNT; attempt++) {
if (batchCancelled) return [];
try {
const response = await callLLM([
{ role: 'system', content: SYSTEM_PROMPT },
{ role: 'user', content: input },
], {
temperature: 0.2,
max_tokens: 500,
timeout,
});
if (!response || !String(response).trim()) {
xbLog.warn(MODULE_ID, `floor ${aiFloor} 解析失败:响应为空`);
if (attempt < RETRY_COUNT) {
await sleep(RETRY_DELAY);
continue;
}
return [];
}
let parsed;
try {
parsed = parseJson(response);
} catch (e) {
xbLog.warn(MODULE_ID, `floor ${aiFloor} 解析失败JSON 异常`);
if (attempt < RETRY_COUNT) {
await sleep(RETRY_DELAY);
continue;
}
return [];
}
if (!parsed?.atoms || !Array.isArray(parsed.atoms)) {
xbLog.warn(MODULE_ID, `floor ${aiFloor} 解析失败atoms 缺失`);
if (attempt < RETRY_COUNT) {
await sleep(RETRY_DELAY);
continue;
}
return [];
}
return parsed.atoms
.filter(a => a?.t && a?.v)
.map((a, idx) => ({
atomId: `atom-${aiFloor}-${idx}`,
floor: aiFloor,
type: a.t,
subject: a.s || null,
value: String(a.v).slice(0, 30),
source: a.f === 'u' ? 'user' : 'ai',
semantic: buildSemantic(a, userName, aiName),
}));
} catch (e) {
if (batchCancelled) return [];
if (attempt < RETRY_COUNT) {
xbLog.warn(MODULE_ID, `floor ${aiFloor}${attempt + 1}次失败,重试...`, e?.message);
await sleep(RETRY_DELAY * (attempt + 1));
continue;
}
xbLog.error(MODULE_ID, `floor ${aiFloor} 失败`, e);
return [];
}
}
return [];
}
/**
* 单轮配对提取(增量时使用)
*/
export async function extractAtomsForRound(userMessage, aiMessage, aiFloor, options = {}) {
return extractAtomsForRoundWithRetry(userMessage, aiMessage, aiFloor, options);
}
/**
* 批量提取(首批 staggered 启动)
* @param {Array} chat
* @param {Function} onProgress - (current, total, failed) => void
*/
export async function batchExtractAtoms(chat, onProgress) {
if (!chat?.length) return [];
batchCancelled = false;
const pairs = [];
for (let i = 0; i < chat.length; i++) {
if (!chat[i].is_user) {
const userMsg = (i > 0 && chat[i - 1]?.is_user) ? chat[i - 1] : null;
pairs.push({ userMsg, aiMsg: chat[i], aiFloor: i });
}
}
if (!pairs.length) return [];
const allAtoms = [];
let completed = 0;
let failed = 0;
for (let i = 0; i < pairs.length; i += CONCURRENCY) {
if (batchCancelled) {
xbLog.info(MODULE_ID, `批量提取已取消 (${completed}/${pairs.length})`);
break;
}
const batch = pairs.slice(i, i + CONCURRENCY);
// ★ 首批 staggered 启动:错开 80ms 发送
if (i === 0) {
const promises = batch.map((pair, idx) => (async () => {
await sleep(idx * STAGGER_DELAY);
if (batchCancelled) return;
try {
const atoms = await extractAtomsForRoundWithRetry(pair.userMsg, pair.aiMsg, pair.aiFloor, { timeout: DEFAULT_TIMEOUT });
if (atoms?.length) {
allAtoms.push(...atoms);
} else {
failed++;
}
} catch {
failed++;
}
completed++;
onProgress?.(completed, pairs.length, failed);
})());
await Promise.all(promises);
} else {
// 后续批次正常并行
const promises = batch.map(pair =>
extractAtomsForRoundWithRetry(pair.userMsg, pair.aiMsg, pair.aiFloor, { timeout: DEFAULT_TIMEOUT })
.then(atoms => {
if (batchCancelled) return;
if (atoms?.length) {
allAtoms.push(...atoms);
} else {
failed++;
}
completed++;
onProgress?.(completed, pairs.length, failed);
})
.catch(() => {
if (batchCancelled) return;
failed++;
completed++;
onProgress?.(completed, pairs.length, failed);
})
);
await Promise.all(promises);
}
// 批次间隔
if (i + CONCURRENCY < pairs.length && !batchCancelled) {
await sleep(30);
}
}
const status = batchCancelled ? '已取消' : '完成';
xbLog.info(MODULE_ID, `批量提取${status}: ${allAtoms.length} atoms, ${completed}/${pairs.length}, ${failed} 失败`);
return allAtoms;
}