apollo: make rank configurable (default 1 = Mini, higher ranks for experimentation)
This commit is contained in:
parent
c5d7d8cb5d
commit
e1cd4fb0ab
2 changed files with 38 additions and 27 deletions
|
|
@ -309,7 +309,7 @@ class ApolloWorker:
|
|||
samples: List[Dict[str, str]],
|
||||
config: Dict[str, Any]) -> List[float]:
|
||||
"""Run Apollo-Mini training on conversation decision points."""
|
||||
from apollo_mini import ApolloMini
|
||||
from apollo_mini import Apollo
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
lr = config.get('learning_rate', self.config['learning_rate'])
|
||||
|
|
@ -331,7 +331,8 @@ class ApolloWorker:
|
|||
if standard_params:
|
||||
groups.append({'params': standard_params})
|
||||
|
||||
optimizer = ApolloMini(groups, lr=lr)
|
||||
rank = config.get('apollo_rank', 1)
|
||||
optimizer = Apollo(groups, lr=lr, rank=rank)
|
||||
logger.info(f"Apollo-Mini: {len(apollo_params)} apollo params, "
|
||||
f"{len(standard_params)} standard, "
|
||||
f"state={optimizer.state_size_bytes()/1e6:.1f}MB")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue