/finetune: expose all Apollo optimizer settings

lr, rank, betas, eps, weight_decay, warmup_steps,
scale, proj_refresh, norm_growth_limit — all optional
with sensible defaults.

Co-Authored-By: Proof of Concept <poc@bcachefs.org>
This commit is contained in:
Kent Overstreet 2026-04-15 23:19:22 -04:00
parent a73bcf5ae3
commit 2f08149fab

View file

@ -360,8 +360,6 @@ class ApolloWorker:
"""
from apollo_plugin.optimizer import Apollo
lr = config.get('learning_rate', self.config['learning_rate'])
# Build parameter groups (Apollo for 2D+, standard for small/1D)
apollo_params, standard_params = [], []
for p in model.parameters():
@ -377,9 +375,22 @@ class ApolloWorker:
if standard_params:
groups.append({'params': standard_params})
rank = config.get('apollo_rank', 1)
optimizer = Apollo(groups, lr=lr, rank=rank)
logger.info(f"Apollo-Mini: {len(apollo_params)} apollo params, "
# Apollo settings from request config, falling back to server defaults
optimizer = Apollo(
groups,
lr=config.get('lr', self.config.get('learning_rate', 1e-5)),
rank=config.get('rank', 256),
betas=tuple(config.get('betas', (0.9, 0.999))),
eps=config.get('eps', 1e-8),
weight_decay=config.get('weight_decay', 0.01),
warmup_steps=config.get('warmup_steps', 0),
scale=config.get('scale'), # None = auto
proj_refresh=config.get('proj_refresh', 200),
norm_growth_limit=config.get('norm_growth_limit', 1.01),
)
rank = config.get('rank', 256)
lr = config.get('lr', self.config.get('learning_rate', 1e-5))
logger.info(f"Apollo (rank={rank}, lr={lr}): {len(apollo_params)} apollo params, "
f"{len(standard_params)} standard, "
f"state={optimizer.state_size_bytes()/1e6:.1f}MB")