"""Training example construction and tokenization. Takes raw conversation context + improved continuation, produces tokenized tensors ready for context-frozen forward+backward. """ import json from dataclasses import dataclass, field from pathlib import Path import torch from transformers import AutoTokenizer @dataclass class TrainingExample: """A single training example for context-frozen training.""" id: str context: str # conversation up to decision point continuation: str # the better response reason: str = "" # why this is a training target memories: list[str] = field(default_factory=list) # memories that were in context # Computed after tokenization input_ids: torch.Tensor | None = None context_len: int = 0 total_len: int = 0 def tokenize(self, tokenizer, max_len: int = 8192, device: str = "cuda:0"): """Tokenize context + continuation into training-ready tensors. The chat template is applied to make the token distribution match what the model sees during inference. """ # Build messages for context (everything up to the decision) # The context should already be in chat format context_ids = tokenizer.encode(self.context, add_special_tokens=False) continuation_ids = tokenizer.encode(self.continuation, add_special_tokens=False) self.context_len = len(context_ids) self.total_len = len(context_ids) + len(continuation_ids) if self.total_len > max_len: # Truncate context from the left, keep continuation intact excess = self.total_len - max_len context_ids = context_ids[excess:] self.context_len = len(context_ids) self.total_len = len(context_ids) + len(continuation_ids) all_ids = context_ids + continuation_ids self.input_ids = torch.tensor(all_ids, device=device) return self def to_dict(self) -> dict: return { 'id': self.id, 'context': self.context, 'continuation': self.continuation, 'reason': self.reason, 'memories': self.memories, 'context_len': self.context_len, 'total_len': self.total_len, } @classmethod def from_dict(cls, d: dict) -> 'TrainingExample': return cls( id=d['id'], context=d['context'], continuation=d['continuation'], reason=d.get('reason', ''), memories=d.get('memories', []), ) def load_examples(path: str) -> list[TrainingExample]: """Load training examples from JSONL file.""" examples = [] with open(path) as f: for line in f: if line.strip(): examples.append(TrainingExample.from_dict(json.loads(line))) return examples def save_examples(examples: list[TrainingExample], path: str): """Save training examples to JSONL file.""" with open(path, 'w') as f: for ex in examples: f.write(json.dumps(ex.to_dict()) + '\n') class ExampleTokenizer: """Handles tokenization with the model's chat template. Applies the same chat template that vLLM uses during inference, so the token distribution matches what the model expects. """ def __init__(self, model_path: str): self.tokenizer = AutoTokenizer.from_pretrained( model_path, trust_remote_code=True) def prepare_example(self, example: TrainingExample, max_len: int = 8192, device: str = "cuda:0") -> TrainingExample: """Tokenize an example using the chat template. For proper training, the context should be formatted exactly as vLLM would format it — with chat template applied. """ # Apply chat template to get the exact token sequence # the model would see during inference # # Context: everything up to the decision point # Continuation: the improved response # # We tokenize them separately to know where context ends # and continuation begins. context_ids = self.tokenizer.encode( example.context, add_special_tokens=True) continuation_ids = self.tokenizer.encode( example.continuation, add_special_tokens=False) example.context_len = len(context_ids) example.total_len = len(context_ids) + len(continuation_ids) if example.total_len > max_len: excess = example.total_len - max_len context_ids = context_ids[excess:] example.context_len = len(context_ids) example.total_len = example.context_len + len(continuation_ids) all_ids = context_ids + continuation_ids example.input_ids = torch.tensor(all_ids, device=device) return example def prepare_from_messages(self, example_id: str, messages: list[dict], decision_idx: int, better_response: str, reason: str = "", memories: list[str] | None = None, max_len: int = 8192, device: str = "cuda:0") -> TrainingExample: """Build a training example from a chat message list. Args: example_id: unique identifier messages: list of {"role": ..., "content": ...} dicts decision_idx: index of the assistant message to replace better_response: the improved response text reason: why this is a training target memories: memory keys that were in context max_len: maximum sequence length device: target device Returns: Tokenized TrainingExample """ # Context: all messages up to (not including) the decision context_messages = messages[:decision_idx] context_text = self.tokenizer.apply_chat_template( context_messages, tokenize=False, add_generation_prompt=True) # Build the example example = TrainingExample( id=example_id, context=context_text, continuation=better_response, reason=reason, memories=memories or [], ) return self.prepare_example(example, max_len=max_len, device=device)