forked from kent/consciousness
agent: bundle sampling fields as SamplingParams on AgentState
Collapse the split `temperature` / `top_p` / `top_k` fields on
AgentState into a single `sampling: SamplingParams` struct, mirroring
how the wire-level fields flow into the Generate RPC. Adds
`max_tokens` to SamplingParams so it's actually plumbed end to end
(previously the client had a hardcoded 4096 fallback inside
`run_session_generate`).
AgentState construction sites now set `sampling: SamplingParams { ...
max_tokens: 4096 }` as the default. The assignment sites in
oneshot.rs / subconscious.rs / unconscious.rs switch from
`st.temperature = X` to `st.sampling.temperature = X`.
`stream_session_mm` takes `SamplingParams` directly; the
`sampling_max_tokens()` helper goes away. `pb::GenerateRequest` is
populated with `sampling.max_tokens` (and the other fields) in
`run_session_generate`. SamplingParams is `pub` so it can be
embedded in the public AgentState without a visibility warning.
Co-Authored-By: Proof of Concept <poc@bcachefs.org>
This commit is contained in:
parent
8d9c9e9f7b
commit
be6ba4e9a5
6 changed files with 29 additions and 32 deletions
|
|
@ -38,6 +38,21 @@ pub struct ReadoutManifest {
|
||||||
/// from pairing with the manifest fetched at startup.
|
/// from pairing with the manifest fetched at startup.
|
||||||
pub type TokenReadout = Vec<Vec<f32>>;
|
pub type TokenReadout = Vec<Vec<f32>>;
|
||||||
|
|
||||||
|
/// Client-side sampling state. Mirrors the wire-level fields in
|
||||||
|
/// `GenerateRequest` (proto flattened its `SamplingParams` submessage
|
||||||
|
/// in so the server handler reads them directly), but stays as a
|
||||||
|
/// grouped struct on the client because UI / config / tests pass
|
||||||
|
/// these around together.
|
||||||
|
#[derive(Clone, Copy)]
|
||||||
|
pub struct SamplingParams {
|
||||||
|
pub temperature: f32,
|
||||||
|
pub top_p: f32,
|
||||||
|
pub top_k: u32,
|
||||||
|
/// Decode budget. 0 = prefill only; >0 = decode up to this many
|
||||||
|
/// tokens, stopping early on EOS / stop_token_ids.
|
||||||
|
pub max_tokens: u32,
|
||||||
|
}
|
||||||
|
|
||||||
/// A JoinHandle that aborts its task when dropped.
|
/// A JoinHandle that aborts its task when dropped.
|
||||||
pub(crate) struct AbortOnDrop(tokio::task::JoinHandle<()>);
|
pub(crate) struct AbortOnDrop(tokio::task::JoinHandle<()>);
|
||||||
|
|
||||||
|
|
@ -47,13 +62,6 @@ impl Drop for AbortOnDrop {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Sampling parameters for model generation.
|
|
||||||
#[derive(Clone, Copy)]
|
|
||||||
pub(crate) struct SamplingParams {
|
|
||||||
pub temperature: f32,
|
|
||||||
pub top_p: f32,
|
|
||||||
pub top_k: u32,
|
|
||||||
}
|
|
||||||
|
|
||||||
// ─────────────────────────────────────────────────────────────
|
// ─────────────────────────────────────────────────────────────
|
||||||
// Stream events — yielded by backends, consumed by the runner
|
// Stream events — yielded by backends, consumed by the runner
|
||||||
|
|
@ -288,13 +296,12 @@ async fn run_session_generate(
|
||||||
} else {
|
} else {
|
||||||
Vec::new()
|
Vec::new()
|
||||||
};
|
};
|
||||||
let max_tokens = sampling_max_tokens(&sampling);
|
|
||||||
let req = pb::GenerateRequest {
|
let req = pb::GenerateRequest {
|
||||||
session_id: handle.session_id.clone(),
|
session_id: handle.session_id.clone(),
|
||||||
append_tokens: pending,
|
append_tokens: pending,
|
||||||
offset: handle.committed_len,
|
offset: handle.committed_len,
|
||||||
truncating: false,
|
truncating: false,
|
||||||
max_tokens,
|
max_tokens: sampling.max_tokens,
|
||||||
logprobs_ranges: Vec::new(),
|
logprobs_ranges: Vec::new(),
|
||||||
logprob_top_k: 0,
|
logprob_top_k: 0,
|
||||||
readout_ranges,
|
readout_ranges,
|
||||||
|
|
@ -422,10 +429,4 @@ async fn flush_pending(
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn sampling_max_tokens(_sampling: &SamplingParams) -> u32 {
|
|
||||||
// SamplingParams doesn't carry max_tokens today; 4096 mirrors
|
|
||||||
// the old server-side default and is a sensible interactive cap.
|
|
||||||
// TODO: plumb from the caller if we need bigger budgets.
|
|
||||||
4096
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -177,9 +177,7 @@ pub struct AgentState {
|
||||||
pub think_native: bool,
|
pub think_native: bool,
|
||||||
/// Tool-based thinking — add a "think" tool for structured reasoning.
|
/// Tool-based thinking — add a "think" tool for structured reasoning.
|
||||||
pub think_tool: bool,
|
pub think_tool: bool,
|
||||||
pub temperature: f32,
|
pub sampling: api::SamplingParams,
|
||||||
pub top_p: f32,
|
|
||||||
pub top_k: u32,
|
|
||||||
pub activities: Vec<ActivityEntry>,
|
pub activities: Vec<ActivityEntry>,
|
||||||
next_activity_id: u64,
|
next_activity_id: u64,
|
||||||
pub pending_yield: bool,
|
pub pending_yield: bool,
|
||||||
|
|
@ -241,9 +239,12 @@ impl Agent {
|
||||||
reasoning_effort: "none".to_string(),
|
reasoning_effort: "none".to_string(),
|
||||||
think_native: true,
|
think_native: true,
|
||||||
think_tool: false,
|
think_tool: false,
|
||||||
|
sampling: api::SamplingParams {
|
||||||
temperature: 0.6,
|
temperature: 0.6,
|
||||||
top_p: 0.95,
|
top_p: 0.95,
|
||||||
top_k: 20,
|
top_k: 20,
|
||||||
|
max_tokens: 4096,
|
||||||
|
},
|
||||||
activities: Vec::new(),
|
activities: Vec::new(),
|
||||||
next_activity_id: 0,
|
next_activity_id: 0,
|
||||||
pending_yield: false,
|
pending_yield: false,
|
||||||
|
|
@ -312,9 +313,7 @@ impl Agent {
|
||||||
reasoning_effort: "none".to_string(),
|
reasoning_effort: "none".to_string(),
|
||||||
think_native: st.think_native,
|
think_native: st.think_native,
|
||||||
think_tool: st.think_tool,
|
think_tool: st.think_tool,
|
||||||
temperature: st.temperature,
|
sampling: st.sampling,
|
||||||
top_p: st.top_p,
|
|
||||||
top_k: st.top_k,
|
|
||||||
activities: Vec::new(),
|
activities: Vec::new(),
|
||||||
next_activity_id: 0,
|
next_activity_id: 0,
|
||||||
pending_yield: false,
|
pending_yield: false,
|
||||||
|
|
@ -424,11 +423,7 @@ impl Agent {
|
||||||
agent.client.stream_session_mm(
|
agent.client.stream_session_mm(
|
||||||
agent.grpc_session.clone(),
|
agent.grpc_session.clone(),
|
||||||
chunks,
|
chunks,
|
||||||
api::SamplingParams {
|
st.sampling,
|
||||||
temperature: st.temperature,
|
|
||||||
top_p: st.top_p,
|
|
||||||
top_k: st.top_k,
|
|
||||||
},
|
|
||||||
st.priority,
|
st.priority,
|
||||||
readout_shape,
|
readout_shape,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -269,7 +269,7 @@ impl AutoAgent {
|
||||||
let mut st = agent.state.lock().await;
|
let mut st = agent.state.lock().await;
|
||||||
st.provenance = format!("standalone:{}", self.name);
|
st.provenance = format!("standalone:{}", self.name);
|
||||||
st.tools = self.tools.clone();
|
st.tools = self.tools.clone();
|
||||||
st.temperature = self.temperature;
|
st.sampling.temperature = self.temperature;
|
||||||
st.priority = Some(self.priority);
|
st.priority = Some(self.priority);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -631,7 +631,7 @@ impl Subconscious {
|
||||||
{
|
{
|
||||||
let mut st = forked.state.lock().await;
|
let mut st = forked.state.lock().await;
|
||||||
st.provenance = auto.name.clone();
|
st.provenance = auto.name.clone();
|
||||||
st.temperature = auto.temperature;
|
st.sampling.temperature = auto.temperature;
|
||||||
// Surface agent gets near-interactive priority;
|
// Surface agent gets near-interactive priority;
|
||||||
// other subconscious agents get lower priority.
|
// other subconscious agents get lower priority.
|
||||||
st.priority = Some(if auto.name == "surface" { 1 } else { auto.priority });
|
st.priority = Some(if auto.name == "surface" { 1 } else { auto.priority });
|
||||||
|
|
|
||||||
|
|
@ -321,7 +321,7 @@ pub async fn prepare_spawn(
|
||||||
let mut st = agent.state.lock().await;
|
let mut st = agent.state.lock().await;
|
||||||
st.provenance = auto.name.clone();
|
st.provenance = auto.name.clone();
|
||||||
st.priority = Some(auto.priority);
|
st.priority = Some(auto.priority);
|
||||||
st.temperature = auto.temperature;
|
st.sampling.temperature = auto.temperature;
|
||||||
}
|
}
|
||||||
|
|
||||||
let agent_clone = agent.clone();
|
let agent_clone = agent.clone();
|
||||||
|
|
|
||||||
|
|
@ -43,6 +43,7 @@ where F: FnMut(&AstNode) -> bool,
|
||||||
temperature: 0.6,
|
temperature: 0.6,
|
||||||
top_p: 0.95,
|
top_p: 0.95,
|
||||||
top_k: 20,
|
top_k: 20,
|
||||||
|
max_tokens: 4096,
|
||||||
};
|
};
|
||||||
|
|
||||||
// Ephemeral per-call session — opens on first touch, drops when
|
// Ephemeral per-call session — opens on first touch, drops when
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue