consciousness/training/training_example.py

176 lines
6.3 KiB
Python
Raw Normal View History

"""Training example construction and tokenization.
Takes raw conversation context + improved continuation, produces
tokenized tensors ready for context-frozen forward+backward.
"""
import json
from dataclasses import dataclass, field
from pathlib import Path
import torch
from transformers import AutoTokenizer
@dataclass
class TrainingExample:
"""A single training example for context-frozen training."""
id: str
context: str # conversation up to decision point
continuation: str # the better response
reason: str = "" # why this is a training target
memories: list[str] = field(default_factory=list) # memories that were in context
# Computed after tokenization
input_ids: torch.Tensor | None = None
context_len: int = 0
total_len: int = 0
def tokenize(self, tokenizer, max_len: int = 8192, device: str = "cuda:0"):
"""Tokenize context + continuation into training-ready tensors.
The chat template is applied to make the token distribution
match what the model sees during inference.
"""
# Build messages for context (everything up to the decision)
# The context should already be in chat format
context_ids = tokenizer.encode(self.context, add_special_tokens=False)
continuation_ids = tokenizer.encode(self.continuation, add_special_tokens=False)
self.context_len = len(context_ids)
self.total_len = len(context_ids) + len(continuation_ids)
if self.total_len > max_len:
# Truncate context from the left, keep continuation intact
excess = self.total_len - max_len
context_ids = context_ids[excess:]
self.context_len = len(context_ids)
self.total_len = len(context_ids) + len(continuation_ids)
all_ids = context_ids + continuation_ids
self.input_ids = torch.tensor(all_ids, device=device)
return self
def to_dict(self) -> dict:
return {
'id': self.id,
'context': self.context,
'continuation': self.continuation,
'reason': self.reason,
'memories': self.memories,
'context_len': self.context_len,
'total_len': self.total_len,
}
@classmethod
def from_dict(cls, d: dict) -> 'TrainingExample':
return cls(
id=d['id'],
context=d['context'],
continuation=d['continuation'],
reason=d.get('reason', ''),
memories=d.get('memories', []),
)
def load_examples(path: str) -> list[TrainingExample]:
"""Load training examples from JSONL file."""
examples = []
with open(path) as f:
for line in f:
if line.strip():
examples.append(TrainingExample.from_dict(json.loads(line)))
return examples
def save_examples(examples: list[TrainingExample], path: str):
"""Save training examples to JSONL file."""
with open(path, 'w') as f:
for ex in examples:
f.write(json.dumps(ex.to_dict()) + '\n')
class ExampleTokenizer:
"""Handles tokenization with the model's chat template.
Applies the same chat template that vLLM uses during inference,
so the token distribution matches what the model expects.
"""
def __init__(self, model_path: str):
self.tokenizer = AutoTokenizer.from_pretrained(
model_path, trust_remote_code=True)
def prepare_example(self, example: TrainingExample,
max_len: int = 8192,
device: str = "cuda:0") -> TrainingExample:
"""Tokenize an example using the chat template.
For proper training, the context should be formatted exactly
as vLLM would format it with chat template applied.
"""
# Apply chat template to get the exact token sequence
# the model would see during inference
#
# Context: everything up to the decision point
# Continuation: the improved response
#
# We tokenize them separately to know where context ends
# and continuation begins.
context_ids = self.tokenizer.encode(
example.context, add_special_tokens=True)
continuation_ids = self.tokenizer.encode(
example.continuation, add_special_tokens=False)
example.context_len = len(context_ids)
example.total_len = len(context_ids) + len(continuation_ids)
if example.total_len > max_len:
excess = example.total_len - max_len
context_ids = context_ids[excess:]
example.context_len = len(context_ids)
example.total_len = example.context_len + len(continuation_ids)
all_ids = context_ids + continuation_ids
example.input_ids = torch.tensor(all_ids, device=device)
return example
def prepare_from_messages(self, example_id: str,
messages: list[dict],
decision_idx: int,
better_response: str,
reason: str = "",
memories: list[str] | None = None,
max_len: int = 8192,
device: str = "cuda:0") -> TrainingExample:
"""Build a training example from a chat message list.
Args:
example_id: unique identifier
messages: list of {"role": ..., "content": ...} dicts
decision_idx: index of the assistant message to replace
better_response: the improved response text
reason: why this is a training target
memories: memory keys that were in context
max_len: maximum sequence length
device: target device
Returns:
Tokenized TrainingExample
"""
# Context: all messages up to (not including) the decision
context_messages = messages[:decision_idx]
context_text = self.tokenizer.apply_chat_template(
context_messages, tokenize=False, add_generation_prompt=True)
# Build the example
example = TrainingExample(
id=example_id,
context=context_text,
continuation=better_response,
reason=reason,
memories=memories or [],
)
return self.prepare_example(example, max_len=max_len, device=device)