apollo: make rank configurable (default 1 = Mini, higher ranks for experimentation)
This commit is contained in:
parent
c5d7d8cb5d
commit
e1cd4fb0ab
2 changed files with 38 additions and 27 deletions
|
|
@ -1,38 +1,46 @@
|
||||||
"""Apollo-Mini optimizer — rank-1 gradient scaling with SGD-level memory.
|
"""Apollo optimizer — configurable-rank gradient scaling with SGD-level memory.
|
||||||
|
|
||||||
Implements the core algorithm from "APOLLO: Approximated Gradient Scaling
|
Implements the core algorithm from "APOLLO: Approximated Gradient Scaling
|
||||||
for Memory-Efficient LLM Optimization" (arXiv:2412.05270).
|
for Memory-Efficient LLM Optimization" (arXiv:2412.05270).
|
||||||
|
|
||||||
For each parameter tensor, maintains:
|
For each parameter tensor, maintains:
|
||||||
- rank-1 projected first moment (m): [m, 1] or [1, n]
|
- projected first moment (m): [m, rank] or [rank, n]
|
||||||
- rank-1 projected second moment (v): same shape
|
- projected second moment (v): same shape
|
||||||
- fixed random projection vector (regenerated from seed)
|
- random projection matrix (regenerated from seed)
|
||||||
|
|
||||||
Total optimizer state: ~50MB for a 27B model (vs 54GB for AdamW).
|
rank=1 is Apollo-Mini (~50MB state for 27B model).
|
||||||
|
rank=2-16 costs proportionally more memory but is still negligible.
|
||||||
|
Compute cost of projection is <1% of forward+backward.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
|
|
||||||
|
|
||||||
class ApolloMini(Optimizer):
|
class Apollo(Optimizer):
|
||||||
"""Apollo-Mini: rank-1 tensor-wise gradient scaling.
|
"""Apollo: configurable-rank tensor-wise gradient scaling.
|
||||||
|
|
||||||
|
rank=1 is Apollo-Mini (SGD-level memory, AdamW-level performance).
|
||||||
|
Higher ranks cost proportionally more memory but may improve
|
||||||
|
training quality for fine-grained behavioral fine-tuning.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
params: model parameters
|
params: model parameters
|
||||||
lr: learning rate (default: 1e-4)
|
lr: learning rate (default: 1e-4)
|
||||||
|
rank: projection rank (default: 1 = Apollo-Mini)
|
||||||
betas: coefficients for moment estimates (default: (0.9, 0.999))
|
betas: coefficients for moment estimates (default: (0.9, 0.999))
|
||||||
eps: term for numerical stability (default: 1e-8)
|
eps: term for numerical stability (default: 1e-8)
|
||||||
weight_decay: decoupled weight decay (default: 0.01)
|
weight_decay: decoupled weight decay (default: 0.01)
|
||||||
warmup_steps: linear warmup steps (default: 0)
|
warmup_steps: linear warmup steps (default: 0)
|
||||||
scale: scaling factor for projection (default: 128)
|
scale_type: 'tensor' for tensor-wise, 'channel' for channel-wise
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, params, lr=1e-4, betas=(0.9, 0.999), eps=1e-8,
|
def __init__(self, params, lr=1e-4, rank=1, betas=(0.9, 0.999), eps=1e-8,
|
||||||
weight_decay=0.01, warmup_steps=0, scale=128):
|
weight_decay=0.01, warmup_steps=0, scale_type='tensor'):
|
||||||
defaults = dict(lr=lr, betas=betas, eps=eps,
|
defaults = dict(lr=lr, rank=rank, betas=betas, eps=eps,
|
||||||
weight_decay=weight_decay,
|
weight_decay=weight_decay,
|
||||||
warmup_steps=warmup_steps, scale=scale)
|
warmup_steps=warmup_steps,
|
||||||
|
scale_type=scale_type)
|
||||||
super().__init__(params, defaults)
|
super().__init__(params, defaults)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
|
|
@ -61,21 +69,21 @@ class ApolloMini(Optimizer):
|
||||||
state['seed'] = id(p) # deterministic per-param seed
|
state['seed'] = id(p) # deterministic per-param seed
|
||||||
|
|
||||||
# Determine projection dimension
|
# Determine projection dimension
|
||||||
if grad.ndim >= 2:
|
rank = group['rank']
|
||||||
|
if grad.ndim >= 2 and min(grad.shape) >= rank:
|
||||||
if grad.shape[0] >= grad.shape[1]:
|
if grad.shape[0] >= grad.shape[1]:
|
||||||
proj_shape = (grad.shape[1], 1)
|
|
||||||
state['proj_dim'] = 'right'
|
state['proj_dim'] = 'right'
|
||||||
moment_shape = (grad.shape[0], 1)
|
moment_shape = (grad.shape[0], rank)
|
||||||
else:
|
else:
|
||||||
proj_shape = (1, grad.shape[0])
|
|
||||||
state['proj_dim'] = 'left'
|
state['proj_dim'] = 'left'
|
||||||
moment_shape = (1, grad.shape[1])
|
moment_shape = (rank, grad.shape[1])
|
||||||
|
|
||||||
state['exp_avg'] = torch.zeros(moment_shape,
|
state['exp_avg'] = torch.zeros(moment_shape,
|
||||||
device=p.device)
|
device=p.device)
|
||||||
state['exp_avg_sq'] = torch.zeros(moment_shape,
|
state['exp_avg_sq'] = torch.zeros(moment_shape,
|
||||||
device=p.device)
|
device=p.device)
|
||||||
state['has_proj'] = True
|
state['has_proj'] = True
|
||||||
|
state['rank'] = rank
|
||||||
else:
|
else:
|
||||||
# 1D params (biases, norms): use standard Adam
|
# 1D params (biases, norms): use standard Adam
|
||||||
state['exp_avg'] = torch.zeros_like(grad)
|
state['exp_avg'] = torch.zeros_like(grad)
|
||||||
|
|
@ -91,23 +99,25 @@ class ApolloMini(Optimizer):
|
||||||
lr_scale = 1.0
|
lr_scale = 1.0
|
||||||
|
|
||||||
if state['has_proj']:
|
if state['has_proj']:
|
||||||
# Generate deterministic random projection vector
|
rank = state['rank']
|
||||||
|
|
||||||
|
# Generate deterministic random projection matrix
|
||||||
gen = torch.Generator(device=p.device)
|
gen = torch.Generator(device=p.device)
|
||||||
gen.manual_seed(state['seed'] + state['step'])
|
gen.manual_seed(state['seed'] + state['step'])
|
||||||
|
|
||||||
# Project gradient to rank-1
|
# Project gradient to low-rank
|
||||||
if state['proj_dim'] == 'right':
|
if state['proj_dim'] == 'right':
|
||||||
proj_vec = torch.randn(grad.shape[1], 1,
|
proj_mat = torch.randn(grad.shape[1], rank,
|
||||||
device=p.device,
|
device=p.device,
|
||||||
generator=gen)
|
generator=gen)
|
||||||
proj_vec = proj_vec / (proj_vec.norm() + eps)
|
proj_mat = proj_mat / (proj_mat.norm(dim=0, keepdim=True) + eps)
|
||||||
proj_grad = grad @ proj_vec # [m, 1]
|
proj_grad = grad @ proj_mat # [m, rank]
|
||||||
else:
|
else:
|
||||||
proj_vec = torch.randn(1, grad.shape[0],
|
proj_mat = torch.randn(rank, grad.shape[0],
|
||||||
device=p.device,
|
device=p.device,
|
||||||
generator=gen)
|
generator=gen)
|
||||||
proj_vec = proj_vec / (proj_vec.norm() + eps)
|
proj_mat = proj_mat / (proj_mat.norm(dim=1, keepdim=True) + eps)
|
||||||
proj_grad = proj_vec @ grad # [1, n]
|
proj_grad = proj_mat @ grad # [rank, n]
|
||||||
|
|
||||||
# Update moments in projected space
|
# Update moments in projected space
|
||||||
state['exp_avg'].mul_(beta1).add_(proj_grad, alpha=1 - beta1)
|
state['exp_avg'].mul_(beta1).add_(proj_grad, alpha=1 - beta1)
|
||||||
|
|
|
||||||
|
|
@ -309,7 +309,7 @@ class ApolloWorker:
|
||||||
samples: List[Dict[str, str]],
|
samples: List[Dict[str, str]],
|
||||||
config: Dict[str, Any]) -> List[float]:
|
config: Dict[str, Any]) -> List[float]:
|
||||||
"""Run Apollo-Mini training on conversation decision points."""
|
"""Run Apollo-Mini training on conversation decision points."""
|
||||||
from apollo_mini import ApolloMini
|
from apollo_mini import Apollo
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
lr = config.get('learning_rate', self.config['learning_rate'])
|
lr = config.get('learning_rate', self.config['learning_rate'])
|
||||||
|
|
@ -331,7 +331,8 @@ class ApolloWorker:
|
||||||
if standard_params:
|
if standard_params:
|
||||||
groups.append({'params': standard_params})
|
groups.append({'params': standard_params})
|
||||||
|
|
||||||
optimizer = ApolloMini(groups, lr=lr)
|
rank = config.get('apollo_rank', 1)
|
||||||
|
optimizer = Apollo(groups, lr=lr, rank=rank)
|
||||||
logger.info(f"Apollo-Mini: {len(apollo_params)} apollo params, "
|
logger.info(f"Apollo-Mini: {len(apollo_params)} apollo params, "
|
||||||
f"{len(standard_params)} standard, "
|
f"{len(standard_params)} standard, "
|
||||||
f"state={optimizer.state_size_bytes()/1e6:.1f}MB")
|
f"state={optimizer.state_size_bytes()/1e6:.1f}MB")
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue