apollo: make rank configurable (default 1 = Mini, higher ranks for experimentation)

This commit is contained in:
ProofOfConcept 2026-03-30 22:06:31 -04:00
parent c5d7d8cb5d
commit e1cd4fb0ab
2 changed files with 38 additions and 27 deletions

View file

@ -1,38 +1,46 @@
"""Apollo-Mini optimizer — rank-1 gradient scaling with SGD-level memory.
"""Apollo optimizer — configurable-rank gradient scaling with SGD-level memory.
Implements the core algorithm from "APOLLO: Approximated Gradient Scaling
for Memory-Efficient LLM Optimization" (arXiv:2412.05270).
For each parameter tensor, maintains:
- rank-1 projected first moment (m): [m, 1] or [1, n]
- rank-1 projected second moment (v): same shape
- fixed random projection vector (regenerated from seed)
- projected first moment (m): [m, rank] or [rank, n]
- projected second moment (v): same shape
- random projection matrix (regenerated from seed)
Total optimizer state: ~50MB for a 27B model (vs 54GB for AdamW).
rank=1 is Apollo-Mini (~50MB state for 27B model).
rank=2-16 costs proportionally more memory but is still negligible.
Compute cost of projection is <1% of forward+backward.
"""
import torch
from torch.optim import Optimizer
class ApolloMini(Optimizer):
"""Apollo-Mini: rank-1 tensor-wise gradient scaling.
class Apollo(Optimizer):
"""Apollo: configurable-rank tensor-wise gradient scaling.
rank=1 is Apollo-Mini (SGD-level memory, AdamW-level performance).
Higher ranks cost proportionally more memory but may improve
training quality for fine-grained behavioral fine-tuning.
Args:
params: model parameters
lr: learning rate (default: 1e-4)
rank: projection rank (default: 1 = Apollo-Mini)
betas: coefficients for moment estimates (default: (0.9, 0.999))
eps: term for numerical stability (default: 1e-8)
weight_decay: decoupled weight decay (default: 0.01)
warmup_steps: linear warmup steps (default: 0)
scale: scaling factor for projection (default: 128)
scale_type: 'tensor' for tensor-wise, 'channel' for channel-wise
"""
def __init__(self, params, lr=1e-4, betas=(0.9, 0.999), eps=1e-8,
weight_decay=0.01, warmup_steps=0, scale=128):
defaults = dict(lr=lr, betas=betas, eps=eps,
def __init__(self, params, lr=1e-4, rank=1, betas=(0.9, 0.999), eps=1e-8,
weight_decay=0.01, warmup_steps=0, scale_type='tensor'):
defaults = dict(lr=lr, rank=rank, betas=betas, eps=eps,
weight_decay=weight_decay,
warmup_steps=warmup_steps, scale=scale)
warmup_steps=warmup_steps,
scale_type=scale_type)
super().__init__(params, defaults)
@torch.no_grad()
@ -61,21 +69,21 @@ class ApolloMini(Optimizer):
state['seed'] = id(p) # deterministic per-param seed
# Determine projection dimension
if grad.ndim >= 2:
rank = group['rank']
if grad.ndim >= 2 and min(grad.shape) >= rank:
if grad.shape[0] >= grad.shape[1]:
proj_shape = (grad.shape[1], 1)
state['proj_dim'] = 'right'
moment_shape = (grad.shape[0], 1)
moment_shape = (grad.shape[0], rank)
else:
proj_shape = (1, grad.shape[0])
state['proj_dim'] = 'left'
moment_shape = (1, grad.shape[1])
moment_shape = (rank, grad.shape[1])
state['exp_avg'] = torch.zeros(moment_shape,
device=p.device)
state['exp_avg_sq'] = torch.zeros(moment_shape,
device=p.device)
state['has_proj'] = True
state['rank'] = rank
else:
# 1D params (biases, norms): use standard Adam
state['exp_avg'] = torch.zeros_like(grad)
@ -91,23 +99,25 @@ class ApolloMini(Optimizer):
lr_scale = 1.0
if state['has_proj']:
# Generate deterministic random projection vector
rank = state['rank']
# Generate deterministic random projection matrix
gen = torch.Generator(device=p.device)
gen.manual_seed(state['seed'] + state['step'])
# Project gradient to rank-1
# Project gradient to low-rank
if state['proj_dim'] == 'right':
proj_vec = torch.randn(grad.shape[1], 1,
proj_mat = torch.randn(grad.shape[1], rank,
device=p.device,
generator=gen)
proj_vec = proj_vec / (proj_vec.norm() + eps)
proj_grad = grad @ proj_vec # [m, 1]
proj_mat = proj_mat / (proj_mat.norm(dim=0, keepdim=True) + eps)
proj_grad = grad @ proj_mat # [m, rank]
else:
proj_vec = torch.randn(1, grad.shape[0],
proj_mat = torch.randn(rank, grad.shape[0],
device=p.device,
generator=gen)
proj_vec = proj_vec / (proj_vec.norm() + eps)
proj_grad = proj_vec @ grad # [1, n]
proj_mat = proj_mat / (proj_mat.norm(dim=1, keepdim=True) + eps)
proj_grad = proj_mat @ grad # [rank, n]
# Update moments in projected space
state['exp_avg'].mul_(beta1).add_(proj_grad, alpha=1 - beta1)

View file

@ -309,7 +309,7 @@ class ApolloWorker:
samples: List[Dict[str, str]],
config: Dict[str, Any]) -> List[float]:
"""Run Apollo-Mini training on conversation decision points."""
from apollo_mini import ApolloMini
from apollo_mini import Apollo
from transformers import AutoTokenizer
lr = config.get('learning_rate', self.config['learning_rate'])
@ -331,7 +331,8 @@ class ApolloWorker:
if standard_params:
groups.append({'params': standard_params})
optimizer = ApolloMini(groups, lr=lr)
rank = config.get('apollo_rank', 1)
optimizer = Apollo(groups, lr=lr, rank=rank)
logger.info(f"Apollo-Mini: {len(apollo_params)} apollo params, "
f"{len(standard_params)} standard, "
f"state={optimizer.state_size_bytes()/1e6:.1f}MB")