apollo-mini training system: initial implementation
Core components for online fine-tuning of Qwen3.5-27B with CUDA IPC shared weight memory between vLLM and the training process: - apollo_mini.py: rank-1 optimizer (SGD memory, AdamW quality) - apollo_worker.py: HTTP daemon coordinating training with vLLM - weight_mapping.py: vLLM merged → HF separate layout (zero-copy views) - training_example.py: tokenization with chat template - export_weights.py: CUDA IPC handle export from vLLM - train.py: standalone training script (alternative to daemon) - DESIGN.md: architecture and protocol documentation Validated: CUDA IPC autograd works on real Qwen3.5 weights (B200). Apollo-Mini rank-1 projection + scaling + in-place update confirmed. Co-Authored-By: Kent Overstreet <kent.overstreet@gmail.com>
This commit is contained in:
parent
13453606ae
commit
c5d7d8cb5d
7 changed files with 1484 additions and 0 deletions
162
training/apollo_mini.py
Normal file
162
training/apollo_mini.py
Normal file
|
|
@ -0,0 +1,162 @@
|
|||
"""Apollo-Mini optimizer — rank-1 gradient scaling with SGD-level memory.
|
||||
|
||||
Implements the core algorithm from "APOLLO: Approximated Gradient Scaling
|
||||
for Memory-Efficient LLM Optimization" (arXiv:2412.05270).
|
||||
|
||||
For each parameter tensor, maintains:
|
||||
- rank-1 projected first moment (m): [m, 1] or [1, n]
|
||||
- rank-1 projected second moment (v): same shape
|
||||
- fixed random projection vector (regenerated from seed)
|
||||
|
||||
Total optimizer state: ~50MB for a 27B model (vs 54GB for AdamW).
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch.optim import Optimizer
|
||||
|
||||
|
||||
class ApolloMini(Optimizer):
|
||||
"""Apollo-Mini: rank-1 tensor-wise gradient scaling.
|
||||
|
||||
Args:
|
||||
params: model parameters
|
||||
lr: learning rate (default: 1e-4)
|
||||
betas: coefficients for moment estimates (default: (0.9, 0.999))
|
||||
eps: term for numerical stability (default: 1e-8)
|
||||
weight_decay: decoupled weight decay (default: 0.01)
|
||||
warmup_steps: linear warmup steps (default: 0)
|
||||
scale: scaling factor for projection (default: 128)
|
||||
"""
|
||||
|
||||
def __init__(self, params, lr=1e-4, betas=(0.9, 0.999), eps=1e-8,
|
||||
weight_decay=0.01, warmup_steps=0, scale=128):
|
||||
defaults = dict(lr=lr, betas=betas, eps=eps,
|
||||
weight_decay=weight_decay,
|
||||
warmup_steps=warmup_steps, scale=scale)
|
||||
super().__init__(params, defaults)
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self, closure=None):
|
||||
loss = None
|
||||
if closure is not None:
|
||||
with torch.enable_grad():
|
||||
loss = closure()
|
||||
|
||||
for group in self.param_groups:
|
||||
lr = group['lr']
|
||||
beta1, beta2 = group['betas']
|
||||
eps = group['eps']
|
||||
weight_decay = group['weight_decay']
|
||||
|
||||
for p in group['params']:
|
||||
if p.grad is None:
|
||||
continue
|
||||
|
||||
grad = p.grad.float()
|
||||
state = self.state[p]
|
||||
|
||||
# Initialize state
|
||||
if len(state) == 0:
|
||||
state['step'] = 0
|
||||
state['seed'] = id(p) # deterministic per-param seed
|
||||
|
||||
# Determine projection dimension
|
||||
if grad.ndim >= 2:
|
||||
if grad.shape[0] >= grad.shape[1]:
|
||||
proj_shape = (grad.shape[1], 1)
|
||||
state['proj_dim'] = 'right'
|
||||
moment_shape = (grad.shape[0], 1)
|
||||
else:
|
||||
proj_shape = (1, grad.shape[0])
|
||||
state['proj_dim'] = 'left'
|
||||
moment_shape = (1, grad.shape[1])
|
||||
|
||||
state['exp_avg'] = torch.zeros(moment_shape,
|
||||
device=p.device)
|
||||
state['exp_avg_sq'] = torch.zeros(moment_shape,
|
||||
device=p.device)
|
||||
state['has_proj'] = True
|
||||
else:
|
||||
# 1D params (biases, norms): use standard Adam
|
||||
state['exp_avg'] = torch.zeros_like(grad)
|
||||
state['exp_avg_sq'] = torch.zeros_like(grad)
|
||||
state['has_proj'] = False
|
||||
|
||||
state['step'] += 1
|
||||
|
||||
# Learning rate warmup
|
||||
if group['warmup_steps'] > 0 and state['step'] <= group['warmup_steps']:
|
||||
lr_scale = state['step'] / group['warmup_steps']
|
||||
else:
|
||||
lr_scale = 1.0
|
||||
|
||||
if state['has_proj']:
|
||||
# Generate deterministic random projection vector
|
||||
gen = torch.Generator(device=p.device)
|
||||
gen.manual_seed(state['seed'] + state['step'])
|
||||
|
||||
# Project gradient to rank-1
|
||||
if state['proj_dim'] == 'right':
|
||||
proj_vec = torch.randn(grad.shape[1], 1,
|
||||
device=p.device,
|
||||
generator=gen)
|
||||
proj_vec = proj_vec / (proj_vec.norm() + eps)
|
||||
proj_grad = grad @ proj_vec # [m, 1]
|
||||
else:
|
||||
proj_vec = torch.randn(1, grad.shape[0],
|
||||
device=p.device,
|
||||
generator=gen)
|
||||
proj_vec = proj_vec / (proj_vec.norm() + eps)
|
||||
proj_grad = proj_vec @ grad # [1, n]
|
||||
|
||||
# Update moments in projected space
|
||||
state['exp_avg'].mul_(beta1).add_(proj_grad, alpha=1 - beta1)
|
||||
state['exp_avg_sq'].mul_(beta2).addcmul_(
|
||||
proj_grad, proj_grad, value=1 - beta2)
|
||||
|
||||
# Bias correction
|
||||
bc1 = 1 - beta1 ** state['step']
|
||||
bc2 = 1 - beta2 ** state['step']
|
||||
m_hat = state['exp_avg'] / bc1
|
||||
v_hat = state['exp_avg_sq'] / bc2
|
||||
|
||||
# Adam update in projected space
|
||||
adam_update = m_hat / (v_hat.sqrt() + eps)
|
||||
|
||||
# Tensor-wise scaling factor
|
||||
scaling = adam_update.norm() / (proj_grad.norm() + eps)
|
||||
|
||||
# Apply to full gradient
|
||||
step_size = lr * lr_scale
|
||||
p.add_(grad.to(p.dtype) * (-step_size * scaling))
|
||||
|
||||
else:
|
||||
# Standard Adam for 1D params
|
||||
state['exp_avg'].mul_(beta1).add_(grad, alpha=1 - beta1)
|
||||
state['exp_avg_sq'].mul_(beta2).addcmul_(
|
||||
grad, grad, value=1 - beta2)
|
||||
|
||||
bc1 = 1 - beta1 ** state['step']
|
||||
bc2 = 1 - beta2 ** state['step']
|
||||
m_hat = state['exp_avg'] / bc1
|
||||
v_hat = state['exp_avg_sq'] / bc2
|
||||
|
||||
update = m_hat / (v_hat.sqrt() + eps)
|
||||
step_size = lr * lr_scale
|
||||
p.add_(update.to(p.dtype), alpha=-step_size)
|
||||
|
||||
# Decoupled weight decay
|
||||
if weight_decay > 0:
|
||||
p.add_(p, alpha=-lr * lr_scale * weight_decay)
|
||||
|
||||
return loss
|
||||
|
||||
def state_size_bytes(self):
|
||||
"""Total optimizer state memory in bytes."""
|
||||
total = 0
|
||||
for state in self.state.values():
|
||||
if isinstance(state, dict):
|
||||
for v in state.values():
|
||||
if isinstance(v, torch.Tensor):
|
||||
total += v.nelement() * v.element_size()
|
||||
return total
|
||||
Loading…
Add table
Add a link
Reference in a new issue