learn: nanosecond timestamps, token ranges for /score

Two related changes to the learn subsystem:

1. AST node timestamps are now non-optional — both Leaf and Branch
   variants carry a DateTime<Utc>. UNIX_EPOCH means "unset" (old entries
   deserialized from on-disk conversation logs).

   Training uses timestamps as unique keys for dedup, so we promote to
   nanosecond precision: node_timestamp_ns(), TrainData.timestamp_ns,
   FinetuneCandidate.timestamp_ns, mark_trained(ns).

2. build_token_ids() now also returns token-position ranges of assistant
   messages. These are passed to vLLM's /score endpoint via the new
   score_ranges field so only scored-position logprobs are returned —
   cuts bandwidth/compute when scoring small windows.

Co-Authored-By: Proof of Concept <poc@bcachefs.org>
This commit is contained in:
Kent Overstreet 2026-04-16 11:48:37 -04:00
parent 5d9d3ffc5b
commit 2b632d568b
5 changed files with 130 additions and 44 deletions

View file

@ -85,6 +85,19 @@ pub enum NodeBody {
Log(String),
}
fn default_timestamp() -> DateTime<Utc> {
DateTime::UNIX_EPOCH
}
/// Deserialize timestamp, treating both missing and null as UNIX_EPOCH.
fn deserialize_timestamp_or_epoch<'de, D>(deserializer: D) -> Result<DateTime<Utc>, D::Error>
where
D: serde::Deserializer<'de>,
{
let opt: Option<DateTime<Utc>> = Option::deserialize(deserializer)?;
Ok(opt.unwrap_or(DateTime::UNIX_EPOCH))
}
/// A leaf node: typed content with cached token IDs.
/// Token IDs are not serialized — they're recomputed on deserialization.
#[derive(Debug, Clone, Serialize)]
@ -92,7 +105,7 @@ pub struct NodeLeaf {
body: NodeBody,
#[serde(skip)]
token_ids: Vec<u32>,
timestamp: Option<DateTime<Utc>>,
timestamp: DateTime<Utc>,
}
impl<'de> Deserialize<'de> for NodeLeaf {
@ -100,7 +113,8 @@ impl<'de> Deserialize<'de> for NodeLeaf {
#[derive(Deserialize)]
struct Raw {
body: NodeBody,
timestamp: Option<DateTime<Utc>>,
#[serde(default = "default_timestamp", deserialize_with = "deserialize_timestamp_or_epoch")]
timestamp: DateTime<Utc>,
}
let raw = Raw::deserialize(deserializer)?;
let token_ids = if raw.body.is_prompt_visible() {
@ -119,6 +133,8 @@ pub enum AstNode {
Branch {
role: Role,
children: Vec<AstNode>,
#[serde(default = "default_timestamp", deserialize_with = "deserialize_timestamp_or_epoch")]
timestamp: DateTime<Utc>,
/// Per-response memory attribution from full scoring matrix.
/// Maps memory key → divergence score for this response.
#[serde(default, skip_serializing_if = "std::collections::BTreeMap::is_empty")]
@ -252,18 +268,18 @@ impl NodeLeaf {
} else {
vec![]
};
Self { body, token_ids, timestamp: None }
Self { body, token_ids, timestamp: Utc::now() }
}
pub fn with_timestamp(mut self, ts: DateTime<Utc>) -> Self {
self.timestamp = Some(ts);
self.timestamp = ts;
self
}
pub fn body(&self) -> &NodeBody { &self.body }
pub fn token_ids(&self) -> &[u32] { &self.token_ids }
pub fn tokens(&self) -> usize { self.token_ids.len() }
pub fn timestamp(&self) -> Option<DateTime<Utc>> { self.timestamp }
pub fn timestamp(&self) -> DateTime<Utc> { self.timestamp }
}
impl AstNode {
@ -307,13 +323,14 @@ impl AstNode {
// -- Branch constructors --------------------------------------------------
pub fn branch(role: Role, children: Vec<AstNode>) -> Self {
Self::Branch { role, children, memory_scores: Default::default() }
Self::Branch { role, children, timestamp: Utc::now(), memory_scores: Default::default() }
}
pub fn system_msg(text: impl Into<String>) -> Self {
Self::Branch {
role: Role::System,
children: vec![Self::content(text)],
timestamp: Utc::now(),
memory_scores: Default::default(),
}
}
@ -322,6 +339,7 @@ impl AstNode {
Self::Branch {
role: Role::User,
children: vec![Self::content(text)],
timestamp: Utc::now(),
memory_scores: Default::default(),
}
}
@ -338,9 +356,10 @@ impl AstNode {
};
Self::Leaf(NodeLeaf { token_ids, ..leaf })
}
Self::Branch { role, children, memory_scores, .. } => Self::Branch {
Self::Branch { role, children, timestamp, memory_scores } => Self::Branch {
role,
children: children.into_iter().map(|c| c.retokenize()).collect(),
timestamp,
memory_scores,
},
}
@ -348,8 +367,8 @@ impl AstNode {
pub fn with_timestamp(mut self, ts: DateTime<Utc>) -> Self {
match &mut self {
Self::Leaf(leaf) => leaf.timestamp = Some(ts),
Self::Branch { .. } => {}
Self::Leaf(leaf) => leaf.timestamp = ts,
Self::Branch { timestamp, .. } => *timestamp = ts,
}
self
}
@ -1340,4 +1359,49 @@ mod tests {
assert_token_invariants(node);
assert!(node.tokens() > 0);
}
// -- Timestamp deserialization tests ------------------------------------------
#[test]
fn test_timestamp_null_becomes_epoch() {
// Old conversation.jsonl entries have "timestamp":null
// serde(default) only handles missing fields, not explicit nulls.
// We need to verify our deserialize handles this correctly.
let json = r#"{"Leaf":{"body":{"Content":"hello"},"timestamp":null}}"#;
let node: AstNode = serde_json::from_str(json).unwrap();
let leaf = node.leaf().unwrap();
assert_eq!(leaf.timestamp(), DateTime::<Utc>::UNIX_EPOCH);
}
#[test]
fn test_timestamp_missing_becomes_epoch() {
let json = r#"{"Leaf":{"body":{"Content":"hello"}}}"#;
let node: AstNode = serde_json::from_str(json).unwrap();
let leaf = node.leaf().unwrap();
assert_eq!(leaf.timestamp(), DateTime::<Utc>::UNIX_EPOCH);
}
#[test]
fn test_branch_timestamp_null_becomes_epoch() {
let json = r#"{"Branch":{"role":"User","children":[{"Leaf":{"body":{"Content":"hi"}}}],"timestamp":null}}"#;
let node: AstNode = serde_json::from_str(json).unwrap();
match node {
AstNode::Branch { timestamp, .. } => {
assert_eq!(timestamp, DateTime::<Utc>::UNIX_EPOCH);
}
_ => panic!("expected Branch"),
}
}
#[test]
fn test_branch_timestamp_missing_becomes_epoch() {
let json = r#"{"Branch":{"role":"User","children":[{"Leaf":{"body":{"Content":"hi"}}}]}}"#;
let node: AstNode = serde_json::from_str(json).unwrap();
match node {
AstNode::Branch { timestamp, .. } => {
assert_eq!(timestamp, DateTime::<Utc>::UNIX_EPOCH);
}
_ => panic!("expected Branch"),
}
}
}

View file

@ -55,15 +55,15 @@ impl ConversationLog {
}
pub fn oldest_timestamp(&self) -> Option<chrono::DateTime<chrono::Utc>> {
// Read forward from the start to find first timestamp
// Read forward from the start to find first non-epoch timestamp
let file = File::open(&self.path).ok()?;
let mmap = unsafe { Mmap::map(&file).ok()? };
// Find first { ... } and parse
for line in mmap.split(|&b| b == b'\n') {
if line.is_empty() { continue; }
if let Ok(node) = serde_json::from_slice::<AstNode>(line) {
if let Some(leaf) = node.leaf() {
if let Some(ts) = leaf.timestamp() {
let ts = leaf.timestamp();
if ts != chrono::DateTime::UNIX_EPOCH {
return Some(ts);
}
}

View file

@ -53,13 +53,18 @@ fn is_assistant(node: &AstNode) -> bool {
///
/// Includes all sections up to and including conversation entries in
/// `range`, with `filter` applied to conversation entries.
///
/// Returns (token_ids, assistant_ranges) where assistant_ranges are
/// (start, end) token positions for each assistant message.
fn build_token_ids(
context: &ContextState,
range: std::ops::Range<usize>,
filter: Filter,
) -> Vec<u32> {
) -> (Vec<u32>, Vec<(usize, usize)>) {
use crate::agent::context::Ast;
let mut ids = Vec::new();
let mut assistant_ranges = Vec::new();
for node in context.system() {
ids.extend(node.token_ids());
}
@ -87,9 +92,16 @@ fn build_token_ids(
Filter::SkipAllMemories => is_memory(node),
};
if skip { continue; }
// Track assistant message boundaries
let is_asst = is_assistant(node);
let start = ids.len();
ids.extend(node.token_ids());
if is_asst {
assistant_ranges.push((start, ids.len()));
}
ids
}
(ids, assistant_ranges)
}
// ── Score API ───────────────────────────────────────────────────
@ -114,6 +126,7 @@ async fn call_score(
http: &crate::agent::api::http::HttpClient,
client: &ApiClient,
prompt: &[u32],
ranges: &[(usize, usize)],
priority: Option<i32>,
) -> anyhow::Result<Vec<ScoreResult>> {
let url = format!("{}/score", client.base_url());
@ -123,6 +136,9 @@ async fn call_score(
"prompt": prompt,
"logprobs": 1,
});
if !ranges.is_empty() {
body["score_ranges"] = serde_json::json!(ranges);
}
if let Some(p) = priority {
body["priority"] = serde_json::json!(p);
}
@ -168,8 +184,10 @@ async fn score_divergence(
filter: Filter<'_>,
priority: Option<i32>,
) -> anyhow::Result<(Vec<f64>, Vec<ScoreResult>)> {
let baseline = call_score(http, client, &build_token_ids(context, range.clone(), Filter::None), priority).await?;
let without = call_score(http, client, &build_token_ids(context, range, filter), priority).await?;
let (baseline_tokens, baseline_ranges) = build_token_ids(context, range.clone(), Filter::None);
let (without_tokens, without_ranges) = build_token_ids(context, range, filter);
let baseline = call_score(http, client, &baseline_tokens, &baseline_ranges, priority).await?;
let without = call_score(http, client, &without_tokens, &without_ranges, priority).await?;
let divs = divergence(&baseline, &without);
Ok((divs, baseline))
}
@ -208,21 +226,21 @@ pub async fn score_memories(
let http = http_client();
let activity = crate::agent::start_activity(agent, "scoring: baseline").await;
let baseline_tokens = {
let (baseline_tokens, baseline_ranges) = {
let ctx = agent.context.lock().await;
build_token_ids(&ctx, 0..ctx.conversation().len(), Filter::None)
};
let baseline = call_score(&http, client, &baseline_tokens, Some(5)).await?;
let baseline = call_score(&http, client, &baseline_tokens, &baseline_ranges, Some(5)).await?;
dbglog!("[scoring-full] baseline done ({} response scores)", baseline.len());
for (mem_idx, key) in memory_keys.iter().enumerate() {
activity.update(format!("scoring: {}/{}", mem_idx + 1, total)).await;
dbglog!("[scoring-full] {}/{}: {}", mem_idx + 1, total, key);
let tokens = {
let (tokens, ranges) = {
let ctx = agent.context.lock().await;
build_token_ids(&ctx, 0..ctx.conversation().len(), Filter::SkipKey(key))
};
let row = match call_score(&http, client, &tokens, Some(5)).await {
let row = match call_score(&http, client, &tokens, &ranges, Some(5)).await {
Ok(without) => {
let divs = divergence(&baseline, &without);
let max_div = divs.iter().cloned().fold(0.0f64, f64::max);
@ -466,8 +484,8 @@ pub struct FinetuneCandidate {
pub continuation_ids: Vec<u32>,
/// What the model would have said without memories (if generated).
pub alternate_text: Option<String>,
/// Timestamp in millis for tracking trained status.
pub timestamp_ms: i64,
/// Timestamp in nanos — used as unique key for trained-set dedup.
pub timestamp_ns: i64,
}
/// Score and enrich finetune candidates with full context.
@ -495,7 +513,7 @@ pub async fn score_finetune_candidates(
let node = &entries[entry_idx];
// Get timestamp and skip if already trained
let timestamp_ms = match node_timestamp_ms(node) {
let timestamp_ns = match node_timestamp_ns(node) {
Some(ts) => {
if trained.contains(&ts) {
continue; // Already trained, skip
@ -520,7 +538,7 @@ pub async fn score_finetune_candidates(
};
// Build token IDs: context = everything before response, continuation = response
let context_ids = build_token_ids(context, 0..entry_idx, Filter::None);
let (context_ids, _) = build_token_ids(context, 0..entry_idx, Filter::None);
let continuation_ids: Vec<u32> = node.token_ids().into_iter().collect();
candidates.push(FinetuneCandidate {
@ -530,7 +548,7 @@ pub async fn score_finetune_candidates(
context_ids,
continuation_ids,
alternate_text: None,
timestamp_ms,
timestamp_ns,
});
}
@ -556,7 +574,7 @@ async fn generate_alternate(
use crate::agent::api::{SamplingParams, StreamToken};
// Build context tokens without memories, up to the response
let mut prompt = build_token_ids(context, 0..entry_idx, Filter::SkipAllMemories);
let (mut prompt, _) = build_token_ids(context, 0..entry_idx, Filter::SkipAllMemories);
// Add assistant turn start
prompt.push(tokenizer::IM_START);
@ -616,7 +634,7 @@ pub fn set_alternates(enabled: bool) {
}
}
/// Load set of trained response timestamps (millis since epoch).
/// Load set of trained response timestamps (nanos since epoch).
pub fn load_trained() -> HashSet<i64> {
let path = trained_path();
match std::fs::read_to_string(&path) {
@ -626,9 +644,9 @@ pub fn load_trained() -> HashSet<i64> {
}
/// Mark a response as trained by its timestamp.
pub fn mark_trained(timestamp_ms: i64) {
pub fn mark_trained(timestamp_ns: i64) {
let mut trained = load_trained();
trained.insert(timestamp_ms);
trained.insert(timestamp_ns);
let path = trained_path();
if let Some(parent) = path.parent() {
let _ = std::fs::create_dir_all(parent);
@ -638,15 +656,19 @@ pub fn mark_trained(timestamp_ms: i64) {
}
}
/// Get timestamp in millis from an AstNode (for Branch, uses first child).
pub fn node_timestamp_ms(node: &AstNode) -> Option<i64> {
/// Get timestamp in nanoseconds from an AstNode.
/// Returns None for entries with default UNIX_EPOCH timestamp (old data)
/// or timestamps outside the representable nano range (pre-1677 or post-2262).
pub fn node_timestamp_ns(node: &AstNode) -> Option<i64> {
let ts = match node {
AstNode::Leaf(leaf) => leaf.timestamp(),
AstNode::Branch { children, .. } => {
children.first()?.leaf()?.timestamp()
AstNode::Branch { timestamp, .. } => *timestamp,
};
if ts == chrono::DateTime::UNIX_EPOCH {
None // Old entry without real timestamp
} else {
ts.timestamp_nanos_opt()
}
}?;
Some(ts.timestamp_millis())
}
// ── Training API ────────────────────────────────────────────────
@ -662,7 +684,7 @@ struct TrainingSample {
pub struct TrainData {
pub context_ids: Vec<u32>,
pub continuation_ids: Vec<u32>,
pub timestamp_ms: i64,
pub timestamp_ns: i64,
}
/// Send training samples to the server.
@ -703,7 +725,7 @@ pub async fn send_to_train(
// Mark all samples as trained
for s in &samples {
mark_trained(s.timestamp_ms);
mark_trained(s.timestamp_ns);
}
let job_id = result.get("job_id")

View file

@ -31,8 +31,8 @@ pub struct FinetuneCandidate {
pub continuation_ids: Vec<u32>,
/// What the model would have said without memories (if generated).
pub alternate_text: Option<String>,
/// Timestamp in millis for tracking trained status.
pub timestamp_ms: i64,
/// Timestamp in nanos — used as unique key for trained-set dedup.
pub timestamp_ns: i64,
}
#[derive(Clone, Debug, PartialEq)]
@ -53,7 +53,7 @@ impl From<crate::subconscious::learn::FinetuneCandidate> for FinetuneCandidate {
context_ids: c.context_ids,
continuation_ids: c.continuation_ids,
alternate_text: c.alternate_text,
timestamp_ms: c.timestamp_ms,
timestamp_ns: c.timestamp_ns,
}
}
}

View file

@ -171,7 +171,7 @@ impl App {
.map(|c| crate::subconscious::learn::TrainData {
context_ids: c.context_ids.clone(),
continuation_ids: c.continuation_ids.clone(),
timestamp_ms: c.timestamp_ms,
timestamp_ns: c.timestamp_ns,
})
.collect();
@ -487,7 +487,7 @@ async fn run(
app.finetune_candidates.retain(|c| c.status != learn::CandidateStatus::Sent);
for c in &ms.finetune_candidates {
let exists = app.finetune_candidates.iter()
.any(|existing| existing.timestamp_ms == c.timestamp_ms);
.any(|existing| existing.timestamp_ns == c.timestamp_ns);
if !exists {
app.finetune_candidates.push(learn::FinetuneCandidate::from(c.clone()));
}
@ -496,7 +496,7 @@ async fn run(
let mut rejected: Vec<_> = app.finetune_candidates.iter()
.enumerate()
.filter(|(_, c)| c.status == learn::CandidateStatus::Rejected)
.map(|(i, c)| (i, c.timestamp_ms))
.map(|(i, c)| (i, c.timestamp_ns))
.collect();
if rejected.len() > 10 {
rejected.sort_by_key(|(_, ts)| std::cmp::Reverse(*ts));