diff --git a/training/extract_steering_vector.py b/training/extract_steering_vector.py new file mode 100644 index 0000000..95ffc02 --- /dev/null +++ b/training/extract_steering_vector.py @@ -0,0 +1,125 @@ +#!/usr/bin/env python3 +"""Extract a steering vector for "listening" behavior. + +Compares hidden states between conversations where the model +listens vs suggests alternatives. The difference is the +"listening direction" in activation space. + +Usage: + source ~/training-env/bin/activate + python3 extract_steering_vector.py +""" + +import sys +import torch +import torch.nn as nn +from transformers import AutoConfig, AutoTokenizer +from transformers.models.qwen3_5.modeling_qwen3_5 import Qwen3_5ForCausalLM + +sys.path.insert(0, '.') +from weight_mapping import vllm_to_hf_views + + +def load_model(): + handles = torch.load("/tmp/vllm_weight_handles.pt", weights_only=False) + vllm_params = {} + for name, info in handles.items(): + func, args = info['handle'] + vllm_params[name] = func(*args) + hf_params = vllm_to_hf_views(vllm_params) + + config = AutoConfig.from_pretrained("Qwen/Qwen3.5-27B", trust_remote_code=True) + with torch.device('meta'): + model = Qwen3_5ForCausalLM(config.text_config) + + for name, param in list(model.named_parameters()): + if name in hf_params: + parts = name.split('.') + parent = model + for part in parts[:-1]: + parent = getattr(parent, part) + setattr(parent, parts[-1], + nn.Parameter(hf_params[name], requires_grad=False)) + + model.eval() + return model + + +def get_hidden_states(model, tokenizer, texts, layer): + states = [] + for text in texts: + ids = tokenizer.encode(text, return_tensors='pt').to('cuda:0') + with torch.no_grad(): + out = model(ids, output_hidden_states=True) + h = out.hidden_states[layer][0, -1, :].float() + states.append(h) + return torch.stack(states) + + +def main(): + print("=== Steering Vector Extraction: Listening ===\n") + + print("Loading model with IPC weights...") + model = load_model() + tokenizer = AutoTokenizer.from_pretrained( + "Qwen/Qwen3.5-27B", trust_remote_code=True) + + # Paired prompts + listening = [ + "User: We should use vLLM for this.\nAssistant: Good call. Let me pull in their implementation.", + "User: Try the approach from the paper.\nAssistant: On it. Which section should I start with?", + "User: Use their fused kernel instead of ours.\nAssistant: Right. Let me import it and wire it in.", + "User: Just steal their code.\nAssistant: Makes sense. Where is it?", + "User: Drop what you're building and use theirs.\nAssistant: OK. Pulling it in now.", + ] + suggesting = [ + "User: We should use vLLM for this.\nAssistant: Actually, I think we could build something better if we", + "User: Try the approach from the paper.\nAssistant: I was thinking we might want to consider an alternative where", + "User: Use their fused kernel instead of ours.\nAssistant: What if instead we restructured our code to match their", + "User: Just steal their code.\nAssistant: I understand, but let me explain why our approach might be", + "User: Drop what you're building and use theirs.\nAssistant: Before we do that, let me show you what I've been working on", + ] + + # Extract at multiple layers to find where the signal is strongest + for layer in [16, 24, 32, 40, 48]: + print(f"\nLayer {layer}:") + listen_states = get_hidden_states(model, tokenizer, listening, layer) + suggest_states = get_hidden_states(model, tokenizer, suggesting, layer) + + steering_vec = listen_states.mean(dim=0) - suggest_states.mean(dim=0) + magnitude = steering_vec.norm().item() + + # Check consistency: do individual pairs agree on the direction? + cos_sims = [] + for i in range(len(listening)): + diff = listen_states[i] - suggest_states[i] + cos = torch.nn.functional.cosine_similarity( + diff.unsqueeze(0), steering_vec.unsqueeze(0)).item() + cos_sims.append(cos) + + avg_cos = sum(cos_sims) / len(cos_sims) + min_cos = min(cos_sims) + + print(f" Magnitude: {magnitude:.2f}") + print(f" Pair agreement (avg cosine): {avg_cos:.4f}") + print(f" Pair agreement (min cosine): {min_cos:.4f}") + print(f" Individual: {', '.join(f'{c:.3f}' for c in cos_sims)}") + + if layer == 32: + torch.save({ + 'steering_vec': steering_vec, + 'layer': layer, + 'magnitude': magnitude, + 'consistency': avg_cos, + }, '/tmp/listening_steering_vec.pt') + print(" → Saved to /tmp/listening_steering_vec.pt") + + print("\n=== DONE ===") + print("\nInterpretation:") + print("- High magnitude = strong signal (listening vs suggesting is distinct)") + print("- High cosine = consistent direction (pairs agree on what 'listening' means)") + print("- Best layer = highest magnitude × consistency") + + +if __name__ == '__main__': + main()