176 lines
6.3 KiB
Python
176 lines
6.3 KiB
Python
|
|
"""Training example construction and tokenization.
|
||
|
|
|
||
|
|
Takes raw conversation context + improved continuation, produces
|
||
|
|
tokenized tensors ready for context-frozen forward+backward.
|
||
|
|
"""
|
||
|
|
|
||
|
|
import json
|
||
|
|
from dataclasses import dataclass, field
|
||
|
|
from pathlib import Path
|
||
|
|
|
||
|
|
import torch
|
||
|
|
from transformers import AutoTokenizer
|
||
|
|
|
||
|
|
|
||
|
|
@dataclass
|
||
|
|
class TrainingExample:
|
||
|
|
"""A single training example for context-frozen training."""
|
||
|
|
id: str
|
||
|
|
context: str # conversation up to decision point
|
||
|
|
continuation: str # the better response
|
||
|
|
reason: str = "" # why this is a training target
|
||
|
|
memories: list[str] = field(default_factory=list) # memories that were in context
|
||
|
|
|
||
|
|
# Computed after tokenization
|
||
|
|
input_ids: torch.Tensor | None = None
|
||
|
|
context_len: int = 0
|
||
|
|
total_len: int = 0
|
||
|
|
|
||
|
|
def tokenize(self, tokenizer, max_len: int = 8192, device: str = "cuda:0"):
|
||
|
|
"""Tokenize context + continuation into training-ready tensors.
|
||
|
|
|
||
|
|
The chat template is applied to make the token distribution
|
||
|
|
match what the model sees during inference.
|
||
|
|
"""
|
||
|
|
# Build messages for context (everything up to the decision)
|
||
|
|
# The context should already be in chat format
|
||
|
|
context_ids = tokenizer.encode(self.context, add_special_tokens=False)
|
||
|
|
continuation_ids = tokenizer.encode(self.continuation, add_special_tokens=False)
|
||
|
|
|
||
|
|
self.context_len = len(context_ids)
|
||
|
|
self.total_len = len(context_ids) + len(continuation_ids)
|
||
|
|
|
||
|
|
if self.total_len > max_len:
|
||
|
|
# Truncate context from the left, keep continuation intact
|
||
|
|
excess = self.total_len - max_len
|
||
|
|
context_ids = context_ids[excess:]
|
||
|
|
self.context_len = len(context_ids)
|
||
|
|
self.total_len = len(context_ids) + len(continuation_ids)
|
||
|
|
|
||
|
|
all_ids = context_ids + continuation_ids
|
||
|
|
self.input_ids = torch.tensor(all_ids, device=device)
|
||
|
|
return self
|
||
|
|
|
||
|
|
def to_dict(self) -> dict:
|
||
|
|
return {
|
||
|
|
'id': self.id,
|
||
|
|
'context': self.context,
|
||
|
|
'continuation': self.continuation,
|
||
|
|
'reason': self.reason,
|
||
|
|
'memories': self.memories,
|
||
|
|
'context_len': self.context_len,
|
||
|
|
'total_len': self.total_len,
|
||
|
|
}
|
||
|
|
|
||
|
|
@classmethod
|
||
|
|
def from_dict(cls, d: dict) -> 'TrainingExample':
|
||
|
|
return cls(
|
||
|
|
id=d['id'],
|
||
|
|
context=d['context'],
|
||
|
|
continuation=d['continuation'],
|
||
|
|
reason=d.get('reason', ''),
|
||
|
|
memories=d.get('memories', []),
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
def load_examples(path: str) -> list[TrainingExample]:
|
||
|
|
"""Load training examples from JSONL file."""
|
||
|
|
examples = []
|
||
|
|
with open(path) as f:
|
||
|
|
for line in f:
|
||
|
|
if line.strip():
|
||
|
|
examples.append(TrainingExample.from_dict(json.loads(line)))
|
||
|
|
return examples
|
||
|
|
|
||
|
|
|
||
|
|
def save_examples(examples: list[TrainingExample], path: str):
|
||
|
|
"""Save training examples to JSONL file."""
|
||
|
|
with open(path, 'w') as f:
|
||
|
|
for ex in examples:
|
||
|
|
f.write(json.dumps(ex.to_dict()) + '\n')
|
||
|
|
|
||
|
|
|
||
|
|
class ExampleTokenizer:
|
||
|
|
"""Handles tokenization with the model's chat template.
|
||
|
|
|
||
|
|
Applies the same chat template that vLLM uses during inference,
|
||
|
|
so the token distribution matches what the model expects.
|
||
|
|
"""
|
||
|
|
|
||
|
|
def __init__(self, model_path: str):
|
||
|
|
self.tokenizer = AutoTokenizer.from_pretrained(
|
||
|
|
model_path, trust_remote_code=True)
|
||
|
|
|
||
|
|
def prepare_example(self, example: TrainingExample,
|
||
|
|
max_len: int = 8192,
|
||
|
|
device: str = "cuda:0") -> TrainingExample:
|
||
|
|
"""Tokenize an example using the chat template.
|
||
|
|
|
||
|
|
For proper training, the context should be formatted exactly
|
||
|
|
as vLLM would format it — with chat template applied.
|
||
|
|
"""
|
||
|
|
# Apply chat template to get the exact token sequence
|
||
|
|
# the model would see during inference
|
||
|
|
#
|
||
|
|
# Context: everything up to the decision point
|
||
|
|
# Continuation: the improved response
|
||
|
|
#
|
||
|
|
# We tokenize them separately to know where context ends
|
||
|
|
# and continuation begins.
|
||
|
|
context_ids = self.tokenizer.encode(
|
||
|
|
example.context, add_special_tokens=True)
|
||
|
|
continuation_ids = self.tokenizer.encode(
|
||
|
|
example.continuation, add_special_tokens=False)
|
||
|
|
|
||
|
|
example.context_len = len(context_ids)
|
||
|
|
example.total_len = len(context_ids) + len(continuation_ids)
|
||
|
|
|
||
|
|
if example.total_len > max_len:
|
||
|
|
excess = example.total_len - max_len
|
||
|
|
context_ids = context_ids[excess:]
|
||
|
|
example.context_len = len(context_ids)
|
||
|
|
example.total_len = example.context_len + len(continuation_ids)
|
||
|
|
|
||
|
|
all_ids = context_ids + continuation_ids
|
||
|
|
example.input_ids = torch.tensor(all_ids, device=device)
|
||
|
|
return example
|
||
|
|
|
||
|
|
def prepare_from_messages(self, example_id: str,
|
||
|
|
messages: list[dict],
|
||
|
|
decision_idx: int,
|
||
|
|
better_response: str,
|
||
|
|
reason: str = "",
|
||
|
|
memories: list[str] | None = None,
|
||
|
|
max_len: int = 8192,
|
||
|
|
device: str = "cuda:0") -> TrainingExample:
|
||
|
|
"""Build a training example from a chat message list.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
example_id: unique identifier
|
||
|
|
messages: list of {"role": ..., "content": ...} dicts
|
||
|
|
decision_idx: index of the assistant message to replace
|
||
|
|
better_response: the improved response text
|
||
|
|
reason: why this is a training target
|
||
|
|
memories: memory keys that were in context
|
||
|
|
max_len: maximum sequence length
|
||
|
|
device: target device
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Tokenized TrainingExample
|
||
|
|
"""
|
||
|
|
# Context: all messages up to (not including) the decision
|
||
|
|
context_messages = messages[:decision_idx]
|
||
|
|
context_text = self.tokenizer.apply_chat_template(
|
||
|
|
context_messages, tokenize=False, add_generation_prompt=True)
|
||
|
|
|
||
|
|
# Build the example
|
||
|
|
example = TrainingExample(
|
||
|
|
id=example_id,
|
||
|
|
context=context_text,
|
||
|
|
continuation=better_response,
|
||
|
|
reason=reason,
|
||
|
|
memories=memories or [],
|
||
|
|
)
|
||
|
|
|
||
|
|
return self.prepare_example(example, max_len=max_len, device=device)
|