agent: bundle sampling fields as SamplingParams on AgentState

Collapse the split `temperature` / `top_p` / `top_k` fields on
AgentState into a single `sampling: SamplingParams` struct, mirroring
how the wire-level fields flow into the Generate RPC. Adds
`max_tokens` to SamplingParams so it's actually plumbed end to end
(previously the client had a hardcoded 4096 fallback inside
`run_session_generate`).

AgentState construction sites now set `sampling: SamplingParams { ...
max_tokens: 4096 }` as the default. The assignment sites in
oneshot.rs / subconscious.rs / unconscious.rs switch from
`st.temperature = X` to `st.sampling.temperature = X`.

`stream_session_mm` takes `SamplingParams` directly; the
`sampling_max_tokens()` helper goes away. `pb::GenerateRequest` is
populated with `sampling.max_tokens` (and the other fields) in
`run_session_generate`. SamplingParams is `pub` so it can be
embedded in the public AgentState without a visibility warning.

Co-Authored-By: Proof of Concept <poc@bcachefs.org>
This commit is contained in:
Kent Overstreet 2026-04-24 12:37:20 -04:00
commit be6ba4e9a5
6 changed files with 29 additions and 32 deletions

View file

@ -38,6 +38,21 @@ pub struct ReadoutManifest {
/// from pairing with the manifest fetched at startup.
pub type TokenReadout = Vec<Vec<f32>>;
/// Client-side sampling state. Mirrors the wire-level fields in
/// `GenerateRequest` (proto flattened its `SamplingParams` submessage
/// in so the server handler reads them directly), but stays as a
/// grouped struct on the client because UI / config / tests pass
/// these around together.
#[derive(Clone, Copy)]
pub struct SamplingParams {
pub temperature: f32,
pub top_p: f32,
pub top_k: u32,
/// Decode budget. 0 = prefill only; >0 = decode up to this many
/// tokens, stopping early on EOS / stop_token_ids.
pub max_tokens: u32,
}
/// A JoinHandle that aborts its task when dropped.
pub(crate) struct AbortOnDrop(tokio::task::JoinHandle<()>);
@ -47,13 +62,6 @@ impl Drop for AbortOnDrop {
}
}
/// Sampling parameters for model generation.
#[derive(Clone, Copy)]
pub(crate) struct SamplingParams {
pub temperature: f32,
pub top_p: f32,
pub top_k: u32,
}
// ─────────────────────────────────────────────────────────────
// Stream events — yielded by backends, consumed by the runner
@ -288,13 +296,12 @@ async fn run_session_generate(
} else {
Vec::new()
};
let max_tokens = sampling_max_tokens(&sampling);
let req = pb::GenerateRequest {
session_id: handle.session_id.clone(),
append_tokens: pending,
offset: handle.committed_len,
truncating: false,
max_tokens,
max_tokens: sampling.max_tokens,
logprobs_ranges: Vec::new(),
logprob_top_k: 0,
readout_ranges,
@ -422,10 +429,4 @@ async fn flush_pending(
Ok(())
}
fn sampling_max_tokens(_sampling: &SamplingParams) -> u32 {
// SamplingParams doesn't carry max_tokens today; 4096 mirrors
// the old server-side default and is a sensible interactive cap.
// TODO: plumb from the caller if we need bigger budgets.
4096
}