fix(diffusion): exclude name1 from WHO/WHAT graph features

This commit is contained in:
2026-02-13 11:38:57 +08:00
parent 40f59d6571
commit 987e488c7b

View File

@@ -31,6 +31,7 @@
// ═══════════════════════════════════════════════════════════════════════════ // ═══════════════════════════════════════════════════════════════════════════
import { xbLog } from '../../../../core/debug-core.js'; import { xbLog } from '../../../../core/debug-core.js';
import { getContext } from '../../../../../../../extensions.js';
const MODULE_ID = 'diffusion'; const MODULE_ID = 'diffusion';
@@ -95,19 +96,20 @@ function cosineSimilarity(a, b) {
/** /**
* WHO channel: entity set = who edges.s edges.t * WHO channel: entity set = who edges.s edges.t
* @param {object} atom * @param {object} atom
* @param {Set<string>} excludeEntities - entities to exclude (e.g. name1)
* @returns {Set<string>} * @returns {Set<string>}
*/ */
function extractEntities(atom) { function extractEntities(atom, excludeEntities = new Set()) {
const set = new Set(); const set = new Set();
for (const w of (atom.who || [])) { for (const w of (atom.who || [])) {
const n = normalize(w); const n = normalize(w);
if (n) set.add(n); if (n && !excludeEntities.has(n)) set.add(n);
} }
for (const e of (atom.edges || [])) { for (const e of (atom.edges || [])) {
const s = normalize(e?.s); const s = normalize(e?.s);
const t = normalize(e?.t); const t = normalize(e?.t);
if (s) set.add(s); if (s && !excludeEntities.has(s)) set.add(s);
if (t) set.add(t); if (t && !excludeEntities.has(t)) set.add(t);
} }
return set; return set;
} }
@@ -115,14 +117,17 @@ function extractEntities(atom) {
/** /**
* WHAT channel: directed interaction pairs "A→B" (strict direction — option A) * WHAT channel: directed interaction pairs "A→B" (strict direction — option A)
* @param {object} atom * @param {object} atom
* @param {Set<string>} excludeEntities
* @returns {Set<string>} * @returns {Set<string>}
*/ */
function extractDirectedPairs(atom) { function extractDirectedPairs(atom, excludeEntities = new Set()) {
const set = new Set(); const set = new Set();
for (const e of (atom.edges || [])) { for (const e of (atom.edges || [])) {
const s = normalize(e?.s); const s = normalize(e?.s);
const t = normalize(e?.t); const t = normalize(e?.t);
if (s && t) set.add(`${s}\u2192${t}`); if (s && t && !excludeEntities.has(s) && !excludeEntities.has(t)) {
set.add(`${s}\u2192${t}`);
}
} }
return set; return set;
} }
@@ -201,12 +206,13 @@ function overlapCoefficient(a, b) {
/** /**
* Pre-extract features for all atoms * Pre-extract features for all atoms
* @param {object[]} allAtoms * @param {object[]} allAtoms
* @param {Set<string>} excludeEntities
* @returns {object[]} feature objects with entities/directedPairs/location/dynamics * @returns {object[]} feature objects with entities/directedPairs/location/dynamics
*/ */
function extractAllFeatures(allAtoms) { function extractAllFeatures(allAtoms, excludeEntities = new Set()) {
return allAtoms.map(atom => ({ return allAtoms.map(atom => ({
entities: extractEntities(atom), entities: extractEntities(atom, excludeEntities),
directedPairs: extractDirectedPairs(atom), directedPairs: extractDirectedPairs(atom, excludeEntities),
location: extractLocation(atom), location: extractLocation(atom),
dynamics: extractDynamics(atom), dynamics: extractDynamics(atom),
})); }));
@@ -258,13 +264,14 @@ function collectPairsFromIndex(index, pairSet, N) {
* Build weighted undirected graph over L0 atoms. * Build weighted undirected graph over L0 atoms.
* *
* @param {object[]} allAtoms * @param {object[]} allAtoms
* @param {Set<string>} excludeEntities
* @returns {{ neighbors: object[][], edgeCount: number, channelStats: object, buildTime: number }} * @returns {{ neighbors: object[][], edgeCount: number, channelStats: object, buildTime: number }}
*/ */
function buildGraph(allAtoms) { function buildGraph(allAtoms, excludeEntities = new Set()) {
const N = allAtoms.length; const N = allAtoms.length;
const T0 = performance.now(); const T0 = performance.now();
const features = extractAllFeatures(allAtoms); const features = extractAllFeatures(allAtoms, excludeEntities);
const { entityIndex, locationIndex } = buildInvertedIndices(features); const { entityIndex, locationIndex } = buildInvertedIndices(features);
// Candidate pairs: share ≥1 entity or same location // Candidate pairs: share ≥1 entity or same location
@@ -604,6 +611,11 @@ export function diffuseFromSeeds(seeds, allAtoms, stateVectors, queryVector, met
return []; return [];
} }
// Align with entity-lexicon hard rule: exclude name1 from graph features.
const { name1 } = getContext();
const excludeEntities = new Set();
if (name1) excludeEntities.add(normalize(name1));
// ─── 1. Build atom index ───────────────────────────────────────── // ─── 1. Build atom index ─────────────────────────────────────────
const atomById = new Map(); const atomById = new Map();
@@ -630,7 +642,7 @@ export function diffuseFromSeeds(seeds, allAtoms, stateVectors, queryVector, met
// ─── 2. Build graph ────────────────────────────────────────────── // ─── 2. Build graph ──────────────────────────────────────────────
const graph = buildGraph(allAtoms); const graph = buildGraph(allAtoms, excludeEntities);
if (graph.edgeCount === 0) { if (graph.edgeCount === 0) {
fillMetrics(metrics, { fillMetrics(metrics, {