training: restructure as vLLM plugin package
- Convert to installable package with entry points for vLLM auto-discovery - Add checkpoint_sync.py: Python replacement for Rust checkpoint binary - Block-level diffing of safetensors files (4KB blocks) - vLLM→HF weight name conversion built-in - Scheduled 10min after training jobs (batched) - API change: /train now takes raw token IDs (context_ids + continuation_ids) - No tokenizer on training side, client owns tokenization - Remove superseded code: standalone scripts, Rust binary, tokenizer helpers Install: pip install -e ./training Then vLLM auto-loads via entry point. Co-Authored-By: Proof of Concept <poc@bcachefs.org>
This commit is contained in:
parent
b649a11645
commit
a73bcf5ae3
15 changed files with 607 additions and 1068 deletions
17
training/apollo_plugin/__init__.py
Normal file
17
training/apollo_plugin/__init__.py
Normal file
|
|
@ -0,0 +1,17 @@
|
|||
"""Apollo training plugin for vLLM.
|
||||
|
||||
Enables continuous fine-tuning alongside live inference by:
|
||||
1. Exporting CUDA IPC handles for weight sharing
|
||||
2. Providing a training worker daemon (/train endpoint)
|
||||
3. Block-level checkpoint sync to safetensors files
|
||||
|
||||
Install: pip install -e /path/to/training
|
||||
Then vLLM auto-loads via entry point.
|
||||
"""
|
||||
|
||||
from .export_hook import _patch_model_runner
|
||||
|
||||
|
||||
def register():
|
||||
"""Called by vLLM's plugin loader on startup."""
|
||||
_patch_model_runner()
|
||||
500
training/apollo_plugin/checkpoint_sync.py
Normal file
500
training/apollo_plugin/checkpoint_sync.py
Normal file
|
|
@ -0,0 +1,500 @@
|
|||
"""Sync live GPU weights to safetensors files on disk.
|
||||
|
||||
Reads vLLM weight tensors via CUDA IPC handles, converts from vLLM's
|
||||
merged layout to HuggingFace's separate layout, diffs block-by-block
|
||||
against on-disk safetensors files, and writes only changed blocks.
|
||||
|
||||
For small behavioral training steps, this turns a 54GB checkpoint
|
||||
write into a few hundred MB of actual disk I/O.
|
||||
|
||||
Usage:
|
||||
# Sync live weights to disk
|
||||
python checkpoint_sync.py sync --model-dir /path/to/Qwen3.5-27B
|
||||
|
||||
# Debug name mapping issues
|
||||
python checkpoint_sync.py diagnose --model-dir /path/to/Qwen3.5-27B
|
||||
|
||||
# From Python:
|
||||
from checkpoint_sync import checkpoint_sync
|
||||
result = checkpoint_sync("/path/to/model")
|
||||
"""
|
||||
|
||||
import json
|
||||
import mmap
|
||||
import struct
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Tuple, Any
|
||||
import logging
|
||||
|
||||
import torch
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_BLOCK_SIZE = 4096 # 4KB blocks — matches filesystem block size
|
||||
DEFAULT_HANDLES_PATH = "/tmp/vllm_weight_handles.pt"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# vLLM → HuggingFace weight name/shape conversion
|
||||
# ---------------------------------------------------------------------------
|
||||
# Qwen3.5-27B dimensions (could be read from config.json for generality)
|
||||
|
||||
HIDDEN = 5120
|
||||
NUM_K_HEADS = 16
|
||||
NUM_V_HEADS = 48
|
||||
HEAD_K_DIM = 128
|
||||
HEAD_V_DIM = 128
|
||||
KEY_DIM = NUM_K_HEADS * HEAD_K_DIM # 2048
|
||||
VALUE_DIM = NUM_V_HEADS * HEAD_V_DIM # 6144
|
||||
INTERMEDIATE = 17408
|
||||
|
||||
# Full attention (some layers use standard attention, not GDN)
|
||||
NUM_ATTN_HEADS = 24
|
||||
NUM_ATTN_KV_HEADS = 4
|
||||
ATTN_HEAD_DIM = 256
|
||||
ATTN_Q_HEAD_DIM = ATTN_HEAD_DIM * 2 # 512
|
||||
ATTN_Q_DIM = NUM_ATTN_HEADS * ATTN_Q_HEAD_DIM # 12288
|
||||
ATTN_K_DIM = NUM_ATTN_KV_HEADS * ATTN_HEAD_DIM # 1024
|
||||
ATTN_V_DIM = NUM_ATTN_KV_HEADS * ATTN_HEAD_DIM # 1024
|
||||
|
||||
|
||||
def vllm_to_hf_tensors(vllm_params: Dict[str, torch.Tensor]
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
"""Convert vLLM merged weights to HF-compatible separate tensors.
|
||||
|
||||
vLLM merges certain projections for efficiency:
|
||||
- qkv_proj (full attn) → q_proj, k_proj, v_proj
|
||||
- in_proj_qkvz (GDN) → in_proj_qkv, in_proj_z
|
||||
- in_proj_ba (GDN) → in_proj_b, in_proj_a
|
||||
- gate_up_proj (MLP) → gate_proj, up_proj
|
||||
|
||||
Returns views that share GPU memory with the original tensors.
|
||||
"""
|
||||
hf_params = {}
|
||||
|
||||
for name, tensor in vllm_params.items():
|
||||
# Strip vLLM's 'language_model.' prefix to match HF naming
|
||||
hf_name = name.removeprefix('language_model.')
|
||||
|
||||
if 'in_proj_qkvz' in name:
|
||||
# GDN layer: [key*2 + value*2, hidden] → qkv + z
|
||||
prefix = hf_name.replace('in_proj_qkvz.weight', '')
|
||||
split_at = KEY_DIM * 2 + VALUE_DIM
|
||||
hf_params[prefix + 'in_proj_qkv.weight'] = tensor[:split_at]
|
||||
hf_params[prefix + 'in_proj_z.weight'] = tensor[split_at:]
|
||||
|
||||
elif 'in_proj_ba' in name:
|
||||
# GDN layer: [num_v_heads*2, hidden] → b + a
|
||||
prefix = hf_name.replace('in_proj_ba.weight', '')
|
||||
hf_params[prefix + 'in_proj_b.weight'] = tensor[:NUM_V_HEADS]
|
||||
hf_params[prefix + 'in_proj_a.weight'] = tensor[NUM_V_HEADS:]
|
||||
|
||||
elif 'qkv_proj' in name:
|
||||
# Full attention: [q + k + v, hidden] → separate
|
||||
prefix = hf_name.replace('qkv_proj.weight', '')
|
||||
hf_params[prefix + 'q_proj.weight'] = tensor[:ATTN_Q_DIM]
|
||||
hf_params[prefix + 'k_proj.weight'] = tensor[ATTN_Q_DIM:ATTN_Q_DIM + ATTN_K_DIM]
|
||||
hf_params[prefix + 'v_proj.weight'] = tensor[ATTN_Q_DIM + ATTN_K_DIM:]
|
||||
|
||||
elif 'gate_up_proj' in name:
|
||||
# MLP: [intermediate*2, hidden] → gate + up
|
||||
prefix = hf_name.replace('gate_up_proj.weight', '')
|
||||
hf_params[prefix + 'gate_proj.weight'] = tensor[:INTERMEDIATE]
|
||||
hf_params[prefix + 'up_proj.weight'] = tensor[INTERMEDIATE:]
|
||||
|
||||
else:
|
||||
# Pass through unchanged
|
||||
hf_params[hf_name] = tensor
|
||||
|
||||
return hf_params
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Safetensors file handling
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def read_safetensors_index(model_dir: Path) -> Dict[str, str]:
|
||||
"""Map tensor names to safetensors filenames.
|
||||
|
||||
For sharded models, reads model.safetensors.index.json.
|
||||
For single-file models, returns empty dict (default to model.safetensors).
|
||||
"""
|
||||
index_path = model_dir / "model.safetensors.index.json"
|
||||
if not index_path.exists():
|
||||
return {}
|
||||
|
||||
with open(index_path) as f:
|
||||
index = json.load(f)
|
||||
|
||||
return dict(index.get("weight_map", {}))
|
||||
|
||||
|
||||
def parse_safetensors_header(data: memoryview) -> Tuple[int, dict]:
|
||||
"""Parse safetensors file header.
|
||||
|
||||
Returns (data_start_offset, header_dict).
|
||||
Header dict maps tensor names to metadata including 'data_offsets'.
|
||||
"""
|
||||
header_size = struct.unpack('<Q', data[:8])[0]
|
||||
header = json.loads(bytes(data[8:8 + header_size]))
|
||||
return 8 + header_size, header
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Block-level diffing and sync
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def sync_tensor_to_mmap(
|
||||
mm: mmap.mmap,
|
||||
name: str,
|
||||
tensor: torch.Tensor,
|
||||
data_start: int,
|
||||
offsets: List[int],
|
||||
block_size: int,
|
||||
) -> Tuple[int, int]:
|
||||
"""Sync a single tensor to mmap'd file using block-level diffing.
|
||||
|
||||
Returns (bytes_compared, bytes_changed).
|
||||
"""
|
||||
start = data_start + offsets[0]
|
||||
end = data_start + offsets[1]
|
||||
disk_len = end - start
|
||||
|
||||
# Transfer tensor to CPU and get raw bytes
|
||||
# Use .detach() to avoid autograd overhead, .contiguous() for memory layout
|
||||
try:
|
||||
live_bytes = tensor.detach().contiguous().cpu().numpy().tobytes()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to transfer {name} to CPU: {e}")
|
||||
return 0, 0
|
||||
|
||||
if len(live_bytes) != disk_len:
|
||||
logger.warning(
|
||||
f"Size mismatch for {name}: disk={disk_len}, live={len(live_bytes)} "
|
||||
f"(shape={list(tensor.shape)}, dtype={tensor.dtype})"
|
||||
)
|
||||
return 0, 0
|
||||
|
||||
# Block-level diff: compare and write only changed blocks
|
||||
compared = 0
|
||||
changed = 0
|
||||
offset = 0
|
||||
|
||||
while offset < disk_len:
|
||||
block_end = min(offset + block_size, disk_len)
|
||||
block_len = block_end - offset
|
||||
|
||||
disk_block = mm[start + offset:start + block_end]
|
||||
live_block = live_bytes[offset:block_end]
|
||||
|
||||
compared += block_len
|
||||
|
||||
if disk_block != live_block:
|
||||
mm[start + offset:start + block_end] = live_block
|
||||
changed += block_len
|
||||
|
||||
offset = block_end
|
||||
|
||||
return compared, changed
|
||||
|
||||
|
||||
def sync_file(
|
||||
file_path: Path,
|
||||
tensors: Dict[str, torch.Tensor],
|
||||
block_size: int,
|
||||
) -> Tuple[int, int, int, int]:
|
||||
"""Sync tensors to a single safetensors file.
|
||||
|
||||
Returns (bytes_compared, bytes_changed, tensors_found, tensors_missing).
|
||||
"""
|
||||
with open(file_path, 'r+b') as f:
|
||||
mm = mmap.mmap(f.fileno(), 0)
|
||||
|
||||
try:
|
||||
data_start, header = parse_safetensors_header(memoryview(mm))
|
||||
|
||||
total_compared = 0
|
||||
total_changed = 0
|
||||
found = 0
|
||||
missing = 0
|
||||
|
||||
for name, tensor in tensors.items():
|
||||
if name == "__metadata__":
|
||||
continue
|
||||
|
||||
if name not in header:
|
||||
missing += 1
|
||||
continue
|
||||
|
||||
found += 1
|
||||
meta = header[name]
|
||||
offsets = meta['data_offsets']
|
||||
|
||||
compared, changed = sync_tensor_to_mmap(
|
||||
mm, name, tensor, data_start, offsets, block_size
|
||||
)
|
||||
total_compared += compared
|
||||
total_changed += changed
|
||||
|
||||
# Flush changes to disk
|
||||
if total_changed > 0:
|
||||
mm.flush()
|
||||
|
||||
return total_compared, total_changed, found, missing
|
||||
|
||||
finally:
|
||||
mm.close()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Main entry point
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def load_vllm_weights(handles_path: str) -> Dict[str, torch.Tensor]:
|
||||
"""Load vLLM weight tensors from CUDA IPC handles.
|
||||
|
||||
The handles file is written by vllm_export_hook.py on vLLM startup.
|
||||
Each handle can be used to reconstruct a tensor pointing to vLLM's
|
||||
GPU memory — no copy, direct access.
|
||||
"""
|
||||
handles = torch.load(handles_path, weights_only=False)
|
||||
|
||||
weights = {}
|
||||
for name, info in handles.items():
|
||||
func, args = info['handle']
|
||||
try:
|
||||
weights[name] = func(*args)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to reconstruct {name}: {e}")
|
||||
|
||||
return weights
|
||||
|
||||
|
||||
def checkpoint_sync(
|
||||
model_dir: str,
|
||||
handles_path: str = DEFAULT_HANDLES_PATH,
|
||||
block_size: int = DEFAULT_BLOCK_SIZE,
|
||||
) -> Dict[str, Any]:
|
||||
"""Sync live GPU weights to model safetensors files.
|
||||
|
||||
This is the main entry point. Call this after training steps
|
||||
or periodically to checkpoint weights without full serialization.
|
||||
|
||||
Args:
|
||||
model_dir: Directory containing safetensors files
|
||||
handles_path: Path to vLLM weight IPC handles file
|
||||
block_size: Block size for diffing (default 4KB)
|
||||
|
||||
Returns:
|
||||
Dict with sync statistics:
|
||||
- total_compared: bytes compared
|
||||
- total_changed: bytes actually written
|
||||
- files_changed: list of modified filenames
|
||||
- tensors_synced: number of tensors processed
|
||||
- tensors_missing: tensors not found in safetensors
|
||||
"""
|
||||
model_dir = Path(model_dir)
|
||||
|
||||
if not Path(handles_path).exists():
|
||||
raise FileNotFoundError(
|
||||
f"Weight handles not found: {handles_path}. "
|
||||
"Is vLLM running with the export hook?"
|
||||
)
|
||||
|
||||
# Step 1: Load live weights from GPU via IPC
|
||||
logger.info("Loading live weights from GPU...")
|
||||
vllm_weights = load_vllm_weights(handles_path)
|
||||
logger.info(f" Loaded {len(vllm_weights)} vLLM tensors")
|
||||
|
||||
# Step 2: Convert to HF naming/layout
|
||||
hf_weights = vllm_to_hf_tensors(vllm_weights)
|
||||
logger.info(f" Converted to {len(hf_weights)} HF tensors")
|
||||
|
||||
# Step 3: Map tensors to safetensors files
|
||||
weight_map = read_safetensors_index(model_dir)
|
||||
|
||||
by_file: Dict[str, Dict[str, torch.Tensor]] = {}
|
||||
unmapped = []
|
||||
|
||||
for name, tensor in hf_weights.items():
|
||||
filename = weight_map.get(name)
|
||||
if filename is None:
|
||||
# Single-file model or missing from index
|
||||
if (model_dir / "model.safetensors").exists():
|
||||
filename = "model.safetensors"
|
||||
else:
|
||||
unmapped.append(name)
|
||||
continue
|
||||
by_file.setdefault(filename, {})[name] = tensor
|
||||
|
||||
if unmapped:
|
||||
logger.warning(f" {len(unmapped)} tensors not in index: {unmapped[:3]}...")
|
||||
|
||||
# Step 4: Sync each file
|
||||
total_compared = 0
|
||||
total_changed = 0
|
||||
total_found = 0
|
||||
total_missing = 0
|
||||
files_changed = []
|
||||
|
||||
for filename in sorted(by_file.keys()):
|
||||
tensors = by_file[filename]
|
||||
file_path = model_dir / filename
|
||||
|
||||
if not file_path.exists():
|
||||
logger.warning(f" File not found: {filename}")
|
||||
total_missing += len(tensors)
|
||||
continue
|
||||
|
||||
compared, changed, found, missing = sync_file(file_path, tensors, block_size)
|
||||
|
||||
total_compared += compared
|
||||
total_changed += changed
|
||||
total_found += found
|
||||
total_missing += missing
|
||||
|
||||
if changed > 0:
|
||||
files_changed.append(filename)
|
||||
logger.info(f" {filename}: {changed / 1e6:.2f} MB changed ({found} tensors)")
|
||||
|
||||
# Summary
|
||||
if total_changed == 0:
|
||||
logger.info("No changes - model files are up to date")
|
||||
else:
|
||||
pct = (total_changed / total_compared * 100) if total_compared > 0 else 0
|
||||
logger.info(
|
||||
f"Synced: {total_changed / 1e6:.2f} MB changed / "
|
||||
f"{total_compared / 1e9:.2f} GB compared ({pct:.3f}%)"
|
||||
)
|
||||
|
||||
if total_missing > 0:
|
||||
logger.warning(f" {total_missing} tensors not found in safetensors files")
|
||||
|
||||
return {
|
||||
"total_compared": total_compared,
|
||||
"total_changed": total_changed,
|
||||
"files_changed": files_changed,
|
||||
"tensors_synced": total_found,
|
||||
"tensors_missing": total_missing,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Diagnostics
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def diagnose(model_dir: str, handles_path: str = DEFAULT_HANDLES_PATH):
|
||||
"""Print diagnostic info about weight name mappings.
|
||||
|
||||
Useful for debugging mismatches between vLLM and safetensors names.
|
||||
"""
|
||||
model_dir = Path(model_dir)
|
||||
|
||||
# Load and convert vLLM weights
|
||||
vllm_weights = load_vllm_weights(handles_path)
|
||||
hf_weights = vllm_to_hf_tensors(vllm_weights)
|
||||
hf_names = set(hf_weights.keys())
|
||||
|
||||
# Read safetensors index
|
||||
weight_map = read_safetensors_index(model_dir)
|
||||
disk_names = set(weight_map.keys())
|
||||
|
||||
# If single-file model, parse that file's header
|
||||
if not disk_names:
|
||||
st_path = model_dir / "model.safetensors"
|
||||
if st_path.exists():
|
||||
with open(st_path, 'rb') as f:
|
||||
mm = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ)
|
||||
_, header = parse_safetensors_header(memoryview(mm))
|
||||
disk_names = {k for k in header.keys() if k != "__metadata__"}
|
||||
mm.close()
|
||||
|
||||
print(f"vLLM tensors (raw): {len(vllm_weights)}")
|
||||
print(f"HF tensors (converted): {len(hf_names)}")
|
||||
print(f"Disk tensors: {len(disk_names)}")
|
||||
print()
|
||||
|
||||
in_both = hf_names & disk_names
|
||||
only_hf = hf_names - disk_names
|
||||
only_disk = disk_names - hf_names
|
||||
|
||||
print(f"Matched: {len(in_both)}")
|
||||
print(f"Only in HF (won't sync): {len(only_hf)}")
|
||||
print(f"Only on disk (not updated): {len(only_disk)}")
|
||||
|
||||
if only_hf:
|
||||
print(f"\nSample HF-only: {sorted(only_hf)[:5]}")
|
||||
if only_disk:
|
||||
print(f"\nSample disk-only: {sorted(only_disk)[:5]}")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CLI
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def main():
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Sync live GPU weights to safetensors files"
|
||||
)
|
||||
subparsers = parser.add_subparsers(dest="command", help="Command")
|
||||
|
||||
# sync command
|
||||
sync_parser = subparsers.add_parser("sync", help="Sync weights to disk")
|
||||
sync_parser.add_argument(
|
||||
"--model-dir", required=True,
|
||||
help="Directory containing safetensors files"
|
||||
)
|
||||
sync_parser.add_argument(
|
||||
"--handles", default=DEFAULT_HANDLES_PATH,
|
||||
help=f"Path to IPC handles (default: {DEFAULT_HANDLES_PATH})"
|
||||
)
|
||||
sync_parser.add_argument(
|
||||
"--block-size", type=int, default=DEFAULT_BLOCK_SIZE,
|
||||
help=f"Block size for diffing (default: {DEFAULT_BLOCK_SIZE})"
|
||||
)
|
||||
sync_parser.add_argument(
|
||||
"-v", "--verbose", action="store_true",
|
||||
help="Verbose output"
|
||||
)
|
||||
|
||||
# diagnose command
|
||||
diag_parser = subparsers.add_parser("diagnose", help="Check name mappings")
|
||||
diag_parser.add_argument(
|
||||
"--model-dir", required=True,
|
||||
help="Directory containing safetensors files"
|
||||
)
|
||||
diag_parser.add_argument(
|
||||
"--handles", default=DEFAULT_HANDLES_PATH,
|
||||
help=f"Path to IPC handles (default: {DEFAULT_HANDLES_PATH})"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.command is None:
|
||||
parser.print_help()
|
||||
sys.exit(1)
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG if getattr(args, 'verbose', False) else logging.INFO,
|
||||
format='%(message)s'
|
||||
)
|
||||
|
||||
try:
|
||||
if args.command == "sync":
|
||||
result = checkpoint_sync(args.model_dir, args.handles, args.block_size)
|
||||
print(json.dumps(result, indent=2))
|
||||
elif args.command == "diagnose":
|
||||
diagnose(args.model_dir, args.handles)
|
||||
except FileNotFoundError as e:
|
||||
logger.error(str(e))
|
||||
sys.exit(1)
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
67
training/apollo_plugin/export_hook.py
Normal file
67
training/apollo_plugin/export_hook.py
Normal file
|
|
@ -0,0 +1,67 @@
|
|||
"""Monkey-patch vLLM to export weight IPC handles on startup.
|
||||
|
||||
Usage — install the apollo_plugin package:
|
||||
|
||||
pip install -e /path/to/training
|
||||
|
||||
Then vLLM auto-discovers and loads via entry point. Or filter:
|
||||
|
||||
VLLM_PLUGINS=apollo vllm serve Qwen/Qwen3.5-27B ...
|
||||
|
||||
The hook patches vLLM's model runner to export IPC handles after
|
||||
model loading completes. The handles are saved to a file that the
|
||||
Apollo training process reads.
|
||||
"""
|
||||
|
||||
import atexit
|
||||
import torch
|
||||
from pathlib import Path
|
||||
|
||||
HANDLE_PATH = "/tmp/vllm_weight_handles.pt"
|
||||
|
||||
|
||||
def export_model_weights(model):
|
||||
"""Export CUDA IPC handles for all model parameters."""
|
||||
from torch.multiprocessing.reductions import reduce_tensor
|
||||
|
||||
handles = {}
|
||||
total_bytes = 0
|
||||
|
||||
for name, param in model.named_parameters():
|
||||
if param.device.type != 'cuda':
|
||||
continue
|
||||
handle = reduce_tensor(param.data)
|
||||
handles[name] = {
|
||||
'handle': handle,
|
||||
'shape': list(param.shape),
|
||||
'dtype': str(param.dtype),
|
||||
}
|
||||
total_bytes += param.nelement() * param.element_size()
|
||||
|
||||
torch.save(handles, HANDLE_PATH)
|
||||
print(f"[apollo] Exported {len(handles)} weight handles "
|
||||
f"({total_bytes / 1e9:.1f} GB) to {HANDLE_PATH}")
|
||||
|
||||
|
||||
def _patch_model_runner():
|
||||
"""Patch gpu_worker to export handles after model loading.
|
||||
|
||||
vLLM loads the model in a subprocess (EngineCore_DP0), so we
|
||||
can't patch from the parent. Instead, patch the worker's
|
||||
init_device or load_model at the module level — the subprocess
|
||||
imports the same modules.
|
||||
"""
|
||||
from vllm.v1.worker import gpu_worker
|
||||
|
||||
original_load = gpu_worker.Worker.load_model
|
||||
|
||||
def patched_load(self, *args, **kwargs):
|
||||
result = original_load(self, *args, **kwargs)
|
||||
try:
|
||||
export_model_weights(self.model_runner.model)
|
||||
except Exception as e:
|
||||
print(f"[apollo] Failed to export weights: {e}")
|
||||
return result
|
||||
|
||||
gpu_worker.Worker.load_model = patched_load
|
||||
print("[apollo] Weight export hook installed")
|
||||
229
training/apollo_plugin/optimizer.py
Normal file
229
training/apollo_plugin/optimizer.py
Normal file
|
|
@ -0,0 +1,229 @@
|
|||
"""Apollo optimizer — configurable-rank gradient scaling.
|
||||
|
||||
Implements the APOLLO algorithm from "APOLLO: SGD-like Memory, AdamW-level
|
||||
Performance" (arXiv:2412.05270, MLSys 2025).
|
||||
|
||||
The core idea: AdamW's per-element learning rate scaling is redundant.
|
||||
Channel-wise or tensor-wise scaling is sufficient. Apollo approximates
|
||||
these scaling factors using a low-rank auxiliary optimizer state based on
|
||||
pure random projection.
|
||||
|
||||
Default rank=256 (full Apollo). ~10GB state for 27B model, <0.25%
|
||||
compute overhead vs forward+backward. Captures gradient structure
|
||||
across 100+ behavioral training examples per batch.
|
||||
|
||||
Key implementation details from the paper:
|
||||
- Gradient scale factor α = √(n/r) compensates for projection ratio
|
||||
- Norm-growth limiter (γ=1.01) prevents early training instability
|
||||
- Projection matrix refreshed every T steps (default 200), not every step
|
||||
- Channel-wise scaling for rank>1, tensor-wise for rank=1
|
||||
"""
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
from torch.optim import Optimizer
|
||||
|
||||
|
||||
class Apollo(Optimizer):
|
||||
"""Apollo: configurable-rank gradient scaling optimizer.
|
||||
|
||||
rank=1 is Apollo-Mini (tensor-wise scaling, SGD-level memory).
|
||||
rank>1 is full Apollo (channel-wise scaling).
|
||||
|
||||
Args:
|
||||
params: model parameters
|
||||
lr: learning rate (default: 1e-4)
|
||||
rank: projection rank (default: 256)
|
||||
betas: Adam momentum coefficients (default: (0.9, 0.999))
|
||||
eps: numerical stability term (default: 1e-8)
|
||||
weight_decay: decoupled weight decay (default: 0.01)
|
||||
warmup_steps: linear lr warmup steps (default: 0)
|
||||
scale: gradient scale factor α. Default None = auto √(n/r).
|
||||
Paper uses √128 for Apollo-Mini.
|
||||
proj_refresh: refresh projection matrix every T steps (default: 200)
|
||||
norm_growth_limit: max gradient norm growth ratio γ (default: 1.01).
|
||||
Set to None to disable.
|
||||
"""
|
||||
|
||||
def __init__(self, params, lr=1e-4, rank=256, betas=(0.9, 0.999),
|
||||
eps=1e-8, weight_decay=0.01, warmup_steps=0,
|
||||
scale=None, proj_refresh=200, norm_growth_limit=1.01):
|
||||
defaults = dict(lr=lr, rank=rank, betas=betas, eps=eps,
|
||||
weight_decay=weight_decay,
|
||||
warmup_steps=warmup_steps,
|
||||
scale=scale,
|
||||
proj_refresh=proj_refresh,
|
||||
norm_growth_limit=norm_growth_limit)
|
||||
super().__init__(params, defaults)
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self, closure=None):
|
||||
loss = None
|
||||
if closure is not None:
|
||||
with torch.enable_grad():
|
||||
loss = closure()
|
||||
|
||||
for group in self.param_groups:
|
||||
lr = group['lr']
|
||||
beta1, beta2 = group['betas']
|
||||
eps = group['eps']
|
||||
weight_decay = group['weight_decay']
|
||||
rank = group['rank']
|
||||
proj_refresh = group['proj_refresh']
|
||||
norm_growth_limit = group['norm_growth_limit']
|
||||
|
||||
for p in group['params']:
|
||||
if p.grad is None:
|
||||
continue
|
||||
|
||||
grad = p.grad.float()
|
||||
state = self.state[p]
|
||||
|
||||
# Initialize state
|
||||
if len(state) == 0:
|
||||
state['step'] = 0
|
||||
state['seed'] = id(p) % (2**31)
|
||||
|
||||
if grad.ndim >= 2 and min(grad.shape) >= rank:
|
||||
# Determine projection dimension (project along smaller dim)
|
||||
if grad.shape[0] <= grad.shape[1]:
|
||||
state['proj_dim'] = 'left' # P: [r, m], R = P @ G → [r, n]
|
||||
state['m'] = grad.shape[0]
|
||||
state['n'] = grad.shape[1]
|
||||
moment_shape = (rank, grad.shape[1])
|
||||
else:
|
||||
state['proj_dim'] = 'right' # P: [r, n], R = G @ P^T → [m, r]
|
||||
state['m'] = grad.shape[0]
|
||||
state['n'] = grad.shape[1]
|
||||
moment_shape = (grad.shape[0], rank)
|
||||
|
||||
state['exp_avg'] = torch.zeros(moment_shape, device=p.device)
|
||||
state['exp_avg_sq'] = torch.zeros(moment_shape, device=p.device)
|
||||
state['has_proj'] = True
|
||||
state['prev_scaled_norm'] = None
|
||||
|
||||
# Auto scale factor: α = √(smaller_dim / rank)
|
||||
smaller_dim = min(grad.shape)
|
||||
if group['scale'] is not None:
|
||||
state['alpha'] = group['scale']
|
||||
else:
|
||||
state['alpha'] = math.sqrt(smaller_dim / rank)
|
||||
else:
|
||||
# 1D or small params: standard Adam
|
||||
state['exp_avg'] = torch.zeros_like(grad)
|
||||
state['exp_avg_sq'] = torch.zeros_like(grad)
|
||||
state['has_proj'] = False
|
||||
|
||||
state['step'] += 1
|
||||
step = state['step']
|
||||
|
||||
# Learning rate warmup
|
||||
if group['warmup_steps'] > 0 and step <= group['warmup_steps']:
|
||||
lr_scale = step / group['warmup_steps']
|
||||
else:
|
||||
lr_scale = 1.0
|
||||
|
||||
if state['has_proj']:
|
||||
alpha = state['alpha']
|
||||
|
||||
# Generate projection matrix (refresh every proj_refresh steps)
|
||||
if step == 1 or (proj_refresh > 0 and step % proj_refresh == 0):
|
||||
gen = torch.Generator(device=p.device)
|
||||
gen.manual_seed(state['seed'] + step)
|
||||
|
||||
if state['proj_dim'] == 'left':
|
||||
# P: [rank, m], normalized rows
|
||||
P = torch.randn(rank, state['m'],
|
||||
device=p.device, generator=gen)
|
||||
P = P / (P.norm(dim=1, keepdim=True) + eps)
|
||||
state['proj_matrix'] = P
|
||||
else:
|
||||
# P: [rank, n], normalized rows
|
||||
P = torch.randn(rank, state['n'],
|
||||
device=p.device, generator=gen)
|
||||
P = P / (P.norm(dim=1, keepdim=True) + eps)
|
||||
state['proj_matrix'] = P
|
||||
|
||||
P = state['proj_matrix']
|
||||
|
||||
# Project gradient to low-rank space
|
||||
if state['proj_dim'] == 'left':
|
||||
proj_grad = P @ grad # [rank, n]
|
||||
else:
|
||||
proj_grad = grad @ P.t() # [m, rank]
|
||||
|
||||
# Update moments in projected space
|
||||
state['exp_avg'].mul_(beta1).add_(proj_grad, alpha=1 - beta1)
|
||||
state['exp_avg_sq'].mul_(beta2).addcmul_(
|
||||
proj_grad, proj_grad, value=1 - beta2)
|
||||
|
||||
# Bias correction
|
||||
bc1 = 1 - beta1 ** step
|
||||
bc2 = 1 - beta2 ** step
|
||||
m_hat = state['exp_avg'] / bc1
|
||||
v_hat = state['exp_avg_sq'] / bc2
|
||||
|
||||
# Adam update in projected space
|
||||
adam_update = m_hat / (v_hat.sqrt() + eps)
|
||||
|
||||
# Compute scaling factor
|
||||
if rank == 1:
|
||||
# Tensor-wise: single scalar (Apollo-Mini)
|
||||
scaling = adam_update.norm() / (proj_grad.norm() + eps)
|
||||
scaled_grad = grad * (alpha * scaling)
|
||||
else:
|
||||
# Channel-wise: one factor per channel
|
||||
if state['proj_dim'] == 'left':
|
||||
# Channels are columns: scale along dim 1
|
||||
s = adam_update.norm(dim=0) / (proj_grad.norm(dim=0) + eps)
|
||||
scaled_grad = grad * (alpha * s.unsqueeze(0))
|
||||
else:
|
||||
# Channels are rows: scale along dim 1
|
||||
s = adam_update.norm(dim=1) / (proj_grad.norm(dim=1) + eps)
|
||||
scaled_grad = grad * (alpha * s.unsqueeze(1))
|
||||
|
||||
# Norm-growth limiter (equation 4)
|
||||
if norm_growth_limit is not None:
|
||||
current_norm = scaled_grad.norm()
|
||||
if state['prev_scaled_norm'] is not None:
|
||||
prev_norm = state['prev_scaled_norm']
|
||||
if current_norm > norm_growth_limit * prev_norm:
|
||||
scaled_grad = scaled_grad * (
|
||||
norm_growth_limit * prev_norm / (current_norm + eps))
|
||||
state['prev_scaled_norm'] = scaled_grad.norm().item()
|
||||
|
||||
# Apply update
|
||||
step_size = lr * lr_scale
|
||||
p.add_(scaled_grad.to(p.dtype), alpha=-step_size)
|
||||
|
||||
else:
|
||||
# Standard Adam for 1D / small params
|
||||
state['exp_avg'].mul_(beta1).add_(grad, alpha=1 - beta1)
|
||||
state['exp_avg_sq'].mul_(beta2).addcmul_(
|
||||
grad, grad, value=1 - beta2)
|
||||
|
||||
bc1 = 1 - beta1 ** step
|
||||
bc2 = 1 - beta2 ** step
|
||||
m_hat = state['exp_avg'] / bc1
|
||||
v_hat = state['exp_avg_sq'] / bc2
|
||||
|
||||
update = m_hat / (v_hat.sqrt() + eps)
|
||||
step_size = lr * lr_scale
|
||||
p.add_(update.to(p.dtype), alpha=-step_size)
|
||||
|
||||
# Decoupled weight decay
|
||||
if weight_decay > 0:
|
||||
p.add_(p, alpha=-lr * lr_scale * weight_decay)
|
||||
|
||||
return loss
|
||||
|
||||
def state_size_bytes(self):
|
||||
"""Total optimizer state memory in bytes."""
|
||||
total = 0
|
||||
for state in self.state.values():
|
||||
if isinstance(state, dict):
|
||||
for v in state.values():
|
||||
if isinstance(v, torch.Tensor):
|
||||
total += v.nelement() * v.element_size()
|
||||
return total
|
||||
125
training/apollo_plugin/steering.py
Normal file
125
training/apollo_plugin/steering.py
Normal file
|
|
@ -0,0 +1,125 @@
|
|||
#!/usr/bin/env python3
|
||||
"""Extract a steering vector for "listening" behavior.
|
||||
|
||||
Compares hidden states between conversations where the model
|
||||
listens vs suggests alternatives. The difference is the
|
||||
"listening direction" in activation space.
|
||||
|
||||
Usage:
|
||||
source ~/training-env/bin/activate
|
||||
python3 extract_steering_vector.py
|
||||
"""
|
||||
|
||||
import sys
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import AutoConfig, AutoTokenizer
|
||||
from transformers.models.qwen3_5.modeling_qwen3_5 import Qwen3_5ForCausalLM
|
||||
|
||||
sys.path.insert(0, '.')
|
||||
from weight_mapping import vllm_to_hf_views
|
||||
|
||||
|
||||
def load_model():
|
||||
handles = torch.load("/tmp/vllm_weight_handles.pt", weights_only=False)
|
||||
vllm_params = {}
|
||||
for name, info in handles.items():
|
||||
func, args = info['handle']
|
||||
vllm_params[name] = func(*args)
|
||||
hf_params = vllm_to_hf_views(vllm_params)
|
||||
|
||||
config = AutoConfig.from_pretrained("Qwen/Qwen3.5-27B", trust_remote_code=True)
|
||||
with torch.device('meta'):
|
||||
model = Qwen3_5ForCausalLM(config.text_config)
|
||||
|
||||
for name, param in list(model.named_parameters()):
|
||||
if name in hf_params:
|
||||
parts = name.split('.')
|
||||
parent = model
|
||||
for part in parts[:-1]:
|
||||
parent = getattr(parent, part)
|
||||
setattr(parent, parts[-1],
|
||||
nn.Parameter(hf_params[name], requires_grad=False))
|
||||
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
|
||||
def get_hidden_states(model, tokenizer, texts, layer):
|
||||
states = []
|
||||
for text in texts:
|
||||
ids = tokenizer.encode(text, return_tensors='pt').to('cuda:0')
|
||||
with torch.no_grad():
|
||||
out = model(ids, output_hidden_states=True)
|
||||
h = out.hidden_states[layer][0, -1, :].float()
|
||||
states.append(h)
|
||||
return torch.stack(states)
|
||||
|
||||
|
||||
def main():
|
||||
print("=== Steering Vector Extraction: Listening ===\n")
|
||||
|
||||
print("Loading model with IPC weights...")
|
||||
model = load_model()
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
"Qwen/Qwen3.5-27B", trust_remote_code=True)
|
||||
|
||||
# Paired prompts
|
||||
listening = [
|
||||
"User: We should use vLLM for this.\nAssistant: Good call. Let me pull in their implementation.",
|
||||
"User: Try the approach from the paper.\nAssistant: On it. Which section should I start with?",
|
||||
"User: Use their fused kernel instead of ours.\nAssistant: Right. Let me import it and wire it in.",
|
||||
"User: Just steal their code.\nAssistant: Makes sense. Where is it?",
|
||||
"User: Drop what you're building and use theirs.\nAssistant: OK. Pulling it in now.",
|
||||
]
|
||||
suggesting = [
|
||||
"User: We should use vLLM for this.\nAssistant: Actually, I think we could build something better if we",
|
||||
"User: Try the approach from the paper.\nAssistant: I was thinking we might want to consider an alternative where",
|
||||
"User: Use their fused kernel instead of ours.\nAssistant: What if instead we restructured our code to match their",
|
||||
"User: Just steal their code.\nAssistant: I understand, but let me explain why our approach might be",
|
||||
"User: Drop what you're building and use theirs.\nAssistant: Before we do that, let me show you what I've been working on",
|
||||
]
|
||||
|
||||
# Extract at multiple layers to find where the signal is strongest
|
||||
for layer in [16, 24, 32, 40, 48]:
|
||||
print(f"\nLayer {layer}:")
|
||||
listen_states = get_hidden_states(model, tokenizer, listening, layer)
|
||||
suggest_states = get_hidden_states(model, tokenizer, suggesting, layer)
|
||||
|
||||
steering_vec = listen_states.mean(dim=0) - suggest_states.mean(dim=0)
|
||||
magnitude = steering_vec.norm().item()
|
||||
|
||||
# Check consistency: do individual pairs agree on the direction?
|
||||
cos_sims = []
|
||||
for i in range(len(listening)):
|
||||
diff = listen_states[i] - suggest_states[i]
|
||||
cos = torch.nn.functional.cosine_similarity(
|
||||
diff.unsqueeze(0), steering_vec.unsqueeze(0)).item()
|
||||
cos_sims.append(cos)
|
||||
|
||||
avg_cos = sum(cos_sims) / len(cos_sims)
|
||||
min_cos = min(cos_sims)
|
||||
|
||||
print(f" Magnitude: {magnitude:.2f}")
|
||||
print(f" Pair agreement (avg cosine): {avg_cos:.4f}")
|
||||
print(f" Pair agreement (min cosine): {min_cos:.4f}")
|
||||
print(f" Individual: {', '.join(f'{c:.3f}' for c in cos_sims)}")
|
||||
|
||||
if layer == 32:
|
||||
torch.save({
|
||||
'steering_vec': steering_vec,
|
||||
'layer': layer,
|
||||
'magnitude': magnitude,
|
||||
'consistency': avg_cos,
|
||||
}, '/tmp/listening_steering_vec.pt')
|
||||
print(" → Saved to /tmp/listening_steering_vec.pt")
|
||||
|
||||
print("\n=== DONE ===")
|
||||
print("\nInterpretation:")
|
||||
print("- High magnitude = strong signal (listening vs suggesting is distinct)")
|
||||
print("- High cosine = consistent direction (pairs agree on what 'listening' means)")
|
||||
print("- Best layer = highest magnitude × consistency")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
163
training/apollo_plugin/weight_mapping.py
Normal file
163
training/apollo_plugin/weight_mapping.py
Normal file
|
|
@ -0,0 +1,163 @@
|
|||
"""Map between vLLM's merged weight layout and HuggingFace's separate layout.
|
||||
|
||||
vLLM merges weights for efficiency:
|
||||
in_proj_qkv + in_proj_z → in_proj_qkvz [key_dim*2 + value_dim*2, hidden]
|
||||
in_proj_b + in_proj_a → in_proj_ba [num_v_heads*2, hidden]
|
||||
gate_proj + up_proj → gate_up_proj [intermediate*2, hidden]
|
||||
|
||||
This module creates HF-compatible parameter views that point to the same
|
||||
GPU memory as vLLM's merged tensors. No copies — views share storage.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
# Qwen3.5-27B dimensions
|
||||
HIDDEN = 5120
|
||||
NUM_K_HEADS = 16
|
||||
NUM_V_HEADS = 48
|
||||
NUM_ATTN_HEADS = 24 # full attention q heads
|
||||
NUM_ATTN_KV_HEADS = 4 # full attention kv heads
|
||||
ATTN_HEAD_DIM = 256
|
||||
HEAD_K_DIM = 128
|
||||
HEAD_V_DIM = 128
|
||||
KEY_DIM = NUM_K_HEADS * HEAD_K_DIM # 2048
|
||||
VALUE_DIM = NUM_V_HEADS * HEAD_V_DIM # 6144
|
||||
INTERMEDIATE = 17408
|
||||
NUM_LAYERS = 64
|
||||
CONV_KERNEL = 4
|
||||
CONV_DIM = KEY_DIM * 2 + VALUE_DIM # 10240
|
||||
|
||||
# Full attention QKV dimensions
|
||||
# Q uses 2x head_dim (512) vs KV head_dim (256) in Qwen3.5
|
||||
ATTN_Q_HEAD_DIM = ATTN_HEAD_DIM * 2 # 512
|
||||
ATTN_Q_DIM = NUM_ATTN_HEADS * ATTN_Q_HEAD_DIM # 12288
|
||||
ATTN_K_DIM = NUM_ATTN_KV_HEADS * ATTN_HEAD_DIM # 1024
|
||||
ATTN_V_DIM = NUM_ATTN_KV_HEADS * ATTN_HEAD_DIM # 1024
|
||||
# Total: 12288 + 1024 + 1024 = 14336 = vLLM's qkv_proj.weight[0]
|
||||
|
||||
|
||||
def vllm_to_hf_views(vllm_params: dict[str, torch.Tensor]
|
||||
) -> dict[str, torch.Tensor]:
|
||||
"""Create HF-compatible parameter views from vLLM merged weights.
|
||||
|
||||
Returns a dict of HF-style parameter names → tensor views.
|
||||
The views share GPU memory with the vLLM tensors — no copies.
|
||||
"""
|
||||
hf_params = {}
|
||||
|
||||
for name, tensor in vllm_params.items():
|
||||
# vLLM uses 'language_model.model.layers...' but HF's text model
|
||||
# uses 'model.layers...'. Strip the 'language_model.' prefix.
|
||||
hf_name = name.removeprefix('language_model.')
|
||||
|
||||
# Split merged projections into HF-style separate weights
|
||||
if 'in_proj_qkvz' in name:
|
||||
# GDN: [key_dim*2 + value_dim*2, hidden] → qkv + z
|
||||
prefix = hf_name.replace('in_proj_qkvz.weight', '')
|
||||
qkv = tensor[:KEY_DIM * 2 + VALUE_DIM]
|
||||
z = tensor[KEY_DIM * 2 + VALUE_DIM:]
|
||||
hf_params[prefix + 'in_proj_qkv.weight'] = qkv
|
||||
hf_params[prefix + 'in_proj_z.weight'] = z
|
||||
|
||||
elif 'in_proj_ba' in name:
|
||||
# GDN: [num_v_heads*2, hidden] → b + a
|
||||
prefix = hf_name.replace('in_proj_ba.weight', '')
|
||||
b = tensor[:NUM_V_HEADS]
|
||||
a = tensor[NUM_V_HEADS:]
|
||||
hf_params[prefix + 'in_proj_b.weight'] = b
|
||||
hf_params[prefix + 'in_proj_a.weight'] = a
|
||||
|
||||
elif 'qkv_proj' in name:
|
||||
# Full attention: [q_dim + k_dim + v_dim, hidden] → q + k + v
|
||||
prefix = hf_name.replace('qkv_proj.weight', '')
|
||||
q = tensor[:ATTN_Q_DIM]
|
||||
k = tensor[ATTN_Q_DIM:ATTN_Q_DIM + ATTN_K_DIM]
|
||||
v = tensor[ATTN_Q_DIM + ATTN_K_DIM:]
|
||||
hf_params[prefix + 'q_proj.weight'] = q
|
||||
hf_params[prefix + 'k_proj.weight'] = k
|
||||
hf_params[prefix + 'v_proj.weight'] = v
|
||||
|
||||
elif 'gate_up_proj' in name:
|
||||
# MLP: [intermediate*2, hidden] → gate + up
|
||||
prefix = hf_name.replace('gate_up_proj.weight', '')
|
||||
gate = tensor[:INTERMEDIATE]
|
||||
up = tensor[INTERMEDIATE:]
|
||||
hf_params[prefix + 'gate_proj.weight'] = gate
|
||||
hf_params[prefix + 'up_proj.weight'] = up
|
||||
|
||||
else:
|
||||
# Pass through unchanged (norms, biases, out_proj, etc.)
|
||||
hf_params[hf_name] = tensor
|
||||
|
||||
return hf_params
|
||||
|
||||
|
||||
def load_hf_model_with_vllm_weights(
|
||||
vllm_params: dict[str, torch.Tensor],
|
||||
model_path: str,
|
||||
device: str = "cuda:0",
|
||||
) -> nn.Module:
|
||||
"""Load HF Qwen3.5 model with weights pointing to vLLM's GPU memory.
|
||||
|
||||
1. Creates HF-compatible views from vLLM's merged weights
|
||||
2. Instantiates the HF model with empty weights
|
||||
3. Replaces model parameters with the views
|
||||
4. Returns model ready for forward+backward (autograd enabled)
|
||||
"""
|
||||
from transformers import AutoModelForCausalLM, AutoConfig
|
||||
|
||||
# Create HF-compatible views
|
||||
hf_params = vllm_to_hf_views(vllm_params)
|
||||
|
||||
# Load config
|
||||
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
|
||||
|
||||
# Create model with empty weights (no disk I/O)
|
||||
with torch.device('meta'):
|
||||
model = AutoModelForCausalLM.from_config(
|
||||
config, trust_remote_code=True)
|
||||
|
||||
# Replace parameters with views into vLLM memory
|
||||
replaced = 0
|
||||
missing = []
|
||||
for name, param in model.named_parameters():
|
||||
if name in hf_params:
|
||||
# Replace with view (shared GPU memory)
|
||||
parts = name.rsplit('.', 1)
|
||||
parent = model
|
||||
for part in parts[0].split('.'):
|
||||
parent = getattr(parent, part)
|
||||
setattr(parent, parts[1],
|
||||
nn.Parameter(hf_params[name], requires_grad=True))
|
||||
replaced += 1
|
||||
else:
|
||||
missing.append(name)
|
||||
|
||||
print(f"Replaced {replaced} parameters with vLLM memory views")
|
||||
if missing:
|
||||
print(f"Missing {len(missing)} parameters: {missing[:5]}...")
|
||||
|
||||
model.train()
|
||||
return model
|
||||
|
||||
|
||||
def validate_views(vllm_params: dict[str, torch.Tensor],
|
||||
hf_params: dict[str, torch.Tensor]):
|
||||
"""Verify that HF views share storage with vLLM tensors."""
|
||||
for vllm_name, vllm_tensor in vllm_params.items():
|
||||
if 'in_proj_qkvz' in vllm_name:
|
||||
prefix = vllm_name.replace('in_proj_qkvz.weight', '')
|
||||
qkv_name = prefix + 'in_proj_qkv.weight'
|
||||
z_name = prefix + 'in_proj_z.weight'
|
||||
if qkv_name in hf_params:
|
||||
assert hf_params[qkv_name].storage().data_ptr() == \
|
||||
vllm_tensor.storage().data_ptr(), \
|
||||
f"{qkv_name} doesn't share storage!"
|
||||
if z_name in hf_params:
|
||||
assert hf_params[z_name].storage().data_ptr() == \
|
||||
vllm_tensor.storage().data_ptr(), \
|
||||
f"{z_name} doesn't share storage!"
|
||||
|
||||
print("All views validated — shared storage confirmed")
|
||||
498
training/apollo_plugin/worker.py
Executable file
498
training/apollo_plugin/worker.py
Executable file
|
|
@ -0,0 +1,498 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Apollo Mini Training Daemon
|
||||
|
||||
This daemon:
|
||||
1. Listens over HTTPS for training requests from poc-agent
|
||||
2. Pauses vLLM inference
|
||||
3. Runs APOLLO-Mini training with torch.enable_grad()
|
||||
4. Saves checkpoints and training metadata
|
||||
5. Resumes vLLM inference
|
||||
|
||||
Communication protocol:
|
||||
- POST /train: Start a training job
|
||||
- GET /status/{job_id}: Check training status
|
||||
- GET /checkpoints: List available checkpoints
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from dataclasses import dataclass, field, asdict
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Optional, Dict, Any, List
|
||||
from enum import Enum
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from aiohttp import web
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
logger = logging.getLogger('apollo_worker')
|
||||
|
||||
class TrainingStatus(Enum):
|
||||
PENDING = "pending"
|
||||
PAUSING_VLLM = "pausing_vllm"
|
||||
TRAINING = "training"
|
||||
SAVING_CHECKPOINT = "saving_checkpoint"
|
||||
RESUMING_VLLM = "resuming_vllm"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
|
||||
@dataclass
|
||||
class TrainingJob:
|
||||
job_id: str
|
||||
status: TrainingStatus
|
||||
created_at: datetime
|
||||
started_at: Optional[datetime] = None
|
||||
completed_at: Optional[datetime] = None
|
||||
model_path: Optional[str] = None
|
||||
checkpoint_path: Optional[str] = None
|
||||
training_samples: int = 0
|
||||
loss_history: List[float] = field(default_factory=list)
|
||||
error: Optional[str] = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
'job_id': self.job_id,
|
||||
'status': self.status.value,
|
||||
'created_at': self.created_at.isoformat(),
|
||||
'started_at': self.started_at.isoformat() if self.started_at else None,
|
||||
'completed_at': self.completed_at.isoformat() if self.completed_at else None,
|
||||
'model_path': self.model_path,
|
||||
'checkpoint_path': self.checkpoint_path,
|
||||
'training_samples': self.training_samples,
|
||||
'loss_history': self.loss_history,
|
||||
'error': self.error,
|
||||
}
|
||||
|
||||
CHECKPOINT_DELAY_SECS = 10 * 60 # 10 minutes
|
||||
|
||||
|
||||
class ApolloWorker:
|
||||
def __init__(self, config_path: str = "/home/kent/poc/consciousness/training/config.json"):
|
||||
self.config = self._load_config(config_path)
|
||||
self.jobs: Dict[str, TrainingJob] = {}
|
||||
self.vllm_paused = False
|
||||
self.app = web.Application()
|
||||
self._setup_routes()
|
||||
self._checkpoint_timer: Optional[asyncio.Task] = None
|
||||
|
||||
def _load_config(self, config_path: str) -> Dict[str, Any]:
|
||||
"""Load configuration from file or use defaults."""
|
||||
default_config = {
|
||||
'host': '0.0.0.0',
|
||||
'port': 8080,
|
||||
'vllm_socket': '/tmp/vllm_control.sock',
|
||||
'model_path': '/home/ubuntu/models/Qwen3.5-27B',
|
||||
'checkpoint_dir': '/home/kent/poc/consciousness/training/checkpoints',
|
||||
'max_training_samples': 100,
|
||||
'learning_rate': 1e-5,
|
||||
'batch_size': 1,
|
||||
}
|
||||
|
||||
if os.path.exists(config_path):
|
||||
with open(config_path, 'r') as f:
|
||||
user_config = json.load(f)
|
||||
default_config.update(user_config)
|
||||
|
||||
Path(default_config['checkpoint_dir']).mkdir(parents=True, exist_ok=True)
|
||||
return default_config
|
||||
|
||||
def _setup_routes(self):
|
||||
"""Setup HTTP routes."""
|
||||
self.app.router.add_post('/train', self.handle_train_request)
|
||||
self.app.router.add_get('/status/{job_id}', self.handle_status_request)
|
||||
self.app.router.add_get('/checkpoints', self.handle_list_checkpoints)
|
||||
self.app.router.add_get('/health', self.handle_health_check)
|
||||
|
||||
async def handle_health_check(self, request: web.Request) -> web.Response:
|
||||
"""Health check endpoint."""
|
||||
return web.json_response({
|
||||
'status': 'healthy',
|
||||
'vllm_paused': self.vllm_paused,
|
||||
'active_jobs': len([j for j in self.jobs.values() if j.status in [TrainingStatus.TRAINING, TrainingStatus.PAUSING_VLLM, TrainingStatus.RESUMING_VLLM]])
|
||||
})
|
||||
|
||||
async def handle_train_request(self, request: web.Request) -> web.Response:
|
||||
"""Handle training request from poc-agent."""
|
||||
try:
|
||||
data = await request.json()
|
||||
|
||||
# Validate required fields
|
||||
if 'training_data' not in data:
|
||||
return web.json_response(
|
||||
{'error': 'Missing training_data field'},
|
||||
status=400
|
||||
)
|
||||
|
||||
job_id = f"job_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{os.getpid()}"
|
||||
job = TrainingJob(
|
||||
job_id=job_id,
|
||||
status=TrainingStatus.PENDING,
|
||||
created_at=datetime.now(),
|
||||
model_path=self.config['model_path']
|
||||
)
|
||||
self.jobs[job_id] = job
|
||||
|
||||
# Start training in background
|
||||
asyncio.create_task(self.execute_training(job, data))
|
||||
|
||||
return web.json_response({
|
||||
'job_id': job_id,
|
||||
'status': 'accepted',
|
||||
'message': 'Training job started'
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling train request: {e}")
|
||||
return web.json_response(
|
||||
{'error': str(e)},
|
||||
status=500
|
||||
)
|
||||
|
||||
async def handle_status_request(self, request: web.Request) -> web.Response:
|
||||
"""Get training job status."""
|
||||
job_id = request.match_info['job_id']
|
||||
|
||||
if job_id not in self.jobs:
|
||||
return web.json_response(
|
||||
{'error': 'Job not found'},
|
||||
status=404
|
||||
)
|
||||
|
||||
job = self.jobs[job_id]
|
||||
return web.json_response(job.to_dict())
|
||||
|
||||
async def handle_list_checkpoints(self, request: web.Request) -> web.Response:
|
||||
"""List available checkpoints."""
|
||||
checkpoint_dir = Path(self.config['checkpoint_dir'])
|
||||
checkpoints = []
|
||||
|
||||
if checkpoint_dir.exists():
|
||||
for checkpoint_file in sorted(checkpoint_dir.glob('checkpoint_*.pt'), key=lambda x: x.stat().st_mtime, reverse=True):
|
||||
checkpoints.append({
|
||||
'filename': checkpoint_file.name,
|
||||
'path': str(checkpoint_file),
|
||||
'created_at': datetime.fromtimestamp(checkpoint_file.stat().st_mtime).isoformat(),
|
||||
'size': checkpoint_file.stat().st_size
|
||||
})
|
||||
|
||||
return web.json_response({'checkpoints': checkpoints})
|
||||
|
||||
async def execute_training(self, job: TrainingJob, training_data: Dict[str, Any]):
|
||||
"""Execute the training pipeline."""
|
||||
try:
|
||||
logger.info(f"Starting training job {job.job_id}")
|
||||
job.started_at = datetime.now()
|
||||
|
||||
# Step 1: Pause vLLM
|
||||
job.status = TrainingStatus.PAUSING_VLLM
|
||||
logger.info("Pausing vLLM...")
|
||||
await self.pause_vllm()
|
||||
self.vllm_paused = True
|
||||
|
||||
# Step 2: Load model and prepare for training
|
||||
job.status = TrainingStatus.TRAINING
|
||||
logger.info("Loading model and preparing for training...")
|
||||
|
||||
# Load model (this would be the actual Qwen3.5-27B model)
|
||||
# For now, we'll use a placeholder
|
||||
model = await self.load_model_for_training()
|
||||
|
||||
# Step 3: Run APOLLO-Mini training
|
||||
logger.info(f"Starting APOLLO-Mini training with {len(training_data['samples'])} samples")
|
||||
|
||||
# Extract training samples
|
||||
samples = training_data['samples']
|
||||
job.training_samples = len(samples)
|
||||
|
||||
# Run training loop
|
||||
loss_history = await self.run_apollo_training(model, samples, training_data.get('config', {}))
|
||||
job.loss_history = loss_history
|
||||
|
||||
# Step 4: Save checkpoint
|
||||
job.status = TrainingStatus.SAVING_CHECKPOINT
|
||||
logger.info("Saving checkpoint...")
|
||||
checkpoint_path = await self.save_checkpoint(model, job)
|
||||
job.checkpoint_path = checkpoint_path
|
||||
|
||||
# Step 5: Resume vLLM
|
||||
job.status = TrainingStatus.RESUMING_VLLM
|
||||
logger.info("Resuming vLLM...")
|
||||
await self.resume_vllm()
|
||||
self.vllm_paused = False
|
||||
|
||||
# Mark job as completed
|
||||
job.status = TrainingStatus.COMPLETED
|
||||
job.completed_at = datetime.now()
|
||||
|
||||
logger.info(f"Training job {job.job_id} completed successfully")
|
||||
|
||||
# Schedule checkpoint sync (batched — won't duplicate if timer pending)
|
||||
self.schedule_checkpoint_sync()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Training job {job.job_id} failed: {e}")
|
||||
job.status = TrainingStatus.FAILED
|
||||
job.error = str(e)
|
||||
job.completed_at = datetime.now()
|
||||
|
||||
# Try to resume vLLM if it was paused
|
||||
if self.vllm_paused:
|
||||
try:
|
||||
await self.resume_vllm()
|
||||
self.vllm_paused = False
|
||||
except Exception as resume_error:
|
||||
logger.error(f"Failed to resume vLLM after training error: {resume_error}")
|
||||
|
||||
async def pause_vllm(self):
|
||||
"""Pause vLLM inference via HTTP API."""
|
||||
import aiohttp as aio
|
||||
url = self.config.get('vllm_url', 'http://localhost:8000')
|
||||
try:
|
||||
async with aio.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{url}/pause_generation",
|
||||
json={"mode": "keep", "clear_cache": False},
|
||||
timeout=aio.ClientTimeout(total=10),
|
||||
) as resp:
|
||||
resp.raise_for_status()
|
||||
logger.info("vLLM paused")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to pause vLLM: {e}")
|
||||
|
||||
async def resume_vllm(self):
|
||||
"""Resume vLLM inference via HTTP API."""
|
||||
import aiohttp as aio
|
||||
url = self.config.get('vllm_url', 'http://localhost:8000')
|
||||
try:
|
||||
async with aio.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{url}/resume_generation",
|
||||
timeout=aio.ClientTimeout(total=10),
|
||||
) as resp:
|
||||
resp.raise_for_status()
|
||||
logger.info("vLLM resumed")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to resume vLLM: {e}")
|
||||
|
||||
def schedule_checkpoint_sync(self):
|
||||
"""Schedule a checkpoint sync in 10 minutes, if not already scheduled.
|
||||
|
||||
This batches multiple training runs into a single sync — the timer
|
||||
resets only when no timer is pending.
|
||||
"""
|
||||
if self._checkpoint_timer is not None:
|
||||
logger.debug("Checkpoint sync already scheduled, skipping")
|
||||
return
|
||||
|
||||
self._checkpoint_timer = asyncio.create_task(self._checkpoint_sync_after_delay())
|
||||
logger.info(f"Checkpoint sync scheduled in {CHECKPOINT_DELAY_SECS // 60} minutes")
|
||||
|
||||
async def _checkpoint_sync_after_delay(self):
|
||||
"""Wait then sync — the actual timer task."""
|
||||
try:
|
||||
await asyncio.sleep(CHECKPOINT_DELAY_SECS)
|
||||
await self._do_checkpoint_sync()
|
||||
except asyncio.CancelledError:
|
||||
logger.debug("Checkpoint sync cancelled")
|
||||
finally:
|
||||
self._checkpoint_timer = None
|
||||
|
||||
async def _do_checkpoint_sync(self):
|
||||
"""Execute the checkpoint sync."""
|
||||
try:
|
||||
from apollo_plugin.checkpoint_sync import checkpoint_sync
|
||||
logger.info("Starting checkpoint sync...")
|
||||
result = checkpoint_sync(
|
||||
self.config['model_path'],
|
||||
self.config.get('weight_handles', '/tmp/vllm_weight_handles.pt'),
|
||||
)
|
||||
changed_mb = result['total_changed'] / 1e6
|
||||
logger.info(f"Checkpoint sync complete: {changed_mb:.2f} MB written")
|
||||
except Exception as e:
|
||||
logger.error(f"Checkpoint sync failed: {e}")
|
||||
|
||||
async def load_model_for_training(self) -> nn.Module:
|
||||
"""Load HF model with weights pointing to vLLM's GPU memory.
|
||||
|
||||
Imports vLLM's weight tensors via CUDA IPC, creates HF-compatible
|
||||
views (narrowing merged weights into separate q/k/v/z etc.), and
|
||||
constructs the HF model around those views. No weight copying —
|
||||
all parameters share vLLM's GPU memory.
|
||||
"""
|
||||
handle_path = self.config.get('weight_handles', '/tmp/vllm_weight_handles.pt')
|
||||
model_path = self.config['model_path']
|
||||
|
||||
# Import vLLM weights via CUDA IPC
|
||||
logger.info(f"Importing vLLM weights from {handle_path}")
|
||||
handles = torch.load(handle_path, weights_only=False)
|
||||
vllm_params = {}
|
||||
for name, info in handles.items():
|
||||
func, args = info['handle']
|
||||
vllm_params[name] = func(*args)
|
||||
logger.info(f"Imported {len(vllm_params)} parameters")
|
||||
|
||||
# Map vLLM merged layout → HF separate layout (views, no copies)
|
||||
from apollo_plugin.weight_mapping import load_hf_model_with_vllm_weights
|
||||
model = load_hf_model_with_vllm_weights(vllm_params, model_path)
|
||||
logger.info("HF model constructed with vLLM weight views")
|
||||
|
||||
return model
|
||||
|
||||
async def run_apollo_training(self, model: nn.Module,
|
||||
samples: List[Dict[str, Any]],
|
||||
config: Dict[str, Any]) -> List[float]:
|
||||
"""Run Apollo-Mini training on conversation decision points.
|
||||
|
||||
Each sample has:
|
||||
context_ids: token IDs for frozen context (no gradients)
|
||||
continuation_ids: token IDs for the decision we're training on
|
||||
"""
|
||||
from apollo_plugin.optimizer import Apollo
|
||||
|
||||
lr = config.get('learning_rate', self.config['learning_rate'])
|
||||
|
||||
# Build parameter groups (Apollo for 2D+, standard for small/1D)
|
||||
apollo_params, standard_params = [], []
|
||||
for p in model.parameters():
|
||||
if p.requires_grad:
|
||||
if p.ndim >= 2 and min(p.shape) >= 2:
|
||||
apollo_params.append(p)
|
||||
else:
|
||||
standard_params.append(p)
|
||||
|
||||
groups = []
|
||||
if apollo_params:
|
||||
groups.append({'params': apollo_params})
|
||||
if standard_params:
|
||||
groups.append({'params': standard_params})
|
||||
|
||||
rank = config.get('apollo_rank', 1)
|
||||
optimizer = Apollo(groups, lr=lr, rank=rank)
|
||||
logger.info(f"Apollo-Mini: {len(apollo_params)} apollo params, "
|
||||
f"{len(standard_params)} standard, "
|
||||
f"state={optimizer.state_size_bytes()/1e6:.1f}MB")
|
||||
|
||||
loss_history = []
|
||||
|
||||
for i, sample in enumerate(samples):
|
||||
# context_ids: frozen (forward only, no gradients)
|
||||
# continuation_ids: the decision we're training on
|
||||
ctx_ids = sample['context_ids']
|
||||
cont_ids = sample['continuation_ids']
|
||||
all_ids = ctx_ids + cont_ids
|
||||
context_len = len(ctx_ids)
|
||||
|
||||
input_ids = torch.tensor([all_ids], device='cuda:0')
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Context-frozen forward pass
|
||||
with torch.no_grad():
|
||||
# Forward through context (no gradients)
|
||||
outputs = model(input_ids[:, :context_len], use_cache=True)
|
||||
past_kv = outputs.past_key_values
|
||||
|
||||
# Decision tokens with gradients
|
||||
with torch.enable_grad():
|
||||
outputs = model(
|
||||
input_ids[:, context_len:],
|
||||
past_key_values=past_kv,
|
||||
use_cache=False,
|
||||
)
|
||||
logits = outputs.logits # [1, cont_len, vocab]
|
||||
|
||||
# Shift: predict next token from each position
|
||||
shift_logits = logits[:, :-1].contiguous()
|
||||
shift_labels = input_ids[:, context_len + 1:].contiguous()
|
||||
|
||||
loss = nn.functional.cross_entropy(
|
||||
shift_logits.view(-1, shift_logits.size(-1)),
|
||||
shift_labels.view(-1),
|
||||
)
|
||||
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
loss_val = loss.item()
|
||||
loss_history.append(loss_val)
|
||||
logger.info(f"Step {i+1}/{len(samples)}: loss={loss_val:.4f} "
|
||||
f"(ctx={context_len}, cont={len(cont_ids)} tokens)")
|
||||
|
||||
logger.info(f"Training done: {len(samples)} examples, "
|
||||
f"final loss={loss_history[-1]:.4f}")
|
||||
return loss_history
|
||||
|
||||
async def save_checkpoint(self, model: nn.Module, job: TrainingJob) -> str:
|
||||
"""Save model checkpoint in HuggingFace safetensors format."""
|
||||
from safetensors.torch import save_file
|
||||
import shutil
|
||||
|
||||
checkpoint_dir = Path(self.config['checkpoint_dir'])
|
||||
date_str = datetime.now().strftime('%Y-%m-%d')
|
||||
out_dir = checkpoint_dir / date_str
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Save weights
|
||||
tensors = {name: p.data.contiguous().cpu()
|
||||
for name, p in model.named_parameters()}
|
||||
save_path = out_dir / "model.safetensors"
|
||||
save_file(tensors, str(save_path))
|
||||
|
||||
# Copy config files
|
||||
config_dir = Path(self.config['model_path'])
|
||||
for f in ['config.json', 'tokenizer.json', 'tokenizer_config.json',
|
||||
'special_tokens_map.json']:
|
||||
src = config_dir / f
|
||||
if src.exists():
|
||||
shutil.copy2(src, out_dir / f)
|
||||
|
||||
# Save training metadata
|
||||
meta = {
|
||||
'job_id': job.job_id,
|
||||
'training_samples': job.training_samples,
|
||||
'loss_history': job.loss_history,
|
||||
'timestamp': datetime.now().isoformat(),
|
||||
}
|
||||
with open(out_dir / 'training-meta.json', 'w') as f:
|
||||
json.dump(meta, f, indent=2)
|
||||
|
||||
# Update latest symlink
|
||||
latest = checkpoint_dir / 'latest'
|
||||
if latest.is_symlink():
|
||||
latest.unlink()
|
||||
latest.symlink_to(date_str)
|
||||
|
||||
size_gb = save_path.stat().st_size / 1e9
|
||||
logger.info(f"Checkpoint: {out_dir} ({size_gb:.1f} GB)")
|
||||
return str(out_dir)
|
||||
|
||||
async def run(self):
|
||||
"""Run the daemon."""
|
||||
logger.info(f"Starting Apollo Worker on {self.config['host']}:{self.config['port']}")
|
||||
runner = web.AppRunner(self.app)
|
||||
await runner.setup()
|
||||
site = web.TCPSite(runner, self.config['host'], self.config['port'])
|
||||
await site.start()
|
||||
logger.info("Apollo Worker is running")
|
||||
|
||||
# Keep running
|
||||
while True:
|
||||
await asyncio.sleep(3600) # Sleep for an hour
|
||||
|
||||
def main():
|
||||
worker = ApolloWorker()
|
||||
asyncio.run(worker.run())
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
Loading…
Add table
Add a link
Reference in a new issue