agent: add sampling parameters (temperature, top_p, top_k)
Move temperature from a per-call parameter to an Agent field, add top_p and top_k. All three are sent to the API via a new SamplingParams struct, displayed on the F5 thalamus screen. Defaults: temperature=0.6, top_p=0.95, top_k=20 (Qwen3.5 defaults). Also adds top_p and top_k to ChatRequest so they're sent in the API payload. Previously only temperature was sent. UI controls for adjusting these at runtime are not yet implemented. Co-Authored-By: Proof of Concept <poc@bcachefs.org>
This commit is contained in:
parent
22f955ad9f
commit
dd009742ef
7 changed files with 53 additions and 8 deletions
|
|
@ -29,6 +29,14 @@ impl Drop for AbortOnDrop {
|
|||
}
|
||||
}
|
||||
|
||||
/// Sampling parameters for model generation.
|
||||
#[derive(Clone, Copy)]
|
||||
pub struct SamplingParams {
|
||||
pub temperature: f32,
|
||||
pub top_p: f32,
|
||||
pub top_k: u32,
|
||||
}
|
||||
|
||||
// ─────────────────────────────────────────────────────────────
|
||||
// Stream events — yielded by backends, consumed by the runner
|
||||
// ─────────────────────────────────────────────────────────────
|
||||
|
|
@ -93,7 +101,7 @@ impl ApiClient {
|
|||
tools: Option<&[ToolDef]>,
|
||||
ui_tx: &UiSender,
|
||||
reasoning_effort: &str,
|
||||
temperature: Option<f32>,
|
||||
sampling: SamplingParams,
|
||||
priority: Option<i32>,
|
||||
) -> (mpsc::UnboundedReceiver<StreamEvent>, AbortOnDrop) {
|
||||
let (tx, rx) = mpsc::unbounded_channel();
|
||||
|
|
@ -110,7 +118,7 @@ impl ApiClient {
|
|||
let result = openai::stream_events(
|
||||
&client, &base_url, &api_key, &model,
|
||||
&messages, tools.as_deref(), &tx, &ui_tx,
|
||||
&reasoning_effort, temperature, priority,
|
||||
&reasoning_effort, sampling, priority,
|
||||
).await;
|
||||
if let Err(e) = result {
|
||||
let _ = tx.send(StreamEvent::Error(e.to_string()));
|
||||
|
|
@ -126,11 +134,11 @@ impl ApiClient {
|
|||
tools: Option<&[ToolDef]>,
|
||||
ui_tx: &UiSender,
|
||||
reasoning_effort: &str,
|
||||
temperature: Option<f32>,
|
||||
sampling: SamplingParams,
|
||||
priority: Option<i32>,
|
||||
) -> Result<(Message, Option<Usage>)> {
|
||||
// Use the event stream and accumulate into a message.
|
||||
let (mut rx, _handle) = self.start_stream(messages, tools, ui_tx, reasoning_effort, temperature, priority);
|
||||
let (mut rx, _handle) = self.start_stream(messages, tools, ui_tx, reasoning_effort, sampling, priority);
|
||||
let mut content = String::new();
|
||||
let mut tool_calls: Vec<ToolCall> = Vec::new();
|
||||
let mut usage = None;
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ pub(super) async fn stream_events(
|
|||
tx: &mpsc::UnboundedSender<StreamEvent>,
|
||||
ui_tx: &UiSender,
|
||||
reasoning_effort: &str,
|
||||
temperature: Option<f32>,
|
||||
sampling: super::SamplingParams,
|
||||
priority: Option<i32>,
|
||||
) -> Result<()> {
|
||||
let request = ChatRequest {
|
||||
|
|
@ -35,7 +35,9 @@ pub(super) async fn stream_events(
|
|||
tool_choice: tools.map(|_| "auto".to_string()),
|
||||
tools: tools.map(|t| t.to_vec()),
|
||||
max_tokens: Some(16384),
|
||||
temperature: Some(temperature.unwrap_or(0.6)),
|
||||
temperature: Some(sampling.temperature),
|
||||
top_p: Some(sampling.top_p),
|
||||
top_k: Some(sampling.top_k),
|
||||
stream: Some(true),
|
||||
reasoning: if reasoning_effort != "none" && reasoning_effort != "default" {
|
||||
Some(ReasoningConfig {
|
||||
|
|
|
|||
|
|
@ -95,6 +95,10 @@ pub struct ChatRequest {
|
|||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub temperature: Option<f32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub top_p: Option<f32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub top_k: Option<u32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub stream: Option<bool>,
|
||||
/// OpenRouter reasoning control. Send both formats for compatibility:
|
||||
/// - reasoning.enabled (older format, still seen in examples)
|
||||
|
|
|
|||
|
|
@ -77,6 +77,10 @@ pub struct Agent {
|
|||
last_prompt_tokens: u32,
|
||||
/// Current reasoning effort level ("none", "low", "high").
|
||||
pub reasoning_effort: String,
|
||||
/// Sampling parameters — adjustable at runtime from the thalamus screen.
|
||||
pub temperature: f32,
|
||||
pub top_p: f32,
|
||||
pub top_k: u32,
|
||||
/// Persistent conversation log — append-only record of all messages.
|
||||
conversation_log: Option<ConversationLog>,
|
||||
/// BPE tokenizer for token counting (cl100k_base — close enough
|
||||
|
|
@ -137,6 +141,9 @@ impl Agent {
|
|||
tool_defs,
|
||||
last_prompt_tokens: 0,
|
||||
reasoning_effort: "none".to_string(),
|
||||
temperature: 0.6,
|
||||
top_p: 0.95,
|
||||
top_k: 20,
|
||||
conversation_log,
|
||||
tokenizer,
|
||||
context,
|
||||
|
|
@ -288,12 +295,17 @@ impl Agent {
|
|||
let (mut rx, _stream_guard) = {
|
||||
let me = agent.lock().await;
|
||||
let api_messages = me.assemble_api_messages();
|
||||
let sampling = api::SamplingParams {
|
||||
temperature: me.temperature,
|
||||
top_p: me.top_p,
|
||||
top_k: me.top_k,
|
||||
};
|
||||
me.client.start_stream(
|
||||
&api_messages,
|
||||
Some(&me.tool_defs),
|
||||
ui_tx,
|
||||
&me.reasoning_effort,
|
||||
None,
|
||||
sampling,
|
||||
None,
|
||||
)
|
||||
};
|
||||
|
|
|
|||
|
|
@ -76,12 +76,17 @@ pub async fn call_api_with_tools(
|
|||
let mut msg_opt = None;
|
||||
let mut usage_opt = None;
|
||||
for attempt in 0..5 {
|
||||
let sampling = crate::agent::api::SamplingParams {
|
||||
temperature: temperature.unwrap_or(0.6),
|
||||
top_p: 0.95,
|
||||
top_k: 20,
|
||||
};
|
||||
match client.chat_completion_stream_temp(
|
||||
&messages,
|
||||
Some(&tool_defs),
|
||||
&ui_tx,
|
||||
&reasoning,
|
||||
temperature,
|
||||
sampling,
|
||||
Some(priority),
|
||||
).await {
|
||||
Ok((msg, usage)) => {
|
||||
|
|
|
|||
|
|
@ -273,6 +273,9 @@ pub struct App {
|
|||
pub(crate) needs_assistant_marker: bool,
|
||||
pub running_processes: u32,
|
||||
pub reasoning_effort: String,
|
||||
pub temperature: f32,
|
||||
pub top_p: f32,
|
||||
pub top_k: u32,
|
||||
pub(crate) active_tools: crate::user::ui_channel::SharedActiveTools,
|
||||
pub(crate) active_pane: ActivePane,
|
||||
pub textarea: tui_textarea::TextArea<'static>,
|
||||
|
|
@ -310,6 +313,9 @@ impl App {
|
|||
turn_started: None, call_started: None, call_timeout_secs: 60,
|
||||
needs_assistant_marker: false, running_processes: 0,
|
||||
reasoning_effort: "none".to_string(),
|
||||
temperature: 0.6,
|
||||
top_p: 0.95,
|
||||
top_k: 20,
|
||||
active_tools, active_pane: ActivePane::Conversation,
|
||||
textarea: new_textarea(vec![String::new()]),
|
||||
input_history: Vec::new(), history_index: None,
|
||||
|
|
|
|||
|
|
@ -48,6 +48,14 @@ impl App {
|
|||
}
|
||||
lines.push(Line::raw(""));
|
||||
|
||||
// Sampling parameters
|
||||
lines.push(Line::styled("── Sampling ──", section));
|
||||
lines.push(Line::raw(""));
|
||||
lines.push(Line::raw(format!(" temperature: {:.2}", self.temperature)));
|
||||
lines.push(Line::raw(format!(" top_p: {:.2}", self.top_p)));
|
||||
lines.push(Line::raw(format!(" top_k: {}", self.top_k)));
|
||||
lines.push(Line::raw(""));
|
||||
|
||||
// Channel status from cached data
|
||||
lines.push(Line::styled("── Channels ──", section));
|
||||
lines.push(Line::raw(""));
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue