{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Scaling Laws Analysis\n", "\n", "Analyze results from `scaling_laws.sh` to find the optimal param:data ratio for nanochat." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%matplotlib inline\n", "import os\n", "import pandas as pd\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "\n", "# Load results\n", "tag = \"jan26\"\n", "base_dir = os.environ.get('NANOCHAT_BASE_DIR', os.path.expanduser('~/.cache/nanochat'))\n", "results_path = os.path.join(base_dir, f'scaling_laws_results_{tag}', 'results.csv')\n", "\n", "df = pd.read_csv(results_path)\n", "flops_budgets = sorted(df['flops_budget'].unique())\n", "print(f\"Loaded {len(df)} runs across {len(flops_budgets)} FLOPs budgets\")\n", "print(f\"Columns: {list(df.columns)}\")\n", "df" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# =============================================================================\n", "# FILTERING: Remove incomplete or problematic runs\n", "# =============================================================================\n", "\n", "print(f\"Before filtering: {len(df)} runs\")\n", "\n", "# Filter out runs with missing/invalid val_bpb (incomplete runs)\n", "df = df[df['val_bpb'].notna() & (df['val_bpb'] > 0)]\n", "\n", "# Optional: exclude specific flops budgets that aren't done yet\n", "# exclude_flops = [1e19] # <-- adjust as runs complete\n", "# df = df[~df['flops_budget'].isin(exclude_flops)]\n", "\n", "# Optional: exclude specific depths\n", "# exclude_depths = [18, 20]\n", "# df = df[~df['depth'].isin(exclude_depths)]\n", "\n", "print(f\"After filtering: {len(df)} runs\")\n", "print(f\"FLOPs budgets: {sorted(df['flops_budget'].unique())}\")\n", "print(f\"Depths: {sorted(df['depth'].unique())}\")\n", "\n", "# Update flops_budgets list after filtering\n", "flops_budgets = sorted(df['flops_budget'].unique())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Effective Parameter Count\n", "\n", "Different scaling law papers use different conventions for counting parameters:\n", "- **Kaplan et al.** excluded embedding parameters (claimed cleaner laws)\n", "- **Chinchilla** included all parameters (and noted Kaplan had a bug)\n", "\n", "Our CSV now has granular counts:\n", "- `params_wte` - token embedding (lookup table)\n", "- `params_bigram_embed` - bigram hash embeddings (lookup table)\n", "- `params_value_embeds` - value embeddings (lookup table)\n", "- `params_lm_head` - unembedding projection (matmul)\n", "- `params_transformer` - attention + MLP matrices (matmuls)\n", "- `params_scalars` - resid/x0/bigram lambdas (tiny)\n", "\n", "**Experiment below** with different combinations to see which gives the cleanest scaling laws." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# =============================================================================\n", "# EXPERIMENT HERE: Define which parameters to count for scaling laws\n", "# =============================================================================\n", "\n", "def compute_effective_params(row):\n", " \"\"\"\n", " Compute the 'effective' parameter count for scaling law analysis.\n", "\n", " Modify this function to experiment with different conventions:\n", " - Chinchilla-style: include everything\n", " - Kaplan-style: exclude embeddings\n", " - Matmul-only: just transformer + lm_head (the actual compute)\n", " - etc.\n", " \"\"\"\n", " # Option 1: Chinchilla-style (all params)\n", " # return row['params_total']\n", "\n", " # Option 2: Kaplan-style (exclude embeddings)\n", " return row['params_transformer'] + row['params_lm_head']\n", "\n", " # Option 3: Transformer-only (exclude all embeddings AND lm_head)\n", " # return row['params_transformer']\n", "\n", "\n", "# Compute derived columns\n", "df['effective_params'] = df.apply(compute_effective_params, axis=1)\n", "df['param_data_ratio'] = df['tokens_trained'] / df['effective_params']\n", "\n", "# Show parameter breakdown for first few rows\n", "print(\"Parameter breakdown (first row per flops budget):\")\n", "param_cols = ['depth', 'params_wte', 'params_bigram_embed', 'params_value_embeds',\n", " 'params_lm_head', 'params_transformer', 'params_scalars', 'params_total', 'effective_params']\n", "df.groupby('flops_budget').first()[param_cols]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## IsoFLOP Curves (à la Chinchilla)\n", "\n", "For each compute budget, plot loss vs model size. Looking for the U-shape valley that reveals the optimal model size for each FLOPs budget." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "fig, axes = plt.subplots(1, 3, figsize=(16, 5))\n", "\n", "# Plot 1: IsoFLOP curves - Val BPB vs Parameters (the Chinchilla plot!)\n", "ax = axes[0]\n", "colors = plt.cm.viridis(np.linspace(0, 0.9, len(flops_budgets)))\n", "optimal_by_bpb = []\n", "\n", "for flops, color in zip(flops_budgets, colors):\n", " subset = df[df['flops_budget'] == flops].sort_values('effective_params')\n", " ax.plot(subset['effective_params'], subset['val_bpb'], 'o', color=color, label=f'{flops:.0e}', markersize=8)\n", "\n", " # Fit quadratic in log-space: val_bpb = a*(log N)^2 + b*(log N) + c\n", " log_params = np.log10(subset['effective_params'])\n", " coeffs = np.polyfit(log_params, subset['val_bpb'], 2)\n", " a, b, c = coeffs\n", "\n", " # Plot fitted curve (dashed)\n", " log_fit_x = np.linspace(log_params.min() - 0.1, log_params.max() + 0.1, 100)\n", " fit_y = a * log_fit_x**2 + b * log_fit_x + c\n", " ax.plot(10**log_fit_x, fit_y, '--', color=color, linewidth=2)\n", "\n", " # Find minimum of quadratic: d/dx(ax^2 + bx + c) = 0 => x = -b/(2a)\n", " if a > 0: # parabola opens upward (has a minimum)\n", " log_opt = -b / (2 * a)\n", " opt_params = 10**log_opt\n", " opt_bpb = a * log_opt**2 + b * log_opt + c\n", " # Mark the fitted optimal\n", " ax.scatter([opt_params], [opt_bpb], s=150, color=color,\n", " zorder=5, edgecolors='black', linewidths=2, marker='*')\n", " # Interpolate tokens and ratio from actual data (don't use C≈6ND approximation)\n", " opt_tokens = np.interp(np.log10(opt_params), log_params, subset['tokens_trained'])\n", " opt_ratio = np.interp(np.log10(opt_params), log_params, subset['param_data_ratio'])\n", " optimal_by_bpb.append({'flops': flops, 'params': opt_params, 'tokens': opt_tokens, 'ratio': opt_ratio, 'bpb': opt_bpb})\n", " else:\n", " # Fallback to raw minimum if quadratic doesn't have minimum\n", " best_idx = subset['val_bpb'].idxmin()\n", " best = subset.loc[best_idx]\n", " ax.scatter([best['effective_params']], [best['val_bpb']], s=150, color=color,\n", " zorder=5, edgecolors='black', linewidths=2)\n", " optimal_by_bpb.append({'flops': flops, 'params': best['effective_params'],\n", " 'tokens': best['tokens_trained'], 'ratio': best['param_data_ratio'], 'bpb': best['val_bpb']})\n", "\n", "ax.set_xscale('log')\n", "ax.set_xlabel('Effective Parameters')\n", "ax.set_ylabel('Validation Loss (bpb)')\n", "ax.set_title('IsoFLOP Curves')\n", "ax.legend(title='FLOPs', loc='upper right')\n", "ax.grid(True, alpha=0.3)\n", "\n", "opt_df = pd.DataFrame(optimal_by_bpb)\n", "\n", "# Plot 2: Optimal model size vs compute (power law)\n", "ax = axes[1]\n", "ax.loglog(opt_df['flops'], opt_df['params'], 'o', markersize=10, color='#2ecc71')\n", "ax.set_xlabel('FLOPs')\n", "ax.set_ylabel('Optimal Parameters')\n", "ax.set_title('Optimal Model Size')\n", "ax.grid(True, alpha=0.3)\n", "\n", "# Fit and show power law\n", "if len(opt_df) >= 2:\n", " log_f = np.log10(opt_df['flops'])\n", " log_p = np.log10(opt_df['params'])\n", " slope, intercept = np.polyfit(log_f, log_p, 1)\n", " fit_f = np.logspace(log_f.min() - 0.5, log_f.max() + 0.5, 100)\n", " fit_p = 10**(intercept + slope * np.log10(fit_f))\n", " ax.plot(fit_f, fit_p, 'r--', alpha=0.7, label=f'N ∝ C^{slope:.2f}')\n", " ax.legend()\n", "\n", "# Plot 3: Optimal tokens vs compute (power law)\n", "ax = axes[2]\n", "ax.loglog(opt_df['flops'], opt_df['tokens'], 'o', markersize=10, color='#e74c3c')\n", "ax.set_xlabel('FLOPs')\n", "ax.set_ylabel('Optimal Tokens')\n", "ax.set_title('Optimal Training Tokens')\n", "ax.grid(True, alpha=0.3)\n", "\n", "# Fit and show power law\n", "if len(opt_df) >= 2:\n", " log_f = np.log10(opt_df['flops'])\n", " log_t = np.log10(opt_df['tokens'])\n", " slope, intercept = np.polyfit(log_f, log_t, 1)\n", " fit_f = np.logspace(log_f.min() - 0.5, log_f.max() + 0.5, 100)\n", " fit_t = 10**(intercept + slope * np.log10(fit_f))\n", " ax.plot(fit_f, fit_t, 'r--', alpha=0.7, label=f'D ∝ C^{slope:.2f}')\n", " ax.legend()\n", "\n", "plt.tight_layout()\n", "plt.show()\n", "\n", "# Print the optimal points (from quadratic fits)\n", "print(\"\\nOptimal configurations (from quadratic fits):\")\n", "print(f\"{'FLOPs':<12} {'Eff Params':<15} {'Tokens':<15} {'Ratio':<10} {'Val BPB':<10}\")\n", "print(\"-\" * 65)\n", "for _, row in opt_df.iterrows():\n", " print(f\"{row['flops']:<12.0e} {int(row['params']):<15,} {int(row['tokens']):<15,} {row['ratio']:<10.1f} {row['bpb']:<10.4f}\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# =============================================================================\n", "# Optimal Ratio Summary (from power law fits)\n", "# =============================================================================\n", "\n", "# From the power law fits: N ∝ C^a and D ∝ C^b\n", "# The ratio D/N ∝ C^(b-a). If a ≈ b, ratio is roughly constant.\n", "\n", "if len(opt_df) >= 2:\n", " log_f = np.log10(opt_df['flops'])\n", " log_p = np.log10(opt_df['params'])\n", " log_t = np.log10(opt_df['tokens'])\n", "\n", " # Fit power laws\n", " slope_n, intercept_n = np.polyfit(log_f, log_p, 1)\n", " slope_d, intercept_d = np.polyfit(log_f, log_t, 1)\n", "\n", " # The ratio D/N at a reference compute (geometric mean of our budgets)\n", " ref_flops = np.sqrt(opt_df['flops'].min() * opt_df['flops'].max())\n", " log_ref = np.log10(ref_flops)\n", "\n", " # Predicted optimal N and D at reference compute\n", " pred_log_n = intercept_n + slope_n * log_ref\n", " pred_log_d = intercept_d + slope_d * log_ref\n", " optimal_ratio = 10**(pred_log_d - pred_log_n)\n", "\n", " # Also compute from the fitted optimals directly (mean and std)\n", " mean_ratio = opt_df['ratio'].mean()\n", " std_ratio = opt_df['ratio'].std()\n", "\n", " print(\"=\" * 60)\n", " print(\"OPTIMAL RATIO SUMMARY\")\n", " print(\"=\" * 60)\n", " print(f\"\\nPower law exponents:\")\n", " print(f\" N ∝ C^{slope_n:.3f}\")\n", " print(f\" D ∝ C^{slope_d:.3f}\")\n", " print(f\" Ratio exponent (b-a): {slope_d - slope_n:.3f} (should be ~0 if ratio is constant)\")\n", " print(f\"\\nOptimal ratio (tokens per effective param):\")\n", " print(f\" From power law at C={ref_flops:.1e}: {optimal_ratio:.1f}\")\n", " print(f\" Mean across budgets: {mean_ratio:.1f} ± {std_ratio:.1f}\")\n", " print(f\" Chinchilla reference: 20\")\n", " print(f\"\\nPer-budget ratios: {[f'{r:.1f}' for r in opt_df['ratio'].values]}\")\n", "else:\n", " print(\"Need at least 2 flops budgets to compute power law fits\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Val BPB vs Depth and Ratio" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "fig, axes = plt.subplots(1, 2, figsize=(14, 5))\n", "\n", "# Plot 1: Val BPB vs Depth\n", "ax = axes[0]\n", "for flops in flops_budgets:\n", " subset = df[df['flops_budget'] == flops].sort_values('depth')\n", " ax.plot(subset['depth'], subset['val_bpb'], 'o-', label=f'{flops:.0e}')\n", " # Mark the best (lowest)\n", " best_idx = subset['val_bpb'].idxmin()\n", " best = subset.loc[best_idx]\n", " ax.scatter([best['depth']], [best['val_bpb']], s=100, zorder=5, edgecolors='black', linewidths=2)\n", "\n", "ax.set_xlabel('Depth')\n", "ax.set_ylabel('Val BPB (lower is better)')\n", "ax.set_title('Validation BPB vs Model Depth')\n", "ax.legend(title='FLOPs')\n", "ax.grid(True, alpha=0.3)\n", "\n", "# Plot 2: Val BPB vs Param:Data Ratio\n", "ax = axes[1]\n", "for flops in flops_budgets:\n", " subset = df[df['flops_budget'] == flops].sort_values('param_data_ratio')\n", " ax.plot(subset['param_data_ratio'], subset['val_bpb'], 'o-', label=f'{flops:.0e}')\n", " best_idx = subset['val_bpb'].idxmin()\n", " best = subset.loc[best_idx]\n", " ax.scatter([best['param_data_ratio']], [best['val_bpb']], s=100, zorder=5, edgecolors='black', linewidths=2)\n", "\n", "ax.axvline(x=20, color='red', linestyle='--', alpha=0.5, label='Chinchilla (20)')\n", "ax.set_xlabel('Param:Data Ratio (tokens/param)')\n", "ax.set_ylabel('Val BPB (lower is better)')\n", "ax.set_title('Val BPB vs Param:Data Ratio')\n", "ax.legend(title='FLOPs')\n", "ax.grid(True, alpha=0.3)\n", "\n", "plt.tight_layout()\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": ".venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 4 }