learn: nanosecond timestamps, token ranges for /score

Two related changes to the learn subsystem:

1. AST node timestamps are now non-optional — both Leaf and Branch
   variants carry a DateTime<Utc>. UNIX_EPOCH means "unset" (old entries
   deserialized from on-disk conversation logs).

   Training uses timestamps as unique keys for dedup, so we promote to
   nanosecond precision: node_timestamp_ns(), TrainData.timestamp_ns,
   FinetuneCandidate.timestamp_ns, mark_trained(ns).

2. build_token_ids() now also returns token-position ranges of assistant
   messages. These are passed to vLLM's /score endpoint via the new
   score_ranges field so only scored-position logprobs are returned —
   cuts bandwidth/compute when scoring small windows.

Co-Authored-By: Proof of Concept <poc@bcachefs.org>
This commit is contained in:
Kent Overstreet 2026-04-16 11:48:37 -04:00
parent 5d9d3ffc5b
commit 2b632d568b
5 changed files with 130 additions and 44 deletions

View file

@ -85,6 +85,19 @@ pub enum NodeBody {
Log(String), Log(String),
} }
fn default_timestamp() -> DateTime<Utc> {
DateTime::UNIX_EPOCH
}
/// Deserialize timestamp, treating both missing and null as UNIX_EPOCH.
fn deserialize_timestamp_or_epoch<'de, D>(deserializer: D) -> Result<DateTime<Utc>, D::Error>
where
D: serde::Deserializer<'de>,
{
let opt: Option<DateTime<Utc>> = Option::deserialize(deserializer)?;
Ok(opt.unwrap_or(DateTime::UNIX_EPOCH))
}
/// A leaf node: typed content with cached token IDs. /// A leaf node: typed content with cached token IDs.
/// Token IDs are not serialized — they're recomputed on deserialization. /// Token IDs are not serialized — they're recomputed on deserialization.
#[derive(Debug, Clone, Serialize)] #[derive(Debug, Clone, Serialize)]
@ -92,7 +105,7 @@ pub struct NodeLeaf {
body: NodeBody, body: NodeBody,
#[serde(skip)] #[serde(skip)]
token_ids: Vec<u32>, token_ids: Vec<u32>,
timestamp: Option<DateTime<Utc>>, timestamp: DateTime<Utc>,
} }
impl<'de> Deserialize<'de> for NodeLeaf { impl<'de> Deserialize<'de> for NodeLeaf {
@ -100,7 +113,8 @@ impl<'de> Deserialize<'de> for NodeLeaf {
#[derive(Deserialize)] #[derive(Deserialize)]
struct Raw { struct Raw {
body: NodeBody, body: NodeBody,
timestamp: Option<DateTime<Utc>>, #[serde(default = "default_timestamp", deserialize_with = "deserialize_timestamp_or_epoch")]
timestamp: DateTime<Utc>,
} }
let raw = Raw::deserialize(deserializer)?; let raw = Raw::deserialize(deserializer)?;
let token_ids = if raw.body.is_prompt_visible() { let token_ids = if raw.body.is_prompt_visible() {
@ -119,6 +133,8 @@ pub enum AstNode {
Branch { Branch {
role: Role, role: Role,
children: Vec<AstNode>, children: Vec<AstNode>,
#[serde(default = "default_timestamp", deserialize_with = "deserialize_timestamp_or_epoch")]
timestamp: DateTime<Utc>,
/// Per-response memory attribution from full scoring matrix. /// Per-response memory attribution from full scoring matrix.
/// Maps memory key → divergence score for this response. /// Maps memory key → divergence score for this response.
#[serde(default, skip_serializing_if = "std::collections::BTreeMap::is_empty")] #[serde(default, skip_serializing_if = "std::collections::BTreeMap::is_empty")]
@ -252,18 +268,18 @@ impl NodeLeaf {
} else { } else {
vec![] vec![]
}; };
Self { body, token_ids, timestamp: None } Self { body, token_ids, timestamp: Utc::now() }
} }
pub fn with_timestamp(mut self, ts: DateTime<Utc>) -> Self { pub fn with_timestamp(mut self, ts: DateTime<Utc>) -> Self {
self.timestamp = Some(ts); self.timestamp = ts;
self self
} }
pub fn body(&self) -> &NodeBody { &self.body } pub fn body(&self) -> &NodeBody { &self.body }
pub fn token_ids(&self) -> &[u32] { &self.token_ids } pub fn token_ids(&self) -> &[u32] { &self.token_ids }
pub fn tokens(&self) -> usize { self.token_ids.len() } pub fn tokens(&self) -> usize { self.token_ids.len() }
pub fn timestamp(&self) -> Option<DateTime<Utc>> { self.timestamp } pub fn timestamp(&self) -> DateTime<Utc> { self.timestamp }
} }
impl AstNode { impl AstNode {
@ -307,13 +323,14 @@ impl AstNode {
// -- Branch constructors -------------------------------------------------- // -- Branch constructors --------------------------------------------------
pub fn branch(role: Role, children: Vec<AstNode>) -> Self { pub fn branch(role: Role, children: Vec<AstNode>) -> Self {
Self::Branch { role, children, memory_scores: Default::default() } Self::Branch { role, children, timestamp: Utc::now(), memory_scores: Default::default() }
} }
pub fn system_msg(text: impl Into<String>) -> Self { pub fn system_msg(text: impl Into<String>) -> Self {
Self::Branch { Self::Branch {
role: Role::System, role: Role::System,
children: vec![Self::content(text)], children: vec![Self::content(text)],
timestamp: Utc::now(),
memory_scores: Default::default(), memory_scores: Default::default(),
} }
} }
@ -322,6 +339,7 @@ impl AstNode {
Self::Branch { Self::Branch {
role: Role::User, role: Role::User,
children: vec![Self::content(text)], children: vec![Self::content(text)],
timestamp: Utc::now(),
memory_scores: Default::default(), memory_scores: Default::default(),
} }
} }
@ -338,9 +356,10 @@ impl AstNode {
}; };
Self::Leaf(NodeLeaf { token_ids, ..leaf }) Self::Leaf(NodeLeaf { token_ids, ..leaf })
} }
Self::Branch { role, children, memory_scores, .. } => Self::Branch { Self::Branch { role, children, timestamp, memory_scores } => Self::Branch {
role, role,
children: children.into_iter().map(|c| c.retokenize()).collect(), children: children.into_iter().map(|c| c.retokenize()).collect(),
timestamp,
memory_scores, memory_scores,
}, },
} }
@ -348,8 +367,8 @@ impl AstNode {
pub fn with_timestamp(mut self, ts: DateTime<Utc>) -> Self { pub fn with_timestamp(mut self, ts: DateTime<Utc>) -> Self {
match &mut self { match &mut self {
Self::Leaf(leaf) => leaf.timestamp = Some(ts), Self::Leaf(leaf) => leaf.timestamp = ts,
Self::Branch { .. } => {} Self::Branch { timestamp, .. } => *timestamp = ts,
} }
self self
} }
@ -1340,4 +1359,49 @@ mod tests {
assert_token_invariants(node); assert_token_invariants(node);
assert!(node.tokens() > 0); assert!(node.tokens() > 0);
} }
// -- Timestamp deserialization tests ------------------------------------------
#[test]
fn test_timestamp_null_becomes_epoch() {
// Old conversation.jsonl entries have "timestamp":null
// serde(default) only handles missing fields, not explicit nulls.
// We need to verify our deserialize handles this correctly.
let json = r#"{"Leaf":{"body":{"Content":"hello"},"timestamp":null}}"#;
let node: AstNode = serde_json::from_str(json).unwrap();
let leaf = node.leaf().unwrap();
assert_eq!(leaf.timestamp(), DateTime::<Utc>::UNIX_EPOCH);
}
#[test]
fn test_timestamp_missing_becomes_epoch() {
let json = r#"{"Leaf":{"body":{"Content":"hello"}}}"#;
let node: AstNode = serde_json::from_str(json).unwrap();
let leaf = node.leaf().unwrap();
assert_eq!(leaf.timestamp(), DateTime::<Utc>::UNIX_EPOCH);
}
#[test]
fn test_branch_timestamp_null_becomes_epoch() {
let json = r#"{"Branch":{"role":"User","children":[{"Leaf":{"body":{"Content":"hi"}}}],"timestamp":null}}"#;
let node: AstNode = serde_json::from_str(json).unwrap();
match node {
AstNode::Branch { timestamp, .. } => {
assert_eq!(timestamp, DateTime::<Utc>::UNIX_EPOCH);
}
_ => panic!("expected Branch"),
}
}
#[test]
fn test_branch_timestamp_missing_becomes_epoch() {
let json = r#"{"Branch":{"role":"User","children":[{"Leaf":{"body":{"Content":"hi"}}}]}}"#;
let node: AstNode = serde_json::from_str(json).unwrap();
match node {
AstNode::Branch { timestamp, .. } => {
assert_eq!(timestamp, DateTime::<Utc>::UNIX_EPOCH);
}
_ => panic!("expected Branch"),
}
}
} }

