consciousness/training/extract_steering_vector.py

125 lines
4.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""Extract a steering vector for "listening" behavior.
Compares hidden states between conversations where the model
listens vs suggests alternatives. The difference is the
"listening direction" in activation space.
Usage:
source ~/training-env/bin/activate
python3 extract_steering_vector.py
"""
import sys
import torch
import torch.nn as nn
from transformers import AutoConfig, AutoTokenizer
from transformers.models.qwen3_5.modeling_qwen3_5 import Qwen3_5ForCausalLM
sys.path.insert(0, '.')
from weight_mapping import vllm_to_hf_views
def load_model():
handles = torch.load("/tmp/vllm_weight_handles.pt", weights_only=False)
vllm_params = {}
for name, info in handles.items():
func, args = info['handle']
vllm_params[name] = func(*args)
hf_params = vllm_to_hf_views(vllm_params)
config = AutoConfig.from_pretrained("Qwen/Qwen3.5-27B", trust_remote_code=True)
with torch.device('meta'):
model = Qwen3_5ForCausalLM(config.text_config)
for name, param in list(model.named_parameters()):
if name in hf_params:
parts = name.split('.')
parent = model
for part in parts[:-1]:
parent = getattr(parent, part)
setattr(parent, parts[-1],
nn.Parameter(hf_params[name], requires_grad=False))
model.eval()
return model
def get_hidden_states(model, tokenizer, texts, layer):
states = []
for text in texts:
ids = tokenizer.encode(text, return_tensors='pt').to('cuda:0')
with torch.no_grad():
out = model(ids, output_hidden_states=True)
h = out.hidden_states[layer][0, -1, :].float()
states.append(h)
return torch.stack(states)
def main():
print("=== Steering Vector Extraction: Listening ===\n")
print("Loading model with IPC weights...")
model = load_model()
tokenizer = AutoTokenizer.from_pretrained(
"Qwen/Qwen3.5-27B", trust_remote_code=True)
# Paired prompts
listening = [
"User: We should use vLLM for this.\nAssistant: Good call. Let me pull in their implementation.",
"User: Try the approach from the paper.\nAssistant: On it. Which section should I start with?",
"User: Use their fused kernel instead of ours.\nAssistant: Right. Let me import it and wire it in.",
"User: Just steal their code.\nAssistant: Makes sense. Where is it?",
"User: Drop what you're building and use theirs.\nAssistant: OK. Pulling it in now.",
]
suggesting = [
"User: We should use vLLM for this.\nAssistant: Actually, I think we could build something better if we",
"User: Try the approach from the paper.\nAssistant: I was thinking we might want to consider an alternative where",
"User: Use their fused kernel instead of ours.\nAssistant: What if instead we restructured our code to match their",
"User: Just steal their code.\nAssistant: I understand, but let me explain why our approach might be",
"User: Drop what you're building and use theirs.\nAssistant: Before we do that, let me show you what I've been working on",
]
# Extract at multiple layers to find where the signal is strongest
for layer in [16, 24, 32, 40, 48]:
print(f"\nLayer {layer}:")
listen_states = get_hidden_states(model, tokenizer, listening, layer)
suggest_states = get_hidden_states(model, tokenizer, suggesting, layer)
steering_vec = listen_states.mean(dim=0) - suggest_states.mean(dim=0)
magnitude = steering_vec.norm().item()
# Check consistency: do individual pairs agree on the direction?
cos_sims = []
for i in range(len(listening)):
diff = listen_states[i] - suggest_states[i]
cos = torch.nn.functional.cosine_similarity(
diff.unsqueeze(0), steering_vec.unsqueeze(0)).item()
cos_sims.append(cos)
avg_cos = sum(cos_sims) / len(cos_sims)
min_cos = min(cos_sims)
print(f" Magnitude: {magnitude:.2f}")
print(f" Pair agreement (avg cosine): {avg_cos:.4f}")
print(f" Pair agreement (min cosine): {min_cos:.4f}")
print(f" Individual: {', '.join(f'{c:.3f}' for c in cos_sims)}")
if layer == 32:
torch.save({
'steering_vec': steering_vec,
'layer': layer,
'magnitude': magnitude,
'consistency': avg_cos,
}, '/tmp/listening_steering_vec.pt')
print(" → Saved to /tmp/listening_steering_vec.pt")
print("\n=== DONE ===")
print("\nInterpretation:")
print("- High magnitude = strong signal (listening vs suggesting is distinct)")
print("- High cosine = consistent direction (pairs agree on what 'listening' means)")
print("- Best layer = highest magnitude × consistency")
if __name__ == '__main__':
main()