link-audit: parallelize Sonnet calls with rayon
Build all batch prompts up front, run them in parallel via rayon::par_iter, process results sequentially. Also fix temp file collision under parallel calls by including thread ID in filename.
This commit is contained in:
parent
e33328e515
commit
ad4e622ab9
1 changed files with 44 additions and 23 deletions
|
|
@ -37,7 +37,9 @@ fn agent_results_dir() -> PathBuf {
|
|||
/// Call Sonnet via claude CLI. Returns the response text.
|
||||
pub(crate) fn call_sonnet(prompt: &str, _timeout_secs: u64) -> Result<String, String> {
|
||||
// Write prompt to temp file (claude CLI needs file input for large prompts)
|
||||
let tmp = std::env::temp_dir().join(format!("poc-digest-{}.txt", std::process::id()));
|
||||
// Use thread ID + PID to avoid collisions under parallel rayon calls
|
||||
let tmp = std::env::temp_dir().join(format!("poc-digest-{}-{:?}.txt",
|
||||
std::process::id(), std::thread::current().id()));
|
||||
fs::write(&tmp, prompt)
|
||||
.map_err(|e| format!("write temp prompt: {}", e))?;
|
||||
|
||||
|
|
@ -1975,32 +1977,53 @@ pub fn link_audit(store: &mut Store, apply: bool) -> Result<AuditStats, String>
|
|||
println!("{} batches (avg {} links/batch)\n", total_batches,
|
||||
if total_batches > 0 { total / total_batches } else { 0 });
|
||||
|
||||
use rayon::prelude::*;
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
|
||||
// Build all batch prompts up front
|
||||
let batch_data: Vec<(usize, Vec<LinkInfo>, String)> = batches.iter().enumerate()
|
||||
.map(|(batch_idx, batch_indices)| {
|
||||
let batch_infos: Vec<LinkInfo> = batch_indices.iter().map(|&i| {
|
||||
let l = &links[i];
|
||||
LinkInfo {
|
||||
rel_idx: l.rel_idx,
|
||||
source_key: l.source_key.clone(),
|
||||
target_key: l.target_key.clone(),
|
||||
source_content: l.source_content.clone(),
|
||||
target_content: l.target_content.clone(),
|
||||
strength: l.strength,
|
||||
target_sections: l.target_sections.clone(),
|
||||
}
|
||||
}).collect();
|
||||
let prompt = build_audit_prompt(&batch_infos, batch_idx + 1, total_batches);
|
||||
(batch_idx, batch_infos, prompt)
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Progress counter
|
||||
let done = AtomicUsize::new(0);
|
||||
|
||||
// Run batches in parallel via rayon
|
||||
let batch_results: Vec<_> = batch_data.par_iter()
|
||||
.map(|(batch_idx, batch_infos, prompt)| {
|
||||
let response = call_sonnet(prompt, 300);
|
||||
let completed = done.fetch_add(1, Ordering::Relaxed) + 1;
|
||||
eprint!("\r Batches: {}/{} done", completed, total_batches);
|
||||
(*batch_idx, batch_infos, response)
|
||||
})
|
||||
.collect();
|
||||
eprintln!(); // newline after progress
|
||||
|
||||
// Process results sequentially
|
||||
let mut stats = AuditStats {
|
||||
kept: 0, deleted: 0, retargeted: 0, weakened: 0, strengthened: 0, errors: 0,
|
||||
};
|
||||
|
||||
// Track changes to apply at the end
|
||||
let mut deletions: Vec<usize> = Vec::new();
|
||||
let mut retargets: Vec<(usize, String)> = Vec::new();
|
||||
let mut strength_changes: Vec<(usize, f32)> = Vec::new();
|
||||
|
||||
for (batch_idx, batch_indices) in batches.iter().enumerate() {
|
||||
let batch_links: Vec<&LinkInfo> = batch_indices.iter()
|
||||
.map(|&i| &links[i])
|
||||
.collect();
|
||||
|
||||
let batch_infos: Vec<LinkInfo> = batch_links.iter().map(|l| LinkInfo {
|
||||
rel_idx: l.rel_idx,
|
||||
source_key: l.source_key.clone(),
|
||||
target_key: l.target_key.clone(),
|
||||
source_content: l.source_content.clone(),
|
||||
target_content: l.target_content.clone(),
|
||||
strength: l.strength,
|
||||
target_sections: l.target_sections.clone(),
|
||||
}).collect();
|
||||
|
||||
let prompt = build_audit_prompt(&batch_infos, batch_idx + 1, total_batches);
|
||||
let response = match call_sonnet(&prompt, 300) {
|
||||
for (batch_idx, batch_infos, response) in &batch_results {
|
||||
let response = match response {
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
eprintln!(" Batch {}: error: {}", batch_idx + 1, e);
|
||||
|
|
@ -2009,9 +2032,8 @@ pub fn link_audit(store: &mut Store, apply: bool) -> Result<AuditStats, String>
|
|||
}
|
||||
};
|
||||
|
||||
let actions = parse_audit_response(&response, batch_infos.len());
|
||||
let actions = parse_audit_response(response, batch_infos.len());
|
||||
|
||||
// Count unresponded links as kept
|
||||
let mut responded: std::collections::HashSet<usize> = std::collections::HashSet::new();
|
||||
|
||||
for (idx, action) in &actions {
|
||||
|
|
@ -2048,7 +2070,6 @@ pub fn link_audit(store: &mut Store, apply: bool) -> Result<AuditStats, String>
|
|||
}
|
||||
}
|
||||
|
||||
// Count unresponded as kept
|
||||
for i in 0..batch_infos.len() {
|
||||
if !responded.contains(&i) {
|
||||
stats.kept += 1;
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue