forked from kent/consciousness
training: restructure as vLLM plugin package
- Convert to installable package with entry points for vLLM auto-discovery - Add checkpoint_sync.py: Python replacement for Rust checkpoint binary - Block-level diffing of safetensors files (4KB blocks) - vLLM→HF weight name conversion built-in - Scheduled 10min after training jobs (batched) - API change: /train now takes raw token IDs (context_ids + continuation_ids) - No tokenizer on training side, client owns tokenization - Remove superseded code: standalone scripts, Rust binary, tokenizer helpers Install: pip install -e ./training Then vLLM auto-loads via entry point. Co-Authored-By: Proof of Concept <poc@bcachefs.org>
This commit is contained in:
parent
b649a11645
commit
a73bcf5ae3
15 changed files with 607 additions and 1068 deletions
163
training/apollo_plugin/weight_mapping.py
Normal file
163
training/apollo_plugin/weight_mapping.py
Normal file
|
|
@ -0,0 +1,163 @@
|
|||
"""Map between vLLM's merged weight layout and HuggingFace's separate layout.
|
||||
|
||||
vLLM merges weights for efficiency:
|
||||
in_proj_qkv + in_proj_z → in_proj_qkvz [key_dim*2 + value_dim*2, hidden]
|
||||
in_proj_b + in_proj_a → in_proj_ba [num_v_heads*2, hidden]
|
||||
gate_proj + up_proj → gate_up_proj [intermediate*2, hidden]
|
||||
|
||||
This module creates HF-compatible parameter views that point to the same
|
||||
GPU memory as vLLM's merged tensors. No copies — views share storage.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
# Qwen3.5-27B dimensions
|
||||
HIDDEN = 5120
|
||||
NUM_K_HEADS = 16
|
||||
NUM_V_HEADS = 48
|
||||
NUM_ATTN_HEADS = 24 # full attention q heads
|
||||
NUM_ATTN_KV_HEADS = 4 # full attention kv heads
|
||||
ATTN_HEAD_DIM = 256
|
||||
HEAD_K_DIM = 128
|
||||
HEAD_V_DIM = 128
|
||||
KEY_DIM = NUM_K_HEADS * HEAD_K_DIM # 2048
|
||||
VALUE_DIM = NUM_V_HEADS * HEAD_V_DIM # 6144
|
||||
INTERMEDIATE = 17408
|
||||
NUM_LAYERS = 64
|
||||
CONV_KERNEL = 4
|
||||
CONV_DIM = KEY_DIM * 2 + VALUE_DIM # 10240
|
||||
|
||||
# Full attention QKV dimensions
|
||||
# Q uses 2x head_dim (512) vs KV head_dim (256) in Qwen3.5
|
||||
ATTN_Q_HEAD_DIM = ATTN_HEAD_DIM * 2 # 512
|
||||
ATTN_Q_DIM = NUM_ATTN_HEADS * ATTN_Q_HEAD_DIM # 12288
|
||||
ATTN_K_DIM = NUM_ATTN_KV_HEADS * ATTN_HEAD_DIM # 1024
|
||||
ATTN_V_DIM = NUM_ATTN_KV_HEADS * ATTN_HEAD_DIM # 1024
|
||||
# Total: 12288 + 1024 + 1024 = 14336 = vLLM's qkv_proj.weight[0]
|
||||
|
||||
|
||||
def vllm_to_hf_views(vllm_params: dict[str, torch.Tensor]
|
||||
) -> dict[str, torch.Tensor]:
|
||||
"""Create HF-compatible parameter views from vLLM merged weights.
|
||||
|
||||
Returns a dict of HF-style parameter names → tensor views.
|
||||
The views share GPU memory with the vLLM tensors — no copies.
|
||||
"""
|
||||
hf_params = {}
|
||||
|
||||
for name, tensor in vllm_params.items():
|
||||
# vLLM uses 'language_model.model.layers...' but HF's text model
|
||||
# uses 'model.layers...'. Strip the 'language_model.' prefix.
|
||||
hf_name = name.removeprefix('language_model.')
|
||||
|
||||
# Split merged projections into HF-style separate weights
|
||||
if 'in_proj_qkvz' in name:
|
||||
# GDN: [key_dim*2 + value_dim*2, hidden] → qkv + z
|
||||
prefix = hf_name.replace('in_proj_qkvz.weight', '')
|
||||
qkv = tensor[:KEY_DIM * 2 + VALUE_DIM]
|
||||
z = tensor[KEY_DIM * 2 + VALUE_DIM:]
|
||||
hf_params[prefix + 'in_proj_qkv.weight'] = qkv
|
||||
hf_params[prefix + 'in_proj_z.weight'] = z
|
||||
|
||||
elif 'in_proj_ba' in name:
|
||||
# GDN: [num_v_heads*2, hidden] → b + a
|
||||
prefix = hf_name.replace('in_proj_ba.weight', '')
|
||||
b = tensor[:NUM_V_HEADS]
|
||||
a = tensor[NUM_V_HEADS:]
|
||||
hf_params[prefix + 'in_proj_b.weight'] = b
|
||||
hf_params[prefix + 'in_proj_a.weight'] = a
|
||||
|
||||
elif 'qkv_proj' in name:
|
||||
# Full attention: [q_dim + k_dim + v_dim, hidden] → q + k + v
|
||||
prefix = hf_name.replace('qkv_proj.weight', '')
|
||||
q = tensor[:ATTN_Q_DIM]
|
||||
k = tensor[ATTN_Q_DIM:ATTN_Q_DIM + ATTN_K_DIM]
|
||||
v = tensor[ATTN_Q_DIM + ATTN_K_DIM:]
|
||||
hf_params[prefix + 'q_proj.weight'] = q
|
||||
hf_params[prefix + 'k_proj.weight'] = k
|
||||
hf_params[prefix + 'v_proj.weight'] = v
|
||||
|
||||
elif 'gate_up_proj' in name:
|
||||
# MLP: [intermediate*2, hidden] → gate + up
|
||||
prefix = hf_name.replace('gate_up_proj.weight', '')
|
||||
gate = tensor[:INTERMEDIATE]
|
||||
up = tensor[INTERMEDIATE:]
|
||||
hf_params[prefix + 'gate_proj.weight'] = gate
|
||||
hf_params[prefix + 'up_proj.weight'] = up
|
||||
|
||||
else:
|
||||
# Pass through unchanged (norms, biases, out_proj, etc.)
|
||||
hf_params[hf_name] = tensor
|
||||
|
||||
return hf_params
|
||||
|
||||
|
||||
def load_hf_model_with_vllm_weights(
|
||||
vllm_params: dict[str, torch.Tensor],
|
||||
model_path: str,
|
||||
device: str = "cuda:0",
|
||||
) -> nn.Module:
|
||||
"""Load HF Qwen3.5 model with weights pointing to vLLM's GPU memory.
|
||||
|
||||
1. Creates HF-compatible views from vLLM's merged weights
|
||||
2. Instantiates the HF model with empty weights
|
||||
3. Replaces model parameters with the views
|
||||
4. Returns model ready for forward+backward (autograd enabled)
|
||||
"""
|
||||
from transformers import AutoModelForCausalLM, AutoConfig
|
||||
|
||||
# Create HF-compatible views
|
||||
hf_params = vllm_to_hf_views(vllm_params)
|
||||
|
||||
# Load config
|
||||
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
|
||||
|
||||
# Create model with empty weights (no disk I/O)
|
||||
with torch.device('meta'):
|
||||
model = AutoModelForCausalLM.from_config(
|
||||
config, trust_remote_code=True)
|
||||
|
||||
# Replace parameters with views into vLLM memory
|
||||
replaced = 0
|
||||
missing = []
|
||||
for name, param in model.named_parameters():
|
||||
if name in hf_params:
|
||||
# Replace with view (shared GPU memory)
|
||||
parts = name.rsplit('.', 1)
|
||||
parent = model
|
||||
for part in parts[0].split('.'):
|
||||
parent = getattr(parent, part)
|
||||
setattr(parent, parts[1],
|
||||
nn.Parameter(hf_params[name], requires_grad=True))
|
||||
replaced += 1
|
||||
else:
|
||||
missing.append(name)
|
||||
|
||||
print(f"Replaced {replaced} parameters with vLLM memory views")
|
||||
if missing:
|
||||
print(f"Missing {len(missing)} parameters: {missing[:5]}...")
|
||||
|
||||
model.train()
|
||||
return model
|
||||
|
||||
|
||||
def validate_views(vllm_params: dict[str, torch.Tensor],
|
||||
hf_params: dict[str, torch.Tensor]):
|
||||
"""Verify that HF views share storage with vLLM tensors."""
|
||||
for vllm_name, vllm_tensor in vllm_params.items():
|
||||
if 'in_proj_qkvz' in vllm_name:
|
||||
prefix = vllm_name.replace('in_proj_qkvz.weight', '')
|
||||
qkv_name = prefix + 'in_proj_qkv.weight'
|
||||
z_name = prefix + 'in_proj_z.weight'
|
||||
if qkv_name in hf_params:
|
||||
assert hf_params[qkv_name].storage().data_ptr() == \
|
||||
vllm_tensor.storage().data_ptr(), \
|
||||
f"{qkv_name} doesn't share storage!"
|
||||
if z_name in hf_params:
|
||||
assert hf_params[z_name].storage().data_ptr() == \
|
||||
vllm_tensor.storage().data_ptr(), \
|
||||
f"{z_name} doesn't share storage!"
|
||||
|
||||
print("All views validated — shared storage confirmed")
|
||||
Loading…
Add table
Add a link
Reference in a new issue