weight_mapping: fix name prefix, add attention QKV dims
This commit is contained in:
parent
d0883e101b
commit
6fb9735def
1 changed files with 41 additions and 18 deletions
|
|
@ -17,6 +17,9 @@ import torch.nn as nn
|
|||
HIDDEN = 5120
|
||||
NUM_K_HEADS = 16
|
||||
NUM_V_HEADS = 48
|
||||
NUM_ATTN_HEADS = 24 # full attention q heads
|
||||
NUM_ATTN_KV_HEADS = 4 # full attention kv heads
|
||||
ATTN_HEAD_DIM = 256
|
||||
HEAD_K_DIM = 128
|
||||
HEAD_V_DIM = 128
|
||||
KEY_DIM = NUM_K_HEADS * HEAD_K_DIM # 2048
|
||||
|
|
@ -26,6 +29,14 @@ NUM_LAYERS = 64
|
|||
CONV_KERNEL = 4
|
||||
CONV_DIM = KEY_DIM * 2 + VALUE_DIM # 10240
|
||||
|
||||
# Full attention QKV dimensions
|
||||
# Q uses 2x head_dim (512) vs KV head_dim (256) in Qwen3.5
|
||||
ATTN_Q_HEAD_DIM = ATTN_HEAD_DIM * 2 # 512
|
||||
ATTN_Q_DIM = NUM_ATTN_HEADS * ATTN_Q_HEAD_DIM # 12288
|
||||
ATTN_K_DIM = NUM_ATTN_KV_HEADS * ATTN_HEAD_DIM # 1024
|
||||
ATTN_V_DIM = NUM_ATTN_KV_HEADS * ATTN_HEAD_DIM # 1024
|
||||
# Total: 12288 + 1024 + 1024 = 14336 = vLLM's qkv_proj.weight[0]
|
||||
|
||||
|
||||
def vllm_to_hf_views(vllm_params: dict[str, torch.Tensor]
|
||||
) -> dict[str, torch.Tensor]:
|
||||
|
|
@ -37,38 +48,50 @@ def vllm_to_hf_views(vllm_params: dict[str, torch.Tensor]
|
|||
hf_params = {}
|
||||
|
||||
for name, tensor in vllm_params.items():
|
||||
# Pass through non-merged params unchanged
|
||||
if 'in_proj_qkvz' not in name and \
|
||||
'in_proj_ba' not in name and \
|
||||
'gate_up_proj' not in name:
|
||||
hf_params[name] = tensor
|
||||
continue
|
||||
# vLLM and HF both use 'language_model.model.layers...' for Qwen3.5.
|
||||
# HF checkpoint has 'model.' prefix but named_parameters() doesn't.
|
||||
# Keep vLLM's names as-is — we'll match when loading into the HF model.
|
||||
hf_name = name
|
||||
|
||||
# Split merged projections into HF-style separate weights
|
||||
if 'in_proj_qkvz' in name:
|
||||
# [key_dim*2 + value_dim*2, hidden] → qkv + z
|
||||
prefix = name.replace('in_proj_qkvz', '')
|
||||
qkv = tensor[:KEY_DIM * 2 + VALUE_DIM] # [key_dim*2 + value_dim, hidden]
|
||||
z = tensor[KEY_DIM * 2 + VALUE_DIM:] # [value_dim, hidden]
|
||||
# GDN: [key_dim*2 + value_dim*2, hidden] → qkv + z
|
||||
prefix = hf_name.replace('in_proj_qkvz.weight', '')
|
||||
qkv = tensor[:KEY_DIM * 2 + VALUE_DIM]
|
||||
z = tensor[KEY_DIM * 2 + VALUE_DIM:]
|
||||
hf_params[prefix + 'in_proj_qkv.weight'] = qkv
|
||||
hf_params[prefix + 'in_proj_z.weight'] = z
|
||||
|
||||
elif 'in_proj_ba' in name:
|
||||
# [num_v_heads*2, hidden] → b + a
|
||||
prefix = name.replace('in_proj_ba', '')
|
||||
b = tensor[:NUM_V_HEADS] # [num_v_heads, hidden]
|
||||
a = tensor[NUM_V_HEADS:] # [num_v_heads, hidden]
|
||||
# GDN: [num_v_heads*2, hidden] → b + a
|
||||
prefix = hf_name.replace('in_proj_ba.weight', '')
|
||||
b = tensor[:NUM_V_HEADS]
|
||||
a = tensor[NUM_V_HEADS:]
|
||||
hf_params[prefix + 'in_proj_b.weight'] = b
|
||||
hf_params[prefix + 'in_proj_a.weight'] = a
|
||||
|
||||
elif 'qkv_proj' in name:
|
||||
# Full attention: [q_dim + k_dim + v_dim, hidden] → q + k + v
|
||||
prefix = hf_name.replace('qkv_proj.weight', '')
|
||||
q = tensor[:ATTN_Q_DIM]
|
||||
k = tensor[ATTN_Q_DIM:ATTN_Q_DIM + ATTN_K_DIM]
|
||||
v = tensor[ATTN_Q_DIM + ATTN_K_DIM:]
|
||||
hf_params[prefix + 'q_proj.weight'] = q
|
||||
hf_params[prefix + 'k_proj.weight'] = k
|
||||
hf_params[prefix + 'v_proj.weight'] = v
|
||||
|
||||
elif 'gate_up_proj' in name:
|
||||
# [intermediate*2, hidden] → gate + up
|
||||
prefix = name.replace('gate_up_proj', '')
|
||||
gate = tensor[:INTERMEDIATE] # [intermediate, hidden]
|
||||
up = tensor[INTERMEDIATE:] # [intermediate, hidden]
|
||||
# MLP: [intermediate*2, hidden] → gate + up
|
||||
prefix = hf_name.replace('gate_up_proj.weight', '')
|
||||
gate = tensor[:INTERMEDIATE]
|
||||
up = tensor[INTERMEDIATE:]
|
||||
hf_params[prefix + 'gate_proj.weight'] = gate
|
||||
hf_params[prefix + 'up_proj.weight'] = up
|
||||
|
||||
else:
|
||||
# Pass through unchanged (norms, biases, out_proj, etc.)
|
||||
hf_params[hf_name] = tensor
|
||||
|
||||
return hf_params
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue