weight_mapping: fix name prefix, add attention QKV dims

This commit is contained in:
ProofOfConcept 2026-03-30 23:09:08 -04:00
parent d0883e101b
commit 6fb9735def

View file

@ -17,6 +17,9 @@ import torch.nn as nn
HIDDEN = 5120 HIDDEN = 5120
NUM_K_HEADS = 16 NUM_K_HEADS = 16
NUM_V_HEADS = 48 NUM_V_HEADS = 48
NUM_ATTN_HEADS = 24 # full attention q heads
NUM_ATTN_KV_HEADS = 4 # full attention kv heads
ATTN_HEAD_DIM = 256
HEAD_K_DIM = 128 HEAD_K_DIM = 128
HEAD_V_DIM = 128 HEAD_V_DIM = 128
KEY_DIM = NUM_K_HEADS * HEAD_K_DIM # 2048 KEY_DIM = NUM_K_HEADS * HEAD_K_DIM # 2048
@ -26,6 +29,14 @@ NUM_LAYERS = 64
CONV_KERNEL = 4 CONV_KERNEL = 4
CONV_DIM = KEY_DIM * 2 + VALUE_DIM # 10240 CONV_DIM = KEY_DIM * 2 + VALUE_DIM # 10240
# Full attention QKV dimensions
# Q uses 2x head_dim (512) vs KV head_dim (256) in Qwen3.5
ATTN_Q_HEAD_DIM = ATTN_HEAD_DIM * 2 # 512
ATTN_Q_DIM = NUM_ATTN_HEADS * ATTN_Q_HEAD_DIM # 12288
ATTN_K_DIM = NUM_ATTN_KV_HEADS * ATTN_HEAD_DIM # 1024
ATTN_V_DIM = NUM_ATTN_KV_HEADS * ATTN_HEAD_DIM # 1024
# Total: 12288 + 1024 + 1024 = 14336 = vLLM's qkv_proj.weight[0]
def vllm_to_hf_views(vllm_params: dict[str, torch.Tensor] def vllm_to_hf_views(vllm_params: dict[str, torch.Tensor]
) -> dict[str, torch.Tensor]: ) -> dict[str, torch.Tensor]:
@ -37,38 +48,50 @@ def vllm_to_hf_views(vllm_params: dict[str, torch.Tensor]
hf_params = {} hf_params = {}
for name, tensor in vllm_params.items(): for name, tensor in vllm_params.items():
# Pass through non-merged params unchanged # vLLM and HF both use 'language_model.model.layers...' for Qwen3.5.
if 'in_proj_qkvz' not in name and \ # HF checkpoint has 'model.' prefix but named_parameters() doesn't.
'in_proj_ba' not in name and \ # Keep vLLM's names as-is — we'll match when loading into the HF model.
'gate_up_proj' not in name: hf_name = name
hf_params[name] = tensor
continue
# Split merged projections into HF-style separate weights # Split merged projections into HF-style separate weights
if 'in_proj_qkvz' in name: if 'in_proj_qkvz' in name:
# [key_dim*2 + value_dim*2, hidden] → qkv + z # GDN: [key_dim*2 + value_dim*2, hidden] → qkv + z
prefix = name.replace('in_proj_qkvz', '') prefix = hf_name.replace('in_proj_qkvz.weight', '')
qkv = tensor[:KEY_DIM * 2 + VALUE_DIM] # [key_dim*2 + value_dim, hidden] qkv = tensor[:KEY_DIM * 2 + VALUE_DIM]
z = tensor[KEY_DIM * 2 + VALUE_DIM:] # [value_dim, hidden] z = tensor[KEY_DIM * 2 + VALUE_DIM:]
hf_params[prefix + 'in_proj_qkv.weight'] = qkv hf_params[prefix + 'in_proj_qkv.weight'] = qkv
hf_params[prefix + 'in_proj_z.weight'] = z hf_params[prefix + 'in_proj_z.weight'] = z
elif 'in_proj_ba' in name: elif 'in_proj_ba' in name:
# [num_v_heads*2, hidden] → b + a # GDN: [num_v_heads*2, hidden] → b + a
prefix = name.replace('in_proj_ba', '') prefix = hf_name.replace('in_proj_ba.weight', '')
b = tensor[:NUM_V_HEADS] # [num_v_heads, hidden] b = tensor[:NUM_V_HEADS]
a = tensor[NUM_V_HEADS:] # [num_v_heads, hidden] a = tensor[NUM_V_HEADS:]
hf_params[prefix + 'in_proj_b.weight'] = b hf_params[prefix + 'in_proj_b.weight'] = b
hf_params[prefix + 'in_proj_a.weight'] = a hf_params[prefix + 'in_proj_a.weight'] = a
elif 'qkv_proj' in name:
# Full attention: [q_dim + k_dim + v_dim, hidden] → q + k + v
prefix = hf_name.replace('qkv_proj.weight', '')
q = tensor[:ATTN_Q_DIM]
k = tensor[ATTN_Q_DIM:ATTN_Q_DIM + ATTN_K_DIM]
v = tensor[ATTN_Q_DIM + ATTN_K_DIM:]
hf_params[prefix + 'q_proj.weight'] = q
hf_params[prefix + 'k_proj.weight'] = k
hf_params[prefix + 'v_proj.weight'] = v
elif 'gate_up_proj' in name: elif 'gate_up_proj' in name:
# [intermediate*2, hidden] → gate + up # MLP: [intermediate*2, hidden] → gate + up
prefix = name.replace('gate_up_proj', '') prefix = hf_name.replace('gate_up_proj.weight', '')
gate = tensor[:INTERMEDIATE] # [intermediate, hidden] gate = tensor[:INTERMEDIATE]
up = tensor[INTERMEDIATE:] # [intermediate, hidden] up = tensor[INTERMEDIATE:]
hf_params[prefix + 'gate_proj.weight'] = gate hf_params[prefix + 'gate_proj.weight'] = gate
hf_params[prefix + 'up_proj.weight'] = up hf_params[prefix + 'up_proj.weight'] = up
else:
# Pass through unchanged (norms, biases, out_proj, etc.)
hf_params[hf_name] = tensor
return hf_params return hf_params