agent: bundle sampling fields as SamplingParams on AgentState

Collapse the split `temperature` / `top_p` / `top_k` fields on
AgentState into a single `sampling: SamplingParams` struct, mirroring
how the wire-level fields flow into the Generate RPC. Adds
`max_tokens` to SamplingParams so it's actually plumbed end to end
(previously the client had a hardcoded 4096 fallback inside
`run_session_generate`).

AgentState construction sites now set `sampling: SamplingParams { ...
max_tokens: 4096 }` as the default. The assignment sites in
oneshot.rs / subconscious.rs / unconscious.rs switch from
`st.temperature = X` to `st.sampling.temperature = X`.

`stream_session_mm` takes `SamplingParams` directly; the
`sampling_max_tokens()` helper goes away. `pb::GenerateRequest` is
populated with `sampling.max_tokens` (and the other fields) in
`run_session_generate`. SamplingParams is `pub` so it can be
embedded in the public AgentState without a visibility warning.

Co-Authored-By: Proof of Concept <poc@bcachefs.org>
This commit is contained in:
Kent Overstreet 2026-04-24 12:37:20 -04:00
commit be6ba4e9a5
6 changed files with 29 additions and 32 deletions

View file

@ -38,6 +38,21 @@ pub struct ReadoutManifest {
/// from pairing with the manifest fetched at startup.
pub type TokenReadout = Vec<Vec<f32>>;
/// Client-side sampling state. Mirrors the wire-level fields in
/// `GenerateRequest` (proto flattened its `SamplingParams` submessage
/// in so the server handler reads them directly), but stays as a
/// grouped struct on the client because UI / config / tests pass
/// these around together.
#[derive(Clone, Copy)]
pub struct SamplingParams {
pub temperature: f32,
pub top_p: f32,
pub top_k: u32,
/// Decode budget. 0 = prefill only; >0 = decode up to this many
/// tokens, stopping early on EOS / stop_token_ids.
pub max_tokens: u32,
}
/// A JoinHandle that aborts its task when dropped.
pub(crate) struct AbortOnDrop(tokio::task::JoinHandle<()>);
@ -47,13 +62,6 @@ impl Drop for AbortOnDrop {
}
}
/// Sampling parameters for model generation.
#[derive(Clone, Copy)]
pub(crate) struct SamplingParams {
pub temperature: f32,
pub top_p: f32,
pub top_k: u32,
}
// ─────────────────────────────────────────────────────────────
// Stream events — yielded by backends, consumed by the runner
@ -288,13 +296,12 @@ async fn run_session_generate(
} else {
Vec::new()
};
let max_tokens = sampling_max_tokens(&sampling);
let req = pb::GenerateRequest {
session_id: handle.session_id.clone(),
append_tokens: pending,
offset: handle.committed_len,
truncating: false,
max_tokens,
max_tokens: sampling.max_tokens,
logprobs_ranges: Vec::new(),
logprob_top_k: 0,
readout_ranges,
@ -422,10 +429,4 @@ async fn flush_pending(
Ok(())
}
fn sampling_max_tokens(_sampling: &SamplingParams) -> u32 {
// SamplingParams doesn't carry max_tokens today; 4096 mirrors
// the old server-side default and is a sensible interactive cap.
// TODO: plumb from the caller if we need bigger budgets.
4096
}

View file

@ -177,9 +177,7 @@ pub struct AgentState {
pub think_native: bool,
/// Tool-based thinking — add a "think" tool for structured reasoning.
pub think_tool: bool,
pub temperature: f32,
pub top_p: f32,
pub top_k: u32,
pub sampling: api::SamplingParams,
pub activities: Vec<ActivityEntry>,
next_activity_id: u64,
pub pending_yield: bool,
@ -241,9 +239,12 @@ impl Agent {
reasoning_effort: "none".to_string(),
think_native: true,
think_tool: false,
sampling: api::SamplingParams {
temperature: 0.6,
top_p: 0.95,
top_k: 20,
max_tokens: 4096,
},
activities: Vec::new(),
next_activity_id: 0,
pending_yield: false,
@ -312,9 +313,7 @@ impl Agent {
reasoning_effort: "none".to_string(),
think_native: st.think_native,
think_tool: st.think_tool,
temperature: st.temperature,
top_p: st.top_p,
top_k: st.top_k,
sampling: st.sampling,
activities: Vec::new(),
next_activity_id: 0,
pending_yield: false,
@ -424,11 +423,7 @@ impl Agent {
agent.client.stream_session_mm(
agent.grpc_session.clone(),
chunks,
api::SamplingParams {
temperature: st.temperature,
top_p: st.top_p,
top_k: st.top_k,
},
st.sampling,
st.priority,
readout_shape,
)

View file

@ -269,7 +269,7 @@ impl AutoAgent {
let mut st = agent.state.lock().await;
st.provenance = format!("standalone:{}", self.name);
st.tools = self.tools.clone();
st.temperature = self.temperature;
st.sampling.temperature = self.temperature;
st.priority = Some(self.priority);
}

View file

@ -631,7 +631,7 @@ impl Subconscious {
{
let mut st = forked.state.lock().await;
st.provenance = auto.name.clone();
st.temperature = auto.temperature;
st.sampling.temperature = auto.temperature;
// Surface agent gets near-interactive priority;
// other subconscious agents get lower priority.
st.priority = Some(if auto.name == "surface" { 1 } else { auto.priority });

View file

@ -321,7 +321,7 @@ pub async fn prepare_spawn(
let mut st = agent.state.lock().await;
st.provenance = auto.name.clone();
st.priority = Some(auto.priority);
st.temperature = auto.temperature;
st.sampling.temperature = auto.temperature;
}
let agent_clone = agent.clone();

View file

@ -43,6 +43,7 @@ where F: FnMut(&AstNode) -> bool,
temperature: 0.6,
top_p: 0.95,
top_k: 20,
max_tokens: 4096,
};
// Ephemeral per-call session — opens on first touch, drops when