diff --git a/training/weight_mapping.py b/training/weight_mapping.py index b3f15b1..1abad45 100644 --- a/training/weight_mapping.py +++ b/training/weight_mapping.py @@ -17,6 +17,9 @@ import torch.nn as nn HIDDEN = 5120 NUM_K_HEADS = 16 NUM_V_HEADS = 48 +NUM_ATTN_HEADS = 24 # full attention q heads +NUM_ATTN_KV_HEADS = 4 # full attention kv heads +ATTN_HEAD_DIM = 256 HEAD_K_DIM = 128 HEAD_V_DIM = 128 KEY_DIM = NUM_K_HEADS * HEAD_K_DIM # 2048 @@ -26,6 +29,14 @@ NUM_LAYERS = 64 CONV_KERNEL = 4 CONV_DIM = KEY_DIM * 2 + VALUE_DIM # 10240 +# Full attention QKV dimensions +# Q uses 2x head_dim (512) vs KV head_dim (256) in Qwen3.5 +ATTN_Q_HEAD_DIM = ATTN_HEAD_DIM * 2 # 512 +ATTN_Q_DIM = NUM_ATTN_HEADS * ATTN_Q_HEAD_DIM # 12288 +ATTN_K_DIM = NUM_ATTN_KV_HEADS * ATTN_HEAD_DIM # 1024 +ATTN_V_DIM = NUM_ATTN_KV_HEADS * ATTN_HEAD_DIM # 1024 +# Total: 12288 + 1024 + 1024 = 14336 = vLLM's qkv_proj.weight[0] + def vllm_to_hf_views(vllm_params: dict[str, torch.Tensor] ) -> dict[str, torch.Tensor]: @@ -37,38 +48,50 @@ def vllm_to_hf_views(vllm_params: dict[str, torch.Tensor] hf_params = {} for name, tensor in vllm_params.items(): - # Pass through non-merged params unchanged - if 'in_proj_qkvz' not in name and \ - 'in_proj_ba' not in name and \ - 'gate_up_proj' not in name: - hf_params[name] = tensor - continue + # vLLM and HF both use 'language_model.model.layers...' for Qwen3.5. + # HF checkpoint has 'model.' prefix but named_parameters() doesn't. + # Keep vLLM's names as-is — we'll match when loading into the HF model. + hf_name = name # Split merged projections into HF-style separate weights if 'in_proj_qkvz' in name: - # [key_dim*2 + value_dim*2, hidden] → qkv + z - prefix = name.replace('in_proj_qkvz', '') - qkv = tensor[:KEY_DIM * 2 + VALUE_DIM] # [key_dim*2 + value_dim, hidden] - z = tensor[KEY_DIM * 2 + VALUE_DIM:] # [value_dim, hidden] + # GDN: [key_dim*2 + value_dim*2, hidden] → qkv + z + prefix = hf_name.replace('in_proj_qkvz.weight', '') + qkv = tensor[:KEY_DIM * 2 + VALUE_DIM] + z = tensor[KEY_DIM * 2 + VALUE_DIM:] hf_params[prefix + 'in_proj_qkv.weight'] = qkv hf_params[prefix + 'in_proj_z.weight'] = z elif 'in_proj_ba' in name: - # [num_v_heads*2, hidden] → b + a - prefix = name.replace('in_proj_ba', '') - b = tensor[:NUM_V_HEADS] # [num_v_heads, hidden] - a = tensor[NUM_V_HEADS:] # [num_v_heads, hidden] + # GDN: [num_v_heads*2, hidden] → b + a + prefix = hf_name.replace('in_proj_ba.weight', '') + b = tensor[:NUM_V_HEADS] + a = tensor[NUM_V_HEADS:] hf_params[prefix + 'in_proj_b.weight'] = b hf_params[prefix + 'in_proj_a.weight'] = a + elif 'qkv_proj' in name: + # Full attention: [q_dim + k_dim + v_dim, hidden] → q + k + v + prefix = hf_name.replace('qkv_proj.weight', '') + q = tensor[:ATTN_Q_DIM] + k = tensor[ATTN_Q_DIM:ATTN_Q_DIM + ATTN_K_DIM] + v = tensor[ATTN_Q_DIM + ATTN_K_DIM:] + hf_params[prefix + 'q_proj.weight'] = q + hf_params[prefix + 'k_proj.weight'] = k + hf_params[prefix + 'v_proj.weight'] = v + elif 'gate_up_proj' in name: - # [intermediate*2, hidden] → gate + up - prefix = name.replace('gate_up_proj', '') - gate = tensor[:INTERMEDIATE] # [intermediate, hidden] - up = tensor[INTERMEDIATE:] # [intermediate, hidden] + # MLP: [intermediate*2, hidden] → gate + up + prefix = hf_name.replace('gate_up_proj.weight', '') + gate = tensor[:INTERMEDIATE] + up = tensor[INTERMEDIATE:] hf_params[prefix + 'gate_proj.weight'] = gate hf_params[prefix + 'up_proj.weight'] = up + else: + # Pass through unchanged (norms, biases, out_proj, etc.) + hf_params[hf_name] = tensor + return hf_params