625 lines
21 KiB
JavaScript
625 lines
21 KiB
JavaScript
// ═══════════════════════════════════════════════════════════════════════════
|
|
// Story Summary - Embedding Service
|
|
// 统一的向量生成接口(本地模型 / 在线服务)
|
|
// ═══════════════════════════════════════════════════════════════════════════
|
|
|
|
import { xbLog } from '../../../core/debug-core.js';
|
|
|
|
const MODULE_ID = 'embedding';
|
|
|
|
// ═══════════════════════════════════════════════════════════════════════════
|
|
// 本地模型配置
|
|
// ═══════════════════════════════════════════════════════════════════════════
|
|
|
|
export const LOCAL_MODELS = {
|
|
'bge-small-zh': {
|
|
id: 'bge-small-zh',
|
|
name: '中文轻量 (51MB)',
|
|
hfId: 'Xenova/bge-small-zh-v1.5',
|
|
dims: 512,
|
|
desc: '手机/低配适用',
|
|
},
|
|
'bge-base-zh': {
|
|
id: 'bge-base-zh',
|
|
name: '中文标准 (102MB)',
|
|
hfId: 'Xenova/bge-base-zh-v1.5',
|
|
dims: 768,
|
|
desc: 'PC 推荐,效果更好',
|
|
},
|
|
'e5-small': {
|
|
id: 'e5-small',
|
|
name: '多语言 (118MB)',
|
|
hfId: 'Xenova/multilingual-e5-small',
|
|
dims: 384,
|
|
desc: '非中文用户',
|
|
},
|
|
};
|
|
|
|
export const DEFAULT_LOCAL_MODEL = 'bge-small-zh';
|
|
|
|
// ═══════════════════════════════════════════════════════════════════════════
|
|
// 在线服务配置
|
|
// ═══════════════════════════════════════════════════════════════════════════
|
|
|
|
export const ONLINE_PROVIDERS = {
|
|
siliconflow: {
|
|
id: 'siliconflow',
|
|
name: '硅基流动',
|
|
baseUrl: 'https://api.siliconflow.cn',
|
|
canFetchModels: false,
|
|
defaultModels: [
|
|
'BAAI/bge-m3',
|
|
'BAAI/bge-large-zh-v1.5',
|
|
'BAAI/bge-small-zh-v1.5',
|
|
],
|
|
},
|
|
cohere: {
|
|
id: 'cohere',
|
|
name: 'Cohere',
|
|
baseUrl: 'https://api.cohere.ai',
|
|
canFetchModels: false,
|
|
defaultModels: [
|
|
'embed-multilingual-v3.0',
|
|
'embed-english-v3.0',
|
|
],
|
|
// Cohere 使用不同的 API 格式
|
|
customEmbed: true,
|
|
},
|
|
openai: {
|
|
id: 'openai',
|
|
name: 'OpenAI 兼容',
|
|
baseUrl: '',
|
|
canFetchModels: true,
|
|
defaultModels: [],
|
|
},
|
|
};
|
|
|
|
// ═══════════════════════════════════════════════════════════════════════════
|
|
// 本地模型状态管理
|
|
// ═══════════════════════════════════════════════════════════════════════════
|
|
|
|
// 已加载的模型实例:{ modelId: pipeline }
|
|
const loadedPipelines = {};
|
|
|
|
// 当前正在下载的模型
|
|
let downloadingModelId = null;
|
|
let downloadAbortController = null;
|
|
|
|
// Worker for local embedding
|
|
let embeddingWorker = null;
|
|
let workerRequestId = 0;
|
|
const workerCallbacks = new Map();
|
|
|
|
function getWorker() {
|
|
if (!embeddingWorker) {
|
|
const workerPath = new URL('./embedder.worker.js', import.meta.url).href;
|
|
embeddingWorker = new Worker(workerPath, { type: 'module' });
|
|
|
|
embeddingWorker.onmessage = (e) => {
|
|
const { requestId, ...data } = e.data || {};
|
|
const callback = workerCallbacks.get(requestId);
|
|
if (callback) {
|
|
callback(data);
|
|
if (data.type === 'result' || data.type === 'error' || data.type === 'loaded') {
|
|
workerCallbacks.delete(requestId);
|
|
}
|
|
}
|
|
};
|
|
}
|
|
return embeddingWorker;
|
|
}
|
|
|
|
function workerRequest(message) {
|
|
return new Promise((resolve, reject) => {
|
|
const requestId = ++workerRequestId;
|
|
const worker = getWorker();
|
|
|
|
workerCallbacks.set(requestId, (data) => {
|
|
if (data.type === 'error') {
|
|
reject(new Error(data.error));
|
|
} else if (data.type === 'result') {
|
|
resolve(data.vectors);
|
|
} else if (data.type === 'loaded') {
|
|
resolve(true);
|
|
}
|
|
});
|
|
|
|
worker.postMessage({ ...message, requestId });
|
|
});
|
|
}
|
|
|
|
|
|
// ═══════════════════════════════════════════════════════════════════════════
|
|
// 本地模型操作
|
|
// ═══════════════════════════════════════════════════════════════════════════
|
|
|
|
/**
|
|
* 检查指定本地模型的状态
|
|
* 只读取缓存,绝不触发下载
|
|
*/
|
|
export async function checkLocalModelStatus(modelId = DEFAULT_LOCAL_MODEL) {
|
|
const modelConfig = LOCAL_MODELS[modelId];
|
|
if (!modelConfig) {
|
|
return { status: 'error', message: '未知模型' };
|
|
}
|
|
|
|
// 已加载到内存
|
|
if (loadedPipelines[modelId]) {
|
|
return { status: 'ready', message: '已就绪' };
|
|
}
|
|
|
|
// 正在下载
|
|
if (downloadingModelId === modelId) {
|
|
return { status: 'downloading', message: '下载中' };
|
|
}
|
|
|
|
// 检查 IndexedDB 缓存
|
|
const hasCache = await checkModelCache(modelConfig.hfId);
|
|
if (hasCache) {
|
|
return { status: 'cached', message: '已缓存,可加载' };
|
|
}
|
|
|
|
return { status: 'not_downloaded', message: '未下载' };
|
|
}
|
|
|
|
/**
|
|
* 检查 IndexedDB 中是否有模型缓存
|
|
*/
|
|
async function checkModelCache(hfId) {
|
|
return new Promise((resolve) => {
|
|
try {
|
|
const request = indexedDB.open('transformers-cache', 1);
|
|
request.onerror = () => resolve(false);
|
|
request.onsuccess = (event) => {
|
|
const db = event.target.result;
|
|
const storeNames = Array.from(db.objectStoreNames);
|
|
db.close();
|
|
// 检查是否有该模型的缓存
|
|
const modelKey = hfId.replace('/', '_');
|
|
const hasModel = storeNames.some(name =>
|
|
name.includes(modelKey) || name.includes('onnx')
|
|
);
|
|
resolve(hasModel);
|
|
};
|
|
request.onupgradeneeded = () => resolve(false);
|
|
} catch {
|
|
resolve(false);
|
|
}
|
|
});
|
|
}
|
|
|
|
/**
|
|
* 下载/加载本地模型
|
|
* @param {string} modelId - 模型ID
|
|
* @param {Function} onProgress - 进度回调 (0-100)
|
|
* @returns {Promise<boolean>}
|
|
*/
|
|
export async function downloadLocalModel(modelId = DEFAULT_LOCAL_MODEL, onProgress) {
|
|
const modelConfig = LOCAL_MODELS[modelId];
|
|
if (!modelConfig) {
|
|
throw new Error(`未知模型: ${modelId}`);
|
|
}
|
|
// 已加载
|
|
if (loadedPipelines[modelId]) {
|
|
onProgress?.(100);
|
|
return true;
|
|
}
|
|
// 正在下载其他模型
|
|
if (downloadingModelId && downloadingModelId !== modelId) {
|
|
throw new Error(`正在下载其他模型: ${downloadingModelId}`);
|
|
}
|
|
// 正在下载同一模型,等待完成
|
|
if (downloadingModelId === modelId) {
|
|
xbLog.info(MODULE_ID, `模型 ${modelId} 正在加载中...`);
|
|
return new Promise((resolve, reject) => {
|
|
const check = () => {
|
|
if (loadedPipelines[modelId]) {
|
|
resolve(true);
|
|
} else if (downloadingModelId !== modelId) {
|
|
reject(new Error('下载已取消'));
|
|
} else {
|
|
setTimeout(check, 200);
|
|
}
|
|
};
|
|
check();
|
|
});
|
|
}
|
|
downloadingModelId = modelId;
|
|
downloadAbortController = new AbortController();
|
|
|
|
try {
|
|
xbLog.info(MODULE_ID, `开始下载模型: ${modelId}`);
|
|
|
|
return await new Promise((resolve, reject) => {
|
|
const requestId = ++workerRequestId;
|
|
const worker = getWorker();
|
|
|
|
workerCallbacks.set(requestId, (data) => {
|
|
if (data.type === 'progress') {
|
|
onProgress?.(data.percent);
|
|
} else if (data.type === 'loaded') {
|
|
loadedPipelines[modelId] = true;
|
|
workerCallbacks.delete(requestId);
|
|
resolve(true);
|
|
} else if (data.type === 'error') {
|
|
workerCallbacks.delete(requestId);
|
|
reject(new Error(data.error));
|
|
}
|
|
});
|
|
|
|
worker.postMessage({
|
|
type: 'load',
|
|
modelId,
|
|
hfId: modelConfig.hfId,
|
|
requestId
|
|
});
|
|
});
|
|
} finally {
|
|
downloadingModelId = null;
|
|
downloadAbortController = null;
|
|
}
|
|
}
|
|
|
|
export function cancelDownload() {
|
|
if (downloadAbortController) {
|
|
downloadAbortController.abort();
|
|
xbLog.info(MODULE_ID, '下载已取消');
|
|
}
|
|
downloadingModelId = null;
|
|
downloadAbortController = null;
|
|
}
|
|
|
|
/**
|
|
* 删除指定模型的缓存
|
|
*/
|
|
export async function deleteLocalModelCache(modelId = null) {
|
|
try {
|
|
// 删除 IndexedDB
|
|
await new Promise((resolve, reject) => {
|
|
const request = indexedDB.deleteDatabase('transformers-cache');
|
|
request.onsuccess = () => resolve();
|
|
request.onerror = () => reject(request.error);
|
|
request.onblocked = () => {
|
|
xbLog.warn(MODULE_ID, 'IndexedDB 删除被阻塞');
|
|
resolve();
|
|
};
|
|
});
|
|
|
|
// 删除 CacheStorage
|
|
if (window.caches) {
|
|
const cacheNames = await window.caches.keys();
|
|
for (const name of cacheNames) {
|
|
if (name.includes('transformers') || name.includes('huggingface') || name.includes('xenova')) {
|
|
await window.caches.delete(name);
|
|
}
|
|
}
|
|
}
|
|
|
|
// 清除内存中的 pipeline
|
|
if (modelId && loadedPipelines[modelId]) {
|
|
delete loadedPipelines[modelId];
|
|
} else {
|
|
Object.keys(loadedPipelines).forEach(key => delete loadedPipelines[key]);
|
|
}
|
|
|
|
xbLog.info(MODULE_ID, '模型缓存已清除');
|
|
return true;
|
|
} catch (e) {
|
|
xbLog.error(MODULE_ID, '清除缓存失败', e);
|
|
throw e;
|
|
}
|
|
}
|
|
|
|
/**
|
|
* 使用本地模型生成向量
|
|
*/
|
|
async function embedLocal(texts, modelId = DEFAULT_LOCAL_MODEL) {
|
|
if (!loadedPipelines[modelId]) {
|
|
await downloadLocalModel(modelId);
|
|
}
|
|
|
|
return await workerRequest({ type: 'embed', texts });
|
|
}
|
|
|
|
export function isLocalModelLoaded(modelId = DEFAULT_LOCAL_MODEL) {
|
|
return !!loadedPipelines[modelId];
|
|
}
|
|
|
|
/**
|
|
* 获取本地模型信息
|
|
*/
|
|
export function getLocalModelInfo(modelId = DEFAULT_LOCAL_MODEL) {
|
|
return LOCAL_MODELS[modelId] || null;
|
|
}
|
|
|
|
// ═══════════════════════════════════════════════════════════════════════════
|
|
// 在线服务操作
|
|
// ═══════════════════════════════════════════════════════════════════════════
|
|
|
|
/**
|
|
* 测试在线服务连接
|
|
*/
|
|
export async function testOnlineService(provider, config) {
|
|
const { url, key, model } = config;
|
|
|
|
if (!key) {
|
|
throw new Error('请填写 API Key');
|
|
}
|
|
if (!model) {
|
|
throw new Error('请选择模型');
|
|
}
|
|
|
|
const providerConfig = ONLINE_PROVIDERS[provider];
|
|
const baseUrl = (providerConfig?.baseUrl || url || '').replace(/\/+$/, '');
|
|
|
|
if (!baseUrl) {
|
|
throw new Error('请填写 API URL');
|
|
}
|
|
|
|
try {
|
|
if (provider === 'cohere') {
|
|
// Cohere 使用不同的 API 格式
|
|
const response = await fetch(`${baseUrl}/v1/embed`, {
|
|
method: 'POST',
|
|
headers: {
|
|
'Authorization': `Bearer ${key}`,
|
|
'Content-Type': 'application/json',
|
|
},
|
|
body: JSON.stringify({
|
|
model: model,
|
|
texts: ['测试连接'],
|
|
input_type: 'search_document',
|
|
}),
|
|
});
|
|
|
|
if (!response.ok) {
|
|
const error = await response.text();
|
|
throw new Error(`API 返回 ${response.status}: ${error}`);
|
|
}
|
|
|
|
const data = await response.json();
|
|
const dims = data.embeddings?.[0]?.length || 0;
|
|
|
|
if (dims === 0) {
|
|
throw new Error('API 返回的向量维度为 0');
|
|
}
|
|
|
|
return { success: true, dims };
|
|
|
|
} else {
|
|
// OpenAI 兼容格式
|
|
const response = await fetch(`${baseUrl}/v1/embeddings`, {
|
|
method: 'POST',
|
|
headers: {
|
|
'Authorization': `Bearer ${key}`,
|
|
'Content-Type': 'application/json',
|
|
},
|
|
body: JSON.stringify({
|
|
model: model,
|
|
input: ['测试连接'],
|
|
}),
|
|
});
|
|
|
|
if (!response.ok) {
|
|
const error = await response.text();
|
|
throw new Error(`API 返回 ${response.status}: ${error}`);
|
|
}
|
|
|
|
const data = await response.json();
|
|
const dims = data.data?.[0]?.embedding?.length || 0;
|
|
|
|
if (dims === 0) {
|
|
throw new Error('API 返回的向量维度为 0');
|
|
}
|
|
|
|
return { success: true, dims };
|
|
}
|
|
|
|
} catch (e) {
|
|
if (e.name === 'TypeError' && e.message.includes('fetch')) {
|
|
throw new Error('网络错误,请检查 URL 是否正确');
|
|
}
|
|
throw e;
|
|
}
|
|
}
|
|
|
|
/**
|
|
* 拉取在线模型列表(仅 OpenAI 兼容)
|
|
*/
|
|
export async function fetchOnlineModels(config) {
|
|
const { url, key } = config;
|
|
|
|
if (!url || !key) {
|
|
throw new Error('请填写 URL 和 Key');
|
|
}
|
|
|
|
const baseUrl = url.replace(/\/+$/, '').replace(/\/v1$/, '');
|
|
|
|
const response = await fetch(`${baseUrl}/v1/models`, {
|
|
headers: {
|
|
'Authorization': `Bearer ${key}`,
|
|
'Accept': 'application/json',
|
|
},
|
|
});
|
|
|
|
if (!response.ok) {
|
|
throw new Error(`获取模型列表失败: ${response.status}`);
|
|
}
|
|
|
|
const data = await response.json();
|
|
const models = data.data?.map(m => m.id).filter(Boolean) || [];
|
|
|
|
// 过滤出 embedding 相关的模型
|
|
const embeddingModels = models.filter(m => {
|
|
const lower = m.toLowerCase();
|
|
return lower.includes('embed') ||
|
|
lower.includes('bge') ||
|
|
lower.includes('e5') ||
|
|
lower.includes('gte');
|
|
});
|
|
|
|
return embeddingModels.length > 0 ? embeddingModels : models.slice(0, 20);
|
|
}
|
|
|
|
/**
|
|
* 使用在线服务生成向量
|
|
*/
|
|
async function embedOnline(texts, provider, config) {
|
|
const { url, key, model } = config;
|
|
|
|
const providerConfig = ONLINE_PROVIDERS[provider];
|
|
const baseUrl = (providerConfig?.baseUrl || url || '').replace(/\/+$/, '');
|
|
|
|
const reqId = Math.random().toString(36).slice(2, 6);
|
|
const maxRetries = 3;
|
|
|
|
for (let attempt = 1; attempt <= maxRetries; attempt++) {
|
|
const startTime = Date.now();
|
|
console.log(`[embed ${reqId}] send ${texts.length} items${attempt > 1 ? ` (retry ${attempt}/${maxRetries})` : ''}`);
|
|
|
|
try {
|
|
let response;
|
|
|
|
if (provider === 'cohere') {
|
|
response = await fetch(`${baseUrl}/v1/embed`, {
|
|
method: 'POST',
|
|
headers: {
|
|
'Authorization': `Bearer ${key}`,
|
|
'Content-Type': 'application/json',
|
|
},
|
|
body: JSON.stringify({
|
|
model: model,
|
|
texts: texts,
|
|
input_type: 'search_document',
|
|
}),
|
|
});
|
|
} else {
|
|
response = await fetch(`${baseUrl}/v1/embeddings`, {
|
|
method: 'POST',
|
|
headers: {
|
|
'Authorization': `Bearer ${key}`,
|
|
'Content-Type': 'application/json',
|
|
},
|
|
body: JSON.stringify({
|
|
model: model,
|
|
input: texts,
|
|
}),
|
|
});
|
|
}
|
|
|
|
console.log(`[embed ${reqId}] status=${response.status} time=${Date.now() - startTime}ms`);
|
|
|
|
if (!response.ok) {
|
|
const error = await response.text();
|
|
throw new Error(`API 返回 ${response.status}: ${error}`);
|
|
}
|
|
|
|
const data = await response.json();
|
|
|
|
if (provider === 'cohere') {
|
|
console.log(`[embed ${reqId}] done items=${data.embeddings?.length || 0} total=${Date.now() - startTime}ms`);
|
|
return data.embeddings.map(e => Array.isArray(e) ? e : Array.from(e));
|
|
}
|
|
|
|
console.log(`[embed ${reqId}] done items=${data.data?.length || 0} total=${Date.now() - startTime}ms`);
|
|
return data.data.map(item => {
|
|
const embedding = item.embedding;
|
|
return Array.isArray(embedding) ? embedding : Array.from(embedding);
|
|
});
|
|
} catch (e) {
|
|
console.warn(`[embed ${reqId}] failed attempt=${attempt} time=${Date.now() - startTime}ms`, e.message);
|
|
|
|
if (attempt < maxRetries) {
|
|
const waitTime = Math.pow(2, attempt - 1) * 1000;
|
|
console.log(`[embed ${reqId}] wait ${waitTime}ms then retry`);
|
|
await new Promise(r => setTimeout(r, waitTime));
|
|
continue;
|
|
}
|
|
|
|
console.error(`[embed ${reqId}] final failure`, e);
|
|
throw e;
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
// ═══════════════════════════════════════════════════════════════════════════
|
|
// 统一接口
|
|
// ═══════════════════════════════════════════════════════════════════════════
|
|
|
|
/**
|
|
* 生成向量(统一接口)
|
|
* @param {string[]} texts - 要向量化的文本数组
|
|
* @param {Object} config - 配置
|
|
* @returns {Promise<number[][]>}
|
|
*/
|
|
export async function embed(texts, config) {
|
|
if (!texts?.length) return [];
|
|
|
|
const { engine, local, online } = config;
|
|
|
|
if (engine === 'local') {
|
|
const modelId = local?.modelId || DEFAULT_LOCAL_MODEL;
|
|
return await embedLocal(texts, modelId);
|
|
|
|
} else if (engine === 'online') {
|
|
const provider = online?.provider || 'siliconflow';
|
|
if (!online?.key || !online?.model) {
|
|
throw new Error('在线服务配置不完整');
|
|
}
|
|
return await embedOnline(texts, provider, online);
|
|
|
|
} else {
|
|
throw new Error(`未知的引擎类型: ${engine}`);
|
|
}
|
|
}
|
|
|
|
/**
|
|
* 获取当前引擎的唯一标识(用于检查向量是否匹配)
|
|
*/
|
|
|
|
// Concurrent embed for online services (local falls back to sequential)
|
|
export async function embedBatchesConcurrent(textBatches, config, concurrency = 3) {
|
|
if (config.engine === 'local' || textBatches.length <= 1) {
|
|
const results = [];
|
|
for (const batch of textBatches) {
|
|
results.push(await embed(batch, config));
|
|
}
|
|
return results;
|
|
}
|
|
|
|
const results = new Array(textBatches.length);
|
|
let index = 0;
|
|
|
|
async function worker() {
|
|
while (index < textBatches.length) {
|
|
const i = index++;
|
|
results[i] = await embed(textBatches[i], config);
|
|
}
|
|
}
|
|
|
|
await Promise.all(
|
|
Array(Math.min(concurrency, textBatches.length))
|
|
.fill(null)
|
|
.map(() => worker())
|
|
);
|
|
|
|
return results;
|
|
}
|
|
|
|
export function getEngineFingerprint(config) {
|
|
if (config.engine === 'local') {
|
|
const modelId = config.local?.modelId || DEFAULT_LOCAL_MODEL;
|
|
const modelConfig = LOCAL_MODELS[modelId];
|
|
return `local:${modelId}:${modelConfig?.dims || 512}`;
|
|
|
|
} else if (config.engine === 'online') {
|
|
const provider = config.online?.provider || 'unknown';
|
|
const model = config.online?.model || 'unknown';
|
|
return `online:${provider}:${model}`;
|
|
|
|
} else {
|
|
return 'unknown';
|
|
}
|
|
}
|