88 lines
2.7 KiB
Python
88 lines
2.7 KiB
Python
|
|
#!/usr/bin/env python3
|
||
|
|
"""Export vLLM's live model weight IPC handles for the training process.
|
||
|
|
|
||
|
|
Connects to a running vLLM instance, iterates over model parameters,
|
||
|
|
and exports CUDA IPC handles that allow another process to access the
|
||
|
|
same GPU memory without copying.
|
||
|
|
|
||
|
|
Usage:
|
||
|
|
# Run after vLLM is serving:
|
||
|
|
python3 export_weights.py --output /tmp/vllm_weight_handles.pt
|
||
|
|
|
||
|
|
# Or via vLLM's API (future):
|
||
|
|
curl -X POST http://localhost:8000/export_weights
|
||
|
|
"""
|
||
|
|
|
||
|
|
import argparse
|
||
|
|
import sys
|
||
|
|
import torch
|
||
|
|
from pathlib import Path
|
||
|
|
|
||
|
|
|
||
|
|
def export_from_model(model, output_path: str):
|
||
|
|
"""Export IPC handles for all model parameters."""
|
||
|
|
from torch.multiprocessing.reductions import reduce_tensor
|
||
|
|
|
||
|
|
handles = {}
|
||
|
|
total_bytes = 0
|
||
|
|
|
||
|
|
for name, param in model.named_parameters():
|
||
|
|
handle = reduce_tensor(param.data)
|
||
|
|
handles[name] = {
|
||
|
|
'handle': handle,
|
||
|
|
'shape': list(param.shape),
|
||
|
|
'dtype': str(param.dtype),
|
||
|
|
}
|
||
|
|
param_bytes = param.nelement() * param.element_size()
|
||
|
|
total_bytes += param_bytes
|
||
|
|
|
||
|
|
torch.save(handles, output_path)
|
||
|
|
|
||
|
|
n_params = len(handles)
|
||
|
|
print(f"Exported {n_params} parameters ({total_bytes / 1e9:.1f} GB)")
|
||
|
|
print(f"Saved to {output_path}")
|
||
|
|
return handles
|
||
|
|
|
||
|
|
|
||
|
|
def main():
|
||
|
|
parser = argparse.ArgumentParser(description="Export vLLM weight IPC handles")
|
||
|
|
parser.add_argument("--output", "-o", default="/tmp/vllm_weight_handles.pt",
|
||
|
|
help="Output path for IPC handles")
|
||
|
|
parser.add_argument("--vllm-pid", type=int, default=None,
|
||
|
|
help="vLLM worker PID (auto-detected if not specified)")
|
||
|
|
args = parser.parse_args()
|
||
|
|
|
||
|
|
# For now: load the model directly and export.
|
||
|
|
# TODO: connect to running vLLM process instead.
|
||
|
|
print("Note: This currently loads the model separately.")
|
||
|
|
print("Full integration will export from the running vLLM process.")
|
||
|
|
print()
|
||
|
|
|
||
|
|
# Detect model path from running vLLM
|
||
|
|
import subprocess
|
||
|
|
result = subprocess.run(
|
||
|
|
['ps', 'aux'], capture_output=True, text=True
|
||
|
|
)
|
||
|
|
model_path = None
|
||
|
|
for line in result.stdout.split('\n'):
|
||
|
|
if 'vllm' in line and '--model' in line:
|
||
|
|
parts = line.split()
|
||
|
|
for i, p in enumerate(parts):
|
||
|
|
if p == '--model' and i + 1 < len(parts):
|
||
|
|
model_path = parts[i + 1]
|
||
|
|
break
|
||
|
|
# Also check model_tag format
|
||
|
|
if p.startswith('--model='):
|
||
|
|
model_path = p.split('=', 1)[1]
|
||
|
|
break
|
||
|
|
|
||
|
|
if model_path:
|
||
|
|
print(f"Detected vLLM model: {model_path}")
|
||
|
|
else:
|
||
|
|
print("Could not detect running vLLM model. Specify manually.")
|
||
|
|
sys.exit(1)
|
||
|
|
|
||
|
|
|
||
|
|
if __name__ == '__main__':
|
||
|
|
main()
|