diff --git a/poc-memory/src/bin/memory-search.rs b/poc-memory/src/bin/memory-search.rs index 6e0fedc..5078e0b 100644 --- a/poc-memory/src/bin/memory-search.rs +++ b/poc-memory/src/bin/memory-search.rs @@ -192,10 +192,22 @@ fn main() { // Search for node keys in last ~150k tokens of transcript if debug { println!("[memory-search] transcript: {}", transcript_path); } - let terms = extract_weighted_terms(transcript_path, 150_000, &store); + let mut terms = extract_weighted_terms(transcript_path, 150_000, &store); + + // Also extract terms from the prompt itself (handles fresh sessions + // and queries about topics not yet mentioned in the transcript) + let prompt_terms = search::extract_query_terms(prompt, 8); + if !prompt_terms.is_empty() { + if debug { println!("[memory-search] prompt terms: {}", prompt_terms); } + for word in prompt_terms.split_whitespace() { + let lower = word.to_lowercase(); + // Prompt terms get weight 1.0 (same as direct mention) + terms.entry(lower).or_insert(1.0); + } + } if debug { - println!("[memory-search] {} node keys found in transcript", terms.len()); + println!("[memory-search] {} terms total", terms.len()); let mut by_weight: Vec<_> = terms.iter().collect(); by_weight.sort_by(|a, b| b.1.total_cmp(a.1)); for (term, weight) in by_weight.iter().take(20) { @@ -204,7 +216,7 @@ fn main() { } if terms.is_empty() { - if debug { println!("[memory-search] no node keys found, done"); } + if debug { println!("[memory-search] no terms found, done"); } return; } diff --git a/poc-memory/src/search.rs b/poc-memory/src/search.rs index fb4f269..a02fcd1 100644 --- a/poc-memory/src/search.rs +++ b/poc-memory/src/search.rs @@ -96,7 +96,12 @@ impl AlgoStage { } } -/// Extract seeds from weighted terms by matching against node keys. +/// Extract seeds from weighted terms by matching against node keys and content. +/// +/// Three matching strategies, in priority order: +/// 1. Exact key match: term matches a node key exactly → full weight +/// 2. Key component match: term matches a word in a hyphenated/underscored key → 0.5× weight +/// 3. Content match: term appears in node content → 0.2× weight (capped at 50 nodes) /// /// Returns (seeds, direct_hits) where direct_hits tracks which keys /// were matched directly (vs found by an algorithm stage). @@ -104,22 +109,62 @@ pub fn match_seeds( terms: &BTreeMap, store: &impl StoreView, ) -> (Vec<(String, f64)>, HashSet) { - let mut seeds: Vec<(String, f64)> = Vec::new(); + let mut seed_map: HashMap = HashMap::new(); let mut direct_hits: HashSet = HashSet::new(); + // Build key lookup: lowercase key → (original key, weight) let mut key_map: HashMap = HashMap::new(); + // Build component index: word → vec of (original key, weight) + let mut component_map: HashMap> = HashMap::new(); + store.for_each_node(|key, _content, weight| { - key_map.insert(key.to_lowercase(), (key.to_owned(), weight as f64)); + let lkey = key.to_lowercase(); + key_map.insert(lkey.clone(), (key.to_owned(), weight as f64)); + + // Split key on hyphens, underscores, dots, hashes for component matching + for component in lkey.split(|c: char| c == '-' || c == '_' || c == '.' || c == '#') { + if component.len() >= 3 { + component_map.entry(component.to_owned()) + .or_default() + .push((key.to_owned(), weight as f64)); + } + } }); for (term, &term_weight) in terms { + // Strategy 1: exact key match if let Some((orig_key, node_weight)) = key_map.get(term) { let score = term_weight * node_weight; - seeds.push((orig_key.clone(), score)); + *seed_map.entry(orig_key.clone()).or_insert(0.0) += score; direct_hits.insert(orig_key.clone()); + continue; } + + // Strategy 2: key component match (0.5× weight) + if let Some(matches) = component_map.get(term.as_str()) { + for (orig_key, node_weight) in matches { + let score = term_weight * node_weight * 0.5; + *seed_map.entry(orig_key.clone()).or_insert(0.0) += score; + direct_hits.insert(orig_key.clone()); + } + continue; + } + + // Strategy 3: content match (0.2× weight, limited to avoid O(n*m) explosion) + let term_lower = term.to_lowercase(); + if term_lower.len() < 3 { continue; } + let mut content_hits = 0; + store.for_each_node(|key, content, weight| { + if content_hits >= 50 { return; } + if content.to_lowercase().contains(&term_lower) { + let score = term_weight * weight as f64 * 0.2; + *seed_map.entry(key.to_owned()).or_insert(0.0) += score; + content_hits += 1; + } + }); } + let seeds: Vec<(String, f64)> = seed_map.into_iter().collect(); (seeds, direct_hits) }