/finetune: expose all Apollo optimizer settings
lr, rank, betas, eps, weight_decay, warmup_steps, scale, proj_refresh, norm_growth_limit — all optional with sensible defaults. Co-Authored-By: Proof of Concept <poc@bcachefs.org>
This commit is contained in:
parent
a73bcf5ae3
commit
2f08149fab
1 changed files with 16 additions and 5 deletions
|
|
@ -360,8 +360,6 @@ class ApolloWorker:
|
|||
"""
|
||||
from apollo_plugin.optimizer import Apollo
|
||||
|
||||
lr = config.get('learning_rate', self.config['learning_rate'])
|
||||
|
||||
# Build parameter groups (Apollo for 2D+, standard for small/1D)
|
||||
apollo_params, standard_params = [], []
|
||||
for p in model.parameters():
|
||||
|
|
@ -377,9 +375,22 @@ class ApolloWorker:
|
|||
if standard_params:
|
||||
groups.append({'params': standard_params})
|
||||
|
||||
rank = config.get('apollo_rank', 1)
|
||||
optimizer = Apollo(groups, lr=lr, rank=rank)
|
||||
logger.info(f"Apollo-Mini: {len(apollo_params)} apollo params, "
|
||||
# Apollo settings from request config, falling back to server defaults
|
||||
optimizer = Apollo(
|
||||
groups,
|
||||
lr=config.get('lr', self.config.get('learning_rate', 1e-5)),
|
||||
rank=config.get('rank', 256),
|
||||
betas=tuple(config.get('betas', (0.9, 0.999))),
|
||||
eps=config.get('eps', 1e-8),
|
||||
weight_decay=config.get('weight_decay', 0.01),
|
||||
warmup_steps=config.get('warmup_steps', 0),
|
||||
scale=config.get('scale'), # None = auto
|
||||
proj_refresh=config.get('proj_refresh', 200),
|
||||
norm_growth_limit=config.get('norm_growth_limit', 1.01),
|
||||
)
|
||||
rank = config.get('rank', 256)
|
||||
lr = config.get('lr', self.config.get('learning_rate', 1e-5))
|
||||
logger.info(f"Apollo (rank={rank}, lr={lr}): {len(apollo_params)} apollo params, "
|
||||
f"{len(standard_params)} standard, "
|
||||
f"state={optimizer.state_size_bytes()/1e6:.1f}MB")
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue