learn: nanosecond timestamps, token ranges for /score

Two related changes to the learn subsystem:

1. AST node timestamps are now non-optional — both Leaf and Branch
   variants carry a DateTime<Utc>. UNIX_EPOCH means "unset" (old entries
   deserialized from on-disk conversation logs).

   Training uses timestamps as unique keys for dedup, so we promote to
   nanosecond precision: node_timestamp_ns(), TrainData.timestamp_ns,
   FinetuneCandidate.timestamp_ns, mark_trained(ns).

2. build_token_ids() now also returns token-position ranges of assistant
   messages. These are passed to vLLM's /score endpoint via the new
   score_ranges field so only scored-position logprobs are returned —
   cuts bandwidth/compute when scoring small windows.

Co-Authored-By: Proof of Concept <poc@bcachefs.org>
This commit is contained in:
Kent Overstreet 2026-04-16 11:48:37 -04:00
parent 5d9d3ffc5b
commit 2b632d568b
5 changed files with 130 additions and 44 deletions

View file

@ -53,13 +53,18 @@ fn is_assistant(node: &AstNode) -> bool {
///
/// Includes all sections up to and including conversation entries in
/// `range`, with `filter` applied to conversation entries.
///
/// Returns (token_ids, assistant_ranges) where assistant_ranges are
/// (start, end) token positions for each assistant message.
fn build_token_ids(
context: &ContextState,
range: std::ops::Range<usize>,
filter: Filter,
) -> Vec<u32> {
) -> (Vec<u32>, Vec<(usize, usize)>) {
use crate::agent::context::Ast;
let mut ids = Vec::new();
let mut assistant_ranges = Vec::new();
for node in context.system() {
ids.extend(node.token_ids());
}
@ -87,9 +92,16 @@ fn build_token_ids(
Filter::SkipAllMemories => is_memory(node),
};
if skip { continue; }
// Track assistant message boundaries
let is_asst = is_assistant(node);
let start = ids.len();
ids.extend(node.token_ids());
if is_asst {
assistant_ranges.push((start, ids.len()));
}
}
ids
(ids, assistant_ranges)
}
// ── Score API ───────────────────────────────────────────────────
@ -114,6 +126,7 @@ async fn call_score(
http: &crate::agent::api::http::HttpClient,
client: &ApiClient,
prompt: &[u32],
ranges: &[(usize, usize)],
priority: Option<i32>,
) -> anyhow::Result<Vec<ScoreResult>> {
let url = format!("{}/score", client.base_url());
@ -123,6 +136,9 @@ async fn call_score(
"prompt": prompt,
"logprobs": 1,
});
if !ranges.is_empty() {
body["score_ranges"] = serde_json::json!(ranges);
}
if let Some(p) = priority {
body["priority"] = serde_json::json!(p);
}
@ -168,8 +184,10 @@ async fn score_divergence(
filter: Filter<'_>,
priority: Option<i32>,
) -> anyhow::Result<(Vec<f64>, Vec<ScoreResult>)> {
let baseline = call_score(http, client, &build_token_ids(context, range.clone(), Filter::None), priority).await?;
let without = call_score(http, client, &build_token_ids(context, range, filter), priority).await?;
let (baseline_tokens, baseline_ranges) = build_token_ids(context, range.clone(), Filter::None);
let (without_tokens, without_ranges) = build_token_ids(context, range, filter);
let baseline = call_score(http, client, &baseline_tokens, &baseline_ranges, priority).await?;
let without = call_score(http, client, &without_tokens, &without_ranges, priority).await?;
let divs = divergence(&baseline, &without);
Ok((divs, baseline))
}
@ -208,21 +226,21 @@ pub async fn score_memories(
let http = http_client();
let activity = crate::agent::start_activity(agent, "scoring: baseline").await;
let baseline_tokens = {
let (baseline_tokens, baseline_ranges) = {
let ctx = agent.context.lock().await;
build_token_ids(&ctx, 0..ctx.conversation().len(), Filter::None)
};
let baseline = call_score(&http, client, &baseline_tokens, Some(5)).await?;
let baseline = call_score(&http, client, &baseline_tokens, &baseline_ranges, Some(5)).await?;
dbglog!("[scoring-full] baseline done ({} response scores)", baseline.len());
for (mem_idx, key) in memory_keys.iter().enumerate() {
activity.update(format!("scoring: {}/{}", mem_idx + 1, total)).await;
dbglog!("[scoring-full] {}/{}: {}", mem_idx + 1, total, key);
let tokens = {
let (tokens, ranges) = {
let ctx = agent.context.lock().await;
build_token_ids(&ctx, 0..ctx.conversation().len(), Filter::SkipKey(key))
};
let row = match call_score(&http, client, &tokens, Some(5)).await {
let row = match call_score(&http, client, &tokens, &ranges, Some(5)).await {
Ok(without) => {
let divs = divergence(&baseline, &without);
let max_div = divs.iter().cloned().fold(0.0f64, f64::max);
@ -466,8 +484,8 @@ pub struct FinetuneCandidate {
pub continuation_ids: Vec<u32>,
/// What the model would have said without memories (if generated).
pub alternate_text: Option<String>,
/// Timestamp in millis for tracking trained status.
pub timestamp_ms: i64,
/// Timestamp in nanos — used as unique key for trained-set dedup.
pub timestamp_ns: i64,
}
/// Score and enrich finetune candidates with full context.
@ -495,7 +513,7 @@ pub async fn score_finetune_candidates(
let node = &entries[entry_idx];
// Get timestamp and skip if already trained
let timestamp_ms = match node_timestamp_ms(node) {
let timestamp_ns = match node_timestamp_ns(node) {
Some(ts) => {
if trained.contains(&ts) {
continue; // Already trained, skip
@ -520,7 +538,7 @@ pub async fn score_finetune_candidates(
};
// Build token IDs: context = everything before response, continuation = response
let context_ids = build_token_ids(context, 0..entry_idx, Filter::None);
let (context_ids, _) = build_token_ids(context, 0..entry_idx, Filter::None);
let continuation_ids: Vec<u32> = node.token_ids().into_iter().collect();
candidates.push(FinetuneCandidate {
@ -530,7 +548,7 @@ pub async fn score_finetune_candidates(
context_ids,
continuation_ids,
alternate_text: None,
timestamp_ms,
timestamp_ns,
});
}
@ -556,7 +574,7 @@ async fn generate_alternate(
use crate::agent::api::{SamplingParams, StreamToken};
// Build context tokens without memories, up to the response
let mut prompt = build_token_ids(context, 0..entry_idx, Filter::SkipAllMemories);
let (mut prompt, _) = build_token_ids(context, 0..entry_idx, Filter::SkipAllMemories);
// Add assistant turn start
prompt.push(tokenizer::IM_START);
@ -616,7 +634,7 @@ pub fn set_alternates(enabled: bool) {
}
}
/// Load set of trained response timestamps (millis since epoch).
/// Load set of trained response timestamps (nanos since epoch).
pub fn load_trained() -> HashSet<i64> {
let path = trained_path();
match std::fs::read_to_string(&path) {
@ -626,9 +644,9 @@ pub fn load_trained() -> HashSet<i64> {
}
/// Mark a response as trained by its timestamp.
pub fn mark_trained(timestamp_ms: i64) {
pub fn mark_trained(timestamp_ns: i64) {
let mut trained = load_trained();
trained.insert(timestamp_ms);
trained.insert(timestamp_ns);
let path = trained_path();
if let Some(parent) = path.parent() {
let _ = std::fs::create_dir_all(parent);
@ -638,15 +656,19 @@ pub fn mark_trained(timestamp_ms: i64) {
}
}
/// Get timestamp in millis from an AstNode (for Branch, uses first child).
pub fn node_timestamp_ms(node: &AstNode) -> Option<i64> {
/// Get timestamp in nanoseconds from an AstNode.
/// Returns None for entries with default UNIX_EPOCH timestamp (old data)
/// or timestamps outside the representable nano range (pre-1677 or post-2262).
pub fn node_timestamp_ns(node: &AstNode) -> Option<i64> {
let ts = match node {
AstNode::Leaf(leaf) => leaf.timestamp(),
AstNode::Branch { children, .. } => {
children.first()?.leaf()?.timestamp()
}
}?;
Some(ts.timestamp_millis())
AstNode::Branch { timestamp, .. } => *timestamp,
};
if ts == chrono::DateTime::UNIX_EPOCH {
None // Old entry without real timestamp
} else {
ts.timestamp_nanos_opt()
}
}
// ── Training API ────────────────────────────────────────────────
@ -662,7 +684,7 @@ struct TrainingSample {
pub struct TrainData {
pub context_ids: Vec<u32>,
pub continuation_ids: Vec<u32>,
pub timestamp_ms: i64,
pub timestamp_ns: i64,
}
/// Send training samples to the server.
@ -703,7 +725,7 @@ pub async fn send_to_train(
// Mark all samples as trained
for s in &samples {
mark_trained(s.timestamp_ms);
mark_trained(s.timestamp_ns);
}
let job_id = result.get("job_id")