From 987e488c7bab9b1f93651315be99df7d9ea2a60d Mon Sep 17 00:00:00 2001 From: bielie Date: Fri, 13 Feb 2026 11:38:57 +0800 Subject: [PATCH] fix(diffusion): exclude name1 from WHO/WHAT graph features --- .../vector/retrieval/diffusion.js | 36 ++++++++++++------- 1 file changed, 24 insertions(+), 12 deletions(-) diff --git a/modules/story-summary/vector/retrieval/diffusion.js b/modules/story-summary/vector/retrieval/diffusion.js index c59ccda..6fb55a3 100644 --- a/modules/story-summary/vector/retrieval/diffusion.js +++ b/modules/story-summary/vector/retrieval/diffusion.js @@ -31,6 +31,7 @@ // ═══════════════════════════════════════════════════════════════════════════ import { xbLog } from '../../../../core/debug-core.js'; +import { getContext } from '../../../../../../../extensions.js'; const MODULE_ID = 'diffusion'; @@ -95,19 +96,20 @@ function cosineSimilarity(a, b) { /** * WHO channel: entity set = who ∪ edges.s ∪ edges.t * @param {object} atom + * @param {Set} excludeEntities - entities to exclude (e.g. name1) * @returns {Set} */ -function extractEntities(atom) { +function extractEntities(atom, excludeEntities = new Set()) { const set = new Set(); for (const w of (atom.who || [])) { const n = normalize(w); - if (n) set.add(n); + if (n && !excludeEntities.has(n)) set.add(n); } for (const e of (atom.edges || [])) { const s = normalize(e?.s); const t = normalize(e?.t); - if (s) set.add(s); - if (t) set.add(t); + if (s && !excludeEntities.has(s)) set.add(s); + if (t && !excludeEntities.has(t)) set.add(t); } return set; } @@ -115,14 +117,17 @@ function extractEntities(atom) { /** * WHAT channel: directed interaction pairs "A→B" (strict direction — option A) * @param {object} atom + * @param {Set} excludeEntities * @returns {Set} */ -function extractDirectedPairs(atom) { +function extractDirectedPairs(atom, excludeEntities = new Set()) { const set = new Set(); for (const e of (atom.edges || [])) { const s = normalize(e?.s); const t = normalize(e?.t); - if (s && t) set.add(`${s}\u2192${t}`); + if (s && t && !excludeEntities.has(s) && !excludeEntities.has(t)) { + set.add(`${s}\u2192${t}`); + } } return set; } @@ -201,12 +206,13 @@ function overlapCoefficient(a, b) { /** * Pre-extract features for all atoms * @param {object[]} allAtoms + * @param {Set} excludeEntities * @returns {object[]} feature objects with entities/directedPairs/location/dynamics */ -function extractAllFeatures(allAtoms) { +function extractAllFeatures(allAtoms, excludeEntities = new Set()) { return allAtoms.map(atom => ({ - entities: extractEntities(atom), - directedPairs: extractDirectedPairs(atom), + entities: extractEntities(atom, excludeEntities), + directedPairs: extractDirectedPairs(atom, excludeEntities), location: extractLocation(atom), dynamics: extractDynamics(atom), })); @@ -258,13 +264,14 @@ function collectPairsFromIndex(index, pairSet, N) { * Build weighted undirected graph over L0 atoms. * * @param {object[]} allAtoms + * @param {Set} excludeEntities * @returns {{ neighbors: object[][], edgeCount: number, channelStats: object, buildTime: number }} */ -function buildGraph(allAtoms) { +function buildGraph(allAtoms, excludeEntities = new Set()) { const N = allAtoms.length; const T0 = performance.now(); - const features = extractAllFeatures(allAtoms); + const features = extractAllFeatures(allAtoms, excludeEntities); const { entityIndex, locationIndex } = buildInvertedIndices(features); // Candidate pairs: share ≥1 entity or same location @@ -604,6 +611,11 @@ export function diffuseFromSeeds(seeds, allAtoms, stateVectors, queryVector, met return []; } + // Align with entity-lexicon hard rule: exclude name1 from graph features. + const { name1 } = getContext(); + const excludeEntities = new Set(); + if (name1) excludeEntities.add(normalize(name1)); + // ─── 1. Build atom index ───────────────────────────────────────── const atomById = new Map(); @@ -630,7 +642,7 @@ export function diffuseFromSeeds(seeds, allAtoms, stateVectors, queryVector, met // ─── 2. Build graph ────────────────────────────────────────────── - const graph = buildGraph(allAtoms); + const graph = buildGraph(allAtoms, excludeEntities); if (graph.edgeCount === 0) { fillMetrics(metrics, {