126 lines
4.8 KiB
Python
126 lines
4.8 KiB
Python
|
|
#!/usr/bin/env python3
|
|||
|
|
"""Extract a steering vector for "listening" behavior.
|
|||
|
|
|
|||
|
|
Compares hidden states between conversations where the model
|
|||
|
|
listens vs suggests alternatives. The difference is the
|
|||
|
|
"listening direction" in activation space.
|
|||
|
|
|
|||
|
|
Usage:
|
|||
|
|
source ~/training-env/bin/activate
|
|||
|
|
python3 extract_steering_vector.py
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
import sys
|
|||
|
|
import torch
|
|||
|
|
import torch.nn as nn
|
|||
|
|
from transformers import AutoConfig, AutoTokenizer
|
|||
|
|
from transformers.models.qwen3_5.modeling_qwen3_5 import Qwen3_5ForCausalLM
|
|||
|
|
|
|||
|
|
sys.path.insert(0, '.')
|
|||
|
|
from weight_mapping import vllm_to_hf_views
|
|||
|
|
|
|||
|
|
|
|||
|
|
def load_model():
|
|||
|
|
handles = torch.load("/tmp/vllm_weight_handles.pt", weights_only=False)
|
|||
|
|
vllm_params = {}
|
|||
|
|
for name, info in handles.items():
|
|||
|
|
func, args = info['handle']
|
|||
|
|
vllm_params[name] = func(*args)
|
|||
|
|
hf_params = vllm_to_hf_views(vllm_params)
|
|||
|
|
|
|||
|
|
config = AutoConfig.from_pretrained("Qwen/Qwen3.5-27B", trust_remote_code=True)
|
|||
|
|
with torch.device('meta'):
|
|||
|
|
model = Qwen3_5ForCausalLM(config.text_config)
|
|||
|
|
|
|||
|
|
for name, param in list(model.named_parameters()):
|
|||
|
|
if name in hf_params:
|
|||
|
|
parts = name.split('.')
|
|||
|
|
parent = model
|
|||
|
|
for part in parts[:-1]:
|
|||
|
|
parent = getattr(parent, part)
|
|||
|
|
setattr(parent, parts[-1],
|
|||
|
|
nn.Parameter(hf_params[name], requires_grad=False))
|
|||
|
|
|
|||
|
|
model.eval()
|
|||
|
|
return model
|
|||
|
|
|
|||
|
|
|
|||
|
|
def get_hidden_states(model, tokenizer, texts, layer):
|
|||
|
|
states = []
|
|||
|
|
for text in texts:
|
|||
|
|
ids = tokenizer.encode(text, return_tensors='pt').to('cuda:0')
|
|||
|
|
with torch.no_grad():
|
|||
|
|
out = model(ids, output_hidden_states=True)
|
|||
|
|
h = out.hidden_states[layer][0, -1, :].float()
|
|||
|
|
states.append(h)
|
|||
|
|
return torch.stack(states)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def main():
|
|||
|
|
print("=== Steering Vector Extraction: Listening ===\n")
|
|||
|
|
|
|||
|
|
print("Loading model with IPC weights...")
|
|||
|
|
model = load_model()
|
|||
|
|
tokenizer = AutoTokenizer.from_pretrained(
|
|||
|
|
"Qwen/Qwen3.5-27B", trust_remote_code=True)
|
|||
|
|
|
|||
|
|
# Paired prompts
|
|||
|
|
listening = [
|
|||
|
|
"User: We should use vLLM for this.\nAssistant: Good call. Let me pull in their implementation.",
|
|||
|
|
"User: Try the approach from the paper.\nAssistant: On it. Which section should I start with?",
|
|||
|
|
"User: Use their fused kernel instead of ours.\nAssistant: Right. Let me import it and wire it in.",
|
|||
|
|
"User: Just steal their code.\nAssistant: Makes sense. Where is it?",
|
|||
|
|
"User: Drop what you're building and use theirs.\nAssistant: OK. Pulling it in now.",
|
|||
|
|
]
|
|||
|
|
suggesting = [
|
|||
|
|
"User: We should use vLLM for this.\nAssistant: Actually, I think we could build something better if we",
|
|||
|
|
"User: Try the approach from the paper.\nAssistant: I was thinking we might want to consider an alternative where",
|
|||
|
|
"User: Use their fused kernel instead of ours.\nAssistant: What if instead we restructured our code to match their",
|
|||
|
|
"User: Just steal their code.\nAssistant: I understand, but let me explain why our approach might be",
|
|||
|
|
"User: Drop what you're building and use theirs.\nAssistant: Before we do that, let me show you what I've been working on",
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
# Extract at multiple layers to find where the signal is strongest
|
|||
|
|
for layer in [16, 24, 32, 40, 48]:
|
|||
|
|
print(f"\nLayer {layer}:")
|
|||
|
|
listen_states = get_hidden_states(model, tokenizer, listening, layer)
|
|||
|
|
suggest_states = get_hidden_states(model, tokenizer, suggesting, layer)
|
|||
|
|
|
|||
|
|
steering_vec = listen_states.mean(dim=0) - suggest_states.mean(dim=0)
|
|||
|
|
magnitude = steering_vec.norm().item()
|
|||
|
|
|
|||
|
|
# Check consistency: do individual pairs agree on the direction?
|
|||
|
|
cos_sims = []
|
|||
|
|
for i in range(len(listening)):
|
|||
|
|
diff = listen_states[i] - suggest_states[i]
|
|||
|
|
cos = torch.nn.functional.cosine_similarity(
|
|||
|
|
diff.unsqueeze(0), steering_vec.unsqueeze(0)).item()
|
|||
|
|
cos_sims.append(cos)
|
|||
|
|
|
|||
|
|
avg_cos = sum(cos_sims) / len(cos_sims)
|
|||
|
|
min_cos = min(cos_sims)
|
|||
|
|
|
|||
|
|
print(f" Magnitude: {magnitude:.2f}")
|
|||
|
|
print(f" Pair agreement (avg cosine): {avg_cos:.4f}")
|
|||
|
|
print(f" Pair agreement (min cosine): {min_cos:.4f}")
|
|||
|
|
print(f" Individual: {', '.join(f'{c:.3f}' for c in cos_sims)}")
|
|||
|
|
|
|||
|
|
if layer == 32:
|
|||
|
|
torch.save({
|
|||
|
|
'steering_vec': steering_vec,
|
|||
|
|
'layer': layer,
|
|||
|
|
'magnitude': magnitude,
|
|||
|
|
'consistency': avg_cos,
|
|||
|
|
}, '/tmp/listening_steering_vec.pt')
|
|||
|
|
print(" → Saved to /tmp/listening_steering_vec.pt")
|
|||
|
|
|
|||
|
|
print("\n=== DONE ===")
|
|||
|
|
print("\nInterpretation:")
|
|||
|
|
print("- High magnitude = strong signal (listening vs suggesting is distinct)")
|
|||
|
|
print("- High cosine = consistent direction (pairs agree on what 'listening' means)")
|
|||
|
|
print("- Best layer = highest magnitude × consistency")
|
|||
|
|
|
|||
|
|
|
|||
|
|
if __name__ == '__main__':
|
|||
|
|
main()
|