learn: nanosecond timestamps, token ranges for /score

Two related changes to the learn subsystem:

1. AST node timestamps are now non-optional — both Leaf and Branch
   variants carry a DateTime<Utc>. UNIX_EPOCH means "unset" (old entries
   deserialized from on-disk conversation logs).

   Training uses timestamps as unique keys for dedup, so we promote to
   nanosecond precision: node_timestamp_ns(), TrainData.timestamp_ns,
   FinetuneCandidate.timestamp_ns, mark_trained(ns).

2. build_token_ids() now also returns token-position ranges of assistant
   messages. These are passed to vLLM's /score endpoint via the new
   score_ranges field so only scored-position logprobs are returned —
   cuts bandwidth/compute when scoring small windows.

Co-Authored-By: Proof of Concept <poc@bcachefs.org>
This commit is contained in:
Kent Overstreet 2026-04-16 11:48:37 -04:00
parent 5d9d3ffc5b
commit 2b632d568b
5 changed files with 130 additions and 44 deletions

View file

@ -85,6 +85,19 @@ pub enum NodeBody {
Log(String),
}
fn default_timestamp() -> DateTime<Utc> {
DateTime::UNIX_EPOCH
}
/// Deserialize timestamp, treating both missing and null as UNIX_EPOCH.
fn deserialize_timestamp_or_epoch<'de, D>(deserializer: D) -> Result<DateTime<Utc>, D::Error>
where
D: serde::Deserializer<'de>,
{
let opt: Option<DateTime<Utc>> = Option::deserialize(deserializer)?;
Ok(opt.unwrap_or(DateTime::UNIX_EPOCH))
}
/// A leaf node: typed content with cached token IDs.
/// Token IDs are not serialized — they're recomputed on deserialization.
#[derive(Debug, Clone, Serialize)]
@ -92,7 +105,7 @@ pub struct NodeLeaf {
body: NodeBody,
#[serde(skip)]
token_ids: Vec<u32>,
timestamp: Option<DateTime<Utc>>,
timestamp: DateTime<Utc>,
}
impl<'de> Deserialize<'de> for NodeLeaf {
@ -100,7 +113,8 @@ impl<'de> Deserialize<'de> for NodeLeaf {
#[derive(Deserialize)]
struct Raw {
body: NodeBody,
timestamp: Option<DateTime<Utc>>,
#[serde(default = "default_timestamp", deserialize_with = "deserialize_timestamp_or_epoch")]
timestamp: DateTime<Utc>,
}
let raw = Raw::deserialize(deserializer)?;
let token_ids = if raw.body.is_prompt_visible() {
@ -119,6 +133,8 @@ pub enum AstNode {
Branch {
role: Role,
children: Vec<AstNode>,
#[serde(default = "default_timestamp", deserialize_with = "deserialize_timestamp_or_epoch")]
timestamp: DateTime<Utc>,
/// Per-response memory attribution from full scoring matrix.
/// Maps memory key → divergence score for this response.
#[serde(default, skip_serializing_if = "std::collections::BTreeMap::is_empty")]
@ -252,18 +268,18 @@ impl NodeLeaf {
} else {
vec![]
};
Self { body, token_ids, timestamp: None }
Self { body, token_ids, timestamp: Utc::now() }
}
pub fn with_timestamp(mut self, ts: DateTime<Utc>) -> Self {
self.timestamp = Some(ts);
self.timestamp = ts;
self
}
pub fn body(&self) -> &NodeBody { &self.body }
pub fn token_ids(&self) -> &[u32] { &self.token_ids }
pub fn tokens(&self) -> usize { self.token_ids.len() }
pub fn timestamp(&self) -> Option<DateTime<Utc>> { self.timestamp }
pub fn timestamp(&self) -> DateTime<Utc> { self.timestamp }
}
impl AstNode {
@ -307,13 +323,14 @@ impl AstNode {
// -- Branch constructors --------------------------------------------------
pub fn branch(role: Role, children: Vec<AstNode>) -> Self {
Self::Branch { role, children, memory_scores: Default::default() }
Self::Branch { role, children, timestamp: Utc::now(), memory_scores: Default::default() }
}
pub fn system_msg(text: impl Into<String>) -> Self {
Self::Branch {
role: Role::System,
children: vec![Self::content(text)],
timestamp: Utc::now(),
memory_scores: Default::default(),
}
}
@ -322,6 +339,7 @@ impl AstNode {
Self::Branch {
role: Role::User,
children: vec![Self::content(text)],
timestamp: Utc::now(),
memory_scores: Default::default(),
}
}
@ -338,9 +356,10 @@ impl AstNode {
};
Self::Leaf(NodeLeaf { token_ids, ..leaf })
}
Self::Branch { role, children, memory_scores, .. } => Self::Branch {
Self::Branch { role, children, timestamp, memory_scores } => Self::Branch {
role,
children: children.into_iter().map(|c| c.retokenize()).collect(),
timestamp,
memory_scores,
},
}
@ -348,8 +367,8 @@ impl AstNode {
pub fn with_timestamp(mut self, ts: DateTime<Utc>) -> Self {
match &mut self {
Self::Leaf(leaf) => leaf.timestamp = Some(ts),
Self::Branch { .. } => {}
Self::Leaf(leaf) => leaf.timestamp = ts,
Self::Branch { timestamp, .. } => *timestamp = ts,
}
self
}
@ -1340,4 +1359,49 @@ mod tests {
assert_token_invariants(node);
assert!(node.tokens() > 0);
}
// -- Timestamp deserialization tests ------------------------------------------
#[test]
fn test_timestamp_null_becomes_epoch() {
// Old conversation.jsonl entries have "timestamp":null
// serde(default) only handles missing fields, not explicit nulls.
// We need to verify our deserialize handles this correctly.
let json = r#"{"Leaf":{"body":{"Content":"hello"},"timestamp":null}}"#;
let node: AstNode = serde_json::from_str(json).unwrap();
let leaf = node.leaf().unwrap();
assert_eq!(leaf.timestamp(), DateTime::<Utc>::UNIX_EPOCH);
}
#[test]
fn test_timestamp_missing_becomes_epoch() {
let json = r#"{"Leaf":{"body":{"Content":"hello"}}}"#;
let node: AstNode = serde_json::from_str(json).unwrap();
let leaf = node.leaf().unwrap();
assert_eq!(leaf.timestamp(), DateTime::<Utc>::UNIX_EPOCH);
}
#[test]
fn test_branch_timestamp_null_becomes_epoch() {
let json = r#"{"Branch":{"role":"User","children":[{"Leaf":{"body":{"Content":"hi"}}}],"timestamp":null}}"#;
let node: AstNode = serde_json::from_str(json).unwrap();
match node {
AstNode::Branch { timestamp, .. } => {
assert_eq!(timestamp, DateTime::<Utc>::UNIX_EPOCH);
}
_ => panic!("expected Branch"),
}
}
#[test]
fn test_branch_timestamp_missing_becomes_epoch() {
let json = r#"{"Branch":{"role":"User","children":[{"Leaf":{"body":{"Content":"hi"}}}]}}"#;
let node: AstNode = serde_json::from_str(json).unwrap();
match node {
AstNode::Branch { timestamp, .. } => {
assert_eq!(timestamp, DateTime::<Utc>::UNIX_EPOCH);
}
_ => panic!("expected Branch"),
}
}
}