diff --git a/src/subconscious/learn.rs b/src/subconscious/learn.rs index 9fa3d5e..a81f0a4 100644 --- a/src/subconscious/learn.rs +++ b/src/subconscious/learn.rs @@ -121,14 +121,18 @@ async fn call_score( http: &crate::agent::api::http::HttpClient, client: &ApiClient, messages: &[serde_json::Value], + priority: Option, ) -> anyhow::Result> { let url = format!("{}/score", client.base_url()); let auth = format!("Bearer {}", client.api_key()); - let body = serde_json::json!({ + let mut body = serde_json::json!({ "model": client.model, "messages": messages, "logprobs": 1, }); + if let Some(p) = priority { + body["priority"] = serde_json::json!(p); + } let response = http .send_json("POST", &url, &[ ("authorization", &auth), @@ -169,9 +173,10 @@ async fn score_divergence( context: &ContextState, range: std::ops::Range, filter: Filter<'_>, + priority: Option, ) -> anyhow::Result<(Vec, Vec)> { - let baseline = call_score(http, client, &build_messages(context, range.clone(), Filter::None)).await?; - let without = call_score(http, client, &build_messages(context, range, filter)).await?; + let baseline = call_score(http, client, &build_messages(context, range.clone(), Filter::None), priority).await?; + let without = call_score(http, client, &build_messages(context, range, filter), priority).await?; let divs = divergence(&baseline, &without); Ok((divs, baseline)) } @@ -232,7 +237,7 @@ pub async fn score_memories( let http = http_client(); let range = 0..context.conversation().len(); - let baseline = call_score(&http, client, &build_messages(context, range.clone(), Filter::None)).await?; + let baseline = call_score(&http, client, &build_messages(context, range.clone(), Filter::None), Some(5)).await?; let total = memory_keys.len(); let mut matrix: Vec> = Vec::new(); @@ -242,7 +247,7 @@ pub async fn score_memories( "scoring {}/{}: {}...", mem_idx + 1, total, key, ); let msgs = build_messages(context, range.clone(), Filter::SkipKey(key)); - match call_score(&http, client, &msgs).await { + match call_score(&http, client, &msgs, Some(5)).await { Ok(without) => matrix.push(divergence(&baseline, &without)), Err(e) => { dbglog!( @@ -312,7 +317,7 @@ pub async fn score_memory( } let http = http_client(); - let (divs, _) = score_divergence(&http, client, context, range, Filter::SkipKey(key)).await?; + let (divs, _) = score_divergence(&http, client, context, range, Filter::SkipKey(key), Some(5)).await?; Ok(divs.iter().sum()) } @@ -389,7 +394,7 @@ where } let _scoring = crate::agent::start_activity(agent, format!("scoring: {}", key)).await; - match score_divergence(&http, client, context, range, Filter::SkipKey(key)).await { + match score_divergence(&http, client, context, range, Filter::SkipKey(key), Some(5)).await { Ok((divs, _)) => { let n_responses = divs.len(); let max_div = divs.iter().cloned().fold(0.0f64, f64::max); @@ -435,7 +440,7 @@ pub async fn score_finetune( } let http = http_client(); - let (divs, _) = score_divergence(&http, client, context, range, Filter::SkipAllMemories).await?; + let (divs, _) = score_divergence(&http, client, context, range, Filter::SkipAllMemories, Some(5)).await?; let mut results: Vec<(usize, f64)> = response_positions.iter() .enumerate()