forked from kent/consciousness
match_seeds() previously only found nodes whose keys exactly matched search terms. This meant searches like "formal verification" or "bcachefs plan" returned nothing — no nodes are keyed with those exact strings. Three-tier matching strategy: 1. Exact key match (full weight) — unchanged 2. Key component match (0.5× weight) — split keys on -/_/./#, match individual words. "plan" now finds "the-plan", "verification" finds "c-to-rust-verification-workflow", etc. 3. Content match (0.2× weight, capped at 50 hits) — search node content for terms that didn't match any key. Catches nodes whose keys are opaque but whose content is relevant. Also adds prompt-based seeding to the hook pipeline: extract_query_terms from the user's prompt and merge into the term set. Previously the hook only seeded from transcript scanning (finding node keys as substrings in conversation history), which meant fresh sessions or queries about new topics produced no search results at all.
842 lines
30 KiB
Rust
842 lines
30 KiB
Rust
// Memory search: composable algorithm pipeline.
|
||
//
|
||
// Each algorithm is a stage: takes seeds Vec<(String, f64)>, returns
|
||
// new/modified seeds. Stages compose left-to-right in a pipeline.
|
||
//
|
||
// Available algorithms:
|
||
// spread — spreading activation through graph edges
|
||
// spectral — nearest neighbors in spectral embedding space
|
||
// manifold — extrapolation along direction defined by seeds (TODO)
|
||
//
|
||
// Seed extraction (matching query terms to node keys) is shared
|
||
// infrastructure, not an algorithm stage.
|
||
|
||
use crate::store::StoreView;
|
||
use crate::graph::Graph;
|
||
use crate::spectral;
|
||
|
||
use std::collections::{BTreeMap, HashMap, HashSet, VecDeque};
|
||
use std::fmt;
|
||
|
||
pub struct SearchResult {
|
||
pub key: String,
|
||
pub activation: f64,
|
||
pub is_direct: bool,
|
||
pub snippet: Option<String>,
|
||
}
|
||
|
||
/// A parsed algorithm stage with its parameters.
|
||
#[derive(Clone, Debug)]
|
||
pub struct AlgoStage {
|
||
pub algo: Algorithm,
|
||
pub params: HashMap<String, String>,
|
||
}
|
||
|
||
#[derive(Clone, Debug)]
|
||
pub enum Algorithm {
|
||
Spread,
|
||
Spectral,
|
||
Manifold,
|
||
Confluence,
|
||
Geodesic,
|
||
}
|
||
|
||
impl fmt::Display for Algorithm {
|
||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||
match self {
|
||
Algorithm::Spread => write!(f, "spread"),
|
||
Algorithm::Spectral => write!(f, "spectral"),
|
||
Algorithm::Manifold => write!(f, "manifold"),
|
||
Algorithm::Confluence => write!(f, "confluence"),
|
||
Algorithm::Geodesic => write!(f, "geodesic"),
|
||
}
|
||
}
|
||
}
|
||
|
||
impl AlgoStage {
|
||
/// Parse "spread,max_hops=4,edge_decay=0.5" into an AlgoStage.
|
||
pub fn parse(s: &str) -> Result<Self, String> {
|
||
let mut parts = s.split(',');
|
||
let name = parts.next().unwrap_or("");
|
||
let algo = match name {
|
||
"spread" => Algorithm::Spread,
|
||
"spectral" => Algorithm::Spectral,
|
||
"manifold" => Algorithm::Manifold,
|
||
"confluence" => Algorithm::Confluence,
|
||
"geodesic" => Algorithm::Geodesic,
|
||
_ => return Err(format!("unknown algorithm: {}", name)),
|
||
};
|
||
let mut params = HashMap::new();
|
||
for part in parts {
|
||
if let Some((k, v)) = part.split_once('=') {
|
||
params.insert(k.to_string(), v.to_string());
|
||
} else {
|
||
return Err(format!("bad param (expected key=val): {}", part));
|
||
}
|
||
}
|
||
Ok(AlgoStage { algo, params })
|
||
}
|
||
|
||
fn param_f64(&self, key: &str, default: f64) -> f64 {
|
||
self.params.get(key)
|
||
.and_then(|v| v.parse().ok())
|
||
.unwrap_or(default)
|
||
}
|
||
|
||
fn param_u32(&self, key: &str, default: u32) -> u32 {
|
||
self.params.get(key)
|
||
.and_then(|v| v.parse().ok())
|
||
.unwrap_or(default)
|
||
}
|
||
|
||
fn param_usize(&self, key: &str, default: usize) -> usize {
|
||
self.params.get(key)
|
||
.and_then(|v| v.parse().ok())
|
||
.unwrap_or(default)
|
||
}
|
||
}
|
||
|
||
/// Extract seeds from weighted terms by matching against node keys and content.
|
||
///
|
||
/// Three matching strategies, in priority order:
|
||
/// 1. Exact key match: term matches a node key exactly → full weight
|
||
/// 2. Key component match: term matches a word in a hyphenated/underscored key → 0.5× weight
|
||
/// 3. Content match: term appears in node content → 0.2× weight (capped at 50 nodes)
|
||
///
|
||
/// Returns (seeds, direct_hits) where direct_hits tracks which keys
|
||
/// were matched directly (vs found by an algorithm stage).
|
||
pub fn match_seeds(
|
||
terms: &BTreeMap<String, f64>,
|
||
store: &impl StoreView,
|
||
) -> (Vec<(String, f64)>, HashSet<String>) {
|
||
let mut seed_map: HashMap<String, f64> = HashMap::new();
|
||
let mut direct_hits: HashSet<String> = HashSet::new();
|
||
|
||
// Build key lookup: lowercase key → (original key, weight)
|
||
let mut key_map: HashMap<String, (String, f64)> = HashMap::new();
|
||
// Build component index: word → vec of (original key, weight)
|
||
let mut component_map: HashMap<String, Vec<(String, f64)>> = HashMap::new();
|
||
|
||
store.for_each_node(|key, _content, weight| {
|
||
let lkey = key.to_lowercase();
|
||
key_map.insert(lkey.clone(), (key.to_owned(), weight as f64));
|
||
|
||
// Split key on hyphens, underscores, dots, hashes for component matching
|
||
for component in lkey.split(|c: char| c == '-' || c == '_' || c == '.' || c == '#') {
|
||
if component.len() >= 3 {
|
||
component_map.entry(component.to_owned())
|
||
.or_default()
|
||
.push((key.to_owned(), weight as f64));
|
||
}
|
||
}
|
||
});
|
||
|
||
for (term, &term_weight) in terms {
|
||
// Strategy 1: exact key match
|
||
if let Some((orig_key, node_weight)) = key_map.get(term) {
|
||
let score = term_weight * node_weight;
|
||
*seed_map.entry(orig_key.clone()).or_insert(0.0) += score;
|
||
direct_hits.insert(orig_key.clone());
|
||
continue;
|
||
}
|
||
|
||
// Strategy 2: key component match (0.5× weight)
|
||
if let Some(matches) = component_map.get(term.as_str()) {
|
||
for (orig_key, node_weight) in matches {
|
||
let score = term_weight * node_weight * 0.5;
|
||
*seed_map.entry(orig_key.clone()).or_insert(0.0) += score;
|
||
direct_hits.insert(orig_key.clone());
|
||
}
|
||
continue;
|
||
}
|
||
|
||
// Strategy 3: content match (0.2× weight, limited to avoid O(n*m) explosion)
|
||
let term_lower = term.to_lowercase();
|
||
if term_lower.len() < 3 { continue; }
|
||
let mut content_hits = 0;
|
||
store.for_each_node(|key, content, weight| {
|
||
if content_hits >= 50 { return; }
|
||
if content.to_lowercase().contains(&term_lower) {
|
||
let score = term_weight * weight as f64 * 0.2;
|
||
*seed_map.entry(key.to_owned()).or_insert(0.0) += score;
|
||
content_hits += 1;
|
||
}
|
||
});
|
||
}
|
||
|
||
let seeds: Vec<(String, f64)> = seed_map.into_iter().collect();
|
||
(seeds, direct_hits)
|
||
}
|
||
|
||
/// Run a pipeline of algorithm stages.
|
||
pub fn run_pipeline(
|
||
stages: &[AlgoStage],
|
||
seeds: Vec<(String, f64)>,
|
||
graph: &Graph,
|
||
store: &impl StoreView,
|
||
debug: bool,
|
||
max_results: usize,
|
||
) -> Vec<(String, f64)> {
|
||
let mut current = seeds;
|
||
|
||
for stage in stages {
|
||
if debug {
|
||
println!("\n[search] === {} ({} seeds in) ===", stage.algo, current.len());
|
||
}
|
||
|
||
current = match stage.algo {
|
||
Algorithm::Spread => run_spread(¤t, graph, store, stage, debug),
|
||
Algorithm::Spectral => run_spectral(¤t, graph, stage, debug),
|
||
Algorithm::Manifold => run_manifold(¤t, graph, stage, debug),
|
||
Algorithm::Confluence => run_confluence(¤t, graph, store, stage, debug),
|
||
Algorithm::Geodesic => run_geodesic(¤t, graph, stage, debug),
|
||
};
|
||
|
||
if debug {
|
||
println!("[search] {} → {} results", stage.algo, current.len());
|
||
for (i, (key, score)) in current.iter().enumerate().take(15) {
|
||
let cutoff = if i + 1 == max_results { " <-- cutoff" } else { "" };
|
||
println!(" [{:.4}] {}{}", score, key, cutoff);
|
||
}
|
||
if current.len() > 15 {
|
||
println!(" ... ({} more)", current.len() - 15);
|
||
}
|
||
}
|
||
}
|
||
|
||
current.truncate(max_results);
|
||
current
|
||
}
|
||
|
||
/// Spreading activation: propagate scores through graph edges.
|
||
///
|
||
/// Tunable params: max_hops (default from store), edge_decay (default from store),
|
||
/// min_activation (default from store).
|
||
fn run_spread(
|
||
seeds: &[(String, f64)],
|
||
graph: &Graph,
|
||
store: &impl StoreView,
|
||
stage: &AlgoStage,
|
||
_debug: bool,
|
||
) -> Vec<(String, f64)> {
|
||
let store_params = store.params();
|
||
let max_hops = stage.param_u32("max_hops", store_params.max_hops);
|
||
let edge_decay = stage.param_f64("edge_decay", store_params.edge_decay);
|
||
let min_activation = stage.param_f64("min_activation", store_params.min_activation * 0.1);
|
||
|
||
spreading_activation(seeds, graph, store, max_hops, edge_decay, min_activation)
|
||
}
|
||
|
||
/// Spectral projection: find nearest neighbors in spectral embedding space.
|
||
///
|
||
/// Tunable params: k (default 20, number of neighbors to find).
|
||
fn run_spectral(
|
||
seeds: &[(String, f64)],
|
||
graph: &Graph,
|
||
stage: &AlgoStage,
|
||
debug: bool,
|
||
) -> Vec<(String, f64)> {
|
||
let k = stage.param_usize("k", 20);
|
||
|
||
let emb = match spectral::load_embedding() {
|
||
Ok(e) => e,
|
||
Err(e) => {
|
||
if debug { println!(" no spectral embedding: {}", e); }
|
||
return seeds.to_vec();
|
||
}
|
||
};
|
||
|
||
let weighted_seeds: Vec<(&str, f64)> = seeds.iter()
|
||
.map(|(k, w)| (k.as_str(), *w))
|
||
.collect();
|
||
let projected = spectral::nearest_to_seeds_weighted(
|
||
&emb, &weighted_seeds, Some(graph), k,
|
||
);
|
||
|
||
if debug {
|
||
for (key, dist) in &projected {
|
||
let score = 1.0 / (1.0 + dist);
|
||
println!(" dist={:.6} score={:.4} {}", dist, score, key);
|
||
}
|
||
}
|
||
|
||
// Merge: keep original seeds, add spectral results as new seeds
|
||
let seed_set: HashSet<&str> = seeds.iter().map(|(k, _)| k.as_str()).collect();
|
||
let mut result = seeds.to_vec();
|
||
for (key, dist) in projected {
|
||
if !seed_set.contains(key.as_str()) {
|
||
let score = 1.0 / (1.0 + dist);
|
||
result.push((key, score));
|
||
}
|
||
}
|
||
result
|
||
}
|
||
|
||
/// Confluence: multi-source reachability scoring.
|
||
///
|
||
/// Unlike spreading activation (which takes max activation from any source),
|
||
/// confluence rewards nodes reachable from *multiple* seeds. For each candidate
|
||
/// node within k hops, score = sum of (seed_weight * edge_decay^distance) across
|
||
/// all seeds that can reach it. Nodes at the intersection of multiple seeds'
|
||
/// neighborhoods score highest.
|
||
///
|
||
/// This naturally handles mixed seeds: unrelated seeds activate disjoint
|
||
/// neighborhoods that don't overlap, so their results separate naturally.
|
||
///
|
||
/// Tunable params: max_hops (default 3), edge_decay (default 0.5),
|
||
/// min_sources (default 2, minimum number of distinct seeds that must reach a node).
|
||
fn run_confluence(
|
||
seeds: &[(String, f64)],
|
||
graph: &Graph,
|
||
store: &impl StoreView,
|
||
stage: &AlgoStage,
|
||
debug: bool,
|
||
) -> Vec<(String, f64)> {
|
||
let max_hops = stage.param_u32("max_hops", 3);
|
||
let edge_decay = stage.param_f64("edge_decay", 0.5);
|
||
let min_sources = stage.param_usize("min_sources", 2);
|
||
|
||
// For each seed, BFS outward collecting (node → activation) at each distance
|
||
// Track which seeds contributed to each node's score
|
||
let mut node_scores: HashMap<String, f64> = HashMap::new();
|
||
let mut node_sources: HashMap<String, HashSet<usize>> = HashMap::new();
|
||
|
||
for (seed_idx, (seed_key, seed_weight)) in seeds.iter().enumerate() {
|
||
let mut visited: HashMap<String, f64> = HashMap::new();
|
||
let mut queue: VecDeque<(String, u32)> = VecDeque::new();
|
||
|
||
visited.insert(seed_key.clone(), *seed_weight);
|
||
queue.push_back((seed_key.clone(), 0));
|
||
|
||
while let Some((key, depth)) = queue.pop_front() {
|
||
if depth >= max_hops { continue; }
|
||
|
||
let act = visited[&key];
|
||
|
||
for (neighbor, strength) in graph.neighbors(&key) {
|
||
let neighbor_weight = store.node_weight(neighbor.as_str());
|
||
let propagated = act * edge_decay * neighbor_weight * strength as f64;
|
||
if propagated < 0.001 { continue; }
|
||
|
||
if !visited.contains_key(neighbor.as_str()) || visited[neighbor.as_str()] < propagated {
|
||
visited.insert(neighbor.clone(), propagated);
|
||
queue.push_back((neighbor.clone(), depth + 1));
|
||
}
|
||
}
|
||
}
|
||
|
||
// Accumulate into global scores (additive across seeds)
|
||
for (key, act) in visited {
|
||
*node_scores.entry(key.clone()).or_insert(0.0) += act;
|
||
node_sources.entry(key).or_default().insert(seed_idx);
|
||
}
|
||
}
|
||
|
||
// Filter to nodes reached by min_sources distinct seeds
|
||
let mut results: Vec<(String, f64)> = node_scores.into_iter()
|
||
.filter(|(key, _)| {
|
||
node_sources.get(key).map(|s| s.len()).unwrap_or(0) >= min_sources
|
||
})
|
||
.collect();
|
||
|
||
if debug {
|
||
// Show source counts
|
||
for (key, score) in results.iter().take(15) {
|
||
let sources = node_sources.get(key).map(|s| s.len()).unwrap_or(0);
|
||
println!(" [{:.4}] {} (from {} seeds)", score, key, sources);
|
||
}
|
||
}
|
||
|
||
results.sort_by(|a, b| b.1.total_cmp(&a.1));
|
||
results
|
||
}
|
||
|
||
/// Geodesic: straightest paths between seed pairs in spectral space.
|
||
///
|
||
/// For each pair of seeds, walk the graph from one to the other, at each
|
||
/// step choosing the neighbor whose spectral direction most aligns with
|
||
/// the target direction. Nodes along these geodesic paths score higher
|
||
/// the more paths pass through them and the straighter those paths are.
|
||
///
|
||
/// Tunable params: max_path (default 6), k (default 20 results).
|
||
fn run_geodesic(
|
||
seeds: &[(String, f64)],
|
||
graph: &Graph,
|
||
stage: &AlgoStage,
|
||
debug: bool,
|
||
) -> Vec<(String, f64)> {
|
||
let max_path = stage.param_usize("max_path", 6);
|
||
let k = stage.param_usize("k", 20);
|
||
|
||
let emb = match spectral::load_embedding() {
|
||
Ok(e) => e,
|
||
Err(e) => {
|
||
if debug { println!(" no spectral embedding: {}", e); }
|
||
return seeds.to_vec();
|
||
}
|
||
};
|
||
|
||
// Filter seeds to those with valid spectral coords
|
||
let valid_seeds: Vec<(&str, f64, &Vec<f64>)> = seeds.iter()
|
||
.filter_map(|(key, weight)| {
|
||
emb.coords.get(key.as_str())
|
||
.filter(|c| c.iter().any(|&v| v.abs() > 1e-12))
|
||
.map(|c| (key.as_str(), *weight, c))
|
||
})
|
||
.collect();
|
||
|
||
if valid_seeds.len() < 2 {
|
||
if debug { println!(" need ≥2 seeds with spectral coords, have {}", valid_seeds.len()); }
|
||
return seeds.to_vec();
|
||
}
|
||
|
||
// For each pair of seeds, find the geodesic path
|
||
let mut path_counts: HashMap<String, f64> = HashMap::new();
|
||
let seed_set: HashSet<&str> = seeds.iter().map(|(k, _)| k.as_str()).collect();
|
||
|
||
for i in 0..valid_seeds.len() {
|
||
for j in (i + 1)..valid_seeds.len() {
|
||
let (key_a, weight_a, coords_a) = &valid_seeds[i];
|
||
let (key_b, weight_b, coords_b) = &valid_seeds[j];
|
||
let pair_weight = weight_a * weight_b;
|
||
|
||
// Walk from A toward B
|
||
let path_ab = geodesic_walk(
|
||
key_a, coords_a, coords_b, graph, &emb, max_path,
|
||
);
|
||
// Walk from B toward A
|
||
let path_ba = geodesic_walk(
|
||
key_b, coords_b, coords_a, graph, &emb, max_path,
|
||
);
|
||
|
||
// Score nodes on both paths (nodes found from both directions score double)
|
||
for (node, alignment) in path_ab.iter().chain(path_ba.iter()) {
|
||
if !seed_set.contains(node.as_str()) {
|
||
*path_counts.entry(node.clone()).or_insert(0.0) += pair_weight * alignment;
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
if debug && !path_counts.is_empty() {
|
||
println!(" {} pairs examined, {} distinct nodes on paths",
|
||
valid_seeds.len() * (valid_seeds.len() - 1) / 2,
|
||
path_counts.len());
|
||
}
|
||
|
||
// Merge with original seeds
|
||
let mut results = seeds.to_vec();
|
||
let mut path_results: Vec<(String, f64)> = path_counts.into_iter().collect();
|
||
path_results.sort_by(|a, b| b.1.total_cmp(&a.1));
|
||
path_results.truncate(k);
|
||
|
||
for (key, score) in path_results {
|
||
if !seed_set.contains(key.as_str()) {
|
||
results.push((key, score));
|
||
}
|
||
}
|
||
|
||
results.sort_by(|a, b| b.1.total_cmp(&a.1));
|
||
results
|
||
}
|
||
|
||
/// Walk from `start` toward `target_coords` in spectral space, choosing
|
||
/// the neighbor at each step whose direction most aligns with the target.
|
||
/// Returns (node_key, alignment_score) for each intermediate node.
|
||
fn geodesic_walk(
|
||
start: &str,
|
||
start_coords: &[f64],
|
||
target_coords: &[f64],
|
||
graph: &Graph,
|
||
emb: &spectral::SpectralEmbedding,
|
||
max_steps: usize,
|
||
) -> Vec<(String, f64)> {
|
||
let mut path = Vec::new();
|
||
let mut current = start.to_string();
|
||
let mut current_coords = start_coords.to_vec();
|
||
let mut visited: HashSet<String> = HashSet::new();
|
||
visited.insert(current.clone());
|
||
|
||
for _ in 0..max_steps {
|
||
// Direction we want to travel: from current toward target
|
||
let direction: Vec<f64> = target_coords.iter()
|
||
.zip(current_coords.iter())
|
||
.map(|(t, c)| t - c)
|
||
.collect();
|
||
|
||
let dir_norm = direction.iter().map(|d| d * d).sum::<f64>().sqrt();
|
||
if dir_norm < 1e-12 { break; } // arrived
|
||
|
||
// Among neighbors with spectral coords, find the one most aligned
|
||
let mut best: Option<(String, Vec<f64>, f64)> = None;
|
||
|
||
for (neighbor, _strength) in graph.neighbors(¤t) {
|
||
if visited.contains(neighbor.as_str()) { continue; }
|
||
|
||
let neighbor_coords = match emb.coords.get(neighbor.as_str()) {
|
||
Some(c) if c.iter().any(|&v| v.abs() > 1e-12) => c,
|
||
_ => continue,
|
||
};
|
||
|
||
// Direction to this neighbor
|
||
let step: Vec<f64> = neighbor_coords.iter()
|
||
.zip(current_coords.iter())
|
||
.map(|(n, c)| n - c)
|
||
.collect();
|
||
|
||
let step_norm = step.iter().map(|s| s * s).sum::<f64>().sqrt();
|
||
if step_norm < 1e-12 { continue; }
|
||
|
||
// Cosine similarity between desired direction and step direction
|
||
let dot: f64 = direction.iter().zip(step.iter()).map(|(d, s)| d * s).sum();
|
||
let alignment = dot / (dir_norm * step_norm);
|
||
|
||
if alignment > 0.0 { // only consider forward-facing neighbors
|
||
if best.as_ref().map(|(_, _, a)| alignment > *a).unwrap_or(true) {
|
||
best = Some((neighbor.clone(), neighbor_coords.clone(), alignment));
|
||
}
|
||
}
|
||
}
|
||
|
||
match best {
|
||
Some((next_key, next_coords, alignment)) => {
|
||
path.push((next_key.clone(), alignment));
|
||
visited.insert(next_key.clone());
|
||
current = next_key;
|
||
current_coords = next_coords;
|
||
}
|
||
None => break, // no forward-facing neighbors
|
||
}
|
||
}
|
||
|
||
path
|
||
}
|
||
|
||
/// Manifold: extrapolation along the direction defined by seeds.
|
||
///
|
||
/// Instead of finding what's *near* the seeds in spectral space (proximity),
|
||
/// find what's in the *direction* the seeds define. Given a weighted centroid
|
||
/// of seeds and the principal direction they span, find nodes that continue
|
||
/// along that direction.
|
||
///
|
||
/// Tunable params: k (default 20 results).
|
||
fn run_manifold(
|
||
seeds: &[(String, f64)],
|
||
graph: &Graph,
|
||
stage: &AlgoStage,
|
||
debug: bool,
|
||
) -> Vec<(String, f64)> {
|
||
let k = stage.param_usize("k", 20);
|
||
|
||
let emb = match spectral::load_embedding() {
|
||
Ok(e) => e,
|
||
Err(e) => {
|
||
if debug { println!(" no spectral embedding: {}", e); }
|
||
return seeds.to_vec();
|
||
}
|
||
};
|
||
|
||
// Collect seeds with valid spectral coordinates
|
||
let seed_data: Vec<(&str, f64, &Vec<f64>)> = seeds.iter()
|
||
.filter_map(|(key, weight)| {
|
||
emb.coords.get(key.as_str())
|
||
.filter(|c| c.iter().any(|&v| v.abs() > 1e-12))
|
||
.map(|c| (key.as_str(), *weight, c))
|
||
})
|
||
.collect();
|
||
|
||
if seed_data.is_empty() {
|
||
if debug { println!(" no seeds with spectral coords"); }
|
||
return seeds.to_vec();
|
||
}
|
||
|
||
let dims = emb.dims;
|
||
|
||
// Compute weighted centroid of seeds
|
||
let mut centroid = vec![0.0f64; dims];
|
||
let mut total_weight = 0.0;
|
||
for (_, weight, coords) in &seed_data {
|
||
for (i, &c) in coords.iter().enumerate() {
|
||
centroid[i] += c * weight;
|
||
}
|
||
total_weight += weight;
|
||
}
|
||
if total_weight > 0.0 {
|
||
for c in &mut centroid {
|
||
*c /= total_weight;
|
||
}
|
||
}
|
||
|
||
// Compute principal direction via power iteration on seed covariance.
|
||
// Initialize with the two most separated seeds (largest spectral distance).
|
||
let mut direction = vec![0.0f64; dims];
|
||
if seed_data.len() >= 2 {
|
||
// Find the two seeds furthest apart in spectral space
|
||
let mut best_dist = 0.0f64;
|
||
for i in 0..seed_data.len() {
|
||
for j in (i + 1)..seed_data.len() {
|
||
let dist: f64 = seed_data[i].2.iter().zip(seed_data[j].2.iter())
|
||
.map(|(a, b)| (a - b).powi(2)).sum::<f64>().sqrt();
|
||
if dist > best_dist {
|
||
best_dist = dist;
|
||
for d in 0..dims {
|
||
direction[d] = seed_data[j].2[d] - seed_data[i].2[d];
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
// Power iteration: 3 rounds on the weighted covariance matrix
|
||
for _ in 0..3 {
|
||
let mut new_dir = vec![0.0f64; dims];
|
||
for (_, weight, coords) in &seed_data {
|
||
let dev: Vec<f64> = coords.iter().zip(centroid.iter()).map(|(c, m)| c - m).collect();
|
||
let dot: f64 = dev.iter().zip(direction.iter()).map(|(d, v)| d * v).sum();
|
||
for d in 0..dims {
|
||
new_dir[d] += weight * dot * dev[d];
|
||
}
|
||
}
|
||
// Normalize
|
||
let norm = new_dir.iter().map(|d| d * d).sum::<f64>().sqrt();
|
||
if norm > 1e-12 {
|
||
for d in &mut new_dir { *d /= norm; }
|
||
}
|
||
direction = new_dir;
|
||
}
|
||
}
|
||
|
||
let dir_norm = direction.iter().map(|d| d * d).sum::<f64>().sqrt();
|
||
|
||
let seed_set: HashSet<&str> = seeds.iter().map(|(k, _)| k.as_str()).collect();
|
||
|
||
// Score each non-seed node by projection onto the direction from centroid
|
||
let mut candidates: Vec<(String, f64)> = emb.coords.iter()
|
||
.filter(|(key, coords)| {
|
||
!seed_set.contains(key.as_str())
|
||
&& coords.iter().any(|&v| v.abs() > 1e-12)
|
||
})
|
||
.map(|(key, coords)| {
|
||
let deviation: Vec<f64> = coords.iter().zip(centroid.iter())
|
||
.map(|(c, m)| c - m)
|
||
.collect();
|
||
|
||
let score = if dir_norm > 1e-12 {
|
||
// Project onto direction: how far along the principal axis
|
||
let projection: f64 = deviation.iter().zip(direction.iter())
|
||
.map(|(d, v)| d * v)
|
||
.sum::<f64>() / dir_norm;
|
||
|
||
// Distance from the axis (perpendicular component)
|
||
let proj_vec: Vec<f64> = direction.iter()
|
||
.map(|&d| d * projection / dir_norm)
|
||
.collect();
|
||
let perp_dist: f64 = deviation.iter().zip(proj_vec.iter())
|
||
.map(|(d, p)| (d - p).powi(2))
|
||
.sum::<f64>()
|
||
.sqrt();
|
||
|
||
// Score: prefer nodes far along the direction but close to the axis
|
||
// Use absolute projection (both directions from centroid are interesting)
|
||
let along = projection.abs();
|
||
if perp_dist < 1e-12 {
|
||
along
|
||
} else {
|
||
along / (1.0 + perp_dist)
|
||
}
|
||
} else {
|
||
// No direction (single seed or all seeds coincide): use distance from centroid
|
||
let dist: f64 = deviation.iter().map(|d| d * d).sum::<f64>().sqrt();
|
||
1.0 / (1.0 + dist)
|
||
};
|
||
|
||
// Bonus for being connected to seeds in the graph
|
||
let graph_bonus: f64 = graph.neighbors(key).iter()
|
||
.filter(|(n, _)| seed_set.contains(n.as_str()))
|
||
.map(|(_, s)| *s as f64 * 0.1)
|
||
.sum();
|
||
|
||
(key.clone(), score + graph_bonus)
|
||
})
|
||
.collect();
|
||
|
||
candidates.sort_by(|a, b| b.1.total_cmp(&a.1));
|
||
candidates.truncate(k);
|
||
|
||
if debug {
|
||
for (key, score) in candidates.iter().take(15) {
|
||
println!(" [{:.4}] {}", score, key);
|
||
}
|
||
}
|
||
|
||
// Merge with original seeds
|
||
let mut results = seeds.to_vec();
|
||
for (key, score) in candidates {
|
||
results.push((key, score));
|
||
}
|
||
results.sort_by(|a, b| b.1.total_cmp(&a.1));
|
||
results
|
||
}
|
||
|
||
/// Simultaneous wavefront spreading activation.
|
||
///
|
||
/// All seeds emit at once. At each hop, activations from all sources
|
||
/// sum at each node, and the combined activation map propagates on
|
||
/// the next hop. This creates interference patterns — nodes where
|
||
/// multiple wavefronts overlap get reinforced and radiate stronger.
|
||
fn spreading_activation(
|
||
seeds: &[(String, f64)],
|
||
graph: &Graph,
|
||
store: &impl StoreView,
|
||
max_hops: u32,
|
||
edge_decay: f64,
|
||
min_activation: f64,
|
||
) -> Vec<(String, f64)> {
|
||
let mut activation: HashMap<String, f64> = HashMap::new();
|
||
|
||
// Initialize wavefront from all seeds
|
||
let mut frontier: HashMap<String, f64> = HashMap::new();
|
||
for (key, act) in seeds {
|
||
*frontier.entry(key.clone()).or_insert(0.0) += act;
|
||
*activation.entry(key.clone()).or_insert(0.0) += act;
|
||
}
|
||
|
||
// Propagate hop by hop — all sources simultaneously
|
||
// Node weight does NOT gate traversal — only edge_decay and edge strength.
|
||
// Node weight is applied at the end for ranking.
|
||
for _hop in 0..max_hops {
|
||
let mut next_frontier: HashMap<String, f64> = HashMap::new();
|
||
|
||
for (key, act) in &frontier {
|
||
for (neighbor, strength) in graph.neighbors(key) {
|
||
let propagated = act * edge_decay * strength as f64;
|
||
if propagated < min_activation { continue; }
|
||
|
||
*next_frontier.entry(neighbor.clone()).or_insert(0.0) += propagated;
|
||
}
|
||
}
|
||
|
||
if next_frontier.is_empty() { break; }
|
||
|
||
// Merge into total activation and advance frontier
|
||
for (key, act) in &next_frontier {
|
||
*activation.entry(key.clone()).or_insert(0.0) += act;
|
||
}
|
||
frontier = next_frontier;
|
||
}
|
||
|
||
// Apply node weight for ranking, not traversal
|
||
let mut results: Vec<_> = activation.into_iter()
|
||
.map(|(key, act)| {
|
||
let weight = store.node_weight(&key);
|
||
(key, act * weight)
|
||
})
|
||
.collect();
|
||
results.sort_by(|a, b| b.1.total_cmp(&a.1));
|
||
results
|
||
}
|
||
|
||
/// Search with weighted terms: exact key matching + spectral projection.
|
||
///
|
||
/// Terms are matched against node keys. Matching nodes become seeds,
|
||
/// scored by term_weight × node_weight. Seeds are then projected into
|
||
/// spectral space to find nearby nodes, with link weights modulating distance.
|
||
pub fn search_weighted(
|
||
terms: &BTreeMap<String, f64>,
|
||
store: &impl StoreView,
|
||
) -> Vec<SearchResult> {
|
||
search_weighted_inner(terms, store, false, 5)
|
||
}
|
||
|
||
/// Like search_weighted but with debug output and configurable result count.
|
||
pub fn search_weighted_debug(
|
||
terms: &BTreeMap<String, f64>,
|
||
store: &impl StoreView,
|
||
max_results: usize,
|
||
) -> Vec<SearchResult> {
|
||
search_weighted_inner(terms, store, true, max_results)
|
||
}
|
||
|
||
fn search_weighted_inner(
|
||
terms: &BTreeMap<String, f64>,
|
||
store: &impl StoreView,
|
||
debug: bool,
|
||
max_results: usize,
|
||
) -> Vec<SearchResult> {
|
||
let graph = crate::graph::build_graph_fast(store);
|
||
let (seeds, direct_hits) = match_seeds(terms, store);
|
||
|
||
if seeds.is_empty() {
|
||
return Vec::new();
|
||
}
|
||
|
||
if debug {
|
||
println!("\n[search] === SEEDS ({}) ===", seeds.len());
|
||
let mut sorted_seeds = seeds.clone();
|
||
sorted_seeds.sort_by(|a, b| b.1.total_cmp(&a.1));
|
||
for (key, score) in sorted_seeds.iter().take(20) {
|
||
println!(" {:.4} {}", score, key);
|
||
}
|
||
}
|
||
|
||
// Default pipeline: spectral → spread (legacy behavior)
|
||
let pipeline = vec![
|
||
AlgoStage { algo: Algorithm::Spectral, params: HashMap::new() },
|
||
AlgoStage { algo: Algorithm::Spread, params: HashMap::new() },
|
||
];
|
||
|
||
let raw_results = run_pipeline(&pipeline, seeds, &graph, store, debug, max_results);
|
||
|
||
raw_results.into_iter()
|
||
.take(max_results)
|
||
.map(|(key, activation)| {
|
||
let is_direct = direct_hits.contains(&key);
|
||
SearchResult { key, activation, is_direct, snippet: None }
|
||
}).collect()
|
||
}
|
||
|
||
/// Search with equal-weight terms (for interactive use).
|
||
pub fn search(query: &str, store: &impl StoreView) -> Vec<SearchResult> {
|
||
let terms: BTreeMap<String, f64> = query.split_whitespace()
|
||
.map(|t| (t.to_lowercase(), 1.0))
|
||
.collect();
|
||
search_weighted(&terms, store)
|
||
}
|
||
|
||
/// Extract meaningful search terms from natural language.
|
||
/// Strips common English stop words, returns up to max_terms words.
|
||
pub fn extract_query_terms(text: &str, max_terms: usize) -> String {
|
||
const STOP_WORDS: &[&str] = &[
|
||
"the", "a", "an", "is", "are", "was", "were", "do", "does", "did",
|
||
"have", "has", "had", "will", "would", "could", "should", "can",
|
||
"may", "might", "shall", "been", "being", "to", "of", "in", "for",
|
||
"on", "with", "at", "by", "from", "as", "but", "or", "and", "not",
|
||
"no", "if", "then", "than", "that", "this", "it", "its", "my",
|
||
"your", "our", "we", "you", "i", "me", "he", "she", "they", "them",
|
||
"what", "how", "why", "when", "where", "about", "just", "let",
|
||
"want", "tell", "show", "think", "know", "see", "look", "make",
|
||
"get", "go", "some", "any", "all", "very", "really", "also", "too",
|
||
"so", "up", "out", "here", "there",
|
||
];
|
||
|
||
text.to_lowercase()
|
||
.split(|c: char| !c.is_alphanumeric())
|
||
.filter(|w| !w.is_empty() && w.len() > 2 && !STOP_WORDS.contains(w))
|
||
.take(max_terms)
|
||
.collect::<Vec<_>>()
|
||
.join(" ")
|
||
}
|
||
|
||
/// Format search results as text lines (for hook consumption).
|
||
pub fn format_results(results: &[SearchResult]) -> String {
|
||
let mut out = String::new();
|
||
for (i, r) in results.iter().enumerate().take(5) {
|
||
let marker = if r.is_direct { "→" } else { " " };
|
||
out.push_str(&format!("{}{:2}. [{:.2}/{:.2}] {}",
|
||
marker, i + 1, r.activation, r.activation, r.key));
|
||
out.push('\n');
|
||
if let Some(ref snippet) = r.snippet {
|
||
out.push_str(&format!(" {}\n", snippet));
|
||
}
|
||
}
|
||
out
|
||
}
|