From c1664bf76b20e35cc27cb8b74cd4d83dc25a7580 Mon Sep 17 00:00:00 2001 From: Kent Overstreet Date: Mon, 9 Mar 2026 01:19:04 -0400 Subject: [PATCH] search: composable algorithm pipeline MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Break search into composable stages that chain left-to-right: each stage takes seeds Vec<(String, f64)> and returns modified seeds. Available algorithms: spread — spreading activation through graph edges spectral — nearest neighbors in spectral embedding manifold — (placeholder) extrapolation along seed direction Stages accept inline params: spread,max_hops=4,edge_decay=0.5 memory-search gets --hook, --debug, --seen modes plus positional pipeline args. poc-memory search gets -p/--pipeline flags. Also: fix spectral decompose() to skip zero eigenvalues from disconnected components, filter degenerate zero-coord nodes from spectral projection, POC_AGENT bail-out for daemon agents, all debug output to stdout. Co-Authored-By: ProofOfConcept --- poc-memory/src/bin/memory-search.rs | 351 +++++++++++++++++++++++++--- poc-memory/src/main.rs | 114 ++++----- poc-memory/src/search.rs | 330 +++++++++++++++++++++----- poc-memory/src/spectral.rs | 79 ++++++- 4 files changed, 723 insertions(+), 151 deletions(-) diff --git a/poc-memory/src/bin/memory-search.rs b/poc-memory/src/bin/memory-search.rs index 48e8bdf..5feeab5 100644 --- a/poc-memory/src/bin/memory-search.rs +++ b/poc-memory/src/bin/memory-search.rs @@ -1,24 +1,76 @@ // memory-search: combined hook for session context loading + ambient memory retrieval // -// On first prompt per session: loads full memory context (identity, journal, etc.) -// On subsequent prompts: searches memory for relevant entries -// On post-compaction: reloads full context -// -// Reads JSON from stdin (Claude Code UserPromptSubmit hook format), -// outputs results for injection into the conversation. +// Modes: +// --hook Run as Claude Code UserPromptSubmit hook (reads stdin, injects into conversation) +// --debug Replay last stashed input, dump every stage to stdout +// --seen Show the seen set for current session +// (default) No-op (future: manual search modes) -use poc_memory::search; +use clap::Parser; +use poc_memory::search::{self, AlgoStage}; use poc_memory::store; -use std::collections::HashSet; +use std::collections::{BTreeMap, HashSet}; use std::fs; use std::io::{self, Read, Write}; use std::path::{Path, PathBuf}; use std::process::Command; use std::time::{Duration, SystemTime}; +#[derive(Parser)] +#[command(name = "memory-search")] +struct Args { + /// Run as Claude Code hook (reads stdin, outputs for injection) + #[arg(long)] + hook: bool, + + /// Debug mode: replay last stashed input, dump every stage + #[arg(short, long)] + debug: bool, + + /// Show the seen set and returned memories for this session + #[arg(long)] + seen: bool, + + /// Max results to return + #[arg(long, default_value = "5")] + max_results: usize, + + /// Algorithm pipeline stages: e.g. spread spectral,k=20 spread,max_hops=4 + /// Default: spread. + pipeline: Vec, +} + +const STASH_PATH: &str = "/tmp/claude-memory-search/last-input.json"; + fn main() { - let mut input = String::new(); - io::stdin().read_to_string(&mut input).unwrap_or_default(); + // Daemon agent calls set POC_AGENT=1 — skip memory search. + if std::env::var("POC_AGENT").is_ok() { + return; + } + + let args = Args::parse(); + + if args.seen { + show_seen(); + return; + } + + let input = if args.hook { + // Hook mode: read from stdin, stash for later debug runs + let mut buf = String::new(); + io::stdin().read_to_string(&mut buf).unwrap_or_default(); + fs::create_dir_all("/tmp/claude-memory-search").ok(); + fs::write(STASH_PATH, &buf).ok(); + buf + } else { + // All other modes: replay stashed input + fs::read_to_string(STASH_PATH).unwrap_or_else(|_| { + eprintln!("No stashed input at {}", STASH_PATH); + std::process::exit(1); + }) + }; + + let debug = args.debug || !args.hook; let json: serde_json::Value = match serde_json::from_str(&input) { Ok(v) => v, @@ -42,6 +94,16 @@ fn main() { let cookie_path = state_dir.join(format!("cookie-{}", session_id)); let is_first = !cookie_path.exists(); + if is_first || is_compaction { + // Reset seen set to keys that load-context will inject + let seen_path = state_dir.join(format!("seen-{}", session_id)); + fs::remove_file(&seen_path).ok(); + } + + if debug { + println!("[memory-search] session={} is_first={} is_compaction={}", session_id, is_first, is_compaction); + } + if is_first || is_compaction { // Create/touch the cookie let cookie = if is_first { @@ -52,52 +114,135 @@ fn main() { fs::read_to_string(&cookie_path).unwrap_or_default().trim().to_string() }; - // Load full memory context + if debug { println!("[memory-search] loading full context"); } + + // Load full memory context and pre-populate seen set with injected keys if let Ok(output) = Command::new("poc-memory").args(["load-context"]).output() { if output.status.success() { let ctx = String::from_utf8_lossy(&output.stdout); if !ctx.trim().is_empty() { - print!("{}", ctx); + // Extract keys from "--- KEY (group) ---" lines + for line in ctx.lines() { + if line.starts_with("--- ") && line.ends_with(" ---") { + let inner = &line[4..line.len() - 4]; + if let Some(paren) = inner.rfind(" (") { + let key = inner[..paren].trim(); + mark_seen(&state_dir, session_id, key); + } + } + } + if debug { println!("[memory-search] context loaded: {} bytes", ctx.len()); } + if args.hook { + print!("{}", ctx); + } } } } - // On first prompt, also bump lookup counter for the cookie - let _ = cookie; // used for tagging below - } - - // Always do ambient search (skip on very short or system prompts) - let word_count = prompt.split_whitespace().count(); - if word_count < 3 { - return; + let _ = cookie; } + // Skip system/AFK prompts for prefix in &["is AFK", "You're on your own", "IRC mention"] { if prompt.starts_with(prefix) { return; } } - let query = search::extract_query_terms(prompt, 3); - if query.is_empty() { - return; - } - let store = match store::Store::load() { Ok(s) => s, Err(_) => return, }; - let results = search::search(&query, &store); - if results.is_empty() { + // Search for node keys in last ~150k tokens of transcript + let transcript_path = json["transcript_path"].as_str().unwrap_or(""); + if debug { println!("[memory-search] transcript: {}", transcript_path); } + let terms = extract_weighted_terms(transcript_path, 150_000, &store); + + if debug { + println!("[memory-search] {} node keys found in transcript", terms.len()); + let mut by_weight: Vec<_> = terms.iter().collect(); + by_weight.sort_by(|a, b| b.1.total_cmp(a.1)); + for (term, weight) in by_weight.iter().take(20) { + println!(" {:.3} {}", weight, term); + } + } + + if terms.is_empty() { + if debug { println!("[memory-search] no node keys found, done"); } return; } + // Parse algorithm pipeline + let pipeline: Vec = if args.pipeline.is_empty() { + // Default: just spreading activation + vec![AlgoStage::parse("spread").unwrap()] + } else { + let mut stages = Vec::new(); + for arg in &args.pipeline { + match AlgoStage::parse(arg) { + Ok(s) => stages.push(s), + Err(e) => { + eprintln!("error: {}", e); + std::process::exit(1); + } + } + } + stages + }; + + if debug { + let names: Vec = pipeline.iter().map(|s| format!("{}", s.algo)).collect(); + println!("[memory-search] pipeline: {}", names.join(" → ")); + } + + // Extract seeds from terms + let graph = poc_memory::graph::build_graph_fast(&store); + let (seeds, direct_hits) = search::match_seeds(&terms, &store); + + if seeds.is_empty() { + if debug { println!("[memory-search] no seeds matched, done"); } + return; + } + + if debug { + println!("[memory-search] {} seeds", seeds.len()); + let mut sorted = seeds.clone(); + sorted.sort_by(|a, b| b.1.total_cmp(&a.1)); + for (key, score) in sorted.iter().take(20) { + println!(" {:.4} {}", score, key); + } + } + + let max_results = if debug { args.max_results.max(25) } else { args.max_results }; + let raw_results = search::run_pipeline(&pipeline, seeds, &graph, &store, debug, max_results); + + let results: Vec = raw_results.into_iter() + .map(|(key, activation)| { + let is_direct = direct_hits.contains(&key); + search::SearchResult { key, activation, is_direct, snippet: None } + }).collect(); + + if debug { + println!("[memory-search] {} search results", results.len()); + for r in results.iter().take(10) { + let marker = if r.is_direct { "→" } else { " " }; + println!(" {} [{:.4}] {}", marker, r.activation, r.key); + } + } + + if results.is_empty() { + if debug { println!("[memory-search] no results, done"); } + return; + } + + let seen = load_seen(&state_dir, session_id); + if debug { println!("[memory-search] {} keys in seen set", seen.len()); } + // Format results like poc-memory search output let search_output = search::format_results(&results); let cookie = fs::read_to_string(&cookie_path).unwrap_or_default().trim().to_string(); - let seen = load_seen(&state_dir, session_id); let mut result_output = String::new(); let mut count = 0; @@ -112,6 +257,7 @@ fn main() { if let Some(key) = extract_key_from_line(trimmed) { if seen.contains(&key) { continue; } mark_seen(&state_dir, session_id, &key); + mark_returned(&state_dir, session_id, &key); result_output.push_str(line); result_output.push('\n'); count += 1; @@ -121,9 +267,14 @@ fn main() { } } - if count == 0 { return; } + if count == 0 { + if debug { println!("[memory-search] all results already seen"); } + return; + } - println!("Recalled memories [{}]:", cookie); + if args.hook { + println!("Recalled memories [{}]:", cookie); + } print!("{}", result_output); // Clean up stale state files (opportunistic) @@ -131,6 +282,82 @@ fn main() { } +/// Reverse-scan the transcript JSONL, extracting text from user/assistant +/// messages until we accumulate `max_tokens` tokens of text content. +/// Then search for all node keys as substrings, weighted by position. +fn extract_weighted_terms( + path: &str, + max_tokens: usize, + store: &poc_memory::store::Store, +) -> BTreeMap { + if path.is_empty() { return BTreeMap::new(); } + + let content = match fs::read_to_string(path) { + Ok(c) => c, + Err(_) => return BTreeMap::new(), + }; + + // Collect text from messages, scanning backwards, until token budget hit + let mut message_texts: Vec = Vec::new(); + let mut token_count = 0; + + for line in content.lines().rev() { + if token_count >= max_tokens { break; } + + let obj: serde_json::Value = match serde_json::from_str(line) { + Ok(v) => v, + Err(_) => continue, + }; + + let msg_type = obj.get("type").and_then(|v| v.as_str()).unwrap_or(""); + if msg_type != "user" && msg_type != "assistant" { continue; } + + let mut msg_text = String::new(); + let msg = obj.get("message").unwrap_or(&obj); + match msg.get("content") { + Some(serde_json::Value::String(s)) => { + msg_text.push_str(s); + } + Some(serde_json::Value::Array(arr)) => { + for block in arr { + if block.get("type").and_then(|v| v.as_str()) == Some("text") { + if let Some(t) = block.get("text").and_then(|v| v.as_str()) { + msg_text.push(' '); + msg_text.push_str(t); + } + } + } + } + _ => {} + } + + token_count += msg_text.len() / 4; + message_texts.push(msg_text); + } + + // Reverse so oldest is first (position weighting: later = more recent = higher) + message_texts.reverse(); + let all_text = message_texts.join(" ").to_lowercase(); + let text_len = all_text.len(); + if text_len == 0 { return BTreeMap::new(); } + + // Search for each node key as a substring (casefolded), accumulate position-weighted score + let mut terms = BTreeMap::new(); + for (key, _node) in &store.nodes { + let key_folded = key.to_lowercase(); + let mut pos = 0; + while let Some(found) = all_text[pos..].find(&key_folded) { + let abs_pos = pos + found; + let weight = (abs_pos + 1) as f64 / text_len as f64; + *terms.entry(key_folded.clone()).or_insert(0.0) += weight; + pos = abs_pos + key_folded.len(); + } + } + + terms +} + + fn extract_key_from_line(line: &str) -> Option { let after_bracket = line.find("] ")?; let rest = &line[after_bracket + 2..]; @@ -167,6 +394,70 @@ fn mark_seen(dir: &Path, session_id: &str, key: &str) { } } +fn mark_returned(dir: &Path, session_id: &str, key: &str) { + let path = dir.join(format!("returned-{}", session_id)); + if let Ok(mut f) = fs::OpenOptions::new().create(true).append(true).open(path) { + writeln!(f, "{}", key).ok(); + } +} + +fn load_returned(dir: &Path, session_id: &str) -> Vec { + let path = dir.join(format!("returned-{}", session_id)); + if path.exists() { + fs::read_to_string(path) + .unwrap_or_default() + .lines() + .filter(|s| !s.is_empty()) + .map(|s| s.to_string()) + .collect() + } else { + Vec::new() + } +} + +fn show_seen() { + let state_dir = PathBuf::from("/tmp/claude-memory-search"); + + // Read stashed input for session_id + let input = match fs::read_to_string(STASH_PATH) { + Ok(s) => s, + Err(_) => { + eprintln!("No stashed input at {}", STASH_PATH); + return; + } + }; + let json: serde_json::Value = match serde_json::from_str(&input) { + Ok(v) => v, + Err(_) => { + eprintln!("Failed to parse stashed input"); + return; + } + }; + let session_id = json["session_id"].as_str().unwrap_or(""); + if session_id.is_empty() { + eprintln!("No session_id in stashed input"); + return; + } + + println!("Session: {}", session_id); + + let cookie_path = state_dir.join(format!("cookie-{}", session_id)); + if let Ok(cookie) = fs::read_to_string(&cookie_path) { + println!("Cookie: {}", cookie.trim()); + } + + let returned = load_returned(&state_dir, session_id); + if !returned.is_empty() { + println!("\nReturned by search ({}):", returned.len()); + for key in &returned { + println!(" {}", key); + } + } + + let seen = load_seen(&state_dir, session_id); + println!("\nSeen set ({} total, {} pre-seeded):", seen.len(), seen.len() - returned.len()); +} + fn cleanup_stale_files(dir: &Path, max_age: Duration) { let entries = match fs::read_dir(dir) { Ok(e) => e, diff --git a/poc-memory/src/main.rs b/poc-memory/src/main.rs index 482e0b3..673c4ce 100644 --- a/poc-memory/src/main.rs +++ b/poc-memory/src/main.rs @@ -59,12 +59,21 @@ struct Cli { #[derive(Subcommand)] enum Command { /// Search memory (AND logic across terms) + /// + /// Pipeline: -p spread -p spectral,k=20 + /// Default pipeline: spread Search { /// Search terms query: Vec, - /// Show 15 results instead of 5, plus spectral neighbors + /// Algorithm pipeline stages (repeatable) + #[arg(short, long = "pipeline")] + pipeline: Vec, + /// Show more results #[arg(long)] expand: bool, + /// Show debug output for each pipeline stage + #[arg(long)] + debug: bool, }, /// Scan markdown files, index all memory units Init, @@ -469,8 +478,8 @@ fn main() { let cli = Cli::parse(); let result = match cli.command { - Command::Search { query, expand } - => cmd_search(&query, expand), + Command::Search { query, pipeline, expand, debug } + => cmd_search(&query, &pipeline, expand, debug), Command::Init => cmd_init(), Command::Migrate => cmd_migrate(), Command::Health => cmd_health(), @@ -575,8 +584,9 @@ fn main() { // ── Command implementations ───────────────────────────────────────── -fn cmd_search(terms: &[String], expand: bool) -> Result<(), String> { +fn cmd_search(terms: &[String], pipeline_args: &[String], expand: bool, debug: bool) -> Result<(), String> { use store::StoreView; + use std::collections::BTreeMap; if terms.is_empty() { return Err("search requires at least one term".into()); @@ -584,70 +594,68 @@ fn cmd_search(terms: &[String], expand: bool) -> Result<(), String> { let query: String = terms.join(" "); + // Parse pipeline (default: spread) + let pipeline: Vec = if pipeline_args.is_empty() { + vec![search::AlgoStage::parse("spread").unwrap()] + } else { + pipeline_args.iter() + .map(|a| search::AlgoStage::parse(a)) + .collect::, _>>()? + }; + + if debug { + let names: Vec = pipeline.iter().map(|s| format!("{}", s.algo)).collect(); + println!("[search] pipeline: {}", names.join(" → ")); + } + let view = store::AnyView::load()?; - let results = search::search(&query, &view); + let graph = graph::build_graph_fast(&view); + + // Build equal-weight terms from query + let terms: BTreeMap = query.split_whitespace() + .map(|t| (t.to_lowercase(), 1.0)) + .collect(); + + let (seeds, direct_hits) = search::match_seeds(&terms, &view); + + if seeds.is_empty() { + eprintln!("No results for '{}'", query); + return Ok(()); + } + + if debug { + println!("[search] {} seeds from query '{}'", seeds.len(), query); + for (key, score) in &seeds { + println!(" {:.4} {}", score, key); + } + } + + let max_results = if expand { 15 } else { 5 }; + let raw = search::run_pipeline(&pipeline, seeds, &graph, &view, debug, max_results); + + let results: Vec = raw.into_iter() + .map(|(key, activation)| { + let is_direct = direct_hits.contains(&key); + search::SearchResult { key, activation, is_direct, snippet: None } + }) + .collect(); if results.is_empty() { eprintln!("No results for '{}'", query); return Ok(()); } - let limit = if expand { 15 } else { 5 }; - - // Log retrieval to a small append-only file (avoid 6MB state.bin rewrite) + // Log retrieval store::Store::log_retrieval_static(&query, &results.iter().map(|r| r.key.clone()).collect::>()); - // Bump daily lookup counters (fast path, no store needed) - let bump_keys: Vec<&str> = results.iter().take(limit).map(|r| r.key.as_str()).collect(); + let bump_keys: Vec<&str> = results.iter().take(max_results).map(|r| r.key.as_str()).collect(); let _ = lookups::bump_many(&bump_keys); - let text_keys: std::collections::HashSet = results.iter() - .take(limit).map(|r| r.key.clone()).collect(); - - for (i, r) in results.iter().enumerate().take(limit) { + for (i, r) in results.iter().enumerate().take(max_results) { let marker = if r.is_direct { "→" } else { " " }; let weight = view.node_weight(&r.key); println!("{}{:2}. [{:.2}/{:.2}] {}", marker, i + 1, r.activation, weight, r.key); - if let Some(ref snippet) = r.snippet { - println!(" {}", snippet); - } - } - - if expand { - if let Ok(emb) = spectral::load_embedding() { - let seeds: Vec<&str> = results.iter() - .take(5) - .map(|r| r.key.as_str()) - .filter(|k| emb.coords.contains_key(*k)) - .collect(); - - if !seeds.is_empty() { - let spectral_hits = spectral::nearest_to_seeds(&emb, &seeds, 10); - let new_hits: Vec<_> = spectral_hits.into_iter() - .filter(|(k, _)| !text_keys.contains(k)) - .take(5) - .collect(); - - if !new_hits.is_empty() { - println!("\nSpectral neighbors (structural, not keyword):"); - for (k, _dist) in &new_hits { - let weight = view.node_weight(k); - println!(" ~ [{:.2}] {}", weight, k); - if let Some(content) = view.node_content(k) { - let snippet = util::first_n_chars( - content.lines() - .find(|l| !l.trim().is_empty() && !l.starts_with('#')) - .unwrap_or(""), - 100); - if !snippet.is_empty() { - println!(" {}", snippet); - } - } - } - } - } - } } Ok(()) diff --git a/poc-memory/src/search.rs b/poc-memory/src/search.rs index 31e2f21..ec1d6f6 100644 --- a/poc-memory/src/search.rs +++ b/poc-memory/src/search.rs @@ -1,13 +1,22 @@ -// Spreading activation search across the memory graph +// Memory search: composable algorithm pipeline. // -// Same model as the old system but richer: uses graph edge strengths, -// supports circumscription parameter for blending associative vs -// causal walks, and benefits from community-aware result grouping. +// Each algorithm is a stage: takes seeds Vec<(String, f64)>, returns +// new/modified seeds. Stages compose left-to-right in a pipeline. +// +// Available algorithms: +// spread — spreading activation through graph edges +// spectral — nearest neighbors in spectral embedding space +// manifold — extrapolation along direction defined by seeds (TODO) +// +// Seed extraction (matching query terms to node keys) is shared +// infrastructure, not an algorithm stage. use crate::store::StoreView; use crate::graph::Graph; +use crate::spectral; -use std::collections::{HashMap, HashSet, VecDeque}; +use std::collections::{BTreeMap, HashMap, HashSet, VecDeque}; +use std::fmt; pub struct SearchResult { pub key: String, @@ -16,18 +25,211 @@ pub struct SearchResult { pub snippet: Option, } -/// Spreading activation with circumscription parameter. +/// A parsed algorithm stage with its parameters. +#[derive(Clone, Debug)] +pub struct AlgoStage { + pub algo: Algorithm, + pub params: HashMap, +} + +#[derive(Clone, Debug)] +pub enum Algorithm { + Spread, + Spectral, + Manifold, +} + +impl fmt::Display for Algorithm { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + Algorithm::Spread => write!(f, "spread"), + Algorithm::Spectral => write!(f, "spectral"), + Algorithm::Manifold => write!(f, "manifold"), + } + } +} + +impl AlgoStage { + /// Parse "spread,max_hops=4,edge_decay=0.5" into an AlgoStage. + pub fn parse(s: &str) -> Result { + let mut parts = s.split(','); + let name = parts.next().unwrap_or(""); + let algo = match name { + "spread" => Algorithm::Spread, + "spectral" => Algorithm::Spectral, + "manifold" => Algorithm::Manifold, + _ => return Err(format!("unknown algorithm: {}", name)), + }; + let mut params = HashMap::new(); + for part in parts { + if let Some((k, v)) = part.split_once('=') { + params.insert(k.to_string(), v.to_string()); + } else { + return Err(format!("bad param (expected key=val): {}", part)); + } + } + Ok(AlgoStage { algo, params }) + } + + fn param_f64(&self, key: &str, default: f64) -> f64 { + self.params.get(key) + .and_then(|v| v.parse().ok()) + .unwrap_or(default) + } + + fn param_u32(&self, key: &str, default: u32) -> u32 { + self.params.get(key) + .and_then(|v| v.parse().ok()) + .unwrap_or(default) + } + + fn param_usize(&self, key: &str, default: usize) -> usize { + self.params.get(key) + .and_then(|v| v.parse().ok()) + .unwrap_or(default) + } +} + +/// Extract seeds from weighted terms by matching against node keys. /// -/// circ = 0.0: field mode — all edges (default, broad resonance) -/// circ = 1.0: causal mode — prefer causal edges +/// Returns (seeds, direct_hits) where direct_hits tracks which keys +/// were matched directly (vs found by an algorithm stage). +pub fn match_seeds( + terms: &BTreeMap, + store: &impl StoreView, +) -> (Vec<(String, f64)>, HashSet) { + let mut seeds: Vec<(String, f64)> = Vec::new(); + let mut direct_hits: HashSet = HashSet::new(); + + let mut key_map: HashMap = HashMap::new(); + store.for_each_node(|key, _content, weight| { + key_map.insert(key.to_lowercase(), (key.to_owned(), weight as f64)); + }); + + for (term, &term_weight) in terms { + if let Some((orig_key, node_weight)) = key_map.get(term) { + let score = term_weight * node_weight; + seeds.push((orig_key.clone(), score)); + direct_hits.insert(orig_key.clone()); + } + } + + (seeds, direct_hits) +} + +/// Run a pipeline of algorithm stages. +pub fn run_pipeline( + stages: &[AlgoStage], + seeds: Vec<(String, f64)>, + graph: &Graph, + store: &impl StoreView, + debug: bool, + max_results: usize, +) -> Vec<(String, f64)> { + let mut current = seeds; + + for stage in stages { + if debug { + println!("\n[search] === {} ({} seeds in) ===", stage.algo, current.len()); + } + + current = match stage.algo { + Algorithm::Spread => run_spread(¤t, graph, store, stage, debug), + Algorithm::Spectral => run_spectral(¤t, graph, stage, debug), + Algorithm::Manifold => { + if debug { println!(" (manifold not yet implemented, passing through)"); } + current + } + }; + + if debug { + println!("[search] {} → {} results", stage.algo, current.len()); + for (i, (key, score)) in current.iter().enumerate().take(15) { + let cutoff = if i + 1 == max_results { " <-- cutoff" } else { "" }; + println!(" [{:.4}] {}{}", score, key, cutoff); + } + if current.len() > 15 { + println!(" ... ({} more)", current.len() - 15); + } + } + } + + current.truncate(max_results); + current +} + +/// Spreading activation: propagate scores through graph edges. +/// +/// Tunable params: max_hops (default from store), edge_decay (default from store), +/// min_activation (default from store). +fn run_spread( + seeds: &[(String, f64)], + graph: &Graph, + store: &impl StoreView, + stage: &AlgoStage, + _debug: bool, +) -> Vec<(String, f64)> { + let store_params = store.params(); + let max_hops = stage.param_u32("max_hops", store_params.max_hops); + let edge_decay = stage.param_f64("edge_decay", store_params.edge_decay); + let min_activation = stage.param_f64("min_activation", store_params.min_activation); + + spreading_activation(seeds, graph, store, max_hops, edge_decay, min_activation) +} + +/// Spectral projection: find nearest neighbors in spectral embedding space. +/// +/// Tunable params: k (default 20, number of neighbors to find). +fn run_spectral( + seeds: &[(String, f64)], + graph: &Graph, + stage: &AlgoStage, + debug: bool, +) -> Vec<(String, f64)> { + let k = stage.param_usize("k", 20); + + let emb = match spectral::load_embedding() { + Ok(e) => e, + Err(e) => { + if debug { println!(" no spectral embedding: {}", e); } + return seeds.to_vec(); + } + }; + + let weighted_seeds: Vec<(&str, f64)> = seeds.iter() + .map(|(k, w)| (k.as_str(), *w)) + .collect(); + let projected = spectral::nearest_to_seeds_weighted( + &emb, &weighted_seeds, Some(graph), k, + ); + + if debug { + for (key, dist) in &projected { + let score = 1.0 / (1.0 + dist); + println!(" dist={:.6} score={:.4} {}", dist, score, key); + } + } + + // Merge: keep original seeds, add spectral results as new seeds + let seed_set: HashSet<&str> = seeds.iter().map(|(k, _)| k.as_str()).collect(); + let mut result = seeds.to_vec(); + for (key, dist) in projected { + if !seed_set.contains(key.as_str()) { + let score = 1.0 / (1.0 + dist); + result.push((key, score)); + } + } + result +} + fn spreading_activation( seeds: &[(String, f64)], graph: &Graph, store: &impl StoreView, - _circumscription: f64, + max_hops: u32, + edge_decay: f64, + min_activation: f64, ) -> Vec<(String, f64)> { - let params = store.params(); - let mut activation: HashMap = HashMap::new(); let mut queue: VecDeque<(String, f64, u32)> = VecDeque::new(); @@ -40,12 +242,12 @@ fn spreading_activation( } while let Some((key, act, depth)) = queue.pop_front() { - if depth >= params.max_hops { continue; } + if depth >= max_hops { continue; } for (neighbor, strength) in graph.neighbors(&key) { let neighbor_weight = store.node_weight(neighbor.as_str()); - let propagated = act * params.edge_decay * neighbor_weight * strength as f64; - if propagated < params.min_activation { continue; } + let propagated = act * edge_decay * neighbor_weight * strength as f64; + if propagated < min_activation { continue; } let current = activation.entry(neighbor.clone()).or_insert(0.0); if propagated > *current { @@ -60,55 +262,71 @@ fn spreading_activation( results } -/// Full search: find direct hits, spread activation, return ranked results -pub fn search(query: &str, store: &impl StoreView) -> Vec { +/// Search with weighted terms: exact key matching + spectral projection. +/// +/// Terms are matched against node keys. Matching nodes become seeds, +/// scored by term_weight × node_weight. Seeds are then projected into +/// spectral space to find nearby nodes, with link weights modulating distance. +pub fn search_weighted( + terms: &BTreeMap, + store: &impl StoreView, +) -> Vec { + search_weighted_inner(terms, store, false, 5) +} + +/// Like search_weighted but with debug output and configurable result count. +pub fn search_weighted_debug( + terms: &BTreeMap, + store: &impl StoreView, + max_results: usize, +) -> Vec { + search_weighted_inner(terms, store, true, max_results) +} + +fn search_weighted_inner( + terms: &BTreeMap, + store: &impl StoreView, + debug: bool, + max_results: usize, +) -> Vec { let graph = crate::graph::build_graph_fast(store); - let query_lower = query.to_lowercase(); - let query_tokens: Vec<&str> = query_lower.split_whitespace().collect(); - - let mut seeds: Vec<(String, f64)> = Vec::new(); - let mut snippets: HashMap = HashMap::new(); - - store.for_each_node(|key, content, weight| { - let content_lower = content.to_lowercase(); - - let exact_match = content_lower.contains(&query_lower); - let token_match = query_tokens.len() > 1 - && query_tokens.iter().all(|t| content_lower.contains(t)); - - if exact_match || token_match { - let activation = if exact_match { weight as f64 } else { weight as f64 * 0.85 }; - seeds.push((key.to_owned(), activation)); - - let snippet: String = content.lines() - .filter(|l| { - let ll = l.to_lowercase(); - if exact_match && ll.contains(&query_lower) { return true; } - query_tokens.iter().any(|t| ll.contains(t)) - }) - .take(3) - .map(|l| { - let t = l.trim(); - crate::util::truncate(t, 97, "...") - }) - .collect::>() - .join("\n "); - snippets.insert(key.to_owned(), snippet); - } - }); + let (seeds, direct_hits) = match_seeds(terms, store); if seeds.is_empty() { return Vec::new(); } - let direct_hits: HashSet = seeds.iter().map(|(k, _)| k.clone()).collect(); - let raw_results = spreading_activation(&seeds, &graph, store, 0.0); + if debug { + println!("\n[search] === SEEDS ({}) ===", seeds.len()); + let mut sorted_seeds = seeds.clone(); + sorted_seeds.sort_by(|a, b| b.1.total_cmp(&a.1)); + for (key, score) in sorted_seeds.iter().take(20) { + println!(" {:.4} {}", score, key); + } + } - raw_results.into_iter().map(|(key, activation)| { - let is_direct = direct_hits.contains(&key); - let snippet = snippets.get(&key).cloned(); - SearchResult { key, activation, is_direct, snippet } - }).collect() + // Default pipeline: spectral → spread (legacy behavior) + let pipeline = vec![ + AlgoStage { algo: Algorithm::Spectral, params: HashMap::new() }, + AlgoStage { algo: Algorithm::Spread, params: HashMap::new() }, + ]; + + let raw_results = run_pipeline(&pipeline, seeds, &graph, store, debug, max_results); + + raw_results.into_iter() + .take(max_results) + .map(|(key, activation)| { + let is_direct = direct_hits.contains(&key); + SearchResult { key, activation, is_direct, snippet: None } + }).collect() +} + +/// Search with equal-weight terms (for interactive use). +pub fn search(query: &str, store: &impl StoreView) -> Vec { + let terms: BTreeMap = query.split_whitespace() + .map(|t| (t.to_lowercase(), 1.0)) + .collect(); + search_weighted(&terms, store) } /// Extract meaningful search terms from natural language. diff --git a/poc-memory/src/spectral.rs b/poc-memory/src/spectral.rs index 0cbd7fd..c43de1e 100644 --- a/poc-memory/src/spectral.rs +++ b/poc-memory/src/spectral.rs @@ -113,12 +113,20 @@ pub fn decompose(graph: &Graph, k: usize) -> SpectralResult { let s = eig.S(); let u = eig.U(); - let k = k.min(n); let mut eigenvalues = Vec::with_capacity(k); let mut eigvecs = Vec::with_capacity(k); let s_col = s.column_vector(); - for col in 0..k { + + // Skip trivial eigenvalues (near-zero = null space from disconnected components). + // The number of zero eigenvalues equals the number of connected components. + let mut start = 0; + while start < n && s_col[start].abs() < 1e-8 { + start += 1; + } + + let k = k.min(n.saturating_sub(start)); + for col in start..start + k { eigenvalues.push(s_col[col]); let mut vec = Vec::with_capacity(n); for row in 0..n { @@ -287,24 +295,71 @@ pub fn nearest_to_seeds( seeds: &[&str], k: usize, ) -> Vec<(String, f64)> { - let seed_set: HashSet<&str> = seeds.iter().copied().collect(); + nearest_to_seeds_weighted(emb, &seeds.iter().map(|&s| (s, 1.0)).collect::>(), None, k) +} - let seed_coords: Vec<&Vec> = seeds.iter() - .filter_map(|s| emb.coords.get(*s)) +/// Find nearest neighbors to weighted seed nodes, using link weights. +/// +/// Each seed has a weight (from query term weighting). For candidates +/// directly linked to a seed, the spectral distance is scaled by +/// 1/link_strength — strong links make effective distance shorter. +/// Seed weight scales the contribution: high-weight seeds pull harder. +/// +/// Returns (key, effective_distance) sorted by distance ascending. +pub fn nearest_to_seeds_weighted( + emb: &SpectralEmbedding, + seeds: &[(&str, f64)], // (key, seed_weight) + graph: Option<&crate::graph::Graph>, + k: usize, +) -> Vec<(String, f64)> { + let seed_set: HashSet<&str> = seeds.iter().map(|(s, _)| *s).collect(); + + let seed_data: Vec<(&str, &Vec, f64)> = seeds.iter() + .filter_map(|(s, w)| { + emb.coords.get(*s) + .filter(|c| c.iter().any(|&v| v.abs() > 1e-12)) // skip degenerate seeds + .map(|c| (*s, c, *w)) + }) .collect(); - if seed_coords.is_empty() { + if seed_data.is_empty() { return vec![]; } - let weights = eigenvalue_weights(&emb.eigenvalues); + // Build seed→neighbor link strength lookup + let link_strengths: HashMap<(&str, &str), f32> = if let Some(g) = graph { + let mut map = HashMap::new(); + for &(seed_key, _) in seeds { + for (neighbor, strength) in g.neighbors(seed_key) { + map.insert((seed_key, neighbor.as_str()), strength); + } + } + map + } else { + HashMap::new() + }; + + let dim_weights = eigenvalue_weights(&emb.eigenvalues); let mut distances: Vec<(String, f64)> = emb.coords.iter() - .filter(|(k, _)| !seed_set.contains(k.as_str())) - .map(|(k, coords)| { - let min_dist = seed_coords.iter() - .map(|sc| weighted_distance(coords, sc, &weights)) + .filter(|(k, coords)| { + !seed_set.contains(k.as_str()) + && coords.iter().any(|&v| v.abs() > 1e-12) // skip degenerate zero-coord nodes + }) + .map(|(candidate_key, coords)| { + let min_dist = seed_data.iter() + .map(|(seed_key, sc, seed_weight)| { + let raw_dist = weighted_distance(coords, sc, &dim_weights); + + // Scale by link strength if directly connected + let link_scale = link_strengths + .get(&(*seed_key, candidate_key.as_str())) + .map(|&s| 1.0 / (1.0 + s as f64)) // strong link → smaller distance + .unwrap_or(1.0); + + raw_dist * link_scale / seed_weight + }) .fold(f64::MAX, f64::min); - (k.clone(), min_dist) + (candidate_key.clone(), min_dist) }) .collect();