apollo-mini training system: initial implementation
Core components for online fine-tuning of Qwen3.5-27B with CUDA IPC shared weight memory between vLLM and the training process: - apollo_mini.py: rank-1 optimizer (SGD memory, AdamW quality) - apollo_worker.py: HTTP daemon coordinating training with vLLM - weight_mapping.py: vLLM merged → HF separate layout (zero-copy views) - training_example.py: tokenization with chat template - export_weights.py: CUDA IPC handle export from vLLM - train.py: standalone training script (alternative to daemon) - DESIGN.md: architecture and protocol documentation Validated: CUDA IPC autograd works on real Qwen3.5 weights (B200). Apollo-Mini rank-1 projection + scaling + in-place update confirmed. Co-Authored-By: Kent Overstreet <kent.overstreet@gmail.com>
This commit is contained in:
parent
13453606ae
commit
c5d7d8cb5d
7 changed files with 1484 additions and 0 deletions
175
training/training_example.py
Normal file
175
training/training_example.py
Normal file
|
|
@ -0,0 +1,175 @@
|
|||
"""Training example construction and tokenization.
|
||||
|
||||
Takes raw conversation context + improved continuation, produces
|
||||
tokenized tensors ready for context-frozen forward+backward.
|
||||
"""
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingExample:
|
||||
"""A single training example for context-frozen training."""
|
||||
id: str
|
||||
context: str # conversation up to decision point
|
||||
continuation: str # the better response
|
||||
reason: str = "" # why this is a training target
|
||||
memories: list[str] = field(default_factory=list) # memories that were in context
|
||||
|
||||
# Computed after tokenization
|
||||
input_ids: torch.Tensor | None = None
|
||||
context_len: int = 0
|
||||
total_len: int = 0
|
||||
|
||||
def tokenize(self, tokenizer, max_len: int = 8192, device: str = "cuda:0"):
|
||||
"""Tokenize context + continuation into training-ready tensors.
|
||||
|
||||
The chat template is applied to make the token distribution
|
||||
match what the model sees during inference.
|
||||
"""
|
||||
# Build messages for context (everything up to the decision)
|
||||
# The context should already be in chat format
|
||||
context_ids = tokenizer.encode(self.context, add_special_tokens=False)
|
||||
continuation_ids = tokenizer.encode(self.continuation, add_special_tokens=False)
|
||||
|
||||
self.context_len = len(context_ids)
|
||||
self.total_len = len(context_ids) + len(continuation_ids)
|
||||
|
||||
if self.total_len > max_len:
|
||||
# Truncate context from the left, keep continuation intact
|
||||
excess = self.total_len - max_len
|
||||
context_ids = context_ids[excess:]
|
||||
self.context_len = len(context_ids)
|
||||
self.total_len = len(context_ids) + len(continuation_ids)
|
||||
|
||||
all_ids = context_ids + continuation_ids
|
||||
self.input_ids = torch.tensor(all_ids, device=device)
|
||||
return self
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
'id': self.id,
|
||||
'context': self.context,
|
||||
'continuation': self.continuation,
|
||||
'reason': self.reason,
|
||||
'memories': self.memories,
|
||||
'context_len': self.context_len,
|
||||
'total_len': self.total_len,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, d: dict) -> 'TrainingExample':
|
||||
return cls(
|
||||
id=d['id'],
|
||||
context=d['context'],
|
||||
continuation=d['continuation'],
|
||||
reason=d.get('reason', ''),
|
||||
memories=d.get('memories', []),
|
||||
)
|
||||
|
||||
|
||||
def load_examples(path: str) -> list[TrainingExample]:
|
||||
"""Load training examples from JSONL file."""
|
||||
examples = []
|
||||
with open(path) as f:
|
||||
for line in f:
|
||||
if line.strip():
|
||||
examples.append(TrainingExample.from_dict(json.loads(line)))
|
||||
return examples
|
||||
|
||||
|
||||
def save_examples(examples: list[TrainingExample], path: str):
|
||||
"""Save training examples to JSONL file."""
|
||||
with open(path, 'w') as f:
|
||||
for ex in examples:
|
||||
f.write(json.dumps(ex.to_dict()) + '\n')
|
||||
|
||||
|
||||
class ExampleTokenizer:
|
||||
"""Handles tokenization with the model's chat template.
|
||||
|
||||
Applies the same chat template that vLLM uses during inference,
|
||||
so the token distribution matches what the model expects.
|
||||
"""
|
||||
|
||||
def __init__(self, model_path: str):
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_path, trust_remote_code=True)
|
||||
|
||||
def prepare_example(self, example: TrainingExample,
|
||||
max_len: int = 8192,
|
||||
device: str = "cuda:0") -> TrainingExample:
|
||||
"""Tokenize an example using the chat template.
|
||||
|
||||
For proper training, the context should be formatted exactly
|
||||
as vLLM would format it — with chat template applied.
|
||||
"""
|
||||
# Apply chat template to get the exact token sequence
|
||||
# the model would see during inference
|
||||
#
|
||||
# Context: everything up to the decision point
|
||||
# Continuation: the improved response
|
||||
#
|
||||
# We tokenize them separately to know where context ends
|
||||
# and continuation begins.
|
||||
context_ids = self.tokenizer.encode(
|
||||
example.context, add_special_tokens=True)
|
||||
continuation_ids = self.tokenizer.encode(
|
||||
example.continuation, add_special_tokens=False)
|
||||
|
||||
example.context_len = len(context_ids)
|
||||
example.total_len = len(context_ids) + len(continuation_ids)
|
||||
|
||||
if example.total_len > max_len:
|
||||
excess = example.total_len - max_len
|
||||
context_ids = context_ids[excess:]
|
||||
example.context_len = len(context_ids)
|
||||
example.total_len = example.context_len + len(continuation_ids)
|
||||
|
||||
all_ids = context_ids + continuation_ids
|
||||
example.input_ids = torch.tensor(all_ids, device=device)
|
||||
return example
|
||||
|
||||
def prepare_from_messages(self, example_id: str,
|
||||
messages: list[dict],
|
||||
decision_idx: int,
|
||||
better_response: str,
|
||||
reason: str = "",
|
||||
memories: list[str] | None = None,
|
||||
max_len: int = 8192,
|
||||
device: str = "cuda:0") -> TrainingExample:
|
||||
"""Build a training example from a chat message list.
|
||||
|
||||
Args:
|
||||
example_id: unique identifier
|
||||
messages: list of {"role": ..., "content": ...} dicts
|
||||
decision_idx: index of the assistant message to replace
|
||||
better_response: the improved response text
|
||||
reason: why this is a training target
|
||||
memories: memory keys that were in context
|
||||
max_len: maximum sequence length
|
||||
device: target device
|
||||
|
||||
Returns:
|
||||
Tokenized TrainingExample
|
||||
"""
|
||||
# Context: all messages up to (not including) the decision
|
||||
context_messages = messages[:decision_idx]
|
||||
context_text = self.tokenizer.apply_chat_template(
|
||||
context_messages, tokenize=False, add_generation_prompt=True)
|
||||
|
||||
# Build the example
|
||||
example = TrainingExample(
|
||||
id=example_id,
|
||||
context=context_text,
|
||||
continuation=better_response,
|
||||
reason=reason,
|
||||
memories=memories or [],
|
||||
)
|
||||
|
||||
return self.prepare_example(example, max_len=max_len, device=device)
|
||||
Loading…
Add table
Add a link
Reference in a new issue