nanochat/scripts/quick_check.py
2025-11-05 22:07:29 +03:30

67 lines
2.3 KiB
Python

"""Quick checkpoint structure check."""
import torch
import sys
checkpoint_path = "/raid/diana/nanochat_cache/chatsft_checkpoints/d20/model_000650.pt"
print(f"Loading: {checkpoint_path}")
try:
checkpoint = torch.load(checkpoint_path, map_location='cpu')
print("\n" + "="*60)
print("CHECKPOINT STRUCTURE")
print("="*60)
print(f"\nTop-level keys: {list(checkpoint.keys())}\n")
for key in checkpoint.keys():
value = checkpoint[key]
if isinstance(value, dict):
print(f"'{key}': dict with {len(value)} items")
# Show a few sub-keys if it's a dict
sub_keys = list(value.keys())[:3]
print(f" Sample keys: {sub_keys}")
elif isinstance(value, torch.Tensor):
print(f"'{key}': Tensor {value.shape}, dtype={value.dtype}")
else:
print(f"'{key}': {type(value).__name__} = {value}")
print("\n" + "="*60)
print("DIAGNOSIS")
print("="*60)
# Check what we need
has_model = 'model' in checkpoint
has_config = 'config' in checkpoint
has_state_dict = 'state_dict' in checkpoint
has_model_state_dict = 'model_state_dict' in checkpoint
print(f"\n✓ Has 'model' key: {has_model}")
print(f"✓ Has 'config' key: {has_config}")
print(f"✓ Has 'state_dict' key: {has_state_dict}")
print(f"✓ Has 'model_state_dict' key: {has_model_state_dict}")
# Try to infer the structure
print("\n" + "="*60)
print("RECOMMENDATION")
print("="*60)
if has_model and has_config:
print("\n✅ Checkpoint has expected structure!")
print(" No changes needed to benchmark_optimizations.py")
elif has_state_dict:
print("\n⚠️ Checkpoint uses 'state_dict' instead of 'model'")
print(" Need to update benchmark to use checkpoint['state_dict']")
elif has_model_state_dict:
print("\n⚠️ Checkpoint uses 'model_state_dict' instead of 'model'")
print(" Need to update benchmark to use checkpoint['model_state_dict']")
else:
print("\n❌ Checkpoint has unexpected structure!")
print(" Available keys:", list(checkpoint.keys()))
print(" You may need to check how the model was saved during training")
except Exception as e:
print(f"\n❌ Error loading checkpoint: {e}")
import traceback
traceback.print_exc()