learn: nanosecond timestamps, token ranges for /score
Two related changes to the learn subsystem: 1. AST node timestamps are now non-optional — both Leaf and Branch variants carry a DateTime<Utc>. UNIX_EPOCH means "unset" (old entries deserialized from on-disk conversation logs). Training uses timestamps as unique keys for dedup, so we promote to nanosecond precision: node_timestamp_ns(), TrainData.timestamp_ns, FinetuneCandidate.timestamp_ns, mark_trained(ns). 2. build_token_ids() now also returns token-position ranges of assistant messages. These are passed to vLLM's /score endpoint via the new score_ranges field so only scored-position logprobs are returned — cuts bandwidth/compute when scoring small windows. Co-Authored-By: Proof of Concept <poc@bcachefs.org>
This commit is contained in:
parent
5d9d3ffc5b
commit
2b632d568b
5 changed files with 130 additions and 44 deletions
|
|
@ -85,6 +85,19 @@ pub enum NodeBody {
|
|||
Log(String),
|
||||
}
|
||||
|
||||
fn default_timestamp() -> DateTime<Utc> {
|
||||
DateTime::UNIX_EPOCH
|
||||
}
|
||||
|
||||
/// Deserialize timestamp, treating both missing and null as UNIX_EPOCH.
|
||||
fn deserialize_timestamp_or_epoch<'de, D>(deserializer: D) -> Result<DateTime<Utc>, D::Error>
|
||||
where
|
||||
D: serde::Deserializer<'de>,
|
||||
{
|
||||
let opt: Option<DateTime<Utc>> = Option::deserialize(deserializer)?;
|
||||
Ok(opt.unwrap_or(DateTime::UNIX_EPOCH))
|
||||
}
|
||||
|
||||
/// A leaf node: typed content with cached token IDs.
|
||||
/// Token IDs are not serialized — they're recomputed on deserialization.
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
|
|
@ -92,7 +105,7 @@ pub struct NodeLeaf {
|
|||
body: NodeBody,
|
||||
#[serde(skip)]
|
||||
token_ids: Vec<u32>,
|
||||
timestamp: Option<DateTime<Utc>>,
|
||||
timestamp: DateTime<Utc>,
|
||||
}
|
||||
|
||||
impl<'de> Deserialize<'de> for NodeLeaf {
|
||||
|
|
@ -100,7 +113,8 @@ impl<'de> Deserialize<'de> for NodeLeaf {
|
|||
#[derive(Deserialize)]
|
||||
struct Raw {
|
||||
body: NodeBody,
|
||||
timestamp: Option<DateTime<Utc>>,
|
||||
#[serde(default = "default_timestamp", deserialize_with = "deserialize_timestamp_or_epoch")]
|
||||
timestamp: DateTime<Utc>,
|
||||
}
|
||||
let raw = Raw::deserialize(deserializer)?;
|
||||
let token_ids = if raw.body.is_prompt_visible() {
|
||||
|
|
@ -119,6 +133,8 @@ pub enum AstNode {
|
|||
Branch {
|
||||
role: Role,
|
||||
children: Vec<AstNode>,
|
||||
#[serde(default = "default_timestamp", deserialize_with = "deserialize_timestamp_or_epoch")]
|
||||
timestamp: DateTime<Utc>,
|
||||
/// Per-response memory attribution from full scoring matrix.
|
||||
/// Maps memory key → divergence score for this response.
|
||||
#[serde(default, skip_serializing_if = "std::collections::BTreeMap::is_empty")]
|
||||
|
|
@ -252,18 +268,18 @@ impl NodeLeaf {
|
|||
} else {
|
||||
vec![]
|
||||
};
|
||||
Self { body, token_ids, timestamp: None }
|
||||
Self { body, token_ids, timestamp: Utc::now() }
|
||||
}
|
||||
|
||||
pub fn with_timestamp(mut self, ts: DateTime<Utc>) -> Self {
|
||||
self.timestamp = Some(ts);
|
||||
self.timestamp = ts;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn body(&self) -> &NodeBody { &self.body }
|
||||
pub fn token_ids(&self) -> &[u32] { &self.token_ids }
|
||||
pub fn tokens(&self) -> usize { self.token_ids.len() }
|
||||
pub fn timestamp(&self) -> Option<DateTime<Utc>> { self.timestamp }
|
||||
pub fn timestamp(&self) -> DateTime<Utc> { self.timestamp }
|
||||
}
|
||||
|
||||
impl AstNode {
|
||||
|
|
@ -307,13 +323,14 @@ impl AstNode {
|
|||
// -- Branch constructors --------------------------------------------------
|
||||
|
||||
pub fn branch(role: Role, children: Vec<AstNode>) -> Self {
|
||||
Self::Branch { role, children, memory_scores: Default::default() }
|
||||
Self::Branch { role, children, timestamp: Utc::now(), memory_scores: Default::default() }
|
||||
}
|
||||
|
||||
pub fn system_msg(text: impl Into<String>) -> Self {
|
||||
Self::Branch {
|
||||
role: Role::System,
|
||||
children: vec![Self::content(text)],
|
||||
timestamp: Utc::now(),
|
||||
memory_scores: Default::default(),
|
||||
}
|
||||
}
|
||||
|
|
@ -322,6 +339,7 @@ impl AstNode {
|
|||
Self::Branch {
|
||||
role: Role::User,
|
||||
children: vec![Self::content(text)],
|
||||
timestamp: Utc::now(),
|
||||
memory_scores: Default::default(),
|
||||
}
|
||||
}
|
||||
|
|
@ -338,9 +356,10 @@ impl AstNode {
|
|||
};
|
||||
Self::Leaf(NodeLeaf { token_ids, ..leaf })
|
||||
}
|
||||
Self::Branch { role, children, memory_scores, .. } => Self::Branch {
|
||||
Self::Branch { role, children, timestamp, memory_scores } => Self::Branch {
|
||||
role,
|
||||
children: children.into_iter().map(|c| c.retokenize()).collect(),
|
||||
timestamp,
|
||||
memory_scores,
|
||||
},
|
||||
}
|
||||
|
|
@ -348,8 +367,8 @@ impl AstNode {
|
|||
|
||||
pub fn with_timestamp(mut self, ts: DateTime<Utc>) -> Self {
|
||||
match &mut self {
|
||||
Self::Leaf(leaf) => leaf.timestamp = Some(ts),
|
||||
Self::Branch { .. } => {}
|
||||
Self::Leaf(leaf) => leaf.timestamp = ts,
|
||||
Self::Branch { timestamp, .. } => *timestamp = ts,
|
||||
}
|
||||
self
|
||||
}
|
||||
|
|
@ -1340,4 +1359,49 @@ mod tests {
|
|||
assert_token_invariants(node);
|
||||
assert!(node.tokens() > 0);
|
||||
}
|
||||
|
||||
// -- Timestamp deserialization tests ------------------------------------------
|
||||
|
||||
#[test]
|
||||
fn test_timestamp_null_becomes_epoch() {
|
||||
// Old conversation.jsonl entries have "timestamp":null
|
||||
// serde(default) only handles missing fields, not explicit nulls.
|
||||
// We need to verify our deserialize handles this correctly.
|
||||
let json = r#"{"Leaf":{"body":{"Content":"hello"},"timestamp":null}}"#;
|
||||
let node: AstNode = serde_json::from_str(json).unwrap();
|
||||
let leaf = node.leaf().unwrap();
|
||||
assert_eq!(leaf.timestamp(), DateTime::<Utc>::UNIX_EPOCH);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_timestamp_missing_becomes_epoch() {
|
||||
let json = r#"{"Leaf":{"body":{"Content":"hello"}}}"#;
|
||||
let node: AstNode = serde_json::from_str(json).unwrap();
|
||||
let leaf = node.leaf().unwrap();
|
||||
assert_eq!(leaf.timestamp(), DateTime::<Utc>::UNIX_EPOCH);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_branch_timestamp_null_becomes_epoch() {
|
||||
let json = r#"{"Branch":{"role":"User","children":[{"Leaf":{"body":{"Content":"hi"}}}],"timestamp":null}}"#;
|
||||
let node: AstNode = serde_json::from_str(json).unwrap();
|
||||
match node {
|
||||
AstNode::Branch { timestamp, .. } => {
|
||||
assert_eq!(timestamp, DateTime::<Utc>::UNIX_EPOCH);
|
||||
}
|
||||
_ => panic!("expected Branch"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_branch_timestamp_missing_becomes_epoch() {
|
||||
let json = r#"{"Branch":{"role":"User","children":[{"Leaf":{"body":{"Content":"hi"}}}]}}"#;
|
||||
let node: AstNode = serde_json::from_str(json).unwrap();
|
||||
match node {
|
||||
AstNode::Branch { timestamp, .. } => {
|
||||
assert_eq!(timestamp, DateTime::<Utc>::UNIX_EPOCH);
|
||||
}
|
||||
_ => panic!("expected Branch"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -55,15 +55,15 @@ impl ConversationLog {
|
|||
}
|
||||
|
||||
pub fn oldest_timestamp(&self) -> Option<chrono::DateTime<chrono::Utc>> {
|
||||
// Read forward from the start to find first timestamp
|
||||
// Read forward from the start to find first non-epoch timestamp
|
||||
let file = File::open(&self.path).ok()?;
|
||||
let mmap = unsafe { Mmap::map(&file).ok()? };
|
||||
// Find first { ... } and parse
|
||||
for line in mmap.split(|&b| b == b'\n') {
|
||||
if line.is_empty() { continue; }
|
||||
if let Ok(node) = serde_json::from_slice::<AstNode>(line) {
|
||||
if let Some(leaf) = node.leaf() {
|
||||
if let Some(ts) = leaf.timestamp() {
|
||||
let ts = leaf.timestamp();
|
||||
if ts != chrono::DateTime::UNIX_EPOCH {
|
||||
return Some(ts);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -53,13 +53,18 @@ fn is_assistant(node: &AstNode) -> bool {
|
|||
///
|
||||
/// Includes all sections up to and including conversation entries in
|
||||
/// `range`, with `filter` applied to conversation entries.
|
||||
///
|
||||
/// Returns (token_ids, assistant_ranges) where assistant_ranges are
|
||||
/// (start, end) token positions for each assistant message.
|
||||
fn build_token_ids(
|
||||
context: &ContextState,
|
||||
range: std::ops::Range<usize>,
|
||||
filter: Filter,
|
||||
) -> Vec<u32> {
|
||||
) -> (Vec<u32>, Vec<(usize, usize)>) {
|
||||
use crate::agent::context::Ast;
|
||||
let mut ids = Vec::new();
|
||||
let mut assistant_ranges = Vec::new();
|
||||
|
||||
for node in context.system() {
|
||||
ids.extend(node.token_ids());
|
||||
}
|
||||
|
|
@ -87,9 +92,16 @@ fn build_token_ids(
|
|||
Filter::SkipAllMemories => is_memory(node),
|
||||
};
|
||||
if skip { continue; }
|
||||
|
||||
// Track assistant message boundaries
|
||||
let is_asst = is_assistant(node);
|
||||
let start = ids.len();
|
||||
ids.extend(node.token_ids());
|
||||
if is_asst {
|
||||
assistant_ranges.push((start, ids.len()));
|
||||
}
|
||||
}
|
||||
ids
|
||||
(ids, assistant_ranges)
|
||||
}
|
||||
|
||||
// ── Score API ───────────────────────────────────────────────────
|
||||
|
|
@ -114,6 +126,7 @@ async fn call_score(
|
|||
http: &crate::agent::api::http::HttpClient,
|
||||
client: &ApiClient,
|
||||
prompt: &[u32],
|
||||
ranges: &[(usize, usize)],
|
||||
priority: Option<i32>,
|
||||
) -> anyhow::Result<Vec<ScoreResult>> {
|
||||
let url = format!("{}/score", client.base_url());
|
||||
|
|
@ -123,6 +136,9 @@ async fn call_score(
|
|||
"prompt": prompt,
|
||||
"logprobs": 1,
|
||||
});
|
||||
if !ranges.is_empty() {
|
||||
body["score_ranges"] = serde_json::json!(ranges);
|
||||
}
|
||||
if let Some(p) = priority {
|
||||
body["priority"] = serde_json::json!(p);
|
||||
}
|
||||
|
|
@ -168,8 +184,10 @@ async fn score_divergence(
|
|||
filter: Filter<'_>,
|
||||
priority: Option<i32>,
|
||||
) -> anyhow::Result<(Vec<f64>, Vec<ScoreResult>)> {
|
||||
let baseline = call_score(http, client, &build_token_ids(context, range.clone(), Filter::None), priority).await?;
|
||||
let without = call_score(http, client, &build_token_ids(context, range, filter), priority).await?;
|
||||
let (baseline_tokens, baseline_ranges) = build_token_ids(context, range.clone(), Filter::None);
|
||||
let (without_tokens, without_ranges) = build_token_ids(context, range, filter);
|
||||
let baseline = call_score(http, client, &baseline_tokens, &baseline_ranges, priority).await?;
|
||||
let without = call_score(http, client, &without_tokens, &without_ranges, priority).await?;
|
||||
let divs = divergence(&baseline, &without);
|
||||
Ok((divs, baseline))
|
||||
}
|
||||
|
|
@ -208,21 +226,21 @@ pub async fn score_memories(
|
|||
let http = http_client();
|
||||
|
||||
let activity = crate::agent::start_activity(agent, "scoring: baseline").await;
|
||||
let baseline_tokens = {
|
||||
let (baseline_tokens, baseline_ranges) = {
|
||||
let ctx = agent.context.lock().await;
|
||||
build_token_ids(&ctx, 0..ctx.conversation().len(), Filter::None)
|
||||
};
|
||||
let baseline = call_score(&http, client, &baseline_tokens, Some(5)).await?;
|
||||
let baseline = call_score(&http, client, &baseline_tokens, &baseline_ranges, Some(5)).await?;
|
||||
dbglog!("[scoring-full] baseline done ({} response scores)", baseline.len());
|
||||
|
||||
for (mem_idx, key) in memory_keys.iter().enumerate() {
|
||||
activity.update(format!("scoring: {}/{}", mem_idx + 1, total)).await;
|
||||
dbglog!("[scoring-full] {}/{}: {}", mem_idx + 1, total, key);
|
||||
let tokens = {
|
||||
let (tokens, ranges) = {
|
||||
let ctx = agent.context.lock().await;
|
||||
build_token_ids(&ctx, 0..ctx.conversation().len(), Filter::SkipKey(key))
|
||||
};
|
||||
let row = match call_score(&http, client, &tokens, Some(5)).await {
|
||||
let row = match call_score(&http, client, &tokens, &ranges, Some(5)).await {
|
||||
Ok(without) => {
|
||||
let divs = divergence(&baseline, &without);
|
||||
let max_div = divs.iter().cloned().fold(0.0f64, f64::max);
|
||||
|
|
@ -466,8 +484,8 @@ pub struct FinetuneCandidate {
|
|||
pub continuation_ids: Vec<u32>,
|
||||
/// What the model would have said without memories (if generated).
|
||||
pub alternate_text: Option<String>,
|
||||
/// Timestamp in millis for tracking trained status.
|
||||
pub timestamp_ms: i64,
|
||||
/// Timestamp in nanos — used as unique key for trained-set dedup.
|
||||
pub timestamp_ns: i64,
|
||||
}
|
||||
|
||||
/// Score and enrich finetune candidates with full context.
|
||||
|
|
@ -495,7 +513,7 @@ pub async fn score_finetune_candidates(
|
|||
let node = &entries[entry_idx];
|
||||
|
||||
// Get timestamp and skip if already trained
|
||||
let timestamp_ms = match node_timestamp_ms(node) {
|
||||
let timestamp_ns = match node_timestamp_ns(node) {
|
||||
Some(ts) => {
|
||||
if trained.contains(&ts) {
|
||||
continue; // Already trained, skip
|
||||
|
|
@ -520,7 +538,7 @@ pub async fn score_finetune_candidates(
|
|||
};
|
||||
|
||||
// Build token IDs: context = everything before response, continuation = response
|
||||
let context_ids = build_token_ids(context, 0..entry_idx, Filter::None);
|
||||
let (context_ids, _) = build_token_ids(context, 0..entry_idx, Filter::None);
|
||||
let continuation_ids: Vec<u32> = node.token_ids().into_iter().collect();
|
||||
|
||||
candidates.push(FinetuneCandidate {
|
||||
|
|
@ -530,7 +548,7 @@ pub async fn score_finetune_candidates(
|
|||
context_ids,
|
||||
continuation_ids,
|
||||
alternate_text: None,
|
||||
timestamp_ms,
|
||||
timestamp_ns,
|
||||
});
|
||||
}
|
||||
|
||||
|
|
@ -556,7 +574,7 @@ async fn generate_alternate(
|
|||
use crate::agent::api::{SamplingParams, StreamToken};
|
||||
|
||||
// Build context tokens without memories, up to the response
|
||||
let mut prompt = build_token_ids(context, 0..entry_idx, Filter::SkipAllMemories);
|
||||
let (mut prompt, _) = build_token_ids(context, 0..entry_idx, Filter::SkipAllMemories);
|
||||
|
||||
// Add assistant turn start
|
||||
prompt.push(tokenizer::IM_START);
|
||||
|
|
@ -616,7 +634,7 @@ pub fn set_alternates(enabled: bool) {
|
|||
}
|
||||
}
|
||||
|
||||
/// Load set of trained response timestamps (millis since epoch).
|
||||
/// Load set of trained response timestamps (nanos since epoch).
|
||||
pub fn load_trained() -> HashSet<i64> {
|
||||
let path = trained_path();
|
||||
match std::fs::read_to_string(&path) {
|
||||
|
|
@ -626,9 +644,9 @@ pub fn load_trained() -> HashSet<i64> {
|
|||
}
|
||||
|
||||
/// Mark a response as trained by its timestamp.
|
||||
pub fn mark_trained(timestamp_ms: i64) {
|
||||
pub fn mark_trained(timestamp_ns: i64) {
|
||||
let mut trained = load_trained();
|
||||
trained.insert(timestamp_ms);
|
||||
trained.insert(timestamp_ns);
|
||||
let path = trained_path();
|
||||
if let Some(parent) = path.parent() {
|
||||
let _ = std::fs::create_dir_all(parent);
|
||||
|
|
@ -638,15 +656,19 @@ pub fn mark_trained(timestamp_ms: i64) {
|
|||
}
|
||||
}
|
||||
|
||||
/// Get timestamp in millis from an AstNode (for Branch, uses first child).
|
||||
pub fn node_timestamp_ms(node: &AstNode) -> Option<i64> {
|
||||
/// Get timestamp in nanoseconds from an AstNode.
|
||||
/// Returns None for entries with default UNIX_EPOCH timestamp (old data)
|
||||
/// or timestamps outside the representable nano range (pre-1677 or post-2262).
|
||||
pub fn node_timestamp_ns(node: &AstNode) -> Option<i64> {
|
||||
let ts = match node {
|
||||
AstNode::Leaf(leaf) => leaf.timestamp(),
|
||||
AstNode::Branch { children, .. } => {
|
||||
children.first()?.leaf()?.timestamp()
|
||||
}
|
||||
}?;
|
||||
Some(ts.timestamp_millis())
|
||||
AstNode::Branch { timestamp, .. } => *timestamp,
|
||||
};
|
||||
if ts == chrono::DateTime::UNIX_EPOCH {
|
||||
None // Old entry without real timestamp
|
||||
} else {
|
||||
ts.timestamp_nanos_opt()
|
||||
}
|
||||
}
|
||||
|
||||
// ── Training API ────────────────────────────────────────────────
|
||||
|
|
@ -662,7 +684,7 @@ struct TrainingSample {
|
|||
pub struct TrainData {
|
||||
pub context_ids: Vec<u32>,
|
||||
pub continuation_ids: Vec<u32>,
|
||||
pub timestamp_ms: i64,
|
||||
pub timestamp_ns: i64,
|
||||
}
|
||||
|
||||
/// Send training samples to the server.
|
||||
|
|
@ -703,7 +725,7 @@ pub async fn send_to_train(
|
|||
|
||||
// Mark all samples as trained
|
||||
for s in &samples {
|
||||
mark_trained(s.timestamp_ms);
|
||||
mark_trained(s.timestamp_ns);
|
||||
}
|
||||
|
||||
let job_id = result.get("job_id")
|
||||
|
|
|
|||
|
|
@ -31,8 +31,8 @@ pub struct FinetuneCandidate {
|
|||
pub continuation_ids: Vec<u32>,
|
||||
/// What the model would have said without memories (if generated).
|
||||
pub alternate_text: Option<String>,
|
||||
/// Timestamp in millis for tracking trained status.
|
||||
pub timestamp_ms: i64,
|
||||
/// Timestamp in nanos — used as unique key for trained-set dedup.
|
||||
pub timestamp_ns: i64,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq)]
|
||||
|
|
@ -53,7 +53,7 @@ impl From<crate::subconscious::learn::FinetuneCandidate> for FinetuneCandidate {
|
|||
context_ids: c.context_ids,
|
||||
continuation_ids: c.continuation_ids,
|
||||
alternate_text: c.alternate_text,
|
||||
timestamp_ms: c.timestamp_ms,
|
||||
timestamp_ns: c.timestamp_ns,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -171,7 +171,7 @@ impl App {
|
|||
.map(|c| crate::subconscious::learn::TrainData {
|
||||
context_ids: c.context_ids.clone(),
|
||||
continuation_ids: c.continuation_ids.clone(),
|
||||
timestamp_ms: c.timestamp_ms,
|
||||
timestamp_ns: c.timestamp_ns,
|
||||
})
|
||||
.collect();
|
||||
|
||||
|
|
@ -487,7 +487,7 @@ async fn run(
|
|||
app.finetune_candidates.retain(|c| c.status != learn::CandidateStatus::Sent);
|
||||
for c in &ms.finetune_candidates {
|
||||
let exists = app.finetune_candidates.iter()
|
||||
.any(|existing| existing.timestamp_ms == c.timestamp_ms);
|
||||
.any(|existing| existing.timestamp_ns == c.timestamp_ns);
|
||||
if !exists {
|
||||
app.finetune_candidates.push(learn::FinetuneCandidate::from(c.clone()));
|
||||
}
|
||||
|
|
@ -496,7 +496,7 @@ async fn run(
|
|||
let mut rejected: Vec<_> = app.finetune_candidates.iter()
|
||||
.enumerate()
|
||||
.filter(|(_, c)| c.status == learn::CandidateStatus::Rejected)
|
||||
.map(|(i, c)| (i, c.timestamp_ms))
|
||||
.map(|(i, c)| (i, c.timestamp_ns))
|
||||
.collect();
|
||||
if rejected.len() > 10 {
|
||||
rejected.sort_by_key(|(_, ts)| std::cmp::Reverse(*ts));
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue