search: composable algorithm pipeline
Break search into composable stages that chain left-to-right: each stage takes seeds Vec<(String, f64)> and returns modified seeds. Available algorithms: spread — spreading activation through graph edges spectral — nearest neighbors in spectral embedding manifold — (placeholder) extrapolation along seed direction Stages accept inline params: spread,max_hops=4,edge_decay=0.5 memory-search gets --hook, --debug, --seen modes plus positional pipeline args. poc-memory search gets -p/--pipeline flags. Also: fix spectral decompose() to skip zero eigenvalues from disconnected components, filter degenerate zero-coord nodes from spectral projection, POC_AGENT bail-out for daemon agents, all debug output to stdout. Co-Authored-By: ProofOfConcept <poc@bcachefs.org>
This commit is contained in:
parent
0a35a17fad
commit
c1664bf76b
4 changed files with 723 additions and 151 deletions
|
|
@ -1,24 +1,76 @@
|
||||||
// memory-search: combined hook for session context loading + ambient memory retrieval
|
// memory-search: combined hook for session context loading + ambient memory retrieval
|
||||||
//
|
//
|
||||||
// On first prompt per session: loads full memory context (identity, journal, etc.)
|
// Modes:
|
||||||
// On subsequent prompts: searches memory for relevant entries
|
// --hook Run as Claude Code UserPromptSubmit hook (reads stdin, injects into conversation)
|
||||||
// On post-compaction: reloads full context
|
// --debug Replay last stashed input, dump every stage to stdout
|
||||||
//
|
// --seen Show the seen set for current session
|
||||||
// Reads JSON from stdin (Claude Code UserPromptSubmit hook format),
|
// (default) No-op (future: manual search modes)
|
||||||
// outputs results for injection into the conversation.
|
|
||||||
|
|
||||||
use poc_memory::search;
|
use clap::Parser;
|
||||||
|
use poc_memory::search::{self, AlgoStage};
|
||||||
use poc_memory::store;
|
use poc_memory::store;
|
||||||
use std::collections::HashSet;
|
use std::collections::{BTreeMap, HashSet};
|
||||||
use std::fs;
|
use std::fs;
|
||||||
use std::io::{self, Read, Write};
|
use std::io::{self, Read, Write};
|
||||||
use std::path::{Path, PathBuf};
|
use std::path::{Path, PathBuf};
|
||||||
use std::process::Command;
|
use std::process::Command;
|
||||||
use std::time::{Duration, SystemTime};
|
use std::time::{Duration, SystemTime};
|
||||||
|
|
||||||
|
#[derive(Parser)]
|
||||||
|
#[command(name = "memory-search")]
|
||||||
|
struct Args {
|
||||||
|
/// Run as Claude Code hook (reads stdin, outputs for injection)
|
||||||
|
#[arg(long)]
|
||||||
|
hook: bool,
|
||||||
|
|
||||||
|
/// Debug mode: replay last stashed input, dump every stage
|
||||||
|
#[arg(short, long)]
|
||||||
|
debug: bool,
|
||||||
|
|
||||||
|
/// Show the seen set and returned memories for this session
|
||||||
|
#[arg(long)]
|
||||||
|
seen: bool,
|
||||||
|
|
||||||
|
/// Max results to return
|
||||||
|
#[arg(long, default_value = "5")]
|
||||||
|
max_results: usize,
|
||||||
|
|
||||||
|
/// Algorithm pipeline stages: e.g. spread spectral,k=20 spread,max_hops=4
|
||||||
|
/// Default: spread.
|
||||||
|
pipeline: Vec<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
const STASH_PATH: &str = "/tmp/claude-memory-search/last-input.json";
|
||||||
|
|
||||||
fn main() {
|
fn main() {
|
||||||
let mut input = String::new();
|
// Daemon agent calls set POC_AGENT=1 — skip memory search.
|
||||||
io::stdin().read_to_string(&mut input).unwrap_or_default();
|
if std::env::var("POC_AGENT").is_ok() {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
let args = Args::parse();
|
||||||
|
|
||||||
|
if args.seen {
|
||||||
|
show_seen();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
let input = if args.hook {
|
||||||
|
// Hook mode: read from stdin, stash for later debug runs
|
||||||
|
let mut buf = String::new();
|
||||||
|
io::stdin().read_to_string(&mut buf).unwrap_or_default();
|
||||||
|
fs::create_dir_all("/tmp/claude-memory-search").ok();
|
||||||
|
fs::write(STASH_PATH, &buf).ok();
|
||||||
|
buf
|
||||||
|
} else {
|
||||||
|
// All other modes: replay stashed input
|
||||||
|
fs::read_to_string(STASH_PATH).unwrap_or_else(|_| {
|
||||||
|
eprintln!("No stashed input at {}", STASH_PATH);
|
||||||
|
std::process::exit(1);
|
||||||
|
})
|
||||||
|
};
|
||||||
|
|
||||||
|
let debug = args.debug || !args.hook;
|
||||||
|
|
||||||
let json: serde_json::Value = match serde_json::from_str(&input) {
|
let json: serde_json::Value = match serde_json::from_str(&input) {
|
||||||
Ok(v) => v,
|
Ok(v) => v,
|
||||||
|
|
@ -42,6 +94,16 @@ fn main() {
|
||||||
let cookie_path = state_dir.join(format!("cookie-{}", session_id));
|
let cookie_path = state_dir.join(format!("cookie-{}", session_id));
|
||||||
let is_first = !cookie_path.exists();
|
let is_first = !cookie_path.exists();
|
||||||
|
|
||||||
|
if is_first || is_compaction {
|
||||||
|
// Reset seen set to keys that load-context will inject
|
||||||
|
let seen_path = state_dir.join(format!("seen-{}", session_id));
|
||||||
|
fs::remove_file(&seen_path).ok();
|
||||||
|
}
|
||||||
|
|
||||||
|
if debug {
|
||||||
|
println!("[memory-search] session={} is_first={} is_compaction={}", session_id, is_first, is_compaction);
|
||||||
|
}
|
||||||
|
|
||||||
if is_first || is_compaction {
|
if is_first || is_compaction {
|
||||||
// Create/touch the cookie
|
// Create/touch the cookie
|
||||||
let cookie = if is_first {
|
let cookie = if is_first {
|
||||||
|
|
@ -52,52 +114,135 @@ fn main() {
|
||||||
fs::read_to_string(&cookie_path).unwrap_or_default().trim().to_string()
|
fs::read_to_string(&cookie_path).unwrap_or_default().trim().to_string()
|
||||||
};
|
};
|
||||||
|
|
||||||
// Load full memory context
|
if debug { println!("[memory-search] loading full context"); }
|
||||||
|
|
||||||
|
// Load full memory context and pre-populate seen set with injected keys
|
||||||
if let Ok(output) = Command::new("poc-memory").args(["load-context"]).output() {
|
if let Ok(output) = Command::new("poc-memory").args(["load-context"]).output() {
|
||||||
if output.status.success() {
|
if output.status.success() {
|
||||||
let ctx = String::from_utf8_lossy(&output.stdout);
|
let ctx = String::from_utf8_lossy(&output.stdout);
|
||||||
if !ctx.trim().is_empty() {
|
if !ctx.trim().is_empty() {
|
||||||
|
// Extract keys from "--- KEY (group) ---" lines
|
||||||
|
for line in ctx.lines() {
|
||||||
|
if line.starts_with("--- ") && line.ends_with(" ---") {
|
||||||
|
let inner = &line[4..line.len() - 4];
|
||||||
|
if let Some(paren) = inner.rfind(" (") {
|
||||||
|
let key = inner[..paren].trim();
|
||||||
|
mark_seen(&state_dir, session_id, key);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if debug { println!("[memory-search] context loaded: {} bytes", ctx.len()); }
|
||||||
|
if args.hook {
|
||||||
print!("{}", ctx);
|
print!("{}", ctx);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// On first prompt, also bump lookup counter for the cookie
|
|
||||||
let _ = cookie; // used for tagging below
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Always do ambient search (skip on very short or system prompts)
|
let _ = cookie;
|
||||||
let word_count = prompt.split_whitespace().count();
|
|
||||||
if word_count < 3 {
|
|
||||||
return;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Skip system/AFK prompts
|
||||||
for prefix in &["is AFK", "You're on your own", "IRC mention"] {
|
for prefix in &["is AFK", "You're on your own", "IRC mention"] {
|
||||||
if prompt.starts_with(prefix) {
|
if prompt.starts_with(prefix) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let query = search::extract_query_terms(prompt, 3);
|
|
||||||
if query.is_empty() {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
let store = match store::Store::load() {
|
let store = match store::Store::load() {
|
||||||
Ok(s) => s,
|
Ok(s) => s,
|
||||||
Err(_) => return,
|
Err(_) => return,
|
||||||
};
|
};
|
||||||
|
|
||||||
let results = search::search(&query, &store);
|
// Search for node keys in last ~150k tokens of transcript
|
||||||
if results.is_empty() {
|
let transcript_path = json["transcript_path"].as_str().unwrap_or("");
|
||||||
|
if debug { println!("[memory-search] transcript: {}", transcript_path); }
|
||||||
|
let terms = extract_weighted_terms(transcript_path, 150_000, &store);
|
||||||
|
|
||||||
|
if debug {
|
||||||
|
println!("[memory-search] {} node keys found in transcript", terms.len());
|
||||||
|
let mut by_weight: Vec<_> = terms.iter().collect();
|
||||||
|
by_weight.sort_by(|a, b| b.1.total_cmp(a.1));
|
||||||
|
for (term, weight) in by_weight.iter().take(20) {
|
||||||
|
println!(" {:.3} {}", weight, term);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if terms.is_empty() {
|
||||||
|
if debug { println!("[memory-search] no node keys found, done"); }
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Parse algorithm pipeline
|
||||||
|
let pipeline: Vec<AlgoStage> = if args.pipeline.is_empty() {
|
||||||
|
// Default: just spreading activation
|
||||||
|
vec![AlgoStage::parse("spread").unwrap()]
|
||||||
|
} else {
|
||||||
|
let mut stages = Vec::new();
|
||||||
|
for arg in &args.pipeline {
|
||||||
|
match AlgoStage::parse(arg) {
|
||||||
|
Ok(s) => stages.push(s),
|
||||||
|
Err(e) => {
|
||||||
|
eprintln!("error: {}", e);
|
||||||
|
std::process::exit(1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
stages
|
||||||
|
};
|
||||||
|
|
||||||
|
if debug {
|
||||||
|
let names: Vec<String> = pipeline.iter().map(|s| format!("{}", s.algo)).collect();
|
||||||
|
println!("[memory-search] pipeline: {}", names.join(" → "));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract seeds from terms
|
||||||
|
let graph = poc_memory::graph::build_graph_fast(&store);
|
||||||
|
let (seeds, direct_hits) = search::match_seeds(&terms, &store);
|
||||||
|
|
||||||
|
if seeds.is_empty() {
|
||||||
|
if debug { println!("[memory-search] no seeds matched, done"); }
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if debug {
|
||||||
|
println!("[memory-search] {} seeds", seeds.len());
|
||||||
|
let mut sorted = seeds.clone();
|
||||||
|
sorted.sort_by(|a, b| b.1.total_cmp(&a.1));
|
||||||
|
for (key, score) in sorted.iter().take(20) {
|
||||||
|
println!(" {:.4} {}", score, key);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let max_results = if debug { args.max_results.max(25) } else { args.max_results };
|
||||||
|
let raw_results = search::run_pipeline(&pipeline, seeds, &graph, &store, debug, max_results);
|
||||||
|
|
||||||
|
let results: Vec<search::SearchResult> = raw_results.into_iter()
|
||||||
|
.map(|(key, activation)| {
|
||||||
|
let is_direct = direct_hits.contains(&key);
|
||||||
|
search::SearchResult { key, activation, is_direct, snippet: None }
|
||||||
|
}).collect();
|
||||||
|
|
||||||
|
if debug {
|
||||||
|
println!("[memory-search] {} search results", results.len());
|
||||||
|
for r in results.iter().take(10) {
|
||||||
|
let marker = if r.is_direct { "→" } else { " " };
|
||||||
|
println!(" {} [{:.4}] {}", marker, r.activation, r.key);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if results.is_empty() {
|
||||||
|
if debug { println!("[memory-search] no results, done"); }
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
let seen = load_seen(&state_dir, session_id);
|
||||||
|
if debug { println!("[memory-search] {} keys in seen set", seen.len()); }
|
||||||
|
|
||||||
// Format results like poc-memory search output
|
// Format results like poc-memory search output
|
||||||
let search_output = search::format_results(&results);
|
let search_output = search::format_results(&results);
|
||||||
|
|
||||||
let cookie = fs::read_to_string(&cookie_path).unwrap_or_default().trim().to_string();
|
let cookie = fs::read_to_string(&cookie_path).unwrap_or_default().trim().to_string();
|
||||||
let seen = load_seen(&state_dir, session_id);
|
|
||||||
|
|
||||||
let mut result_output = String::new();
|
let mut result_output = String::new();
|
||||||
let mut count = 0;
|
let mut count = 0;
|
||||||
|
|
@ -112,6 +257,7 @@ fn main() {
|
||||||
if let Some(key) = extract_key_from_line(trimmed) {
|
if let Some(key) = extract_key_from_line(trimmed) {
|
||||||
if seen.contains(&key) { continue; }
|
if seen.contains(&key) { continue; }
|
||||||
mark_seen(&state_dir, session_id, &key);
|
mark_seen(&state_dir, session_id, &key);
|
||||||
|
mark_returned(&state_dir, session_id, &key);
|
||||||
result_output.push_str(line);
|
result_output.push_str(line);
|
||||||
result_output.push('\n');
|
result_output.push('\n');
|
||||||
count += 1;
|
count += 1;
|
||||||
|
|
@ -121,9 +267,14 @@ fn main() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if count == 0 { return; }
|
if count == 0 {
|
||||||
|
if debug { println!("[memory-search] all results already seen"); }
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if args.hook {
|
||||||
println!("Recalled memories [{}]:", cookie);
|
println!("Recalled memories [{}]:", cookie);
|
||||||
|
}
|
||||||
print!("{}", result_output);
|
print!("{}", result_output);
|
||||||
|
|
||||||
// Clean up stale state files (opportunistic)
|
// Clean up stale state files (opportunistic)
|
||||||
|
|
@ -131,6 +282,82 @@ fn main() {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/// Reverse-scan the transcript JSONL, extracting text from user/assistant
|
||||||
|
/// messages until we accumulate `max_tokens` tokens of text content.
|
||||||
|
/// Then search for all node keys as substrings, weighted by position.
|
||||||
|
fn extract_weighted_terms(
|
||||||
|
path: &str,
|
||||||
|
max_tokens: usize,
|
||||||
|
store: &poc_memory::store::Store,
|
||||||
|
) -> BTreeMap<String, f64> {
|
||||||
|
if path.is_empty() { return BTreeMap::new(); }
|
||||||
|
|
||||||
|
let content = match fs::read_to_string(path) {
|
||||||
|
Ok(c) => c,
|
||||||
|
Err(_) => return BTreeMap::new(),
|
||||||
|
};
|
||||||
|
|
||||||
|
// Collect text from messages, scanning backwards, until token budget hit
|
||||||
|
let mut message_texts: Vec<String> = Vec::new();
|
||||||
|
let mut token_count = 0;
|
||||||
|
|
||||||
|
for line in content.lines().rev() {
|
||||||
|
if token_count >= max_tokens { break; }
|
||||||
|
|
||||||
|
let obj: serde_json::Value = match serde_json::from_str(line) {
|
||||||
|
Ok(v) => v,
|
||||||
|
Err(_) => continue,
|
||||||
|
};
|
||||||
|
|
||||||
|
let msg_type = obj.get("type").and_then(|v| v.as_str()).unwrap_or("");
|
||||||
|
if msg_type != "user" && msg_type != "assistant" { continue; }
|
||||||
|
|
||||||
|
let mut msg_text = String::new();
|
||||||
|
let msg = obj.get("message").unwrap_or(&obj);
|
||||||
|
match msg.get("content") {
|
||||||
|
Some(serde_json::Value::String(s)) => {
|
||||||
|
msg_text.push_str(s);
|
||||||
|
}
|
||||||
|
Some(serde_json::Value::Array(arr)) => {
|
||||||
|
for block in arr {
|
||||||
|
if block.get("type").and_then(|v| v.as_str()) == Some("text") {
|
||||||
|
if let Some(t) = block.get("text").and_then(|v| v.as_str()) {
|
||||||
|
msg_text.push(' ');
|
||||||
|
msg_text.push_str(t);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
|
||||||
|
token_count += msg_text.len() / 4;
|
||||||
|
message_texts.push(msg_text);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reverse so oldest is first (position weighting: later = more recent = higher)
|
||||||
|
message_texts.reverse();
|
||||||
|
let all_text = message_texts.join(" ").to_lowercase();
|
||||||
|
let text_len = all_text.len();
|
||||||
|
if text_len == 0 { return BTreeMap::new(); }
|
||||||
|
|
||||||
|
// Search for each node key as a substring (casefolded), accumulate position-weighted score
|
||||||
|
let mut terms = BTreeMap::new();
|
||||||
|
for (key, _node) in &store.nodes {
|
||||||
|
let key_folded = key.to_lowercase();
|
||||||
|
let mut pos = 0;
|
||||||
|
while let Some(found) = all_text[pos..].find(&key_folded) {
|
||||||
|
let abs_pos = pos + found;
|
||||||
|
let weight = (abs_pos + 1) as f64 / text_len as f64;
|
||||||
|
*terms.entry(key_folded.clone()).or_insert(0.0) += weight;
|
||||||
|
pos = abs_pos + key_folded.len();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
terms
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
fn extract_key_from_line(line: &str) -> Option<String> {
|
fn extract_key_from_line(line: &str) -> Option<String> {
|
||||||
let after_bracket = line.find("] ")?;
|
let after_bracket = line.find("] ")?;
|
||||||
let rest = &line[after_bracket + 2..];
|
let rest = &line[after_bracket + 2..];
|
||||||
|
|
@ -167,6 +394,70 @@ fn mark_seen(dir: &Path, session_id: &str, key: &str) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn mark_returned(dir: &Path, session_id: &str, key: &str) {
|
||||||
|
let path = dir.join(format!("returned-{}", session_id));
|
||||||
|
if let Ok(mut f) = fs::OpenOptions::new().create(true).append(true).open(path) {
|
||||||
|
writeln!(f, "{}", key).ok();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn load_returned(dir: &Path, session_id: &str) -> Vec<String> {
|
||||||
|
let path = dir.join(format!("returned-{}", session_id));
|
||||||
|
if path.exists() {
|
||||||
|
fs::read_to_string(path)
|
||||||
|
.unwrap_or_default()
|
||||||
|
.lines()
|
||||||
|
.filter(|s| !s.is_empty())
|
||||||
|
.map(|s| s.to_string())
|
||||||
|
.collect()
|
||||||
|
} else {
|
||||||
|
Vec::new()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn show_seen() {
|
||||||
|
let state_dir = PathBuf::from("/tmp/claude-memory-search");
|
||||||
|
|
||||||
|
// Read stashed input for session_id
|
||||||
|
let input = match fs::read_to_string(STASH_PATH) {
|
||||||
|
Ok(s) => s,
|
||||||
|
Err(_) => {
|
||||||
|
eprintln!("No stashed input at {}", STASH_PATH);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let json: serde_json::Value = match serde_json::from_str(&input) {
|
||||||
|
Ok(v) => v,
|
||||||
|
Err(_) => {
|
||||||
|
eprintln!("Failed to parse stashed input");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let session_id = json["session_id"].as_str().unwrap_or("");
|
||||||
|
if session_id.is_empty() {
|
||||||
|
eprintln!("No session_id in stashed input");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
println!("Session: {}", session_id);
|
||||||
|
|
||||||
|
let cookie_path = state_dir.join(format!("cookie-{}", session_id));
|
||||||
|
if let Ok(cookie) = fs::read_to_string(&cookie_path) {
|
||||||
|
println!("Cookie: {}", cookie.trim());
|
||||||
|
}
|
||||||
|
|
||||||
|
let returned = load_returned(&state_dir, session_id);
|
||||||
|
if !returned.is_empty() {
|
||||||
|
println!("\nReturned by search ({}):", returned.len());
|
||||||
|
for key in &returned {
|
||||||
|
println!(" {}", key);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let seen = load_seen(&state_dir, session_id);
|
||||||
|
println!("\nSeen set ({} total, {} pre-seeded):", seen.len(), seen.len() - returned.len());
|
||||||
|
}
|
||||||
|
|
||||||
fn cleanup_stale_files(dir: &Path, max_age: Duration) {
|
fn cleanup_stale_files(dir: &Path, max_age: Duration) {
|
||||||
let entries = match fs::read_dir(dir) {
|
let entries = match fs::read_dir(dir) {
|
||||||
Ok(e) => e,
|
Ok(e) => e,
|
||||||
|
|
|
||||||
|
|
@ -59,12 +59,21 @@ struct Cli {
|
||||||
#[derive(Subcommand)]
|
#[derive(Subcommand)]
|
||||||
enum Command {
|
enum Command {
|
||||||
/// Search memory (AND logic across terms)
|
/// Search memory (AND logic across terms)
|
||||||
|
///
|
||||||
|
/// Pipeline: -p spread -p spectral,k=20
|
||||||
|
/// Default pipeline: spread
|
||||||
Search {
|
Search {
|
||||||
/// Search terms
|
/// Search terms
|
||||||
query: Vec<String>,
|
query: Vec<String>,
|
||||||
/// Show 15 results instead of 5, plus spectral neighbors
|
/// Algorithm pipeline stages (repeatable)
|
||||||
|
#[arg(short, long = "pipeline")]
|
||||||
|
pipeline: Vec<String>,
|
||||||
|
/// Show more results
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
expand: bool,
|
expand: bool,
|
||||||
|
/// Show debug output for each pipeline stage
|
||||||
|
#[arg(long)]
|
||||||
|
debug: bool,
|
||||||
},
|
},
|
||||||
/// Scan markdown files, index all memory units
|
/// Scan markdown files, index all memory units
|
||||||
Init,
|
Init,
|
||||||
|
|
@ -469,8 +478,8 @@ fn main() {
|
||||||
let cli = Cli::parse();
|
let cli = Cli::parse();
|
||||||
|
|
||||||
let result = match cli.command {
|
let result = match cli.command {
|
||||||
Command::Search { query, expand }
|
Command::Search { query, pipeline, expand, debug }
|
||||||
=> cmd_search(&query, expand),
|
=> cmd_search(&query, &pipeline, expand, debug),
|
||||||
Command::Init => cmd_init(),
|
Command::Init => cmd_init(),
|
||||||
Command::Migrate => cmd_migrate(),
|
Command::Migrate => cmd_migrate(),
|
||||||
Command::Health => cmd_health(),
|
Command::Health => cmd_health(),
|
||||||
|
|
@ -575,8 +584,9 @@ fn main() {
|
||||||
|
|
||||||
// ── Command implementations ─────────────────────────────────────────
|
// ── Command implementations ─────────────────────────────────────────
|
||||||
|
|
||||||
fn cmd_search(terms: &[String], expand: bool) -> Result<(), String> {
|
fn cmd_search(terms: &[String], pipeline_args: &[String], expand: bool, debug: bool) -> Result<(), String> {
|
||||||
use store::StoreView;
|
use store::StoreView;
|
||||||
|
use std::collections::BTreeMap;
|
||||||
|
|
||||||
if terms.is_empty() {
|
if terms.is_empty() {
|
||||||
return Err("search requires at least one term".into());
|
return Err("search requires at least one term".into());
|
||||||
|
|
@ -584,70 +594,68 @@ fn cmd_search(terms: &[String], expand: bool) -> Result<(), String> {
|
||||||
|
|
||||||
let query: String = terms.join(" ");
|
let query: String = terms.join(" ");
|
||||||
|
|
||||||
|
// Parse pipeline (default: spread)
|
||||||
|
let pipeline: Vec<search::AlgoStage> = if pipeline_args.is_empty() {
|
||||||
|
vec![search::AlgoStage::parse("spread").unwrap()]
|
||||||
|
} else {
|
||||||
|
pipeline_args.iter()
|
||||||
|
.map(|a| search::AlgoStage::parse(a))
|
||||||
|
.collect::<Result<Vec<_>, _>>()?
|
||||||
|
};
|
||||||
|
|
||||||
|
if debug {
|
||||||
|
let names: Vec<String> = pipeline.iter().map(|s| format!("{}", s.algo)).collect();
|
||||||
|
println!("[search] pipeline: {}", names.join(" → "));
|
||||||
|
}
|
||||||
|
|
||||||
let view = store::AnyView::load()?;
|
let view = store::AnyView::load()?;
|
||||||
let results = search::search(&query, &view);
|
let graph = graph::build_graph_fast(&view);
|
||||||
|
|
||||||
|
// Build equal-weight terms from query
|
||||||
|
let terms: BTreeMap<String, f64> = query.split_whitespace()
|
||||||
|
.map(|t| (t.to_lowercase(), 1.0))
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let (seeds, direct_hits) = search::match_seeds(&terms, &view);
|
||||||
|
|
||||||
|
if seeds.is_empty() {
|
||||||
|
eprintln!("No results for '{}'", query);
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
if debug {
|
||||||
|
println!("[search] {} seeds from query '{}'", seeds.len(), query);
|
||||||
|
for (key, score) in &seeds {
|
||||||
|
println!(" {:.4} {}", score, key);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let max_results = if expand { 15 } else { 5 };
|
||||||
|
let raw = search::run_pipeline(&pipeline, seeds, &graph, &view, debug, max_results);
|
||||||
|
|
||||||
|
let results: Vec<search::SearchResult> = raw.into_iter()
|
||||||
|
.map(|(key, activation)| {
|
||||||
|
let is_direct = direct_hits.contains(&key);
|
||||||
|
search::SearchResult { key, activation, is_direct, snippet: None }
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
if results.is_empty() {
|
if results.is_empty() {
|
||||||
eprintln!("No results for '{}'", query);
|
eprintln!("No results for '{}'", query);
|
||||||
return Ok(());
|
return Ok(());
|
||||||
}
|
}
|
||||||
|
|
||||||
let limit = if expand { 15 } else { 5 };
|
// Log retrieval
|
||||||
|
|
||||||
// Log retrieval to a small append-only file (avoid 6MB state.bin rewrite)
|
|
||||||
store::Store::log_retrieval_static(&query,
|
store::Store::log_retrieval_static(&query,
|
||||||
&results.iter().map(|r| r.key.clone()).collect::<Vec<_>>());
|
&results.iter().map(|r| r.key.clone()).collect::<Vec<_>>());
|
||||||
|
|
||||||
// Bump daily lookup counters (fast path, no store needed)
|
let bump_keys: Vec<&str> = results.iter().take(max_results).map(|r| r.key.as_str()).collect();
|
||||||
let bump_keys: Vec<&str> = results.iter().take(limit).map(|r| r.key.as_str()).collect();
|
|
||||||
let _ = lookups::bump_many(&bump_keys);
|
let _ = lookups::bump_many(&bump_keys);
|
||||||
|
|
||||||
let text_keys: std::collections::HashSet<String> = results.iter()
|
for (i, r) in results.iter().enumerate().take(max_results) {
|
||||||
.take(limit).map(|r| r.key.clone()).collect();
|
|
||||||
|
|
||||||
for (i, r) in results.iter().enumerate().take(limit) {
|
|
||||||
let marker = if r.is_direct { "→" } else { " " };
|
let marker = if r.is_direct { "→" } else { " " };
|
||||||
let weight = view.node_weight(&r.key);
|
let weight = view.node_weight(&r.key);
|
||||||
println!("{}{:2}. [{:.2}/{:.2}] {}", marker, i + 1, r.activation, weight, r.key);
|
println!("{}{:2}. [{:.2}/{:.2}] {}", marker, i + 1, r.activation, weight, r.key);
|
||||||
if let Some(ref snippet) = r.snippet {
|
|
||||||
println!(" {}", snippet);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if expand {
|
|
||||||
if let Ok(emb) = spectral::load_embedding() {
|
|
||||||
let seeds: Vec<&str> = results.iter()
|
|
||||||
.take(5)
|
|
||||||
.map(|r| r.key.as_str())
|
|
||||||
.filter(|k| emb.coords.contains_key(*k))
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
if !seeds.is_empty() {
|
|
||||||
let spectral_hits = spectral::nearest_to_seeds(&emb, &seeds, 10);
|
|
||||||
let new_hits: Vec<_> = spectral_hits.into_iter()
|
|
||||||
.filter(|(k, _)| !text_keys.contains(k))
|
|
||||||
.take(5)
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
if !new_hits.is_empty() {
|
|
||||||
println!("\nSpectral neighbors (structural, not keyword):");
|
|
||||||
for (k, _dist) in &new_hits {
|
|
||||||
let weight = view.node_weight(k);
|
|
||||||
println!(" ~ [{:.2}] {}", weight, k);
|
|
||||||
if let Some(content) = view.node_content(k) {
|
|
||||||
let snippet = util::first_n_chars(
|
|
||||||
content.lines()
|
|
||||||
.find(|l| !l.trim().is_empty() && !l.starts_with('#'))
|
|
||||||
.unwrap_or(""),
|
|
||||||
100);
|
|
||||||
if !snippet.is_empty() {
|
|
||||||
println!(" {}", snippet);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
|
|
|
||||||
|
|
@ -1,13 +1,22 @@
|
||||||
// Spreading activation search across the memory graph
|
// Memory search: composable algorithm pipeline.
|
||||||
//
|
//
|
||||||
// Same model as the old system but richer: uses graph edge strengths,
|
// Each algorithm is a stage: takes seeds Vec<(String, f64)>, returns
|
||||||
// supports circumscription parameter for blending associative vs
|
// new/modified seeds. Stages compose left-to-right in a pipeline.
|
||||||
// causal walks, and benefits from community-aware result grouping.
|
//
|
||||||
|
// Available algorithms:
|
||||||
|
// spread — spreading activation through graph edges
|
||||||
|
// spectral — nearest neighbors in spectral embedding space
|
||||||
|
// manifold — extrapolation along direction defined by seeds (TODO)
|
||||||
|
//
|
||||||
|
// Seed extraction (matching query terms to node keys) is shared
|
||||||
|
// infrastructure, not an algorithm stage.
|
||||||
|
|
||||||
use crate::store::StoreView;
|
use crate::store::StoreView;
|
||||||
use crate::graph::Graph;
|
use crate::graph::Graph;
|
||||||
|
use crate::spectral;
|
||||||
|
|
||||||
use std::collections::{HashMap, HashSet, VecDeque};
|
use std::collections::{BTreeMap, HashMap, HashSet, VecDeque};
|
||||||
|
use std::fmt;
|
||||||
|
|
||||||
pub struct SearchResult {
|
pub struct SearchResult {
|
||||||
pub key: String,
|
pub key: String,
|
||||||
|
|
@ -16,18 +25,211 @@ pub struct SearchResult {
|
||||||
pub snippet: Option<String>,
|
pub snippet: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Spreading activation with circumscription parameter.
|
/// A parsed algorithm stage with its parameters.
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
pub struct AlgoStage {
|
||||||
|
pub algo: Algorithm,
|
||||||
|
pub params: HashMap<String, String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
pub enum Algorithm {
|
||||||
|
Spread,
|
||||||
|
Spectral,
|
||||||
|
Manifold,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl fmt::Display for Algorithm {
|
||||||
|
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||||
|
match self {
|
||||||
|
Algorithm::Spread => write!(f, "spread"),
|
||||||
|
Algorithm::Spectral => write!(f, "spectral"),
|
||||||
|
Algorithm::Manifold => write!(f, "manifold"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AlgoStage {
|
||||||
|
/// Parse "spread,max_hops=4,edge_decay=0.5" into an AlgoStage.
|
||||||
|
pub fn parse(s: &str) -> Result<Self, String> {
|
||||||
|
let mut parts = s.split(',');
|
||||||
|
let name = parts.next().unwrap_or("");
|
||||||
|
let algo = match name {
|
||||||
|
"spread" => Algorithm::Spread,
|
||||||
|
"spectral" => Algorithm::Spectral,
|
||||||
|
"manifold" => Algorithm::Manifold,
|
||||||
|
_ => return Err(format!("unknown algorithm: {}", name)),
|
||||||
|
};
|
||||||
|
let mut params = HashMap::new();
|
||||||
|
for part in parts {
|
||||||
|
if let Some((k, v)) = part.split_once('=') {
|
||||||
|
params.insert(k.to_string(), v.to_string());
|
||||||
|
} else {
|
||||||
|
return Err(format!("bad param (expected key=val): {}", part));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(AlgoStage { algo, params })
|
||||||
|
}
|
||||||
|
|
||||||
|
fn param_f64(&self, key: &str, default: f64) -> f64 {
|
||||||
|
self.params.get(key)
|
||||||
|
.and_then(|v| v.parse().ok())
|
||||||
|
.unwrap_or(default)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn param_u32(&self, key: &str, default: u32) -> u32 {
|
||||||
|
self.params.get(key)
|
||||||
|
.and_then(|v| v.parse().ok())
|
||||||
|
.unwrap_or(default)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn param_usize(&self, key: &str, default: usize) -> usize {
|
||||||
|
self.params.get(key)
|
||||||
|
.and_then(|v| v.parse().ok())
|
||||||
|
.unwrap_or(default)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Extract seeds from weighted terms by matching against node keys.
|
||||||
///
|
///
|
||||||
/// circ = 0.0: field mode — all edges (default, broad resonance)
|
/// Returns (seeds, direct_hits) where direct_hits tracks which keys
|
||||||
/// circ = 1.0: causal mode — prefer causal edges
|
/// were matched directly (vs found by an algorithm stage).
|
||||||
|
pub fn match_seeds(
|
||||||
|
terms: &BTreeMap<String, f64>,
|
||||||
|
store: &impl StoreView,
|
||||||
|
) -> (Vec<(String, f64)>, HashSet<String>) {
|
||||||
|
let mut seeds: Vec<(String, f64)> = Vec::new();
|
||||||
|
let mut direct_hits: HashSet<String> = HashSet::new();
|
||||||
|
|
||||||
|
let mut key_map: HashMap<String, (String, f64)> = HashMap::new();
|
||||||
|
store.for_each_node(|key, _content, weight| {
|
||||||
|
key_map.insert(key.to_lowercase(), (key.to_owned(), weight as f64));
|
||||||
|
});
|
||||||
|
|
||||||
|
for (term, &term_weight) in terms {
|
||||||
|
if let Some((orig_key, node_weight)) = key_map.get(term) {
|
||||||
|
let score = term_weight * node_weight;
|
||||||
|
seeds.push((orig_key.clone(), score));
|
||||||
|
direct_hits.insert(orig_key.clone());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
(seeds, direct_hits)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Run a pipeline of algorithm stages.
|
||||||
|
pub fn run_pipeline(
|
||||||
|
stages: &[AlgoStage],
|
||||||
|
seeds: Vec<(String, f64)>,
|
||||||
|
graph: &Graph,
|
||||||
|
store: &impl StoreView,
|
||||||
|
debug: bool,
|
||||||
|
max_results: usize,
|
||||||
|
) -> Vec<(String, f64)> {
|
||||||
|
let mut current = seeds;
|
||||||
|
|
||||||
|
for stage in stages {
|
||||||
|
if debug {
|
||||||
|
println!("\n[search] === {} ({} seeds in) ===", stage.algo, current.len());
|
||||||
|
}
|
||||||
|
|
||||||
|
current = match stage.algo {
|
||||||
|
Algorithm::Spread => run_spread(¤t, graph, store, stage, debug),
|
||||||
|
Algorithm::Spectral => run_spectral(¤t, graph, stage, debug),
|
||||||
|
Algorithm::Manifold => {
|
||||||
|
if debug { println!(" (manifold not yet implemented, passing through)"); }
|
||||||
|
current
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
if debug {
|
||||||
|
println!("[search] {} → {} results", stage.algo, current.len());
|
||||||
|
for (i, (key, score)) in current.iter().enumerate().take(15) {
|
||||||
|
let cutoff = if i + 1 == max_results { " <-- cutoff" } else { "" };
|
||||||
|
println!(" [{:.4}] {}{}", score, key, cutoff);
|
||||||
|
}
|
||||||
|
if current.len() > 15 {
|
||||||
|
println!(" ... ({} more)", current.len() - 15);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
current.truncate(max_results);
|
||||||
|
current
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Spreading activation: propagate scores through graph edges.
|
||||||
|
///
|
||||||
|
/// Tunable params: max_hops (default from store), edge_decay (default from store),
|
||||||
|
/// min_activation (default from store).
|
||||||
|
fn run_spread(
|
||||||
|
seeds: &[(String, f64)],
|
||||||
|
graph: &Graph,
|
||||||
|
store: &impl StoreView,
|
||||||
|
stage: &AlgoStage,
|
||||||
|
_debug: bool,
|
||||||
|
) -> Vec<(String, f64)> {
|
||||||
|
let store_params = store.params();
|
||||||
|
let max_hops = stage.param_u32("max_hops", store_params.max_hops);
|
||||||
|
let edge_decay = stage.param_f64("edge_decay", store_params.edge_decay);
|
||||||
|
let min_activation = stage.param_f64("min_activation", store_params.min_activation);
|
||||||
|
|
||||||
|
spreading_activation(seeds, graph, store, max_hops, edge_decay, min_activation)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Spectral projection: find nearest neighbors in spectral embedding space.
|
||||||
|
///
|
||||||
|
/// Tunable params: k (default 20, number of neighbors to find).
|
||||||
|
fn run_spectral(
|
||||||
|
seeds: &[(String, f64)],
|
||||||
|
graph: &Graph,
|
||||||
|
stage: &AlgoStage,
|
||||||
|
debug: bool,
|
||||||
|
) -> Vec<(String, f64)> {
|
||||||
|
let k = stage.param_usize("k", 20);
|
||||||
|
|
||||||
|
let emb = match spectral::load_embedding() {
|
||||||
|
Ok(e) => e,
|
||||||
|
Err(e) => {
|
||||||
|
if debug { println!(" no spectral embedding: {}", e); }
|
||||||
|
return seeds.to_vec();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let weighted_seeds: Vec<(&str, f64)> = seeds.iter()
|
||||||
|
.map(|(k, w)| (k.as_str(), *w))
|
||||||
|
.collect();
|
||||||
|
let projected = spectral::nearest_to_seeds_weighted(
|
||||||
|
&emb, &weighted_seeds, Some(graph), k,
|
||||||
|
);
|
||||||
|
|
||||||
|
if debug {
|
||||||
|
for (key, dist) in &projected {
|
||||||
|
let score = 1.0 / (1.0 + dist);
|
||||||
|
println!(" dist={:.6} score={:.4} {}", dist, score, key);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Merge: keep original seeds, add spectral results as new seeds
|
||||||
|
let seed_set: HashSet<&str> = seeds.iter().map(|(k, _)| k.as_str()).collect();
|
||||||
|
let mut result = seeds.to_vec();
|
||||||
|
for (key, dist) in projected {
|
||||||
|
if !seed_set.contains(key.as_str()) {
|
||||||
|
let score = 1.0 / (1.0 + dist);
|
||||||
|
result.push((key, score));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
result
|
||||||
|
}
|
||||||
|
|
||||||
fn spreading_activation(
|
fn spreading_activation(
|
||||||
seeds: &[(String, f64)],
|
seeds: &[(String, f64)],
|
||||||
graph: &Graph,
|
graph: &Graph,
|
||||||
store: &impl StoreView,
|
store: &impl StoreView,
|
||||||
_circumscription: f64,
|
max_hops: u32,
|
||||||
|
edge_decay: f64,
|
||||||
|
min_activation: f64,
|
||||||
) -> Vec<(String, f64)> {
|
) -> Vec<(String, f64)> {
|
||||||
let params = store.params();
|
|
||||||
|
|
||||||
let mut activation: HashMap<String, f64> = HashMap::new();
|
let mut activation: HashMap<String, f64> = HashMap::new();
|
||||||
let mut queue: VecDeque<(String, f64, u32)> = VecDeque::new();
|
let mut queue: VecDeque<(String, f64, u32)> = VecDeque::new();
|
||||||
|
|
||||||
|
|
@ -40,12 +242,12 @@ fn spreading_activation(
|
||||||
}
|
}
|
||||||
|
|
||||||
while let Some((key, act, depth)) = queue.pop_front() {
|
while let Some((key, act, depth)) = queue.pop_front() {
|
||||||
if depth >= params.max_hops { continue; }
|
if depth >= max_hops { continue; }
|
||||||
|
|
||||||
for (neighbor, strength) in graph.neighbors(&key) {
|
for (neighbor, strength) in graph.neighbors(&key) {
|
||||||
let neighbor_weight = store.node_weight(neighbor.as_str());
|
let neighbor_weight = store.node_weight(neighbor.as_str());
|
||||||
let propagated = act * params.edge_decay * neighbor_weight * strength as f64;
|
let propagated = act * edge_decay * neighbor_weight * strength as f64;
|
||||||
if propagated < params.min_activation { continue; }
|
if propagated < min_activation { continue; }
|
||||||
|
|
||||||
let current = activation.entry(neighbor.clone()).or_insert(0.0);
|
let current = activation.entry(neighbor.clone()).or_insert(0.0);
|
||||||
if propagated > *current {
|
if propagated > *current {
|
||||||
|
|
@ -60,57 +262,73 @@ fn spreading_activation(
|
||||||
results
|
results
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Full search: find direct hits, spread activation, return ranked results
|
/// Search with weighted terms: exact key matching + spectral projection.
|
||||||
pub fn search(query: &str, store: &impl StoreView) -> Vec<SearchResult> {
|
///
|
||||||
let graph = crate::graph::build_graph_fast(store);
|
/// Terms are matched against node keys. Matching nodes become seeds,
|
||||||
let query_lower = query.to_lowercase();
|
/// scored by term_weight × node_weight. Seeds are then projected into
|
||||||
let query_tokens: Vec<&str> = query_lower.split_whitespace().collect();
|
/// spectral space to find nearby nodes, with link weights modulating distance.
|
||||||
|
pub fn search_weighted(
|
||||||
let mut seeds: Vec<(String, f64)> = Vec::new();
|
terms: &BTreeMap<String, f64>,
|
||||||
let mut snippets: HashMap<String, String> = HashMap::new();
|
store: &impl StoreView,
|
||||||
|
) -> Vec<SearchResult> {
|
||||||
store.for_each_node(|key, content, weight| {
|
search_weighted_inner(terms, store, false, 5)
|
||||||
let content_lower = content.to_lowercase();
|
|
||||||
|
|
||||||
let exact_match = content_lower.contains(&query_lower);
|
|
||||||
let token_match = query_tokens.len() > 1
|
|
||||||
&& query_tokens.iter().all(|t| content_lower.contains(t));
|
|
||||||
|
|
||||||
if exact_match || token_match {
|
|
||||||
let activation = if exact_match { weight as f64 } else { weight as f64 * 0.85 };
|
|
||||||
seeds.push((key.to_owned(), activation));
|
|
||||||
|
|
||||||
let snippet: String = content.lines()
|
|
||||||
.filter(|l| {
|
|
||||||
let ll = l.to_lowercase();
|
|
||||||
if exact_match && ll.contains(&query_lower) { return true; }
|
|
||||||
query_tokens.iter().any(|t| ll.contains(t))
|
|
||||||
})
|
|
||||||
.take(3)
|
|
||||||
.map(|l| {
|
|
||||||
let t = l.trim();
|
|
||||||
crate::util::truncate(t, 97, "...")
|
|
||||||
})
|
|
||||||
.collect::<Vec<_>>()
|
|
||||||
.join("\n ");
|
|
||||||
snippets.insert(key.to_owned(), snippet);
|
|
||||||
}
|
}
|
||||||
});
|
|
||||||
|
/// Like search_weighted but with debug output and configurable result count.
|
||||||
|
pub fn search_weighted_debug(
|
||||||
|
terms: &BTreeMap<String, f64>,
|
||||||
|
store: &impl StoreView,
|
||||||
|
max_results: usize,
|
||||||
|
) -> Vec<SearchResult> {
|
||||||
|
search_weighted_inner(terms, store, true, max_results)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn search_weighted_inner(
|
||||||
|
terms: &BTreeMap<String, f64>,
|
||||||
|
store: &impl StoreView,
|
||||||
|
debug: bool,
|
||||||
|
max_results: usize,
|
||||||
|
) -> Vec<SearchResult> {
|
||||||
|
let graph = crate::graph::build_graph_fast(store);
|
||||||
|
let (seeds, direct_hits) = match_seeds(terms, store);
|
||||||
|
|
||||||
if seeds.is_empty() {
|
if seeds.is_empty() {
|
||||||
return Vec::new();
|
return Vec::new();
|
||||||
}
|
}
|
||||||
|
|
||||||
let direct_hits: HashSet<String> = seeds.iter().map(|(k, _)| k.clone()).collect();
|
if debug {
|
||||||
let raw_results = spreading_activation(&seeds, &graph, store, 0.0);
|
println!("\n[search] === SEEDS ({}) ===", seeds.len());
|
||||||
|
let mut sorted_seeds = seeds.clone();
|
||||||
|
sorted_seeds.sort_by(|a, b| b.1.total_cmp(&a.1));
|
||||||
|
for (key, score) in sorted_seeds.iter().take(20) {
|
||||||
|
println!(" {:.4} {}", score, key);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
raw_results.into_iter().map(|(key, activation)| {
|
// Default pipeline: spectral → spread (legacy behavior)
|
||||||
|
let pipeline = vec![
|
||||||
|
AlgoStage { algo: Algorithm::Spectral, params: HashMap::new() },
|
||||||
|
AlgoStage { algo: Algorithm::Spread, params: HashMap::new() },
|
||||||
|
];
|
||||||
|
|
||||||
|
let raw_results = run_pipeline(&pipeline, seeds, &graph, store, debug, max_results);
|
||||||
|
|
||||||
|
raw_results.into_iter()
|
||||||
|
.take(max_results)
|
||||||
|
.map(|(key, activation)| {
|
||||||
let is_direct = direct_hits.contains(&key);
|
let is_direct = direct_hits.contains(&key);
|
||||||
let snippet = snippets.get(&key).cloned();
|
SearchResult { key, activation, is_direct, snippet: None }
|
||||||
SearchResult { key, activation, is_direct, snippet }
|
|
||||||
}).collect()
|
}).collect()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Search with equal-weight terms (for interactive use).
|
||||||
|
pub fn search(query: &str, store: &impl StoreView) -> Vec<SearchResult> {
|
||||||
|
let terms: BTreeMap<String, f64> = query.split_whitespace()
|
||||||
|
.map(|t| (t.to_lowercase(), 1.0))
|
||||||
|
.collect();
|
||||||
|
search_weighted(&terms, store)
|
||||||
|
}
|
||||||
|
|
||||||
/// Extract meaningful search terms from natural language.
|
/// Extract meaningful search terms from natural language.
|
||||||
/// Strips common English stop words, returns up to max_terms words.
|
/// Strips common English stop words, returns up to max_terms words.
|
||||||
pub fn extract_query_terms(text: &str, max_terms: usize) -> String {
|
pub fn extract_query_terms(text: &str, max_terms: usize) -> String {
|
||||||
|
|
|
||||||
|
|
@ -113,12 +113,20 @@ pub fn decompose(graph: &Graph, k: usize) -> SpectralResult {
|
||||||
let s = eig.S();
|
let s = eig.S();
|
||||||
let u = eig.U();
|
let u = eig.U();
|
||||||
|
|
||||||
let k = k.min(n);
|
|
||||||
let mut eigenvalues = Vec::with_capacity(k);
|
let mut eigenvalues = Vec::with_capacity(k);
|
||||||
let mut eigvecs = Vec::with_capacity(k);
|
let mut eigvecs = Vec::with_capacity(k);
|
||||||
|
|
||||||
let s_col = s.column_vector();
|
let s_col = s.column_vector();
|
||||||
for col in 0..k {
|
|
||||||
|
// Skip trivial eigenvalues (near-zero = null space from disconnected components).
|
||||||
|
// The number of zero eigenvalues equals the number of connected components.
|
||||||
|
let mut start = 0;
|
||||||
|
while start < n && s_col[start].abs() < 1e-8 {
|
||||||
|
start += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
let k = k.min(n.saturating_sub(start));
|
||||||
|
for col in start..start + k {
|
||||||
eigenvalues.push(s_col[col]);
|
eigenvalues.push(s_col[col]);
|
||||||
let mut vec = Vec::with_capacity(n);
|
let mut vec = Vec::with_capacity(n);
|
||||||
for row in 0..n {
|
for row in 0..n {
|
||||||
|
|
@ -287,24 +295,71 @@ pub fn nearest_to_seeds(
|
||||||
seeds: &[&str],
|
seeds: &[&str],
|
||||||
k: usize,
|
k: usize,
|
||||||
) -> Vec<(String, f64)> {
|
) -> Vec<(String, f64)> {
|
||||||
let seed_set: HashSet<&str> = seeds.iter().copied().collect();
|
nearest_to_seeds_weighted(emb, &seeds.iter().map(|&s| (s, 1.0)).collect::<Vec<_>>(), None, k)
|
||||||
|
}
|
||||||
|
|
||||||
let seed_coords: Vec<&Vec<f64>> = seeds.iter()
|
/// Find nearest neighbors to weighted seed nodes, using link weights.
|
||||||
.filter_map(|s| emb.coords.get(*s))
|
///
|
||||||
|
/// Each seed has a weight (from query term weighting). For candidates
|
||||||
|
/// directly linked to a seed, the spectral distance is scaled by
|
||||||
|
/// 1/link_strength — strong links make effective distance shorter.
|
||||||
|
/// Seed weight scales the contribution: high-weight seeds pull harder.
|
||||||
|
///
|
||||||
|
/// Returns (key, effective_distance) sorted by distance ascending.
|
||||||
|
pub fn nearest_to_seeds_weighted(
|
||||||
|
emb: &SpectralEmbedding,
|
||||||
|
seeds: &[(&str, f64)], // (key, seed_weight)
|
||||||
|
graph: Option<&crate::graph::Graph>,
|
||||||
|
k: usize,
|
||||||
|
) -> Vec<(String, f64)> {
|
||||||
|
let seed_set: HashSet<&str> = seeds.iter().map(|(s, _)| *s).collect();
|
||||||
|
|
||||||
|
let seed_data: Vec<(&str, &Vec<f64>, f64)> = seeds.iter()
|
||||||
|
.filter_map(|(s, w)| {
|
||||||
|
emb.coords.get(*s)
|
||||||
|
.filter(|c| c.iter().any(|&v| v.abs() > 1e-12)) // skip degenerate seeds
|
||||||
|
.map(|c| (*s, c, *w))
|
||||||
|
})
|
||||||
.collect();
|
.collect();
|
||||||
if seed_coords.is_empty() {
|
if seed_data.is_empty() {
|
||||||
return vec![];
|
return vec![];
|
||||||
}
|
}
|
||||||
|
|
||||||
let weights = eigenvalue_weights(&emb.eigenvalues);
|
// Build seed→neighbor link strength lookup
|
||||||
|
let link_strengths: HashMap<(&str, &str), f32> = if let Some(g) = graph {
|
||||||
|
let mut map = HashMap::new();
|
||||||
|
for &(seed_key, _) in seeds {
|
||||||
|
for (neighbor, strength) in g.neighbors(seed_key) {
|
||||||
|
map.insert((seed_key, neighbor.as_str()), strength);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
map
|
||||||
|
} else {
|
||||||
|
HashMap::new()
|
||||||
|
};
|
||||||
|
|
||||||
|
let dim_weights = eigenvalue_weights(&emb.eigenvalues);
|
||||||
|
|
||||||
let mut distances: Vec<(String, f64)> = emb.coords.iter()
|
let mut distances: Vec<(String, f64)> = emb.coords.iter()
|
||||||
.filter(|(k, _)| !seed_set.contains(k.as_str()))
|
.filter(|(k, coords)| {
|
||||||
.map(|(k, coords)| {
|
!seed_set.contains(k.as_str())
|
||||||
let min_dist = seed_coords.iter()
|
&& coords.iter().any(|&v| v.abs() > 1e-12) // skip degenerate zero-coord nodes
|
||||||
.map(|sc| weighted_distance(coords, sc, &weights))
|
})
|
||||||
|
.map(|(candidate_key, coords)| {
|
||||||
|
let min_dist = seed_data.iter()
|
||||||
|
.map(|(seed_key, sc, seed_weight)| {
|
||||||
|
let raw_dist = weighted_distance(coords, sc, &dim_weights);
|
||||||
|
|
||||||
|
// Scale by link strength if directly connected
|
||||||
|
let link_scale = link_strengths
|
||||||
|
.get(&(*seed_key, candidate_key.as_str()))
|
||||||
|
.map(|&s| 1.0 / (1.0 + s as f64)) // strong link → smaller distance
|
||||||
|
.unwrap_or(1.0);
|
||||||
|
|
||||||
|
raw_dist * link_scale / seed_weight
|
||||||
|
})
|
||||||
.fold(f64::MAX, f64::min);
|
.fold(f64::MAX, f64::min);
|
||||||
(k.clone(), min_dist)
|
(candidate_key.clone(), min_dist)
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue