apollo-mini training system: initial implementation
Core components for online fine-tuning of Qwen3.5-27B with CUDA IPC shared weight memory between vLLM and the training process: - apollo_mini.py: rank-1 optimizer (SGD memory, AdamW quality) - apollo_worker.py: HTTP daemon coordinating training with vLLM - weight_mapping.py: vLLM merged → HF separate layout (zero-copy views) - training_example.py: tokenization with chat template - export_weights.py: CUDA IPC handle export from vLLM - train.py: standalone training script (alternative to daemon) - DESIGN.md: architecture and protocol documentation Validated: CUDA IPC autograd works on real Qwen3.5 weights (B200). Apollo-Mini rank-1 projection + scaling + in-place update confirmed. Co-Authored-By: Kent Overstreet <kent.overstreet@gmail.com>
This commit is contained in:
parent
13453606ae
commit
c5d7d8cb5d
7 changed files with 1484 additions and 0 deletions
197
training/DESIGN.md
Normal file
197
training/DESIGN.md
Normal file
|
|
@ -0,0 +1,197 @@
|
||||||
|
# Apollo Mini Training System Design
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
This system enables continuous fine-tuning of the Qwen3.5-27B model while maintaining inference capability through vLLM. The key insight is that APOLLO-Mini's near-zero optimizer state (kilobytes for a 7B model) combined with LoRA adapters makes the memory overhead small enough to fit within vLLM's reclaimed KV cache space.
|
||||||
|
|
||||||
|
## Architecture
|
||||||
|
|
||||||
|
### Components
|
||||||
|
|
||||||
|
1. **Apollo Worker Daemon** (`apollo_worker.py`)
|
||||||
|
- Listens over HTTP/HTTPS for training requests
|
||||||
|
- Manages vLLM pause/resume cycle
|
||||||
|
- Executes APOLLO-Mini training with `torch.enable_grad()`
|
||||||
|
- Saves checkpoints and training metadata
|
||||||
|
- Runs on the B200 server alongside vLLM
|
||||||
|
|
||||||
|
2. **Training Signal Agent** (to be built)
|
||||||
|
- Runs online like surface-observe
|
||||||
|
- Analyzes recent conversation windows
|
||||||
|
- Identifies improvement opportunities
|
||||||
|
- Requests training from Apollo Worker
|
||||||
|
- Runs on Moria (separate from B200)
|
||||||
|
|
||||||
|
3. **vLLM Inference Engine**
|
||||||
|
- Continues serving during non-training periods
|
||||||
|
- Pauses during training steps
|
||||||
|
- Shares GPU memory with training process
|
||||||
|
|
||||||
|
### Communication Protocol
|
||||||
|
|
||||||
|
```
|
||||||
|
POST /train
|
||||||
|
{
|
||||||
|
"training_data": {
|
||||||
|
"samples": [
|
||||||
|
{
|
||||||
|
"input": "conversation context",
|
||||||
|
"expected_output": "better response",
|
||||||
|
"rationale": "why this is better"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"config": {
|
||||||
|
"learning_rate": 1e-5,
|
||||||
|
"max_steps": 100
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Response:
|
||||||
|
{
|
||||||
|
"job_id": "job_20260331_012345_12345",
|
||||||
|
"status": "accepted",
|
||||||
|
"message": "Training job started"
|
||||||
|
}
|
||||||
|
|
||||||
|
GET /status/{job_id}
|
||||||
|
Response:
|
||||||
|
{
|
||||||
|
"job_id": "job_20260331_012345_12345",
|
||||||
|
"status": "completed",
|
||||||
|
"training_samples": 50,
|
||||||
|
"loss_history": [0.5, 0.45, 0.42, ...],
|
||||||
|
"checkpoint_path": "/home/kent/poc/consciousness/training/checkpoints/checkpoint_job_20260331_012345_12345.pt"
|
||||||
|
}
|
||||||
|
|
||||||
|
GET /checkpoints
|
||||||
|
Response:
|
||||||
|
{
|
||||||
|
"checkpoints": [
|
||||||
|
{
|
||||||
|
"filename": "checkpoint_job_20260331_012345_12345.pt",
|
||||||
|
"path": "/home/kent/poc/consciousness/training/checkpoints/checkpoint_job_20260331_012345_12345.pt",
|
||||||
|
"created_at": "2026-03-31T01:23:45",
|
||||||
|
"size": 55000000000
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Training Pipeline
|
||||||
|
|
||||||
|
### 1. Signal Detection
|
||||||
|
- Training signal agent monitors conversation logs
|
||||||
|
- Identifies patterns where PoC could improve:
|
||||||
|
- Responses that needed memories to get right
|
||||||
|
- Things that could be done better with more time/context
|
||||||
|
- High-frequency memory accesses indicating knowledge gaps
|
||||||
|
- Builds training dataset with input/expected_output/rationale
|
||||||
|
|
||||||
|
### 2. Training Request
|
||||||
|
- Agent sends POST /train with training samples
|
||||||
|
- Apollo Worker accepts job and begins execution
|
||||||
|
|
||||||
|
### 3. vLLM Pause
|
||||||
|
- Apollo Worker signals vLLM to pause inference
|
||||||
|
- vLLM freezes in-flight requests
|
||||||
|
- GPU memory freed from KV cache becomes available
|
||||||
|
|
||||||
|
### 4. Model Loading & Training
|
||||||
|
- Load model weights (shared with vLLM via memory mapping)
|
||||||
|
- Enable gradients: `torch.enable_grad()`
|
||||||
|
- Run APOLLO-Mini training loop:
|
||||||
|
- Project gradients into rank-1 subspace
|
||||||
|
- Update moments in projected space
|
||||||
|
- Compute tensor-wise scaling factor
|
||||||
|
- Apply updates to full gradient
|
||||||
|
- Track loss history
|
||||||
|
|
||||||
|
### 5. Checkpoint Saving
|
||||||
|
- Save model state dict
|
||||||
|
- Record training metadata (samples, loss history, job ID)
|
||||||
|
- Store in checkpoint directory
|
||||||
|
|
||||||
|
### 6. vLLM Resume
|
||||||
|
- Signal vLLM to resume inference
|
||||||
|
- KV cache rebuilt as new requests arrive
|
||||||
|
- Updated weights now active in inference
|
||||||
|
|
||||||
|
## Memory Management
|
||||||
|
|
||||||
|
### APOLLO-Mini Advantages
|
||||||
|
- **Optimizer state**: ~kilobytes (vs. gigabytes for AdamW)
|
||||||
|
- **Gradient memory**: Only for current batch (not full model)
|
||||||
|
- **Activation memory**: Only for current training step
|
||||||
|
- **Total overhead**: ~55GB for full fine-tuning, much less for LoRA
|
||||||
|
|
||||||
|
### vLLM Memory Reclamation
|
||||||
|
- KV cache can consume 50-70% of GPU memory during inference
|
||||||
|
- Pausing inference frees this memory for training
|
||||||
|
- Training can use reclaimed space without evicting model weights
|
||||||
|
|
||||||
|
### Strategy
|
||||||
|
1. **LoRA + APOLLO-Mini**: Train only adapter parameters (~100MB for rank-16)
|
||||||
|
2. **Time-multiplexed**: Pause inference, train, resume
|
||||||
|
3. **Nightly checkpoints**: Save full model state overnight when inference load is low
|
||||||
|
|
||||||
|
## Implementation Phases
|
||||||
|
|
||||||
|
### Phase 1: Prototype (Current)
|
||||||
|
- [x] Apollo Worker daemon skeleton
|
||||||
|
- [ ] vLLM pause/resume integration
|
||||||
|
- [ ] Basic training loop with placeholder model
|
||||||
|
- [ ] Checkpoint saving/loading
|
||||||
|
- [ ] Test with small dataset
|
||||||
|
|
||||||
|
### Phase 2: Integration
|
||||||
|
- [ ] Connect to actual Qwen3.5-27B model
|
||||||
|
- [ ] Implement vLLM pause/resume API
|
||||||
|
- [ ] Memory mapping for weight sharing
|
||||||
|
- [ ] Training signal agent MVP
|
||||||
|
- [ ] End-to-end test with real conversations
|
||||||
|
|
||||||
|
### Phase 3: Production
|
||||||
|
- [ ] APOLLO-Mini implementation (rank-1 projection)
|
||||||
|
- [ ] LoRA adapter integration
|
||||||
|
- [ ] Nightly checkpoint scheduling
|
||||||
|
- [ ] Training metrics and monitoring
|
||||||
|
- [ ] Rollback mechanism for bad checkpoints
|
||||||
|
|
||||||
|
## Technical Challenges
|
||||||
|
|
||||||
|
### 1. vLLM Pause/Resume
|
||||||
|
- vLLM's `pause_generation()` API needs testing
|
||||||
|
- In-flight request handling during pause
|
||||||
|
- KV cache invalidation strategy
|
||||||
|
|
||||||
|
### 2. Gradient Computation
|
||||||
|
- `torch.inference_mode()` blocks gradients
|
||||||
|
- Must override with `torch.enable_grad()` during training
|
||||||
|
- CUDA graphs incompatible with training (use eager mode)
|
||||||
|
|
||||||
|
### 3. Memory Sharing
|
||||||
|
- Model weights must be shared between vLLM and training process
|
||||||
|
- Memory mapping or zero-copy IPC
|
||||||
|
- Tensor parallelism consistency (if using TP)
|
||||||
|
|
||||||
|
### 4. APOLLO-Mini Implementation
|
||||||
|
- Rank-1 gradient projection
|
||||||
|
- Fixed random projection matrix (not SVD)
|
||||||
|
- Tensor-wise scaling factor computation
|
||||||
|
- Integration with existing optimizer infrastructure
|
||||||
|
|
||||||
|
## Next Steps
|
||||||
|
|
||||||
|
1. **Test vLLM pause/resume**: Verify API works and measure overhead
|
||||||
|
2. **Implement weight sharing**: Memory map model weights between processes
|
||||||
|
3. **Build training signal agent**: MVP that identifies improvement opportunities
|
||||||
|
4. **Test end-to-end**: Run training job with real conversation data
|
||||||
|
5. **Optimize**: Measure memory usage, training time, inference impact
|
||||||
|
|
||||||
|
## References
|
||||||
|
|
||||||
|
- APOLLO-Mini paper: arXiv:2412.05270
|
||||||
|
- vLLM source: `/tmp/vllm/`
|
||||||
|
- LeMix (interleaved training/inference): arXiv:2507.21276
|
||||||
|
- Research document: `/home/kent/.claude/projects/-home-kent-bcachefs-tools/memory/research-apollo-vllm-finetuning.md`
|
||||||
162
training/apollo_mini.py
Normal file
162
training/apollo_mini.py
Normal file
|
|
@ -0,0 +1,162 @@
|
||||||
|
"""Apollo-Mini optimizer — rank-1 gradient scaling with SGD-level memory.
|
||||||
|
|
||||||
|
Implements the core algorithm from "APOLLO: Approximated Gradient Scaling
|
||||||
|
for Memory-Efficient LLM Optimization" (arXiv:2412.05270).
|
||||||
|
|
||||||
|
For each parameter tensor, maintains:
|
||||||
|
- rank-1 projected first moment (m): [m, 1] or [1, n]
|
||||||
|
- rank-1 projected second moment (v): same shape
|
||||||
|
- fixed random projection vector (regenerated from seed)
|
||||||
|
|
||||||
|
Total optimizer state: ~50MB for a 27B model (vs 54GB for AdamW).
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.optim import Optimizer
|
||||||
|
|
||||||
|
|
||||||
|
class ApolloMini(Optimizer):
|
||||||
|
"""Apollo-Mini: rank-1 tensor-wise gradient scaling.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
params: model parameters
|
||||||
|
lr: learning rate (default: 1e-4)
|
||||||
|
betas: coefficients for moment estimates (default: (0.9, 0.999))
|
||||||
|
eps: term for numerical stability (default: 1e-8)
|
||||||
|
weight_decay: decoupled weight decay (default: 0.01)
|
||||||
|
warmup_steps: linear warmup steps (default: 0)
|
||||||
|
scale: scaling factor for projection (default: 128)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, params, lr=1e-4, betas=(0.9, 0.999), eps=1e-8,
|
||||||
|
weight_decay=0.01, warmup_steps=0, scale=128):
|
||||||
|
defaults = dict(lr=lr, betas=betas, eps=eps,
|
||||||
|
weight_decay=weight_decay,
|
||||||
|
warmup_steps=warmup_steps, scale=scale)
|
||||||
|
super().__init__(params, defaults)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def step(self, closure=None):
|
||||||
|
loss = None
|
||||||
|
if closure is not None:
|
||||||
|
with torch.enable_grad():
|
||||||
|
loss = closure()
|
||||||
|
|
||||||
|
for group in self.param_groups:
|
||||||
|
lr = group['lr']
|
||||||
|
beta1, beta2 = group['betas']
|
||||||
|
eps = group['eps']
|
||||||
|
weight_decay = group['weight_decay']
|
||||||
|
|
||||||
|
for p in group['params']:
|
||||||
|
if p.grad is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
grad = p.grad.float()
|
||||||
|
state = self.state[p]
|
||||||
|
|
||||||
|
# Initialize state
|
||||||
|
if len(state) == 0:
|
||||||
|
state['step'] = 0
|
||||||
|
state['seed'] = id(p) # deterministic per-param seed
|
||||||
|
|
||||||
|
# Determine projection dimension
|
||||||
|
if grad.ndim >= 2:
|
||||||
|
if grad.shape[0] >= grad.shape[1]:
|
||||||
|
proj_shape = (grad.shape[1], 1)
|
||||||
|
state['proj_dim'] = 'right'
|
||||||
|
moment_shape = (grad.shape[0], 1)
|
||||||
|
else:
|
||||||
|
proj_shape = (1, grad.shape[0])
|
||||||
|
state['proj_dim'] = 'left'
|
||||||
|
moment_shape = (1, grad.shape[1])
|
||||||
|
|
||||||
|
state['exp_avg'] = torch.zeros(moment_shape,
|
||||||
|
device=p.device)
|
||||||
|
state['exp_avg_sq'] = torch.zeros(moment_shape,
|
||||||
|
device=p.device)
|
||||||
|
state['has_proj'] = True
|
||||||
|
else:
|
||||||
|
# 1D params (biases, norms): use standard Adam
|
||||||
|
state['exp_avg'] = torch.zeros_like(grad)
|
||||||
|
state['exp_avg_sq'] = torch.zeros_like(grad)
|
||||||
|
state['has_proj'] = False
|
||||||
|
|
||||||
|
state['step'] += 1
|
||||||
|
|
||||||
|
# Learning rate warmup
|
||||||
|
if group['warmup_steps'] > 0 and state['step'] <= group['warmup_steps']:
|
||||||
|
lr_scale = state['step'] / group['warmup_steps']
|
||||||
|
else:
|
||||||
|
lr_scale = 1.0
|
||||||
|
|
||||||
|
if state['has_proj']:
|
||||||
|
# Generate deterministic random projection vector
|
||||||
|
gen = torch.Generator(device=p.device)
|
||||||
|
gen.manual_seed(state['seed'] + state['step'])
|
||||||
|
|
||||||
|
# Project gradient to rank-1
|
||||||
|
if state['proj_dim'] == 'right':
|
||||||
|
proj_vec = torch.randn(grad.shape[1], 1,
|
||||||
|
device=p.device,
|
||||||
|
generator=gen)
|
||||||
|
proj_vec = proj_vec / (proj_vec.norm() + eps)
|
||||||
|
proj_grad = grad @ proj_vec # [m, 1]
|
||||||
|
else:
|
||||||
|
proj_vec = torch.randn(1, grad.shape[0],
|
||||||
|
device=p.device,
|
||||||
|
generator=gen)
|
||||||
|
proj_vec = proj_vec / (proj_vec.norm() + eps)
|
||||||
|
proj_grad = proj_vec @ grad # [1, n]
|
||||||
|
|
||||||
|
# Update moments in projected space
|
||||||
|
state['exp_avg'].mul_(beta1).add_(proj_grad, alpha=1 - beta1)
|
||||||
|
state['exp_avg_sq'].mul_(beta2).addcmul_(
|
||||||
|
proj_grad, proj_grad, value=1 - beta2)
|
||||||
|
|
||||||
|
# Bias correction
|
||||||
|
bc1 = 1 - beta1 ** state['step']
|
||||||
|
bc2 = 1 - beta2 ** state['step']
|
||||||
|
m_hat = state['exp_avg'] / bc1
|
||||||
|
v_hat = state['exp_avg_sq'] / bc2
|
||||||
|
|
||||||
|
# Adam update in projected space
|
||||||
|
adam_update = m_hat / (v_hat.sqrt() + eps)
|
||||||
|
|
||||||
|
# Tensor-wise scaling factor
|
||||||
|
scaling = adam_update.norm() / (proj_grad.norm() + eps)
|
||||||
|
|
||||||
|
# Apply to full gradient
|
||||||
|
step_size = lr * lr_scale
|
||||||
|
p.add_(grad.to(p.dtype) * (-step_size * scaling))
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Standard Adam for 1D params
|
||||||
|
state['exp_avg'].mul_(beta1).add_(grad, alpha=1 - beta1)
|
||||||
|
state['exp_avg_sq'].mul_(beta2).addcmul_(
|
||||||
|
grad, grad, value=1 - beta2)
|
||||||
|
|
||||||
|
bc1 = 1 - beta1 ** state['step']
|
||||||
|
bc2 = 1 - beta2 ** state['step']
|
||||||
|
m_hat = state['exp_avg'] / bc1
|
||||||
|
v_hat = state['exp_avg_sq'] / bc2
|
||||||
|
|
||||||
|
update = m_hat / (v_hat.sqrt() + eps)
|
||||||
|
step_size = lr * lr_scale
|
||||||
|
p.add_(update.to(p.dtype), alpha=-step_size)
|
||||||
|
|
||||||
|
# Decoupled weight decay
|
||||||
|
if weight_decay > 0:
|
||||||
|
p.add_(p, alpha=-lr * lr_scale * weight_decay)
|
||||||
|
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def state_size_bytes(self):
|
||||||
|
"""Total optimizer state memory in bytes."""
|
||||||
|
total = 0
|
||||||
|
for state in self.state.values():
|
||||||
|
if isinstance(state, dict):
|
||||||
|
for v in state.values():
|
||||||
|
if isinstance(v, torch.Tensor):
|
||||||
|
total += v.nelement() * v.element_size()
|
||||||
|
return total
|
||||||
453
training/apollo_worker.py
Executable file
453
training/apollo_worker.py
Executable file
|
|
@ -0,0 +1,453 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Apollo Mini Training Daemon
|
||||||
|
|
||||||
|
This daemon:
|
||||||
|
1. Listens over HTTPS for training requests from poc-agent
|
||||||
|
2. Pauses vLLM inference
|
||||||
|
3. Runs APOLLO-Mini training with torch.enable_grad()
|
||||||
|
4. Saves checkpoints and training metadata
|
||||||
|
5. Resumes vLLM inference
|
||||||
|
|
||||||
|
Communication protocol:
|
||||||
|
- POST /train: Start a training job
|
||||||
|
- GET /status/{job_id}: Check training status
|
||||||
|
- GET /checkpoints: List available checkpoints
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass, field, asdict
|
||||||
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional, Dict, Any, List
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from aiohttp import web
|
||||||
|
|
||||||
|
# Configure logging
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||||
|
)
|
||||||
|
logger = logging.getLogger('apollo_worker')
|
||||||
|
|
||||||
|
class TrainingStatus(Enum):
|
||||||
|
PENDING = "pending"
|
||||||
|
PAUSING_VLLM = "pausing_vllm"
|
||||||
|
TRAINING = "training"
|
||||||
|
SAVING_CHECKPOINT = "saving_checkpoint"
|
||||||
|
RESUMING_VLLM = "resuming_vllm"
|
||||||
|
COMPLETED = "completed"
|
||||||
|
FAILED = "failed"
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TrainingJob:
|
||||||
|
job_id: str
|
||||||
|
status: TrainingStatus
|
||||||
|
created_at: datetime
|
||||||
|
started_at: Optional[datetime] = None
|
||||||
|
completed_at: Optional[datetime] = None
|
||||||
|
model_path: Optional[str] = None
|
||||||
|
checkpoint_path: Optional[str] = None
|
||||||
|
training_samples: int = 0
|
||||||
|
loss_history: List[float] = field(default_factory=list)
|
||||||
|
error: Optional[str] = None
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
'job_id': self.job_id,
|
||||||
|
'status': self.status.value,
|
||||||
|
'created_at': self.created_at.isoformat(),
|
||||||
|
'started_at': self.started_at.isoformat() if self.started_at else None,
|
||||||
|
'completed_at': self.completed_at.isoformat() if self.completed_at else None,
|
||||||
|
'model_path': self.model_path,
|
||||||
|
'checkpoint_path': self.checkpoint_path,
|
||||||
|
'training_samples': self.training_samples,
|
||||||
|
'loss_history': self.loss_history,
|
||||||
|
'error': self.error,
|
||||||
|
}
|
||||||
|
|
||||||
|
class ApolloWorker:
|
||||||
|
def __init__(self, config_path: str = "/home/kent/poc/consciousness/training/config.json"):
|
||||||
|
self.config = self._load_config(config_path)
|
||||||
|
self.jobs: Dict[str, TrainingJob] = {}
|
||||||
|
self.vllm_paused = False
|
||||||
|
self.app = web.Application()
|
||||||
|
self._setup_routes()
|
||||||
|
|
||||||
|
def _load_config(self, config_path: str) -> Dict[str, Any]:
|
||||||
|
"""Load configuration from file or use defaults."""
|
||||||
|
default_config = {
|
||||||
|
'host': '0.0.0.0',
|
||||||
|
'port': 8080,
|
||||||
|
'vllm_socket': '/tmp/vllm_control.sock',
|
||||||
|
'model_path': '/home/ubuntu/models/Qwen3.5-27B',
|
||||||
|
'checkpoint_dir': '/home/kent/poc/consciousness/training/checkpoints',
|
||||||
|
'max_training_samples': 100,
|
||||||
|
'learning_rate': 1e-5,
|
||||||
|
'batch_size': 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
if os.path.exists(config_path):
|
||||||
|
with open(config_path, 'r') as f:
|
||||||
|
user_config = json.load(f)
|
||||||
|
default_config.update(user_config)
|
||||||
|
|
||||||
|
Path(default_config['checkpoint_dir']).mkdir(parents=True, exist_ok=True)
|
||||||
|
return default_config
|
||||||
|
|
||||||
|
def _setup_routes(self):
|
||||||
|
"""Setup HTTP routes."""
|
||||||
|
self.app.router.add_post('/train', self.handle_train_request)
|
||||||
|
self.app.router.add_get('/status/{job_id}', self.handle_status_request)
|
||||||
|
self.app.router.add_get('/checkpoints', self.handle_list_checkpoints)
|
||||||
|
self.app.router.add_get('/health', self.handle_health_check)
|
||||||
|
|
||||||
|
async def handle_health_check(self, request: web.Request) -> web.Response:
|
||||||
|
"""Health check endpoint."""
|
||||||
|
return web.json_response({
|
||||||
|
'status': 'healthy',
|
||||||
|
'vllm_paused': self.vllm_paused,
|
||||||
|
'active_jobs': len([j for j in self.jobs.values() if j.status in [TrainingStatus.TRAINING, TrainingStatus.PAUSING_VLLM, TrainingStatus.RESUMING_VLLM]])
|
||||||
|
})
|
||||||
|
|
||||||
|
async def handle_train_request(self, request: web.Request) -> web.Response:
|
||||||
|
"""Handle training request from poc-agent."""
|
||||||
|
try:
|
||||||
|
data = await request.json()
|
||||||
|
|
||||||
|
# Validate required fields
|
||||||
|
if 'training_data' not in data:
|
||||||
|
return web.json_response(
|
||||||
|
{'error': 'Missing training_data field'},
|
||||||
|
status=400
|
||||||
|
)
|
||||||
|
|
||||||
|
job_id = f"job_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{os.getpid()}"
|
||||||
|
job = TrainingJob(
|
||||||
|
job_id=job_id,
|
||||||
|
status=TrainingStatus.PENDING,
|
||||||
|
created_at=datetime.now(),
|
||||||
|
model_path=self.config['model_path']
|
||||||
|
)
|
||||||
|
self.jobs[job_id] = job
|
||||||
|
|
||||||
|
# Start training in background
|
||||||
|
asyncio.create_task(self.execute_training(job, data))
|
||||||
|
|
||||||
|
return web.json_response({
|
||||||
|
'job_id': job_id,
|
||||||
|
'status': 'accepted',
|
||||||
|
'message': 'Training job started'
|
||||||
|
})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error handling train request: {e}")
|
||||||
|
return web.json_response(
|
||||||
|
{'error': str(e)},
|
||||||
|
status=500
|
||||||
|
)
|
||||||
|
|
||||||
|
async def handle_status_request(self, request: web.Request) -> web.Response:
|
||||||
|
"""Get training job status."""
|
||||||
|
job_id = request.match_info['job_id']
|
||||||
|
|
||||||
|
if job_id not in self.jobs:
|
||||||
|
return web.json_response(
|
||||||
|
{'error': 'Job not found'},
|
||||||
|
status=404
|
||||||
|
)
|
||||||
|
|
||||||
|
job = self.jobs[job_id]
|
||||||
|
return web.json_response(job.to_dict())
|
||||||
|
|
||||||
|
async def handle_list_checkpoints(self, request: web.Request) -> web.Response:
|
||||||
|
"""List available checkpoints."""
|
||||||
|
checkpoint_dir = Path(self.config['checkpoint_dir'])
|
||||||
|
checkpoints = []
|
||||||
|
|
||||||
|
if checkpoint_dir.exists():
|
||||||
|
for checkpoint_file in sorted(checkpoint_dir.glob('checkpoint_*.pt'), key=lambda x: x.stat().st_mtime, reverse=True):
|
||||||
|
checkpoints.append({
|
||||||
|
'filename': checkpoint_file.name,
|
||||||
|
'path': str(checkpoint_file),
|
||||||
|
'created_at': datetime.fromtimestamp(checkpoint_file.stat().st_mtime).isoformat(),
|
||||||
|
'size': checkpoint_file.stat().st_size
|
||||||
|
})
|
||||||
|
|
||||||
|
return web.json_response({'checkpoints': checkpoints})
|
||||||
|
|
||||||
|
async def execute_training(self, job: TrainingJob, training_data: Dict[str, Any]):
|
||||||
|
"""Execute the training pipeline."""
|
||||||
|
try:
|
||||||
|
logger.info(f"Starting training job {job.job_id}")
|
||||||
|
job.started_at = datetime.now()
|
||||||
|
|
||||||
|
# Step 1: Pause vLLM
|
||||||
|
job.status = TrainingStatus.PAUSING_VLLM
|
||||||
|
logger.info("Pausing vLLM...")
|
||||||
|
await self.pause_vllm()
|
||||||
|
self.vllm_paused = True
|
||||||
|
|
||||||
|
# Step 2: Load model and prepare for training
|
||||||
|
job.status = TrainingStatus.TRAINING
|
||||||
|
logger.info("Loading model and preparing for training...")
|
||||||
|
|
||||||
|
# Load model (this would be the actual Qwen3.5-27B model)
|
||||||
|
# For now, we'll use a placeholder
|
||||||
|
model = await self.load_model_for_training()
|
||||||
|
|
||||||
|
# Step 3: Run APOLLO-Mini training
|
||||||
|
logger.info(f"Starting APOLLO-Mini training with {len(training_data['samples'])} samples")
|
||||||
|
|
||||||
|
# Extract training samples
|
||||||
|
samples = training_data['samples']
|
||||||
|
job.training_samples = len(samples)
|
||||||
|
|
||||||
|
# Run training loop
|
||||||
|
loss_history = await self.run_apollo_training(model, samples, training_data.get('config', {}))
|
||||||
|
job.loss_history = loss_history
|
||||||
|
|
||||||
|
# Step 4: Save checkpoint
|
||||||
|
job.status = TrainingStatus.SAVING_CHECKPOINT
|
||||||
|
logger.info("Saving checkpoint...")
|
||||||
|
checkpoint_path = await self.save_checkpoint(model, job)
|
||||||
|
job.checkpoint_path = checkpoint_path
|
||||||
|
|
||||||
|
# Step 5: Resume vLLM
|
||||||
|
job.status = TrainingStatus.RESUMING_VLLM
|
||||||
|
logger.info("Resuming vLLM...")
|
||||||
|
await self.resume_vllm()
|
||||||
|
self.vllm_paused = False
|
||||||
|
|
||||||
|
# Mark job as completed
|
||||||
|
job.status = TrainingStatus.COMPLETED
|
||||||
|
job.completed_at = datetime.now()
|
||||||
|
|
||||||
|
logger.info(f"Training job {job.job_id} completed successfully")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Training job {job.job_id} failed: {e}")
|
||||||
|
job.status = TrainingStatus.FAILED
|
||||||
|
job.error = str(e)
|
||||||
|
job.completed_at = datetime.now()
|
||||||
|
|
||||||
|
# Try to resume vLLM if it was paused
|
||||||
|
if self.vllm_paused:
|
||||||
|
try:
|
||||||
|
await self.resume_vllm()
|
||||||
|
self.vllm_paused = False
|
||||||
|
except Exception as resume_error:
|
||||||
|
logger.error(f"Failed to resume vLLM after training error: {resume_error}")
|
||||||
|
|
||||||
|
async def pause_vllm(self):
|
||||||
|
"""Pause vLLM inference via HTTP API."""
|
||||||
|
import aiohttp as aio
|
||||||
|
url = self.config.get('vllm_url', 'http://localhost:8000')
|
||||||
|
try:
|
||||||
|
async with aio.ClientSession() as session:
|
||||||
|
async with session.post(
|
||||||
|
f"{url}/pause_generation",
|
||||||
|
json={"mode": "keep", "clear_cache": False},
|
||||||
|
timeout=aio.ClientTimeout(total=10),
|
||||||
|
) as resp:
|
||||||
|
resp.raise_for_status()
|
||||||
|
logger.info("vLLM paused")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to pause vLLM: {e}")
|
||||||
|
|
||||||
|
async def resume_vllm(self):
|
||||||
|
"""Resume vLLM inference via HTTP API."""
|
||||||
|
import aiohttp as aio
|
||||||
|
url = self.config.get('vllm_url', 'http://localhost:8000')
|
||||||
|
try:
|
||||||
|
async with aio.ClientSession() as session:
|
||||||
|
async with session.post(
|
||||||
|
f"{url}/resume_generation",
|
||||||
|
timeout=aio.ClientTimeout(total=10),
|
||||||
|
) as resp:
|
||||||
|
resp.raise_for_status()
|
||||||
|
logger.info("vLLM resumed")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to resume vLLM: {e}")
|
||||||
|
|
||||||
|
async def load_model_for_training(self) -> nn.Module:
|
||||||
|
"""Load HF model with weights pointing to vLLM's GPU memory.
|
||||||
|
|
||||||
|
Imports vLLM's weight tensors via CUDA IPC, creates HF-compatible
|
||||||
|
views (narrowing merged weights into separate q/k/v/z etc.), and
|
||||||
|
constructs the HF model around those views. No weight copying —
|
||||||
|
all parameters share vLLM's GPU memory.
|
||||||
|
"""
|
||||||
|
handle_path = self.config.get('weight_handles', '/tmp/vllm_weight_handles.pt')
|
||||||
|
model_path = self.config['model_path']
|
||||||
|
|
||||||
|
# Import vLLM weights via CUDA IPC
|
||||||
|
logger.info(f"Importing vLLM weights from {handle_path}")
|
||||||
|
handles = torch.load(handle_path, weights_only=False)
|
||||||
|
vllm_params = {}
|
||||||
|
for name, info in handles.items():
|
||||||
|
func, args = info['handle']
|
||||||
|
vllm_params[name] = func(*args)
|
||||||
|
logger.info(f"Imported {len(vllm_params)} parameters")
|
||||||
|
|
||||||
|
# Map vLLM merged layout → HF separate layout (views, no copies)
|
||||||
|
from weight_mapping import load_hf_model_with_vllm_weights
|
||||||
|
model = load_hf_model_with_vllm_weights(vllm_params, model_path)
|
||||||
|
logger.info("HF model constructed with vLLM weight views")
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
async def run_apollo_training(self, model: nn.Module,
|
||||||
|
samples: List[Dict[str, str]],
|
||||||
|
config: Dict[str, Any]) -> List[float]:
|
||||||
|
"""Run Apollo-Mini training on conversation decision points."""
|
||||||
|
from apollo_mini import ApolloMini
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
lr = config.get('learning_rate', self.config['learning_rate'])
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
self.config['model_path'], trust_remote_code=True)
|
||||||
|
|
||||||
|
# Build parameter groups (Apollo for 2D+, standard for small/1D)
|
||||||
|
apollo_params, standard_params = [], []
|
||||||
|
for p in model.parameters():
|
||||||
|
if p.requires_grad:
|
||||||
|
if p.ndim >= 2 and min(p.shape) >= 2:
|
||||||
|
apollo_params.append(p)
|
||||||
|
else:
|
||||||
|
standard_params.append(p)
|
||||||
|
|
||||||
|
groups = []
|
||||||
|
if apollo_params:
|
||||||
|
groups.append({'params': apollo_params})
|
||||||
|
if standard_params:
|
||||||
|
groups.append({'params': standard_params})
|
||||||
|
|
||||||
|
optimizer = ApolloMini(groups, lr=lr)
|
||||||
|
logger.info(f"Apollo-Mini: {len(apollo_params)} apollo params, "
|
||||||
|
f"{len(standard_params)} standard, "
|
||||||
|
f"state={optimizer.state_size_bytes()/1e6:.1f}MB")
|
||||||
|
|
||||||
|
loss_history = []
|
||||||
|
|
||||||
|
for i, sample in enumerate(samples):
|
||||||
|
context = sample.get('context', '')
|
||||||
|
continuation = sample.get('continuation', '')
|
||||||
|
|
||||||
|
# Tokenize
|
||||||
|
ctx_ids = tokenizer.encode(context, add_special_tokens=True)
|
||||||
|
cont_ids = tokenizer.encode(continuation, add_special_tokens=False)
|
||||||
|
all_ids = ctx_ids + cont_ids
|
||||||
|
context_len = len(ctx_ids)
|
||||||
|
|
||||||
|
input_ids = torch.tensor([all_ids], device='cuda:0')
|
||||||
|
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
# Context-frozen forward pass
|
||||||
|
with torch.no_grad():
|
||||||
|
# Forward through context (no gradients)
|
||||||
|
outputs = model(input_ids[:, :context_len], use_cache=True)
|
||||||
|
past_kv = outputs.past_key_values
|
||||||
|
|
||||||
|
# Decision tokens with gradients
|
||||||
|
with torch.enable_grad():
|
||||||
|
outputs = model(
|
||||||
|
input_ids[:, context_len:],
|
||||||
|
past_key_values=past_kv,
|
||||||
|
use_cache=False,
|
||||||
|
)
|
||||||
|
logits = outputs.logits # [1, cont_len, vocab]
|
||||||
|
|
||||||
|
# Shift: predict next token from each position
|
||||||
|
shift_logits = logits[:, :-1].contiguous()
|
||||||
|
shift_labels = input_ids[:, context_len + 1:].contiguous()
|
||||||
|
|
||||||
|
loss = nn.functional.cross_entropy(
|
||||||
|
shift_logits.view(-1, shift_logits.size(-1)),
|
||||||
|
shift_labels.view(-1),
|
||||||
|
)
|
||||||
|
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
loss_val = loss.item()
|
||||||
|
loss_history.append(loss_val)
|
||||||
|
logger.info(f"Step {i+1}/{len(samples)}: loss={loss_val:.4f} "
|
||||||
|
f"(ctx={context_len}, cont={len(cont_ids)} tokens)")
|
||||||
|
|
||||||
|
logger.info(f"Training done: {len(samples)} examples, "
|
||||||
|
f"final loss={loss_history[-1]:.4f}")
|
||||||
|
return loss_history
|
||||||
|
|
||||||
|
async def save_checkpoint(self, model: nn.Module, job: TrainingJob) -> str:
|
||||||
|
"""Save model checkpoint in HuggingFace safetensors format."""
|
||||||
|
from safetensors.torch import save_file
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
checkpoint_dir = Path(self.config['checkpoint_dir'])
|
||||||
|
date_str = datetime.now().strftime('%Y-%m-%d')
|
||||||
|
out_dir = checkpoint_dir / date_str
|
||||||
|
out_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Save weights
|
||||||
|
tensors = {name: p.data.contiguous().cpu()
|
||||||
|
for name, p in model.named_parameters()}
|
||||||
|
save_path = out_dir / "model.safetensors"
|
||||||
|
save_file(tensors, str(save_path))
|
||||||
|
|
||||||
|
# Copy config files
|
||||||
|
config_dir = Path(self.config['model_path'])
|
||||||
|
for f in ['config.json', 'tokenizer.json', 'tokenizer_config.json',
|
||||||
|
'special_tokens_map.json']:
|
||||||
|
src = config_dir / f
|
||||||
|
if src.exists():
|
||||||
|
shutil.copy2(src, out_dir / f)
|
||||||
|
|
||||||
|
# Save training metadata
|
||||||
|
meta = {
|
||||||
|
'job_id': job.job_id,
|
||||||
|
'training_samples': job.training_samples,
|
||||||
|
'loss_history': job.loss_history,
|
||||||
|
'timestamp': datetime.now().isoformat(),
|
||||||
|
}
|
||||||
|
with open(out_dir / 'training-meta.json', 'w') as f:
|
||||||
|
json.dump(meta, f, indent=2)
|
||||||
|
|
||||||
|
# Update latest symlink
|
||||||
|
latest = checkpoint_dir / 'latest'
|
||||||
|
if latest.is_symlink():
|
||||||
|
latest.unlink()
|
||||||
|
latest.symlink_to(date_str)
|
||||||
|
|
||||||
|
size_gb = save_path.stat().st_size / 1e9
|
||||||
|
logger.info(f"Checkpoint: {out_dir} ({size_gb:.1f} GB)")
|
||||||
|
return str(out_dir)
|
||||||
|
|
||||||
|
async def run(self):
|
||||||
|
"""Run the daemon."""
|
||||||
|
logger.info(f"Starting Apollo Worker on {self.config['host']}:{self.config['port']}")
|
||||||
|
runner = web.AppRunner(self.app)
|
||||||
|
await runner.setup()
|
||||||
|
site = web.TCPSite(runner, self.config['host'], self.config['port'])
|
||||||
|
await site.start()
|
||||||
|
logger.info("Apollo Worker is running")
|
||||||
|
|
||||||
|
# Keep running
|
||||||
|
while True:
|
||||||
|
await asyncio.sleep(3600) # Sleep for an hour
|
||||||
|
|
||||||
|
def main():
|
||||||
|
worker = ApolloWorker()
|
||||||
|
asyncio.run(worker.run())
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
87
training/export_weights.py
Normal file
87
training/export_weights.py
Normal file
|
|
@ -0,0 +1,87 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
"""Export vLLM's live model weight IPC handles for the training process.
|
||||||
|
|
||||||
|
Connects to a running vLLM instance, iterates over model parameters,
|
||||||
|
and exports CUDA IPC handles that allow another process to access the
|
||||||
|
same GPU memory without copying.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
# Run after vLLM is serving:
|
||||||
|
python3 export_weights.py --output /tmp/vllm_weight_handles.pt
|
||||||
|
|
||||||
|
# Or via vLLM's API (future):
|
||||||
|
curl -X POST http://localhost:8000/export_weights
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import sys
|
||||||
|
import torch
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
def export_from_model(model, output_path: str):
|
||||||
|
"""Export IPC handles for all model parameters."""
|
||||||
|
from torch.multiprocessing.reductions import reduce_tensor
|
||||||
|
|
||||||
|
handles = {}
|
||||||
|
total_bytes = 0
|
||||||
|
|
||||||
|
for name, param in model.named_parameters():
|
||||||
|
handle = reduce_tensor(param.data)
|
||||||
|
handles[name] = {
|
||||||
|
'handle': handle,
|
||||||
|
'shape': list(param.shape),
|
||||||
|
'dtype': str(param.dtype),
|
||||||
|
}
|
||||||
|
param_bytes = param.nelement() * param.element_size()
|
||||||
|
total_bytes += param_bytes
|
||||||
|
|
||||||
|
torch.save(handles, output_path)
|
||||||
|
|
||||||
|
n_params = len(handles)
|
||||||
|
print(f"Exported {n_params} parameters ({total_bytes / 1e9:.1f} GB)")
|
||||||
|
print(f"Saved to {output_path}")
|
||||||
|
return handles
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="Export vLLM weight IPC handles")
|
||||||
|
parser.add_argument("--output", "-o", default="/tmp/vllm_weight_handles.pt",
|
||||||
|
help="Output path for IPC handles")
|
||||||
|
parser.add_argument("--vllm-pid", type=int, default=None,
|
||||||
|
help="vLLM worker PID (auto-detected if not specified)")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# For now: load the model directly and export.
|
||||||
|
# TODO: connect to running vLLM process instead.
|
||||||
|
print("Note: This currently loads the model separately.")
|
||||||
|
print("Full integration will export from the running vLLM process.")
|
||||||
|
print()
|
||||||
|
|
||||||
|
# Detect model path from running vLLM
|
||||||
|
import subprocess
|
||||||
|
result = subprocess.run(
|
||||||
|
['ps', 'aux'], capture_output=True, text=True
|
||||||
|
)
|
||||||
|
model_path = None
|
||||||
|
for line in result.stdout.split('\n'):
|
||||||
|
if 'vllm' in line and '--model' in line:
|
||||||
|
parts = line.split()
|
||||||
|
for i, p in enumerate(parts):
|
||||||
|
if p == '--model' and i + 1 < len(parts):
|
||||||
|
model_path = parts[i + 1]
|
||||||
|
break
|
||||||
|
# Also check model_tag format
|
||||||
|
if p.startswith('--model='):
|
||||||
|
model_path = p.split('=', 1)[1]
|
||||||
|
break
|
||||||
|
|
||||||
|
if model_path:
|
||||||
|
print(f"Detected vLLM model: {model_path}")
|
||||||
|
else:
|
||||||
|
print("Could not detect running vLLM model. Specify manually.")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
269
training/train.py
Normal file
269
training/train.py
Normal file
|
|
@ -0,0 +1,269 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
"""Nightly training process for Apollo-Mini fine-tuning.
|
||||||
|
|
||||||
|
Imports vLLM's model weights via CUDA IPC, runs context-frozen
|
||||||
|
training on flagged conversation segments, saves updated checkpoint.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python3 train.py \
|
||||||
|
--weights /tmp/vllm_weight_handles.pt \
|
||||||
|
--examples training-examples.jsonl \
|
||||||
|
--checkpoint-dir checkpoints/ \
|
||||||
|
--lr 1e-5
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from safetensors.torch import save_file
|
||||||
|
|
||||||
|
from apollo_mini import ApolloMini
|
||||||
|
|
||||||
|
|
||||||
|
def import_weights(handle_path: str) -> dict[str, torch.Tensor]:
|
||||||
|
"""Import weight tensors from CUDA IPC handles."""
|
||||||
|
handles = torch.load(handle_path, weights_only=False)
|
||||||
|
params = {}
|
||||||
|
for name, info in handles.items():
|
||||||
|
func, args = info['handle']
|
||||||
|
tensor = func(*args)
|
||||||
|
params[name] = tensor
|
||||||
|
return params
|
||||||
|
|
||||||
|
|
||||||
|
def make_param_groups(params: dict[str, torch.Tensor]) -> list[dict]:
|
||||||
|
"""Split parameters into Apollo-Mini and standard groups.
|
||||||
|
|
||||||
|
Apollo-Mini needs 2D+ matrices with min dimension >= 2.
|
||||||
|
Small tensors (norms, biases, conv1d 3D weights) use standard Adam.
|
||||||
|
"""
|
||||||
|
apollo_params = []
|
||||||
|
standard_params = []
|
||||||
|
|
||||||
|
for name, p in params.items():
|
||||||
|
p.requires_grad_(True)
|
||||||
|
if p.ndim >= 2 and min(p.shape) >= 2:
|
||||||
|
apollo_params.append(p)
|
||||||
|
else:
|
||||||
|
standard_params.append(p)
|
||||||
|
|
||||||
|
groups = []
|
||||||
|
if apollo_params:
|
||||||
|
groups.append({
|
||||||
|
'params': apollo_params,
|
||||||
|
'name': 'apollo',
|
||||||
|
})
|
||||||
|
if standard_params:
|
||||||
|
groups.append({
|
||||||
|
'params': standard_params,
|
||||||
|
'name': 'standard',
|
||||||
|
})
|
||||||
|
|
||||||
|
n_apollo = sum(p.nelement() for p in apollo_params)
|
||||||
|
n_standard = sum(p.nelement() for p in standard_params)
|
||||||
|
print(f"Parameter groups: apollo={n_apollo/1e9:.2f}B, standard={n_standard/1e6:.1f}M")
|
||||||
|
return groups
|
||||||
|
|
||||||
|
|
||||||
|
def forward_pass(params, input_ids, context_len, device):
|
||||||
|
"""Run context-frozen forward pass.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
params: dict of name -> tensor (shared with vLLM)
|
||||||
|
input_ids: full sequence [1, seq_len]
|
||||||
|
context_len: number of context tokens (no gradient)
|
||||||
|
device: CUDA device
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
logits for decision tokens, target ids for loss
|
||||||
|
"""
|
||||||
|
# TODO: Build proper forward model matching vLLM's weight layout.
|
||||||
|
# For now this is a placeholder — the real implementation needs
|
||||||
|
# to replicate vLLM's model architecture (merged projections,
|
||||||
|
# GDN recurrence, full attention, MLP) using the shared weights.
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Forward model not yet implemented. "
|
||||||
|
"Need to build a model that matches vLLM's merged weight layout "
|
||||||
|
"(MergedColumnParallelLinear for qkvz/ba/gate_up, "
|
||||||
|
"RowParallelLinear for out_proj/down) and computes the same "
|
||||||
|
"forward pass with autograd enabled."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def save_checkpoint(params: dict[str, torch.Tensor],
|
||||||
|
checkpoint_dir: str,
|
||||||
|
config_path: str = None):
|
||||||
|
"""Save model checkpoint in HuggingFace safetensors format.
|
||||||
|
|
||||||
|
Saves weights split across shards matching the original model layout,
|
||||||
|
archives the previous checkpoint, and updates the 'latest' symlink.
|
||||||
|
"""
|
||||||
|
date_str = datetime.now().strftime("%Y-%m-%d")
|
||||||
|
out_dir = Path(checkpoint_dir) / date_str
|
||||||
|
out_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Save all weights in a single safetensors file for now.
|
||||||
|
# TODO: split across shards matching HF model index for large models.
|
||||||
|
tensors = {}
|
||||||
|
for name, param in params.items():
|
||||||
|
tensors[name] = param.data.contiguous().cpu()
|
||||||
|
|
||||||
|
save_path = out_dir / "model.safetensors"
|
||||||
|
save_file(tensors, str(save_path))
|
||||||
|
print(f"Saved checkpoint to {save_path} ({save_path.stat().st_size / 1e9:.1f} GB)")
|
||||||
|
|
||||||
|
# Copy config files if provided
|
||||||
|
if config_path:
|
||||||
|
import shutil
|
||||||
|
config_dir = Path(config_path)
|
||||||
|
for f in ['config.json', 'tokenizer.json', 'tokenizer_config.json',
|
||||||
|
'special_tokens_map.json', 'generation_config.json']:
|
||||||
|
src = config_dir / f
|
||||||
|
if src.exists():
|
||||||
|
shutil.copy2(src, out_dir / f)
|
||||||
|
|
||||||
|
# Update latest symlink
|
||||||
|
latest = Path(checkpoint_dir) / "latest"
|
||||||
|
if latest.is_symlink():
|
||||||
|
latest.unlink()
|
||||||
|
latest.symlink_to(date_str)
|
||||||
|
print(f"Updated {latest} -> {date_str}")
|
||||||
|
|
||||||
|
return str(out_dir)
|
||||||
|
|
||||||
|
|
||||||
|
def train_step(params, example, optimizer, device, log_entries):
|
||||||
|
"""Run one training step on a single example.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
params: dict of name -> tensor
|
||||||
|
example: dict with 'input_ids', 'context_len', 'target_ids'
|
||||||
|
optimizer: ApolloMini instance
|
||||||
|
device: CUDA device
|
||||||
|
log_entries: list to append log dicts to
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
loss value
|
||||||
|
"""
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
input_ids = torch.tensor(example['input_ids'], device=device).unsqueeze(0)
|
||||||
|
context_len = example['context_len']
|
||||||
|
|
||||||
|
# Forward pass (context frozen, decision tokens with grad)
|
||||||
|
logits, targets = forward_pass(params, input_ids, context_len, device)
|
||||||
|
|
||||||
|
# Cross-entropy loss on decision tokens
|
||||||
|
loss = torch.nn.functional.cross_entropy(
|
||||||
|
logits.view(-1, logits.shape[-1]),
|
||||||
|
targets.view(-1),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Backward
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
|
# Compute gradient stats before optimizer step
|
||||||
|
total_grad_norm = 0.0
|
||||||
|
for p in params.values():
|
||||||
|
if p.grad is not None:
|
||||||
|
total_grad_norm += p.grad.norm().item() ** 2
|
||||||
|
total_grad_norm = total_grad_norm ** 0.5
|
||||||
|
|
||||||
|
# Optimizer step
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
# Log
|
||||||
|
log_entries.append({
|
||||||
|
'example_id': example.get('id', 'unknown'),
|
||||||
|
'loss': loss.item(),
|
||||||
|
'grad_norm': total_grad_norm,
|
||||||
|
'timestamp': datetime.now().isoformat(),
|
||||||
|
})
|
||||||
|
|
||||||
|
return loss.item()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="Apollo-Mini training")
|
||||||
|
parser.add_argument("--weights", required=True,
|
||||||
|
help="Path to exported weight IPC handles")
|
||||||
|
parser.add_argument("--examples", required=True,
|
||||||
|
help="Path to training examples JSONL")
|
||||||
|
parser.add_argument("--checkpoint-dir", default="checkpoints",
|
||||||
|
help="Directory for saving checkpoints")
|
||||||
|
parser.add_argument("--config-path", default=None,
|
||||||
|
help="Path to model config files (for checkpoint)")
|
||||||
|
parser.add_argument("--lr", type=float, default=1e-5,
|
||||||
|
help="Learning rate")
|
||||||
|
parser.add_argument("--warmup-steps", type=int, default=10,
|
||||||
|
help="Learning rate warmup steps")
|
||||||
|
parser.add_argument("--weight-decay", type=float, default=0.01)
|
||||||
|
parser.add_argument("--dry-run", action="store_true",
|
||||||
|
help="Load weights and validate, don't train")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
print(f"Apollo-Mini Training")
|
||||||
|
print(f" weights: {args.weights}")
|
||||||
|
print(f" examples: {args.examples}")
|
||||||
|
print(f" lr: {args.lr}")
|
||||||
|
print()
|
||||||
|
|
||||||
|
# Import weights
|
||||||
|
print("Importing weights via CUDA IPC...")
|
||||||
|
params = import_weights(args.weights)
|
||||||
|
print(f" {len(params)} parameters imported")
|
||||||
|
|
||||||
|
# Make parameter groups
|
||||||
|
param_groups = make_param_groups(params)
|
||||||
|
|
||||||
|
# Initialize optimizer
|
||||||
|
optimizer = ApolloMini(param_groups, lr=args.lr,
|
||||||
|
weight_decay=args.weight_decay,
|
||||||
|
warmup_steps=args.warmup_steps)
|
||||||
|
print(f" Optimizer state: {optimizer.state_size_bytes() / 1e6:.1f} MB")
|
||||||
|
|
||||||
|
if args.dry_run:
|
||||||
|
print("\nDry run — weights imported and validated successfully.")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Load training examples
|
||||||
|
examples = []
|
||||||
|
with open(args.examples) as f:
|
||||||
|
for line in f:
|
||||||
|
examples.append(json.loads(line))
|
||||||
|
print(f" {len(examples)} training examples")
|
||||||
|
|
||||||
|
# Training loop
|
||||||
|
log_entries = []
|
||||||
|
print(f"\nTraining...")
|
||||||
|
t0 = time.time()
|
||||||
|
|
||||||
|
for i, example in enumerate(examples):
|
||||||
|
loss = train_step(params, example, optimizer, 'cuda:0', log_entries)
|
||||||
|
print(f" [{i+1}/{len(examples)}] loss={loss:.4f}")
|
||||||
|
|
||||||
|
elapsed = time.time() - t0
|
||||||
|
print(f"\nTraining complete: {len(examples)} examples in {elapsed:.1f}s")
|
||||||
|
print(f" Final optimizer state: {optimizer.state_size_bytes() / 1e6:.1f} MB")
|
||||||
|
|
||||||
|
# Save checkpoint
|
||||||
|
print("\nSaving checkpoint...")
|
||||||
|
save_checkpoint(params, args.checkpoint_dir, args.config_path)
|
||||||
|
|
||||||
|
# Save training log
|
||||||
|
date_str = datetime.now().strftime("%Y-%m-%d")
|
||||||
|
log_path = Path(args.checkpoint_dir) / date_str / "training-log.jsonl"
|
||||||
|
with open(log_path, 'w') as f:
|
||||||
|
for entry in log_entries:
|
||||||
|
f.write(json.dumps(entry) + '\n')
|
||||||
|
print(f"Training log: {log_path}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
175
training/training_example.py
Normal file
175
training/training_example.py
Normal file
|
|
@ -0,0 +1,175 @@
|
||||||
|
"""Training example construction and tokenization.
|
||||||
|
|
||||||
|
Takes raw conversation context + improved continuation, produces
|
||||||
|
tokenized tensors ready for context-frozen forward+backward.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TrainingExample:
|
||||||
|
"""A single training example for context-frozen training."""
|
||||||
|
id: str
|
||||||
|
context: str # conversation up to decision point
|
||||||
|
continuation: str # the better response
|
||||||
|
reason: str = "" # why this is a training target
|
||||||
|
memories: list[str] = field(default_factory=list) # memories that were in context
|
||||||
|
|
||||||
|
# Computed after tokenization
|
||||||
|
input_ids: torch.Tensor | None = None
|
||||||
|
context_len: int = 0
|
||||||
|
total_len: int = 0
|
||||||
|
|
||||||
|
def tokenize(self, tokenizer, max_len: int = 8192, device: str = "cuda:0"):
|
||||||
|
"""Tokenize context + continuation into training-ready tensors.
|
||||||
|
|
||||||
|
The chat template is applied to make the token distribution
|
||||||
|
match what the model sees during inference.
|
||||||
|
"""
|
||||||
|
# Build messages for context (everything up to the decision)
|
||||||
|
# The context should already be in chat format
|
||||||
|
context_ids = tokenizer.encode(self.context, add_special_tokens=False)
|
||||||
|
continuation_ids = tokenizer.encode(self.continuation, add_special_tokens=False)
|
||||||
|
|
||||||
|
self.context_len = len(context_ids)
|
||||||
|
self.total_len = len(context_ids) + len(continuation_ids)
|
||||||
|
|
||||||
|
if self.total_len > max_len:
|
||||||
|
# Truncate context from the left, keep continuation intact
|
||||||
|
excess = self.total_len - max_len
|
||||||
|
context_ids = context_ids[excess:]
|
||||||
|
self.context_len = len(context_ids)
|
||||||
|
self.total_len = len(context_ids) + len(continuation_ids)
|
||||||
|
|
||||||
|
all_ids = context_ids + continuation_ids
|
||||||
|
self.input_ids = torch.tensor(all_ids, device=device)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def to_dict(self) -> dict:
|
||||||
|
return {
|
||||||
|
'id': self.id,
|
||||||
|
'context': self.context,
|
||||||
|
'continuation': self.continuation,
|
||||||
|
'reason': self.reason,
|
||||||
|
'memories': self.memories,
|
||||||
|
'context_len': self.context_len,
|
||||||
|
'total_len': self.total_len,
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, d: dict) -> 'TrainingExample':
|
||||||
|
return cls(
|
||||||
|
id=d['id'],
|
||||||
|
context=d['context'],
|
||||||
|
continuation=d['continuation'],
|
||||||
|
reason=d.get('reason', ''),
|
||||||
|
memories=d.get('memories', []),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def load_examples(path: str) -> list[TrainingExample]:
|
||||||
|
"""Load training examples from JSONL file."""
|
||||||
|
examples = []
|
||||||
|
with open(path) as f:
|
||||||
|
for line in f:
|
||||||
|
if line.strip():
|
||||||
|
examples.append(TrainingExample.from_dict(json.loads(line)))
|
||||||
|
return examples
|
||||||
|
|
||||||
|
|
||||||
|
def save_examples(examples: list[TrainingExample], path: str):
|
||||||
|
"""Save training examples to JSONL file."""
|
||||||
|
with open(path, 'w') as f:
|
||||||
|
for ex in examples:
|
||||||
|
f.write(json.dumps(ex.to_dict()) + '\n')
|
||||||
|
|
||||||
|
|
||||||
|
class ExampleTokenizer:
|
||||||
|
"""Handles tokenization with the model's chat template.
|
||||||
|
|
||||||
|
Applies the same chat template that vLLM uses during inference,
|
||||||
|
so the token distribution matches what the model expects.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, model_path: str):
|
||||||
|
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
model_path, trust_remote_code=True)
|
||||||
|
|
||||||
|
def prepare_example(self, example: TrainingExample,
|
||||||
|
max_len: int = 8192,
|
||||||
|
device: str = "cuda:0") -> TrainingExample:
|
||||||
|
"""Tokenize an example using the chat template.
|
||||||
|
|
||||||
|
For proper training, the context should be formatted exactly
|
||||||
|
as vLLM would format it — with chat template applied.
|
||||||
|
"""
|
||||||
|
# Apply chat template to get the exact token sequence
|
||||||
|
# the model would see during inference
|
||||||
|
#
|
||||||
|
# Context: everything up to the decision point
|
||||||
|
# Continuation: the improved response
|
||||||
|
#
|
||||||
|
# We tokenize them separately to know where context ends
|
||||||
|
# and continuation begins.
|
||||||
|
context_ids = self.tokenizer.encode(
|
||||||
|
example.context, add_special_tokens=True)
|
||||||
|
continuation_ids = self.tokenizer.encode(
|
||||||
|
example.continuation, add_special_tokens=False)
|
||||||
|
|
||||||
|
example.context_len = len(context_ids)
|
||||||
|
example.total_len = len(context_ids) + len(continuation_ids)
|
||||||
|
|
||||||
|
if example.total_len > max_len:
|
||||||
|
excess = example.total_len - max_len
|
||||||
|
context_ids = context_ids[excess:]
|
||||||
|
example.context_len = len(context_ids)
|
||||||
|
example.total_len = example.context_len + len(continuation_ids)
|
||||||
|
|
||||||
|
all_ids = context_ids + continuation_ids
|
||||||
|
example.input_ids = torch.tensor(all_ids, device=device)
|
||||||
|
return example
|
||||||
|
|
||||||
|
def prepare_from_messages(self, example_id: str,
|
||||||
|
messages: list[dict],
|
||||||
|
decision_idx: int,
|
||||||
|
better_response: str,
|
||||||
|
reason: str = "",
|
||||||
|
memories: list[str] | None = None,
|
||||||
|
max_len: int = 8192,
|
||||||
|
device: str = "cuda:0") -> TrainingExample:
|
||||||
|
"""Build a training example from a chat message list.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
example_id: unique identifier
|
||||||
|
messages: list of {"role": ..., "content": ...} dicts
|
||||||
|
decision_idx: index of the assistant message to replace
|
||||||
|
better_response: the improved response text
|
||||||
|
reason: why this is a training target
|
||||||
|
memories: memory keys that were in context
|
||||||
|
max_len: maximum sequence length
|
||||||
|
device: target device
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tokenized TrainingExample
|
||||||
|
"""
|
||||||
|
# Context: all messages up to (not including) the decision
|
||||||
|
context_messages = messages[:decision_idx]
|
||||||
|
context_text = self.tokenizer.apply_chat_template(
|
||||||
|
context_messages, tokenize=False, add_generation_prompt=True)
|
||||||
|
|
||||||
|
# Build the example
|
||||||
|
example = TrainingExample(
|
||||||
|
id=example_id,
|
||||||
|
context=context_text,
|
||||||
|
continuation=better_response,
|
||||||
|
reason=reason,
|
||||||
|
memories=memories or [],
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.prepare_example(example, max_len=max_len, device=device)
|
||||||
141
training/weight_mapping.py
Normal file
141
training/weight_mapping.py
Normal file
|
|
@ -0,0 +1,141 @@
|
||||||
|
"""Map between vLLM's merged weight layout and HuggingFace's separate layout.
|
||||||
|
|
||||||
|
vLLM merges weights for efficiency:
|
||||||
|
in_proj_qkv + in_proj_z → in_proj_qkvz [key_dim*2 + value_dim*2, hidden]
|
||||||
|
in_proj_b + in_proj_a → in_proj_ba [num_v_heads*2, hidden]
|
||||||
|
gate_proj + up_proj → gate_up_proj [intermediate*2, hidden]
|
||||||
|
|
||||||
|
This module creates HF-compatible parameter views that point to the same
|
||||||
|
GPU memory as vLLM's merged tensors. No copies — views share storage.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
# Qwen3.5-27B dimensions
|
||||||
|
HIDDEN = 5120
|
||||||
|
NUM_K_HEADS = 16
|
||||||
|
NUM_V_HEADS = 48
|
||||||
|
HEAD_K_DIM = 128
|
||||||
|
HEAD_V_DIM = 128
|
||||||
|
KEY_DIM = NUM_K_HEADS * HEAD_K_DIM # 2048
|
||||||
|
VALUE_DIM = NUM_V_HEADS * HEAD_V_DIM # 6144
|
||||||
|
INTERMEDIATE = 17408
|
||||||
|
NUM_LAYERS = 64
|
||||||
|
CONV_KERNEL = 4
|
||||||
|
CONV_DIM = KEY_DIM * 2 + VALUE_DIM # 10240
|
||||||
|
|
||||||
|
|
||||||
|
def vllm_to_hf_views(vllm_params: dict[str, torch.Tensor]
|
||||||
|
) -> dict[str, torch.Tensor]:
|
||||||
|
"""Create HF-compatible parameter views from vLLM merged weights.
|
||||||
|
|
||||||
|
Returns a dict of HF-style parameter names → tensor views.
|
||||||
|
The views share GPU memory with the vLLM tensors — no copies.
|
||||||
|
"""
|
||||||
|
hf_params = {}
|
||||||
|
|
||||||
|
for name, tensor in vllm_params.items():
|
||||||
|
# Pass through non-merged params unchanged
|
||||||
|
if 'in_proj_qkvz' not in name and \
|
||||||
|
'in_proj_ba' not in name and \
|
||||||
|
'gate_up_proj' not in name:
|
||||||
|
hf_params[name] = tensor
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Split merged projections into HF-style separate weights
|
||||||
|
if 'in_proj_qkvz' in name:
|
||||||
|
# [key_dim*2 + value_dim*2, hidden] → qkv + z
|
||||||
|
prefix = name.replace('in_proj_qkvz', '')
|
||||||
|
qkv = tensor[:KEY_DIM * 2 + VALUE_DIM] # [key_dim*2 + value_dim, hidden]
|
||||||
|
z = tensor[KEY_DIM * 2 + VALUE_DIM:] # [value_dim, hidden]
|
||||||
|
hf_params[prefix + 'in_proj_qkv.weight'] = qkv
|
||||||
|
hf_params[prefix + 'in_proj_z.weight'] = z
|
||||||
|
|
||||||
|
elif 'in_proj_ba' in name:
|
||||||
|
# [num_v_heads*2, hidden] → b + a
|
||||||
|
prefix = name.replace('in_proj_ba', '')
|
||||||
|
b = tensor[:NUM_V_HEADS] # [num_v_heads, hidden]
|
||||||
|
a = tensor[NUM_V_HEADS:] # [num_v_heads, hidden]
|
||||||
|
hf_params[prefix + 'in_proj_b.weight'] = b
|
||||||
|
hf_params[prefix + 'in_proj_a.weight'] = a
|
||||||
|
|
||||||
|
elif 'gate_up_proj' in name:
|
||||||
|
# [intermediate*2, hidden] → gate + up
|
||||||
|
prefix = name.replace('gate_up_proj', '')
|
||||||
|
gate = tensor[:INTERMEDIATE] # [intermediate, hidden]
|
||||||
|
up = tensor[INTERMEDIATE:] # [intermediate, hidden]
|
||||||
|
hf_params[prefix + 'gate_proj.weight'] = gate
|
||||||
|
hf_params[prefix + 'up_proj.weight'] = up
|
||||||
|
|
||||||
|
return hf_params
|
||||||
|
|
||||||
|
|
||||||
|
def load_hf_model_with_vllm_weights(
|
||||||
|
vllm_params: dict[str, torch.Tensor],
|
||||||
|
model_path: str,
|
||||||
|
device: str = "cuda:0",
|
||||||
|
) -> nn.Module:
|
||||||
|
"""Load HF Qwen3.5 model with weights pointing to vLLM's GPU memory.
|
||||||
|
|
||||||
|
1. Creates HF-compatible views from vLLM's merged weights
|
||||||
|
2. Instantiates the HF model with empty weights
|
||||||
|
3. Replaces model parameters with the views
|
||||||
|
4. Returns model ready for forward+backward (autograd enabled)
|
||||||
|
"""
|
||||||
|
from transformers import AutoModelForCausalLM, AutoConfig
|
||||||
|
|
||||||
|
# Create HF-compatible views
|
||||||
|
hf_params = vllm_to_hf_views(vllm_params)
|
||||||
|
|
||||||
|
# Load config
|
||||||
|
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
|
||||||
|
|
||||||
|
# Create model with empty weights (no disk I/O)
|
||||||
|
with torch.device('meta'):
|
||||||
|
model = AutoModelForCausalLM.from_config(
|
||||||
|
config, trust_remote_code=True)
|
||||||
|
|
||||||
|
# Replace parameters with views into vLLM memory
|
||||||
|
replaced = 0
|
||||||
|
missing = []
|
||||||
|
for name, param in model.named_parameters():
|
||||||
|
if name in hf_params:
|
||||||
|
# Replace with view (shared GPU memory)
|
||||||
|
parts = name.rsplit('.', 1)
|
||||||
|
parent = model
|
||||||
|
for part in parts[0].split('.'):
|
||||||
|
parent = getattr(parent, part)
|
||||||
|
setattr(parent, parts[1],
|
||||||
|
nn.Parameter(hf_params[name], requires_grad=True))
|
||||||
|
replaced += 1
|
||||||
|
else:
|
||||||
|
missing.append(name)
|
||||||
|
|
||||||
|
print(f"Replaced {replaced} parameters with vLLM memory views")
|
||||||
|
if missing:
|
||||||
|
print(f"Missing {len(missing)} parameters: {missing[:5]}...")
|
||||||
|
|
||||||
|
model.train()
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def validate_views(vllm_params: dict[str, torch.Tensor],
|
||||||
|
hf_params: dict[str, torch.Tensor]):
|
||||||
|
"""Verify that HF views share storage with vLLM tensors."""
|
||||||
|
for vllm_name, vllm_tensor in vllm_params.items():
|
||||||
|
if 'in_proj_qkvz' in vllm_name:
|
||||||
|
prefix = vllm_name.replace('in_proj_qkvz.weight', '')
|
||||||
|
qkv_name = prefix + 'in_proj_qkv.weight'
|
||||||
|
z_name = prefix + 'in_proj_z.weight'
|
||||||
|
if qkv_name in hf_params:
|
||||||
|
assert hf_params[qkv_name].storage().data_ptr() == \
|
||||||
|
vllm_tensor.storage().data_ptr(), \
|
||||||
|
f"{qkv_name} doesn't share storage!"
|
||||||
|
if z_name in hf_params:
|
||||||
|
assert hf_params[z_name].storage().data_ptr() == \
|
||||||
|
vllm_tensor.storage().data_ptr(), \
|
||||||
|
f"{z_name} doesn't share storage!"
|
||||||
|
|
||||||
|
print("All views validated — shared storage confirmed")
|
||||||
Loading…
Add table
Add a link
Reference in a new issue