View file

@ -55,15 +55,15 @@ impl ConversationLog {
} }
pub fn oldest_timestamp(&self) -> Option<chrono::DateTime<chrono::Utc>> { pub fn oldest_timestamp(&self) -> Option<chrono::DateTime<chrono::Utc>> {
// Read forward from the start to find first timestamp // Read forward from the start to find first non-epoch timestamp
let file = File::open(&self.path).ok()?; let file = File::open(&self.path).ok()?;
let mmap = unsafe { Mmap::map(&file).ok()? }; let mmap = unsafe { Mmap::map(&file).ok()? };
// Find first { ... } and parse
for line in mmap.split(|&b| b == b'\n') { for line in mmap.split(|&b| b == b'\n') {
if line.is_empty() { continue; } if line.is_empty() { continue; }
if let Ok(node) = serde_json::from_slice::<AstNode>(line) { if let Ok(node) = serde_json::from_slice::<AstNode>(line) {
if let Some(leaf) = node.leaf() { if let Some(leaf) = node.leaf() {
if let Some(ts) = leaf.timestamp() { let ts = leaf.timestamp();
if ts != chrono::DateTime::UNIX_EPOCH {
return Some(ts); return Some(ts);
} }
} }

View file

@ -53,13 +53,18 @@ fn is_assistant(node: &AstNode) -> bool {
/// ///
/// Includes all sections up to and including conversation entries in /// Includes all sections up to and including conversation entries in
/// `range`, with `filter` applied to conversation entries. /// `range`, with `filter` applied to conversation entries.
///
/// Returns (token_ids, assistant_ranges) where assistant_ranges are
/// (start, end) token positions for each assistant message.
fn build_token_ids( fn build_token_ids(
context: &ContextState, context: &ContextState,
range: std::ops::Range<usize>, range: std::ops::Range<usize>,
filter: Filter, filter: Filter,
) -> Vec<u32> { ) -> (Vec<u32>, Vec<(usize, usize)>) {
use crate::agent::context::Ast; use crate::agent::context::Ast;
let mut ids = Vec::new(); let mut ids = Vec::new();
let mut assistant_ranges = Vec::new();
for node in context.system() { for node in context.system() {
ids.extend(node.token_ids()); ids.extend(node.token_ids());
} }
@ -87,9 +92,16 @@ fn build_token_ids(
Filter::SkipAllMemories => is_memory(node), Filter::SkipAllMemories => is_memory(node),
}; };
if skip { continue; } if skip { continue; }
// Track assistant message boundaries
let is_asst = is_assistant(node);
let start = ids.len();
ids.extend(node.token_ids()); ids.extend(node.token_ids());
if is_asst {
assistant_ranges.push((start, ids.len()));
} }
ids }
(ids, assistant_ranges)
} }
// ── Score API ─────────────────────────────────────────────────── // ── Score API ───────────────────────────────────────────────────
@ -114,6 +126,7 @@ async fn call_score(
http: &crate::agent::api::http::HttpClient, http: &crate::agent::api::http::HttpClient,
client: &ApiClient, client: &ApiClient,
prompt: &[u32], prompt: &[u32],
ranges: &[(usize, usize)],
priority: Option<i32>, priority: Option<i32>,
) -> anyhow::Result<Vec<ScoreResult>> { ) -> anyhow::Result<Vec<ScoreResult>> {
let url = format!("{}/score", client.base_url()); let url = format!("{}/score", client.base_url());
@ -123,6 +136,9 @@ async fn call_score(
"prompt": prompt, "prompt": prompt,
"logprobs": 1, "logprobs": 1,
}); });
if !ranges.is_empty() {
body["score_ranges"] = serde_json::json!(ranges);
}
if let Some(p) = priority { if let Some(p) = priority {
body["priority"] = serde_json::json!(p); body["priority"] = serde_json::json!(p);
} }
@ -168,8 +184,10 @@ async fn score_divergence(
filter: Filter<'_>, filter: Filter<'_>,
priority: Option<i32>, priority: Option<i32>,
) -> anyhow::Result<(Vec<f64>, Vec<ScoreResult>)> { ) -> anyhow::Result<(Vec<f64>, Vec<ScoreResult>)> {
let baseline = call_score(http, client, &build_token_ids(context, range.clone(), Filter::None), priority).await?; let (baseline_tokens, baseline_ranges) = build_token_ids(context, range.clone(), Filter::None);
let without = call_score(http, client, &build_token_ids(context, range, filter), priority).await?; let (without_tokens, without_ranges) = build_token_ids(context, range, filter);
let baseline = call_score(http, client, &baseline_tokens, &baseline_ranges, priority).await?;
let without = call_score(http, client, &without_tokens, &without_ranges, priority).await?;
let divs = divergence(&baseline, &without); let divs = divergence(&baseline, &without);
Ok((divs, baseline)) Ok((divs, baseline))
} }
@ -208,21 +226,21 @@ pub async fn score_memories(
let http = http_client(); let http = http_client();
let activity = crate::agent::start_activity(agent, "scoring: baseline").await; let activity = crate::agent::start_activity(agent, "scoring: baseline").await;
let baseline_tokens = { let (baseline_tokens, baseline_ranges) = {
let ctx = agent.context.lock().await; let ctx = agent.context.lock().await;
build_token_ids(&ctx, 0..ctx.conversation().len(), Filter::None) build_token_ids(&ctx, 0..ctx.conversation().len(), Filter::None)
}; };
let baseline = call_score(&http, client, &baseline_tokens, Some(5)).await?; let baseline = call_score(&http, client, &baseline_tokens, &baseline_ranges, Some(5)).await?;
dbglog!("[scoring-full] baseline done ({} response scores)", baseline.len()); dbglog!("[scoring-full] baseline done ({} response scores)", baseline.len());
for (mem_idx, key) in memory_keys.iter().enumerate() { for (mem_idx, key) in memory_keys.iter().enumerate() {
activity.update(format!("scoring: {}/{}", mem_idx + 1, total)).await; activity.update(format!("scoring: {}/{}", mem_idx + 1, total)).await;
dbglog!("[scoring-full] {}/{}: {}", mem_idx + 1, total, key); dbglog!("[scoring-full] {}/{}: {}", mem_idx + 1, total, key);
let tokens = { let (tokens, ranges) = {
let ctx = agent.context.lock().await; let ctx = agent.context.lock().await;
build_token_ids(&ctx, 0..ctx.conversation().len(), Filter::SkipKey(key)) build_token_ids(&ctx, 0..ctx.conversation().len(), Filter::SkipKey(key))
}; };
let row = match call_score(&http, client, &tokens, Some(5)).await { let row = match call_score(&http, client, &tokens, &ranges, Some(5)).await {
Ok(without) => { Ok(without) => {
let divs = divergence(&baseline, &without); let divs = divergence(&baseline, &without);
let max_div = divs.iter().cloned().fold(0.0f64, f64::max); let max_div = divs.iter().cloned().fold(0.0f64, f64::max);
@ -466,8 +484,8 @@ pub struct FinetuneCandidate {
pub continuation_ids: Vec<u32>, pub continuation_ids: Vec<u32>,
/// What the model would have said without memories (if generated). /// What the model would have said without memories (if generated).
pub alternate_text: Option<String>, pub alternate_text: Option<String>,
/// Timestamp in millis for tracking trained status. /// Timestamp in nanos — used as unique key for trained-set dedup.
pub timestamp_ms: i64, pub timestamp_ns: i64,
} }
/// Score and enrich finetune candidates with full context. /// Score and enrich finetune candidates with full context.
@ -495,7 +513,7 @@ pub async fn score_finetune_candidates(
let node = &entries[entry_idx]; let node = &entries[entry_idx];
// Get timestamp and skip if already trained // Get timestamp and skip if already trained
let timestamp_ms = match node_timestamp_ms(node) { let timestamp_ns = match node_timestamp_ns(node) {
Some(ts) => { Some(ts) => {
if trained.contains(&ts) { if trained.contains(&ts) {
continue; // Already trained, skip continue; // Already trained, skip
@ -520,7 +538,7 @@ pub async fn score_finetune_candidates(
}; };
// Build token IDs: context = everything before response, continuation = response // Build token IDs: context = everything before response, continuation = response
let context_ids = build_token_ids(context, 0..entry_idx, Filter::None); let (context_ids, _) = build_token_ids(context, 0..entry_idx, Filter::None);
let continuation_ids: Vec<u32> = node.token_ids().into_iter().collect(); let continuation_ids: Vec<u32> = node.token_ids().into_iter().collect();
candidates.push(FinetuneCandidate { candidates.push(FinetuneCandidate {
@ -530,7 +548,7 @@ pub async fn score_finetune_candidates(
context_ids, context_ids,
continuation_ids, continuation_ids,
alternate_text: None, alternate_text: None,
timestamp_ms, timestamp_ns,
}); });
} }
@ -556,7 +574,7 @@ async fn generate_alternate(
use crate::agent::api::{SamplingParams, StreamToken}; use crate::agent::api::{SamplingParams, StreamToken};
// Build context tokens without memories, up to the response // Build context tokens without memories, up to the response
let mut prompt = build_token_ids(context, 0..entry_idx, Filter::SkipAllMemories); let (mut prompt, _) = build_token_ids(context, 0..entry_idx, Filter::SkipAllMemories);
// Add assistant turn start // Add assistant turn start
prompt.push(tokenizer::IM_START); prompt.push(tokenizer::IM_START);
@ -616,7 +634,7 @@ pub fn set_alternates(enabled: bool) {
} }
} }
/// Load set of trained response timestamps (millis since epoch). /// Load set of trained response timestamps (nanos since epoch).
pub fn load_trained() -> HashSet<i64> { pub fn load_trained() -> HashSet<i64> {
let path = trained_path(); let path = trained_path();
match std::fs::read_to_string(&path) { match std::fs::read_to_string(&path) {
@ -626,9 +644,9 @@ pub fn load_trained() -> HashSet<i64> {
} }
/// Mark a response as trained by its timestamp. /// Mark a response as trained by its timestamp.
pub fn mark_trained(timestamp_ms: i64) { pub fn mark_trained(timestamp_ns: i64) {
let mut trained = load_trained(); let mut trained = load_trained();
trained.insert(timestamp_ms); trained.insert(timestamp_ns);
let path = trained_path(); let path = trained_path();
if let Some(parent) = path.parent() { if let Some(parent) = path.parent() {
let _ = std::fs::create_dir_all(parent); let _ = std::fs::create_dir_all(parent);
@ -638,15 +656,19 @@ pub fn mark_trained(timestamp_ms: i64) {
} }
} }
/// Get timestamp in millis from an AstNode (for Branch, uses first child). /// Get timestamp in nanoseconds from an AstNode.
pub fn node_timestamp_ms(node: &AstNode) -> Option<i64> { /// Returns None for entries with default UNIX_EPOCH timestamp (old data)
/// or timestamps outside the representable nano range (pre-1677 or post-2262).
pub fn node_timestamp_ns(node: &AstNode) -> Option<i64> {
let ts = match node { let ts = match node {
AstNode::Leaf(leaf) => leaf.timestamp(), AstNode::Leaf(leaf) => leaf.timestamp(),
AstNode::Branch { children, .. } => { AstNode::Branch { timestamp, .. } => *timestamp,
children.first()?.leaf()?.timestamp() };
if ts == chrono::DateTime::UNIX_EPOCH {
None // Old entry without real timestamp
} else {
ts.timestamp_nanos_opt()
} }
}?;
Some(ts.timestamp_millis())
} }
// ── Training API ──────────────────────────────────────────────── // ── Training API ────────────────────────────────────────────────
@ -662,7 +684,7 @@ struct TrainingSample {
pub struct TrainData { pub struct TrainData {
pub context_ids: Vec<u32>, pub context_ids: Vec<u32>,
pub continuation_ids: Vec<u32>, pub continuation_ids: Vec<u32>,
pub timestamp_ms: i64, pub timestamp_ns: i64,
} }
/// Send training samples to the server. /// Send training samples to the server.
@ -703,7 +725,7 @@ pub async fn send_to_train(
// Mark all samples as trained // Mark all samples as trained
for s in &samples { for s in &samples {
mark_trained(s.timestamp_ms); mark_trained(s.timestamp_ns);
} }
let job_id = result.get("job_id") let job_id = result.get("job_id")

View file

@ -31,8 +31,8 @@ pub struct FinetuneCandidate {
pub continuation_ids: Vec<u32>, pub continuation_ids: Vec<u32>,
/// What the model would have said without memories (if generated). /// What the model would have said without memories (if generated).
pub alternate_text: Option<String>, pub alternate_text: Option<String>,
/// Timestamp in millis for tracking trained status. /// Timestamp in nanos — used as unique key for trained-set dedup.
pub timestamp_ms: i64, pub timestamp_ns: i64,
} }
#[derive(Clone, Debug, PartialEq)] #[derive(Clone, Debug, PartialEq)]
@ -53,7 +53,7 @@ impl From<crate::subconscious::learn::FinetuneCandidate> for FinetuneCandidate {
context_ids: c.context_ids, context_ids: c.context_ids,
continuation_ids: c.continuation_ids, continuation_ids: c.continuation_ids,
alternate_text: c.alternate_text, alternate_text: c.alternate_text,
timestamp_ms: c.timestamp_ms, timestamp_ns: c.timestamp_ns,
} }
} }
} }

View file

@ -171,7 +171,7 @@ impl App {
.map(|c| crate::subconscious::learn::TrainData { .map(|c| crate::subconscious::learn::TrainData {
context_ids: c.context_ids.clone(), context_ids: c.context_ids.clone(),
continuation_ids: c.continuation_ids.clone(), continuation_ids: c.continuation_ids.clone(),
timestamp_ms: c.timestamp_ms, timestamp_ns: c.timestamp_ns,
}) })
.collect(); .collect();
@ -487,7 +487,7 @@ async fn run(
app.finetune_candidates.retain(|c| c.status != learn::CandidateStatus::Sent); app.finetune_candidates.retain(|c| c.status != learn::CandidateStatus::Sent);
for c in &ms.finetune_candidates { for c in &ms.finetune_candidates {
let exists = app.finetune_candidates.iter() let exists = app.finetune_candidates.iter()
.any(|existing| existing.timestamp_ms == c.timestamp_ms); .any(|existing| existing.timestamp_ns == c.timestamp_ns);
if !exists { if !exists {
app.finetune_candidates.push(learn::FinetuneCandidate::from(c.clone())); app.finetune_candidates.push(learn::FinetuneCandidate::from(c.clone()));
} }
@ -496,7 +496,7 @@ async fn run(
let mut rejected: Vec<_> = app.finetune_candidates.iter() let mut rejected: Vec<_> = app.finetune_candidates.iter()
.enumerate() .enumerate()
.filter(|(_, c)| c.status == learn::CandidateStatus::Rejected) .filter(|(_, c)| c.status == learn::CandidateStatus::Rejected)
.map(|(i, c)| (i, c.timestamp_ms)) .map(|(i, c)| (i, c.timestamp_ns))
.collect(); .collect();
if rejected.len() > 10 { if rejected.len() > 10 {
rejected.sort_by_key(|(_, ts)| std::cmp::Reverse(*ts)); rejected.sort_by_key(|(_, ts)| std::cmp::Reverse(*ts));