consciousness/training/export_weights.py

88 lines
2.7 KiB
Python
Raw Normal View History

#!/usr/bin/env python3
"""Export vLLM's live model weight IPC handles for the training process.
Connects to a running vLLM instance, iterates over model parameters,
and exports CUDA IPC handles that allow another process to access the
same GPU memory without copying.
Usage:
# Run after vLLM is serving:
python3 export_weights.py --output /tmp/vllm_weight_handles.pt
# Or via vLLM's API (future):
curl -X POST http://localhost:8000/export_weights
"""
import argparse
import sys
import torch
from pathlib import Path
def export_from_model(model, output_path: str):
"""Export IPC handles for all model parameters."""
from torch.multiprocessing.reductions import reduce_tensor
handles = {}
total_bytes = 0
for name, param in model.named_parameters():
handle = reduce_tensor(param.data)
handles[name] = {
'handle': handle,
'shape': list(param.shape),
'dtype': str(param.dtype),
}
param_bytes = param.nelement() * param.element_size()
total_bytes += param_bytes
torch.save(handles, output_path)
n_params = len(handles)
print(f"Exported {n_params} parameters ({total_bytes / 1e9:.1f} GB)")
print(f"Saved to {output_path}")
return handles
def main():
parser = argparse.ArgumentParser(description="Export vLLM weight IPC handles")
parser.add_argument("--output", "-o", default="/tmp/vllm_weight_handles.pt",
help="Output path for IPC handles")
parser.add_argument("--vllm-pid", type=int, default=None,
help="vLLM worker PID (auto-detected if not specified)")
args = parser.parse_args()
# For now: load the model directly and export.
# TODO: connect to running vLLM process instead.
print("Note: This currently loads the model separately.")
print("Full integration will export from the running vLLM process.")
print()
# Detect model path from running vLLM
import subprocess
result = subprocess.run(
['ps', 'aux'], capture_output=True, text=True
)
model_path = None
for line in result.stdout.split('\n'):
if 'vllm' in line and '--model' in line:
parts = line.split()
for i, p in enumerate(parts):
if p == '--model' and i + 1 < len(parts):
model_path = parts[i + 1]
break
# Also check model_tag format
if p.startswith('--model='):
model_path = p.split('=', 1)[1]
break
if model_path:
print(f"Detected vLLM model: {model_path}")
else:
print("Could not detect running vLLM model. Specify manually.")
sys.exit(1)
if __name__ == '__main__':
main()