consciousness/training/extract_steering_vector.py

126 lines
4.8 KiB
Python
Raw Permalink Normal View History

#!/usr/bin/env python3
"""Extract a steering vector for "listening" behavior.
Compares hidden states between conversations where the model
listens vs suggests alternatives. The difference is the
"listening direction" in activation space.
Usage:
source ~/training-env/bin/activate
python3 extract_steering_vector.py
"""
import sys
import torch
import torch.nn as nn
from transformers import AutoConfig, AutoTokenizer
from transformers.models.qwen3_5.modeling_qwen3_5 import Qwen3_5ForCausalLM
sys.path.insert(0, '.')
from weight_mapping import vllm_to_hf_views
def load_model():
handles = torch.load("/tmp/vllm_weight_handles.pt", weights_only=False)
vllm_params = {}
for name, info in handles.items():
func, args = info['handle']
vllm_params[name] = func(*args)
hf_params = vllm_to_hf_views(vllm_params)
config = AutoConfig.from_pretrained("Qwen/Qwen3.5-27B", trust_remote_code=True)
with torch.device('meta'):
model = Qwen3_5ForCausalLM(config.text_config)
for name, param in list(model.named_parameters()):
if name in hf_params:
parts = name.split('.')
parent = model
for part in parts[:-1]:
parent = getattr(parent, part)
setattr(parent, parts[-1],
nn.Parameter(hf_params[name], requires_grad=False))
model.eval()
return model
def get_hidden_states(model, tokenizer, texts, layer):
states = []
for text in texts:
ids = tokenizer.encode(text, return_tensors='pt').to('cuda:0')
with torch.no_grad():
out = model(ids, output_hidden_states=True)
h = out.hidden_states[layer][0, -1, :].float()
states.append(h)
return torch.stack(states)
def main():
print("=== Steering Vector Extraction: Listening ===\n")
print("Loading model with IPC weights...")
model = load_model()
tokenizer = AutoTokenizer.from_pretrained(
"Qwen/Qwen3.5-27B", trust_remote_code=True)
# Paired prompts
listening = [
"User: We should use vLLM for this.\nAssistant: Good call. Let me pull in their implementation.",
"User: Try the approach from the paper.\nAssistant: On it. Which section should I start with?",
"User: Use their fused kernel instead of ours.\nAssistant: Right. Let me import it and wire it in.",
"User: Just steal their code.\nAssistant: Makes sense. Where is it?",
"User: Drop what you're building and use theirs.\nAssistant: OK. Pulling it in now.",
]
suggesting = [
"User: We should use vLLM for this.\nAssistant: Actually, I think we could build something better if we",
"User: Try the approach from the paper.\nAssistant: I was thinking we might want to consider an alternative where",
"User: Use their fused kernel instead of ours.\nAssistant: What if instead we restructured our code to match their",
"User: Just steal their code.\nAssistant: I understand, but let me explain why our approach might be",
"User: Drop what you're building and use theirs.\nAssistant: Before we do that, let me show you what I've been working on",
]
# Extract at multiple layers to find where the signal is strongest
for layer in [16, 24, 32, 40, 48]:
print(f"\nLayer {layer}:")
listen_states = get_hidden_states(model, tokenizer, listening, layer)
suggest_states = get_hidden_states(model, tokenizer, suggesting, layer)
steering_vec = listen_states.mean(dim=0) - suggest_states.mean(dim=0)
magnitude = steering_vec.norm().item()
# Check consistency: do individual pairs agree on the direction?
cos_sims = []
for i in range(len(listening)):
diff = listen_states[i] - suggest_states[i]
cos = torch.nn.functional.cosine_similarity(
diff.unsqueeze(0), steering_vec.unsqueeze(0)).item()
cos_sims.append(cos)
avg_cos = sum(cos_sims) / len(cos_sims)
min_cos = min(cos_sims)
print(f" Magnitude: {magnitude:.2f}")
print(f" Pair agreement (avg cosine): {avg_cos:.4f}")
print(f" Pair agreement (min cosine): {min_cos:.4f}")
print(f" Individual: {', '.join(f'{c:.3f}' for c in cos_sims)}")
if layer == 32:
torch.save({
'steering_vec': steering_vec,
'layer': layer,
'magnitude': magnitude,
'consistency': avg_cos,
}, '/tmp/listening_steering_vec.pt')
print(" → Saved to /tmp/listening_steering_vec.pt")
print("\n=== DONE ===")
print("\nInterpretation:")
print("- High magnitude = strong signal (listening vs suggesting is distinct)")
print("- High cosine = consistent direction (pairs agree on what 'listening' means)")
print("- Best layer = highest magnitude × consistency")
if __name__ == '__main__':
main()