// amygdala.rs — F8 amygdala screen: live per-token concept-readout // projections from the vLLM server's readout.safetensors. // // Left panel: top-K concepts by magnitude at the currently-selected // layer, as horizontal bars. The concept names come from the manifest // fetched at agent startup; the values come from the per-token readout // pushed onto agent.readout by the streaming token handler. // // Bottom: scrolling history of the last few tokens' top concept. // // Keys: // 1..9 select layer index (1 = first layer in the manifest) // t toggle between "current" (last token) and "mean over recent" use ratatui::{ layout::{Constraint, Direction, Layout, Rect}, style::{Color, Modifier, Style}, text::{Line, Span}, widgets::{Block, Borders, Gauge, Paragraph, Wrap}, Frame, }; use ratatui::crossterm::event::{Event, KeyCode}; use super::{App, ScreenView}; use crate::agent::api::ReadoutManifest; use crate::agent::readout::ReadoutEntry; const TOP_K: usize = 20; pub(crate) struct AmygdalaScreen { selected_layer: usize, mode: DisplayMode, } #[derive(Clone, Copy, PartialEq)] enum DisplayMode { /// Values from the single most recent token. Current, /// Mean over all tokens currently in the ring buffer. MeanRecent, } impl AmygdalaScreen { pub fn new() -> Self { Self { selected_layer: 0, mode: DisplayMode::Current, } } } impl ScreenView for AmygdalaScreen { fn label(&self) -> &'static str { "amygdala" } fn tick(&mut self, frame: &mut Frame, area: Rect, events: &[Event], app: &mut App) { for event in events { if let Event::Key(key) = event { match key.code { KeyCode::Char(c) if c.is_ascii_digit() && c != '0' => { let idx = (c as u8 - b'1') as usize; self.selected_layer = idx; } KeyCode::Char('t') => { self.mode = match self.mode { DisplayMode::Current => DisplayMode::MeanRecent, DisplayMode::MeanRecent => DisplayMode::Current, }; } _ => {} } } } // Snapshot the shared buffer with a short lock. let snapshot = match app.agent.readout.lock() { Ok(buf) => { if !buf.is_enabled() { render_disabled(frame, area); return; } let manifest = buf.manifest.clone().unwrap(); let entries: Vec = buf.recent.iter().cloned().collect(); (manifest, entries) } Err(_) => { render_disabled(frame, area); return; } }; let (manifest, entries) = snapshot; // Bound the selected layer to what the manifest actually has. let n_layers = manifest.layers.len(); if self.selected_layer >= n_layers { self.selected_layer = 0; } // Compute the values to display: either the latest token's row // for the selected layer, or the mean across recent tokens. let values: Option> = match self.mode { DisplayMode::Current => entries .last() .and_then(|e| e.readout.get(self.selected_layer).cloned()), DisplayMode::MeanRecent => mean_layer(&entries, self.selected_layer), }; let layout = Layout::default() .direction(Direction::Vertical) .constraints([ Constraint::Length(3), // header Constraint::Min(10), // bars Constraint::Length(6), // recent tokens ]) .split(area); render_header(frame, layout[0], &manifest, self.selected_layer, self.mode, entries.len()); match values { Some(v) => render_bars(frame, layout[1], &manifest.concepts, &v), None => render_empty_bars(frame, layout[1]), } render_recent(frame, layout[2], &entries, self.selected_layer, &manifest.concepts); } } fn render_disabled(frame: &mut Frame, area: Rect) { let text = Paragraph::new(Line::from(vec![ Span::raw("readout disabled — server did not return a manifest. "), Span::styled("Start vLLM with ", Style::default().fg(Color::DarkGray)), Span::styled("VLLM_READOUT_MANIFEST", Style::default().fg(Color::Yellow)), Span::styled(" + ", Style::default().fg(Color::DarkGray)), Span::styled("VLLM_READOUT_VECTORS", Style::default().fg(Color::Yellow)), Span::styled(".", Style::default().fg(Color::DarkGray)), ])) .wrap(Wrap { trim: true }) .block(Block::default().borders(Borders::ALL).title("amygdala")); frame.render_widget(text, area); } fn render_header(frame: &mut Frame, area: Rect, manifest: &ReadoutManifest, selected: usize, mode: DisplayMode, n_tokens: usize) { let mode_str = match mode { DisplayMode::Current => "current", DisplayMode::MeanRecent => "mean(recent)", }; let layer = manifest.layers.get(selected).copied().unwrap_or(0); let mut spans = vec![ Span::styled("layer ", Style::default().fg(Color::DarkGray)), Span::styled( format!("{}/{} ", selected + 1, manifest.layers.len()), Style::default().add_modifier(Modifier::BOLD), ), Span::styled("(index ", Style::default().fg(Color::DarkGray)), Span::styled(format!("{}", layer), Style::default().fg(Color::Cyan)), Span::styled(") ", Style::default().fg(Color::DarkGray)), Span::styled("mode ", Style::default().fg(Color::DarkGray)), Span::styled(mode_str, Style::default().fg(Color::Yellow)), Span::styled(" ", Style::default()), Span::styled( format!("{} toks in ring", n_tokens), Style::default().fg(Color::DarkGray), ), ]; spans.push(Span::raw(" ")); spans.push(Span::styled( format!("[1-{}] layer [t] toggle mode", manifest.layers.len().min(9)), Style::default().fg(Color::DarkGray), )); let para = Paragraph::new(Line::from(spans)) .block(Block::default().borders(Borders::ALL).title("amygdala")); frame.render_widget(para, area); } fn render_bars(frame: &mut Frame, area: Rect, concepts: &[String], values: &[f32]) { // Sort indices by |value| descending, take top K. let mut indexed: Vec<(usize, f32)> = values.iter() .enumerate().map(|(i, v)| (i, *v)).collect(); indexed.sort_by(|a, b| b.1.abs().partial_cmp(&a.1.abs()) .unwrap_or(std::cmp::Ordering::Equal)); indexed.truncate(TOP_K.min(concepts.len())); let inner = Block::default().borders(Borders::ALL) .title("top concepts"); let inner_area = inner.inner(area); frame.render_widget(inner, area); if inner_area.height == 0 || indexed.is_empty() { return; } // Find the max absolute value so bars are comparable. let max_abs = indexed.iter().map(|(_, v)| v.abs()) .fold(0.0_f32, f32::max) .max(1e-6); let rows = (inner_area.height as usize).min(indexed.len()); let row_constraints: Vec = std::iter::repeat(Constraint::Length(1)).take(rows).collect(); let chunks = Layout::default() .direction(Direction::Vertical) .constraints(row_constraints) .split(inner_area); for (i, (c_idx, v)) in indexed.iter().take(rows).enumerate() { let label = concepts.get(*c_idx).cloned() .unwrap_or_else(|| format!("c{}", c_idx)); let ratio = (v.abs() / max_abs).clamp(0.0, 1.0); let color = if *v >= 0.0 { Color::Green } else { Color::Red }; let gauge = Gauge::default() .ratio(ratio as f64) .gauge_style(Style::default().fg(color).bg(Color::Reset)) .label(format!("{:<26} {:+.3}", truncate_name(&label, 26), v)); frame.render_widget(gauge, chunks[i]); } } fn render_empty_bars(frame: &mut Frame, area: Rect) { let para = Paragraph::new(Line::from(Span::styled( "waiting for tokens…", Style::default().fg(Color::DarkGray), ))) .block(Block::default().borders(Borders::ALL).title("top concepts")); frame.render_widget(para, area); } fn render_recent(frame: &mut Frame, area: Rect, entries: &[ReadoutEntry], layer: usize, concepts: &[String]) { let mut lines: Vec = Vec::new(); for entry in entries.iter().rev().take(4) { let row = match entry.readout.get(layer) { Some(r) => r, None => continue, }; // top concept at this layer for this token let (best_idx, best_val) = row.iter().enumerate() .fold((0, 0.0_f32), |acc, (i, v)| { if v.abs() > acc.1.abs() { (i, *v) } else { acc } }); let name = concepts.get(best_idx).cloned() .unwrap_or_else(|| format!("c{}", best_idx)); let tok_str = format!("t{:>5}", entry.token_id); lines.push(Line::from(vec![ Span::styled(tok_str, Style::default().fg(Color::DarkGray)), Span::raw(" "), Span::styled( format!("{:<24}", truncate_name(&name, 24)), Style::default().fg( if best_val >= 0.0 { Color::Green } else { Color::Red }, ), ), Span::styled( format!(" {:+.3}", best_val), Style::default().add_modifier(Modifier::BOLD), ), ])); } let para = Paragraph::new(lines) .block(Block::default().borders(Borders::ALL).title("recent tokens — top concept")); frame.render_widget(para, area); } fn mean_layer(entries: &[ReadoutEntry], layer: usize) -> Option> { let rows: Vec<&Vec> = entries.iter() .filter_map(|e| e.readout.get(layer)) .collect(); if rows.is_empty() { return None; } let n_concepts = rows[0].len(); let mut acc = vec![0.0_f32; n_concepts]; for r in &rows { for (i, v) in r.iter().enumerate() { acc[i] += *v; } } let n = rows.len() as f32; for v in &mut acc { *v /= n; } Some(acc) } fn truncate_name(s: &str, max: usize) -> String { if s.len() <= max { s.to_string() } else { format!("{}…", &s[..max.saturating_sub(1)]) } }