/finetune: expose all Apollo optimizer settings
lr, rank, betas, eps, weight_decay, warmup_steps, scale, proj_refresh, norm_growth_limit — all optional with sensible defaults. Co-Authored-By: Proof of Concept <poc@bcachefs.org>
This commit is contained in:
parent
a73bcf5ae3
commit
2f08149fab
1 changed files with 16 additions and 5 deletions
|
|
@ -360,8 +360,6 @@ class ApolloWorker:
|
||||||
"""
|
"""
|
||||||
from apollo_plugin.optimizer import Apollo
|
from apollo_plugin.optimizer import Apollo
|
||||||
|
|
||||||
lr = config.get('learning_rate', self.config['learning_rate'])
|
|
||||||
|
|
||||||
# Build parameter groups (Apollo for 2D+, standard for small/1D)
|
# Build parameter groups (Apollo for 2D+, standard for small/1D)
|
||||||
apollo_params, standard_params = [], []
|
apollo_params, standard_params = [], []
|
||||||
for p in model.parameters():
|
for p in model.parameters():
|
||||||
|
|
@ -377,9 +375,22 @@ class ApolloWorker:
|
||||||
if standard_params:
|
if standard_params:
|
||||||
groups.append({'params': standard_params})
|
groups.append({'params': standard_params})
|
||||||
|
|
||||||
rank = config.get('apollo_rank', 1)
|
# Apollo settings from request config, falling back to server defaults
|
||||||
optimizer = Apollo(groups, lr=lr, rank=rank)
|
optimizer = Apollo(
|
||||||
logger.info(f"Apollo-Mini: {len(apollo_params)} apollo params, "
|
groups,
|
||||||
|
lr=config.get('lr', self.config.get('learning_rate', 1e-5)),
|
||||||
|
rank=config.get('rank', 256),
|
||||||
|
betas=tuple(config.get('betas', (0.9, 0.999))),
|
||||||
|
eps=config.get('eps', 1e-8),
|
||||||
|
weight_decay=config.get('weight_decay', 0.01),
|
||||||
|
warmup_steps=config.get('warmup_steps', 0),
|
||||||
|
scale=config.get('scale'), # None = auto
|
||||||
|
proj_refresh=config.get('proj_refresh', 200),
|
||||||
|
norm_growth_limit=config.get('norm_growth_limit', 1.01),
|
||||||
|
)
|
||||||
|
rank = config.get('rank', 256)
|
||||||
|
lr = config.get('lr', self.config.get('learning_rate', 1e-5))
|
||||||
|
logger.info(f"Apollo (rank={rank}, lr={lr}): {len(apollo_params)} apollo params, "
|
||||||
f"{len(standard_params)} standard, "
|
f"{len(standard_params)} standard, "
|
||||||
f"state={optimizer.state_size_bytes()/1e6:.1f}MB")
|
f"state={optimizer.state_size_bytes()/1e6:.1f}MB")
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue