diff --git a/training/apollo_plugin/worker.py b/training/apollo_plugin/worker.py index 5d9ba29..d180c13 100755 --- a/training/apollo_plugin/worker.py +++ b/training/apollo_plugin/worker.py @@ -360,8 +360,6 @@ class ApolloWorker: """ from apollo_plugin.optimizer import Apollo - lr = config.get('learning_rate', self.config['learning_rate']) - # Build parameter groups (Apollo for 2D+, standard for small/1D) apollo_params, standard_params = [], [] for p in model.parameters(): @@ -377,9 +375,22 @@ class ApolloWorker: if standard_params: groups.append({'params': standard_params}) - rank = config.get('apollo_rank', 1) - optimizer = Apollo(groups, lr=lr, rank=rank) - logger.info(f"Apollo-Mini: {len(apollo_params)} apollo params, " + # Apollo settings from request config, falling back to server defaults + optimizer = Apollo( + groups, + lr=config.get('lr', self.config.get('learning_rate', 1e-5)), + rank=config.get('rank', 256), + betas=tuple(config.get('betas', (0.9, 0.999))), + eps=config.get('eps', 1e-8), + weight_decay=config.get('weight_decay', 0.01), + warmup_steps=config.get('warmup_steps', 0), + scale=config.get('scale'), # None = auto + proj_refresh=config.get('proj_refresh', 200), + norm_growth_limit=config.get('norm_growth_limit', 1.01), + ) + rank = config.get('rank', 256) + lr = config.get('lr', self.config.get('learning_rate', 1e-5)) + logger.info(f"Apollo (rank={rank}, lr={lr}): {len(apollo_params)} apollo params, " f"{len(standard_params)} standard, " f"state={optimizer.state_size_bytes()/1e6:.1f}MB")