forked from kent/consciousness
agent: bundle sampling fields as SamplingParams on AgentState
Collapse the split `temperature` / `top_p` / `top_k` fields on
AgentState into a single `sampling: SamplingParams` struct, mirroring
how the wire-level fields flow into the Generate RPC. Adds
`max_tokens` to SamplingParams so it's actually plumbed end to end
(previously the client had a hardcoded 4096 fallback inside
`run_session_generate`).
AgentState construction sites now set `sampling: SamplingParams { ...
max_tokens: 4096 }` as the default. The assignment sites in
oneshot.rs / subconscious.rs / unconscious.rs switch from
`st.temperature = X` to `st.sampling.temperature = X`.
`stream_session_mm` takes `SamplingParams` directly; the
`sampling_max_tokens()` helper goes away. `pb::GenerateRequest` is
populated with `sampling.max_tokens` (and the other fields) in
`run_session_generate`. SamplingParams is `pub` so it can be
embedded in the public AgentState without a visibility warning.
Co-Authored-By: Proof of Concept <poc@bcachefs.org>
This commit is contained in:
parent
8d9c9e9f7b
commit
be6ba4e9a5
6 changed files with 29 additions and 32 deletions
|
|
@ -38,6 +38,21 @@ pub struct ReadoutManifest {
|
|||
/// from pairing with the manifest fetched at startup.
|
||||
pub type TokenReadout = Vec<Vec<f32>>;
|
||||
|
||||
/// Client-side sampling state. Mirrors the wire-level fields in
|
||||
/// `GenerateRequest` (proto flattened its `SamplingParams` submessage
|
||||
/// in so the server handler reads them directly), but stays as a
|
||||
/// grouped struct on the client because UI / config / tests pass
|
||||
/// these around together.
|
||||
#[derive(Clone, Copy)]
|
||||
pub struct SamplingParams {
|
||||
pub temperature: f32,
|
||||
pub top_p: f32,
|
||||
pub top_k: u32,
|
||||
/// Decode budget. 0 = prefill only; >0 = decode up to this many
|
||||
/// tokens, stopping early on EOS / stop_token_ids.
|
||||
pub max_tokens: u32,
|
||||
}
|
||||
|
||||
/// A JoinHandle that aborts its task when dropped.
|
||||
pub(crate) struct AbortOnDrop(tokio::task::JoinHandle<()>);
|
||||
|
||||
|
|
@ -47,13 +62,6 @@ impl Drop for AbortOnDrop {
|
|||
}
|
||||
}
|
||||
|
||||
/// Sampling parameters for model generation.
|
||||
#[derive(Clone, Copy)]
|
||||
pub(crate) struct SamplingParams {
|
||||
pub temperature: f32,
|
||||
pub top_p: f32,
|
||||
pub top_k: u32,
|
||||
}
|
||||
|
||||
// ─────────────────────────────────────────────────────────────
|
||||
// Stream events — yielded by backends, consumed by the runner
|
||||
|
|
@ -288,13 +296,12 @@ async fn run_session_generate(
|
|||
} else {
|
||||
Vec::new()
|
||||
};
|
||||
let max_tokens = sampling_max_tokens(&sampling);
|
||||
let req = pb::GenerateRequest {
|
||||
session_id: handle.session_id.clone(),
|
||||
append_tokens: pending,
|
||||
offset: handle.committed_len,
|
||||
truncating: false,
|
||||
max_tokens,
|
||||
max_tokens: sampling.max_tokens,
|
||||
logprobs_ranges: Vec::new(),
|
||||
logprob_top_k: 0,
|
||||
readout_ranges,
|
||||
|
|
@ -422,10 +429,4 @@ async fn flush_pending(
|
|||
Ok(())
|
||||
}
|
||||
|
||||
fn sampling_max_tokens(_sampling: &SamplingParams) -> u32 {
|
||||
// SamplingParams doesn't carry max_tokens today; 4096 mirrors
|
||||
// the old server-side default and is a sensible interactive cap.
|
||||
// TODO: plumb from the caller if we need bigger budgets.
|
||||
4096
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue