{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Scaling Laws Analysis\n", "\n", "Analyze results from `scaling_laws.sh` to find the optimal param:data ratio for nanochat." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import os\n", "import pandas as pd\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "\n", "# Load results\n", "base_dir = os.environ.get('NANOCHAT_BASE_DIR', os.path.expanduser('~/.cache/nanochat'))\n", "results_path = os.path.join(base_dir, 'scaling_laws_results', 'results.csv')\n", "\n", "df = pd.read_csv(results_path)\n", "flops_budgets = sorted(df['flops_budget'].unique())\n", "print(f\"Loaded {len(df)} runs across {len(flops_budgets)} FLOPs budgets\")\n", "print(f\"Columns: {list(df.columns)}\")\n", "df" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## IsoFLOP Curves (Ć  la Chinchilla)\n", "\n", "For each compute budget, plot loss vs model size. Looking for the U-shape valley that reveals the optimal model size for each FLOPs budget." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "fig, axes = plt.subplots(1, 3, figsize=(16, 5))\n", "\n", "# Plot 1: IsoFLOP curves - Val BPB vs Parameters (the Chinchilla plot!)\n", "ax = axes[0]\n", "colors = plt.cm.viridis(np.linspace(0, 0.9, len(flops_budgets)))\n", "optimal_by_bpb = []\n", "\n", "for flops, color in zip(flops_budgets, colors):\n", " subset = df[df['flops_budget'] == flops].sort_values('num_scaling_params')\n", " ax.plot(subset['num_scaling_params'], subset['val_bpb'], 'o', color=color, label=f'{flops:.0e}', markersize=8)\n", "\n", " # Fit quadratic in log-space: val_bpb = a*(log N)^2 + b*(log N) + c\n", " log_params = np.log10(subset['num_scaling_params'])\n", " coeffs = np.polyfit(log_params, subset['val_bpb'], 2)\n", " a, b, c = coeffs\n", "\n", " # Plot fitted curve (dashed)\n", " log_fit_x = np.linspace(log_params.min() - 0.1, log_params.max() + 0.1, 100)\n", " fit_y = a * log_fit_x**2 + b * log_fit_x + c\n", " ax.plot(10**log_fit_x, fit_y, '--', color=color, linewidth=2)\n", "\n", " # Find minimum of quadratic: d/dx(ax^2 + bx + c) = 0 => x = -b/(2a)\n", " if a > 0: # parabola opens upward (has a minimum)\n", " log_opt = -b / (2 * a)\n", " opt_params = 10**log_opt\n", " opt_bpb = a * log_opt**2 + b * log_opt + c\n", " # Mark the fitted optimal\n", " ax.scatter([opt_params], [opt_bpb], s=150, color=color,\n", " zorder=5, edgecolors='black', linewidths=2, marker='*')\n", " # Interpolate tokens and ratio from actual data (don't use Cā‰ˆ6ND approximation)\n", " opt_tokens = np.interp(np.log10(opt_params), log_params, subset['tokens_trained'])\n", " opt_ratio = np.interp(np.log10(opt_params), log_params, subset['param_data_ratio'])\n", " optimal_by_bpb.append({'flops': flops, 'params': opt_params, 'tokens': opt_tokens, 'ratio': opt_ratio, 'bpb': opt_bpb})\n", " else:\n", " # Fallback to raw minimum if quadratic doesn't have minimum\n", " best_idx = subset['val_bpb'].idxmin()\n", " best = subset.loc[best_idx]\n", " ax.scatter([best['num_scaling_params']], [best['val_bpb']], s=150, color=color,\n", " zorder=5, edgecolors='black', linewidths=2)\n", " optimal_by_bpb.append({'flops': flops, 'params': best['num_scaling_params'],\n", " 'tokens': best['tokens_trained'], 'ratio': best['param_data_ratio'], 'bpb': best['val_bpb']})\n", "\n", "ax.set_xscale('log')\n", "ax.set_xlabel('Parameters')\n", "ax.set_ylabel('Validation Loss (bpb)')\n", "ax.set_title('IsoFLOP Curves')\n", "ax.legend(title='FLOPs', loc='upper right')\n", "ax.grid(True, alpha=0.3)\n", "\n", "opt_df = pd.DataFrame(optimal_by_bpb)\n", "\n", "# Plot 2: Optimal model size vs compute (power law)\n", "ax = axes[1]\n", "ax.loglog(opt_df['flops'], opt_df['params'], 'o', markersize=10, color='#2ecc71')\n", "ax.set_xlabel('FLOPs')\n", "ax.set_ylabel('Optimal Parameters')\n", "ax.set_title('Optimal Model Size')\n", "ax.grid(True, alpha=0.3)\n", "\n", "# Fit and show power law\n", "if len(opt_df) >= 2:\n", " log_f = np.log10(opt_df['flops'])\n", " log_p = np.log10(opt_df['params'])\n", " slope, intercept = np.polyfit(log_f, log_p, 1)\n", " fit_f = np.logspace(log_f.min() - 0.5, log_f.max() + 0.5, 100)\n", " fit_p = 10**(intercept + slope * np.log10(fit_f))\n", " ax.plot(fit_f, fit_p, 'r--', alpha=0.7, label=f'N āˆ C^{slope:.2f}')\n", " ax.legend()\n", "\n", "# Plot 3: Optimal tokens vs compute (power law)\n", "ax = axes[2]\n", "ax.loglog(opt_df['flops'], opt_df['tokens'], 'o', markersize=10, color='#e74c3c')\n", "ax.set_xlabel('FLOPs')\n", "ax.set_ylabel('Optimal Tokens')\n", "ax.set_title('Optimal Training Tokens')\n", "ax.grid(True, alpha=0.3)\n", "\n", "# Fit and show power law\n", "if len(opt_df) >= 2:\n", " log_f = np.log10(opt_df['flops'])\n", " log_t = np.log10(opt_df['tokens'])\n", " slope, intercept = np.polyfit(log_f, log_t, 1)\n", " fit_f = np.logspace(log_f.min() - 0.5, log_f.max() + 0.5, 100)\n", " fit_t = 10**(intercept + slope * np.log10(fit_f))\n", " ax.plot(fit_f, fit_t, 'r--', alpha=0.7, label=f'D āˆ C^{slope:.2f}')\n", " ax.legend()\n", "\n", "plt.tight_layout()\n", "plt.show()\n", "\n", "# Print the optimal points (from quadratic fits)\n", "print(\"\\nOptimal configurations (from quadratic fits):\")\n", "print(f\"{'FLOPs':<12} {'Params':<15} {'Tokens':<15} {'Ratio':<10} {'Val BPB':<10}\")\n", "print(\"-\" * 65)\n", "for _, row in opt_df.iterrows():\n", " print(f\"{row['flops']:<12.0e} {int(row['params']):<15,} {int(row['tokens']):<15,} {row['ratio']:<10.1f} {row['bpb']:<10.4f}\")\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Val BPB vs Depth and Ratio" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "fig, axes = plt.subplots(1, 2, figsize=(14, 5))\n", "\n", "# Plot 1: Val BPB vs Depth\n", "ax = axes[0]\n", "for flops in flops_budgets:\n", " subset = df[df['flops_budget'] == flops].sort_values('depth')\n", " ax.plot(subset['depth'], subset['val_bpb'], 'o-', label=f'{flops:.0e}')\n", " # Mark the best (lowest)\n", " best_idx = subset['val_bpb'].idxmin()\n", " best = subset.loc[best_idx]\n", " ax.scatter([best['depth']], [best['val_bpb']], s=100, zorder=5, edgecolors='black', linewidths=2)\n", "\n", "ax.set_xlabel('Depth')\n", "ax.set_ylabel('Val BPB (lower is better)')\n", "ax.set_title('Validation BPB vs Model Depth')\n", "ax.legend(title='FLOPs')\n", "ax.grid(True, alpha=0.3)\n", "\n", "# Plot 2: Val BPB vs Param:Data Ratio\n", "ax = axes[1]\n", "for flops in flops_budgets:\n", " subset = df[df['flops_budget'] == flops].sort_values('param_data_ratio')\n", " ax.plot(subset['param_data_ratio'], subset['val_bpb'], 'o-', label=f'{flops:.0e}')\n", " best_idx = subset['val_bpb'].idxmin()\n", " best = subset.loc[best_idx]\n", " ax.scatter([best['param_data_ratio']], [best['val_bpb']], s=100, zorder=5, edgecolors='black', linewidths=2)\n", "\n", "ax.axvline(x=20, color='red', linestyle='--', alpha=0.5, label='Chinchilla (20)')\n", "ax.set_xlabel('Param:Data Ratio (tokens/param)')\n", "ax.set_ylabel('Val BPB (lower is better)')\n", "ax.set_title('Val BPB vs Param:Data Ratio')\n", "ax.legend(title='FLOPs')\n", "ax.grid(True, alpha=0.3)\n", "\n", "plt.tight_layout()\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": ".venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 4 }