diff --git a/poc-memory/src/search.rs b/poc-memory/src/search.rs index 8cd8afa..a334643 100644 --- a/poc-memory/src/search.rs +++ b/poc-memory/src/search.rs @@ -522,28 +522,42 @@ fn run_manifold( } } - // Compute principal direction: weighted PCA axis 1 - // For each seed, its deviation from centroid contributes to the direction + // Compute principal direction via power iteration on seed covariance. + // Initialize with the two most separated seeds (largest spectral distance). let mut direction = vec![0.0f64; dims]; if seed_data.len() >= 2 { - // Use power iteration to find dominant direction of seed spread - // Initialize with the vector from first seed to last seed - let first = seed_data.first().unwrap().2; - let last = seed_data.last().unwrap().2; - for i in 0..dims { - direction[i] = last[i] - first[i]; - } - - // One round of power iteration on the covariance matrix - let mut new_dir = vec![0.0f64; dims]; - for (_, weight, coords) in &seed_data { - let dev: Vec = coords.iter().zip(centroid.iter()).map(|(c, m)| c - m).collect(); - let dot: f64 = dev.iter().zip(direction.iter()).map(|(d, v)| d * v).sum(); - for i in 0..dims { - new_dir[i] += weight * dot * dev[i]; + // Find the two seeds furthest apart in spectral space + let mut best_dist = 0.0f64; + for i in 0..seed_data.len() { + for j in (i + 1)..seed_data.len() { + let dist: f64 = seed_data[i].2.iter().zip(seed_data[j].2.iter()) + .map(|(a, b)| (a - b).powi(2)).sum::().sqrt(); + if dist > best_dist { + best_dist = dist; + for d in 0..dims { + direction[d] = seed_data[j].2[d] - seed_data[i].2[d]; + } + } } } - direction = new_dir; + + // Power iteration: 3 rounds on the weighted covariance matrix + for _ in 0..3 { + let mut new_dir = vec![0.0f64; dims]; + for (_, weight, coords) in &seed_data { + let dev: Vec = coords.iter().zip(centroid.iter()).map(|(c, m)| c - m).collect(); + let dot: f64 = dev.iter().zip(direction.iter()).map(|(d, v)| d * v).sum(); + for d in 0..dims { + new_dir[d] += weight * dot * dev[d]; + } + } + // Normalize + let norm = new_dir.iter().map(|d| d * d).sum::().sqrt(); + if norm > 1e-12 { + for d in &mut new_dir { *d /= norm; } + } + direction = new_dir; + } } let dir_norm = direction.iter().map(|d| d * d).sum::().sqrt();