apollo: make rank configurable (default 1 = Mini, higher ranks for experimentation)
This commit is contained in:
parent
c5d7d8cb5d
commit
e1cd4fb0ab
2 changed files with 38 additions and 27 deletions
|
|
@ -1,38 +1,46 @@
|
|||
"""Apollo-Mini optimizer — rank-1 gradient scaling with SGD-level memory.
|
||||
"""Apollo optimizer — configurable-rank gradient scaling with SGD-level memory.
|
||||
|
||||
Implements the core algorithm from "APOLLO: Approximated Gradient Scaling
|
||||
for Memory-Efficient LLM Optimization" (arXiv:2412.05270).
|
||||
|
||||
For each parameter tensor, maintains:
|
||||
- rank-1 projected first moment (m): [m, 1] or [1, n]
|
||||
- rank-1 projected second moment (v): same shape
|
||||
- fixed random projection vector (regenerated from seed)
|
||||
- projected first moment (m): [m, rank] or [rank, n]
|
||||
- projected second moment (v): same shape
|
||||
- random projection matrix (regenerated from seed)
|
||||
|
||||
Total optimizer state: ~50MB for a 27B model (vs 54GB for AdamW).
|
||||
rank=1 is Apollo-Mini (~50MB state for 27B model).
|
||||
rank=2-16 costs proportionally more memory but is still negligible.
|
||||
Compute cost of projection is <1% of forward+backward.
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch.optim import Optimizer
|
||||
|
||||
|
||||
class ApolloMini(Optimizer):
|
||||
"""Apollo-Mini: rank-1 tensor-wise gradient scaling.
|
||||
class Apollo(Optimizer):
|
||||
"""Apollo: configurable-rank tensor-wise gradient scaling.
|
||||
|
||||
rank=1 is Apollo-Mini (SGD-level memory, AdamW-level performance).
|
||||
Higher ranks cost proportionally more memory but may improve
|
||||
training quality for fine-grained behavioral fine-tuning.
|
||||
|
||||
Args:
|
||||
params: model parameters
|
||||
lr: learning rate (default: 1e-4)
|
||||
rank: projection rank (default: 1 = Apollo-Mini)
|
||||
betas: coefficients for moment estimates (default: (0.9, 0.999))
|
||||
eps: term for numerical stability (default: 1e-8)
|
||||
weight_decay: decoupled weight decay (default: 0.01)
|
||||
warmup_steps: linear warmup steps (default: 0)
|
||||
scale: scaling factor for projection (default: 128)
|
||||
scale_type: 'tensor' for tensor-wise, 'channel' for channel-wise
|
||||
"""
|
||||
|
||||
def __init__(self, params, lr=1e-4, betas=(0.9, 0.999), eps=1e-8,
|
||||
weight_decay=0.01, warmup_steps=0, scale=128):
|
||||
defaults = dict(lr=lr, betas=betas, eps=eps,
|
||||
def __init__(self, params, lr=1e-4, rank=1, betas=(0.9, 0.999), eps=1e-8,
|
||||
weight_decay=0.01, warmup_steps=0, scale_type='tensor'):
|
||||
defaults = dict(lr=lr, rank=rank, betas=betas, eps=eps,
|
||||
weight_decay=weight_decay,
|
||||
warmup_steps=warmup_steps, scale=scale)
|
||||
warmup_steps=warmup_steps,
|
||||
scale_type=scale_type)
|
||||
super().__init__(params, defaults)
|
||||
|
||||
@torch.no_grad()
|
||||
|
|
@ -61,21 +69,21 @@ class ApolloMini(Optimizer):
|
|||
state['seed'] = id(p) # deterministic per-param seed
|
||||
|
||||
# Determine projection dimension
|
||||
if grad.ndim >= 2:
|
||||
rank = group['rank']
|
||||
if grad.ndim >= 2 and min(grad.shape) >= rank:
|
||||
if grad.shape[0] >= grad.shape[1]:
|
||||
proj_shape = (grad.shape[1], 1)
|
||||
state['proj_dim'] = 'right'
|
||||
moment_shape = (grad.shape[0], 1)
|
||||
moment_shape = (grad.shape[0], rank)
|
||||
else:
|
||||
proj_shape = (1, grad.shape[0])
|
||||
state['proj_dim'] = 'left'
|
||||
moment_shape = (1, grad.shape[1])
|
||||
moment_shape = (rank, grad.shape[1])
|
||||
|
||||
state['exp_avg'] = torch.zeros(moment_shape,
|
||||
device=p.device)
|
||||
state['exp_avg_sq'] = torch.zeros(moment_shape,
|
||||
device=p.device)
|
||||
state['has_proj'] = True
|
||||
state['rank'] = rank
|
||||
else:
|
||||
# 1D params (biases, norms): use standard Adam
|
||||
state['exp_avg'] = torch.zeros_like(grad)
|
||||
|
|
@ -91,23 +99,25 @@ class ApolloMini(Optimizer):
|
|||
lr_scale = 1.0
|
||||
|
||||
if state['has_proj']:
|
||||
# Generate deterministic random projection vector
|
||||
rank = state['rank']
|
||||
|
||||
# Generate deterministic random projection matrix
|
||||
gen = torch.Generator(device=p.device)
|
||||
gen.manual_seed(state['seed'] + state['step'])
|
||||
|
||||
# Project gradient to rank-1
|
||||
# Project gradient to low-rank
|
||||
if state['proj_dim'] == 'right':
|
||||
proj_vec = torch.randn(grad.shape[1], 1,
|
||||
proj_mat = torch.randn(grad.shape[1], rank,
|
||||
device=p.device,
|
||||
generator=gen)
|
||||
proj_vec = proj_vec / (proj_vec.norm() + eps)
|
||||
proj_grad = grad @ proj_vec # [m, 1]
|
||||
proj_mat = proj_mat / (proj_mat.norm(dim=0, keepdim=True) + eps)
|
||||
proj_grad = grad @ proj_mat # [m, rank]
|
||||
else:
|
||||
proj_vec = torch.randn(1, grad.shape[0],
|
||||
proj_mat = torch.randn(rank, grad.shape[0],
|
||||
device=p.device,
|
||||
generator=gen)
|
||||
proj_vec = proj_vec / (proj_vec.norm() + eps)
|
||||
proj_grad = proj_vec @ grad # [1, n]
|
||||
proj_mat = proj_mat / (proj_mat.norm(dim=1, keepdim=True) + eps)
|
||||
proj_grad = proj_mat @ grad # [rank, n]
|
||||
|
||||
# Update moments in projected space
|
||||
state['exp_avg'].mul_(beta1).add_(proj_grad, alpha=1 - beta1)
|
||||
|
|
|
|||
|
|
@ -309,7 +309,7 @@ class ApolloWorker:
|
|||
samples: List[Dict[str, str]],
|
||||
config: Dict[str, Any]) -> List[float]:
|
||||
"""Run Apollo-Mini training on conversation decision points."""
|
||||
from apollo_mini import ApolloMini
|
||||
from apollo_mini import Apollo
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
lr = config.get('learning_rate', self.config['learning_rate'])
|
||||
|
|
@ -331,7 +331,8 @@ class ApolloWorker:
|
|||
if standard_params:
|
||||
groups.append({'params': standard_params})
|
||||
|
||||
optimizer = ApolloMini(groups, lr=lr)
|
||||
rank = config.get('apollo_rank', 1)
|
||||
optimizer = Apollo(groups, lr=lr, rank=rank)
|
||||
logger.info(f"Apollo-Mini: {len(apollo_params)} apollo params, "
|
||||
f"{len(standard_params)} standard, "
|
||||
f"state={optimizer.state_size_bytes()/1e6:.1f}MB")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue