diff --git a/src/agent/api/mod.rs b/src/agent/api/mod.rs index 19d05cf..d2c415f 100644 --- a/src/agent/api/mod.rs +++ b/src/agent/api/mod.rs @@ -29,6 +29,14 @@ impl Drop for AbortOnDrop { } } +/// Sampling parameters for model generation. +#[derive(Clone, Copy)] +pub struct SamplingParams { + pub temperature: f32, + pub top_p: f32, + pub top_k: u32, +} + // ───────────────────────────────────────────────────────────── // Stream events — yielded by backends, consumed by the runner // ───────────────────────────────────────────────────────────── @@ -93,7 +101,7 @@ impl ApiClient { tools: Option<&[ToolDef]>, ui_tx: &UiSender, reasoning_effort: &str, - temperature: Option, + sampling: SamplingParams, priority: Option, ) -> (mpsc::UnboundedReceiver, AbortOnDrop) { let (tx, rx) = mpsc::unbounded_channel(); @@ -110,7 +118,7 @@ impl ApiClient { let result = openai::stream_events( &client, &base_url, &api_key, &model, &messages, tools.as_deref(), &tx, &ui_tx, - &reasoning_effort, temperature, priority, + &reasoning_effort, sampling, priority, ).await; if let Err(e) = result { let _ = tx.send(StreamEvent::Error(e.to_string())); @@ -126,11 +134,11 @@ impl ApiClient { tools: Option<&[ToolDef]>, ui_tx: &UiSender, reasoning_effort: &str, - temperature: Option, + sampling: SamplingParams, priority: Option, ) -> Result<(Message, Option)> { // Use the event stream and accumulate into a message. - let (mut rx, _handle) = self.start_stream(messages, tools, ui_tx, reasoning_effort, temperature, priority); + let (mut rx, _handle) = self.start_stream(messages, tools, ui_tx, reasoning_effort, sampling, priority); let mut content = String::new(); let mut tool_calls: Vec = Vec::new(); let mut usage = None; diff --git a/src/agent/api/openai.rs b/src/agent/api/openai.rs index d780434..7b380a2 100644 --- a/src/agent/api/openai.rs +++ b/src/agent/api/openai.rs @@ -26,7 +26,7 @@ pub(super) async fn stream_events( tx: &mpsc::UnboundedSender, ui_tx: &UiSender, reasoning_effort: &str, - temperature: Option, + sampling: super::SamplingParams, priority: Option, ) -> Result<()> { let request = ChatRequest { @@ -35,7 +35,9 @@ pub(super) async fn stream_events( tool_choice: tools.map(|_| "auto".to_string()), tools: tools.map(|t| t.to_vec()), max_tokens: Some(16384), - temperature: Some(temperature.unwrap_or(0.6)), + temperature: Some(sampling.temperature), + top_p: Some(sampling.top_p), + top_k: Some(sampling.top_k), stream: Some(true), reasoning: if reasoning_effort != "none" && reasoning_effort != "default" { Some(ReasoningConfig { diff --git a/src/agent/api/types.rs b/src/agent/api/types.rs index 6a1249c..9d0d7e1 100644 --- a/src/agent/api/types.rs +++ b/src/agent/api/types.rs @@ -95,6 +95,10 @@ pub struct ChatRequest { #[serde(skip_serializing_if = "Option::is_none")] pub temperature: Option, #[serde(skip_serializing_if = "Option::is_none")] + pub top_p: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub top_k: Option, + #[serde(skip_serializing_if = "Option::is_none")] pub stream: Option, /// OpenRouter reasoning control. Send both formats for compatibility: /// - reasoning.enabled (older format, still seen in examples) diff --git a/src/agent/mod.rs b/src/agent/mod.rs index dc0c1cd..2d13552 100644 --- a/src/agent/mod.rs +++ b/src/agent/mod.rs @@ -77,6 +77,10 @@ pub struct Agent { last_prompt_tokens: u32, /// Current reasoning effort level ("none", "low", "high"). pub reasoning_effort: String, + /// Sampling parameters — adjustable at runtime from the thalamus screen. + pub temperature: f32, + pub top_p: f32, + pub top_k: u32, /// Persistent conversation log — append-only record of all messages. conversation_log: Option, /// BPE tokenizer for token counting (cl100k_base — close enough @@ -137,6 +141,9 @@ impl Agent { tool_defs, last_prompt_tokens: 0, reasoning_effort: "none".to_string(), + temperature: 0.6, + top_p: 0.95, + top_k: 20, conversation_log, tokenizer, context, @@ -288,12 +295,17 @@ impl Agent { let (mut rx, _stream_guard) = { let me = agent.lock().await; let api_messages = me.assemble_api_messages(); + let sampling = api::SamplingParams { + temperature: me.temperature, + top_p: me.top_p, + top_k: me.top_k, + }; me.client.start_stream( &api_messages, Some(&me.tool_defs), ui_tx, &me.reasoning_effort, - None, + sampling, None, ) }; diff --git a/src/subconscious/api.rs b/src/subconscious/api.rs index 31cb3fc..b9d5be9 100644 --- a/src/subconscious/api.rs +++ b/src/subconscious/api.rs @@ -76,12 +76,17 @@ pub async fn call_api_with_tools( let mut msg_opt = None; let mut usage_opt = None; for attempt in 0..5 { + let sampling = crate::agent::api::SamplingParams { + temperature: temperature.unwrap_or(0.6), + top_p: 0.95, + top_k: 20, + }; match client.chat_completion_stream_temp( &messages, Some(&tool_defs), &ui_tx, &reasoning, - temperature, + sampling, Some(priority), ).await { Ok((msg, usage)) => { diff --git a/src/user/mod.rs b/src/user/mod.rs index fc2f6dd..40d6c0b 100644 --- a/src/user/mod.rs +++ b/src/user/mod.rs @@ -273,6 +273,9 @@ pub struct App { pub(crate) needs_assistant_marker: bool, pub running_processes: u32, pub reasoning_effort: String, + pub temperature: f32, + pub top_p: f32, + pub top_k: u32, pub(crate) active_tools: crate::user::ui_channel::SharedActiveTools, pub(crate) active_pane: ActivePane, pub textarea: tui_textarea::TextArea<'static>, @@ -310,6 +313,9 @@ impl App { turn_started: None, call_started: None, call_timeout_secs: 60, needs_assistant_marker: false, running_processes: 0, reasoning_effort: "none".to_string(), + temperature: 0.6, + top_p: 0.95, + top_k: 20, active_tools, active_pane: ActivePane::Conversation, textarea: new_textarea(vec![String::new()]), input_history: Vec::new(), history_index: None, diff --git a/src/user/thalamus.rs b/src/user/thalamus.rs index c17288d..8bf6f71 100644 --- a/src/user/thalamus.rs +++ b/src/user/thalamus.rs @@ -48,6 +48,14 @@ impl App { } lines.push(Line::raw("")); + // Sampling parameters + lines.push(Line::styled("── Sampling ──", section)); + lines.push(Line::raw("")); + lines.push(Line::raw(format!(" temperature: {:.2}", self.temperature))); + lines.push(Line::raw(format!(" top_p: {:.2}", self.top_p))); + lines.push(Line::raw(format!(" top_k: {}", self.top_k))); + lines.push(Line::raw("")); + // Channel status from cached data lines.push(Line::styled("── Channels ──", section)); lines.push(Line::raw(""));