forked from kent/consciousness
agent: share one tonic Channel + migrate scoring to gRPC Generate
Two changes that bolt together — the shared connection means the new
scoring path actually costs one HTTP/2 handshake across the whole
process instead of one-per-RPC.
ApiClient gains `salience_channel: Arc<OnceCell<Channel>>`. First
call to `ApiClient::salience_client()` opens the channel via
`connect_channel()` and stores the Channel; subsequent calls clone
it (cheap — tonic multiplexes concurrent RPCs over the single
HTTP/2 connection). Every ApiClient clone shares the same OnceCell,
so all agents spawned from Mind's client — plus every ephemeral
scoring session — reuse one connection.
SessionHandle refactored to hold an `ApiClient` clone instead of
a bag of (base_url, api_key) strings. `open` / `append_image` /
`generate` go through `self.client.salience_client()` now. New
`prefill_only(tokens)` method encapsulates the "Generate with
max_tokens=0 to append text" pattern (previously a private free
function in api/mod.rs called `flush_pending`). Drop impl on
SessionHandle stays — still fires CloseSession on the shared
channel in a detached task.
`run_session_generate` switched from `(base_url, api_key, model)`
to `&ApiClient`; the agent-turn flow that uses it keeps the same
shape but `stream_session_mm` clones the ApiClient into the
spawned worker.
learn.rs migrated from the HTTP `/v1/score` endpoint to a gRPC
session-based score:
* `call_score` opens an ephemeral SessionHandle on the client,
converts (prompt_tokens, images) → Vec<WireChunk> via the new
`prompt_to_chunks` helper (splits on VISION_START/VISION_END),
walks chunks calling `prefill_only` + `append_image`, runs a
final Generate with `max_tokens=0` + `logprobs_ranges` over
the scored positions, and sums each Token event's
`sampled_logprob` per range to produce `ScoreResult`s.
* SessionHandle drops at end of scope → CloseSession auto-fires,
keeping the server's session map clean between calls.
* No more HTTP path, no more `http_client()` helper, no more
`ScoreResponse` / serde plumbing for /v1/score.
* `send_to_train` still uses HTTP (it talks to /v1/train which
isn't on the gRPC protocol); its ad-hoc HTTP client lives
inline now instead of reaching for the deleted `http_client()`.
Co-Authored-By: Proof of Concept <poc@bcachefs.org>
This commit is contained in:
parent
be6ba4e9a5
commit
4feebb7bc4
3 changed files with 268 additions and 213 deletions
|
|
@ -93,6 +93,13 @@ pub struct ApiClient {
|
||||||
/// across ApiClient clones (every Agent/fork gets the same cell).
|
/// across ApiClient clones (every Agent/fork gets the same cell).
|
||||||
/// `None` after fetch means the server has readout disabled (404).
|
/// `None` after fetch means the server has readout disabled (404).
|
||||||
manifest: std::sync::Arc<tokio::sync::OnceCell<Option<ReadoutManifest>>>,
|
manifest: std::sync::Arc<tokio::sync::OnceCell<Option<ReadoutManifest>>>,
|
||||||
|
/// Shared tonic Channel to the salience gRPC endpoint. Opened on
|
||||||
|
/// first use and reused across every SessionHandle / RPC call
|
||||||
|
/// derived from this ApiClient. tonic multiplexes concurrent
|
||||||
|
/// requests over the HTTP/2 connection automatically.
|
||||||
|
salience_channel: std::sync::Arc<
|
||||||
|
tokio::sync::OnceCell<tonic::transport::Channel>
|
||||||
|
>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ApiClient {
|
impl ApiClient {
|
||||||
|
|
@ -108,9 +115,27 @@ impl ApiClient {
|
||||||
model: model.to_string(),
|
model: model.to_string(),
|
||||||
base_url: base_url.trim_end_matches('/').to_string(),
|
base_url: base_url.trim_end_matches('/').to_string(),
|
||||||
manifest: std::sync::Arc::new(tokio::sync::OnceCell::new()),
|
manifest: std::sync::Arc::new(tokio::sync::OnceCell::new()),
|
||||||
|
salience_channel: std::sync::Arc::new(tokio::sync::OnceCell::new()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Return a `SalienceClient` on the shared gRPC channel — opens
|
||||||
|
/// the channel on first call and reuses it thereafter across
|
||||||
|
/// every ApiClient clone. All scoring / inference / session
|
||||||
|
/// RPCs flow through this single multiplexed HTTP/2 connection.
|
||||||
|
pub async fn salience_client(&self) -> Result<
|
||||||
|
salience::pb::salience_client::SalienceClient<tonic::transport::Channel>
|
||||||
|
> {
|
||||||
|
let ch = self.salience_channel.get_or_try_init(|| async {
|
||||||
|
let grpc_url = salience::derive_grpc_url(&self.base_url);
|
||||||
|
log::debug!(target: "grpc",
|
||||||
|
"opening shared salience channel: http_base={} -> grpc_url={}",
|
||||||
|
self.base_url, grpc_url);
|
||||||
|
salience::connect_channel(&grpc_url).await
|
||||||
|
}).await?;
|
||||||
|
Ok(salience::pb::salience_client::SalienceClient::new(ch.clone()))
|
||||||
|
}
|
||||||
|
|
||||||
/// Stream generation via a gRPC session. Walks the prompt chunks
|
/// Stream generation via a gRPC session. Walks the prompt chunks
|
||||||
/// comparing against the session's `committed_len`, sends the
|
/// comparing against the session's `committed_len`, sends the
|
||||||
/// delta as interleaved `AppendImage` + intermediate
|
/// delta as interleaved `AppendImage` + intermediate
|
||||||
|
|
@ -130,14 +155,12 @@ impl ApiClient {
|
||||||
readout_shape: Option<(u32, u32)>,
|
readout_shape: Option<(u32, u32)>,
|
||||||
) -> (mpsc::UnboundedReceiver<StreamToken>, AbortOnDrop) {
|
) -> (mpsc::UnboundedReceiver<StreamToken>, AbortOnDrop) {
|
||||||
let (tx, rx) = mpsc::unbounded_channel();
|
let (tx, rx) = mpsc::unbounded_channel();
|
||||||
let base_url = self.base_url.clone();
|
let client = self.clone();
|
||||||
let api_key = self.api_key.clone();
|
|
||||||
let model = self.model.clone();
|
|
||||||
|
|
||||||
let handle = tokio::spawn(async move {
|
let handle = tokio::spawn(async move {
|
||||||
let result = run_session_generate(
|
let result = run_session_generate(
|
||||||
session_lock, &base_url, &api_key, &model,
|
session_lock, &client, chunks, sampling, priority,
|
||||||
chunks, sampling, priority, readout_shape, &tx,
|
readout_shape, &tx,
|
||||||
).await;
|
).await;
|
||||||
if let Err(e) = result {
|
if let Err(e) = result {
|
||||||
log::warn!(target: "grpc",
|
log::warn!(target: "grpc",
|
||||||
|
|
@ -195,9 +218,7 @@ impl ApiClient {
|
||||||
/// handle is dropped so the next call reopens.
|
/// handle is dropped so the next call reopens.
|
||||||
async fn run_session_generate(
|
async fn run_session_generate(
|
||||||
session_lock: std::sync::Arc<crate::Mutex<Option<salience::SessionHandle>>>,
|
session_lock: std::sync::Arc<crate::Mutex<Option<salience::SessionHandle>>>,
|
||||||
base_url: &str,
|
client: &ApiClient,
|
||||||
api_key: &str,
|
|
||||||
model: &str,
|
|
||||||
chunks: Vec<super::context::WireChunk>,
|
chunks: Vec<super::context::WireChunk>,
|
||||||
sampling: SamplingParams,
|
sampling: SamplingParams,
|
||||||
priority: Option<i32>,
|
priority: Option<i32>,
|
||||||
|
|
@ -216,7 +237,7 @@ async fn run_session_generate(
|
||||||
None => {
|
None => {
|
||||||
drop(guard);
|
drop(guard);
|
||||||
log::debug!(target: "grpc", "run_session_generate: opening new session");
|
log::debug!(target: "grpc", "run_session_generate: opening new session");
|
||||||
salience::SessionHandle::open(base_url, api_key, model).await?
|
salience::SessionHandle::open(client).await?
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
@ -268,7 +289,7 @@ async fn run_session_generate(
|
||||||
WireChunk::Tokens(t) => pending.extend_from_slice(t),
|
WireChunk::Tokens(t) => pending.extend_from_slice(t),
|
||||||
WireChunk::Image { bytes, mime, .. } => {
|
WireChunk::Image { bytes, mime, .. } => {
|
||||||
if !pending.is_empty() {
|
if !pending.is_empty() {
|
||||||
flush_pending(&mut handle, std::mem::take(&mut pending)).await?;
|
handle.prefill_only(std::mem::take(&mut pending)).await?;
|
||||||
}
|
}
|
||||||
let resp = handle
|
let resp = handle
|
||||||
.append_image(bytes.clone(), mime.clone(), false)
|
.append_image(bytes.clone(), mime.clone(), false)
|
||||||
|
|
@ -394,39 +415,3 @@ async fn run_session_generate(
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Emit a prefill-only Generate for the pending token run. Used to
|
|
||||||
/// append text that separates two image blocks — the server needs
|
|
||||||
/// those tokens in its session before we AppendImage the next image,
|
|
||||||
/// but we don't want the cost or output of a decode step.
|
|
||||||
async fn flush_pending(
|
|
||||||
handle: &mut salience::SessionHandle,
|
|
||||||
tokens: Vec<u32>,
|
|
||||||
) -> Result<()> {
|
|
||||||
use futures::StreamExt;
|
|
||||||
use salience::pb;
|
|
||||||
let req = pb::GenerateRequest {
|
|
||||||
session_id: handle.session_id.clone(),
|
|
||||||
append_tokens: tokens,
|
|
||||||
offset: handle.committed_len,
|
|
||||||
truncating: false,
|
|
||||||
max_tokens: 0,
|
|
||||||
logprobs_ranges: Vec::new(),
|
|
||||||
logprob_top_k: 0,
|
|
||||||
readout_ranges: Vec::new(),
|
|
||||||
temperature: 0.0,
|
|
||||||
top_p: 0.0,
|
|
||||||
top_k: 0,
|
|
||||||
stop_token_ids: Vec::new(),
|
|
||||||
priority: 0,
|
|
||||||
};
|
|
||||||
let mut stream = handle.generate(req).await?;
|
|
||||||
while let Some(event) = stream.next().await {
|
|
||||||
let event = event.map_err(|s| anyhow::anyhow!("flush Generate stream: {}", s))?;
|
|
||||||
if let Some(pb::generate_event::Event::Done(d)) = event.event {
|
|
||||||
handle.committed_len = d.total_tokens;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -24,7 +24,12 @@ pub type SalienceClient = pb::salience_client::SalienceClient<Channel>;
|
||||||
/// looks like `https://host:8443`. User-provided CA certs under
|
/// looks like `https://host:8443`. User-provided CA certs under
|
||||||
/// `~/.consciousness/certs/` are trusted in addition to the system
|
/// `~/.consciousness/certs/` are trusted in addition to the system
|
||||||
/// roots (for self-signed server certs).
|
/// roots (for self-signed server certs).
|
||||||
pub async fn connect(base_url: &str) -> Result<SalienceClient> {
|
///
|
||||||
|
/// Returns the raw `Channel` so callers (`ApiClient::salience_client`)
|
||||||
|
/// can cache it and clone a `SalienceClient` per request without
|
||||||
|
/// reopening the TCP/TLS connection. tonic multiplexes RPCs over the
|
||||||
|
/// shared channel automatically.
|
||||||
|
pub async fn connect_channel(base_url: &str) -> Result<Channel> {
|
||||||
let mut endpoint = Endpoint::from_shared(base_url.to_string())
|
let mut endpoint = Endpoint::from_shared(base_url.to_string())
|
||||||
.with_context(|| format!("invalid salience endpoint: {}", base_url))?
|
.with_context(|| format!("invalid salience endpoint: {}", base_url))?
|
||||||
.connect_timeout(std::time::Duration::from_secs(30))
|
.connect_timeout(std::time::Duration::from_secs(30))
|
||||||
|
|
@ -41,11 +46,10 @@ pub async fn connect(base_url: &str) -> Result<SalienceClient> {
|
||||||
.with_context(|| "configuring tonic TLS")?;
|
.with_context(|| "configuring tonic TLS")?;
|
||||||
}
|
}
|
||||||
|
|
||||||
let channel = endpoint
|
endpoint
|
||||||
.connect()
|
.connect()
|
||||||
.await
|
.await
|
||||||
.with_context(|| format!("failed to connect to salience server at {}", base_url))?;
|
.with_context(|| format!("failed to connect to salience server at {}", base_url))
|
||||||
Ok(pb::salience_client::SalienceClient::new(channel))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Derive the gRPC base URL from the HTTP completions base URL.
|
/// Derive the gRPC base URL from the HTTP completions base URL.
|
||||||
|
|
@ -76,107 +80,42 @@ pub fn with_auth<T>(req: &mut tonic::Request<T>, api_key: &str) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Call the server's `OpenSession` RPC and return the response.
|
/// Handle to a server-side session. Carries the id + an `ApiClient`
|
||||||
pub async fn open_session(
|
/// clone (which holds the shared tonic Channel) so subsequent
|
||||||
base_url: &str,
|
/// per-session RPCs go over the process-global connection.
|
||||||
api_key: &str,
|
/// `committed_len` tracks the server's current session.tokens length
|
||||||
model: &str,
|
/// so the client can submit deltas with the right `offset`.
|
||||||
) -> Result<pb::OpenSessionResponse> {
|
|
||||||
let mut client = connect(base_url).await?;
|
|
||||||
let mut req = tonic::Request::new(pb::OpenSessionRequest {
|
|
||||||
model: model.to_string(),
|
|
||||||
});
|
|
||||||
with_auth(&mut req, api_key);
|
|
||||||
let resp = client
|
|
||||||
.open_session(req)
|
|
||||||
.await
|
|
||||||
.with_context(|| "OpenSession RPC failed")?;
|
|
||||||
Ok(resp.into_inner())
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Call the server's `CloseSession` RPC. Idempotent on the server.
|
|
||||||
pub async fn close_session(base_url: &str, api_key: &str, session_id: &str) -> Result<()> {
|
|
||||||
let mut client = connect(base_url).await?;
|
|
||||||
let mut req = tonic::Request::new(pb::CloseSessionRequest {
|
|
||||||
session_id: session_id.to_string(),
|
|
||||||
});
|
|
||||||
with_auth(&mut req, api_key);
|
|
||||||
client
|
|
||||||
.close_session(req)
|
|
||||||
.await
|
|
||||||
.with_context(|| "CloseSession RPC failed")?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Append an image to a session. Server decodes the image, computes N
|
|
||||||
/// via vLLM's own multimodal pipeline, writes the full vision block
|
|
||||||
/// (`<|vision_start|> + IMAGE_PAD×N + <|vision_end|>`) into
|
|
||||||
/// session.tokens, and returns (N, new total length).
|
|
||||||
///
|
|
||||||
/// `offset` is the client's view of the session's current token count;
|
|
||||||
/// the server rejects if it diverges from its own (unless
|
|
||||||
/// `truncating=true`, in which case the server slices to `offset`
|
|
||||||
/// first — but never through a vision block).
|
|
||||||
pub async fn append_image(
|
|
||||||
base_url: &str,
|
|
||||||
api_key: &str,
|
|
||||||
session_id: &str,
|
|
||||||
data: Vec<u8>,
|
|
||||||
mime: String,
|
|
||||||
offset: u32,
|
|
||||||
truncating: bool,
|
|
||||||
) -> Result<pb::AppendImageResponse> {
|
|
||||||
let mut client = connect(base_url).await?;
|
|
||||||
let mut req = tonic::Request::new(pb::AppendImageRequest {
|
|
||||||
session_id: session_id.to_string(),
|
|
||||||
data,
|
|
||||||
mime,
|
|
||||||
offset,
|
|
||||||
truncating,
|
|
||||||
});
|
|
||||||
with_auth(&mut req, api_key);
|
|
||||||
let resp = client
|
|
||||||
.append_image(req)
|
|
||||||
.await
|
|
||||||
.with_context(|| "AppendImage RPC failed")?;
|
|
||||||
Ok(resp.into_inner())
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Handle to a server-side session. Carries the id + connection params
|
|
||||||
/// so subsequent per-session RPCs (AppendImage, Generate, ForkSession)
|
|
||||||
/// can be issued without the caller juggling base_url / api_key each
|
|
||||||
/// time. `committed_len` tracks the server's current session.tokens
|
|
||||||
/// length so the client can submit deltas with the right `offset`.
|
|
||||||
pub struct SessionHandle {
|
pub struct SessionHandle {
|
||||||
pub session_id: String,
|
pub session_id: String,
|
||||||
pub max_model_len: u32,
|
pub max_model_len: u32,
|
||||||
pub base_url: String,
|
|
||||||
pub api_key: String,
|
|
||||||
pub committed_len: u32,
|
pub committed_len: u32,
|
||||||
|
client: super::ApiClient,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl SessionHandle {
|
impl SessionHandle {
|
||||||
pub async fn open(base_url: &str, api_key: &str, model: &str) -> Result<Self> {
|
pub async fn open(client: &super::ApiClient) -> Result<Self> {
|
||||||
let grpc_url = derive_grpc_url(base_url);
|
let mut c = client.salience_client().await?;
|
||||||
log::debug!(target: "grpc",
|
let mut req = tonic::Request::new(pb::OpenSessionRequest {
|
||||||
"SessionHandle::open http_base={} -> grpc_url={}",
|
model: client.model.clone(),
|
||||||
base_url, grpc_url);
|
});
|
||||||
let resp = open_session(&grpc_url, api_key, model).await?;
|
with_auth(&mut req, client.api_key());
|
||||||
|
let resp = c
|
||||||
|
.open_session(req)
|
||||||
|
.await
|
||||||
|
.with_context(|| "OpenSession RPC failed")?
|
||||||
|
.into_inner();
|
||||||
log::debug!(target: "grpc",
|
log::debug!(target: "grpc",
|
||||||
"SessionHandle::open session_id={} max_model_len={}",
|
"SessionHandle::open session_id={} max_model_len={}",
|
||||||
resp.session_id, resp.max_model_len);
|
resp.session_id, resp.max_model_len);
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
session_id: resp.session_id,
|
session_id: resp.session_id,
|
||||||
max_model_len: resp.max_model_len,
|
max_model_len: resp.max_model_len,
|
||||||
base_url: grpc_url,
|
|
||||||
api_key: api_key.to_string(),
|
|
||||||
committed_len: 0,
|
committed_len: 0,
|
||||||
|
client: client.clone(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn close(self) -> Result<()> {
|
pub fn client(&self) -> &super::ApiClient { &self.client }
|
||||||
close_session(&self.base_url, &self.api_key, &self.session_id).await
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Append an image via the server-side vision block. Updates
|
/// Append an image via the server-side vision block. Updates
|
||||||
/// `committed_len` from the server's response on success.
|
/// `committed_len` from the server's response on success.
|
||||||
|
|
@ -186,37 +125,105 @@ impl SessionHandle {
|
||||||
mime: String,
|
mime: String,
|
||||||
truncating: bool,
|
truncating: bool,
|
||||||
) -> Result<pb::AppendImageResponse> {
|
) -> Result<pb::AppendImageResponse> {
|
||||||
let resp = append_image(
|
let mut c = self.client.salience_client().await?;
|
||||||
&self.base_url,
|
let mut req = tonic::Request::new(pb::AppendImageRequest {
|
||||||
&self.api_key,
|
session_id: self.session_id.clone(),
|
||||||
&self.session_id,
|
|
||||||
data,
|
data,
|
||||||
mime,
|
mime,
|
||||||
self.committed_len,
|
offset: self.committed_len,
|
||||||
truncating,
|
truncating,
|
||||||
)
|
});
|
||||||
.await?;
|
with_auth(&mut req, self.client.api_key());
|
||||||
|
let resp = c
|
||||||
|
.append_image(req)
|
||||||
|
.await
|
||||||
|
.with_context(|| "AppendImage RPC failed")?
|
||||||
|
.into_inner();
|
||||||
self.committed_len = resp.total_length;
|
self.committed_len = resp.total_length;
|
||||||
Ok(resp)
|
Ok(resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Open a gRPC Generate stream with the given request. Caller
|
/// Open a gRPC Generate stream with the given request. Caller
|
||||||
/// iterates the returned stream of GenerateEvents; the handle's
|
/// iterates the returned stream of GenerateEvents; the handle's
|
||||||
/// `committed_len` is advanced on Done based on the Done event's
|
/// `committed_len` should be advanced by the caller on Done based
|
||||||
/// `total_tokens` field.
|
/// on the Done event's `total_tokens` field.
|
||||||
pub async fn generate(
|
pub async fn generate(
|
||||||
&self,
|
&self,
|
||||||
req: pb::GenerateRequest,
|
req: pb::GenerateRequest,
|
||||||
) -> Result<tonic::Streaming<pb::GenerateEvent>> {
|
) -> Result<tonic::Streaming<pb::GenerateEvent>> {
|
||||||
let mut client = connect(&self.base_url).await?;
|
let mut c = self.client.salience_client().await?;
|
||||||
let mut req = tonic::Request::new(req);
|
let mut req = tonic::Request::new(req);
|
||||||
with_auth(&mut req, &self.api_key);
|
with_auth(&mut req, self.client.api_key());
|
||||||
let resp = client
|
let resp = c
|
||||||
.generate(req)
|
.generate(req)
|
||||||
.await
|
.await
|
||||||
.with_context(|| "Generate RPC failed")?;
|
.with_context(|| "Generate RPC failed")?;
|
||||||
Ok(resp.into_inner())
|
Ok(resp.into_inner())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Run a prefill-only Generate (max_tokens=0) that appends the
|
||||||
|
/// given tokens to the session. No decode, no Token events — the
|
||||||
|
/// server just extends session.tokens and runs prefill to warm
|
||||||
|
/// the KV cache. Used to interleave text runs between AppendImage
|
||||||
|
/// calls, and by score paths that want prompt_logprobs without a
|
||||||
|
/// decode step.
|
||||||
|
pub async fn prefill_only(&mut self, tokens: Vec<u32>) -> Result<()> {
|
||||||
|
use futures::StreamExt;
|
||||||
|
let req = pb::GenerateRequest {
|
||||||
|
session_id: self.session_id.clone(),
|
||||||
|
append_tokens: tokens,
|
||||||
|
offset: self.committed_len,
|
||||||
|
truncating: false,
|
||||||
|
max_tokens: 0,
|
||||||
|
logprobs_ranges: Vec::new(),
|
||||||
|
logprob_top_k: 0,
|
||||||
|
readout_ranges: Vec::new(),
|
||||||
|
temperature: 0.0,
|
||||||
|
top_p: 0.0,
|
||||||
|
top_k: 0,
|
||||||
|
stop_token_ids: Vec::new(),
|
||||||
|
priority: 0,
|
||||||
|
};
|
||||||
|
let mut stream = self.generate(req).await?;
|
||||||
|
while let Some(event) = stream.next().await {
|
||||||
|
let event = event.map_err(|s| anyhow::anyhow!("prefill Generate stream: {}", s))?;
|
||||||
|
if let Some(pb::generate_event::Event::Done(d)) = event.event {
|
||||||
|
self.committed_len = d.total_tokens;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Drop → fire CloseSession in a detached task so servers don't leak
|
||||||
|
/// sessions until TTL eviction. Best-effort: if no tokio runtime is
|
||||||
|
/// available we skip; the server's 30min TTL will reap it eventually.
|
||||||
|
impl Drop for SessionHandle {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
if self.session_id.is_empty() {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
let session_id = std::mem::take(&mut self.session_id);
|
||||||
|
let client = self.client.clone();
|
||||||
|
let Ok(rt) = tokio::runtime::Handle::try_current() else {
|
||||||
|
log::debug!(target: "grpc",
|
||||||
|
"SessionHandle drop outside tokio runtime, session {} leaks to TTL",
|
||||||
|
session_id);
|
||||||
|
return;
|
||||||
|
};
|
||||||
|
rt.spawn(async move {
|
||||||
|
let Ok(mut c) = client.salience_client().await else { return };
|
||||||
|
let mut req = tonic::Request::new(pb::CloseSessionRequest {
|
||||||
|
session_id: session_id.clone(),
|
||||||
|
});
|
||||||
|
with_auth(&mut req, client.api_key());
|
||||||
|
if let Err(e) = c.close_session(req).await {
|
||||||
|
log::debug!(target: "grpc",
|
||||||
|
"CloseSession on drop failed for {}: {:#}",
|
||||||
|
session_id, e);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
|
|
|
||||||
|
|
@ -1,100 +1,166 @@
|
||||||
// training.rs — Memory importance scoring via /v1/score
|
// learn.rs — Memory importance scoring over the salience gRPC protocol.
|
||||||
//
|
//
|
||||||
// Three scoring modes, all built on the same call_score() primitive:
|
// Three scoring modes, all built on call_score():
|
||||||
//
|
//
|
||||||
// score_memories() — Full N×M matrix (memories × responses) for the
|
// score_memories() — Full N×M matrix (memories × responses) for the
|
||||||
// debug screen. Expensive: N+1 API calls.
|
// debug screen. Expensive: N+1 sessions/calls.
|
||||||
//
|
//
|
||||||
// memory_score() — Single memory importance. Scores the 50 messages
|
// score_memory() — Single memory importance. Scores the 50 messages
|
||||||
// after it was surfaced, with/without that memory.
|
// after it was surfaced, with/without that memory.
|
||||||
// 2 API calls.
|
// 2 calls.
|
||||||
//
|
//
|
||||||
// finetune_score() — Identifies training candidates. Scores recent
|
// finetune_score() — Identifies training candidates. Scores recent
|
||||||
// messages with all memories stripped. Responses
|
// messages with all memories stripped. Responses
|
||||||
// with high divergence depend on memories the model
|
// with high divergence depend on memories the model
|
||||||
// hasn't internalized. 2 API calls.
|
// hasn't internalized. 2 calls.
|
||||||
|
//
|
||||||
|
// Each call opens an ephemeral gRPC session (reusing the shared
|
||||||
|
// tonic Channel on `ApiClient`), pushes the prompt through as
|
||||||
|
// interleaved tokens + AppendImage calls, runs Generate with
|
||||||
|
// max_tokens=0 + logprobs_ranges over the scored positions, collects
|
||||||
|
// each Token event's sampled_logprob, then drops the SessionHandle —
|
||||||
|
// which triggers a best-effort CloseSession over the shared channel.
|
||||||
|
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
use crate::agent::api::ApiClient;
|
use crate::agent::api::ApiClient;
|
||||||
|
use crate::agent::api::salience::{SessionHandle, pb};
|
||||||
use crate::agent::context::{
|
use crate::agent::context::{
|
||||||
Ast, AstNode, ContextState, Role, WireImage,
|
Ast, AstNode, ContextState, Role, WireChunk, WireImage,
|
||||||
is_assistant, is_memory_node, memory_key, render_branch_text, render_prior_context,
|
is_assistant, is_memory_node, memory_key, render_branch_text, render_prior_context,
|
||||||
};
|
};
|
||||||
|
use crate::agent::tokenizer;
|
||||||
use crate::mind::{MindState, MindTriggered, TaskHandle};
|
use crate::mind::{MindState, MindTriggered, TaskHandle};
|
||||||
use crate::subconscious::generate::gen_continuation;
|
use crate::subconscious::generate::gen_continuation;
|
||||||
|
|
||||||
const SCORE_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(300);
|
|
||||||
|
|
||||||
// ── Score API ───────────────────────────────────────────────────
|
// ── Score API ───────────────────────────────────────────────────
|
||||||
|
|
||||||
#[derive(serde::Deserialize)]
|
#[derive(Debug, Clone)]
|
||||||
struct ScoreResult {
|
struct ScoreResult {
|
||||||
total_logprob: f64,
|
total_logprob: f64,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(serde::Deserialize)]
|
/// Convert a flat (prompt_tokens, images) pair into the interleaved
|
||||||
struct ScoreResponse {
|
/// chunks the session protocol expects. Tokens up to the next
|
||||||
scores: Vec<ScoreResult>,
|
/// `<|vision_start|>` become a Tokens chunk; each
|
||||||
|
/// `<|vision_start|>..<|vision_end|>` run collapses into one Image
|
||||||
|
/// chunk paired by position with the next entry in `images`. The
|
||||||
|
/// server re-expands the IMAGE_PADs on AppendImage.
|
||||||
|
fn prompt_to_chunks(prompt: &[u32], images: &[WireImage]) -> Vec<WireChunk> {
|
||||||
|
let mut out: Vec<WireChunk> = Vec::new();
|
||||||
|
let mut cur = 0;
|
||||||
|
let mut img_idx = 0;
|
||||||
|
while cur < prompt.len() {
|
||||||
|
if prompt[cur] == tokenizer::VISION_START {
|
||||||
|
let end_rel = prompt[cur..].iter()
|
||||||
|
.position(|&t| t == tokenizer::VISION_END)
|
||||||
|
.unwrap_or_else(|| panic!(
|
||||||
|
"unmatched VISION_START at position {} in prompt", cur));
|
||||||
|
let end = cur + end_rel + 1;
|
||||||
|
let img = images.get(img_idx)
|
||||||
|
.unwrap_or_else(|| panic!(
|
||||||
|
"image index {} out of range for {} images", img_idx, images.len()));
|
||||||
|
out.push(WireChunk::Image {
|
||||||
|
bytes: img.bytes.clone(),
|
||||||
|
mime: img.mime.clone(),
|
||||||
|
known_expanded_len: (end - cur) as u32,
|
||||||
|
});
|
||||||
|
img_idx += 1;
|
||||||
|
cur = end;
|
||||||
|
} else {
|
||||||
|
let next_vs = prompt[cur..].iter()
|
||||||
|
.position(|&t| t == tokenizer::VISION_START);
|
||||||
|
let end = match next_vs {
|
||||||
|
Some(o) => cur + o,
|
||||||
|
None => prompt.len(),
|
||||||
|
};
|
||||||
|
out.push(WireChunk::Tokens(prompt[cur..end].to_vec()));
|
||||||
|
cur = end;
|
||||||
}
|
}
|
||||||
|
}
|
||||||
fn http_client() -> crate::agent::api::http::HttpClient {
|
out
|
||||||
crate::agent::api::http::HttpClient::builder()
|
|
||||||
.timeout(SCORE_TIMEOUT)
|
|
||||||
.build()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn call_score(
|
async fn call_score(
|
||||||
http: &crate::agent::api::http::HttpClient,
|
|
||||||
client: &ApiClient,
|
client: &ApiClient,
|
||||||
prompt: &[u32],
|
prompt: &[u32],
|
||||||
images: &[WireImage],
|
images: &[WireImage],
|
||||||
ranges: &[(usize, usize)],
|
ranges: &[(usize, usize)],
|
||||||
priority: Option<i32>,
|
priority: Option<i32>,
|
||||||
) -> anyhow::Result<Vec<ScoreResult>> {
|
) -> anyhow::Result<Vec<ScoreResult>> {
|
||||||
|
use futures::StreamExt;
|
||||||
|
|
||||||
// Nothing to score — skip the round-trip.
|
// Nothing to score — skip the round-trip.
|
||||||
if ranges.is_empty() {
|
if ranges.is_empty() {
|
||||||
return Ok(Vec::new());
|
return Ok(Vec::new());
|
||||||
}
|
}
|
||||||
let url = format!("{}/score", client.base_url());
|
|
||||||
let auth = format!("Bearer {}", client.api_key());
|
let chunks = prompt_to_chunks(prompt, images);
|
||||||
let mut body = serde_json::json!({
|
let mut handle = SessionHandle::open(client).await?;
|
||||||
"model": client.model,
|
|
||||||
"prompt": prompt,
|
// Walk chunks: AppendImage for each image, prefill-only Generate
|
||||||
"score_ranges": ranges,
|
// for each text run between images. Accumulate any trailing text
|
||||||
"logprobs": 1,
|
// run into `pending` for the final logprob-generating Generate.
|
||||||
});
|
let mut pending: Vec<u32> = Vec::new();
|
||||||
if !images.is_empty() {
|
for chunk in chunks {
|
||||||
use base64::Engine;
|
match chunk {
|
||||||
let b64 = base64::engine::general_purpose::STANDARD;
|
WireChunk::Tokens(t) => pending.extend(t),
|
||||||
let uris: Vec<String> = images.iter()
|
WireChunk::Image { bytes, mime, .. } => {
|
||||||
.map(|img| format!("data:{};base64,{}", img.mime, b64.encode(&img.bytes)))
|
if !pending.is_empty() {
|
||||||
|
handle.prefill_only(std::mem::take(&mut pending)).await?;
|
||||||
|
}
|
||||||
|
handle.append_image(bytes, mime, false).await?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Final Generate: max_tokens=0 so the server runs prefill of the
|
||||||
|
// trailing `pending` tokens and emits Token events for each
|
||||||
|
// position covered by logprobs_ranges, then Done. logprob_top_k=0
|
||||||
|
// means "just the sampled (prompt) token's logprob" — no top-k
|
||||||
|
// alternatives, which is all call_score historically needed.
|
||||||
|
let logprobs_ranges: Vec<pb::PositionRange> = ranges.iter()
|
||||||
|
.map(|(s, e)| pb::PositionRange { start: *s as u32, end: *e as u32 })
|
||||||
.collect();
|
.collect();
|
||||||
body["multi_modal_data"] = serde_json::json!({ "image": uris });
|
let req = pb::GenerateRequest {
|
||||||
}
|
session_id: handle.session_id.clone(),
|
||||||
if let Some(p) = priority {
|
append_tokens: pending,
|
||||||
body["priority"] = serde_json::json!(p);
|
offset: handle.committed_len,
|
||||||
}
|
truncating: false,
|
||||||
let response = http
|
max_tokens: 0,
|
||||||
.send_json("POST", &url, &[
|
logprobs_ranges,
|
||||||
("authorization", &auth),
|
logprob_top_k: 0,
|
||||||
], &body)
|
readout_ranges: Vec::new(),
|
||||||
.await?;
|
temperature: 0.0,
|
||||||
|
top_p: 0.0,
|
||||||
|
top_k: 0,
|
||||||
|
stop_token_ids: Vec::new(),
|
||||||
|
priority: priority.unwrap_or(0),
|
||||||
|
};
|
||||||
|
|
||||||
let status = response.status();
|
let mut stream = handle.generate(req).await?;
|
||||||
let body: serde_json::Value = response.json().await?;
|
let mut totals = vec![0.0f64; ranges.len()];
|
||||||
|
while let Some(event) = stream.next().await {
|
||||||
if !status.is_success() {
|
let event = event
|
||||||
let msg = body.get("error").and_then(|e| e.as_str()).unwrap_or("unknown error");
|
.map_err(|s| anyhow::anyhow!("score Generate stream: {}", s))?;
|
||||||
anyhow::bail!("score API HTTP {}: {}", status, msg);
|
let Some(inner) = event.event else { continue };
|
||||||
|
match inner {
|
||||||
|
pb::generate_event::Event::Token(t) => {
|
||||||
|
if !t.has_sampled_logprob { continue; }
|
||||||
|
let pos = t.position as usize;
|
||||||
|
for (i, (start, end)) in ranges.iter().enumerate() {
|
||||||
|
if pos >= *start && pos < *end {
|
||||||
|
totals[i] += t.sampled_logprob as f64;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
pb::generate_event::Event::Done(_) => break,
|
||||||
}
|
}
|
||||||
if let Some(err) = body.get("error").and_then(|e| e.as_str()) {
|
|
||||||
anyhow::bail!("score API error: {}", err);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
let result: ScoreResponse = serde_json::from_value(body)
|
Ok(totals.into_iter()
|
||||||
.map_err(|e| anyhow::anyhow!("failed to parse score response: {}", e))?;
|
.map(|total_logprob| ScoreResult { total_logprob })
|
||||||
Ok(result.scores)
|
.collect())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Compute per-position logprob divergence: how much worse the model
|
/// Compute per-position logprob divergence: how much worse the model
|
||||||
|
|
@ -110,7 +176,6 @@ fn divergence(baseline: &[ScoreResult], without: &[ScoreResult]) -> Vec<f64> {
|
||||||
|
|
||||||
/// Score two message sets and return total divergence.
|
/// Score two message sets and return total divergence.
|
||||||
async fn score_divergence<F>(
|
async fn score_divergence<F>(
|
||||||
http: &crate::agent::api::http::HttpClient,
|
|
||||||
client: &ApiClient,
|
client: &ApiClient,
|
||||||
context: &ContextState,
|
context: &ContextState,
|
||||||
range: std::ops::Range<usize>,
|
range: std::ops::Range<usize>,
|
||||||
|
|
@ -123,9 +188,9 @@ where F: FnMut(&AstNode) -> bool,
|
||||||
context.wire_prompt(range.clone(), |_| false);
|
context.wire_prompt(range.clone(), |_| false);
|
||||||
let (without_tokens, without_images, without_ranges) =
|
let (without_tokens, without_images, without_ranges) =
|
||||||
context.wire_prompt(range, skip);
|
context.wire_prompt(range, skip);
|
||||||
let baseline = call_score(http, client, &baseline_tokens, &baseline_images,
|
let baseline = call_score(client, &baseline_tokens, &baseline_images,
|
||||||
&baseline_ranges, priority).await?;
|
&baseline_ranges, priority).await?;
|
||||||
let without = call_score(http, client, &without_tokens, &without_images,
|
let without = call_score(client, &without_tokens, &without_images,
|
||||||
&without_ranges, priority).await?;
|
&without_ranges, priority).await?;
|
||||||
let divs = divergence(&baseline, &without);
|
let divs = divergence(&baseline, &without);
|
||||||
Ok((divs, baseline))
|
Ok((divs, baseline))
|
||||||
|
|
@ -162,14 +227,13 @@ pub async fn score_memories(
|
||||||
dbglog!("[scoring-full] starting: {} memories × {} responses",
|
dbglog!("[scoring-full] starting: {} memories × {} responses",
|
||||||
total, response_indices.len());
|
total, response_indices.len());
|
||||||
|
|
||||||
let http = http_client();
|
|
||||||
|
|
||||||
let activity = crate::agent::start_activity(agent, "scoring: baseline").await;
|
let activity = crate::agent::start_activity(agent, "scoring: baseline").await;
|
||||||
let (baseline_tokens, baseline_images, baseline_ranges) = {
|
let (baseline_tokens, baseline_images, baseline_ranges) = {
|
||||||
let ctx = agent.context.lock().await;
|
let ctx = agent.context.lock().await;
|
||||||
ctx.wire_prompt(0..ctx.conversation().len(), |_| false)
|
ctx.wire_prompt(0..ctx.conversation().len(), |_| false)
|
||||||
};
|
};
|
||||||
let baseline = call_score(&http, client, &baseline_tokens, &baseline_images,
|
let baseline = call_score(client, &baseline_tokens, &baseline_images,
|
||||||
&baseline_ranges, Some(5)).await?;
|
&baseline_ranges, Some(5)).await?;
|
||||||
dbglog!("[scoring-full] baseline done ({} response scores)", baseline.len());
|
dbglog!("[scoring-full] baseline done ({} response scores)", baseline.len());
|
||||||
|
|
||||||
|
|
@ -180,7 +244,7 @@ pub async fn score_memories(
|
||||||
let ctx = agent.context.lock().await;
|
let ctx = agent.context.lock().await;
|
||||||
ctx.wire_prompt(0..ctx.conversation().len(), |n| memory_key(n) == Some(key.as_str()))
|
ctx.wire_prompt(0..ctx.conversation().len(), |n| memory_key(n) == Some(key.as_str()))
|
||||||
};
|
};
|
||||||
let row = match call_score(&http, client, &tokens, &images, &ranges, Some(5)).await {
|
let row = match call_score(client, &tokens, &images, &ranges, Some(5)).await {
|
||||||
Ok(without) => {
|
Ok(without) => {
|
||||||
let divs = divergence(&baseline, &without);
|
let divs = divergence(&baseline, &without);
|
||||||
let max_div = divs.iter().cloned().fold(0.0f64, f64::max);
|
let max_div = divs.iter().cloned().fold(0.0f64, f64::max);
|
||||||
|
|
@ -263,8 +327,7 @@ pub async fn score_memory(
|
||||||
return Ok(0.0);
|
return Ok(0.0);
|
||||||
}
|
}
|
||||||
|
|
||||||
let http = http_client();
|
let (divs, _) = score_divergence(client, context, range,
|
||||||
let (divs, _) = score_divergence(&http, client, context, range,
|
|
||||||
|n| memory_key(n) == Some(key), Some(5)).await?;
|
|n| memory_key(n) == Some(key), Some(5)).await?;
|
||||||
|
|
||||||
Ok(divs.iter().sum())
|
Ok(divs.iter().sum())
|
||||||
|
|
@ -322,7 +385,6 @@ where
|
||||||
// Score oldest-first
|
// Score oldest-first
|
||||||
candidates.sort_by_key(|&(_, _, last)| last);
|
candidates.sort_by_key(|&(_, _, last)| last);
|
||||||
|
|
||||||
let http = http_client();
|
|
||||||
let mut scored = 0;
|
let mut scored = 0;
|
||||||
|
|
||||||
let entries = context.conversation();
|
let entries = context.conversation();
|
||||||
|
|
@ -357,7 +419,7 @@ where
|
||||||
}
|
}
|
||||||
|
|
||||||
activity.update(format!("scoring: {}/{} {}", scored + 1, total, key)).await;
|
activity.update(format!("scoring: {}/{} {}", scored + 1, total, key)).await;
|
||||||
match score_divergence(&http, client, context, range,
|
match score_divergence(client, context, range,
|
||||||
|n| memory_key(n) == Some(key), Some(5)).await {
|
|n| memory_key(n) == Some(key), Some(5)).await {
|
||||||
Ok((divs, _)) => {
|
Ok((divs, _)) => {
|
||||||
let n_responses = divs.len();
|
let n_responses = divs.len();
|
||||||
|
|
@ -505,8 +567,7 @@ pub async fn score_finetune(
|
||||||
return Ok(Vec::new());
|
return Ok(Vec::new());
|
||||||
}
|
}
|
||||||
|
|
||||||
let http = http_client();
|
let (divs, _) = score_divergence(client, context, range, is_memory_node, Some(5)).await?;
|
||||||
let (divs, _) = score_divergence(&http, client, context, range, is_memory_node, Some(5)).await?;
|
|
||||||
|
|
||||||
let mut results: Vec<(usize, f64)> = response_positions.iter()
|
let mut results: Vec<(usize, f64)> = response_positions.iter()
|
||||||
.enumerate()
|
.enumerate()
|
||||||
|
|
@ -804,8 +865,10 @@ pub async fn send_to_train(
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
let http = http_client();
|
|
||||||
let url = format!("{}/train", client.base_url());
|
let url = format!("{}/train", client.base_url());
|
||||||
|
let http = crate::agent::api::http::HttpClient::builder()
|
||||||
|
.timeout(std::time::Duration::from_secs(300))
|
||||||
|
.build();
|
||||||
let response = http.send_json("POST", &url, &[], &body).await?;
|
let response = http.send_json("POST", &url, &[], &body).await?;
|
||||||
|
|
||||||
let status = response.status();
|
let status = response.status();
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue