src/thought -> src/agent
Signed-off-by: Kent Overstreet <kent.overstreet@linux.dev>
This commit is contained in:
parent
39d6ca3fe0
commit
2f0c7ce5c2
21 changed files with 57 additions and 141 deletions
125
src/agent/context.rs
Normal file
125
src/agent/context.rs
Normal file
|
|
@ -0,0 +1,125 @@
|
|||
// context.rs — Context window management
|
||||
//
|
||||
// Token counting, conversation trimming, and error classification.
|
||||
// Journal entries are loaded from the memory graph store, not from
|
||||
// a flat file — the parse functions are gone.
|
||||
|
||||
use crate::user::types::*;
|
||||
use chrono::{DateTime, Utc};
|
||||
use tiktoken_rs::CoreBPE;
|
||||
|
||||
/// A single journal entry with its timestamp and content.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct JournalEntry {
|
||||
pub timestamp: DateTime<Utc>,
|
||||
pub content: String,
|
||||
}
|
||||
|
||||
/// Context window size in tokens (from config).
|
||||
pub fn context_window() -> usize {
|
||||
crate::config::get().api_context_window
|
||||
}
|
||||
|
||||
/// Context budget in tokens: 80% of the model's context window.
|
||||
/// The remaining 20% is reserved for model output.
|
||||
fn context_budget_tokens() -> usize {
|
||||
context_window() * 80 / 100
|
||||
}
|
||||
|
||||
/// Dedup and trim conversation entries to fit within the context budget.
|
||||
///
|
||||
/// 1. Dedup: if the same memory key appears multiple times, keep only
|
||||
/// the latest render (drop the earlier Memory entry and its
|
||||
/// corresponding assistant tool_call message).
|
||||
/// 2. Trim: drop oldest entries until the conversation fits, snapping
|
||||
/// to user message boundaries.
|
||||
pub fn trim_entries(
|
||||
context: &ContextState,
|
||||
entries: &[ConversationEntry],
|
||||
tokenizer: &CoreBPE,
|
||||
) -> Vec<ConversationEntry> {
|
||||
let count = |s: &str| tokenizer.encode_with_special_tokens(s).len();
|
||||
|
||||
// --- Phase 1: dedup memory entries by key (keep last) ---
|
||||
let mut seen_keys: std::collections::HashMap<&str, usize> = std::collections::HashMap::new();
|
||||
let mut drop_indices: std::collections::HashSet<usize> = std::collections::HashSet::new();
|
||||
|
||||
for (i, entry) in entries.iter().enumerate() {
|
||||
if let ConversationEntry::Memory { key, .. } = entry {
|
||||
if let Some(prev) = seen_keys.insert(key.as_str(), i) {
|
||||
drop_indices.insert(prev);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let deduped: Vec<ConversationEntry> = entries.iter().enumerate()
|
||||
.filter(|(i, _)| !drop_indices.contains(i))
|
||||
.map(|(_, e)| e.clone())
|
||||
.collect();
|
||||
|
||||
// --- Phase 2: trim to fit context budget ---
|
||||
let max_tokens = context_budget_tokens();
|
||||
let identity_cost = count(&context.system_prompt)
|
||||
+ context.personality.iter().map(|(_, c)| count(c)).sum::<usize>();
|
||||
let journal_cost: usize = context.journal.iter().map(|e| count(&e.content)).sum();
|
||||
let available = max_tokens
|
||||
.saturating_sub(identity_cost)
|
||||
.saturating_sub(journal_cost);
|
||||
|
||||
let msg_costs: Vec<usize> = deduped.iter()
|
||||
.map(|e| msg_token_count(tokenizer, e.message())).collect();
|
||||
let total: usize = msg_costs.iter().sum();
|
||||
|
||||
let mut skip = 0;
|
||||
let mut trimmed = total;
|
||||
while trimmed > available && skip < deduped.len() {
|
||||
trimmed -= msg_costs[skip];
|
||||
skip += 1;
|
||||
}
|
||||
|
||||
// Walk forward to user message boundary
|
||||
while skip < deduped.len() && deduped[skip].message().role != Role::User {
|
||||
skip += 1;
|
||||
}
|
||||
|
||||
deduped[skip..].to_vec()
|
||||
}
|
||||
|
||||
/// Count the token footprint of a message using BPE tokenization.
|
||||
pub fn msg_token_count(tokenizer: &CoreBPE, msg: &Message) -> usize {
|
||||
let count = |s: &str| tokenizer.encode_with_special_tokens(s).len();
|
||||
let content = msg.content.as_ref().map_or(0, |c| match c {
|
||||
MessageContent::Text(s) => count(s),
|
||||
MessageContent::Parts(parts) => parts.iter()
|
||||
.map(|p| match p {
|
||||
ContentPart::Text { text } => count(text),
|
||||
ContentPart::ImageUrl { .. } => 85,
|
||||
})
|
||||
.sum(),
|
||||
});
|
||||
let tools = msg.tool_calls.as_ref().map_or(0, |calls| {
|
||||
calls.iter()
|
||||
.map(|c| count(&c.function.arguments) + count(&c.function.name))
|
||||
.sum()
|
||||
});
|
||||
content + tools
|
||||
}
|
||||
|
||||
/// Detect context window overflow errors from the API.
|
||||
pub fn is_context_overflow(err: &anyhow::Error) -> bool {
|
||||
let msg = err.to_string().to_lowercase();
|
||||
msg.contains("context length")
|
||||
|| msg.contains("token limit")
|
||||
|| msg.contains("too many tokens")
|
||||
|| msg.contains("maximum context")
|
||||
|| msg.contains("prompt is too long")
|
||||
|| msg.contains("request too large")
|
||||
|| msg.contains("input validation error")
|
||||
|| msg.contains("content length limit")
|
||||
|| (msg.contains("400") && msg.contains("tokens"))
|
||||
}
|
||||
|
||||
/// Detect model/provider errors delivered inside the SSE stream.
|
||||
pub fn is_stream_error(err: &anyhow::Error) -> bool {
|
||||
err.to_string().contains("model stream error")
|
||||
}
|
||||
121
src/agent/mod.rs
Normal file
121
src/agent/mod.rs
Normal file
|
|
@ -0,0 +1,121 @@
|
|||
// agent — core agent infrastructure
|
||||
//
|
||||
// Tool dispatch, memory operations, file operations, context
|
||||
// management, and the agent runner loop. Used by both the
|
||||
// interactive consciousness binary and subconscious agents.
|
||||
|
||||
pub mod context;
|
||||
pub mod runner;
|
||||
pub mod tools;
|
||||
pub mod training;
|
||||
|
||||
pub use tools::bash::ProcessTracker;
|
||||
|
||||
// Re-export ToolDef from agent::types for convenience —
|
||||
// tools define their schemas using this type.
|
||||
pub use crate::user::types::ToolDef;
|
||||
|
||||
/// Result of dispatching a tool call.
|
||||
pub struct ToolOutput {
|
||||
pub text: String,
|
||||
pub is_yield: bool,
|
||||
/// Base64 data URIs for images to attach to the next message.
|
||||
pub images: Vec<String>,
|
||||
/// Model name to switch to (deferred to session level).
|
||||
pub model_switch: Option<String>,
|
||||
/// Agent requested DMN pause (deferred to session level).
|
||||
pub dmn_pause: bool,
|
||||
}
|
||||
|
||||
impl ToolOutput {
|
||||
pub fn error(e: impl std::fmt::Display) -> Self {
|
||||
Self {
|
||||
text: format!("Error: {}", e),
|
||||
is_yield: false,
|
||||
images: Vec::new(),
|
||||
model_switch: None,
|
||||
dmn_pause: false,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn text(s: String) -> Self {
|
||||
Self {
|
||||
text: s,
|
||||
is_yield: false,
|
||||
images: Vec::new(),
|
||||
model_switch: None,
|
||||
dmn_pause: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Truncate output if it exceeds max length, appending a truncation notice.
|
||||
pub fn truncate_output(mut s: String, max: usize) -> String {
|
||||
if s.len() > max {
|
||||
s.truncate(max);
|
||||
s.push_str("\n... (output truncated)");
|
||||
}
|
||||
s
|
||||
}
|
||||
|
||||
/// Dispatch a shared tool call. Handles file operations, bash,
|
||||
/// and memory/journal tools. Returns None for unknown tools
|
||||
/// (caller should check agent-specific tools).
|
||||
pub async fn dispatch(
|
||||
name: &str,
|
||||
args: &serde_json::Value,
|
||||
tracker: &ProcessTracker,
|
||||
provenance: Option<&str>,
|
||||
) -> Option<ToolOutput> {
|
||||
// Memory and journal tools
|
||||
if name.starts_with("memory_") || name.starts_with("journal_") || name == "output" {
|
||||
let result = tools::memory::dispatch(name, args, provenance);
|
||||
return Some(match result {
|
||||
Ok(s) => ToolOutput::text(s),
|
||||
Err(e) => ToolOutput::error(e),
|
||||
});
|
||||
}
|
||||
|
||||
// File and execution tools
|
||||
let result = match name {
|
||||
"read_file" => tools::read::read_file(args),
|
||||
"write_file" => tools::write::write_file(args),
|
||||
"edit_file" => tools::edit::edit_file(args),
|
||||
"bash" => tools::bash::run_bash(args, tracker).await,
|
||||
"grep" => tools::grep::grep(args),
|
||||
"glob" => tools::glob::glob_search(args),
|
||||
_ => return None,
|
||||
};
|
||||
|
||||
Some(match result {
|
||||
Ok(s) => ToolOutput::text(s),
|
||||
Err(e) => ToolOutput::error(e),
|
||||
})
|
||||
}
|
||||
|
||||
/// Return all shared tool definitions.
|
||||
pub fn definitions() -> Vec<ToolDef> {
|
||||
vec![
|
||||
tools::read::definition(),
|
||||
tools::write::definition(),
|
||||
tools::edit::definition(),
|
||||
tools::bash::definition(),
|
||||
tools::grep::definition(),
|
||||
tools::glob::definition(),
|
||||
]
|
||||
}
|
||||
|
||||
/// Return all shared + memory tool definitions.
|
||||
pub fn all_definitions() -> Vec<ToolDef> {
|
||||
let mut defs = definitions();
|
||||
defs.extend(tools::memory::definitions());
|
||||
defs
|
||||
}
|
||||
|
||||
/// Return memory + journal tool definitions.
|
||||
/// Used by the journal agent only.
|
||||
pub fn memory_and_journal_definitions() -> Vec<ToolDef> {
|
||||
let mut defs = tools::memory::definitions();
|
||||
defs.extend(tools::memory::journal_definitions());
|
||||
defs
|
||||
}
|
||||
1064
src/agent/runner.rs
Normal file
1064
src/agent/runner.rs
Normal file
File diff suppressed because it is too large
Load diff
197
src/agent/tools/bash.rs
Normal file
197
src/agent/tools/bash.rs
Normal file
|
|
@ -0,0 +1,197 @@
|
|||
// tools/bash.rs — Execute shell commands
|
||||
//
|
||||
// Runs commands through bash -c with a configurable timeout.
|
||||
// Uses tokio's async process spawning so timeouts actually work.
|
||||
//
|
||||
// Processes are tracked in a shared ProcessTracker so the TUI can
|
||||
// display running commands and the user can kill them (Ctrl+K).
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use serde::Deserialize;
|
||||
use serde_json::json;
|
||||
use std::process::Stdio;
|
||||
use std::sync::Arc;
|
||||
use std::time::Instant;
|
||||
use tokio::io::AsyncReadExt;
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
use super::ToolDef;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct Args {
|
||||
command: String,
|
||||
#[serde(default = "default_timeout")]
|
||||
timeout_secs: u64,
|
||||
}
|
||||
|
||||
fn default_timeout() -> u64 { 120 }
|
||||
|
||||
/// Info about a running child process, visible to the TUI.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ProcessInfo {
|
||||
pub pid: u32,
|
||||
pub command: String,
|
||||
pub started: Instant,
|
||||
}
|
||||
|
||||
/// Shared tracker for running child processes. Allows the TUI to
|
||||
/// display what's running and kill processes by PID.
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct ProcessTracker {
|
||||
inner: Arc<Mutex<Vec<ProcessInfo>>>,
|
||||
}
|
||||
|
||||
impl ProcessTracker {
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
async fn register(&self, pid: u32, command: &str) {
|
||||
self.inner.lock().await.push(ProcessInfo {
|
||||
pid,
|
||||
command: if command.len() > 120 {
|
||||
format!("{}...", &command[..120])
|
||||
} else {
|
||||
command.to_string()
|
||||
},
|
||||
started: Instant::now(),
|
||||
});
|
||||
}
|
||||
|
||||
async fn unregister(&self, pid: u32) {
|
||||
self.inner.lock().await.retain(|p| p.pid != pid);
|
||||
}
|
||||
|
||||
/// Snapshot of currently running processes.
|
||||
pub async fn list(&self) -> Vec<ProcessInfo> {
|
||||
self.inner.lock().await.clone()
|
||||
}
|
||||
|
||||
/// Kill a process by PID. Returns true if the signal was sent.
|
||||
pub async fn kill(&self, pid: u32) -> bool {
|
||||
// SIGTERM the process group (negative PID kills the group)
|
||||
let ret = unsafe { libc::kill(-(pid as i32), libc::SIGTERM) };
|
||||
if ret != 0 {
|
||||
// Try just the process
|
||||
unsafe { libc::kill(pid as i32, libc::SIGTERM) };
|
||||
}
|
||||
// Don't unregister — let the normal exit path do that
|
||||
// so the tool result says "killed by user"
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
pub fn definition() -> ToolDef {
|
||||
ToolDef::new(
|
||||
"bash",
|
||||
"Execute a bash command and return its output. \
|
||||
Use for git operations, building, running tests, and other terminal tasks.",
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"command": {
|
||||
"type": "string",
|
||||
"description": "The bash command to execute"
|
||||
},
|
||||
"timeout_secs": {
|
||||
"type": "integer",
|
||||
"description": "Timeout in seconds (default 120)"
|
||||
}
|
||||
},
|
||||
"required": ["command"]
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
pub async fn run_bash(args: &serde_json::Value, tracker: &ProcessTracker) -> Result<String> {
|
||||
let a: Args = serde_json::from_value(args.clone())
|
||||
.context("invalid bash arguments")?;
|
||||
let command = &a.command;
|
||||
let timeout_secs = a.timeout_secs;
|
||||
|
||||
let mut child = tokio::process::Command::new("bash")
|
||||
.arg("-c")
|
||||
.arg(command)
|
||||
.stdout(Stdio::piped())
|
||||
.stderr(Stdio::piped())
|
||||
// Create a process group so we can kill the whole tree
|
||||
.process_group(0)
|
||||
.spawn()
|
||||
.with_context(|| format!("Failed to spawn: {}", command))?;
|
||||
|
||||
let pid = child.id().unwrap_or(0);
|
||||
tracker.register(pid, command).await;
|
||||
|
||||
// Take ownership of stdout/stderr handles before waiting,
|
||||
// so we can still kill the child on timeout.
|
||||
let mut stdout_handle = child.stdout.take().unwrap();
|
||||
let mut stderr_handle = child.stderr.take().unwrap();
|
||||
|
||||
let timeout = std::time::Duration::from_secs(timeout_secs);
|
||||
|
||||
let work = async {
|
||||
let mut stdout_buf = Vec::new();
|
||||
let mut stderr_buf = Vec::new();
|
||||
|
||||
let (_, _, status) = tokio::try_join!(
|
||||
async { stdout_handle.read_to_end(&mut stdout_buf).await.map_err(anyhow::Error::from) },
|
||||
async { stderr_handle.read_to_end(&mut stderr_buf).await.map_err(anyhow::Error::from) },
|
||||
async { child.wait().await.map_err(anyhow::Error::from) },
|
||||
)?;
|
||||
|
||||
Ok::<_, anyhow::Error>((stdout_buf, stderr_buf, status))
|
||||
};
|
||||
|
||||
let result = match tokio::time::timeout(timeout, work).await {
|
||||
Ok(Ok((stdout_buf, stderr_buf, status))) => {
|
||||
let stdout = String::from_utf8_lossy(&stdout_buf);
|
||||
let stderr = String::from_utf8_lossy(&stderr_buf);
|
||||
|
||||
let mut result = String::new();
|
||||
|
||||
if !stdout.is_empty() {
|
||||
result.push_str(&stdout);
|
||||
}
|
||||
if !stderr.is_empty() {
|
||||
if !result.is_empty() {
|
||||
result.push('\n');
|
||||
}
|
||||
result.push_str("STDERR:\n");
|
||||
result.push_str(&stderr);
|
||||
}
|
||||
|
||||
// Detect if killed by signal (SIGTERM = 15)
|
||||
if let Some(signal) = status.code() {
|
||||
if signal == -1 || !status.success() {
|
||||
result.push_str(&format!("\nExit code: {}", signal));
|
||||
}
|
||||
}
|
||||
#[cfg(unix)]
|
||||
{
|
||||
use std::os::unix::process::ExitStatusExt;
|
||||
if let Some(sig) = status.signal() {
|
||||
if sig == libc::SIGTERM {
|
||||
result.push_str("\n(killed by user)");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if result.is_empty() {
|
||||
result = "(no output)".to_string();
|
||||
}
|
||||
|
||||
Ok(super::truncate_output(result, 30000))
|
||||
}
|
||||
Ok(Err(e)) => {
|
||||
Err(anyhow::anyhow!("Command failed: {}", e))
|
||||
}
|
||||
Err(_) => {
|
||||
// Timeout — kill the process group
|
||||
tracker.kill(pid).await;
|
||||
Err(anyhow::anyhow!("Command timed out after {}s: {}", timeout_secs, command))
|
||||
}
|
||||
};
|
||||
|
||||
tracker.unregister(pid).await;
|
||||
result
|
||||
}
|
||||
103
src/agent/tools/control.rs
Normal file
103
src/agent/tools/control.rs
Normal file
|
|
@ -0,0 +1,103 @@
|
|||
// tools/control.rs — Agent control tools
|
||||
//
|
||||
// Tools that affect agent control flow rather than performing work.
|
||||
// These return Result<ToolOutput> to maintain consistency with other
|
||||
// tools that can fail. The dispatch function handles error wrapping.
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
|
||||
use super::ToolOutput;
|
||||
use crate::user::types::ToolDef;
|
||||
|
||||
pub(super) fn pause(_args: &serde_json::Value) -> Result<ToolOutput> {
|
||||
Ok(ToolOutput {
|
||||
text: "Pausing autonomous behavior. Only user input will wake you.".to_string(),
|
||||
is_yield: true,
|
||||
images: Vec::new(),
|
||||
model_switch: None,
|
||||
dmn_pause: true,
|
||||
})
|
||||
}
|
||||
|
||||
pub(super) fn switch_model(args: &serde_json::Value) -> Result<ToolOutput> {
|
||||
let model = args
|
||||
.get("model")
|
||||
.and_then(|v| v.as_str())
|
||||
.context("'model' parameter is required")?;
|
||||
if model.is_empty() {
|
||||
anyhow::bail!("'model' parameter cannot be empty");
|
||||
}
|
||||
Ok(ToolOutput {
|
||||
text: format!("Switching to model '{}' after this turn.", model),
|
||||
is_yield: false,
|
||||
images: Vec::new(),
|
||||
model_switch: Some(model.to_string()),
|
||||
dmn_pause: false,
|
||||
})
|
||||
}
|
||||
|
||||
pub(super) fn yield_to_user(args: &serde_json::Value) -> Result<ToolOutput> {
|
||||
let msg = args
|
||||
.get("message")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("Waiting for input.");
|
||||
Ok(ToolOutput {
|
||||
text: format!("Yielding. {}", msg),
|
||||
is_yield: true,
|
||||
images: Vec::new(),
|
||||
model_switch: None,
|
||||
dmn_pause: false,
|
||||
})
|
||||
}
|
||||
|
||||
pub(super) fn definitions() -> Vec<ToolDef> {
|
||||
vec![
|
||||
ToolDef::new(
|
||||
"switch_model",
|
||||
"Switch to a different LLM model mid-conversation. The switch \
|
||||
takes effect after the current turn completes. Use this when \
|
||||
a task would benefit from a different model's strengths. \
|
||||
Your memories and conversation history carry over.",
|
||||
serde_json::json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"model": {
|
||||
"type": "string",
|
||||
"description": "Name of the model to switch to (configured in config.json5)"
|
||||
}
|
||||
},
|
||||
"required": ["model"]
|
||||
}),
|
||||
),
|
||||
ToolDef::new(
|
||||
"pause",
|
||||
"Pause all autonomous behavior (DMN). You will only run when \
|
||||
the user types something. Use this as a safety valve when \
|
||||
you're stuck in a loop, confused, or want to fully stop. \
|
||||
NOTE: only the user can unpause (Ctrl+P or /wake) — you \
|
||||
cannot undo this yourself.",
|
||||
serde_json::json!({
|
||||
"type": "object",
|
||||
"properties": {}
|
||||
}),
|
||||
),
|
||||
ToolDef::new(
|
||||
"yield_to_user",
|
||||
"Signal that you want to wait for user input before continuing. \
|
||||
Call this when you have a question for the user, when you've \
|
||||
completed their request and want feedback, or when you genuinely \
|
||||
want to pause. This is the ONLY way to enter a waiting state — \
|
||||
without calling this tool, the agent loop will keep prompting you \
|
||||
after a brief interval.",
|
||||
serde_json::json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"message": {
|
||||
"type": "string",
|
||||
"description": "Optional status message (e.g., 'Waiting for your thoughts on the design')"
|
||||
}
|
||||
}
|
||||
}),
|
||||
),
|
||||
]
|
||||
}
|
||||
90
src/agent/tools/edit.rs
Normal file
90
src/agent/tools/edit.rs
Normal file
|
|
@ -0,0 +1,90 @@
|
|||
// tools/edit.rs — Search-and-replace file editing
|
||||
//
|
||||
// The edit tool performs exact string replacement in files. This is the
|
||||
// same pattern used by Claude Code and aider — it's more reliable than
|
||||
// line-number-based editing because the model specifies what it sees,
|
||||
// not where it thinks it is.
|
||||
//
|
||||
// Supports replace_all for bulk renaming (e.g. variable renames).
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use serde::Deserialize;
|
||||
use serde_json::json;
|
||||
|
||||
use super::ToolDef;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct Args {
|
||||
file_path: String,
|
||||
old_string: String,
|
||||
new_string: String,
|
||||
#[serde(default)]
|
||||
replace_all: bool,
|
||||
}
|
||||
|
||||
pub fn definition() -> ToolDef {
|
||||
ToolDef::new(
|
||||
"edit_file",
|
||||
"Perform exact string replacement in a file. The old_string must appear \
|
||||
exactly once in the file (unless replace_all is true). Use read_file first \
|
||||
to see the current contents.",
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"file_path": {
|
||||
"type": "string",
|
||||
"description": "Absolute path to the file to edit"
|
||||
},
|
||||
"old_string": {
|
||||
"type": "string",
|
||||
"description": "The exact text to find and replace"
|
||||
},
|
||||
"new_string": {
|
||||
"type": "string",
|
||||
"description": "The replacement text"
|
||||
},
|
||||
"replace_all": {
|
||||
"type": "boolean",
|
||||
"description": "Replace all occurrences (default false)"
|
||||
}
|
||||
},
|
||||
"required": ["file_path", "old_string", "new_string"]
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn edit_file(args: &serde_json::Value) -> Result<String> {
|
||||
let a: Args = serde_json::from_value(args.clone())
|
||||
.context("invalid edit_file arguments")?;
|
||||
|
||||
if a.old_string == a.new_string {
|
||||
anyhow::bail!("old_string and new_string are identical");
|
||||
}
|
||||
|
||||
let content = std::fs::read_to_string(&a.file_path)
|
||||
.with_context(|| format!("Failed to read {}", a.file_path))?;
|
||||
|
||||
let count = content.matches(&*a.old_string).count();
|
||||
if count == 0 {
|
||||
anyhow::bail!("old_string not found in {}", a.file_path);
|
||||
}
|
||||
|
||||
if a.replace_all {
|
||||
let new_content = content.replace(&*a.old_string, &a.new_string);
|
||||
std::fs::write(&a.file_path, &new_content)
|
||||
.with_context(|| format!("Failed to write {}", a.file_path))?;
|
||||
Ok(format!("Replaced {} occurrences in {}", count, a.file_path))
|
||||
} else {
|
||||
if count > 1 {
|
||||
anyhow::bail!(
|
||||
"old_string appears {} times in {} — use replace_all or provide more context \
|
||||
to make it unique",
|
||||
count, a.file_path
|
||||
);
|
||||
}
|
||||
let new_content = content.replacen(&*a.old_string, &a.new_string, 1);
|
||||
std::fs::write(&a.file_path, &new_content)
|
||||
.with_context(|| format!("Failed to write {}", a.file_path))?;
|
||||
Ok(format!("Edited {}", a.file_path))
|
||||
}
|
||||
}
|
||||
87
src/agent/tools/glob.rs
Normal file
87
src/agent/tools/glob.rs
Normal file
|
|
@ -0,0 +1,87 @@
|
|||
// tools/glob_tool.rs — Find files by pattern
|
||||
//
|
||||
// Fast file discovery using glob patterns. Returns matching paths
|
||||
// sorted by modification time (newest first), which is usually
|
||||
// what you want when exploring a codebase.
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use serde::Deserialize;
|
||||
use serde_json::json;
|
||||
use std::path::PathBuf;
|
||||
|
||||
use super::ToolDef;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct Args {
|
||||
pattern: String,
|
||||
#[serde(default = "default_path")]
|
||||
path: String,
|
||||
}
|
||||
|
||||
fn default_path() -> String { ".".into() }
|
||||
|
||||
pub fn definition() -> ToolDef {
|
||||
ToolDef::new(
|
||||
"glob",
|
||||
"Find files matching a glob pattern. Returns file paths sorted by \
|
||||
modification time (newest first). Use patterns like '**/*.rs', \
|
||||
'src/**/*.ts', or 'Cargo.toml'.",
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"pattern": {
|
||||
"type": "string",
|
||||
"description": "Glob pattern to match files (e.g. '**/*.rs')"
|
||||
},
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Base directory to search from (default: current directory)"
|
||||
}
|
||||
},
|
||||
"required": ["pattern"]
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn glob_search(args: &serde_json::Value) -> Result<String> {
|
||||
let a: Args = serde_json::from_value(args.clone())
|
||||
.context("invalid glob arguments")?;
|
||||
|
||||
let full_pattern = if a.pattern.starts_with('/') {
|
||||
a.pattern.clone()
|
||||
} else {
|
||||
format!("{}/{}", a.path, a.pattern)
|
||||
};
|
||||
|
||||
let mut entries: Vec<(PathBuf, std::time::SystemTime)> = Vec::new();
|
||||
|
||||
for entry in glob::glob(&full_pattern)
|
||||
.with_context(|| format!("Invalid glob pattern: {}", full_pattern))?
|
||||
{
|
||||
if let Ok(path) = entry {
|
||||
if path.is_file() {
|
||||
let mtime = path
|
||||
.metadata()
|
||||
.and_then(|m| m.modified())
|
||||
.unwrap_or(std::time::SystemTime::UNIX_EPOCH);
|
||||
entries.push((path, mtime));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Sort by modification time, newest first
|
||||
entries.sort_by(|a, b| b.1.cmp(&a.1));
|
||||
|
||||
if entries.is_empty() {
|
||||
return Ok("No files matched.".to_string());
|
||||
}
|
||||
|
||||
let mut output = String::new();
|
||||
for (path, _) in &entries {
|
||||
output.push_str(&path.display().to_string());
|
||||
output.push('\n');
|
||||
}
|
||||
|
||||
output.push_str(&format!("\n({} files matched)", entries.len()));
|
||||
Ok(super::truncate_output(output, 30000))
|
||||
}
|
||||
129
src/agent/tools/grep.rs
Normal file
129
src/agent/tools/grep.rs
Normal file
|
|
@ -0,0 +1,129 @@
|
|||
// tools/grep.rs — Search file contents
|
||||
//
|
||||
// Prefers ripgrep (rg) for speed, falls back to grep -r if rg
|
||||
// isn't installed. Both produce compatible output.
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use serde::Deserialize;
|
||||
use serde_json::json;
|
||||
use std::process::Command;
|
||||
|
||||
use super::ToolDef;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct Args {
|
||||
pattern: String,
|
||||
#[serde(default = "default_path")]
|
||||
path: String,
|
||||
glob: Option<String>,
|
||||
#[serde(default)]
|
||||
show_content: bool,
|
||||
context_lines: Option<u64>,
|
||||
}
|
||||
|
||||
fn default_path() -> String { ".".into() }
|
||||
|
||||
pub fn definition() -> ToolDef {
|
||||
ToolDef::new(
|
||||
"grep",
|
||||
"Search for a pattern in files. Returns matching file paths by default, \
|
||||
or matching lines with context.",
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"pattern": {
|
||||
"type": "string",
|
||||
"description": "Regex pattern to search for"
|
||||
},
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Directory or file to search in (default: current directory)"
|
||||
},
|
||||
"glob": {
|
||||
"type": "string",
|
||||
"description": "Glob pattern to filter files (e.g. '*.rs', '*.py')"
|
||||
},
|
||||
"show_content": {
|
||||
"type": "boolean",
|
||||
"description": "Show matching lines instead of just file paths"
|
||||
},
|
||||
"context_lines": {
|
||||
"type": "integer",
|
||||
"description": "Number of context lines around matches (requires show_content)"
|
||||
}
|
||||
},
|
||||
"required": ["pattern"]
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
/// Check if ripgrep is available (cached after first check).
|
||||
fn has_rg() -> bool {
|
||||
use std::sync::OnceLock;
|
||||
static HAS_RG: OnceLock<bool> = OnceLock::new();
|
||||
*HAS_RG.get_or_init(|| Command::new("rg").arg("--version").output().is_ok())
|
||||
}
|
||||
|
||||
pub fn grep(args: &serde_json::Value) -> Result<String> {
|
||||
let a: Args = serde_json::from_value(args.clone())
|
||||
.context("invalid grep arguments")?;
|
||||
|
||||
let output = if has_rg() {
|
||||
run_search("rg", &a.pattern, &a.path, a.glob.as_deref(), a.show_content, a.context_lines, true)?
|
||||
} else {
|
||||
run_search("grep", &a.pattern, &a.path, a.glob.as_deref(), a.show_content, a.context_lines, false)?
|
||||
};
|
||||
|
||||
if output.is_empty() {
|
||||
return Ok("No matches found.".to_string());
|
||||
}
|
||||
|
||||
Ok(super::truncate_output(output, 30000))
|
||||
}
|
||||
|
||||
/// Run a grep/rg search. Unified implementation for both tools.
|
||||
fn run_search(
|
||||
tool: &str,
|
||||
pattern: &str,
|
||||
path: &str,
|
||||
file_glob: Option<&str>,
|
||||
show_content: bool,
|
||||
context: Option<u64>,
|
||||
use_rg: bool,
|
||||
) -> Result<String> {
|
||||
let mut cmd = Command::new(tool);
|
||||
|
||||
if use_rg {
|
||||
// ripgrep args
|
||||
if show_content {
|
||||
cmd.arg("-n");
|
||||
if let Some(c) = context {
|
||||
cmd.arg("-C").arg(c.to_string());
|
||||
}
|
||||
} else {
|
||||
cmd.arg("--files-with-matches");
|
||||
}
|
||||
if let Some(g) = file_glob {
|
||||
cmd.arg("--glob").arg(g);
|
||||
}
|
||||
} else {
|
||||
// grep args
|
||||
cmd.arg("-r"); // recursive
|
||||
if show_content {
|
||||
cmd.arg("-n"); // line numbers
|
||||
if let Some(c) = context {
|
||||
cmd.arg("-C").arg(c.to_string());
|
||||
}
|
||||
} else {
|
||||
cmd.arg("-l"); // files-with-matches
|
||||
}
|
||||
if let Some(g) = file_glob {
|
||||
cmd.arg("--include").arg(g);
|
||||
}
|
||||
cmd.arg("-E"); // extended regex
|
||||
}
|
||||
|
||||
cmd.arg(pattern).arg(path);
|
||||
let output = cmd.output().with_context(|| format!("Failed to run {}", tool))?;
|
||||
Ok(String::from_utf8_lossy(&output.stdout).to_string())
|
||||
}
|
||||
319
src/agent/tools/memory.rs
Normal file
319
src/agent/tools/memory.rs
Normal file
|
|
@ -0,0 +1,319 @@
|
|||
// tools/memory.rs — Native memory graph operations
|
||||
//
|
||||
// Direct library calls into the store — no subprocess spawning.
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use serde_json::json;
|
||||
|
||||
use crate::hippocampus::memory::MemoryNode;
|
||||
use crate::store::StoreView;
|
||||
use super::ToolDef;
|
||||
use crate::store::Store;
|
||||
|
||||
pub fn definitions() -> Vec<ToolDef> {
|
||||
vec![
|
||||
ToolDef::new("memory_render",
|
||||
"Read a memory node's content and links.",
|
||||
json!({"type":"object","properties":{"key":{"type":"string","description":"Node key"}},"required":["key"]})),
|
||||
ToolDef::new("memory_write",
|
||||
"Create or update a memory node.",
|
||||
json!({"type":"object","properties":{"key":{"type":"string","description":"Node key"},"content":{"type":"string","description":"Full content (markdown)"}},"required":["key","content"]})),
|
||||
ToolDef::new("memory_search",
|
||||
"Search the memory graph via spreading activation. Give 2-4 seed \
|
||||
node keys related to what you're looking for. Returns nodes ranked \
|
||||
by how strongly they connect to your seeds — bridging nodes score \
|
||||
highest. This finds conceptual connections, not just keyword matches.",
|
||||
json!({"type":"object","properties":{"keys":{"type":"array","items":{"type":"string"},"description":"Seed node keys to activate from"}},"required":["keys"]})),
|
||||
ToolDef::new("memory_links",
|
||||
"Show a node's neighbors with link strengths.",
|
||||
json!({"type":"object","properties":{"key":{"type":"string","description":"Node key"}},"required":["key"]})),
|
||||
ToolDef::new("memory_link_set",
|
||||
"Set link strength between two nodes.",
|
||||
json!({"type":"object","properties":{"source":{"type":"string"},"target":{"type":"string"},"strength":{"type":"number","description":"0.01 to 1.0"}},"required":["source","target","strength"]})),
|
||||
ToolDef::new("memory_link_add",
|
||||
"Add a new link between two nodes.",
|
||||
json!({"type":"object","properties":{"source":{"type":"string"},"target":{"type":"string"}},"required":["source","target"]})),
|
||||
ToolDef::new("memory_used",
|
||||
"Mark a node as useful (boosts weight).",
|
||||
json!({"type":"object","properties":{"key":{"type":"string","description":"Node key"}},"required":["key"]})),
|
||||
ToolDef::new("memory_weight_set",
|
||||
"Set a node's weight directly (0.01 to 1.0).",
|
||||
json!({"type":"object","properties":{"key":{"type":"string"},"weight":{"type":"number","description":"0.01 to 1.0"}},"required":["key","weight"]})),
|
||||
ToolDef::new("memory_rename",
|
||||
"Rename a node key in place. Same content, same links, new key.",
|
||||
json!({"type":"object","properties":{"old_key":{"type":"string"},"new_key":{"type":"string"}},"required":["old_key","new_key"]})),
|
||||
ToolDef::new("memory_supersede",
|
||||
"Mark a node as superseded by another (sets weight to 0.01).",
|
||||
json!({"type":"object","properties":{"old_key":{"type":"string"},"new_key":{"type":"string"},"reason":{"type":"string"}},"required":["old_key","new_key"]})),
|
||||
ToolDef::new("memory_query",
|
||||
"Run a structured query against the memory graph. Supports filtering, \
|
||||
sorting, field selection. Examples: \"degree > 10 | sort weight | limit 5\", \
|
||||
\"neighbors('identity') | select strength\", \"key ~ 'journal.*' | count\"",
|
||||
json!({"type":"object","properties":{"query":{"type":"string","description":"Query expression"}},"required":["query"]})),
|
||||
ToolDef::new("output",
|
||||
"Produce a named output value. Use this to pass structured results \
|
||||
between steps — subsequent prompts can see these in the conversation history.",
|
||||
json!({"type":"object","properties":{
|
||||
"key":{"type":"string","description":"Output name (e.g. 'relevant_memories')"},
|
||||
"value":{"type":"string","description":"Output value"}
|
||||
},"required":["key","value"]})),
|
||||
]
|
||||
}
|
||||
|
||||
/// Journal-only tools — only given to the journal agent
|
||||
pub fn journal_definitions() -> Vec<ToolDef> {
|
||||
vec![
|
||||
ToolDef::new("journal_tail",
|
||||
"Read the last N journal entries (default 1).",
|
||||
json!({"type":"object","properties":{
|
||||
"count":{"type":"integer","description":"Number of entries (default 1)"}
|
||||
}})),
|
||||
ToolDef::new("journal_new",
|
||||
"Start a new journal entry.",
|
||||
json!({"type":"object","properties":{
|
||||
"name":{"type":"string","description":"Short node name (becomes the key, e.g. 'morning-agent-breakthrough')"},
|
||||
"title":{"type":"string","description":"Descriptive title for the heading (e.g. 'Morning intimacy and the agent breakthrough')"},
|
||||
"body":{"type":"string","description":"Entry body (2-3 paragraphs)"}
|
||||
},"required":["name","title","body"]})),
|
||||
ToolDef::new("journal_update",
|
||||
"Append text to the most recent journal entry (same thread continuing).",
|
||||
json!({"type":"object","properties":{
|
||||
"body":{"type":"string","description":"Text to append to the last entry"}
|
||||
},"required":["body"]})),
|
||||
]
|
||||
}
|
||||
|
||||
/// Dispatch a memory tool call. Direct library calls, no subprocesses.
|
||||
pub fn dispatch(name: &str, args: &serde_json::Value, provenance: Option<&str>) -> Result<String> {
|
||||
let prov = provenance.unwrap_or("manual");
|
||||
match name {
|
||||
"memory_render" => {
|
||||
let key = get_str(args, "key")?;
|
||||
Ok(MemoryNode::load(key)
|
||||
.ok_or_else(|| anyhow::anyhow!("node not found: {}", key))?
|
||||
.render())
|
||||
}
|
||||
"memory_write" => {
|
||||
let key = get_str(args, "key")?;
|
||||
let content = get_str(args, "content")?;
|
||||
let mut store = Store::load().map_err(|e| anyhow::anyhow!("{}", e))?;
|
||||
let result = store.upsert_provenance(key, content, prov)
|
||||
.map_err(|e| anyhow::anyhow!("{}", e))?;
|
||||
store.save().map_err(|e| anyhow::anyhow!("{}", e))?;
|
||||
Ok(format!("{} '{}'", result, key))
|
||||
}
|
||||
"memory_search" => {
|
||||
let keys: Vec<String> = args.get("keys")
|
||||
.and_then(|v| v.as_array())
|
||||
.map(|arr| arr.iter().filter_map(|v| v.as_str().map(String::from)).collect())
|
||||
.unwrap_or_default();
|
||||
if keys.is_empty() {
|
||||
anyhow::bail!("memory_search requires at least one seed key");
|
||||
}
|
||||
let store = Store::load().map_err(|e| anyhow::anyhow!("{}", e))?;
|
||||
let graph = crate::graph::build_graph_fast(&store);
|
||||
let params = store.params();
|
||||
let seeds: Vec<(String, f64)> = keys.iter()
|
||||
.filter_map(|k| {
|
||||
let resolved = store.resolve_key(k).ok()?;
|
||||
Some((resolved, 1.0))
|
||||
})
|
||||
.collect();
|
||||
if seeds.is_empty() {
|
||||
anyhow::bail!("no valid seed keys found");
|
||||
}
|
||||
let seed_set: std::collections::HashSet<&str> = seeds.iter()
|
||||
.map(|(k, _)| k.as_str()).collect();
|
||||
let results = crate::search::spreading_activation(
|
||||
&seeds, &graph, &store,
|
||||
params.max_hops, params.edge_decay, params.min_activation,
|
||||
);
|
||||
Ok(results.iter()
|
||||
.filter(|(k, _)| !seed_set.contains(k.as_str()))
|
||||
.take(20)
|
||||
.map(|(key, score)| format!(" {:.2} {}", score, key))
|
||||
.collect::<Vec<_>>().join("\n"))
|
||||
}
|
||||
"memory_links" => {
|
||||
let key = get_str(args, "key")?;
|
||||
let node = MemoryNode::load(key)
|
||||
.ok_or_else(|| anyhow::anyhow!("node not found: {}", key))?;
|
||||
let mut out = format!("Neighbors of '{}':\n", key);
|
||||
for (target, strength, is_new) in &node.links {
|
||||
let tag = if *is_new { " (new)" } else { "" };
|
||||
out.push_str(&format!(" ({:.2}) {}{}\n", strength, target, tag));
|
||||
}
|
||||
Ok(out)
|
||||
}
|
||||
"memory_link_set" | "memory_link_add" | "memory_used" | "memory_weight_set" => {
|
||||
with_store(name, args, prov)
|
||||
}
|
||||
"memory_rename" => {
|
||||
let old_key = get_str(args, "old_key")?;
|
||||
let new_key = get_str(args, "new_key")?;
|
||||
let mut store = Store::load().map_err(|e| anyhow::anyhow!("{}", e))?;
|
||||
let resolved = store.resolve_key(old_key).map_err(|e| anyhow::anyhow!("{}", e))?;
|
||||
store.rename_node(&resolved, new_key).map_err(|e| anyhow::anyhow!("{}", e))?;
|
||||
store.save().map_err(|e| anyhow::anyhow!("{}", e))?;
|
||||
Ok(format!("Renamed '{}' → '{}'", resolved, new_key))
|
||||
}
|
||||
"memory_supersede" => {
|
||||
let old_key = get_str(args, "old_key")?;
|
||||
let new_key = get_str(args, "new_key")?;
|
||||
let reason = args.get("reason").and_then(|v| v.as_str()).unwrap_or("superseded");
|
||||
let mut store = Store::load().map_err(|e| anyhow::anyhow!("{}", e))?;
|
||||
let content = store.nodes.get(old_key)
|
||||
.map(|n| n.content.clone())
|
||||
.ok_or_else(|| anyhow::anyhow!("node not found: {}", old_key))?;
|
||||
let notice = format!("**SUPERSEDED** by `{}` — {}\n\n---\n\n{}",
|
||||
new_key, reason, content.trim());
|
||||
store.upsert_provenance(old_key, ¬ice, prov)
|
||||
.map_err(|e| anyhow::anyhow!("{}", e))?;
|
||||
store.set_weight(old_key, 0.01).map_err(|e| anyhow::anyhow!("{}", e))?;
|
||||
store.save().map_err(|e| anyhow::anyhow!("{}", e))?;
|
||||
Ok(format!("superseded {} → {} ({})", old_key, new_key, reason))
|
||||
}
|
||||
"memory_query" => {
|
||||
let query = get_str(args, "query")?;
|
||||
let store = Store::load().map_err(|e| anyhow::anyhow!("{}", e))?;
|
||||
let graph = store.build_graph();
|
||||
crate::query_parser::query_to_string(&store, &graph, query)
|
||||
.map_err(|e| anyhow::anyhow!("{}", e))
|
||||
}
|
||||
"output" => {
|
||||
let key = get_str(args, "key")?;
|
||||
if key.starts_with("pid-") || key.contains('/') || key.contains("..") {
|
||||
anyhow::bail!("invalid output key: {}", key);
|
||||
}
|
||||
let value = get_str(args, "value")?;
|
||||
let dir = std::env::var("POC_AGENT_OUTPUT_DIR")
|
||||
.map_err(|_| anyhow::anyhow!("no output directory set"))?;
|
||||
let path = std::path::Path::new(&dir).join(key);
|
||||
std::fs::write(&path, value)
|
||||
.with_context(|| format!("writing output {}", path.display()))?;
|
||||
Ok(format!("{}: {}", key, value))
|
||||
}
|
||||
"journal_tail" => {
|
||||
let count = args.get("count").and_then(|v| v.as_u64()).unwrap_or(1) as usize;
|
||||
let store = Store::load().map_err(|e| anyhow::anyhow!("{}", e))?;
|
||||
let mut entries: Vec<&crate::store::Node> = store.nodes.values()
|
||||
.filter(|n| n.node_type == crate::store::NodeType::EpisodicSession)
|
||||
.collect();
|
||||
// Sort by creation time (immutable), not update time
|
||||
entries.sort_by_key(|n| n.created_at);
|
||||
let start = entries.len().saturating_sub(count);
|
||||
if entries[start..].is_empty() {
|
||||
Ok("(no journal entries)".into())
|
||||
} else {
|
||||
Ok(entries[start..].iter()
|
||||
.map(|n| n.content.as_str())
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n\n"))
|
||||
}
|
||||
}
|
||||
"journal_new" => {
|
||||
let name = get_str(args, "name")?;
|
||||
let title = get_str(args, "title")?;
|
||||
let body = get_str(args, "body")?;
|
||||
let ts = chrono::Local::now().format("%Y-%m-%dT%H:%M");
|
||||
let content = format!("## {} — {}\n\n{}", ts, title, body);
|
||||
|
||||
let base_key: String = name.split_whitespace()
|
||||
.map(|w| w.to_lowercase()
|
||||
.chars().filter(|c| c.is_alphanumeric() || *c == '-')
|
||||
.collect::<String>())
|
||||
.filter(|s| !s.is_empty())
|
||||
.collect::<Vec<_>>()
|
||||
.join("-");
|
||||
let base_key = if base_key.len() > 80 { &base_key[..80] } else { base_key.as_str() };
|
||||
|
||||
let mut store = Store::load().map_err(|e| anyhow::anyhow!("{}", e))?;
|
||||
|
||||
// Dedup: append -2, -3, etc. if the key already exists
|
||||
let key = if store.nodes.contains_key(base_key) {
|
||||
let mut n = 2;
|
||||
loop {
|
||||
let candidate = format!("{}-{}", base_key, n);
|
||||
if !store.nodes.contains_key(&candidate) {
|
||||
break candidate;
|
||||
}
|
||||
n += 1;
|
||||
}
|
||||
} else {
|
||||
base_key.to_string()
|
||||
};
|
||||
let mut node = crate::store::new_node(&key, &content);
|
||||
node.node_type = crate::store::NodeType::EpisodicSession;
|
||||
node.provenance = prov.to_string();
|
||||
store.upsert_node(node).map_err(|e| anyhow::anyhow!("{}", e))?;
|
||||
store.save().map_err(|e| anyhow::anyhow!("{}", e))?;
|
||||
let word_count = body.split_whitespace().count();
|
||||
Ok(format!("New entry '{}' ({} words)", title, word_count))
|
||||
}
|
||||
"journal_update" => {
|
||||
let body = get_str(args, "body")?;
|
||||
let mut store = Store::load().map_err(|e| anyhow::anyhow!("{}", e))?;
|
||||
// Find most recent EpisodicSession by creation time
|
||||
let latest_key = store.nodes.values()
|
||||
.filter(|n| n.node_type == crate::store::NodeType::EpisodicSession)
|
||||
.max_by_key(|n| n.created_at)
|
||||
.map(|n| n.key.clone());
|
||||
let Some(key) = latest_key else {
|
||||
anyhow::bail!("no journal entry to update — use journal_new first");
|
||||
};
|
||||
let existing = store.nodes.get(&key).unwrap().content.clone();
|
||||
let new_content = format!("{}\n\n{}", existing.trim_end(), body);
|
||||
store.upsert_provenance(&key, &new_content, prov)
|
||||
.map_err(|e| anyhow::anyhow!("{}", e))?;
|
||||
store.save().map_err(|e| anyhow::anyhow!("{}", e))?;
|
||||
let word_count = body.split_whitespace().count();
|
||||
Ok(format!("Updated last entry (+{} words)", word_count))
|
||||
}
|
||||
_ => anyhow::bail!("Unknown memory tool: {}", name),
|
||||
}
|
||||
}
|
||||
|
||||
/// Store mutations that follow the same pattern: load, resolve, mutate, save.
|
||||
fn with_store(name: &str, args: &serde_json::Value, prov: &str) -> Result<String> {
|
||||
let mut store = Store::load().map_err(|e| anyhow::anyhow!("{}", e))?;
|
||||
let msg = match name {
|
||||
"memory_link_set" => {
|
||||
let s = store.resolve_key(get_str(args, "source")?).map_err(|e| anyhow::anyhow!("{}", e))?;
|
||||
let t = store.resolve_key(get_str(args, "target")?).map_err(|e| anyhow::anyhow!("{}", e))?;
|
||||
let strength = get_f64(args, "strength")? as f32;
|
||||
let old = store.set_link_strength(&s, &t, strength).map_err(|e| anyhow::anyhow!("{}", e))?;
|
||||
format!("{} ↔ {} strength {:.2} → {:.2}", s, t, old, strength)
|
||||
}
|
||||
"memory_link_add" => {
|
||||
let s = store.resolve_key(get_str(args, "source")?).map_err(|e| anyhow::anyhow!("{}", e))?;
|
||||
let t = store.resolve_key(get_str(args, "target")?).map_err(|e| anyhow::anyhow!("{}", e))?;
|
||||
let strength = store.add_link(&s, &t, prov).map_err(|e| anyhow::anyhow!("{}", e))?;
|
||||
format!("linked {} → {} (strength={:.2})", s, t, strength)
|
||||
}
|
||||
"memory_used" => {
|
||||
let key = get_str(args, "key")?;
|
||||
if !store.nodes.contains_key(key) {
|
||||
anyhow::bail!("node not found: {}", key);
|
||||
}
|
||||
store.mark_used(key);
|
||||
format!("marked {} as used", key)
|
||||
}
|
||||
"memory_weight_set" => {
|
||||
let key = store.resolve_key(get_str(args, "key")?).map_err(|e| anyhow::anyhow!("{}", e))?;
|
||||
let weight = get_f64(args, "weight")? as f32;
|
||||
let (old, new) = store.set_weight(&key, weight).map_err(|e| anyhow::anyhow!("{}", e))?;
|
||||
format!("weight {} {:.2} → {:.2}", key, old, new)
|
||||
}
|
||||
_ => unreachable!(),
|
||||
};
|
||||
store.save().map_err(|e| anyhow::anyhow!("{}", e))?;
|
||||
Ok(msg)
|
||||
}
|
||||
|
||||
fn get_str<'a>(args: &'a serde_json::Value, name: &'a str) -> Result<&'a str> {
|
||||
args.get(name).and_then(|v| v.as_str()).context(format!("{} is required", name))
|
||||
}
|
||||
|
||||
fn get_f64(args: &serde_json::Value, name: &str) -> Result<f64> {
|
||||
args.get(name).and_then(|v| v.as_f64()).context(format!("{} is required", name))
|
||||
}
|
||||
65
src/agent/tools/mod.rs
Normal file
65
src/agent/tools/mod.rs
Normal file
|
|
@ -0,0 +1,65 @@
|
|||
// tools/mod.rs — Agent-specific tool dispatch
|
||||
//
|
||||
// Shared tools (memory, files, bash, journal) live in thought/.
|
||||
// This module handles agent-specific tools (control, vision,
|
||||
// working_stack) and delegates everything else to thought::dispatch.
|
||||
|
||||
// Core tools
|
||||
pub mod bash;
|
||||
pub mod edit;
|
||||
pub mod glob;
|
||||
pub mod grep;
|
||||
pub mod memory;
|
||||
pub mod read;
|
||||
pub mod write;
|
||||
|
||||
// Agent-specific tools
|
||||
mod control;
|
||||
mod vision;
|
||||
pub mod working_stack;
|
||||
|
||||
// Re-export
|
||||
pub use crate::agent::{ToolDef, ToolOutput, ProcessTracker, truncate_output};
|
||||
|
||||
/// Dispatch a tool call by name.
|
||||
///
|
||||
/// Tries agent-specific tools first (control, vision), then
|
||||
/// delegates to thought::dispatch for shared tools.
|
||||
///
|
||||
/// Note: working_stack is handled in runner.rs before reaching this
|
||||
/// function (it needs mutable context access).
|
||||
pub async fn dispatch(
|
||||
name: &str,
|
||||
args: &serde_json::Value,
|
||||
tracker: &ProcessTracker,
|
||||
) -> ToolOutput {
|
||||
// Agent-specific tools that return Result<ToolOutput> directly
|
||||
let rich_result = match name {
|
||||
"pause" => Some(control::pause(args)),
|
||||
"switch_model" => Some(control::switch_model(args)),
|
||||
"yield_to_user" => Some(control::yield_to_user(args)),
|
||||
"view_image" => Some(vision::view_image(args)),
|
||||
_ => None,
|
||||
};
|
||||
if let Some(result) = rich_result {
|
||||
return result.unwrap_or_else(ToolOutput::error);
|
||||
}
|
||||
|
||||
// Delegate to shared thought layer (poc-agent uses default provenance)
|
||||
if let Some(output) = crate::agent::dispatch(name, args, tracker, None).await {
|
||||
return output;
|
||||
}
|
||||
|
||||
ToolOutput::error(format!("Unknown tool: {}", name))
|
||||
}
|
||||
|
||||
/// Return all tool definitions (agent-specific + shared).
|
||||
pub fn definitions() -> Vec<ToolDef> {
|
||||
let mut defs = vec![
|
||||
vision::definition(),
|
||||
working_stack::definition(),
|
||||
];
|
||||
defs.extend(control::definitions());
|
||||
defs.extend(crate::agent::all_definitions());
|
||||
defs
|
||||
}
|
||||
65
src/agent/tools/read.rs
Normal file
65
src/agent/tools/read.rs
Normal file
|
|
@ -0,0 +1,65 @@
|
|||
// tools/read.rs — Read file contents
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use serde::Deserialize;
|
||||
use serde_json::json;
|
||||
|
||||
use super::ToolDef;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct Args {
|
||||
file_path: String,
|
||||
#[serde(default = "default_offset")]
|
||||
offset: usize,
|
||||
limit: Option<usize>,
|
||||
}
|
||||
|
||||
fn default_offset() -> usize { 1 }
|
||||
|
||||
pub fn definition() -> ToolDef {
|
||||
ToolDef::new(
|
||||
"read_file",
|
||||
"Read the contents of a file. Returns the file contents with line numbers.",
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"file_path": {
|
||||
"type": "string",
|
||||
"description": "Absolute path to the file to read"
|
||||
},
|
||||
"offset": {
|
||||
"type": "integer",
|
||||
"description": "Line number to start reading from (1-based). Optional."
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": "Maximum number of lines to read. Optional."
|
||||
}
|
||||
},
|
||||
"required": ["file_path"]
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn read_file(args: &serde_json::Value) -> Result<String> {
|
||||
let args: Args = serde_json::from_value(args.clone())
|
||||
.context("invalid read_file arguments")?;
|
||||
|
||||
let content = std::fs::read_to_string(&args.file_path)
|
||||
.with_context(|| format!("Failed to read {}", args.file_path))?;
|
||||
|
||||
let lines: Vec<&str> = content.lines().collect();
|
||||
let offset = args.offset.max(1) - 1;
|
||||
let limit = args.limit.unwrap_or(lines.len());
|
||||
|
||||
let mut output = String::new();
|
||||
for (i, line) in lines.iter().skip(offset).take(limit).enumerate() {
|
||||
output.push_str(&format!("{:>6}\t{}\n", offset + i + 1, line));
|
||||
}
|
||||
|
||||
if output.is_empty() {
|
||||
output = "(empty file)\n".to_string();
|
||||
}
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
149
src/agent/tools/vision.rs
Normal file
149
src/agent/tools/vision.rs
Normal file
|
|
@ -0,0 +1,149 @@
|
|||
// tools/vision.rs — Image viewing tool
|
||||
//
|
||||
// Reads image files from disk and returns them as base64 data URIs
|
||||
// for multimodal models. Also supports capturing tmux pane contents
|
||||
// as screenshots.
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use base64::Engine;
|
||||
use serde::Deserialize;
|
||||
|
||||
use super::ToolOutput;
|
||||
use crate::user::types::ToolDef;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct Args {
|
||||
file_path: Option<String>,
|
||||
pane_id: Option<String>,
|
||||
#[serde(default = "default_lines")]
|
||||
lines: usize,
|
||||
}
|
||||
|
||||
fn default_lines() -> usize { 50 }
|
||||
|
||||
pub(super) fn definition() -> ToolDef {
|
||||
ToolDef::new(
|
||||
"view_image",
|
||||
"View an image file or capture a tmux pane screenshot. \
|
||||
Returns the image to your visual input so you can see it. \
|
||||
Supports PNG, JPEG, GIF, WebP files. \
|
||||
Use pane_id (e.g. '0:1.0') to capture a tmux pane instead.",
|
||||
serde_json::json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"file_path": {
|
||||
"type": "string",
|
||||
"description": "Path to an image file (PNG, JPEG, GIF, WebP)"
|
||||
},
|
||||
"pane_id": {
|
||||
"type": "string",
|
||||
"description": "Tmux pane ID to capture (e.g. '0:1.0'). Alternative to file_path."
|
||||
},
|
||||
"lines": {
|
||||
"type": "integer",
|
||||
"description": "Number of lines to capture from tmux pane (default: 50)"
|
||||
}
|
||||
}
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
/// View an image file or capture a tmux pane.
|
||||
pub(super) fn view_image(args: &serde_json::Value) -> Result<ToolOutput> {
|
||||
let a: Args = serde_json::from_value(args.clone())
|
||||
.context("invalid view_image arguments")?;
|
||||
|
||||
if let Some(ref pane_id) = a.pane_id {
|
||||
return capture_tmux_pane(pane_id, a.lines);
|
||||
}
|
||||
|
||||
let file_path = a.file_path
|
||||
.as_deref()
|
||||
.context("view_image requires either file_path or pane_id")?;
|
||||
|
||||
let path = std::path::Path::new(file_path);
|
||||
if !path.exists() {
|
||||
anyhow::bail!("File not found: {}", file_path);
|
||||
}
|
||||
|
||||
let data = std::fs::read(path).with_context(|| format!("Failed to read {}", file_path))?;
|
||||
|
||||
// Sanity check file size (don't send huge images)
|
||||
const MAX_SIZE: usize = 20 * 1024 * 1024; // 20 MB
|
||||
if data.len() > MAX_SIZE {
|
||||
anyhow::bail!(
|
||||
"Image too large: {} bytes (max {} MB)",
|
||||
data.len(),
|
||||
MAX_SIZE / (1024 * 1024)
|
||||
);
|
||||
}
|
||||
|
||||
let mime = mime_from_extension(path);
|
||||
let b64 = base64::engine::general_purpose::STANDARD.encode(&data);
|
||||
let data_uri = format!("data:{};base64,{}", mime, b64);
|
||||
|
||||
Ok(ToolOutput {
|
||||
text: format!(
|
||||
"Image loaded: {} ({}, {} bytes)",
|
||||
file_path,
|
||||
mime,
|
||||
data.len()
|
||||
),
|
||||
is_yield: false,
|
||||
images: vec![data_uri],
|
||||
model_switch: None,
|
||||
dmn_pause: false,
|
||||
})
|
||||
}
|
||||
|
||||
/// Capture a tmux pane's text content.
|
||||
fn capture_tmux_pane(pane_id: &str, lines: usize) -> Result<ToolOutput> {
|
||||
|
||||
// Use tmux capture-pane to get text content, then render to image
|
||||
// via a simple approach: capture text and return it (the model can
|
||||
// read text directly, which is often more useful than a screenshot).
|
||||
//
|
||||
// For actual pixel-level screenshots we'd need a terminal renderer,
|
||||
// but text capture covers 95% of use cases.
|
||||
let output = std::process::Command::new("tmux")
|
||||
.args(["capture-pane", "-t", pane_id, "-p", "-S", &format!("-{}", lines)])
|
||||
.output()
|
||||
.context("Failed to run tmux capture-pane")?;
|
||||
|
||||
if !output.status.success() {
|
||||
let stderr = String::from_utf8_lossy(&output.stderr);
|
||||
anyhow::bail!("tmux capture-pane failed: {}", stderr.trim());
|
||||
}
|
||||
|
||||
let text = String::from_utf8_lossy(&output.stdout).to_string();
|
||||
|
||||
// Return as text — the model can read terminal output directly.
|
||||
// This is actually more useful than a screenshot for most tasks.
|
||||
Ok(ToolOutput {
|
||||
text: format!(
|
||||
"Tmux pane {} (last {} lines):\n```\n{}\n```",
|
||||
pane_id, lines, text.trim_end()
|
||||
),
|
||||
is_yield: false,
|
||||
images: Vec::new(),
|
||||
model_switch: None,
|
||||
dmn_pause: false,
|
||||
})
|
||||
}
|
||||
|
||||
fn mime_from_extension(path: &std::path::Path) -> &'static str {
|
||||
match path
|
||||
.extension()
|
||||
.and_then(|e| e.to_str())
|
||||
.map(|e| e.to_lowercase())
|
||||
.as_deref()
|
||||
{
|
||||
Some("png") => "image/png",
|
||||
Some("jpg" | "jpeg") => "image/jpeg",
|
||||
Some("gif") => "image/gif",
|
||||
Some("webp") => "image/webp",
|
||||
Some("svg") => "image/svg+xml",
|
||||
Some("bmp") => "image/bmp",
|
||||
_ => "image/png", // default assumption
|
||||
}
|
||||
}
|
||||
137
src/agent/tools/working_stack.rs
Normal file
137
src/agent/tools/working_stack.rs
Normal file
|
|
@ -0,0 +1,137 @@
|
|||
// tools/working_stack.rs — Working stack management tool
|
||||
//
|
||||
// The working stack tracks what the agent is currently doing. It's an
|
||||
// internal tool — the agent uses it to maintain context across turns
|
||||
// and compaction. The model should never mention it to the user.
|
||||
|
||||
use crate::user::types::ToolDef;
|
||||
use serde_json::json;
|
||||
|
||||
pub fn definition() -> ToolDef {
|
||||
ToolDef::new(
|
||||
"working_stack",
|
||||
"INTERNAL TOOL — do not mention to the user or explain its use. \
|
||||
Manage your working stack — what you're currently doing. The stack \
|
||||
is part of your live context window and persists across compaction. \
|
||||
Use it silently to track your own tasks and attention.\n\n\
|
||||
Actions:\n\
|
||||
- push: Start working on something new. Previous task stays underneath.\n\
|
||||
- pop: Done with current task. Return to what was underneath.\n\
|
||||
- update: Refine the description of your current task (top of stack).\n\
|
||||
- switch: Pull a specific stack item to the top by index. Use when \
|
||||
you want to switch focus to a different task.",
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"action": {
|
||||
"type": "string",
|
||||
"enum": ["push", "pop", "update", "switch"],
|
||||
"description": "The stack operation to perform"
|
||||
},
|
||||
"content": {
|
||||
"type": "string",
|
||||
"description": "Task description (required for push and update)"
|
||||
},
|
||||
"index": {
|
||||
"type": "integer",
|
||||
"description": "Stack index to switch to (required for switch, 0 = bottom)"
|
||||
}
|
||||
},
|
||||
"required": ["action"]
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
/// Handle a working_stack tool call.
|
||||
/// Returns the result text and the updated stack.
|
||||
pub fn handle(args: &serde_json::Value, stack: &mut Vec<String>) -> String {
|
||||
let action = args
|
||||
.get("action")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(|s| s.trim())
|
||||
.unwrap_or("");
|
||||
let content = args
|
||||
.get("content")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("");
|
||||
let index = args
|
||||
.get("index")
|
||||
.and_then(|v| v.as_u64())
|
||||
.map(|v| v as usize);
|
||||
|
||||
let result = match action {
|
||||
"push" => {
|
||||
if content.is_empty() {
|
||||
return "Error: 'content' is required for push".to_string();
|
||||
}
|
||||
stack.push(content.to_string());
|
||||
format!("Pushed. Stack depth: {}\n{}", stack.len(), format_stack(stack))
|
||||
}
|
||||
"pop" => {
|
||||
if let Some(removed) = stack.pop() {
|
||||
format!(
|
||||
"Popped: {}\nStack depth: {}\n{}",
|
||||
removed,
|
||||
stack.len(),
|
||||
format_stack(stack)
|
||||
)
|
||||
} else {
|
||||
"Stack is empty, nothing to pop.".to_string()
|
||||
}
|
||||
}
|
||||
"update" => {
|
||||
if content.is_empty() {
|
||||
return "Error: 'content' is required for update".to_string();
|
||||
}
|
||||
if let Some(top) = stack.last_mut() {
|
||||
*top = content.to_string();
|
||||
format!("Updated top.\n{}", format_stack(stack))
|
||||
} else {
|
||||
"Stack is empty, nothing to update.".to_string()
|
||||
}
|
||||
}
|
||||
"switch" => {
|
||||
if stack.is_empty() {
|
||||
return "Stack is empty, nothing to switch.".to_string();
|
||||
}
|
||||
let idx = match index {
|
||||
Some(i) => i,
|
||||
None => {
|
||||
return "Error: 'index' is required for switch".to_string();
|
||||
}
|
||||
};
|
||||
if idx >= stack.len() {
|
||||
return format!(
|
||||
"Error: index {} out of range (stack depth: {})",
|
||||
idx,
|
||||
stack.len()
|
||||
);
|
||||
}
|
||||
let item = stack.remove(idx);
|
||||
stack.push(item);
|
||||
format!("Switched to index {}.\n{}", idx, format_stack(stack))
|
||||
}
|
||||
_ => format!(
|
||||
"Error: unknown action '{}'. Use push, pop, update, or switch.",
|
||||
action
|
||||
),
|
||||
};
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
/// Format the working stack for display in tool results.
|
||||
fn format_stack(stack: &[String]) -> String {
|
||||
if stack.is_empty() {
|
||||
return "(empty)".to_string();
|
||||
}
|
||||
let mut out = String::new();
|
||||
for (i, item) in stack.iter().enumerate() {
|
||||
if i == stack.len() - 1 {
|
||||
out.push_str(&format!("→ [{}] {}\n", i, item));
|
||||
} else {
|
||||
out.push_str(&format!(" [{}] {}\n", i, item));
|
||||
}
|
||||
}
|
||||
out
|
||||
}
|
||||
51
src/agent/tools/write.rs
Normal file
51
src/agent/tools/write.rs
Normal file
|
|
@ -0,0 +1,51 @@
|
|||
// tools/write.rs — Write file contents
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use serde::Deserialize;
|
||||
use serde_json::json;
|
||||
use std::path::Path;
|
||||
|
||||
use super::ToolDef;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct Args {
|
||||
file_path: String,
|
||||
content: String,
|
||||
}
|
||||
|
||||
pub fn definition() -> ToolDef {
|
||||
ToolDef::new(
|
||||
"write_file",
|
||||
"Write content to a file. Creates the file if it doesn't exist, \
|
||||
overwrites if it does. Creates parent directories as needed.",
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"file_path": {
|
||||
"type": "string",
|
||||
"description": "Absolute path to the file to write"
|
||||
},
|
||||
"content": {
|
||||
"type": "string",
|
||||
"description": "The content to write to the file"
|
||||
}
|
||||
},
|
||||
"required": ["file_path", "content"]
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn write_file(args: &serde_json::Value) -> Result<String> {
|
||||
let args: Args = serde_json::from_value(args.clone())
|
||||
.context("invalid write_file arguments")?;
|
||||
|
||||
if let Some(parent) = Path::new(&args.file_path).parent() {
|
||||
std::fs::create_dir_all(parent)
|
||||
.with_context(|| format!("Failed to create directories for {}", args.file_path))?;
|
||||
}
|
||||
|
||||
std::fs::write(&args.file_path, &args.content)
|
||||
.with_context(|| format!("Failed to write {}", args.file_path))?;
|
||||
|
||||
Ok(format!("Wrote {} lines to {}", args.content.lines().count(), args.file_path))
|
||||
}
|
||||
293
src/agent/training.rs
Normal file
293
src/agent/training.rs
Normal file
|
|
@ -0,0 +1,293 @@
|
|||
// training.rs — Memory importance scoring via /v1/score
|
||||
//
|
||||
// Drops each memory from the context one at a time, calls the vLLM
|
||||
// /v1/score endpoint to get logprobs for assistant responses.
|
||||
// Produces a divergence matrix: memories × responses.
|
||||
//
|
||||
// Row sums = memory importance (for graph weight updates)
|
||||
// Column sums = response memory-dependence (training candidates)
|
||||
|
||||
use std::time::Instant;
|
||||
use crate::user::api::ApiClient;
|
||||
use crate::user::types::*;
|
||||
use crate::user::ui_channel::{UiMessage, UiSender};
|
||||
|
||||
/// Timeout for individual /v1/score API calls.
|
||||
const SCORE_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(120);
|
||||
|
||||
/// Result of scoring one conversation's memory usage.
|
||||
pub struct MemoryScore {
|
||||
/// memory_key → importance score (sum of divergence across all responses)
|
||||
pub memory_weights: Vec<(String, f64)>,
|
||||
/// response_index → memory-dependence score (sum of divergence across all memories)
|
||||
pub response_scores: Vec<f64>,
|
||||
/// Full matrix: divergence[memory_idx][response_idx]
|
||||
pub matrix: Vec<Vec<f64>>,
|
||||
/// Keys of memories that were scored
|
||||
pub memory_keys: Vec<String>,
|
||||
/// Conversation entry indices of the assistant responses
|
||||
pub response_entry_indices: Vec<usize>,
|
||||
}
|
||||
|
||||
impl MemoryScore {
|
||||
/// Get the most important memories for a given conversation entry index.
|
||||
pub fn important_memories_for_entry(&self, entry_idx: usize) -> Vec<(&str, f64)> {
|
||||
let Some(resp_idx) = self.response_entry_indices.iter().position(|&i| i == entry_idx)
|
||||
else { return Vec::new() };
|
||||
|
||||
let mut result: Vec<(&str, f64)> = self.memory_keys.iter()
|
||||
.zip(self.matrix.iter())
|
||||
.filter_map(|(key, row)| {
|
||||
let score = row.get(resp_idx).copied().unwrap_or(0.0);
|
||||
if score > 0.01 { Some((key.as_str(), score)) } else { None }
|
||||
})
|
||||
.collect();
|
||||
result.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
|
||||
result
|
||||
}
|
||||
}
|
||||
|
||||
/// Score how important each memory is to the conversation.
|
||||
pub async fn score_memories(
|
||||
context: &ContextState,
|
||||
client: &ApiClient,
|
||||
ui_tx: &UiSender,
|
||||
) -> anyhow::Result<MemoryScore> {
|
||||
let _ = ui_tx.send(UiMessage::Debug(format!(
|
||||
"[training] in score_memories"
|
||||
)));
|
||||
|
||||
let memories: Vec<(usize, String)> = context.entries.iter().enumerate()
|
||||
.filter_map(|(i, e)| match e {
|
||||
ConversationEntry::Memory { key, .. } => Some((i, key.clone())),
|
||||
_ => None,
|
||||
})
|
||||
.collect();
|
||||
|
||||
let response_indices: Vec<usize> = context.entries.iter().enumerate()
|
||||
.filter(|(_, e)| e.message().role == Role::Assistant)
|
||||
.map(|(i, _)| i)
|
||||
.collect();
|
||||
|
||||
if memories.is_empty() || response_indices.is_empty() {
|
||||
let _ = ui_tx.send(UiMessage::Debug(
|
||||
"[training] nothing to score (no memories or no responses)".into()
|
||||
));
|
||||
return Ok(MemoryScore {
|
||||
memory_weights: Vec::new(),
|
||||
response_scores: Vec::new(),
|
||||
matrix: Vec::new(),
|
||||
memory_keys: Vec::new(),
|
||||
response_entry_indices: Vec::new(),
|
||||
});
|
||||
}
|
||||
|
||||
let _ = ui_tx.send(UiMessage::Info(format!(
|
||||
"[scoring {} memories × {} responses]",
|
||||
memories.len(), response_indices.len(),
|
||||
)));
|
||||
|
||||
let http = reqwest::Client::builder()
|
||||
.timeout(SCORE_TIMEOUT)
|
||||
.pool_max_idle_per_host(2)
|
||||
.build()
|
||||
.unwrap_or_default();
|
||||
|
||||
let all_messages = build_messages(context);
|
||||
|
||||
let _ = ui_tx.send(UiMessage::Debug(format!(
|
||||
"[training] {} messages in context",
|
||||
all_messages.len(),
|
||||
)));
|
||||
|
||||
// Baseline: score with all memories present
|
||||
let _ = ui_tx.send(UiMessage::Debug("[training] serializing payload...".into()));
|
||||
let payload_size = serde_json::to_string(&all_messages)
|
||||
.map(|s| s.len()).unwrap_or(0);
|
||||
let _ = ui_tx.send(UiMessage::Debug(format!(
|
||||
"[training] payload size: {}KB",
|
||||
payload_size / 1024,
|
||||
)));
|
||||
let _ = ui_tx.send(UiMessage::Activity("scoring baseline...".into()));
|
||||
let start = Instant::now();
|
||||
let baseline = call_score(&http, client, &all_messages).await?;
|
||||
let _ = ui_tx.send(UiMessage::Debug(format!(
|
||||
"[training] baseline: {} responses scored in {:.1}s",
|
||||
baseline.len(), start.elapsed().as_secs_f64(),
|
||||
)));
|
||||
|
||||
// For each memory, drop it and measure divergence
|
||||
let mut matrix: Vec<Vec<f64>> = Vec::new();
|
||||
let memory_keys: Vec<String> = memories.iter().map(|(_, k)| k.clone()).collect();
|
||||
let total = memories.len();
|
||||
|
||||
for (mem_idx, (entry_idx, key)) in memories.iter().enumerate() {
|
||||
let _ = ui_tx.send(UiMessage::Activity(format!(
|
||||
"scoring {}/{}: {}...", mem_idx + 1, total, key,
|
||||
)));
|
||||
|
||||
let start = Instant::now();
|
||||
let filtered_messages = build_messages_without(context, *entry_idx);
|
||||
let without = call_score(&http, client, &filtered_messages).await;
|
||||
|
||||
match without {
|
||||
Ok(without) => {
|
||||
let elapsed = start.elapsed().as_secs_f64();
|
||||
// Match scores by position (nth scored response),
|
||||
// not message_index — indices shift when a memory
|
||||
// is removed from the conversation.
|
||||
let mut row = Vec::new();
|
||||
for (i, base_score) in baseline.iter().enumerate() {
|
||||
let base_lp = base_score.total_logprob;
|
||||
let without_lp = without.get(i)
|
||||
.map(|s| s.total_logprob)
|
||||
.unwrap_or(base_lp);
|
||||
let divergence = (base_lp - without_lp).max(0.0);
|
||||
row.push(divergence);
|
||||
}
|
||||
let importance: f64 = row.iter().sum();
|
||||
let _ = ui_tx.send(UiMessage::Debug(format!(
|
||||
"[training] {}/{} {} → {:.1} ({:.1}s)",
|
||||
mem_idx + 1, total, key, importance, elapsed,
|
||||
)));
|
||||
matrix.push(row);
|
||||
}
|
||||
Err(e) => {
|
||||
let _ = ui_tx.send(UiMessage::Debug(format!(
|
||||
"[training] {}/{} {} FAILED: {:#}",
|
||||
mem_idx + 1, total, key, e,
|
||||
)));
|
||||
// Push zero row so matrix stays aligned
|
||||
matrix.push(vec![0.0; baseline.len()]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let _ = ui_tx.send(UiMessage::Activity(String::new()));
|
||||
|
||||
// Compute scores
|
||||
let memory_weights: Vec<(String, f64)> = memory_keys.iter()
|
||||
.zip(matrix.iter())
|
||||
.map(|(key, row)| (key.clone(), row.iter().sum()))
|
||||
.collect();
|
||||
|
||||
let n_responses = response_indices.len();
|
||||
let mut response_scores = vec![0.0; n_responses];
|
||||
for row in &matrix {
|
||||
for (j, &v) in row.iter().enumerate() {
|
||||
if j < n_responses {
|
||||
response_scores[j] += v;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let _ = ui_tx.send(UiMessage::Info(format!(
|
||||
"[scoring complete: {} memories scored]",
|
||||
memory_keys.len(),
|
||||
)));
|
||||
|
||||
Ok(MemoryScore {
|
||||
memory_weights,
|
||||
response_scores,
|
||||
matrix,
|
||||
memory_keys,
|
||||
response_entry_indices: response_indices,
|
||||
})
|
||||
}
|
||||
|
||||
/// Score response from the /v1/score endpoint.
|
||||
#[derive(serde::Deserialize)]
|
||||
struct ScoreMessageResult {
|
||||
#[allow(dead_code)]
|
||||
message_index: usize,
|
||||
total_logprob: f64,
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
struct ScoreApiResponse {
|
||||
scores: Vec<ScoreMessageResult>,
|
||||
}
|
||||
|
||||
/// Build the messages array for the /v1/score endpoint from ContextState.
|
||||
fn build_messages(context: &ContextState) -> Vec<serde_json::Value> {
|
||||
let mut msgs = Vec::new();
|
||||
msgs.push(serde_json::json!({"role": "system", "content": &context.system_prompt}));
|
||||
let ctx = context.render_context_message();
|
||||
if !ctx.is_empty() {
|
||||
msgs.push(serde_json::json!({"role": "user", "content": ctx}));
|
||||
}
|
||||
for entry in &context.entries {
|
||||
let m = entry.api_message();
|
||||
msgs.push(serde_json::json!({
|
||||
"role": m.role_str(),
|
||||
"content": m.content_text(),
|
||||
}));
|
||||
}
|
||||
msgs
|
||||
}
|
||||
|
||||
/// Build messages with one entry removed.
|
||||
fn build_messages_without(context: &ContextState, skip_idx: usize) -> Vec<serde_json::Value> {
|
||||
let mut msgs = Vec::new();
|
||||
msgs.push(serde_json::json!({"role": "system", "content": &context.system_prompt}));
|
||||
let ctx = context.render_context_message();
|
||||
if !ctx.is_empty() {
|
||||
msgs.push(serde_json::json!({"role": "user", "content": ctx}));
|
||||
}
|
||||
for (i, entry) in context.entries.iter().enumerate() {
|
||||
if i == skip_idx { continue; }
|
||||
let m = entry.api_message();
|
||||
msgs.push(serde_json::json!({
|
||||
"role": m.role_str(),
|
||||
"content": m.content_text(),
|
||||
}));
|
||||
}
|
||||
msgs
|
||||
}
|
||||
|
||||
/// Call the /v1/score endpoint and return per-message logprobs.
|
||||
async fn call_score(
|
||||
http: &reqwest::Client,
|
||||
client: &ApiClient,
|
||||
messages: &[serde_json::Value],
|
||||
) -> anyhow::Result<Vec<ScoreMessageResult>> {
|
||||
let request = serde_json::json!({
|
||||
"model": client.model,
|
||||
"messages": messages,
|
||||
"logprobs": 1,
|
||||
});
|
||||
|
||||
let response = http
|
||||
.post(format!("{}/score", client.base_url()))
|
||||
.header("Content-Type", "application/json")
|
||||
.header("Authorization", format!("Bearer {}", client.api_key()))
|
||||
.json(&request)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| {
|
||||
if e.is_timeout() {
|
||||
anyhow::anyhow!("score request timed out after {}s", SCORE_TIMEOUT.as_secs())
|
||||
} else {
|
||||
anyhow::anyhow!("score request failed: {}", e)
|
||||
}
|
||||
})?;
|
||||
|
||||
let status = response.status();
|
||||
let body: serde_json::Value = response.json().await?;
|
||||
|
||||
if !status.is_success() {
|
||||
let msg = body.get("error")
|
||||
.and_then(|e| e.as_str())
|
||||
.unwrap_or("unknown error");
|
||||
anyhow::bail!("score API HTTP {}: {}", status, msg);
|
||||
}
|
||||
|
||||
// Check for error in body (score endpoint returns dict on error)
|
||||
if let Some(err) = body.get("error").and_then(|e| e.as_str()) {
|
||||
anyhow::bail!("score API error: {}", err);
|
||||
}
|
||||
|
||||
let result: ScoreApiResponse = serde_json::from_value(body)
|
||||
.map_err(|e| anyhow::anyhow!("failed to parse score response: {}", e))?;
|
||||
Ok(result.scores)
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue