refactor diffusion to r-sem edges with time window and add rVector I/O
This commit is contained in:
@@ -10,7 +10,7 @@
|
||||
//
|
||||
// Algorithm:
|
||||
// 1. Build undirected weighted graph over all L0 atoms
|
||||
// Four channels: WHO/WHAT/WHERE/HOW (Jaccard/Overlap/ExactMatch)
|
||||
// Candidate edges: WHAT + R semantic; WHO/WHERE are reweight-only
|
||||
// 2. Personalized PageRank (Power Iteration)
|
||||
// Seeds weighted by rerankScore — Haveliwala (2002) topic-sensitive variant
|
||||
// α = 0.15 restart probability — Page et al. (1998)
|
||||
@@ -32,7 +32,6 @@
|
||||
|
||||
import { xbLog } from '../../../../core/debug-core.js';
|
||||
import { getContext } from '../../../../../../../extensions.js';
|
||||
import { tokenizeForIndex } from '../utils/tokenizer.js';
|
||||
|
||||
const MODULE_ID = 'diffusion';
|
||||
|
||||
@@ -47,22 +46,27 @@ const CONFIG = {
|
||||
MAX_ITER: 50, // hard iteration cap (typically converges in 15-25)
|
||||
|
||||
// Edge weight channel coefficients
|
||||
// Candidate generation uses WHAT/HOW only.
|
||||
// Candidate generation uses WHAT + R semantic only.
|
||||
// WHO/WHERE are reweight-only signals.
|
||||
GAMMA: {
|
||||
what: 0.45, // interaction pair overlap — Szymkiewicz-Simpson
|
||||
how: 0.30, // action-term co-occurrence — Jaccard
|
||||
who: 0.15, // endpoint entity overlap — Jaccard (reweight-only)
|
||||
where: 0.10, // location exact match — damped (reweight-only)
|
||||
what: 0.40, // interaction pair overlap
|
||||
rSem: 0.40, // semantic similarity over edges.r aggregate
|
||||
who: 0.10, // endpoint entity overlap (reweight-only)
|
||||
where: 0.05, // location exact match (reweight-only)
|
||||
time: 0.05, // temporal decay score
|
||||
},
|
||||
// R semantic candidate generation
|
||||
R_SEM_MIN_SIM: 0.62,
|
||||
R_SEM_TOPK: 8,
|
||||
TIME_WINDOW_MAX: 80,
|
||||
TIME_DECAY_DIVISOR: 12,
|
||||
WHERE_MAX_GROUP_SIZE: 16, // skip location-only pair expansion for over-common places
|
||||
WHERE_FREQ_DAMP_PIVOT: 6, // location freq <= pivot keeps full WHERE score
|
||||
WHERE_FREQ_DAMP_MIN: 0.20, // lower bound for damped WHERE contribution
|
||||
HOW_MAX_GROUP_SIZE: 24, // skip ultra-common action terms to avoid dense pair explosion
|
||||
|
||||
// Post-verification (Cosine Gate)
|
||||
COSINE_GATE: 0.46, // min cosine(queryVector, stateVector)
|
||||
SCORE_FLOOR: 0.12, // min finalScore = PPR_normalized × cosine
|
||||
SCORE_FLOOR: 0.10, // min finalScore = PPR_normalized × cosine
|
||||
DIFFUSION_CAP: 100, // max diffused nodes (excluding seeds)
|
||||
};
|
||||
|
||||
@@ -144,23 +148,14 @@ function extractLocation(atom) {
|
||||
return normalize(atom.where);
|
||||
}
|
||||
|
||||
/**
|
||||
* HOW channel: action terms from edges.r
|
||||
* @param {object} atom
|
||||
* @param {Set<string>} excludeEntities
|
||||
* @returns {Set<string>}
|
||||
*/
|
||||
function extractActionTerms(atom, excludeEntities = new Set()) {
|
||||
const set = new Set();
|
||||
for (const e of (atom.edges || [])) {
|
||||
const rel = String(e?.r || '').trim();
|
||||
if (!rel) continue;
|
||||
for (const token of tokenizeForIndex(rel)) {
|
||||
const t = normalize(token);
|
||||
if (t && !excludeEntities.has(t)) set.add(t);
|
||||
}
|
||||
}
|
||||
return set;
|
||||
function getFloorDistance(a, b) {
|
||||
const fa = Number(a?.floor || 0);
|
||||
const fb = Number(b?.floor || 0);
|
||||
return Math.abs(fa - fb);
|
||||
}
|
||||
|
||||
function getTimeScore(distance) {
|
||||
return Math.exp(-distance / CONFIG.TIME_DECAY_DIVISOR);
|
||||
}
|
||||
|
||||
// ═══════════════════════════════════════════════════════════════════════════
|
||||
@@ -205,35 +200,31 @@ function overlapCoefficient(a, b) {
|
||||
// ═══════════════════════════════════════════════════════════════════════════
|
||||
// Graph construction
|
||||
//
|
||||
// Candidate pairs discovered via inverted indices on entities and locations.
|
||||
// HOW-only pairs are still excluded from candidate generation to avoid O(N²);
|
||||
// all channel weights are evaluated for the entity/location candidate set.
|
||||
// All four channels evaluated for every candidate pair.
|
||||
// Candidate pairs discovered via WHAT inverted index and R semantic top-k.
|
||||
// WHO/WHERE are reweight-only signals and never create candidate pairs.
|
||||
// ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
/**
|
||||
* Pre-extract features for all atoms
|
||||
* @param {object[]} allAtoms
|
||||
* @param {Set<string>} excludeEntities
|
||||
* @returns {object[]} feature objects with entities/interactionPairs/location/actionTerms
|
||||
* @returns {object[]} feature objects with entities/interactionPairs/location
|
||||
*/
|
||||
function extractAllFeatures(allAtoms, excludeEntities = new Set()) {
|
||||
return allAtoms.map(atom => ({
|
||||
entities: extractEntities(atom, excludeEntities),
|
||||
interactionPairs: extractInteractionPairs(atom, excludeEntities),
|
||||
location: extractLocation(atom),
|
||||
actionTerms: extractActionTerms(atom, excludeEntities),
|
||||
}));
|
||||
}
|
||||
|
||||
/**
|
||||
* Build inverted index: value → list of atom indices
|
||||
* @param {object[]} features
|
||||
* @returns {{ whatIndex: Map, howIndex: Map, locationFreq: Map }}
|
||||
* @returns {{ whatIndex: Map, locationFreq: Map }}
|
||||
*/
|
||||
function buildInvertedIndices(features) {
|
||||
const whatIndex = new Map();
|
||||
const howIndex = new Map();
|
||||
const locationFreq = new Map();
|
||||
|
||||
for (let i = 0; i < features.length; i++) {
|
||||
@@ -241,15 +232,11 @@ function buildInvertedIndices(features) {
|
||||
if (!whatIndex.has(pair)) whatIndex.set(pair, []);
|
||||
whatIndex.get(pair).push(i);
|
||||
}
|
||||
for (const action of features[i].actionTerms) {
|
||||
if (!howIndex.has(action)) howIndex.set(action, []);
|
||||
howIndex.get(action).push(i);
|
||||
}
|
||||
const loc = features[i].location;
|
||||
if (loc) locationFreq.set(loc, (locationFreq.get(loc) || 0) + 1);
|
||||
}
|
||||
|
||||
return { whatIndex, howIndex, locationFreq };
|
||||
return { whatIndex, locationFreq };
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -274,38 +261,88 @@ function collectPairsFromIndex(index, pairSet, N) {
|
||||
* Build weighted undirected graph over L0 atoms.
|
||||
*
|
||||
* @param {object[]} allAtoms
|
||||
* @param {object[]} stateVectors
|
||||
* @param {Set<string>} excludeEntities
|
||||
* @returns {{ neighbors: object[][], edgeCount: number, channelStats: object, buildTime: number }}
|
||||
*/
|
||||
function buildGraph(allAtoms, excludeEntities = new Set()) {
|
||||
function buildGraph(allAtoms, stateVectors = [], excludeEntities = new Set()) {
|
||||
const N = allAtoms.length;
|
||||
const T0 = performance.now();
|
||||
|
||||
const features = extractAllFeatures(allAtoms, excludeEntities);
|
||||
const { whatIndex, howIndex, locationFreq } = buildInvertedIndices(features);
|
||||
const { whatIndex, locationFreq } = buildInvertedIndices(features);
|
||||
|
||||
// Candidate pairs: only WHAT/HOW can create edges
|
||||
// Candidate pairs: WHAT + R semantic
|
||||
const pairSetByWhat = new Set();
|
||||
const pairSetByHow = new Set();
|
||||
const pairSetByRSem = new Set();
|
||||
const rSemByPair = new Map();
|
||||
const pairSet = new Set();
|
||||
collectPairsFromIndex(whatIndex, pairSetByWhat, N);
|
||||
let skippedHowGroups = 0;
|
||||
for (const [term, indices] of howIndex.entries()) {
|
||||
if (!term) continue;
|
||||
if (indices.length > CONFIG.HOW_MAX_GROUP_SIZE) {
|
||||
skippedHowGroups++;
|
||||
continue;
|
||||
|
||||
const rVectorByAtomId = new Map(
|
||||
(stateVectors || [])
|
||||
.filter(v => v?.atomId && v?.rVector?.length)
|
||||
.map(v => [v.atomId, v.rVector])
|
||||
);
|
||||
const rVectors = allAtoms.map(a => rVectorByAtomId.get(a.atomId) || null);
|
||||
|
||||
const directedNeighbors = Array.from({ length: N }, () => []);
|
||||
let rSemSimSum = 0;
|
||||
let rSemSimCount = 0;
|
||||
let topKPrunedPairs = 0;
|
||||
let timeWindowFilteredPairs = 0;
|
||||
|
||||
// Enumerate only pairs within floor window to avoid O(N^2) full scan.
|
||||
const sortedByFloor = allAtoms
|
||||
.map((atom, idx) => ({ idx, floor: Number(atom?.floor || 0) }))
|
||||
.sort((a, b) => a.floor - b.floor);
|
||||
|
||||
for (let left = 0; left < sortedByFloor.length; left++) {
|
||||
const i = sortedByFloor[left].idx;
|
||||
const baseFloor = sortedByFloor[left].floor;
|
||||
|
||||
for (let right = left + 1; right < sortedByFloor.length; right++) {
|
||||
const floorDelta = sortedByFloor[right].floor - baseFloor;
|
||||
if (floorDelta > CONFIG.TIME_WINDOW_MAX) break;
|
||||
|
||||
const j = sortedByFloor[right].idx;
|
||||
const vi = rVectors[i];
|
||||
const vj = rVectors[j];
|
||||
if (!vi?.length || !vj?.length) continue;
|
||||
|
||||
const sim = cosineSimilarity(vi, vj);
|
||||
if (sim < CONFIG.R_SEM_MIN_SIM) continue;
|
||||
|
||||
directedNeighbors[i].push({ target: j, sim });
|
||||
directedNeighbors[j].push({ target: i, sim });
|
||||
rSemSimSum += sim;
|
||||
rSemSimCount++;
|
||||
}
|
||||
}
|
||||
|
||||
for (let i = 0; i < N; i++) {
|
||||
const arr = directedNeighbors[i];
|
||||
if (!arr.length) continue;
|
||||
arr.sort((a, b) => b.sim - a.sim);
|
||||
if (arr.length > CONFIG.R_SEM_TOPK) {
|
||||
topKPrunedPairs += arr.length - CONFIG.R_SEM_TOPK;
|
||||
}
|
||||
for (const n of arr.slice(0, CONFIG.R_SEM_TOPK)) {
|
||||
const lo = Math.min(i, n.target);
|
||||
const hi = Math.max(i, n.target);
|
||||
const packed = lo * N + hi;
|
||||
pairSetByRSem.add(packed);
|
||||
const prev = rSemByPair.get(packed) || 0;
|
||||
if (n.sim > prev) rSemByPair.set(packed, n.sim);
|
||||
}
|
||||
const oneHowMap = new Map([[term, indices]]);
|
||||
collectPairsFromIndex(oneHowMap, pairSetByHow, N);
|
||||
}
|
||||
for (const p of pairSetByWhat) pairSet.add(p);
|
||||
for (const p of pairSetByHow) pairSet.add(p);
|
||||
for (const p of pairSetByRSem) pairSet.add(p);
|
||||
|
||||
// Compute edge weights for all candidates
|
||||
const neighbors = Array.from({ length: N }, () => []);
|
||||
let edgeCount = 0;
|
||||
const channelStats = { what: 0, where: 0, how: 0, who: 0 };
|
||||
const channelStats = { what: 0, where: 0, rSem: 0, who: 0 };
|
||||
let reweightWhoUsed = 0;
|
||||
let reweightWhereUsed = 0;
|
||||
|
||||
@@ -313,11 +350,18 @@ function buildGraph(allAtoms, excludeEntities = new Set()) {
|
||||
const i = Math.floor(packed / N);
|
||||
const j = packed % N;
|
||||
|
||||
const distance = getFloorDistance(allAtoms[i], allAtoms[j]);
|
||||
if (distance > CONFIG.TIME_WINDOW_MAX) {
|
||||
timeWindowFilteredPairs++;
|
||||
continue;
|
||||
}
|
||||
const wTime = getTimeScore(distance);
|
||||
|
||||
const fi = features[i];
|
||||
const fj = features[j];
|
||||
|
||||
const wWhat = overlapCoefficient(fi.interactionPairs, fj.interactionPairs);
|
||||
const wHow = jaccard(fi.actionTerms, fj.actionTerms);
|
||||
const wRSem = rSemByPair.get(packed) || 0;
|
||||
const wWho = jaccard(fi.entities, fj.entities);
|
||||
let wWhere = 0.0;
|
||||
if (fi.location && fi.location === fj.location) {
|
||||
@@ -331,9 +375,10 @@ function buildGraph(allAtoms, excludeEntities = new Set()) {
|
||||
|
||||
const weight =
|
||||
CONFIG.GAMMA.what * wWhat +
|
||||
CONFIG.GAMMA.how * wHow +
|
||||
CONFIG.GAMMA.rSem * wRSem +
|
||||
CONFIG.GAMMA.who * wWho +
|
||||
CONFIG.GAMMA.where * wWhere;
|
||||
CONFIG.GAMMA.where * wWhere +
|
||||
CONFIG.GAMMA.time * wTime;
|
||||
|
||||
if (weight > 0) {
|
||||
neighbors[i].push({ target: j, weight });
|
||||
@@ -341,7 +386,7 @@ function buildGraph(allAtoms, excludeEntities = new Set()) {
|
||||
edgeCount++;
|
||||
|
||||
if (wWhat > 0) channelStats.what++;
|
||||
if (wHow > 0) channelStats.how++;
|
||||
if (wRSem > 0) channelStats.rSem++;
|
||||
if (wWho > 0) channelStats.who++;
|
||||
if (wWhere > 0) channelStats.where++;
|
||||
if (wWho > 0) reweightWhoUsed++;
|
||||
@@ -353,10 +398,10 @@ function buildGraph(allAtoms, excludeEntities = new Set()) {
|
||||
|
||||
xbLog.info(MODULE_ID,
|
||||
`Graph: ${N} nodes, ${edgeCount} edges ` +
|
||||
`(candidate_by_what=${pairSetByWhat.size} candidate_by_how=${pairSetByHow.size}) ` +
|
||||
`(what=${channelStats.what} how=${channelStats.how} who=${channelStats.who} where=${channelStats.where}) ` +
|
||||
`(candidate_by_what=${pairSetByWhat.size} candidate_by_r_sem=${pairSetByRSem.size}) ` +
|
||||
`(what=${channelStats.what} r_sem=${channelStats.rSem} who=${channelStats.who} where=${channelStats.where}) ` +
|
||||
`(reweight_who_used=${reweightWhoUsed} reweight_where_used=${reweightWhereUsed}) ` +
|
||||
`(howSkippedGroups=${skippedHowGroups}) ` +
|
||||
`(time_window_filtered=${timeWindowFilteredPairs} topk_pruned=${topKPrunedPairs}) ` +
|
||||
`(${buildTime}ms)`
|
||||
);
|
||||
|
||||
@@ -370,7 +415,10 @@ function buildGraph(allAtoms, excludeEntities = new Set()) {
|
||||
buildTime,
|
||||
candidatePairs: pairSet.size,
|
||||
pairsFromWhat: pairSetByWhat.size,
|
||||
pairsFromHow: pairSetByHow.size,
|
||||
pairsFromRSem: pairSetByRSem.size,
|
||||
rSemAvgSim: rSemSimCount ? Number((rSemSimSum / rSemSimCount).toFixed(3)) : 0,
|
||||
timeWindowFilteredPairs,
|
||||
topKPrunedPairs,
|
||||
reweightWhoUsed,
|
||||
reweightWhereUsed,
|
||||
edgeDensity,
|
||||
@@ -646,7 +694,7 @@ function postVerify(pi, atomIds, atomById, seedAtomIds, vectorMap, queryVector)
|
||||
* @param {object[]} allAtoms - getStateAtoms() result
|
||||
* Each: { atomId, floor, semantic, edges, where }
|
||||
* @param {object[]} stateVectors - getAllStateVectors() result
|
||||
* Each: { atomId, floor, vector: Float32Array }
|
||||
* Each: { atomId, floor, vector: Float32Array, rVector?: Float32Array }
|
||||
* @param {Float32Array|number[]} queryVector - R2 weighted query vector
|
||||
* @param {object|null} metrics - metrics object (optional, mutated in-place)
|
||||
* @returns {object[]} Additional L0 atoms for l0Selected
|
||||
@@ -693,7 +741,7 @@ export function diffuseFromSeeds(seeds, allAtoms, stateVectors, queryVector, met
|
||||
|
||||
// ─── 2. Build graph ──────────────────────────────────────────────
|
||||
|
||||
const graph = buildGraph(allAtoms, excludeEntities);
|
||||
const graph = buildGraph(allAtoms, stateVectors, excludeEntities);
|
||||
|
||||
if (graph.edgeCount === 0) {
|
||||
fillMetrics(metrics, {
|
||||
@@ -703,7 +751,10 @@ export function diffuseFromSeeds(seeds, allAtoms, stateVectors, queryVector, met
|
||||
channelStats: graph.channelStats,
|
||||
candidatePairs: graph.candidatePairs,
|
||||
pairsFromWhat: graph.pairsFromWhat,
|
||||
pairsFromHow: graph.pairsFromHow,
|
||||
pairsFromRSem: graph.pairsFromRSem,
|
||||
rSemAvgSim: graph.rSemAvgSim,
|
||||
timeWindowFilteredPairs: graph.timeWindowFilteredPairs,
|
||||
topKPrunedPairs: graph.topKPrunedPairs,
|
||||
edgeDensity: graph.edgeDensity,
|
||||
reweightWhoUsed: graph.reweightWhoUsed,
|
||||
reweightWhereUsed: graph.reweightWhereUsed,
|
||||
@@ -755,7 +806,10 @@ export function diffuseFromSeeds(seeds, allAtoms, stateVectors, queryVector, met
|
||||
channelStats: graph.channelStats,
|
||||
candidatePairs: graph.candidatePairs,
|
||||
pairsFromWhat: graph.pairsFromWhat,
|
||||
pairsFromHow: graph.pairsFromHow,
|
||||
pairsFromRSem: graph.pairsFromRSem,
|
||||
rSemAvgSim: graph.rSemAvgSim,
|
||||
timeWindowFilteredPairs: graph.timeWindowFilteredPairs,
|
||||
topKPrunedPairs: graph.topKPrunedPairs,
|
||||
edgeDensity: graph.edgeDensity,
|
||||
reweightWhoUsed: graph.reweightWhoUsed,
|
||||
reweightWhereUsed: graph.reweightWhereUsed,
|
||||
@@ -826,10 +880,13 @@ function fillMetricsEmpty(metrics) {
|
||||
cosineGateNoVector: 0,
|
||||
finalCount: 0,
|
||||
scoreDistribution: { min: 0, max: 0, mean: 0 },
|
||||
byChannel: { what: 0, where: 0, how: 0, who: 0 },
|
||||
byChannel: { what: 0, where: 0, rSem: 0, who: 0 },
|
||||
candidatePairs: 0,
|
||||
pairsFromWhat: 0,
|
||||
pairsFromHow: 0,
|
||||
pairsFromRSem: 0,
|
||||
rSemAvgSim: 0,
|
||||
timeWindowFilteredPairs: 0,
|
||||
topKPrunedPairs: 0,
|
||||
edgeDensity: 0,
|
||||
reweightWhoUsed: 0,
|
||||
reweightWhereUsed: 0,
|
||||
@@ -856,10 +913,13 @@ function fillMetrics(metrics, data) {
|
||||
postGatePassRate: data.postGatePassRate || 0,
|
||||
finalCount: data.finalCount || 0,
|
||||
scoreDistribution: data.scoreDistribution || { min: 0, max: 0, mean: 0 },
|
||||
byChannel: data.channelStats || { what: 0, where: 0, how: 0, who: 0 },
|
||||
byChannel: data.channelStats || { what: 0, where: 0, rSem: 0, who: 0 },
|
||||
candidatePairs: data.candidatePairs || 0,
|
||||
pairsFromWhat: data.pairsFromWhat || 0,
|
||||
pairsFromHow: data.pairsFromHow || 0,
|
||||
pairsFromRSem: data.pairsFromRSem || 0,
|
||||
rSemAvgSim: data.rSemAvgSim || 0,
|
||||
timeWindowFilteredPairs: data.timeWindowFilteredPairs || 0,
|
||||
topKPrunedPairs: data.topKPrunedPairs || 0,
|
||||
edgeDensity: data.edgeDensity || 0,
|
||||
reweightWhoUsed: data.reweightWhoUsed || 0,
|
||||
reweightWhereUsed: data.reweightWhereUsed || 0,
|
||||
|
||||
Reference in New Issue
Block a user