forked from kent/consciousness
agent: share one tonic Channel + migrate scoring to gRPC Generate
Two changes that bolt together — the shared connection means the new
scoring path actually costs one HTTP/2 handshake across the whole
process instead of one-per-RPC.
ApiClient gains `salience_channel: Arc<OnceCell<Channel>>`. First
call to `ApiClient::salience_client()` opens the channel via
`connect_channel()` and stores the Channel; subsequent calls clone
it (cheap — tonic multiplexes concurrent RPCs over the single
HTTP/2 connection). Every ApiClient clone shares the same OnceCell,
so all agents spawned from Mind's client — plus every ephemeral
scoring session — reuse one connection.
SessionHandle refactored to hold an `ApiClient` clone instead of
a bag of (base_url, api_key) strings. `open` / `append_image` /
`generate` go through `self.client.salience_client()` now. New
`prefill_only(tokens)` method encapsulates the "Generate with
max_tokens=0 to append text" pattern (previously a private free
function in api/mod.rs called `flush_pending`). Drop impl on
SessionHandle stays — still fires CloseSession on the shared
channel in a detached task.
`run_session_generate` switched from `(base_url, api_key, model)`
to `&ApiClient`; the agent-turn flow that uses it keeps the same
shape but `stream_session_mm` clones the ApiClient into the
spawned worker.
learn.rs migrated from the HTTP `/v1/score` endpoint to a gRPC
session-based score:
* `call_score` opens an ephemeral SessionHandle on the client,
converts (prompt_tokens, images) → Vec<WireChunk> via the new
`prompt_to_chunks` helper (splits on VISION_START/VISION_END),
walks chunks calling `prefill_only` + `append_image`, runs a
final Generate with `max_tokens=0` + `logprobs_ranges` over
the scored positions, and sums each Token event's
`sampled_logprob` per range to produce `ScoreResult`s.
* SessionHandle drops at end of scope → CloseSession auto-fires,
keeping the server's session map clean between calls.
* No more HTTP path, no more `http_client()` helper, no more
`ScoreResponse` / serde plumbing for /v1/score.
* `send_to_train` still uses HTTP (it talks to /v1/train which
isn't on the gRPC protocol); its ad-hoc HTTP client lives
inline now instead of reaching for the deleted `http_client()`.
Co-Authored-By: Proof of Concept <poc@bcachefs.org>
This commit is contained in:
parent
be6ba4e9a5
commit
4feebb7bc4
3 changed files with 268 additions and 213 deletions
|
|
@ -93,6 +93,13 @@ pub struct ApiClient {
|
|||
/// across ApiClient clones (every Agent/fork gets the same cell).
|
||||
/// `None` after fetch means the server has readout disabled (404).
|
||||
manifest: std::sync::Arc<tokio::sync::OnceCell<Option<ReadoutManifest>>>,
|
||||
/// Shared tonic Channel to the salience gRPC endpoint. Opened on
|
||||
/// first use and reused across every SessionHandle / RPC call
|
||||
/// derived from this ApiClient. tonic multiplexes concurrent
|
||||
/// requests over the HTTP/2 connection automatically.
|
||||
salience_channel: std::sync::Arc<
|
||||
tokio::sync::OnceCell<tonic::transport::Channel>
|
||||
>,
|
||||
}
|
||||
|
||||
impl ApiClient {
|
||||
|
|
@ -108,9 +115,27 @@ impl ApiClient {
|
|||
model: model.to_string(),
|
||||
base_url: base_url.trim_end_matches('/').to_string(),
|
||||
manifest: std::sync::Arc::new(tokio::sync::OnceCell::new()),
|
||||
salience_channel: std::sync::Arc::new(tokio::sync::OnceCell::new()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Return a `SalienceClient` on the shared gRPC channel — opens
|
||||
/// the channel on first call and reuses it thereafter across
|
||||
/// every ApiClient clone. All scoring / inference / session
|
||||
/// RPCs flow through this single multiplexed HTTP/2 connection.
|
||||
pub async fn salience_client(&self) -> Result<
|
||||
salience::pb::salience_client::SalienceClient<tonic::transport::Channel>
|
||||
> {
|
||||
let ch = self.salience_channel.get_or_try_init(|| async {
|
||||
let grpc_url = salience::derive_grpc_url(&self.base_url);
|
||||
log::debug!(target: "grpc",
|
||||
"opening shared salience channel: http_base={} -> grpc_url={}",
|
||||
self.base_url, grpc_url);
|
||||
salience::connect_channel(&grpc_url).await
|
||||
}).await?;
|
||||
Ok(salience::pb::salience_client::SalienceClient::new(ch.clone()))
|
||||
}
|
||||
|
||||
/// Stream generation via a gRPC session. Walks the prompt chunks
|
||||
/// comparing against the session's `committed_len`, sends the
|
||||
/// delta as interleaved `AppendImage` + intermediate
|
||||
|
|
@ -130,14 +155,12 @@ impl ApiClient {
|
|||
readout_shape: Option<(u32, u32)>,
|
||||
) -> (mpsc::UnboundedReceiver<StreamToken>, AbortOnDrop) {
|
||||
let (tx, rx) = mpsc::unbounded_channel();
|
||||
let base_url = self.base_url.clone();
|
||||
let api_key = self.api_key.clone();
|
||||
let model = self.model.clone();
|
||||
let client = self.clone();
|
||||
|
||||
let handle = tokio::spawn(async move {
|
||||
let result = run_session_generate(
|
||||
session_lock, &base_url, &api_key, &model,
|
||||
chunks, sampling, priority, readout_shape, &tx,
|
||||
session_lock, &client, chunks, sampling, priority,
|
||||
readout_shape, &tx,
|
||||
).await;
|
||||
if let Err(e) = result {
|
||||
log::warn!(target: "grpc",
|
||||
|
|
@ -195,9 +218,7 @@ impl ApiClient {
|
|||
/// handle is dropped so the next call reopens.
|
||||
async fn run_session_generate(
|
||||
session_lock: std::sync::Arc<crate::Mutex<Option<salience::SessionHandle>>>,
|
||||
base_url: &str,
|
||||
api_key: &str,
|
||||
model: &str,
|
||||
client: &ApiClient,
|
||||
chunks: Vec<super::context::WireChunk>,
|
||||
sampling: SamplingParams,
|
||||
priority: Option<i32>,
|
||||
|
|
@ -216,7 +237,7 @@ async fn run_session_generate(
|
|||
None => {
|
||||
drop(guard);
|
||||
log::debug!(target: "grpc", "run_session_generate: opening new session");
|
||||
salience::SessionHandle::open(base_url, api_key, model).await?
|
||||
salience::SessionHandle::open(client).await?
|
||||
}
|
||||
}
|
||||
};
|
||||
|
|
@ -268,7 +289,7 @@ async fn run_session_generate(
|
|||
WireChunk::Tokens(t) => pending.extend_from_slice(t),
|
||||
WireChunk::Image { bytes, mime, .. } => {
|
||||
if !pending.is_empty() {
|
||||
flush_pending(&mut handle, std::mem::take(&mut pending)).await?;
|
||||
handle.prefill_only(std::mem::take(&mut pending)).await?;
|
||||
}
|
||||
let resp = handle
|
||||
.append_image(bytes.clone(), mime.clone(), false)
|
||||
|
|
@ -394,39 +415,3 @@ async fn run_session_generate(
|
|||
Ok(())
|
||||
}
|
||||
|
||||
/// Emit a prefill-only Generate for the pending token run. Used to
|
||||
/// append text that separates two image blocks — the server needs
|
||||
/// those tokens in its session before we AppendImage the next image,
|
||||
/// but we don't want the cost or output of a decode step.
|
||||
async fn flush_pending(
|
||||
handle: &mut salience::SessionHandle,
|
||||
tokens: Vec<u32>,
|
||||
) -> Result<()> {
|
||||
use futures::StreamExt;
|
||||
use salience::pb;
|
||||
let req = pb::GenerateRequest {
|
||||
session_id: handle.session_id.clone(),
|
||||
append_tokens: tokens,
|
||||
offset: handle.committed_len,
|
||||
truncating: false,
|
||||
max_tokens: 0,
|
||||
logprobs_ranges: Vec::new(),
|
||||
logprob_top_k: 0,
|
||||
readout_ranges: Vec::new(),
|
||||
temperature: 0.0,
|
||||
top_p: 0.0,
|
||||
top_k: 0,
|
||||
stop_token_ids: Vec::new(),
|
||||
priority: 0,
|
||||
};
|
||||
let mut stream = handle.generate(req).await?;
|
||||
while let Some(event) = stream.next().await {
|
||||
let event = event.map_err(|s| anyhow::anyhow!("flush Generate stream: {}", s))?;
|
||||
if let Some(pb::generate_event::Event::Done(d)) = event.event {
|
||||
handle.committed_len = d.total_tokens;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue