658 lines
22 KiB
JavaScript
658 lines
22 KiB
JavaScript
// ═══════════════════════════════════════════════════════════════════════════
|
||
// Story Summary - Embedding Service
|
||
// 统一的向量生成接口(本地模型 / 在线服务)
|
||
// ═══════════════════════════════════════════════════════════════════════════
|
||
|
||
import { xbLog } from '../../../core/debug-core.js';
|
||
|
||
const MODULE_ID = 'embedding';
|
||
|
||
// ═══════════════════════════════════════════════════════════════════════════
|
||
// 本地模型配置
|
||
// ═══════════════════════════════════════════════════════════════════════════
|
||
|
||
export const LOCAL_MODELS = {
|
||
'bge-small-zh': {
|
||
id: 'bge-small-zh',
|
||
name: '中文轻量 (51MB)',
|
||
hfId: 'Xenova/bge-small-zh-v1.5',
|
||
dims: 512,
|
||
desc: '手机/低配适用',
|
||
},
|
||
'bge-base-zh': {
|
||
id: 'bge-base-zh',
|
||
name: '中文标准 (102MB)',
|
||
hfId: 'Xenova/bge-base-zh-v1.5',
|
||
dims: 768,
|
||
desc: 'PC 推荐,效果更好',
|
||
},
|
||
'e5-small': {
|
||
id: 'e5-small',
|
||
name: '多语言 (118MB)',
|
||
hfId: 'Xenova/multilingual-e5-small',
|
||
dims: 384,
|
||
desc: '非中文用户',
|
||
},
|
||
};
|
||
|
||
export const DEFAULT_LOCAL_MODEL = 'bge-small-zh';
|
||
|
||
// ═══════════════════════════════════════════════════════════════════════════
|
||
// 在线服务配置
|
||
// ═══════════════════════════════════════════════════════════════════════════
|
||
|
||
export const ONLINE_PROVIDERS = {
|
||
siliconflow: {
|
||
id: 'siliconflow',
|
||
name: '硅基流动',
|
||
baseUrl: 'https://api.siliconflow.cn',
|
||
canFetchModels: false,
|
||
defaultModels: [
|
||
'BAAI/bge-m3',
|
||
'BAAI/bge-large-zh-v1.5',
|
||
'BAAI/bge-small-zh-v1.5',
|
||
],
|
||
},
|
||
cohere: {
|
||
id: 'cohere',
|
||
name: 'Cohere',
|
||
baseUrl: 'https://api.cohere.ai',
|
||
canFetchModels: false,
|
||
defaultModels: [
|
||
'embed-multilingual-v3.0',
|
||
'embed-english-v3.0',
|
||
],
|
||
// Cohere 使用不同的 API 格式
|
||
customEmbed: true,
|
||
},
|
||
openai: {
|
||
id: 'openai',
|
||
name: 'OpenAI 兼容',
|
||
baseUrl: '',
|
||
canFetchModels: true,
|
||
defaultModels: [],
|
||
},
|
||
};
|
||
|
||
// ═══════════════════════════════════════════════════════════════════════════
|
||
// 本地模型状态管理
|
||
// ═══════════════════════════════════════════════════════════════════════════
|
||
|
||
// 已加载的模型实例:{ modelId: pipeline }
|
||
const loadedPipelines = {};
|
||
|
||
// 当前正在下载的模型
|
||
let downloadingModelId = null;
|
||
let downloadAbortController = null;
|
||
|
||
// Worker for local embedding
|
||
let embeddingWorker = null;
|
||
let workerRequestId = 0;
|
||
const workerCallbacks = new Map();
|
||
|
||
function getWorker() {
|
||
if (!embeddingWorker) {
|
||
const workerPath = new URL('./embedder.worker.js', import.meta.url).href;
|
||
embeddingWorker = new Worker(workerPath, { type: 'module' });
|
||
|
||
embeddingWorker.onmessage = (e) => {
|
||
const { requestId, ...data } = e.data || {};
|
||
const callback = workerCallbacks.get(requestId);
|
||
if (callback) {
|
||
callback(data);
|
||
if (data.type === 'result' || data.type === 'error' || data.type === 'loaded') {
|
||
workerCallbacks.delete(requestId);
|
||
}
|
||
}
|
||
};
|
||
}
|
||
return embeddingWorker;
|
||
}
|
||
|
||
function workerRequest(message) {
|
||
return new Promise((resolve, reject) => {
|
||
const requestId = ++workerRequestId;
|
||
const worker = getWorker();
|
||
|
||
workerCallbacks.set(requestId, (data) => {
|
||
if (data.type === 'error') {
|
||
reject(new Error(data.error));
|
||
} else if (data.type === 'result') {
|
||
resolve(data.vectors);
|
||
} else if (data.type === 'loaded') {
|
||
resolve(true);
|
||
}
|
||
});
|
||
|
||
worker.postMessage({ ...message, requestId });
|
||
});
|
||
}
|
||
|
||
|
||
// ═══════════════════════════════════════════════════════════════════════════
|
||
// 本地模型操作
|
||
// ═══════════════════════════════════════════════════════════════════════════
|
||
|
||
/**
|
||
* 检查指定本地模型的状态
|
||
* 只读取缓存,绝不触发下载
|
||
*/
|
||
export async function checkLocalModelStatus(modelId = DEFAULT_LOCAL_MODEL) {
|
||
const modelConfig = LOCAL_MODELS[modelId];
|
||
if (!modelConfig) {
|
||
return { status: 'error', message: '未知模型' };
|
||
}
|
||
|
||
// 已加载到内存
|
||
if (loadedPipelines[modelId]) {
|
||
return { status: 'ready', message: '已就绪' };
|
||
}
|
||
|
||
// 正在下载
|
||
if (downloadingModelId === modelId) {
|
||
return { status: 'downloading', message: '下载中' };
|
||
}
|
||
|
||
// 检查 IndexedDB 缓存
|
||
const hasCache = await checkModelCache(modelConfig.hfId);
|
||
if (hasCache) {
|
||
return { status: 'cached', message: '已缓存,可加载' };
|
||
}
|
||
|
||
return { status: 'not_downloaded', message: '未下载' };
|
||
}
|
||
|
||
/**
|
||
* 检查 IndexedDB 中是否有模型缓存
|
||
*/
|
||
async function checkModelCache(hfId) {
|
||
return new Promise((resolve) => {
|
||
try {
|
||
const request = indexedDB.open('transformers-cache', 1);
|
||
request.onerror = () => resolve(false);
|
||
request.onsuccess = (event) => {
|
||
const db = event.target.result;
|
||
const storeNames = Array.from(db.objectStoreNames);
|
||
db.close();
|
||
// 检查是否有该模型的缓存
|
||
const modelKey = hfId.replace('/', '_');
|
||
const hasModel = storeNames.some(name =>
|
||
name.includes(modelKey) || name.includes('onnx')
|
||
);
|
||
resolve(hasModel);
|
||
};
|
||
request.onupgradeneeded = () => resolve(false);
|
||
} catch {
|
||
resolve(false);
|
||
}
|
||
});
|
||
}
|
||
|
||
/**
|
||
* 下载/加载本地模型
|
||
* @param {string} modelId - 模型ID
|
||
* @param {Function} onProgress - 进度回调 (0-100)
|
||
* @returns {Promise<boolean>}
|
||
*/
|
||
export async function downloadLocalModel(modelId = DEFAULT_LOCAL_MODEL, onProgress) {
|
||
const modelConfig = LOCAL_MODELS[modelId];
|
||
if (!modelConfig) {
|
||
throw new Error(`未知模型: ${modelId}`);
|
||
}
|
||
// 已加载
|
||
if (loadedPipelines[modelId]) {
|
||
onProgress?.(100);
|
||
return true;
|
||
}
|
||
// 正在下载其他模型
|
||
if (downloadingModelId && downloadingModelId !== modelId) {
|
||
throw new Error(`正在下载其他模型: ${downloadingModelId}`);
|
||
}
|
||
// 正在下载同一模型,等待完成
|
||
if (downloadingModelId === modelId) {
|
||
xbLog.info(MODULE_ID, `模型 ${modelId} 正在加载中...`);
|
||
return new Promise((resolve, reject) => {
|
||
const check = () => {
|
||
if (loadedPipelines[modelId]) {
|
||
resolve(true);
|
||
} else if (downloadingModelId !== modelId) {
|
||
reject(new Error('下载已取消'));
|
||
} else {
|
||
setTimeout(check, 200);
|
||
}
|
||
};
|
||
check();
|
||
});
|
||
}
|
||
downloadingModelId = modelId;
|
||
downloadAbortController = new AbortController();
|
||
|
||
try {
|
||
xbLog.info(MODULE_ID, `开始下载模型: ${modelId}`);
|
||
|
||
return await new Promise((resolve, reject) => {
|
||
const requestId = ++workerRequestId;
|
||
const worker = getWorker();
|
||
|
||
workerCallbacks.set(requestId, (data) => {
|
||
if (data.type === 'progress') {
|
||
onProgress?.(data.percent);
|
||
} else if (data.type === 'loaded') {
|
||
loadedPipelines[modelId] = true;
|
||
workerCallbacks.delete(requestId);
|
||
resolve(true);
|
||
} else if (data.type === 'error') {
|
||
workerCallbacks.delete(requestId);
|
||
reject(new Error(data.error));
|
||
}
|
||
});
|
||
|
||
worker.postMessage({
|
||
type: 'load',
|
||
modelId,
|
||
hfId: modelConfig.hfId,
|
||
requestId
|
||
});
|
||
});
|
||
} finally {
|
||
downloadingModelId = null;
|
||
downloadAbortController = null;
|
||
}
|
||
}
|
||
|
||
export function cancelDownload() {
|
||
if (downloadAbortController) {
|
||
downloadAbortController.abort();
|
||
xbLog.info(MODULE_ID, '下载已取消');
|
||
}
|
||
downloadingModelId = null;
|
||
downloadAbortController = null;
|
||
}
|
||
|
||
/**
|
||
* 删除指定模型的缓存
|
||
*/
|
||
export async function deleteLocalModelCache(modelId = null) {
|
||
try {
|
||
// 删除 IndexedDB
|
||
await new Promise((resolve, reject) => {
|
||
const request = indexedDB.deleteDatabase('transformers-cache');
|
||
request.onsuccess = () => resolve();
|
||
request.onerror = () => reject(request.error);
|
||
request.onblocked = () => {
|
||
xbLog.warn(MODULE_ID, 'IndexedDB 删除被阻塞');
|
||
resolve();
|
||
};
|
||
});
|
||
|
||
// 删除 CacheStorage
|
||
if (window.caches) {
|
||
const cacheNames = await window.caches.keys();
|
||
for (const name of cacheNames) {
|
||
if (name.includes('transformers') || name.includes('huggingface') || name.includes('xenova')) {
|
||
await window.caches.delete(name);
|
||
}
|
||
}
|
||
}
|
||
|
||
// 清除内存中的 pipeline
|
||
if (modelId && loadedPipelines[modelId]) {
|
||
delete loadedPipelines[modelId];
|
||
} else {
|
||
Object.keys(loadedPipelines).forEach(key => delete loadedPipelines[key]);
|
||
}
|
||
|
||
xbLog.info(MODULE_ID, '模型缓存已清除');
|
||
return true;
|
||
} catch (e) {
|
||
xbLog.error(MODULE_ID, '清除缓存失败', e);
|
||
throw e;
|
||
}
|
||
}
|
||
|
||
/**
|
||
* 使用本地模型生成向量
|
||
*/
|
||
async function embedLocal(texts, modelId = DEFAULT_LOCAL_MODEL) {
|
||
if (!loadedPipelines[modelId]) {
|
||
await downloadLocalModel(modelId);
|
||
}
|
||
|
||
return await workerRequest({ type: 'embed', texts });
|
||
}
|
||
|
||
export function isLocalModelLoaded(modelId = DEFAULT_LOCAL_MODEL) {
|
||
return !!loadedPipelines[modelId];
|
||
}
|
||
|
||
/**
|
||
* 获取本地模型信息
|
||
*/
|
||
export function getLocalModelInfo(modelId = DEFAULT_LOCAL_MODEL) {
|
||
return LOCAL_MODELS[modelId] || null;
|
||
}
|
||
|
||
// ═══════════════════════════════════════════════════════════════════════════
|
||
// 在线服务操作
|
||
// ═══════════════════════════════════════════════════════════════════════════
|
||
|
||
/**
|
||
* 测试在线服务连接
|
||
*/
|
||
export async function testOnlineService(provider, config) {
|
||
const { url, key, model } = config;
|
||
|
||
if (!key) {
|
||
throw new Error('请填写 API Key');
|
||
}
|
||
if (!model) {
|
||
throw new Error('请选择模型');
|
||
}
|
||
|
||
const providerConfig = ONLINE_PROVIDERS[provider];
|
||
const baseUrl = (providerConfig?.baseUrl || url || '').replace(/\/+$/, '');
|
||
|
||
if (!baseUrl) {
|
||
throw new Error('请填写 API URL');
|
||
}
|
||
|
||
try {
|
||
if (provider === 'cohere') {
|
||
// Cohere 使用不同的 API 格式
|
||
const response = await fetch(`${baseUrl}/v1/embed`, {
|
||
method: 'POST',
|
||
headers: {
|
||
'Authorization': `Bearer ${key}`,
|
||
'Content-Type': 'application/json',
|
||
},
|
||
body: JSON.stringify({
|
||
model: model,
|
||
texts: ['测试连接'],
|
||
input_type: 'search_document',
|
||
}),
|
||
});
|
||
|
||
if (!response.ok) {
|
||
const error = await response.text();
|
||
throw new Error(`API 返回 ${response.status}: ${error}`);
|
||
}
|
||
|
||
const data = await response.json();
|
||
const dims = data.embeddings?.[0]?.length || 0;
|
||
|
||
if (dims === 0) {
|
||
throw new Error('API 返回的向量维度为 0');
|
||
}
|
||
|
||
return { success: true, dims };
|
||
|
||
} else {
|
||
// OpenAI 兼容格式
|
||
const response = await fetch(`${baseUrl}/v1/embeddings`, {
|
||
method: 'POST',
|
||
headers: {
|
||
'Authorization': `Bearer ${key}`,
|
||
'Content-Type': 'application/json',
|
||
},
|
||
body: JSON.stringify({
|
||
model: model,
|
||
input: ['测试连接'],
|
||
}),
|
||
});
|
||
|
||
if (!response.ok) {
|
||
const error = await response.text();
|
||
throw new Error(`API 返回 ${response.status}: ${error}`);
|
||
}
|
||
|
||
const data = await response.json();
|
||
const dims = data.data?.[0]?.embedding?.length || 0;
|
||
|
||
if (dims === 0) {
|
||
throw new Error('API 返回的向量维度为 0');
|
||
}
|
||
|
||
return { success: true, dims };
|
||
}
|
||
|
||
} catch (e) {
|
||
if (e.name === 'TypeError' && e.message.includes('fetch')) {
|
||
throw new Error('网络错误,请检查 URL 是否正确');
|
||
}
|
||
throw e;
|
||
}
|
||
}
|
||
|
||
/**
|
||
* 拉取在线模型列表(仅 OpenAI 兼容)
|
||
*/
|
||
export async function fetchOnlineModels(config) {
|
||
const { url, key } = config;
|
||
|
||
if (!url || !key) {
|
||
throw new Error('请填写 URL 和 Key');
|
||
}
|
||
|
||
const baseUrl = url.replace(/\/+$/, '').replace(/\/v1$/, '');
|
||
|
||
const response = await fetch(`${baseUrl}/v1/models`, {
|
||
headers: {
|
||
'Authorization': `Bearer ${key}`,
|
||
'Accept': 'application/json',
|
||
},
|
||
});
|
||
|
||
if (!response.ok) {
|
||
throw new Error(`获取模型列表失败: ${response.status}`);
|
||
}
|
||
|
||
const data = await response.json();
|
||
const models = data.data?.map(m => m.id).filter(Boolean) || [];
|
||
|
||
// 过滤出 embedding 相关的模型
|
||
const embeddingModels = models.filter(m => {
|
||
const lower = m.toLowerCase();
|
||
return lower.includes('embed') ||
|
||
lower.includes('bge') ||
|
||
lower.includes('e5') ||
|
||
lower.includes('gte');
|
||
});
|
||
|
||
return embeddingModels.length > 0 ? embeddingModels : models.slice(0, 20);
|
||
}
|
||
|
||
/**
|
||
* 使用在线服务生成向量
|
||
*/
|
||
async function embedOnline(texts, provider, config, options = {}) {
|
||
const { url, key, model } = config;
|
||
const signal = options?.signal;
|
||
|
||
const providerConfig = ONLINE_PROVIDERS[provider];
|
||
const baseUrl = (providerConfig?.baseUrl || url || '').replace(/\/+$/, '');
|
||
|
||
const reqId = Math.random().toString(36).slice(2, 6);
|
||
|
||
// 永远重试:指数退避 + 上限 + 抖动
|
||
const BASE_WAIT_MS = 1200;
|
||
const MAX_WAIT_MS = 15000;
|
||
|
||
const sleepAbortable = (ms) => new Promise((resolve, reject) => {
|
||
if (signal?.aborted) return reject(new DOMException('Aborted', 'AbortError'));
|
||
const t = setTimeout(resolve, ms);
|
||
if (signal) {
|
||
signal.addEventListener('abort', () => {
|
||
clearTimeout(t);
|
||
reject(new DOMException('Aborted', 'AbortError'));
|
||
}, { once: true });
|
||
}
|
||
});
|
||
|
||
let attempt = 0;
|
||
while (true) {
|
||
attempt++;
|
||
const startTime = Date.now();
|
||
console.log(`[embed ${reqId}] send ${texts.length} items (attempt ${attempt})`);
|
||
|
||
try {
|
||
let response;
|
||
|
||
if (provider === 'cohere') {
|
||
response = await fetch(`${baseUrl}/v1/embed`, {
|
||
method: 'POST',
|
||
headers: {
|
||
'Authorization': `Bearer ${key}`,
|
||
'Content-Type': 'application/json',
|
||
},
|
||
body: JSON.stringify({
|
||
model: model,
|
||
texts: texts,
|
||
input_type: 'search_document',
|
||
}),
|
||
signal,
|
||
});
|
||
} else {
|
||
response = await fetch(`${baseUrl}/v1/embeddings`, {
|
||
method: 'POST',
|
||
headers: {
|
||
'Authorization': `Bearer ${key}`,
|
||
'Content-Type': 'application/json',
|
||
},
|
||
body: JSON.stringify({
|
||
model: model,
|
||
input: texts,
|
||
}),
|
||
signal,
|
||
});
|
||
}
|
||
|
||
console.log(`[embed ${reqId}] status=${response.status} time=${Date.now() - startTime}ms`);
|
||
|
||
// 需要“永远重试”的典型状态:
|
||
// - 429:限流
|
||
// - 403:配额/风控/未实名等(你提到的硅基未认证)
|
||
// - 5xx:服务端错误
|
||
const retryableStatus = (s) => s === 429 || s === 403 || (s >= 500 && s <= 599);
|
||
|
||
if (!response.ok) {
|
||
const errorText = await response.text().catch(() => '');
|
||
|
||
if (retryableStatus(response.status)) {
|
||
const exp = Math.min(MAX_WAIT_MS, BASE_WAIT_MS * Math.pow(2, Math.min(attempt, 6) - 1));
|
||
const jitter = Math.floor(Math.random() * 350);
|
||
const waitMs = exp + jitter;
|
||
console.warn(`[embed ${reqId}] retryable error ${response.status}, wait ${waitMs}ms`);
|
||
await sleepAbortable(waitMs);
|
||
continue;
|
||
}
|
||
|
||
// 非可恢复错误:直接抛出(比如 400 参数错、401 key 错等)
|
||
const err = new Error(`API 返回 ${response.status}: ${errorText}`);
|
||
err.status = response.status;
|
||
throw err;
|
||
}
|
||
|
||
const data = await response.json();
|
||
|
||
if (provider === 'cohere') {
|
||
return (data.embeddings || []).map(e => Array.isArray(e) ? e : Array.from(e));
|
||
}
|
||
return (data.data || []).map(item => {
|
||
const embedding = item.embedding;
|
||
return Array.isArray(embedding) ? embedding : Array.from(embedding);
|
||
});
|
||
} catch (e) {
|
||
// 取消:必须立刻退出
|
||
if (e?.name === 'AbortError') throw e;
|
||
|
||
// 网络错误:永远重试
|
||
const exp = Math.min(MAX_WAIT_MS, BASE_WAIT_MS * Math.pow(2, Math.min(attempt, 6) - 1));
|
||
const jitter = Math.floor(Math.random() * 350);
|
||
const waitMs = exp + jitter;
|
||
console.warn(`[embed ${reqId}] network/error, wait ${waitMs}ms then retry: ${e?.message || e}`);
|
||
await sleepAbortable(waitMs);
|
||
}
|
||
}
|
||
}
|
||
|
||
|
||
// ═══════════════════════════════════════════════════════════════════════════
|
||
// 统一接口
|
||
// ═══════════════════════════════════════════════════════════════════════════
|
||
|
||
/**
|
||
* 生成向量(统一接口)
|
||
* @param {string[]} texts - 要向量化的文本数组
|
||
* @param {Object} config - 配置
|
||
* @returns {Promise<number[][]>}
|
||
*/
|
||
export async function embed(texts, config, options = {}) {
|
||
if (!texts?.length) return [];
|
||
|
||
const { engine, local, online } = config;
|
||
|
||
if (engine === 'local') {
|
||
const modelId = local?.modelId || DEFAULT_LOCAL_MODEL;
|
||
return await embedLocal(texts, modelId);
|
||
|
||
} else if (engine === 'online') {
|
||
const provider = online?.provider || 'siliconflow';
|
||
if (!online?.key || !online?.model) {
|
||
throw new Error('在线服务配置不完整');
|
||
}
|
||
return await embedOnline(texts, provider, online, options);
|
||
|
||
} else {
|
||
throw new Error(`未知的引擎类型: ${engine}`);
|
||
}
|
||
}
|
||
|
||
/**
|
||
* 获取当前引擎的唯一标识(用于检查向量是否匹配)
|
||
*/
|
||
|
||
// Concurrent embed for online services (local falls back to sequential)
|
||
export async function embedBatchesConcurrent(textBatches, config, concurrency = 3) {
|
||
if (config.engine === 'local' || textBatches.length <= 1) {
|
||
const results = [];
|
||
for (const batch of textBatches) {
|
||
results.push(await embed(batch, config));
|
||
}
|
||
return results;
|
||
}
|
||
|
||
const results = new Array(textBatches.length);
|
||
let index = 0;
|
||
|
||
async function worker() {
|
||
while (index < textBatches.length) {
|
||
const i = index++;
|
||
results[i] = await embed(textBatches[i], config);
|
||
}
|
||
}
|
||
|
||
await Promise.all(
|
||
Array(Math.min(concurrency, textBatches.length))
|
||
.fill(null)
|
||
.map(() => worker())
|
||
);
|
||
|
||
return results;
|
||
}
|
||
|
||
export function getEngineFingerprint(config) {
|
||
if (config.engine === 'local') {
|
||
const modelId = config.local?.modelId || DEFAULT_LOCAL_MODEL;
|
||
const modelConfig = LOCAL_MODELS[modelId];
|
||
return `local:${modelId}:${modelConfig?.dims || 512}`;
|
||
|
||
} else if (config.engine === 'online') {
|
||
const provider = config.online?.provider || 'unknown';
|
||
const model = config.online?.model || 'unknown';
|
||
return `online:${provider}:${model}`;
|
||
|
||
} else {
|
||
return 'unknown';
|
||
}
|
||
}
|