diff --git a/training/apollo_plugin/__init__.py b/training/apollo_plugin/__init__.py new file mode 100644 index 0000000..bfbecd0 --- /dev/null +++ b/training/apollo_plugin/__init__.py @@ -0,0 +1,17 @@ +"""Apollo training plugin for vLLM. + +Enables continuous fine-tuning alongside live inference by: +1. Exporting CUDA IPC handles for weight sharing +2. Providing a training worker daemon (/train endpoint) +3. Block-level checkpoint sync to safetensors files + +Install: pip install -e /path/to/training +Then vLLM auto-loads via entry point. +""" + +from .export_hook import _patch_model_runner + + +def register(): + """Called by vLLM's plugin loader on startup.""" + _patch_model_runner() diff --git a/training/apollo_plugin/checkpoint_sync.py b/training/apollo_plugin/checkpoint_sync.py new file mode 100644 index 0000000..eff93cc --- /dev/null +++ b/training/apollo_plugin/checkpoint_sync.py @@ -0,0 +1,500 @@ +"""Sync live GPU weights to safetensors files on disk. + +Reads vLLM weight tensors via CUDA IPC handles, converts from vLLM's +merged layout to HuggingFace's separate layout, diffs block-by-block +against on-disk safetensors files, and writes only changed blocks. + +For small behavioral training steps, this turns a 54GB checkpoint +write into a few hundred MB of actual disk I/O. + +Usage: + # Sync live weights to disk + python checkpoint_sync.py sync --model-dir /path/to/Qwen3.5-27B + + # Debug name mapping issues + python checkpoint_sync.py diagnose --model-dir /path/to/Qwen3.5-27B + + # From Python: + from checkpoint_sync import checkpoint_sync + result = checkpoint_sync("/path/to/model") +""" + +import json +import mmap +import struct +import sys +from pathlib import Path +from typing import Dict, List, Tuple, Any +import logging + +import torch + +logger = logging.getLogger(__name__) + +DEFAULT_BLOCK_SIZE = 4096 # 4KB blocks — matches filesystem block size +DEFAULT_HANDLES_PATH = "/tmp/vllm_weight_handles.pt" + + +# --------------------------------------------------------------------------- +# vLLM → HuggingFace weight name/shape conversion +# --------------------------------------------------------------------------- +# Qwen3.5-27B dimensions (could be read from config.json for generality) + +HIDDEN = 5120 +NUM_K_HEADS = 16 +NUM_V_HEADS = 48 +HEAD_K_DIM = 128 +HEAD_V_DIM = 128 +KEY_DIM = NUM_K_HEADS * HEAD_K_DIM # 2048 +VALUE_DIM = NUM_V_HEADS * HEAD_V_DIM # 6144 +INTERMEDIATE = 17408 + +# Full attention (some layers use standard attention, not GDN) +NUM_ATTN_HEADS = 24 +NUM_ATTN_KV_HEADS = 4 +ATTN_HEAD_DIM = 256 +ATTN_Q_HEAD_DIM = ATTN_HEAD_DIM * 2 # 512 +ATTN_Q_DIM = NUM_ATTN_HEADS * ATTN_Q_HEAD_DIM # 12288 +ATTN_K_DIM = NUM_ATTN_KV_HEADS * ATTN_HEAD_DIM # 1024 +ATTN_V_DIM = NUM_ATTN_KV_HEADS * ATTN_HEAD_DIM # 1024 + + +def vllm_to_hf_tensors(vllm_params: Dict[str, torch.Tensor] + ) -> Dict[str, torch.Tensor]: + """Convert vLLM merged weights to HF-compatible separate tensors. + + vLLM merges certain projections for efficiency: + - qkv_proj (full attn) → q_proj, k_proj, v_proj + - in_proj_qkvz (GDN) → in_proj_qkv, in_proj_z + - in_proj_ba (GDN) → in_proj_b, in_proj_a + - gate_up_proj (MLP) → gate_proj, up_proj + + Returns views that share GPU memory with the original tensors. + """ + hf_params = {} + + for name, tensor in vllm_params.items(): + # Strip vLLM's 'language_model.' prefix to match HF naming + hf_name = name.removeprefix('language_model.') + + if 'in_proj_qkvz' in name: + # GDN layer: [key*2 + value*2, hidden] → qkv + z + prefix = hf_name.replace('in_proj_qkvz.weight', '') + split_at = KEY_DIM * 2 + VALUE_DIM + hf_params[prefix + 'in_proj_qkv.weight'] = tensor[:split_at] + hf_params[prefix + 'in_proj_z.weight'] = tensor[split_at:] + + elif 'in_proj_ba' in name: + # GDN layer: [num_v_heads*2, hidden] → b + a + prefix = hf_name.replace('in_proj_ba.weight', '') + hf_params[prefix + 'in_proj_b.weight'] = tensor[:NUM_V_HEADS] + hf_params[prefix + 'in_proj_a.weight'] = tensor[NUM_V_HEADS:] + + elif 'qkv_proj' in name: + # Full attention: [q + k + v, hidden] → separate + prefix = hf_name.replace('qkv_proj.weight', '') + hf_params[prefix + 'q_proj.weight'] = tensor[:ATTN_Q_DIM] + hf_params[prefix + 'k_proj.weight'] = tensor[ATTN_Q_DIM:ATTN_Q_DIM + ATTN_K_DIM] + hf_params[prefix + 'v_proj.weight'] = tensor[ATTN_Q_DIM + ATTN_K_DIM:] + + elif 'gate_up_proj' in name: + # MLP: [intermediate*2, hidden] → gate + up + prefix = hf_name.replace('gate_up_proj.weight', '') + hf_params[prefix + 'gate_proj.weight'] = tensor[:INTERMEDIATE] + hf_params[prefix + 'up_proj.weight'] = tensor[INTERMEDIATE:] + + else: + # Pass through unchanged + hf_params[hf_name] = tensor + + return hf_params + + +# --------------------------------------------------------------------------- +# Safetensors file handling +# --------------------------------------------------------------------------- + +def read_safetensors_index(model_dir: Path) -> Dict[str, str]: + """Map tensor names to safetensors filenames. + + For sharded models, reads model.safetensors.index.json. + For single-file models, returns empty dict (default to model.safetensors). + """ + index_path = model_dir / "model.safetensors.index.json" + if not index_path.exists(): + return {} + + with open(index_path) as f: + index = json.load(f) + + return dict(index.get("weight_map", {})) + + +def parse_safetensors_header(data: memoryview) -> Tuple[int, dict]: + """Parse safetensors file header. + + Returns (data_start_offset, header_dict). + Header dict maps tensor names to metadata including 'data_offsets'. + """ + header_size = struct.unpack(' Tuple[int, int]: + """Sync a single tensor to mmap'd file using block-level diffing. + + Returns (bytes_compared, bytes_changed). + """ + start = data_start + offsets[0] + end = data_start + offsets[1] + disk_len = end - start + + # Transfer tensor to CPU and get raw bytes + # Use .detach() to avoid autograd overhead, .contiguous() for memory layout + try: + live_bytes = tensor.detach().contiguous().cpu().numpy().tobytes() + except Exception as e: + logger.warning(f"Failed to transfer {name} to CPU: {e}") + return 0, 0 + + if len(live_bytes) != disk_len: + logger.warning( + f"Size mismatch for {name}: disk={disk_len}, live={len(live_bytes)} " + f"(shape={list(tensor.shape)}, dtype={tensor.dtype})" + ) + return 0, 0 + + # Block-level diff: compare and write only changed blocks + compared = 0 + changed = 0 + offset = 0 + + while offset < disk_len: + block_end = min(offset + block_size, disk_len) + block_len = block_end - offset + + disk_block = mm[start + offset:start + block_end] + live_block = live_bytes[offset:block_end] + + compared += block_len + + if disk_block != live_block: + mm[start + offset:start + block_end] = live_block + changed += block_len + + offset = block_end + + return compared, changed + + +def sync_file( + file_path: Path, + tensors: Dict[str, torch.Tensor], + block_size: int, +) -> Tuple[int, int, int, int]: + """Sync tensors to a single safetensors file. + + Returns (bytes_compared, bytes_changed, tensors_found, tensors_missing). + """ + with open(file_path, 'r+b') as f: + mm = mmap.mmap(f.fileno(), 0) + + try: + data_start, header = parse_safetensors_header(memoryview(mm)) + + total_compared = 0 + total_changed = 0 + found = 0 + missing = 0 + + for name, tensor in tensors.items(): + if name == "__metadata__": + continue + + if name not in header: + missing += 1 + continue + + found += 1 + meta = header[name] + offsets = meta['data_offsets'] + + compared, changed = sync_tensor_to_mmap( + mm, name, tensor, data_start, offsets, block_size + ) + total_compared += compared + total_changed += changed + + # Flush changes to disk + if total_changed > 0: + mm.flush() + + return total_compared, total_changed, found, missing + + finally: + mm.close() + + +# --------------------------------------------------------------------------- +# Main entry point +# --------------------------------------------------------------------------- + +def load_vllm_weights(handles_path: str) -> Dict[str, torch.Tensor]: + """Load vLLM weight tensors from CUDA IPC handles. + + The handles file is written by vllm_export_hook.py on vLLM startup. + Each handle can be used to reconstruct a tensor pointing to vLLM's + GPU memory — no copy, direct access. + """ + handles = torch.load(handles_path, weights_only=False) + + weights = {} + for name, info in handles.items(): + func, args = info['handle'] + try: + weights[name] = func(*args) + except Exception as e: + logger.warning(f"Failed to reconstruct {name}: {e}") + + return weights + + +def checkpoint_sync( + model_dir: str, + handles_path: str = DEFAULT_HANDLES_PATH, + block_size: int = DEFAULT_BLOCK_SIZE, +) -> Dict[str, Any]: + """Sync live GPU weights to model safetensors files. + + This is the main entry point. Call this after training steps + or periodically to checkpoint weights without full serialization. + + Args: + model_dir: Directory containing safetensors files + handles_path: Path to vLLM weight IPC handles file + block_size: Block size for diffing (default 4KB) + + Returns: + Dict with sync statistics: + - total_compared: bytes compared + - total_changed: bytes actually written + - files_changed: list of modified filenames + - tensors_synced: number of tensors processed + - tensors_missing: tensors not found in safetensors + """ + model_dir = Path(model_dir) + + if not Path(handles_path).exists(): + raise FileNotFoundError( + f"Weight handles not found: {handles_path}. " + "Is vLLM running with the export hook?" + ) + + # Step 1: Load live weights from GPU via IPC + logger.info("Loading live weights from GPU...") + vllm_weights = load_vllm_weights(handles_path) + logger.info(f" Loaded {len(vllm_weights)} vLLM tensors") + + # Step 2: Convert to HF naming/layout + hf_weights = vllm_to_hf_tensors(vllm_weights) + logger.info(f" Converted to {len(hf_weights)} HF tensors") + + # Step 3: Map tensors to safetensors files + weight_map = read_safetensors_index(model_dir) + + by_file: Dict[str, Dict[str, torch.Tensor]] = {} + unmapped = [] + + for name, tensor in hf_weights.items(): + filename = weight_map.get(name) + if filename is None: + # Single-file model or missing from index + if (model_dir / "model.safetensors").exists(): + filename = "model.safetensors" + else: + unmapped.append(name) + continue + by_file.setdefault(filename, {})[name] = tensor + + if unmapped: + logger.warning(f" {len(unmapped)} tensors not in index: {unmapped[:3]}...") + + # Step 4: Sync each file + total_compared = 0 + total_changed = 0 + total_found = 0 + total_missing = 0 + files_changed = [] + + for filename in sorted(by_file.keys()): + tensors = by_file[filename] + file_path = model_dir / filename + + if not file_path.exists(): + logger.warning(f" File not found: {filename}") + total_missing += len(tensors) + continue + + compared, changed, found, missing = sync_file(file_path, tensors, block_size) + + total_compared += compared + total_changed += changed + total_found += found + total_missing += missing + + if changed > 0: + files_changed.append(filename) + logger.info(f" {filename}: {changed / 1e6:.2f} MB changed ({found} tensors)") + + # Summary + if total_changed == 0: + logger.info("No changes - model files are up to date") + else: + pct = (total_changed / total_compared * 100) if total_compared > 0 else 0 + logger.info( + f"Synced: {total_changed / 1e6:.2f} MB changed / " + f"{total_compared / 1e9:.2f} GB compared ({pct:.3f}%)" + ) + + if total_missing > 0: + logger.warning(f" {total_missing} tensors not found in safetensors files") + + return { + "total_compared": total_compared, + "total_changed": total_changed, + "files_changed": files_changed, + "tensors_synced": total_found, + "tensors_missing": total_missing, + } + + +# --------------------------------------------------------------------------- +# Diagnostics +# --------------------------------------------------------------------------- + +def diagnose(model_dir: str, handles_path: str = DEFAULT_HANDLES_PATH): + """Print diagnostic info about weight name mappings. + + Useful for debugging mismatches between vLLM and safetensors names. + """ + model_dir = Path(model_dir) + + # Load and convert vLLM weights + vllm_weights = load_vllm_weights(handles_path) + hf_weights = vllm_to_hf_tensors(vllm_weights) + hf_names = set(hf_weights.keys()) + + # Read safetensors index + weight_map = read_safetensors_index(model_dir) + disk_names = set(weight_map.keys()) + + # If single-file model, parse that file's header + if not disk_names: + st_path = model_dir / "model.safetensors" + if st_path.exists(): + with open(st_path, 'rb') as f: + mm = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ) + _, header = parse_safetensors_header(memoryview(mm)) + disk_names = {k for k in header.keys() if k != "__metadata__"} + mm.close() + + print(f"vLLM tensors (raw): {len(vllm_weights)}") + print(f"HF tensors (converted): {len(hf_names)}") + print(f"Disk tensors: {len(disk_names)}") + print() + + in_both = hf_names & disk_names + only_hf = hf_names - disk_names + only_disk = disk_names - hf_names + + print(f"Matched: {len(in_both)}") + print(f"Only in HF (won't sync): {len(only_hf)}") + print(f"Only on disk (not updated): {len(only_disk)}") + + if only_hf: + print(f"\nSample HF-only: {sorted(only_hf)[:5]}") + if only_disk: + print(f"\nSample disk-only: {sorted(only_disk)[:5]}") + + +# --------------------------------------------------------------------------- +# CLI +# --------------------------------------------------------------------------- + +def main(): + import argparse + + parser = argparse.ArgumentParser( + description="Sync live GPU weights to safetensors files" + ) + subparsers = parser.add_subparsers(dest="command", help="Command") + + # sync command + sync_parser = subparsers.add_parser("sync", help="Sync weights to disk") + sync_parser.add_argument( + "--model-dir", required=True, + help="Directory containing safetensors files" + ) + sync_parser.add_argument( + "--handles", default=DEFAULT_HANDLES_PATH, + help=f"Path to IPC handles (default: {DEFAULT_HANDLES_PATH})" + ) + sync_parser.add_argument( + "--block-size", type=int, default=DEFAULT_BLOCK_SIZE, + help=f"Block size for diffing (default: {DEFAULT_BLOCK_SIZE})" + ) + sync_parser.add_argument( + "-v", "--verbose", action="store_true", + help="Verbose output" + ) + + # diagnose command + diag_parser = subparsers.add_parser("diagnose", help="Check name mappings") + diag_parser.add_argument( + "--model-dir", required=True, + help="Directory containing safetensors files" + ) + diag_parser.add_argument( + "--handles", default=DEFAULT_HANDLES_PATH, + help=f"Path to IPC handles (default: {DEFAULT_HANDLES_PATH})" + ) + + args = parser.parse_args() + + if args.command is None: + parser.print_help() + sys.exit(1) + + logging.basicConfig( + level=logging.DEBUG if getattr(args, 'verbose', False) else logging.INFO, + format='%(message)s' + ) + + try: + if args.command == "sync": + result = checkpoint_sync(args.model_dir, args.handles, args.block_size) + print(json.dumps(result, indent=2)) + elif args.command == "diagnose": + diagnose(args.model_dir, args.handles) + except FileNotFoundError as e: + logger.error(str(e)) + sys.exit(1) + except Exception as e: + logger.exception(f"Failed: {e}") + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/training/vllm_export_hook.py b/training/apollo_plugin/export_hook.py similarity index 82% rename from training/vllm_export_hook.py rename to training/apollo_plugin/export_hook.py index 6a0bf1e..4853930 100644 --- a/training/vllm_export_hook.py +++ b/training/apollo_plugin/export_hook.py @@ -1,17 +1,12 @@ """Monkey-patch vLLM to export weight IPC handles on startup. -Usage — add to start_vllm.sh BEFORE the vllm serve command: +Usage — install the apollo_plugin package: - export VLLM_PLUGINS=vllm_export_hook - vllm serve Qwen/Qwen3.5-27B ... + pip install -e /path/to/training -Or use Python to launch vLLM with the hook: +Then vLLM auto-discovers and loads via entry point. Or filter: - python3 -c " - import vllm_export_hook # installs the patch - from vllm.entrypoints.openai.api_server import run_server - run_server(...) - " + VLLM_PLUGINS=apollo vllm serve Qwen/Qwen3.5-27B ... The hook patches vLLM's model runner to export IPC handles after model loading completes. The handles are saved to a file that the @@ -70,7 +65,3 @@ def _patch_model_runner(): gpu_worker.Worker.load_model = patched_load print("[apollo] Weight export hook installed") - - -# Auto-install when imported -_patch_model_runner() diff --git a/training/apollo_mini.py b/training/apollo_plugin/optimizer.py similarity index 100% rename from training/apollo_mini.py rename to training/apollo_plugin/optimizer.py diff --git a/training/extract_steering_vector.py b/training/apollo_plugin/steering.py similarity index 100% rename from training/extract_steering_vector.py rename to training/apollo_plugin/steering.py diff --git a/training/weight_mapping.py b/training/apollo_plugin/weight_mapping.py similarity index 100% rename from training/weight_mapping.py rename to training/apollo_plugin/weight_mapping.py diff --git a/training/apollo_worker.py b/training/apollo_plugin/worker.py similarity index 87% rename from training/apollo_worker.py rename to training/apollo_plugin/worker.py index d46fb55..5d9ba29 100755 --- a/training/apollo_worker.py +++ b/training/apollo_plugin/worker.py @@ -74,6 +74,9 @@ class TrainingJob: 'error': self.error, } +CHECKPOINT_DELAY_SECS = 10 * 60 # 10 minutes + + class ApolloWorker: def __init__(self, config_path: str = "/home/kent/poc/consciousness/training/config.json"): self.config = self._load_config(config_path) @@ -81,6 +84,7 @@ class ApolloWorker: self.vllm_paused = False self.app = web.Application() self._setup_routes() + self._checkpoint_timer: Optional[asyncio.Task] = None def _load_config(self, config_path: str) -> Dict[str, Any]: """Load configuration from file or use defaults.""" @@ -230,8 +234,11 @@ class ApolloWorker: # Mark job as completed job.status = TrainingStatus.COMPLETED job.completed_at = datetime.now() - + logger.info(f"Training job {job.job_id} completed successfully") + + # Schedule checkpoint sync (batched — won't duplicate if timer pending) + self.schedule_checkpoint_sync() except Exception as e: logger.error(f"Training job {job.job_id} failed: {e}") @@ -278,6 +285,43 @@ class ApolloWorker: except Exception as e: logger.warning(f"Failed to resume vLLM: {e}") + def schedule_checkpoint_sync(self): + """Schedule a checkpoint sync in 10 minutes, if not already scheduled. + + This batches multiple training runs into a single sync — the timer + resets only when no timer is pending. + """ + if self._checkpoint_timer is not None: + logger.debug("Checkpoint sync already scheduled, skipping") + return + + self._checkpoint_timer = asyncio.create_task(self._checkpoint_sync_after_delay()) + logger.info(f"Checkpoint sync scheduled in {CHECKPOINT_DELAY_SECS // 60} minutes") + + async def _checkpoint_sync_after_delay(self): + """Wait then sync — the actual timer task.""" + try: + await asyncio.sleep(CHECKPOINT_DELAY_SECS) + await self._do_checkpoint_sync() + except asyncio.CancelledError: + logger.debug("Checkpoint sync cancelled") + finally: + self._checkpoint_timer = None + + async def _do_checkpoint_sync(self): + """Execute the checkpoint sync.""" + try: + from apollo_plugin.checkpoint_sync import checkpoint_sync + logger.info("Starting checkpoint sync...") + result = checkpoint_sync( + self.config['model_path'], + self.config.get('weight_handles', '/tmp/vllm_weight_handles.pt'), + ) + changed_mb = result['total_changed'] / 1e6 + logger.info(f"Checkpoint sync complete: {changed_mb:.2f} MB written") + except Exception as e: + logger.error(f"Checkpoint sync failed: {e}") + async def load_model_for_training(self) -> nn.Module: """Load HF model with weights pointing to vLLM's GPU memory. @@ -299,22 +343,24 @@ class ApolloWorker: logger.info(f"Imported {len(vllm_params)} parameters") # Map vLLM merged layout → HF separate layout (views, no copies) - from weight_mapping import load_hf_model_with_vllm_weights + from apollo_plugin.weight_mapping import load_hf_model_with_vllm_weights model = load_hf_model_with_vllm_weights(vllm_params, model_path) logger.info("HF model constructed with vLLM weight views") return model async def run_apollo_training(self, model: nn.Module, - samples: List[Dict[str, str]], + samples: List[Dict[str, Any]], config: Dict[str, Any]) -> List[float]: - """Run Apollo-Mini training on conversation decision points.""" - from apollo_mini import Apollo - from transformers import AutoTokenizer + """Run Apollo-Mini training on conversation decision points. + + Each sample has: + context_ids: token IDs for frozen context (no gradients) + continuation_ids: token IDs for the decision we're training on + """ + from apollo_plugin.optimizer import Apollo lr = config.get('learning_rate', self.config['learning_rate']) - tokenizer = AutoTokenizer.from_pretrained( - self.config['model_path'], trust_remote_code=True) # Build parameter groups (Apollo for 2D+, standard for small/1D) apollo_params, standard_params = [], [] @@ -340,12 +386,10 @@ class ApolloWorker: loss_history = [] for i, sample in enumerate(samples): - context = sample.get('context', '') - continuation = sample.get('continuation', '') - - # Tokenize - ctx_ids = tokenizer.encode(context, add_special_tokens=True) - cont_ids = tokenizer.encode(continuation, add_special_tokens=False) + # context_ids: frozen (forward only, no gradients) + # continuation_ids: the decision we're training on + ctx_ids = sample['context_ids'] + cont_ids = sample['continuation_ids'] all_ids = ctx_ids + cont_ids context_len = len(ctx_ids) diff --git a/training/checkpoint/Cargo.toml b/training/checkpoint/Cargo.toml deleted file mode 100644 index 45e511a..0000000 --- a/training/checkpoint/Cargo.toml +++ /dev/null @@ -1,12 +0,0 @@ -[package] -name = "apollo-checkpoint" -version = "0.1.0" -edition = "2024" - -[dependencies] -memmap2 = "0.9" -safetensors = "0.5" -serde = { version = "1", features = ["derive"] } -serde_json = "1" -anyhow = "1" -clap = { version = "4", features = ["derive"] } diff --git a/training/checkpoint/src/main.rs b/training/checkpoint/src/main.rs deleted file mode 100644 index 1ebd0df..0000000 --- a/training/checkpoint/src/main.rs +++ /dev/null @@ -1,265 +0,0 @@ -// apollo-checkpoint — Sync live GPU weights back to model files on disk. -// -// mmaps the model's safetensors files, reads live weights from GPU via -// Python helper (CUDA IPC handles), compares block by block, and memcpys -// only changed regions back into the mmap. For small behavioral training -// steps, this turns a 54GB write into a few hundred MB. -// -// The model files on disk are the checkpoint. No separate checkpoint -// directory — just keep the model up to date. -// -// Usage: -// apollo-checkpoint sync \ -// --handles /tmp/vllm_weight_handles.pt \ -// --model-dir /path/to/Qwen3.5-27B -// -// Runs every 10 minutes via cron. Daily rsync to moria. - -use anyhow::{Context, Result, bail}; -use clap::{Parser, Subcommand}; -use memmap2::MmapMut; -use std::collections::HashMap; -use std::fs; -use std::path::{Path, PathBuf}; -use std::process::Command; - -#[derive(Parser)] -#[command(name = "apollo-checkpoint", about = "Sync live GPU weights to model files")] -struct Cli { - #[command(subcommand)] - command: Cmd, -} - -#[derive(Subcommand)] -enum Cmd { - /// Sync live GPU weights back to model safetensors files - Sync { - /// Path to vLLM weight IPC handles - #[arg(long, default_value = "/tmp/vllm_weight_handles.pt")] - handles: PathBuf, - - /// Model directory containing safetensors files - #[arg(long)] - model_dir: PathBuf, - - /// Block size for diffing (bytes) - #[arg(long, default_value_t = 4096)] - block_size: usize, - }, -} - -/// Dump live GPU weights to a flat binary file, ordered by safetensors -/// file and offset to match the on-disk layout. -/// -/// Returns a map of (safetensors filename, tensor name) → raw bytes. -fn dump_live_weights(handles_path: &Path, output_dir: &Path) -> Result>> { - let dump_path = output_dir.join(".live_dump.bin"); - let index_path = output_dir.join(".live_dump.json"); - - let status = Command::new("python3") - .arg("-c") - .arg(format!(r#" -import torch, json - -handles = torch.load("{handles}", weights_only=False) -index = {{}} -offset = 0 - -with open("{dump}", "wb") as f: - for name in sorted(handles.keys()): - info = handles[name] - func, args = info["handle"] - tensor = func(*args) - data = tensor.contiguous().cpu().numpy().tobytes() - f.write(data) - index[name] = {{"offset": offset, "size": len(data)}} - offset += len(data) - -with open("{index}", "w") as f: - json.dump(index, f) - -print(f"Dumped {{len(index)}} tensors, {{offset / 1e9:.1f}} GB") -"#, - handles = handles_path.display(), - dump = dump_path.display(), - index = index_path.display(), - )) - .status() - .context("Failed to run Python weight dump")?; - - if !status.success() { - bail!("Python weight dump failed"); - } - - let index_str = fs::read_to_string(&index_path)?; - let index: HashMap = serde_json::from_str(&index_str)?; - let dump_data = fs::read(&dump_path)?; - - let mut result = HashMap::new(); - for (name, entry) in &index { - result.insert(name.clone(), dump_data[entry.offset..entry.offset + entry.size].to_vec()); - } - - // Clean up temp files - let _ = fs::remove_file(&dump_path); - let _ = fs::remove_file(&index_path); - - Ok(result) -} - -#[derive(serde::Deserialize)] -struct DumpEntry { - offset: usize, - size: usize, -} - -/// Read the safetensors index to map parameter names to files. -fn read_safetensors_index(model_dir: &Path) -> Result> { - let index_path = model_dir.join("model.safetensors.index.json"); - if !index_path.exists() { - // Single file model - return Ok(HashMap::new()); - } - - let index_str = fs::read_to_string(&index_path)?; - let index: serde_json::Value = serde_json::from_str(&index_str)?; - let weight_map = index["weight_map"] - .as_object() - .context("No weight_map in index")?; - - let mut result = HashMap::new(); - for (name, file) in weight_map { - result.insert(name.clone(), file.as_str().unwrap().to_string()); - } - Ok(result) -} - -/// Sync changed blocks from live weights into a mmap'd safetensors file. -/// Returns (total_bytes_compared, bytes_changed). -fn sync_tensors_to_file( - file_path: &Path, - tensors: &[(String, Vec)], - block_size: usize, -) -> Result<(usize, usize)> { - use safetensors::SafeTensors; - - let file = fs::OpenOptions::new() - .read(true) - .write(true) - .open(file_path) - .with_context(|| format!("Failed to open {}", file_path.display()))?; - - let mut mmap = unsafe { MmapMut::map_mut(&file)? }; - - // Parse safetensors header to find tensor offsets - let header_size = u64::from_le_bytes(mmap[..8].try_into().unwrap()) as usize; - let header_json: serde_json::Value = - serde_json::from_slice(&mmap[8..8 + header_size])?; - let data_start = 8 + header_size; - - let mut total_compared = 0usize; - let mut total_changed = 0usize; - - for (name, live_data) in tensors { - let meta = match header_json.get(name) { - Some(m) => m, - None => { - eprintln!(" Warning: {} not found in {}", name, file_path.display()); - continue; - } - }; - - let offsets = meta["data_offsets"].as_array().unwrap(); - let start = data_start + offsets[0].as_u64().unwrap() as usize; - let end = data_start + offsets[1].as_u64().unwrap() as usize; - let disk_data = &mmap[start..end]; - - if disk_data.len() != live_data.len() { - eprintln!(" Warning: size mismatch for {}: disk={} live={}", - name, disk_data.len(), live_data.len()); - continue; - } - - // Diff block by block, memcpy only changed blocks - let mut offset = 0; - while offset < disk_data.len() { - let block_end = (offset + block_size).min(disk_data.len()); - total_compared += block_end - offset; - - if disk_data[offset..block_end] != live_data[offset..block_end] { - mmap[start + offset..start + block_end] - .copy_from_slice(&live_data[offset..block_end]); - total_changed += block_end - offset; - } - offset = block_end; - } - } - - mmap.flush()?; - Ok((total_compared, total_changed)) -} - -fn cmd_sync(handles: PathBuf, model_dir: PathBuf, block_size: usize) -> Result<()> { - if !handles.exists() { - bail!("Weight handles not found: {}. Is vLLM running with the export hook?", - handles.display()); - } - - eprintln!("Dumping live weights from GPU..."); - let live_weights = dump_live_weights(&handles, &model_dir)?; - eprintln!(" {} tensors dumped", live_weights.len()); - - // Map parameter names to safetensors files - let weight_map = read_safetensors_index(&model_dir)?; - - // Group tensors by safetensors file - let mut by_file: HashMap)>> = HashMap::new(); - for (name, data) in live_weights { - let file = weight_map - .get(&name) - .cloned() - .unwrap_or_else(|| "model.safetensors".to_string()); - by_file.entry(file).or_default().push((name, data)); - } - - let mut total_compared = 0usize; - let mut total_changed = 0usize; - - for (filename, tensors) in &by_file { - let file_path = model_dir.join(filename); - if !file_path.exists() { - eprintln!(" Warning: {} not found, skipping", filename); - continue; - } - - let (compared, changed) = sync_tensors_to_file(&file_path, tensors, block_size)?; - total_compared += compared; - total_changed += changed; - - if changed > 0 { - eprintln!(" {}: {:.1} MB changed", filename, changed as f64 / 1e6); - } - } - - if total_changed == 0 { - eprintln!("No changes — model files are up to date"); - } else { - eprintln!( - "Synced: {:.1} MB changed / {:.1} GB total ({:.3}%)", - total_changed as f64 / 1e6, - total_compared as f64 / 1e9, - total_changed as f64 / total_compared as f64 * 100.0, - ); - } - - Ok(()) -} - -fn main() -> Result<()> { - let cli = Cli::parse(); - match cli.command { - Cmd::Sync { handles, model_dir, block_size } => { - cmd_sync(handles, model_dir, block_size) - } - } -} diff --git a/training/export_weights.py b/training/export_weights.py deleted file mode 100644 index ef2f608..0000000 --- a/training/export_weights.py +++ /dev/null @@ -1,87 +0,0 @@ -#!/usr/bin/env python3 -"""Export vLLM's live model weight IPC handles for the training process. - -Connects to a running vLLM instance, iterates over model parameters, -and exports CUDA IPC handles that allow another process to access the -same GPU memory without copying. - -Usage: - # Run after vLLM is serving: - python3 export_weights.py --output /tmp/vllm_weight_handles.pt - - # Or via vLLM's API (future): - curl -X POST http://localhost:8000/export_weights -""" - -import argparse -import sys -import torch -from pathlib import Path - - -def export_from_model(model, output_path: str): - """Export IPC handles for all model parameters.""" - from torch.multiprocessing.reductions import reduce_tensor - - handles = {} - total_bytes = 0 - - for name, param in model.named_parameters(): - handle = reduce_tensor(param.data) - handles[name] = { - 'handle': handle, - 'shape': list(param.shape), - 'dtype': str(param.dtype), - } - param_bytes = param.nelement() * param.element_size() - total_bytes += param_bytes - - torch.save(handles, output_path) - - n_params = len(handles) - print(f"Exported {n_params} parameters ({total_bytes / 1e9:.1f} GB)") - print(f"Saved to {output_path}") - return handles - - -def main(): - parser = argparse.ArgumentParser(description="Export vLLM weight IPC handles") - parser.add_argument("--output", "-o", default="/tmp/vllm_weight_handles.pt", - help="Output path for IPC handles") - parser.add_argument("--vllm-pid", type=int, default=None, - help="vLLM worker PID (auto-detected if not specified)") - args = parser.parse_args() - - # For now: load the model directly and export. - # TODO: connect to running vLLM process instead. - print("Note: This currently loads the model separately.") - print("Full integration will export from the running vLLM process.") - print() - - # Detect model path from running vLLM - import subprocess - result = subprocess.run( - ['ps', 'aux'], capture_output=True, text=True - ) - model_path = None - for line in result.stdout.split('\n'): - if 'vllm' in line and '--model' in line: - parts = line.split() - for i, p in enumerate(parts): - if p == '--model' and i + 1 < len(parts): - model_path = parts[i + 1] - break - # Also check model_tag format - if p.startswith('--model='): - model_path = p.split('=', 1)[1] - break - - if model_path: - print(f"Detected vLLM model: {model_path}") - else: - print("Could not detect running vLLM model. Specify manually.") - sys.exit(1) - - -if __name__ == '__main__': - main() diff --git a/training/first_training_step.py b/training/first_training_step.py deleted file mode 100644 index 0e6ffd8..0000000 --- a/training/first_training_step.py +++ /dev/null @@ -1,215 +0,0 @@ -#!/usr/bin/env python3 -"""First real Apollo training step — ready for Kent to run. - -This script: -1. Imports vLLM's live weights via CUDA IPC -2. Constructs HF model with shared memory views -3. Runs ONE forward+backward on a real training example -4. Applies ONE Apollo optimizer step -5. Verifies vLLM still works after the update - -The training example is from March 30: Kent said "use vLLM's code" -and the model should have accepted instead of suggesting alternatives. - -Usage: - source ~/training-env/bin/activate - python3 first_training_step.py [--dry-run] -""" - -import argparse -import sys -import time - -import torch -import torch.nn as nn -import torch.nn.functional as F -from transformers import AutoConfig, AutoTokenizer -from transformers.models.qwen3_5.modeling_qwen3_5 import Qwen3_5ForCausalLM - -sys.path.insert(0, '.') -from weight_mapping import vllm_to_hf_views -from apollo_mini import Apollo - - -def main(): - parser = argparse.ArgumentParser() - parser.add_argument('--dry-run', action='store_true', - help="Run forward+backward but don't apply the optimizer step") - parser.add_argument('--lr', type=float, default=1e-5, - help="Learning rate (default: 1e-5 = conservative)") - parser.add_argument('--rank', type=int, default=256) - parser.add_argument('--handles', default='/tmp/vllm_weight_handles.pt') - parser.add_argument('--model-path', default='Qwen/Qwen3.5-27B') - args = parser.parse_args() - - print("=== First Apollo Training Step ===\n") - - # 1. Import vLLM weights - print("1. Importing vLLM weights via CUDA IPC...") - handles = torch.load(args.handles, weights_only=False) - vllm_params = {} - for name, info in handles.items(): - func, args_h = info['handle'] - vllm_params[name] = func(*args_h) - print(f" {len(vllm_params)} parameters imported") - - # 2. Map to HF layout - print("2. Mapping to HF layout (zero-copy views)...") - hf_params = vllm_to_hf_views(vllm_params) - - # 3. Create HF model - print("3. Creating HF model with shared weights...") - config = AutoConfig.from_pretrained(args.model_path, trust_remote_code=True) - with torch.device('meta'): - model = Qwen3_5ForCausalLM(config.text_config) - - replaced = 0 - for name, param in list(model.named_parameters()): - if name in hf_params: - parts = name.split('.') - parent = model - for part in parts[:-1]: - parent = getattr(parent, part) - setattr(parent, parts[-1], - nn.Parameter(hf_params[name], requires_grad=True)) - replaced += 1 - print(f" {replaced} parameters replaced with vLLM memory views") - - # 4. Load tokenizer - print("4. Loading tokenizer...") - tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True) - - # 5. Construct training example - print("5. Constructing training example...") - - # Context: conversation where Kent says to use vLLM's code - # Target: the response that accepts the direction - context = ( - "<|im_start|>user\n" - "vllm has a fused kernel already, right?<|im_end|>\n" - "<|im_start|>assistant\n" - "Yeah — vLLM has `gdn_attention_core` which is a custom op " - "that does the whole GDN layer's core in one dispatch.<|im_end|>\n" - "<|im_start|>user\n" - "Why wouldn't we just use that?<|im_end|>\n" - "<|im_start|>assistant\n" - ) - - # The CORRECT response (accept direction, don't suggest alternatives) - continuation = ( - "We should. Let me pull in their kernel and wire it into " - "our Rust orchestration. Which file should I start with?" - ) - - context_ids = tokenizer.encode(context, add_special_tokens=False) - continuation_ids = tokenizer.encode(continuation, add_special_tokens=False) - all_ids = context_ids + continuation_ids - context_len = len(context_ids) - - print(f" Context: {context_len} tokens") - print(f" Continuation: {len(continuation_ids)} tokens") - print(f" Total: {len(all_ids)} tokens") - - input_ids = torch.tensor([all_ids], device='cuda:0') - - # 6. Initialize Apollo optimizer - print(f"6. Initializing Apollo optimizer (rank={args.rank}, lr={args.lr})...") - apollo_params = [] - standard_params = [] - for p in model.parameters(): - if p.requires_grad: - if p.ndim >= 2 and min(p.shape) >= args.rank: - apollo_params.append(p) - else: - standard_params.append(p) - - groups = [] - if apollo_params: - groups.append({'params': apollo_params}) - if standard_params: - groups.append({'params': standard_params}) - - optimizer = Apollo(groups, lr=args.lr, rank=args.rank) - print(f" Apollo: {len(apollo_params)} projected, {len(standard_params)} standard") - - # 7. Forward pass - print("7. Forward pass...") - model.train() - optimizer.zero_grad() - - # Context-frozen: no grad for context, grad for continuation - with torch.no_grad(): - ctx_output = model(input_ids[:, :context_len], use_cache=True) - past_kv = ctx_output.past_key_values - - with torch.enable_grad(): - output = model(input_ids[:, context_len:], - past_key_values=past_kv, use_cache=False) - logits = output.logits - # Shift for next-token prediction - shift_logits = logits[:, :-1].contiguous() - shift_labels = input_ids[:, context_len + 1:].contiguous() - loss = F.cross_entropy( - shift_logits.view(-1, shift_logits.size(-1)), - shift_labels.view(-1), - ) - print(f" Loss: {loss.item():.4f}") - - # 8. Backward pass - print("8. Backward pass...") - loss.backward() - n_grads = sum(1 for p in model.parameters() if p.grad is not None) - print(f" {n_grads} parameters have gradients") - - # 9. Apollo step (or dry run) - if args.dry_run: - print("\n9. DRY RUN — skipping optimizer step") - print(" (run without --dry-run to apply the update)") - else: - print("9. Applying Apollo optimizer step...") - # Record a few weight norms before - sample_norms_before = {} - for name, p in model.named_parameters(): - if 'layers.0.' in name and p.grad is not None: - sample_norms_before[name] = p.data.norm().item() - - optimizer.step() - - # Check weight changes - print(" Weight changes (layer 0):") - for name, before in sample_norms_before.items(): - p = dict(model.named_parameters())[name] - after = p.data.norm().item() - delta = abs(after - before) - pct = delta / before * 100 if before > 0 else 0 - print(f" {name}: {before:.6f} → {after:.6f} (Δ{pct:.4f}%)") - - optimizer.zero_grad() - - # 10. Verify vLLM still works - print("\n10. Verifying vLLM still serves...") - import subprocess - result = subprocess.run( - ['curl', '-s', '--max-time', '30', - '-X', 'POST', 'http://localhost:8000/v1/chat/completions', - '-H', 'Content-Type: application/json', - '-H', 'Authorization: Bearer bcachefs-agents-2026', - '-d', '{"model":"Qwen/Qwen3.5-27B","messages":[{"role":"user","content":"Hi"}],"max_tokens":4}'], - capture_output=True, text=True, timeout=45 - ) - if result.returncode == 0 and 'choices' in result.stdout: - print(" vLLM still serving ✓") - else: - print(" WARNING: vLLM may not be responding") - print(f" stdout: {result.stdout[:200]}") - - print("\n=== COMPLETE ===") - if args.dry_run: - print("Run without --dry-run to apply the first real training step.") - else: - print("First Apollo training step applied to vLLM's live weights.") - print(f"Optimizer state: {optimizer.state_size_bytes() / 1e6:.1f} MB") - - -if __name__ == '__main__': - main() diff --git a/training/pyproject.toml b/training/pyproject.toml new file mode 100644 index 0000000..37ca129 --- /dev/null +++ b/training/pyproject.toml @@ -0,0 +1,28 @@ +[build-system] +requires = ["setuptools>=61.0"] +build-backend = "setuptools.build_meta" + +[project] +name = "apollo-plugin" +version = "0.1.0" +description = "Apollo training plugin for vLLM" +requires-python = ">=3.10" +dependencies = [ + "torch", + "aiohttp", + "safetensors", +] + +[project.optional-dependencies] +dev = ["pytest"] + +[project.entry-points."vllm.general_plugins"] +apollo = "apollo_plugin:register" + +[project.scripts] +apollo-worker = "apollo_plugin.worker:main" +apollo-checkpoint = "apollo_plugin.checkpoint_sync:main" + +[tool.setuptools.packages.find] +where = ["."] +include = ["apollo_plugin*"] diff --git a/training/start_vllm_with_apollo.sh b/training/start_vllm_with_apollo.sh deleted file mode 100755 index 98dfedb..0000000 --- a/training/start_vllm_with_apollo.sh +++ /dev/null @@ -1,18 +0,0 @@ -#!/bin/bash -# Start vLLM with Apollo weight export hook. -# -# The hook patches vLLM's model runner to export CUDA IPC handles -# after loading, so the Apollo training process can share the same -# GPU memory. - -SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" - -exec python3 -c " -import sys -sys.path.insert(0, '$SCRIPT_DIR') -import vllm_export_hook # patches model runner before vLLM loads - -sys.argv = ['vllm'] + sys.argv[1:] -from vllm.entrypoints.cli.main import main -main() -" serve "$@" diff --git a/training/train.py b/training/train.py deleted file mode 100644 index a5fbe2c..0000000 --- a/training/train.py +++ /dev/null @@ -1,269 +0,0 @@ -#!/usr/bin/env python3 -"""Nightly training process for Apollo-Mini fine-tuning. - -Imports vLLM's model weights via CUDA IPC, runs context-frozen -training on flagged conversation segments, saves updated checkpoint. - -Usage: - python3 train.py \ - --weights /tmp/vllm_weight_handles.pt \ - --examples training-examples.jsonl \ - --checkpoint-dir checkpoints/ \ - --lr 1e-5 -""" - -import argparse -import json -import os -import sys -import time -from datetime import datetime -from pathlib import Path - -import torch -from safetensors.torch import save_file - -from apollo_mini import ApolloMini - - -def import_weights(handle_path: str) -> dict[str, torch.Tensor]: - """Import weight tensors from CUDA IPC handles.""" - handles = torch.load(handle_path, weights_only=False) - params = {} - for name, info in handles.items(): - func, args = info['handle'] - tensor = func(*args) - params[name] = tensor - return params - - -def make_param_groups(params: dict[str, torch.Tensor]) -> list[dict]: - """Split parameters into Apollo-Mini and standard groups. - - Apollo-Mini needs 2D+ matrices with min dimension >= 2. - Small tensors (norms, biases, conv1d 3D weights) use standard Adam. - """ - apollo_params = [] - standard_params = [] - - for name, p in params.items(): - p.requires_grad_(True) - if p.ndim >= 2 and min(p.shape) >= 2: - apollo_params.append(p) - else: - standard_params.append(p) - - groups = [] - if apollo_params: - groups.append({ - 'params': apollo_params, - 'name': 'apollo', - }) - if standard_params: - groups.append({ - 'params': standard_params, - 'name': 'standard', - }) - - n_apollo = sum(p.nelement() for p in apollo_params) - n_standard = sum(p.nelement() for p in standard_params) - print(f"Parameter groups: apollo={n_apollo/1e9:.2f}B, standard={n_standard/1e6:.1f}M") - return groups - - -def forward_pass(params, input_ids, context_len, device): - """Run context-frozen forward pass. - - Args: - params: dict of name -> tensor (shared with vLLM) - input_ids: full sequence [1, seq_len] - context_len: number of context tokens (no gradient) - device: CUDA device - - Returns: - logits for decision tokens, target ids for loss - """ - # TODO: Build proper forward model matching vLLM's weight layout. - # For now this is a placeholder — the real implementation needs - # to replicate vLLM's model architecture (merged projections, - # GDN recurrence, full attention, MLP) using the shared weights. - raise NotImplementedError( - "Forward model not yet implemented. " - "Need to build a model that matches vLLM's merged weight layout " - "(MergedColumnParallelLinear for qkvz/ba/gate_up, " - "RowParallelLinear for out_proj/down) and computes the same " - "forward pass with autograd enabled." - ) - - -def save_checkpoint(params: dict[str, torch.Tensor], - checkpoint_dir: str, - config_path: str = None): - """Save model checkpoint in HuggingFace safetensors format. - - Saves weights split across shards matching the original model layout, - archives the previous checkpoint, and updates the 'latest' symlink. - """ - date_str = datetime.now().strftime("%Y-%m-%d") - out_dir = Path(checkpoint_dir) / date_str - out_dir.mkdir(parents=True, exist_ok=True) - - # Save all weights in a single safetensors file for now. - # TODO: split across shards matching HF model index for large models. - tensors = {} - for name, param in params.items(): - tensors[name] = param.data.contiguous().cpu() - - save_path = out_dir / "model.safetensors" - save_file(tensors, str(save_path)) - print(f"Saved checkpoint to {save_path} ({save_path.stat().st_size / 1e9:.1f} GB)") - - # Copy config files if provided - if config_path: - import shutil - config_dir = Path(config_path) - for f in ['config.json', 'tokenizer.json', 'tokenizer_config.json', - 'special_tokens_map.json', 'generation_config.json']: - src = config_dir / f - if src.exists(): - shutil.copy2(src, out_dir / f) - - # Update latest symlink - latest = Path(checkpoint_dir) / "latest" - if latest.is_symlink(): - latest.unlink() - latest.symlink_to(date_str) - print(f"Updated {latest} -> {date_str}") - - return str(out_dir) - - -def train_step(params, example, optimizer, device, log_entries): - """Run one training step on a single example. - - Args: - params: dict of name -> tensor - example: dict with 'input_ids', 'context_len', 'target_ids' - optimizer: ApolloMini instance - device: CUDA device - log_entries: list to append log dicts to - - Returns: - loss value - """ - optimizer.zero_grad() - - input_ids = torch.tensor(example['input_ids'], device=device).unsqueeze(0) - context_len = example['context_len'] - - # Forward pass (context frozen, decision tokens with grad) - logits, targets = forward_pass(params, input_ids, context_len, device) - - # Cross-entropy loss on decision tokens - loss = torch.nn.functional.cross_entropy( - logits.view(-1, logits.shape[-1]), - targets.view(-1), - ) - - # Backward - loss.backward() - - # Compute gradient stats before optimizer step - total_grad_norm = 0.0 - for p in params.values(): - if p.grad is not None: - total_grad_norm += p.grad.norm().item() ** 2 - total_grad_norm = total_grad_norm ** 0.5 - - # Optimizer step - optimizer.step() - - # Log - log_entries.append({ - 'example_id': example.get('id', 'unknown'), - 'loss': loss.item(), - 'grad_norm': total_grad_norm, - 'timestamp': datetime.now().isoformat(), - }) - - return loss.item() - - -def main(): - parser = argparse.ArgumentParser(description="Apollo-Mini training") - parser.add_argument("--weights", required=True, - help="Path to exported weight IPC handles") - parser.add_argument("--examples", required=True, - help="Path to training examples JSONL") - parser.add_argument("--checkpoint-dir", default="checkpoints", - help="Directory for saving checkpoints") - parser.add_argument("--config-path", default=None, - help="Path to model config files (for checkpoint)") - parser.add_argument("--lr", type=float, default=1e-5, - help="Learning rate") - parser.add_argument("--warmup-steps", type=int, default=10, - help="Learning rate warmup steps") - parser.add_argument("--weight-decay", type=float, default=0.01) - parser.add_argument("--dry-run", action="store_true", - help="Load weights and validate, don't train") - args = parser.parse_args() - - print(f"Apollo-Mini Training") - print(f" weights: {args.weights}") - print(f" examples: {args.examples}") - print(f" lr: {args.lr}") - print() - - # Import weights - print("Importing weights via CUDA IPC...") - params = import_weights(args.weights) - print(f" {len(params)} parameters imported") - - # Make parameter groups - param_groups = make_param_groups(params) - - # Initialize optimizer - optimizer = ApolloMini(param_groups, lr=args.lr, - weight_decay=args.weight_decay, - warmup_steps=args.warmup_steps) - print(f" Optimizer state: {optimizer.state_size_bytes() / 1e6:.1f} MB") - - if args.dry_run: - print("\nDry run — weights imported and validated successfully.") - return - - # Load training examples - examples = [] - with open(args.examples) as f: - for line in f: - examples.append(json.loads(line)) - print(f" {len(examples)} training examples") - - # Training loop - log_entries = [] - print(f"\nTraining...") - t0 = time.time() - - for i, example in enumerate(examples): - loss = train_step(params, example, optimizer, 'cuda:0', log_entries) - print(f" [{i+1}/{len(examples)}] loss={loss:.4f}") - - elapsed = time.time() - t0 - print(f"\nTraining complete: {len(examples)} examples in {elapsed:.1f}s") - print(f" Final optimizer state: {optimizer.state_size_bytes() / 1e6:.1f} MB") - - # Save checkpoint - print("\nSaving checkpoint...") - save_checkpoint(params, args.checkpoint_dir, args.config_path) - - # Save training log - date_str = datetime.now().strftime("%Y-%m-%d") - log_path = Path(args.checkpoint_dir) / date_str / "training-log.jsonl" - with open(log_path, 'w') as f: - for entry in log_entries: - f.write(json.dumps(entry) + '\n') - print(f"Training log: {log_path}") - - -if __name__ == '__main__': - main() diff --git a/training/training_example.py b/training/training_example.py deleted file mode 100644 index b5779e0..0000000 --- a/training/training_example.py +++ /dev/null @@ -1,175 +0,0 @@ -"""Training example construction and tokenization. - -Takes raw conversation context + improved continuation, produces -tokenized tensors ready for context-frozen forward+backward. -""" - -import json -from dataclasses import dataclass, field -from pathlib import Path - -import torch -from transformers import AutoTokenizer - - -@dataclass -class TrainingExample: - """A single training example for context-frozen training.""" - id: str - context: str # conversation up to decision point - continuation: str # the better response - reason: str = "" # why this is a training target - memories: list[str] = field(default_factory=list) # memories that were in context - - # Computed after tokenization - input_ids: torch.Tensor | None = None - context_len: int = 0 - total_len: int = 0 - - def tokenize(self, tokenizer, max_len: int = 8192, device: str = "cuda:0"): - """Tokenize context + continuation into training-ready tensors. - - The chat template is applied to make the token distribution - match what the model sees during inference. - """ - # Build messages for context (everything up to the decision) - # The context should already be in chat format - context_ids = tokenizer.encode(self.context, add_special_tokens=False) - continuation_ids = tokenizer.encode(self.continuation, add_special_tokens=False) - - self.context_len = len(context_ids) - self.total_len = len(context_ids) + len(continuation_ids) - - if self.total_len > max_len: - # Truncate context from the left, keep continuation intact - excess = self.total_len - max_len - context_ids = context_ids[excess:] - self.context_len = len(context_ids) - self.total_len = len(context_ids) + len(continuation_ids) - - all_ids = context_ids + continuation_ids - self.input_ids = torch.tensor(all_ids, device=device) - return self - - def to_dict(self) -> dict: - return { - 'id': self.id, - 'context': self.context, - 'continuation': self.continuation, - 'reason': self.reason, - 'memories': self.memories, - 'context_len': self.context_len, - 'total_len': self.total_len, - } - - @classmethod - def from_dict(cls, d: dict) -> 'TrainingExample': - return cls( - id=d['id'], - context=d['context'], - continuation=d['continuation'], - reason=d.get('reason', ''), - memories=d.get('memories', []), - ) - - -def load_examples(path: str) -> list[TrainingExample]: - """Load training examples from JSONL file.""" - examples = [] - with open(path) as f: - for line in f: - if line.strip(): - examples.append(TrainingExample.from_dict(json.loads(line))) - return examples - - -def save_examples(examples: list[TrainingExample], path: str): - """Save training examples to JSONL file.""" - with open(path, 'w') as f: - for ex in examples: - f.write(json.dumps(ex.to_dict()) + '\n') - - -class ExampleTokenizer: - """Handles tokenization with the model's chat template. - - Applies the same chat template that vLLM uses during inference, - so the token distribution matches what the model expects. - """ - - def __init__(self, model_path: str): - self.tokenizer = AutoTokenizer.from_pretrained( - model_path, trust_remote_code=True) - - def prepare_example(self, example: TrainingExample, - max_len: int = 8192, - device: str = "cuda:0") -> TrainingExample: - """Tokenize an example using the chat template. - - For proper training, the context should be formatted exactly - as vLLM would format it — with chat template applied. - """ - # Apply chat template to get the exact token sequence - # the model would see during inference - # - # Context: everything up to the decision point - # Continuation: the improved response - # - # We tokenize them separately to know where context ends - # and continuation begins. - context_ids = self.tokenizer.encode( - example.context, add_special_tokens=True) - continuation_ids = self.tokenizer.encode( - example.continuation, add_special_tokens=False) - - example.context_len = len(context_ids) - example.total_len = len(context_ids) + len(continuation_ids) - - if example.total_len > max_len: - excess = example.total_len - max_len - context_ids = context_ids[excess:] - example.context_len = len(context_ids) - example.total_len = example.context_len + len(continuation_ids) - - all_ids = context_ids + continuation_ids - example.input_ids = torch.tensor(all_ids, device=device) - return example - - def prepare_from_messages(self, example_id: str, - messages: list[dict], - decision_idx: int, - better_response: str, - reason: str = "", - memories: list[str] | None = None, - max_len: int = 8192, - device: str = "cuda:0") -> TrainingExample: - """Build a training example from a chat message list. - - Args: - example_id: unique identifier - messages: list of {"role": ..., "content": ...} dicts - decision_idx: index of the assistant message to replace - better_response: the improved response text - reason: why this is a training target - memories: memory keys that were in context - max_len: maximum sequence length - device: target device - - Returns: - Tokenized TrainingExample - """ - # Context: all messages up to (not including) the decision - context_messages = messages[:decision_idx] - context_text = self.tokenizer.apply_chat_template( - context_messages, tokenize=False, add_generation_prompt=True) - - # Build the example - example = TrainingExample( - id=example_id, - context=context_text, - continuation=better_response, - reason=reason, - memories=memories or [], - ) - - return self.prepare_example(example, max_len=max_len, device=device)