nanochat/dev/scaling_analysis.ipynb
2026-01-07 22:28:53 +00:00

228 lines
8.6 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Scaling Laws Analysis\n",
"\n",
"Analyze results from `scaling_laws.sh` to find the optimal param:data ratio for nanochat."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import pandas as pd\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"\n",
"# Load results\n",
"base_dir = os.environ.get('NANOCHAT_BASE_DIR', os.path.expanduser('~/.cache/nanochat'))\n",
"results_path = os.path.join(base_dir, 'scaling_laws_results', 'results.csv')\n",
"\n",
"df = pd.read_csv(results_path)\n",
"flops_budgets = sorted(df['flops_budget'].unique())\n",
"print(f\"Loaded {len(df)} runs across {len(flops_budgets)} FLOPs budgets\")\n",
"print(f\"Columns: {list(df.columns)}\")\n",
"df"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## IsoFLOP Curves (à la Chinchilla)\n",
"\n",
"For each compute budget, plot loss vs model size. Looking for the U-shape valley that reveals the optimal model size for each FLOPs budget."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"fig, axes = plt.subplots(1, 3, figsize=(16, 5))\n",
"\n",
"# Plot 1: IsoFLOP curves - Val BPB vs Parameters (the Chinchilla plot!)\n",
"ax = axes[0]\n",
"colors = plt.cm.viridis(np.linspace(0, 0.9, len(flops_budgets)))\n",
"optimal_by_bpb = []\n",
"\n",
"for flops, color in zip(flops_budgets, colors):\n",
" subset = df[df['flops_budget'] == flops].sort_values('num_scaling_params')\n",
" ax.plot(subset['num_scaling_params'], subset['val_bpb'], 'o', color=color, label=f'{flops:.0e}', markersize=8)\n",
"\n",
" # Fit quadratic in log-space: val_bpb = a*(log N)^2 + b*(log N) + c\n",
" log_params = np.log10(subset['num_scaling_params'])\n",
" coeffs = np.polyfit(log_params, subset['val_bpb'], 2)\n",
" a, b, c = coeffs\n",
"\n",
" # Plot fitted curve (dashed)\n",
" log_fit_x = np.linspace(log_params.min() - 0.1, log_params.max() + 0.1, 100)\n",
" fit_y = a * log_fit_x**2 + b * log_fit_x + c\n",
" ax.plot(10**log_fit_x, fit_y, '--', color=color, linewidth=2)\n",
"\n",
" # Find minimum of quadratic: d/dx(ax^2 + bx + c) = 0 => x = -b/(2a)\n",
" if a > 0: # parabola opens upward (has a minimum)\n",
" log_opt = -b / (2 * a)\n",
" opt_params = 10**log_opt\n",
" opt_bpb = a * log_opt**2 + b * log_opt + c\n",
" # Mark the fitted optimal\n",
" ax.scatter([opt_params], [opt_bpb], s=150, color=color,\n",
" zorder=5, edgecolors='black', linewidths=2, marker='*')\n",
" # Interpolate tokens and ratio from actual data (don't use C≈6ND approximation)\n",
" opt_tokens = np.interp(np.log10(opt_params), log_params, subset['tokens_trained'])\n",
" opt_ratio = np.interp(np.log10(opt_params), log_params, subset['param_data_ratio'])\n",
" optimal_by_bpb.append({'flops': flops, 'params': opt_params, 'tokens': opt_tokens, 'ratio': opt_ratio, 'bpb': opt_bpb})\n",
" else:\n",
" # Fallback to raw minimum if quadratic doesn't have minimum\n",
" best_idx = subset['val_bpb'].idxmin()\n",
" best = subset.loc[best_idx]\n",
" ax.scatter([best['num_scaling_params']], [best['val_bpb']], s=150, color=color,\n",
" zorder=5, edgecolors='black', linewidths=2)\n",
" optimal_by_bpb.append({'flops': flops, 'params': best['num_scaling_params'],\n",
" 'tokens': best['tokens_trained'], 'ratio': best['param_data_ratio'], 'bpb': best['val_bpb']})\n",
"\n",
"ax.set_xscale('log')\n",
"ax.set_xlabel('Parameters')\n",
"ax.set_ylabel('Validation Loss (bpb)')\n",
"ax.set_title('IsoFLOP Curves')\n",
"ax.legend(title='FLOPs', loc='upper right')\n",
"ax.grid(True, alpha=0.3)\n",
"\n",
"opt_df = pd.DataFrame(optimal_by_bpb)\n",
"\n",
"# Plot 2: Optimal model size vs compute (power law)\n",
"ax = axes[1]\n",
"ax.loglog(opt_df['flops'], opt_df['params'], 'o', markersize=10, color='#2ecc71')\n",
"ax.set_xlabel('FLOPs')\n",
"ax.set_ylabel('Optimal Parameters')\n",
"ax.set_title('Optimal Model Size')\n",
"ax.grid(True, alpha=0.3)\n",
"\n",
"# Fit and show power law\n",
"if len(opt_df) >= 2:\n",
" log_f = np.log10(opt_df['flops'])\n",
" log_p = np.log10(opt_df['params'])\n",
" slope, intercept = np.polyfit(log_f, log_p, 1)\n",
" fit_f = np.logspace(log_f.min() - 0.5, log_f.max() + 0.5, 100)\n",
" fit_p = 10**(intercept + slope * np.log10(fit_f))\n",
" ax.plot(fit_f, fit_p, 'r--', alpha=0.7, label=f'N ∝ C^{slope:.2f}')\n",
" ax.legend()\n",
"\n",
"# Plot 3: Optimal tokens vs compute (power law)\n",
"ax = axes[2]\n",
"ax.loglog(opt_df['flops'], opt_df['tokens'], 'o', markersize=10, color='#e74c3c')\n",
"ax.set_xlabel('FLOPs')\n",
"ax.set_ylabel('Optimal Tokens')\n",
"ax.set_title('Optimal Training Tokens')\n",
"ax.grid(True, alpha=0.3)\n",
"\n",
"# Fit and show power law\n",
"if len(opt_df) >= 2:\n",
" log_f = np.log10(opt_df['flops'])\n",
" log_t = np.log10(opt_df['tokens'])\n",
" slope, intercept = np.polyfit(log_f, log_t, 1)\n",
" fit_f = np.logspace(log_f.min() - 0.5, log_f.max() + 0.5, 100)\n",
" fit_t = 10**(intercept + slope * np.log10(fit_f))\n",
" ax.plot(fit_f, fit_t, 'r--', alpha=0.7, label=f'D ∝ C^{slope:.2f}')\n",
" ax.legend()\n",
"\n",
"plt.tight_layout()\n",
"plt.show()\n",
"\n",
"# Print the optimal points (from quadratic fits)\n",
"print(\"\\nOptimal configurations (from quadratic fits):\")\n",
"print(f\"{'FLOPs':<12} {'Params':<15} {'Tokens':<15} {'Ratio':<10} {'Val BPB':<10}\")\n",
"print(\"-\" * 65)\n",
"for _, row in opt_df.iterrows():\n",
" print(f\"{row['flops']:<12.0e} {int(row['params']):<15,} {int(row['tokens']):<15,} {row['ratio']:<10.1f} {row['bpb']:<10.4f}\")\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Val BPB vs Depth and Ratio"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"fig, axes = plt.subplots(1, 2, figsize=(14, 5))\n",
"\n",
"# Plot 1: Val BPB vs Depth\n",
"ax = axes[0]\n",
"for flops in flops_budgets:\n",
" subset = df[df['flops_budget'] == flops].sort_values('depth')\n",
" ax.plot(subset['depth'], subset['val_bpb'], 'o-', label=f'{flops:.0e}')\n",
" # Mark the best (lowest)\n",
" best_idx = subset['val_bpb'].idxmin()\n",
" best = subset.loc[best_idx]\n",
" ax.scatter([best['depth']], [best['val_bpb']], s=100, zorder=5, edgecolors='black', linewidths=2)\n",
"\n",
"ax.set_xlabel('Depth')\n",
"ax.set_ylabel('Val BPB (lower is better)')\n",
"ax.set_title('Validation BPB vs Model Depth')\n",
"ax.legend(title='FLOPs')\n",
"ax.grid(True, alpha=0.3)\n",
"\n",
"# Plot 2: Val BPB vs Param:Data Ratio\n",
"ax = axes[1]\n",
"for flops in flops_budgets:\n",
" subset = df[df['flops_budget'] == flops].sort_values('param_data_ratio')\n",
" ax.plot(subset['param_data_ratio'], subset['val_bpb'], 'o-', label=f'{flops:.0e}')\n",
" best_idx = subset['val_bpb'].idxmin()\n",
" best = subset.loc[best_idx]\n",
" ax.scatter([best['param_data_ratio']], [best['val_bpb']], s=100, zorder=5, edgecolors='black', linewidths=2)\n",
"\n",
"ax.axvline(x=20, color='red', linestyle='--', alpha=0.5, label='Chinchilla (20)')\n",
"ax.set_xlabel('Param:Data Ratio (tokens/param)')\n",
"ax.set_ylabel('Val BPB (lower is better)')\n",
"ax.set_title('Val BPB vs Param:Data Ratio')\n",
"ax.legend(title='FLOPs')\n",
"ax.grid(True, alpha=0.3)\n",
"\n",
"plt.tight_layout()\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 4
}