reuse HTTP client across scoring calls for connection pooling

Single reqwest::Client shared across all prompt_logprobs calls
instead of creating a new one per call. Keeps HTTP connections
alive for faster sequential requests.

Co-Authored-By: Proof of Concept <poc@bcachefs.org>
This commit is contained in:
Kent Overstreet 2026-04-02 23:11:40 -04:00
parent 31302961e2
commit 4f19c02e50

View file

@ -83,8 +83,14 @@ pub async fn score_memories(
memories.len(), response_indices.len(), memories.len(), response_indices.len(),
))); )));
// Shared HTTP client for connection reuse across all scoring calls
let http = reqwest::Client::builder()
.pool_max_idle_per_host(2)
.build()
.unwrap_or_default();
// Baseline: logprobs with all memories present // Baseline: logprobs with all memories present
let baseline = get_response_logprobs(context, &context.entries, client, ui_tx).await?; let baseline = get_response_logprobs(context, &context.entries, client, &http, ui_tx).await?;
let _ = ui_tx.send(UiMessage::Debug(format!( let _ = ui_tx.send(UiMessage::Debug(format!(
"[training] baseline: {} response tokens scored", "[training] baseline: {} response tokens scored",
@ -110,7 +116,7 @@ pub async fn score_memories(
.map(|(_, e)| e.clone()) .map(|(_, e)| e.clone())
.collect(); .collect();
let without = get_response_logprobs(context, &filtered, client, ui_tx).await?; let without = get_response_logprobs(context, &filtered, client, &http, ui_tx).await?;
// Compute per-response divergence // Compute per-response divergence
let mut row = Vec::new(); let mut row = Vec::new();
@ -194,6 +200,7 @@ async fn get_response_logprobs(
context: &ContextState, context: &ContextState,
entries: &[ConversationEntry], entries: &[ConversationEntry],
client: &ApiClient, client: &ApiClient,
http: &reqwest::Client,
ui_tx: &UiSender, ui_tx: &UiSender,
) -> anyhow::Result<Vec<Vec<f64>>> { ) -> anyhow::Result<Vec<Vec<f64>>> {
// Build the fixed prefix (system prompt + personality) // Build the fixed prefix (system prompt + personality)
@ -235,7 +242,7 @@ async fn get_response_logprobs(
let mut msgs = prefix.clone(); let mut msgs = prefix.clone();
msgs.extend(chunk.iter().map(|e| e.api_message().clone())); msgs.extend(chunk.iter().map(|e| e.api_message().clone()));
let result = call_prompt_logprobs(&msgs, client).await?; let result = call_prompt_logprobs(&msgs, client, http).await?;
all_responses.extend(result); all_responses.extend(result);
} }
@ -277,6 +284,7 @@ fn chunk_entries(entries: &[ConversationEntry], budget_chars: usize) -> Vec<Vec<
async fn call_prompt_logprobs( async fn call_prompt_logprobs(
msgs: &[Message], msgs: &[Message],
client: &ApiClient, client: &ApiClient,
http: &reqwest::Client,
) -> anyhow::Result<Vec<Vec<f64>>> { ) -> anyhow::Result<Vec<Vec<f64>>> {
let request = serde_json::json!({ let request = serde_json::json!({
"model": client.model, "model": client.model,
@ -286,7 +294,7 @@ async fn call_prompt_logprobs(
"stream": false, "stream": false,
}); });
let response = reqwest::Client::new() let response = http
.post(format!("{}/chat/completions", client.base_url())) .post(format!("{}/chat/completions", client.base_url()))
.header("Content-Type", "application/json") .header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {}", client.api_key())) .header("Authorization", format!("Bearer {}", client.api_key()))