consciousness/training/apollo_mini.py

172 lines
7 KiB
Python

"""Apollo optimizer — configurable-rank gradient scaling with SGD-level memory.
Implements the core algorithm from "APOLLO: Approximated Gradient Scaling
for Memory-Efficient LLM Optimization" (arXiv:2412.05270).
For each parameter tensor, maintains:
- projected first moment (m): [m, rank] or [rank, n]
- projected second moment (v): same shape
- random projection matrix (regenerated from seed)
rank=1 is Apollo-Mini (~50MB state for 27B model).
rank=2-16 costs proportionally more memory but is still negligible.
Compute cost of projection is <1% of forward+backward.
"""
import torch
from torch.optim import Optimizer
class Apollo(Optimizer):
"""Apollo: configurable-rank tensor-wise gradient scaling.
rank=1 is Apollo-Mini (SGD-level memory, AdamW-level performance).
Higher ranks cost proportionally more memory but may improve
training quality for fine-grained behavioral fine-tuning.
Args:
params: model parameters
lr: learning rate (default: 1e-4)
rank: projection rank (default: 1 = Apollo-Mini)
betas: coefficients for moment estimates (default: (0.9, 0.999))
eps: term for numerical stability (default: 1e-8)
weight_decay: decoupled weight decay (default: 0.01)
warmup_steps: linear warmup steps (default: 0)
scale_type: 'tensor' for tensor-wise, 'channel' for channel-wise
"""
def __init__(self, params, lr=1e-4, rank=1, betas=(0.9, 0.999), eps=1e-8,
weight_decay=0.01, warmup_steps=0, scale_type='tensor'):
defaults = dict(lr=lr, rank=rank, betas=betas, eps=eps,
weight_decay=weight_decay,
warmup_steps=warmup_steps,
scale_type=scale_type)
super().__init__(params, defaults)
@torch.no_grad()
def step(self, closure=None):
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
lr = group['lr']
beta1, beta2 = group['betas']
eps = group['eps']
weight_decay = group['weight_decay']
for p in group['params']:
if p.grad is None:
continue
grad = p.grad.float()
state = self.state[p]
# Initialize state
if len(state) == 0:
state['step'] = 0
state['seed'] = id(p) # deterministic per-param seed
# Determine projection dimension
rank = group['rank']
if grad.ndim >= 2 and min(grad.shape) >= rank:
if grad.shape[0] >= grad.shape[1]:
state['proj_dim'] = 'right'
moment_shape = (grad.shape[0], rank)
else:
state['proj_dim'] = 'left'
moment_shape = (rank, grad.shape[1])
state['exp_avg'] = torch.zeros(moment_shape,
device=p.device)
state['exp_avg_sq'] = torch.zeros(moment_shape,
device=p.device)
state['has_proj'] = True
state['rank'] = rank
else:
# 1D params (biases, norms): use standard Adam
state['exp_avg'] = torch.zeros_like(grad)
state['exp_avg_sq'] = torch.zeros_like(grad)
state['has_proj'] = False
state['step'] += 1
# Learning rate warmup
if group['warmup_steps'] > 0 and state['step'] <= group['warmup_steps']:
lr_scale = state['step'] / group['warmup_steps']
else:
lr_scale = 1.0
if state['has_proj']:
rank = state['rank']
# Generate deterministic random projection matrix
gen = torch.Generator(device=p.device)
gen.manual_seed(state['seed'] + state['step'])
# Project gradient to low-rank
if state['proj_dim'] == 'right':
proj_mat = torch.randn(grad.shape[1], rank,
device=p.device,
generator=gen)
proj_mat = proj_mat / (proj_mat.norm(dim=0, keepdim=True) + eps)
proj_grad = grad @ proj_mat # [m, rank]
else:
proj_mat = torch.randn(rank, grad.shape[0],
device=p.device,
generator=gen)
proj_mat = proj_mat / (proj_mat.norm(dim=1, keepdim=True) + eps)
proj_grad = proj_mat @ grad # [rank, n]
# Update moments in projected space
state['exp_avg'].mul_(beta1).add_(proj_grad, alpha=1 - beta1)
state['exp_avg_sq'].mul_(beta2).addcmul_(
proj_grad, proj_grad, value=1 - beta2)
# Bias correction
bc1 = 1 - beta1 ** state['step']
bc2 = 1 - beta2 ** state['step']
m_hat = state['exp_avg'] / bc1
v_hat = state['exp_avg_sq'] / bc2
# Adam update in projected space
adam_update = m_hat / (v_hat.sqrt() + eps)
# Tensor-wise scaling factor
scaling = adam_update.norm() / (proj_grad.norm() + eps)
# Apply to full gradient
step_size = lr * lr_scale
p.add_(grad.to(p.dtype) * (-step_size * scaling))
else:
# Standard Adam for 1D params
state['exp_avg'].mul_(beta1).add_(grad, alpha=1 - beta1)
state['exp_avg_sq'].mul_(beta2).addcmul_(
grad, grad, value=1 - beta2)
bc1 = 1 - beta1 ** state['step']
bc2 = 1 - beta2 ** state['step']
m_hat = state['exp_avg'] / bc1
v_hat = state['exp_avg_sq'] / bc2
update = m_hat / (v_hat.sqrt() + eps)
step_size = lr * lr_scale
p.add_(update.to(p.dtype), alpha=-step_size)
# Decoupled weight decay
if weight_decay > 0:
p.add_(p, alpha=-lr * lr_scale * weight_decay)
return loss
def state_size_bytes(self):
"""Total optimizer state memory in bytes."""
total = 0
for state in self.state.values():
if isinstance(state, dict):
for v in state.values():
if isinstance(v, torch.Tensor):
total += v.nelement() * v.element_size()
return total