Files
LittleWhiteBox/modules/story-summary/vector/embedder.js

658 lines
22 KiB
JavaScript
Raw Normal View History

2026-01-26 01:16:35 +08:00
// ═══════════════════════════════════════════════════════════════════════════
// Story Summary - Embedding Service
// 统一的向量生成接口(本地模型 / 在线服务)
// ═══════════════════════════════════════════════════════════════════════════
import { xbLog } from '../../../core/debug-core.js';
const MODULE_ID = 'embedding';
// ═══════════════════════════════════════════════════════════════════════════
// 本地模型配置
// ═══════════════════════════════════════════════════════════════════════════
export const LOCAL_MODELS = {
'bge-small-zh': {
id: 'bge-small-zh',
name: '中文轻量 (51MB)',
hfId: 'Xenova/bge-small-zh-v1.5',
dims: 512,
desc: '手机/低配适用',
},
'bge-base-zh': {
id: 'bge-base-zh',
name: '中文标准 (102MB)',
hfId: 'Xenova/bge-base-zh-v1.5',
dims: 768,
desc: 'PC 推荐,效果更好',
},
'e5-small': {
id: 'e5-small',
name: '多语言 (118MB)',
hfId: 'Xenova/multilingual-e5-small',
dims: 384,
desc: '非中文用户',
},
};
export const DEFAULT_LOCAL_MODEL = 'bge-small-zh';
// ═══════════════════════════════════════════════════════════════════════════
// 在线服务配置
// ═══════════════════════════════════════════════════════════════════════════
export const ONLINE_PROVIDERS = {
siliconflow: {
id: 'siliconflow',
name: '硅基流动',
baseUrl: 'https://api.siliconflow.cn',
canFetchModels: false,
defaultModels: [
'BAAI/bge-m3',
'BAAI/bge-large-zh-v1.5',
'BAAI/bge-small-zh-v1.5',
],
},
cohere: {
id: 'cohere',
name: 'Cohere',
baseUrl: 'https://api.cohere.ai',
canFetchModels: false,
defaultModels: [
'embed-multilingual-v3.0',
'embed-english-v3.0',
],
// Cohere 使用不同的 API 格式
customEmbed: true,
},
openai: {
id: 'openai',
name: 'OpenAI 兼容',
baseUrl: '',
canFetchModels: true,
defaultModels: [],
},
};
// ═══════════════════════════════════════════════════════════════════════════
// 本地模型状态管理
// ═══════════════════════════════════════════════════════════════════════════
// 已加载的模型实例:{ modelId: pipeline }
const loadedPipelines = {};
// 当前正在下载的模型
let downloadingModelId = null;
let downloadAbortController = null;
// Worker for local embedding
let embeddingWorker = null;
let workerRequestId = 0;
const workerCallbacks = new Map();
function getWorker() {
if (!embeddingWorker) {
const workerPath = new URL('./embedder.worker.js', import.meta.url).href;
embeddingWorker = new Worker(workerPath, { type: 'module' });
embeddingWorker.onmessage = (e) => {
const { requestId, ...data } = e.data || {};
const callback = workerCallbacks.get(requestId);
if (callback) {
callback(data);
if (data.type === 'result' || data.type === 'error' || data.type === 'loaded') {
workerCallbacks.delete(requestId);
}
}
};
}
return embeddingWorker;
}
function workerRequest(message) {
return new Promise((resolve, reject) => {
const requestId = ++workerRequestId;
const worker = getWorker();
workerCallbacks.set(requestId, (data) => {
if (data.type === 'error') {
reject(new Error(data.error));
} else if (data.type === 'result') {
resolve(data.vectors);
} else if (data.type === 'loaded') {
resolve(true);
}
});
worker.postMessage({ ...message, requestId });
});
}
// ═══════════════════════════════════════════════════════════════════════════
// 本地模型操作
// ═══════════════════════════════════════════════════════════════════════════
/**
* 检查指定本地模型的状态
* 只读取缓存绝不触发下载
*/
export async function checkLocalModelStatus(modelId = DEFAULT_LOCAL_MODEL) {
const modelConfig = LOCAL_MODELS[modelId];
if (!modelConfig) {
return { status: 'error', message: '未知模型' };
}
// 已加载到内存
if (loadedPipelines[modelId]) {
return { status: 'ready', message: '已就绪' };
}
// 正在下载
if (downloadingModelId === modelId) {
return { status: 'downloading', message: '下载中' };
}
// 检查 IndexedDB 缓存
const hasCache = await checkModelCache(modelConfig.hfId);
if (hasCache) {
return { status: 'cached', message: '已缓存,可加载' };
}
return { status: 'not_downloaded', message: '未下载' };
}
/**
* 检查 IndexedDB 中是否有模型缓存
*/
async function checkModelCache(hfId) {
return new Promise((resolve) => {
try {
const request = indexedDB.open('transformers-cache', 1);
request.onerror = () => resolve(false);
request.onsuccess = (event) => {
const db = event.target.result;
const storeNames = Array.from(db.objectStoreNames);
db.close();
// 检查是否有该模型的缓存
const modelKey = hfId.replace('/', '_');
const hasModel = storeNames.some(name =>
name.includes(modelKey) || name.includes('onnx')
);
resolve(hasModel);
};
request.onupgradeneeded = () => resolve(false);
} catch {
resolve(false);
}
});
}
/**
* 下载/加载本地模型
* @param {string} modelId - 模型ID
* @param {Function} onProgress - 进度回调 (0-100)
* @returns {Promise<boolean>}
*/
export async function downloadLocalModel(modelId = DEFAULT_LOCAL_MODEL, onProgress) {
const modelConfig = LOCAL_MODELS[modelId];
if (!modelConfig) {
throw new Error(`未知模型: ${modelId}`);
}
// 已加载
if (loadedPipelines[modelId]) {
onProgress?.(100);
return true;
}
// 正在下载其他模型
if (downloadingModelId && downloadingModelId !== modelId) {
throw new Error(`正在下载其他模型: ${downloadingModelId}`);
}
// 正在下载同一模型,等待完成
if (downloadingModelId === modelId) {
xbLog.info(MODULE_ID, `模型 ${modelId} 正在加载中...`);
return new Promise((resolve, reject) => {
const check = () => {
if (loadedPipelines[modelId]) {
resolve(true);
} else if (downloadingModelId !== modelId) {
reject(new Error('下载已取消'));
} else {
setTimeout(check, 200);
}
};
check();
});
}
downloadingModelId = modelId;
downloadAbortController = new AbortController();
try {
xbLog.info(MODULE_ID, `开始下载模型: ${modelId}`);
return await new Promise((resolve, reject) => {
const requestId = ++workerRequestId;
const worker = getWorker();
workerCallbacks.set(requestId, (data) => {
if (data.type === 'progress') {
onProgress?.(data.percent);
} else if (data.type === 'loaded') {
loadedPipelines[modelId] = true;
workerCallbacks.delete(requestId);
resolve(true);
} else if (data.type === 'error') {
workerCallbacks.delete(requestId);
reject(new Error(data.error));
}
});
worker.postMessage({
type: 'load',
modelId,
hfId: modelConfig.hfId,
requestId
});
});
} finally {
downloadingModelId = null;
downloadAbortController = null;
}
}
export function cancelDownload() {
if (downloadAbortController) {
downloadAbortController.abort();
xbLog.info(MODULE_ID, '下载已取消');
}
downloadingModelId = null;
downloadAbortController = null;
}
/**
* 删除指定模型的缓存
*/
export async function deleteLocalModelCache(modelId = null) {
try {
// 删除 IndexedDB
await new Promise((resolve, reject) => {
const request = indexedDB.deleteDatabase('transformers-cache');
request.onsuccess = () => resolve();
request.onerror = () => reject(request.error);
request.onblocked = () => {
xbLog.warn(MODULE_ID, 'IndexedDB 删除被阻塞');
resolve();
};
});
// 删除 CacheStorage
if (window.caches) {
const cacheNames = await window.caches.keys();
for (const name of cacheNames) {
if (name.includes('transformers') || name.includes('huggingface') || name.includes('xenova')) {
await window.caches.delete(name);
}
}
}
// 清除内存中的 pipeline
if (modelId && loadedPipelines[modelId]) {
delete loadedPipelines[modelId];
} else {
Object.keys(loadedPipelines).forEach(key => delete loadedPipelines[key]);
}
xbLog.info(MODULE_ID, '模型缓存已清除');
return true;
} catch (e) {
xbLog.error(MODULE_ID, '清除缓存失败', e);
throw e;
}
}
/**
* 使用本地模型生成向量
*/
async function embedLocal(texts, modelId = DEFAULT_LOCAL_MODEL) {
if (!loadedPipelines[modelId]) {
await downloadLocalModel(modelId);
}
return await workerRequest({ type: 'embed', texts });
}
export function isLocalModelLoaded(modelId = DEFAULT_LOCAL_MODEL) {
return !!loadedPipelines[modelId];
}
/**
* 获取本地模型信息
*/
export function getLocalModelInfo(modelId = DEFAULT_LOCAL_MODEL) {
return LOCAL_MODELS[modelId] || null;
}
// ═══════════════════════════════════════════════════════════════════════════
// 在线服务操作
// ═══════════════════════════════════════════════════════════════════════════
/**
* 测试在线服务连接
*/
export async function testOnlineService(provider, config) {
const { url, key, model } = config;
if (!key) {
throw new Error('请填写 API Key');
}
if (!model) {
throw new Error('请选择模型');
}
const providerConfig = ONLINE_PROVIDERS[provider];
const baseUrl = (providerConfig?.baseUrl || url || '').replace(/\/+$/, '');
if (!baseUrl) {
throw new Error('请填写 API URL');
}
try {
if (provider === 'cohere') {
// Cohere 使用不同的 API 格式
const response = await fetch(`${baseUrl}/v1/embed`, {
method: 'POST',
headers: {
'Authorization': `Bearer ${key}`,
'Content-Type': 'application/json',
},
body: JSON.stringify({
model: model,
texts: ['测试连接'],
input_type: 'search_document',
}),
});
if (!response.ok) {
const error = await response.text();
throw new Error(`API 返回 ${response.status}: ${error}`);
}
const data = await response.json();
const dims = data.embeddings?.[0]?.length || 0;
if (dims === 0) {
throw new Error('API 返回的向量维度为 0');
}
return { success: true, dims };
} else {
// OpenAI 兼容格式
const response = await fetch(`${baseUrl}/v1/embeddings`, {
method: 'POST',
headers: {
'Authorization': `Bearer ${key}`,
'Content-Type': 'application/json',
},
body: JSON.stringify({
model: model,
input: ['测试连接'],
}),
});
if (!response.ok) {
const error = await response.text();
throw new Error(`API 返回 ${response.status}: ${error}`);
}
const data = await response.json();
const dims = data.data?.[0]?.embedding?.length || 0;
if (dims === 0) {
throw new Error('API 返回的向量维度为 0');
}
return { success: true, dims };
}
} catch (e) {
if (e.name === 'TypeError' && e.message.includes('fetch')) {
throw new Error('网络错误,请检查 URL 是否正确');
}
throw e;
}
}
/**
* 拉取在线模型列表 OpenAI 兼容
*/
export async function fetchOnlineModels(config) {
const { url, key } = config;
if (!url || !key) {
throw new Error('请填写 URL 和 Key');
}
const baseUrl = url.replace(/\/+$/, '').replace(/\/v1$/, '');
const response = await fetch(`${baseUrl}/v1/models`, {
headers: {
'Authorization': `Bearer ${key}`,
'Accept': 'application/json',
},
});
if (!response.ok) {
throw new Error(`获取模型列表失败: ${response.status}`);
}
const data = await response.json();
const models = data.data?.map(m => m.id).filter(Boolean) || [];
// 过滤出 embedding 相关的模型
const embeddingModels = models.filter(m => {
const lower = m.toLowerCase();
return lower.includes('embed') ||
lower.includes('bge') ||
lower.includes('e5') ||
lower.includes('gte');
});
return embeddingModels.length > 0 ? embeddingModels : models.slice(0, 20);
}
/**
* 使用在线服务生成向量
*/
2026-01-29 01:18:50 +08:00
async function embedOnline(texts, provider, config, options = {}) {
2026-01-26 01:16:35 +08:00
const { url, key, model } = config;
2026-01-29 01:18:50 +08:00
const signal = options?.signal;
2026-01-26 01:16:35 +08:00
const providerConfig = ONLINE_PROVIDERS[provider];
const baseUrl = (providerConfig?.baseUrl || url || '').replace(/\/+$/, '');
const reqId = Math.random().toString(36).slice(2, 6);
2026-01-29 01:18:50 +08:00
// 永远重试:指数退避 + 上限 + 抖动
const BASE_WAIT_MS = 1200;
const MAX_WAIT_MS = 15000;
const sleepAbortable = (ms) => new Promise((resolve, reject) => {
if (signal?.aborted) return reject(new DOMException('Aborted', 'AbortError'));
const t = setTimeout(resolve, ms);
if (signal) {
signal.addEventListener('abort', () => {
clearTimeout(t);
reject(new DOMException('Aborted', 'AbortError'));
}, { once: true });
}
});
let attempt = 0;
while (true) {
attempt++;
2026-01-26 01:16:35 +08:00
const startTime = Date.now();
2026-01-29 01:18:50 +08:00
console.log(`[embed ${reqId}] send ${texts.length} items (attempt ${attempt})`);
2026-01-26 01:16:35 +08:00
try {
let response;
if (provider === 'cohere') {
response = await fetch(`${baseUrl}/v1/embed`, {
method: 'POST',
headers: {
'Authorization': `Bearer ${key}`,
'Content-Type': 'application/json',
},
body: JSON.stringify({
model: model,
texts: texts,
input_type: 'search_document',
}),
2026-01-29 01:18:50 +08:00
signal,
2026-01-26 01:16:35 +08:00
});
} else {
response = await fetch(`${baseUrl}/v1/embeddings`, {
method: 'POST',
headers: {
'Authorization': `Bearer ${key}`,
'Content-Type': 'application/json',
},
body: JSON.stringify({
model: model,
input: texts,
}),
2026-01-29 01:18:50 +08:00
signal,
2026-01-26 01:16:35 +08:00
});
}
console.log(`[embed ${reqId}] status=${response.status} time=${Date.now() - startTime}ms`);
2026-01-29 01:18:50 +08:00
// 需要“永远重试”的典型状态:
// - 429限流
// - 403配额/风控/未实名等(你提到的硅基未认证)
// - 5xx服务端错误
const retryableStatus = (s) => s === 429 || s === 403 || (s >= 500 && s <= 599);
2026-01-26 01:16:35 +08:00
if (!response.ok) {
2026-01-29 01:18:50 +08:00
const errorText = await response.text().catch(() => '');
if (retryableStatus(response.status)) {
const exp = Math.min(MAX_WAIT_MS, BASE_WAIT_MS * Math.pow(2, Math.min(attempt, 6) - 1));
const jitter = Math.floor(Math.random() * 350);
const waitMs = exp + jitter;
console.warn(`[embed ${reqId}] retryable error ${response.status}, wait ${waitMs}ms`);
await sleepAbortable(waitMs);
continue;
}
// 非可恢复错误:直接抛出(比如 400 参数错、401 key 错等)
const err = new Error(`API 返回 ${response.status}: ${errorText}`);
err.status = response.status;
throw err;
2026-01-26 01:16:35 +08:00
}
const data = await response.json();
if (provider === 'cohere') {
2026-01-29 01:18:50 +08:00
return (data.embeddings || []).map(e => Array.isArray(e) ? e : Array.from(e));
2026-01-26 01:16:35 +08:00
}
2026-01-29 01:18:50 +08:00
return (data.data || []).map(item => {
2026-01-26 01:16:35 +08:00
const embedding = item.embedding;
return Array.isArray(embedding) ? embedding : Array.from(embedding);
});
} catch (e) {
2026-01-29 01:18:50 +08:00
// 取消:必须立刻退出
if (e?.name === 'AbortError') throw e;
// 网络错误:永远重试
const exp = Math.min(MAX_WAIT_MS, BASE_WAIT_MS * Math.pow(2, Math.min(attempt, 6) - 1));
const jitter = Math.floor(Math.random() * 350);
const waitMs = exp + jitter;
console.warn(`[embed ${reqId}] network/error, wait ${waitMs}ms then retry: ${e?.message || e}`);
await sleepAbortable(waitMs);
2026-01-26 01:16:35 +08:00
}
}
}
// ═══════════════════════════════════════════════════════════════════════════
// 统一接口
// ═══════════════════════════════════════════════════════════════════════════
/**
* 生成向量统一接口
* @param {string[]} texts - 要向量化的文本数组
* @param {Object} config - 配置
* @returns {Promise<number[][]>}
*/
2026-01-29 01:18:50 +08:00
export async function embed(texts, config, options = {}) {
2026-01-26 01:16:35 +08:00
if (!texts?.length) return [];
const { engine, local, online } = config;
if (engine === 'local') {
const modelId = local?.modelId || DEFAULT_LOCAL_MODEL;
return await embedLocal(texts, modelId);
} else if (engine === 'online') {
const provider = online?.provider || 'siliconflow';
if (!online?.key || !online?.model) {
throw new Error('在线服务配置不完整');
}
2026-01-29 01:18:50 +08:00
return await embedOnline(texts, provider, online, options);
2026-01-26 01:16:35 +08:00
} else {
throw new Error(`未知的引擎类型: ${engine}`);
}
}
/**
* 获取当前引擎的唯一标识用于检查向量是否匹配
*/
// Concurrent embed for online services (local falls back to sequential)
export async function embedBatchesConcurrent(textBatches, config, concurrency = 3) {
if (config.engine === 'local' || textBatches.length <= 1) {
const results = [];
for (const batch of textBatches) {
results.push(await embed(batch, config));
}
return results;
}
const results = new Array(textBatches.length);
let index = 0;
async function worker() {
while (index < textBatches.length) {
const i = index++;
results[i] = await embed(textBatches[i], config);
}
}
await Promise.all(
Array(Math.min(concurrency, textBatches.length))
.fill(null)
.map(() => worker())
);
return results;
}
export function getEngineFingerprint(config) {
if (config.engine === 'local') {
const modelId = config.local?.modelId || DEFAULT_LOCAL_MODEL;
const modelConfig = LOCAL_MODELS[modelId];
return `local:${modelId}:${modelConfig?.dims || 512}`;
} else if (config.engine === 'online') {
const provider = config.online?.provider || 'unknown';
const model = config.online?.model || 'unknown';
return `online:${provider}:${model}`;
} else {
return 'unknown';
}
}