diff --git a/.gitignore b/.gitignore index 4a87b23..520c94e 100644 --- a/.gitignore +++ b/.gitignore @@ -4,4 +4,4 @@ __pycache__/ rustbpe/target/ dev-ignore/ report.md -eval_bundle/ \ No newline at end of file +eval_bundle/ diff --git a/dev/gen_synthetic_data.py b/dev/gen_synthetic_data.py index 068824f..025941d 100644 --- a/dev/gen_synthetic_data.py +++ b/dev/gen_synthetic_data.py @@ -28,24 +28,23 @@ NOTE: You need OpenRouter API key in a file called "openroutertoken.txt" in the (obviously you can tune this arbitrarily to your liking) NOTE: For more details see this discussion: https://github.com/karpathy/nanochat/discussions/139 """ -import requests + +import copy import json import os -import copy import random from concurrent.futures import ThreadPoolExecutor, as_completed +import requests + from nanochat.common import get_base_dir -api_key = open("openroutertoken.txt", "r", encoding="utf-8").read().strip() +api_key = open("openroutertoken.txt", encoding="utf-8").read().strip() url = "https://openrouter.ai/api/v1/chat/completions" -headers = { - "Authorization": f"Bearer {api_key}", - "Content-Type": "application/json" -} +headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"} -readme = open("README.md", "r", encoding="utf-8").read().strip() +readme = open("README.md", encoding="utf-8").read().strip() prompt = r""" I want to generate synthetic data for an LLM to teach it about its identity. Here is the identity I want: @@ -276,48 +275,46 @@ prompt = prompt.replace("%README%", readme) # Define the JSON schema for structured output response_format = { - "type": "json_schema", - "json_schema": { - "name": "conversation", - "strict": True, - "schema": { - "type": "object", - "properties": { - "messages": { - "type": "array", - "description": "A list of conversation messages alternating between user and assistant, with the first message being a user message", - "items": { + "type": "json_schema", + "json_schema": { + "name": "conversation", + "strict": True, + "schema": { "type": "object", "properties": { - "role": { - "type": "string", - "description": "The role of the speaker, either 'user' or 'assistant'" - }, - "content": { - "type": "string", - "description": "The message content" - } + "messages": { + "type": "array", + "description": "A list of conversation messages alternating between user and assistant, with the first message being a user message", + "items": { + "type": "object", + "properties": { + "role": { + "type": "string", + "description": "The role of the speaker, either 'user' or 'assistant'", + }, + "content": {"type": "string", "description": "The message content"}, + }, + "required": ["role", "content"], + "additionalProperties": False, + }, + } }, - "required": ["role", "content"], - "additionalProperties": False - } - } - }, - "required": ["messages"], - "additionalProperties": False - } - } + "required": ["messages"], + "additionalProperties": False, + }, + }, } # Sadly it doesn't seem like Chat completions support `n` # to generate multiple completions per prompt. base_payload = { - "model": "google/gemini-2.5-flash", - "stream": False, - "response_format": response_format, - "temperature": 1.0, + "model": "google/gemini-2.5-flash", + "stream": False, + "response_format": response_format, + "temperature": 1.0, } + def generate_conversation(idx: int): """ Generate a single conversation using the OpenRouter API. @@ -325,7 +322,7 @@ def generate_conversation(idx: int): """ # pick 5 example user first messages and insert them into prompt as inspiration - rng = random.Random(idx) # use idx as seed to the rng + rng = random.Random(idx) # use idx as seed to the rng user_first_prompt = "\n".join(rng.choice(user_first_prompts) for _ in range(5)) payload = copy.deepcopy(base_payload) modified_prompt = prompt.replace("%USER_FIRST_PROMPTS%", user_first_prompt) @@ -357,7 +354,6 @@ print(f"Generating {num_conversations} conversations with {num_workers} workers. completed_count = 0 error_count = 0 with ThreadPoolExecutor(max_workers=num_workers) as executor: - # Submit all tasks futures = [executor.submit(generate_conversation, idx) for idx in range(num_conversations)] @@ -369,7 +365,9 @@ with ThreadPoolExecutor(max_workers=num_workers) as executor: # Lightly validate the conversation structure for i, message in enumerate(messages): expected_role = "user" if i % 2 == 0 else "assistant" - assert message['role'] == expected_role, f"Message {i} has role {message['role']} but should be {expected_role}" + assert message['role'] == expected_role, ( + f"Message {i} has role {message['role']} but should be {expected_role}" + ) # If all looks good, write the messages to file with open(output_file, 'a') as f: @@ -384,4 +382,3 @@ with ThreadPoolExecutor(max_workers=num_workers) as executor: print(f"\nDone! Successfully saved {completed_count} conversations to {output_file}") if error_count > 0: print(f"Encountered {error_count} errors during generation") - diff --git a/dev/generate_logo.html b/dev/generate_logo.html index edfa99e..d0da736 100644 --- a/dev/generate_logo.html +++ b/dev/generate_logo.html @@ -26,4 +26,4 @@ svg.innerHTML += ``; - \ No newline at end of file + diff --git a/dev/repackage_data_reference.py b/dev/repackage_data_reference.py index 32980a8..754035f 100644 --- a/dev/repackage_data_reference.py +++ b/dev/repackage_data_reference.py @@ -13,24 +13,25 @@ training latency. NOTE: This file is meant only as reference/documentation of the dataset preparation and it is not used during the project runtime. """ + import os import time -from datasets import load_dataset -import pyarrow.parquet as pq import pyarrow as pa +import pyarrow.parquet as pq +from datasets import load_dataset # Source dataset dataset_kwargs = { "path": "HuggingFaceFW/fineweb-edu", "split": "train", - "name": "sample-100BT", # ~100B GPT-2 tokens at ~3 chars/token => ~300B chars total + "name": "sample-100BT", # ~100B GPT-2 tokens at ~3 chars/token => ~300B chars total } ds = load_dataset(**dataset_kwargs) # Shuffle to scramble the order ds = ds.shuffle(seed=42) -ndocs = len(ds) # total number of documents to process +ndocs = len(ds) # total number of documents to process print(f"Total number of documents: {ndocs}") # Repackage into parquet files @@ -39,7 +40,7 @@ os.makedirs(output_dir, exist_ok=True) # Write to parquet files chars_per_shard = 250_000_000 -row_group_size = 1024 # HF uses 1000 but we use multiple of 2, nicer for distributed data loader later +row_group_size = 1024 # HF uses 1000 but we use multiple of 2, nicer for distributed data loader later shard_docs = [] shard_index = 0 shard_characters = 0 @@ -52,20 +53,20 @@ for doc in ds: shard_characters += len(text) collected_enough_chars = shard_characters >= chars_per_shard docs_multiple_of_row_group_size = len(shard_docs) % row_group_size == 0 - if collected_enough_chars and docs_multiple_of_row_group_size: # leads to ~100MB of text (compressed) + if collected_enough_chars and docs_multiple_of_row_group_size: # leads to ~100MB of text (compressed) shard_path = os.path.join(output_dir, f"shard_{shard_index:05d}.parquet") shard_table = pa.Table.from_pydict({"text": shard_docs}) pq.write_table( shard_table, shard_path, row_group_size=row_group_size, - use_dictionary=False, # this is usually used for categorical data - compression="zstd", # Valid values: {‘NONE’, ‘SNAPPY’, ‘GZIP’, ‘BROTLI’, ‘LZ4’, ‘ZSTD’} + use_dictionary=False, # this is usually used for categorical data + compression="zstd", # Valid values: {‘NONE’, ‘SNAPPY’, ‘GZIP’, ‘BROTLI’, ‘LZ4’, ‘ZSTD’} compression_level=3, - write_statistics=False, # not needed for text + write_statistics=False, # not needed for text ) t1 = time.time() - dt = t1 - t0 # for this shard alone + dt = t1 - t0 # for this shard alone t0 = t1 total_docs_processed += len(shard_docs) total_time_spent += dt @@ -73,15 +74,20 @@ for doc in ds: avg_time_per_doc = total_time_spent / total_docs_processed remaining_time = remaining_docs * avg_time_per_doc remaining_time_hours = remaining_time / 3600 - print(f"Wrote {shard_path}. #documents: {len(shard_docs)} | #characters: {shard_characters} | time: {dt:.2f}s | remaining time: {remaining_time_hours:.2f}h") + print( + f"Wrote {shard_path}. #documents: {len(shard_docs)} | #characters: {shard_characters} | time: {dt:.2f}s | remaining time: {remaining_time_hours:.2f}h" + ) shard_docs = [] shard_characters = 0 shard_index += 1 + # Demonstration of how the data was later uploaded to HuggingFace def upload(): import os + from huggingface_hub import HfApi + token = os.getenv("HF_TOKEN") api = HfApi(token=token) api.upload_large_folder( @@ -89,4 +95,6 @@ def upload(): repo_id="karpathy/fineweb-edu-100b-shuffle", repo_type="dataset", ) + + # upload() diff --git a/nanochat/adamw.py b/nanochat/adamw.py index db591de..975f9b2 100644 --- a/nanochat/adamw.py +++ b/nanochat/adamw.py @@ -2,6 +2,7 @@ Borrowed from modded-nanogpt. By Keller, @vagrawal, et al. Not a general optimizer! But works for our specific use. """ + import torch import torch.distributed as dist from torch import Tensor @@ -12,7 +13,15 @@ class DistAdamW(torch.optim.Optimizer): Distributed AdamW optimizer. In the style of ZeRO-2, i.e. sharded optimizer states and gradient reduction """ - def __init__(self, param_groups, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): + + def __init__( + self, + param_groups, + lr: float = 1e-3, + betas: tuple[float, float] = (0.9, 0.999), + eps: float = 1e-8, + weight_decay: float = 0.01, + ): defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) super().__init__(param_groups, defaults) @@ -30,7 +39,9 @@ class DistAdamW(torch.optim.Optimizer): grad = params[base_i].grad rank_size = grad.shape[0] // world_size grad_slice = torch.empty_like(grad[:rank_size]) - reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) + reduce_scatter_futures.append( + dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future() + ) grad_slices.append(grad_slice) idx = 0 @@ -43,7 +54,7 @@ class DistAdamW(torch.optim.Optimizer): reduce_scatter_futures[idx].wait() p = params[base] rank_size = p.shape[0] // world_size - p_slice = p[rank * rank_size:(rank + 1) * rank_size] + p_slice = p[rank * rank_size : (rank + 1) * rank_size] lr = group['lr'] * getattr(p, "lr_mul", 1.0) state = self.state[p] g_slice = grad_slices[idx] @@ -64,8 +75,8 @@ class DistAdamW(torch.optim.Optimizer): exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) # bias corrections - bias1 = 1 - beta1 ** t - bias2 = 1 - beta2 ** t + bias1 = 1 - beta1**t + bias2 = 1 - beta2**t # compute step denom = exp_avg_sq.sqrt().add_(eps) step_size = lr * (torch.sqrt(bias2) / bias1) diff --git a/nanochat/checkpoint_manager.py b/nanochat/checkpoint_manager.py index 63f257f..4da8311 100644 --- a/nanochat/checkpoint_manager.py +++ b/nanochat/checkpoint_manager.py @@ -1,25 +1,29 @@ """ Utilities for saving and loading model/optim/state checkpoints. """ -import os -import re + import glob import json import logging +import os +import re + import torch -from nanochat.common import get_base_dir +from nanochat.common import get_base_dir, setup_default_logging from nanochat.gpt import GPT, GPTConfig from nanochat.tokenizer import get_tokenizer -from nanochat.common import setup_default_logging # Set up logging setup_default_logging() logger = logging.getLogger(__name__) + + def log0(message): if int(os.environ.get('RANK', 0)) == 0: logger.info(message) + def save_checkpoint(checkpoint_dir, step, model_data, optimizer_data, meta_data, rank=0): if rank == 0: os.makedirs(checkpoint_dir, exist_ok=True) @@ -38,6 +42,7 @@ def save_checkpoint(checkpoint_dir, step, model_data, optimizer_data, meta_data, torch.save(optimizer_data, optimizer_path) logger.info(f"Saved optimizer state to: {optimizer_path}") + def load_checkpoint(checkpoint_dir, step, device, load_optimizer=False, rank=0): # Load the model state model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.pt") @@ -49,7 +54,7 @@ def load_checkpoint(checkpoint_dir, step, device, load_optimizer=False, rank=0): optimizer_data = torch.load(optimizer_path, map_location=device) # Load the metadata meta_path = os.path.join(checkpoint_dir, f"meta_{step:06d}.json") - with open(meta_path, "r", encoding="utf-8") as f: + with open(meta_path, encoding="utf-8") as f: meta_data = json.load(f) return model_data, optimizer_data, meta_data @@ -66,10 +71,7 @@ def build_model(checkpoint_dir, step, device, phase): model_data, optimizer_data, meta_data = load_checkpoint(checkpoint_dir, step, device, load_optimizer=False) if device.type in {"cpu", "mps"}: # Convert bfloat16 tensors to float for CPU inference - model_data = { - k: v.float() if v.dtype == torch.bfloat16 else v - for k, v in model_data.items() - } + model_data = {k: v.float() if v.dtype == torch.bfloat16 else v for k, v in model_data.items()} # Hack: fix torch compile issue, which prepends all keys with _orig_mod. model_data = {k.removeprefix("_orig_mod."): v for k, v in model_data.items()} model_config_kwargs = meta_data["model_config"] @@ -79,7 +81,7 @@ def build_model(checkpoint_dir, step, device, phase): model = GPT(model_config) # Load the model state model.to_empty(device=device) - model.init_weights() # note: this is dumb, but we need to init the rotary embeddings. TODO: fix model re-init + model.init_weights() # note: this is dumb, but we need to init the rotary embeddings. TODO: fix model re-init model.load_state_dict(model_data, strict=True, assign=True) # Put the model in the right training phase / mode if phase == "eval": @@ -121,9 +123,11 @@ def find_last_step(checkpoint_dir): last_step = int(max(os.path.basename(f).split("_")[-1].split(".")[0] for f in checkpoint_files)) return last_step + # ----------------------------------------------------------------------------- # convenience functions that take into account nanochat's directory structure + def load_model_from_dir(checkpoints_dir, device, phase, model_tag=None, step=None): if model_tag is None: # guess the model tag by defaulting to the largest model @@ -139,6 +143,7 @@ def load_model_from_dir(checkpoints_dir, device, phase, model_tag=None, step=Non model, tokenizer, meta_data = build_model(checkpoint_dir, step, device, phase) return model, tokenizer, meta_data + def load_model(source, *args, **kwargs): model_dir = { "base": "base_checkpoints", diff --git a/nanochat/common.py b/nanochat/common.py index 8f36f94..e22d83d 100644 --- a/nanochat/common.py +++ b/nanochat/common.py @@ -2,26 +2,30 @@ Common utilities for nanochat. """ +import logging import os import re -import logging import urllib.request + import torch import torch.distributed as dist from filelock import FileLock + class ColoredFormatter(logging.Formatter): """Custom formatter that adds colors to log messages.""" + # ANSI color codes COLORS = { - 'DEBUG': '\033[36m', # Cyan - 'INFO': '\033[32m', # Green + 'DEBUG': '\033[36m', # Cyan + 'INFO': '\033[32m', # Green 'WARNING': '\033[33m', # Yellow - 'ERROR': '\033[31m', # Red - 'CRITICAL': '\033[35m', # Magenta + 'ERROR': '\033[31m', # Red + 'CRITICAL': '\033[35m', # Magenta } RESET = '\033[0m' BOLD = '\033[1m' + def format(self, record): # Add color to the level name levelname = record.levelname @@ -36,17 +40,17 @@ class ColoredFormatter(logging.Formatter): message = re.sub(r'(Shard \d+)', rf'{self.COLORS["INFO"]}{self.BOLD}\1{self.RESET}', message) return message + def setup_default_logging(): handler = logging.StreamHandler() handler.setFormatter(ColoredFormatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')) - logging.basicConfig( - level=logging.INFO, - handlers=[handler] - ) + logging.basicConfig(level=logging.INFO, handlers=[handler]) + setup_default_logging() logger = logging.getLogger(__name__) + def get_base_dir(): # co-locate nanochat intermediates with other cached data in ~/.cache (by default) if os.environ.get("NANOCHAT_BASE_DIR"): @@ -58,6 +62,7 @@ def get_base_dir(): os.makedirs(nanochat_dir, exist_ok=True) return nanochat_dir + def download_file_with_lock(url, filename, postprocess_fn=None): """ Downloads a file from a URL to a local path in the base directory. @@ -81,7 +86,7 @@ def download_file_with_lock(url, filename, postprocess_fn=None): # Download the content as bytes print(f"Downloading {url}...") with urllib.request.urlopen(url) as response: - content = response.read() # bytes + content = response.read() # bytes # Write to local file with open(file_path, 'wb') as f: @@ -94,11 +99,13 @@ def download_file_with_lock(url, filename, postprocess_fn=None): return file_path -def print0(s="",**kwargs): + +def print0(s="", **kwargs): ddp_rank = int(os.environ.get('RANK', 0)) if ddp_rank == 0: print(s, **kwargs) + def print_banner(): # Cool DOS Rebel font ASCII banner made with https://manytools.org/hacker-tools/ascii-banner/ banner = """ @@ -113,10 +120,12 @@ def print_banner(): """ print0(banner) + def is_ddp(): # TODO is there a proper way return int(os.environ.get('RANK', -1)) != -1 + def get_dist_info(): if is_ddp(): assert all(var in os.environ for var in ['RANK', 'LOCAL_RANK', 'WORLD_SIZE']) @@ -127,6 +136,7 @@ def get_dist_info(): else: return False, 0, 0, 1 + def autodetect_device_type(): # prefer to use CUDA if available, otherwise use MPS, otherwise fallback on CPU if torch.cuda.is_available(): @@ -138,14 +148,19 @@ def autodetect_device_type(): print0(f"Autodetected device type: {device_type}") return device_type -def compute_init(device_type="cuda"): # cuda|cpu|mps + +def compute_init(device_type="cuda"): # cuda|cpu|mps """Basic initialization that we keep doing over and over, so make common.""" assert device_type in ["cuda", "mps", "cpu"], "Invalid device type atm" if device_type == "cuda": - assert torch.cuda.is_available(), "Your PyTorch installation is not configured for CUDA but device_type is 'cuda'" + assert torch.cuda.is_available(), ( + "Your PyTorch installation is not configured for CUDA but device_type is 'cuda'" + ) if device_type == "mps": - assert torch.backends.mps.is_available(), "Your PyTorch installation is not configured for MPS but device_type is 'mps'" + assert torch.backends.mps.is_available(), ( + "Your PyTorch installation is not configured for MPS but device_type is 'mps'" + ) # Reproducibility # Note that we set the global seeds here, but most of the code uses explicit rng objects. @@ -158,7 +173,7 @@ def compute_init(device_type="cuda"): # cuda|cpu|mps # Precision if device_type == "cuda": - torch.set_float32_matmul_precision("high") # uses tf32 instead of fp32 for matmuls + torch.set_float32_matmul_precision("high") # uses tf32 instead of fp32 for matmuls # Distributed setup: Distributed Data Parallel (DDP), optional, and requires CUDA ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info() @@ -168,23 +183,28 @@ def compute_init(device_type="cuda"): # cuda|cpu|mps dist.init_process_group(backend="nccl", device_id=device) dist.barrier() else: - device = torch.device(device_type) # mps|cpu + device = torch.device(device_type) # mps|cpu if ddp_rank == 0: logger.info(f"Distributed world size: {ddp_world_size}") return ddp, ddp_rank, ddp_local_rank, ddp_world_size, device + def compute_cleanup(): """Companion function to compute_init, to clean things up before script exit""" if is_ddp(): dist.destroy_process_group() + class DummyWandb: """Useful if we wish to not use wandb but have all the same signatures""" + def __init__(self): pass + def log(self, *args, **kwargs): pass + def finish(self): pass diff --git a/nanochat/configurator.py b/nanochat/configurator.py index ec1b76d..1541f7a 100644 --- a/nanochat/configurator.py +++ b/nanochat/configurator.py @@ -18,11 +18,13 @@ import os import sys from ast import literal_eval -def print0(s="",**kwargs): + +def print0(s="", **kwargs): ddp_rank = int(os.environ.get('RANK', 0)) if ddp_rank == 0: print(s, **kwargs) + for arg in sys.argv[1:]: if '=' not in arg: # assume it's the name of a config file diff --git a/nanochat/core_eval.py b/nanochat/core_eval.py index f3c9a9f..1dfe377 100644 --- a/nanochat/core_eval.py +++ b/nanochat/core_eval.py @@ -5,15 +5,17 @@ https://arxiv.org/abs/2406.11794 TODOs: - All tasks ~match except for squad. We get 31% reference is 37%. Figure out why. """ + import random -from jinja2 import Template import torch import torch.distributed as dist +from jinja2 import Template # ----------------------------------------------------------------------------- # Prompt rendering utilities + def render_prompts_mc(item, continuation_delimiter, fewshot_examples=None): """Render complete prompts for a multiple choice question""" template_str = """ @@ -24,11 +26,7 @@ def render_prompts_mc(item, continuation_delimiter, fewshot_examples=None): {{ item.query }}{{ continuation_delimiter }}{{ choice }}""".strip() template = Template(template_str) fewshot_examples = fewshot_examples or [] - context = { - 'fewshot_examples': fewshot_examples, - 'continuation_delimiter': continuation_delimiter, - 'item': item - } + context = {'fewshot_examples': fewshot_examples, 'continuation_delimiter': continuation_delimiter, 'item': item} prompts = [template.render(choice=choice, **context) for choice in item['choices']] return prompts @@ -43,13 +41,8 @@ def render_prompts_schema(item, continuation_delimiter, fewshot_examples=None): {{ context }}{{ continuation_delimiter }}{{ item.continuation }}""".strip() template = Template(template_str) fewshot_examples = fewshot_examples or [] - context = { - 'fewshot_examples': fewshot_examples, - 'continuation_delimiter': continuation_delimiter, - 'item': item - } - prompts = [template.render(context=context_option, **context) - for context_option in item['context_options']] + context = {'fewshot_examples': fewshot_examples, 'continuation_delimiter': continuation_delimiter, 'item': item} + prompts = [template.render(context=context_option, **context) for context_option in item['context_options']] return prompts @@ -67,11 +60,7 @@ def render_prompts_lm(item, continuation_delimiter, fewshot_examples=None): {{ item.context | trim }}{{ continuation_delimiter }}{% if include_continuation %}{{ item.continuation }}{% endif %}""".strip() template = Template(template_str) fewshot_examples = fewshot_examples or [] - context = { - 'fewshot_examples': fewshot_examples, - 'continuation_delimiter': continuation_delimiter, - 'item': item - } + context = {'fewshot_examples': fewshot_examples, 'continuation_delimiter': continuation_delimiter, 'item': item} # Return two prompts: without and with the continuation prompt_without = template.render(include_continuation=False, **context) prompt_with = template.render(include_continuation=True, **context) @@ -89,10 +78,7 @@ def find_common_length(token_sequences, direction='left'): - direction: 'left' for prefix, 'right' for suffix """ min_len = min(len(seq) for seq in token_sequences) - indices = { - 'left': range(min_len), - 'right': range(-1, -min_len-1, -1) - }[direction] + indices = {'left': range(min_len), 'right': range(-1, -min_len - 1, -1)}[direction] # Find the first position where the token sequences differ for i, idx in enumerate(indices): token = token_sequences[0][idx] @@ -106,7 +92,7 @@ def stack_sequences(tokens, pad_token_id): bsz, seq_len = len(tokens), max(len(x) for x in tokens) input_ids = torch.full((bsz, seq_len), pad_token_id, dtype=torch.long) for i, x in enumerate(tokens): - input_ids[i, :len(x)] = torch.tensor(x, dtype=torch.long) + input_ids[i, : len(x)] = torch.tensor(x, dtype=torch.long) return input_ids @@ -153,9 +139,7 @@ def forward_model(model, input_ids): target_ids = torch.roll(input_ids, shifts=-1, dims=1) # Calculate cross entropy at all positions losses = torch.nn.functional.cross_entropy( - outputs.view(batch_size * seq_len, -1), - target_ids.view(batch_size * seq_len), - reduction='none' + outputs.view(batch_size * seq_len, -1), target_ids.view(batch_size * seq_len), reduction='none' ).view(batch_size, seq_len) # Set the last column to be nan because there is no autoregressive loss there losses[:, -1] = float('nan') @@ -201,19 +185,19 @@ def evaluate_example(idx, model, tokenizer, data, device, task_meta): for t, s, e in zip(tokens, start_idxs, end_idxs): if len(t) > max_tokens: num_to_crop = len(t) - max_tokens - new_tokens.append(t[-max_tokens:]) # take the last max_tokens tokens - new_start_idxs.append(s - num_to_crop) # shift the indices down + new_tokens.append(t[-max_tokens:]) # take the last max_tokens tokens + new_start_idxs.append(s - num_to_crop) # shift the indices down new_end_idxs.append(e - num_to_crop) assert s - num_to_crop >= 0, "this should never happen right?" assert e - num_to_crop >= 0, "this should never happen right?" else: - new_tokens.append(t) # keep unchanged + new_tokens.append(t) # keep unchanged new_start_idxs.append(s) new_end_idxs.append(e) tokens, start_idxs, end_idxs = new_tokens, new_start_idxs, new_end_idxs # Stack up all the sequences into a batch - pad_token_id = tokenizer.get_bos_token_id() # use BOS as pad token is ok + pad_token_id = tokenizer.get_bos_token_id() # use BOS as pad token is ok input_ids = stack_sequences(tokens, pad_token_id) input_ids = input_ids.to(device) @@ -226,13 +210,12 @@ def evaluate_example(idx, model, tokenizer, data, device, task_meta): si = start_idxs[0] ei = end_idxs[0] # predictions[i] predict input_ids[i+1] autoregressively - predicted_tokens = predictions[0, si-1:ei-1] + predicted_tokens = predictions[0, si - 1 : ei - 1] actual_tokens = input_ids[0, si:ei] is_correct = torch.all(predicted_tokens == actual_tokens).item() elif task_type in ['multiple_choice', 'schema']: # For MC/schema: find the option with lowest average loss - mean_losses = [losses[i, si-1:ei-1].mean().item() - for i, (si, ei) in enumerate(zip(start_idxs, end_idxs))] + mean_losses = [losses[i, si - 1 : ei - 1].mean().item() for i, (si, ei) in enumerate(zip(start_idxs, end_idxs))] pred_idx = mean_losses.index(min(mean_losses)) is_correct = pred_idx == item['gold'] else: diff --git a/nanochat/dataloader.py b/nanochat/dataloader.py index 3271298..c55f4a7 100644 --- a/nanochat/dataloader.py +++ b/nanochat/dataloader.py @@ -1,13 +1,16 @@ from collections import deque -import torch import pyarrow.parquet as pq +import torch from nanochat.common import get_dist_info from nanochat.dataset import list_parquet_files from nanochat.tokenizer import get_tokenizer -def tokenizing_distributed_data_loader_with_state(B, T, split, tokenizer_threads=4, tokenizer_batch_size=128, device="cuda", resume_state_dict=None): + +def tokenizing_distributed_data_loader_with_state( + B, T, split, tokenizer_threads=4, tokenizer_batch_size=128, device="cuda", resume_state_dict=None +): """ Stream pretraining text from parquet files, tokenize, yield training batches. @@ -24,42 +27,44 @@ def tokenizing_distributed_data_loader_with_state(B, T, split, tokenizer_threads # infinite iterator over document batches (list of text strings) ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info() + def document_batches(): parquet_paths = list_parquet_files() parquet_paths = parquet_paths[:-1] if split == "train" else parquet_paths[-1:] resume_pq_idx = resume_state_dict["pq_idx"] if resume_state_dict is not None else 0 resume_rg_idx = resume_state_dict["rg_idx"] if resume_state_dict is not None else None - pq_idx = resume_pq_idx # we kick off parquet files at the resume index (or by default just 0) - while True: # iterate infinitely (multi-epoch) - while pq_idx < len(parquet_paths): # iterate over all parquet files + pq_idx = resume_pq_idx # we kick off parquet files at the resume index (or by default just 0) + while True: # iterate infinitely (multi-epoch) + while pq_idx < len(parquet_paths): # iterate over all parquet files filepath = parquet_paths[pq_idx] pf = pq.ParquetFile(filepath) # Start from resume point if resuming on same file, otherwise from DDP rank # I know this state resumption is a little bit tricky and a little bit hacky... sigh. if resume_rg_idx is not None: - base_idx = resume_rg_idx // ddp_world_size # in units of ddp_world_size - base_idx += 1 # advance by 1 so that we definitely don't repeat data after resuming + base_idx = resume_rg_idx // ddp_world_size # in units of ddp_world_size + base_idx += 1 # advance by 1 so that we definitely don't repeat data after resuming rg_idx = base_idx * ddp_world_size + ddp_rank - resume_rg_idx = None # set to None as we only want to do this a single time + resume_rg_idx = None # set to None as we only want to do this a single time else: rg_idx = ddp_rank while rg_idx < pf.num_row_groups: rg = pf.read_row_group(rg_idx) - batch = rg.column('text').to_pylist() # each batch is a parquet group, e.g. 1024 rows + batch = rg.column('text').to_pylist() # each batch is a parquet group, e.g. 1024 rows # the tokenizer encode might want to go in even smaller batches, e.g. 128 rows for i in range(0, len(batch), tokenizer_batch_size): - yield batch[i:i+tokenizer_batch_size], (pq_idx, rg_idx) - rg_idx += ddp_world_size # advance to the next row group (in DDP) - pq_idx += 1 # advance to the next parquet file + yield batch[i : i + tokenizer_batch_size], (pq_idx, rg_idx) + rg_idx += ddp_world_size # advance to the next row group (in DDP) + pq_idx += 1 # advance to the next parquet file + batches = document_batches() # Now emit batches of tokens. - needed_tokens = B * T + 1 # +1 is because we also need the target at the last token + needed_tokens = B * T + 1 # +1 is because we also need the target at the last token # get the tokenizer and the bos token tokenizer = get_tokenizer() bos_token = tokenizer.get_bos_token_id() # scratch buffer holds the tokens for one iteration - token_buffer = deque() # we stream tokens on the right and pop from the left + token_buffer = deque() # we stream tokens on the right and pop from the left while True: # Accumulate enough tokens for one iteration before yielding. while len(token_buffer) < needed_tokens: @@ -71,16 +76,20 @@ def tokenizing_distributed_data_loader_with_state(B, T, split, tokenizer_threads tokens = [token_buffer.popleft() for _ in range(needed_tokens)] # CUDA supports memory pinning for asynchronous transfers between CPU and GPU use_cuda_optimizations = device == "cuda" - scratch = torch.tensor(tokens, dtype=torch.long, pin_memory=use_cuda_optimizations) # in PyTorch, long=int64 + scratch = torch.tensor(tokens, dtype=torch.long, pin_memory=use_cuda_optimizations) # in PyTorch, long=int64 # Create the inputs/targets as 1D tensors inputs_cpu = scratch[:-1] targets_cpu = scratch[1:] # Reshape to 2D and move to GPU async inputs = inputs_cpu.view(B, T).to(device=device, non_blocking=use_cuda_optimizations) targets = targets_cpu.view(B, T).to(device=device, non_blocking=use_cuda_optimizations) - state_dict = {"pq_idx": pq_idx, "rg_idx": rg_idx} # we need this in case we wish to approximately resume training + state_dict = { + "pq_idx": pq_idx, + "rg_idx": rg_idx, + } # we need this in case we wish to approximately resume training yield inputs, targets, state_dict + def tokenizing_distributed_data_loader(*args, **kwargs): # helper function that only emits the inputs/targets and not the state_dict for inputs, targets, state_dict in tokenizing_distributed_data_loader_with_state(*args, **kwargs): diff --git a/nanochat/dataset.py b/nanochat/dataset.py index 602daed..5cad9ab 100644 --- a/nanochat/dataset.py +++ b/nanochat/dataset.py @@ -7,13 +7,14 @@ This file contains utilities for: For details of how the dataset was prepared, see `repackage_data_reference.py`. """ -import os import argparse +import os import time -import requests -import pyarrow.parquet as pq from multiprocessing import Pool +import pyarrow.parquet as pq +import requests + from nanochat.common import get_base_dir # ----------------------------------------------------------------------------- @@ -21,8 +22,8 @@ from nanochat.common import get_base_dir # The URL on the internet where the data is hosted and downloaded from on demand BASE_URL = "https://huggingface.co/datasets/karpathy/fineweb-edu-100b-shuffle/resolve/main" -MAX_SHARD = 1822 # the last datashard is shard_01822.parquet -index_to_filename = lambda index: f"shard_{index:05d}.parquet" # format of the filenames +MAX_SHARD = 1822 # the last datashard is shard_01822.parquet +index_to_filename = lambda index: f"shard_{index:05d}.parquet" # format of the filenames base_dir = get_base_dir() DATA_DIR = os.path.join(base_dir, "base_data") os.makedirs(DATA_DIR, exist_ok=True) @@ -30,16 +31,15 @@ os.makedirs(DATA_DIR, exist_ok=True) # ----------------------------------------------------------------------------- # These functions are useful utilities to other modules, can/should be imported + def list_parquet_files(data_dir=None): - """ Looks into a data dir and returns full paths to all parquet files. """ + """Looks into a data dir and returns full paths to all parquet files.""" data_dir = DATA_DIR if data_dir is None else data_dir - parquet_files = sorted([ - f for f in os.listdir(data_dir) - if f.endswith('.parquet') and not f.endswith('.tmp') - ]) + parquet_files = sorted([f for f in os.listdir(data_dir) if f.endswith('.parquet') and not f.endswith('.tmp')]) parquet_paths = [os.path.join(data_dir, f) for f in parquet_files] return parquet_paths + def parquets_iter_batched(split, start=0, step=1): """ Iterate through the dataset, in batches of underlying row_groups for efficiency. @@ -56,9 +56,10 @@ def parquets_iter_batched(split, start=0, step=1): texts = rg.column('text').to_pylist() yield texts + # ----------------------------------------------------------------------------- def download_single_file(index): - """ Downloads a single file index, with some backoff """ + """Downloads a single file index, with some backoff""" # Construct the local filepath for this file and skip if it already exists filename = index_to_filename(index) @@ -78,7 +79,7 @@ def download_single_file(index): response = requests.get(url, stream=True, timeout=30) response.raise_for_status() # Write to temporary file first - temp_path = filepath + f".tmp" + temp_path = filepath + ".tmp" with open(temp_path, 'wb') as f: for chunk in response.iter_content(chunk_size=1024 * 1024): # 1MB chunks if chunk: @@ -88,10 +89,10 @@ def download_single_file(index): print(f"Successfully downloaded {filename}") return True - except (requests.RequestException, IOError) as e: + except (OSError, requests.RequestException) as e: print(f"Attempt {attempt}/{max_attempts} failed for {filename}: {e}") # Clean up any partial files - for path in [filepath + f".tmp", filepath]: + for path in [filepath + ".tmp", filepath]: if os.path.exists(path): try: os.remove(path) @@ -99,7 +100,7 @@ def download_single_file(index): pass # Try a few times with exponential backoff: 2^attempt seconds if attempt < max_attempts: - wait_time = 2 ** attempt + wait_time = 2**attempt print(f"Waiting {wait_time} seconds before retry...") time.sleep(wait_time) else: @@ -111,8 +112,12 @@ def download_single_file(index): if __name__ == "__main__": parser = argparse.ArgumentParser(description="Download FineWeb-Edu 100BT dataset shards") - parser.add_argument("-n", "--num-files", type=int, default=-1, help="Number of shards to download (default: -1), -1 = disable") - parser.add_argument("-w", "--num-workers", type=int, default=4, help="Number of parallel download workers (default: 4)") + parser.add_argument( + "-n", "--num-files", type=int, default=-1, help="Number of shards to download (default: -1), -1 = disable" + ) + parser.add_argument( + "-w", "--num-workers", type=int, default=4, help="Number of parallel download workers (default: 4)" + ) args = parser.parse_args() num = MAX_SHARD + 1 if args.num_files == -1 else min(args.num_files, MAX_SHARD + 1) diff --git a/nanochat/engine.py b/nanochat/engine.py index d749d94..d37fd3a 100644 --- a/nanochat/engine.py +++ b/nanochat/engine.py @@ -11,15 +11,17 @@ Notes: The whole thing is made as efficient as possible. """ -import torch -import torch.nn.functional as F import signal import warnings -from contextlib import contextmanager from collections import deque -from nanochat.common import compute_init, autodetect_device_type +from contextlib import contextmanager, nullcontext + +import torch +import torch.nn.functional as F + from nanochat.checkpoint_manager import load_model -from contextlib import nullcontext +from nanochat.common import autodetect_device_type, compute_init + # ----------------------------------------------------------------------------- # Calculator tool helpers @@ -33,17 +35,19 @@ def timeout(duration, formula): yield signal.alarm(0) + def eval_with_timeout(formula, max_time=3): try: with timeout(max_time, formula): with warnings.catch_warnings(): warnings.simplefilter("ignore", SyntaxWarning) return eval(formula, {"__builtins__": {}}, {}) - except Exception as e: + except Exception: signal.alarm(0) # print(f"Warning: Failed to eval {formula}, exception: {e}") # it's ok ignore wrong calculator usage return None + def use_calculator(expr): """ Evaluate a Python expression safely. @@ -65,9 +69,25 @@ def use_calculator(expr): return None # Disallow dangerous patterns - dangerous_patterns = ['__', 'import', 'exec', 'eval', 'compile', 'open', 'file', - 'input', 'raw_input', 'globals', 'locals', 'vars', 'dir', - 'getattr', 'setattr', 'delattr', 'hasattr'] + dangerous_patterns = [ + '__', + 'import', + 'exec', + 'eval', + 'compile', + 'open', + 'file', + 'input', + 'raw_input', + 'globals', + 'locals', + 'vars', + 'dir', + 'getattr', + 'setattr', + 'delattr', + 'hasattr', + ] expr_lower = expr.lower() if any(pattern in expr_lower for pattern in dangerous_patterns): return None @@ -79,6 +99,7 @@ def use_calculator(expr): # Evaluate with timeout return eval_with_timeout(expr) + # ----------------------------------------------------------------------------- class KVCache: """ @@ -90,7 +111,7 @@ class KVCache: # Each of K/V is of shape (B, H, T, D) and we have one per layer of the Transformer. self.kv_shape = (num_layers, 2, batch_size, num_heads, seq_len, head_dim) self.kv_cache = None - self.pos = 0 # current position in time in the cache + self.pos = 0 # current position in time in the cache def reset(self): self.pos = 0 @@ -122,7 +143,7 @@ class KVCache: dtype, device = other.kv_cache.dtype, other.kv_cache.device self.kv_cache = torch.empty(self.kv_shape, dtype=dtype, device=device) # 3) copy the data over - self.kv_cache[:, :, :, :, :other.pos, :] = other.kv_cache + self.kv_cache[:, :, :, :, : other.pos, :] = other.kv_cache # 4) update the pos self.pos = other.pos @@ -135,8 +156,8 @@ class KVCache: t0, t1 = self.pos, self.pos + T_add # Dynamically grow the cache if needed if t1 > self.kv_cache.size(4): - t_needed = t1 + 1024 # as much as we need plus buffer of 1024 - t_needed = (t_needed + 1023) & ~1023 # then round up to the nearest multiple of 1024 + t_needed = t1 + 1024 # as much as we need plus buffer of 1024 + t_needed = (t_needed + 1023) & ~1023 # then round up to the nearest multiple of 1024 additional_shape = list(self.kv_cache.shape) additional_shape[4] = t_needed - self.kv_cache.size(4) additional_cache = torch.empty(additional_shape, dtype=k.dtype, device=k.device) @@ -173,22 +194,24 @@ def sample_next_token(logits, rng, temperature=1.0, top_k=None): probs = F.softmax(logits, dim=-1) return torch.multinomial(probs, num_samples=1, generator=rng) + # ----------------------------------------------------------------------------- + class RowState: # Per-row state tracking during generation def __init__(self, current_tokens=None): - self.current_tokens = current_tokens or [] # Current token sequence for this row - self.forced_tokens = deque() # Queue of tokens to force inject - self.in_python_block = False # Whether we are inside a python block - self.python_expr_tokens = [] # Tokens of the current python expression - self.completed = False # Whether this row has completed generation + self.current_tokens = current_tokens or [] # Current token sequence for this row + self.forced_tokens = deque() # Queue of tokens to force inject + self.in_python_block = False # Whether we are inside a python block + self.python_expr_tokens = [] # Tokens of the current python expression + self.completed = False # Whether this row has completed generation + class Engine: - def __init__(self, model, tokenizer): self.model = model - self.tokenizer = tokenizer # needed for tool use + self.tokenizer = tokenizer # needed for tool use @torch.inference_mode() def generate(self, tokens, num_samples=1, max_tokens=None, temperature=1.0, top_k=None, seed=42): @@ -204,8 +227,8 @@ class Engine: python_end = get_special("<|python_end|>") output_start = get_special("<|output_start|>") output_end = get_special("<|output_end|>") - assistant_end = get_special("<|assistant_end|>") # if sampled, ends row - bos = self.tokenizer.get_bos_token_id() # if sampled, ends row + assistant_end = get_special("<|assistant_end|>") # if sampled, ends row + bos = self.tokenizer.get_bos_token_id() # if sampled, ends row # 1) Run a batch 1 prefill of the prompt tokens m = self.model.config @@ -229,7 +252,7 @@ class Engine: **kv_model_kwargs, ) kv_cache_decode.prefill(kv_cache_prefill) - del kv_cache_prefill # no need to keep this memory around + del kv_cache_prefill # no need to keep this memory around # 3) Initialize states for each sample row_states = [RowState(tokens.copy()) for _ in range(num_samples)] @@ -259,12 +282,12 @@ class Engine: sampled_tokens = next_ids[:, 0].tolist() # Process each row: choose the next token, update state, optional tool use - token_column = [] # contains the next token id along each row - token_masks = [] # contains the mask (was it sampled (1) or forced (0)?) along each row + token_column = [] # contains the next token id along each row + token_masks = [] # contains the mask (was it sampled (1) or forced (0)?) along each row for i, state in enumerate(row_states): # Select the next token in this row - is_forced = len(state.forced_tokens) > 0 # are there tokens waiting to be forced in deque? - token_masks.append(0 if is_forced else 1) # mask is 0 if forced, 1 if sampled + is_forced = len(state.forced_tokens) > 0 # are there tokens waiting to be forced in deque? + token_masks.append(0 if is_forced else 1) # mask is 0 if forced, 1 if sampled next_token = state.forced_tokens.popleft() if is_forced else sampled_tokens[i] token_column.append(next_token) # Update the state of this row to include the next token @@ -327,10 +350,13 @@ if __name__ == "__main__": is equivalent to the faster Engine.generate function here. """ import time + # init compute ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init() device_type = autodetect_device_type() - autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext() + autocast_ctx = ( + torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext() + ) # load the model and tokenizer model, tokenizer, meta = load_model("base", device, phase="eval") @@ -357,12 +383,12 @@ if __name__ == "__main__": # generate tokens with Engine generated_tokens = [] engine = Engine(model, tokenizer) - stream = engine.generate(prompt_tokens, num_samples=1, **kwargs) # note: runs in fp32 + stream = engine.generate(prompt_tokens, num_samples=1, **kwargs) # note: runs in fp32 torch.cuda.synchronize() t0 = time.time() with autocast_ctx: for token_column, token_masks in stream: - token = token_column[0] # only print out the first row + token = token_column[0] # only print out the first row generated_tokens.append(token) chunk = tokenizer.decode([token]) print(chunk, end="", flush=True) diff --git a/nanochat/execution.py b/nanochat/execution.py index 6f50c74..923f5df 100644 --- a/nanochat/execution.py +++ b/nanochat/execution.py @@ -30,17 +30,18 @@ import platform import signal import tempfile from dataclasses import dataclass -from typing import Optional # ----------------------------------------------------------------------------- + @dataclass class ExecutionResult: """Result of executing Python code in a sandbox.""" + success: bool stdout: str stderr: str - error: Optional[str] = None + error: str | None = None timeout: bool = False memory_exceeded: bool = False @@ -101,13 +102,13 @@ class WriteOnlyStringIO(io.StringIO): """StringIO that throws an exception when it's read from""" def read(self, *args, **kwargs): - raise IOError + raise OSError def readline(self, *args, **kwargs): - raise IOError + raise OSError def readlines(self, *args, **kwargs): - raise IOError + raise OSError def readable(self, *args, **kwargs): """Returns True if the IO object can be read.""" @@ -131,7 +132,7 @@ def chdir(root): os.chdir(cwd) -def reliability_guard(maximum_memory_bytes: Optional[int] = None): +def reliability_guard(maximum_memory_bytes: int | None = None): """ This disables various destructive functions and prevents the generated code from interfering with the test (e.g. fork bomb, killing other processes, @@ -147,6 +148,7 @@ def reliability_guard(maximum_memory_bytes: Optional[int] = None): if platform.uname().system != "Darwin": # These resource limit calls seem to fail on macOS (Darwin), skip? import resource + resource.setrlimit(resource.RLIMIT_AS, (maximum_memory_bytes, maximum_memory_bytes)) resource.setrlimit(resource.RLIMIT_DATA, (maximum_memory_bytes, maximum_memory_bytes)) resource.setrlimit(resource.RLIMIT_STACK, (maximum_memory_bytes, maximum_memory_bytes)) @@ -211,10 +213,9 @@ def reliability_guard(maximum_memory_bytes: Optional[int] = None): sys.modules["tkinter"] = None -def _unsafe_execute(code: str, timeout: float, maximum_memory_bytes: Optional[int], result_dict): +def _unsafe_execute(code: str, timeout: float, maximum_memory_bytes: int | None, result_dict): """Execute code in a subprocess with safety guards. Results are written to result_dict.""" with create_tempdir(): - # These system calls are needed when cleaning up tempdir. import os import shutil @@ -228,14 +229,16 @@ def _unsafe_execute(code: str, timeout: float, maximum_memory_bytes: Optional[in reliability_guard(maximum_memory_bytes=maximum_memory_bytes) # Default to failure - result_dict.update({ - "success": False, - "stdout": "", - "stderr": "", - "timeout": False, - "memory_exceeded": False, - "error": None, - }) + result_dict.update( + { + "success": False, + "stdout": "", + "stderr": "", + "timeout": False, + "memory_exceeded": False, + "error": None, + } + ) try: exec_globals = {} @@ -253,28 +256,36 @@ def _unsafe_execute(code: str, timeout: float, maximum_memory_bytes: Optional[in # uncomment the following line and proceed at your own risk: exec(code, exec_globals) - result_dict.update({ - "success": True, - "stdout": stdout_capture.getvalue(), - "stderr": stderr_capture.getvalue(), - }) + result_dict.update( + { + "success": True, + "stdout": stdout_capture.getvalue(), + "stderr": stderr_capture.getvalue(), + } + ) except TimeoutException: - result_dict.update({ - "timeout": True, - "error": "Execution timed out", - }) + result_dict.update( + { + "timeout": True, + "error": "Execution timed out", + } + ) except MemoryError as e: - result_dict.update({ - "memory_exceeded": True, - "error": f"Memory limit exceeded: {e}", - }) + result_dict.update( + { + "memory_exceeded": True, + "error": f"Memory limit exceeded: {e}", + } + ) except BaseException as e: - result_dict.update({ - "error": f"{type(e).__name__}: {e}", - }) + result_dict.update( + { + "error": f"{type(e).__name__}: {e}", + } + ) # Needed for cleaning up. shutil.rmtree = rmtree @@ -285,8 +296,8 @@ def _unsafe_execute(code: str, timeout: float, maximum_memory_bytes: Optional[in def execute_code( code: str, - timeout: float = 5.0, # 5 seconds default - maximum_memory_bytes: Optional[int] = 256 * 1024 * 1024, # 256MB default + timeout: float = 5.0, # 5 seconds default + maximum_memory_bytes: int | None = 256 * 1024 * 1024, # 256MB default ) -> ExecutionResult: """ Execute Python code in a sandboxed environment. @@ -310,10 +321,7 @@ def execute_code( manager = multiprocessing.Manager() result_dict = manager.dict() - p = multiprocessing.Process( - target=_unsafe_execute, - args=(code, timeout, maximum_memory_bytes, result_dict) - ) + p = multiprocessing.Process(target=_unsafe_execute, args=(code, timeout, maximum_memory_bytes, result_dict)) p.start() p.join(timeout=timeout + 1) @@ -346,4 +354,3 @@ def execute_code( timeout=result_dict["timeout"], memory_exceeded=result_dict["memory_exceeded"], ) - diff --git a/nanochat/gpt.py b/nanochat/gpt.py index 216343c..bc839e8 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -12,24 +12,25 @@ Notable features: """ import math -from functools import partial from dataclasses import dataclass +from functools import partial import torch import torch.nn as nn import torch.nn.functional as F -from nanochat.common import get_dist_info, print0 -from nanochat.muon import Muon, DistMuon from nanochat.adamw import DistAdamW +from nanochat.common import get_dist_info +from nanochat.muon import DistMuon, Muon + @dataclass class GPTConfig: sequence_len: int = 1024 vocab_size: int = 50304 n_layer: int = 12 - n_head: int = 6 # number of query heads - n_kv_head: int = 6 # number of key/value heads (GQA) + n_head: int = 6 # number of query heads + n_kv_head: int = 6 # number of key/value heads (GQA) n_embd: int = 768 @@ -41,13 +42,14 @@ def norm(x): def apply_rotary_emb(x, cos, sin): assert x.ndim == 4 # multihead attention d = x.shape[3] // 2 - x1, x2 = x[..., :d], x[..., d:] # split up last time into two halves - y1 = x1 * cos + x2 * sin # rotate pairs of dims + x1, x2 = x[..., :d], x[..., d:] # split up last time into two halves + y1 = x1 * cos + x2 * sin # rotate pairs of dims y2 = x1 * (-sin) + x2 * cos - out = torch.cat([y1, y2], 3) # re-assemble - out = out.to(x.dtype) # ensure input/output dtypes match + out = torch.cat([y1, y2], 3) # re-assemble + out = out.to(x.dtype) # ensure input/output dtypes match return out + class CausalSelfAttention(nn.Module): def __init__(self, config, layer_idx): super().__init__() @@ -73,18 +75,24 @@ class CausalSelfAttention(nn.Module): # Apply Rotary Embeddings to queries and keys to get relative positional encoding cos, sin = cos_sin - q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin) # QK rotary embedding - q, k = norm(q), norm(k) # QK norm - q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) # make head be batch dim, i.e. (B, T, H, D) -> (B, H, T, D) + q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin) # QK rotary embedding + q, k = norm(q), norm(k) # QK norm + q, k, v = ( + q.transpose(1, 2), + k.transpose(1, 2), + v.transpose(1, 2), + ) # make head be batch dim, i.e. (B, T, H, D) -> (B, H, T, D) # Apply KV cache: insert current k,v into cache, get the full view so far if kv_cache is not None: k, v = kv_cache.insert_kv(self.layer_idx, k, v) - Tq = q.size(2) # number of queries in this forward pass - Tk = k.size(2) # number of keys/values in total (in the cache + current forward pass) + Tq = q.size(2) # number of queries in this forward pass + Tk = k.size(2) # number of keys/values in total (in the cache + current forward pass) # Attention: queries attend to keys/values autoregressively. A few cases to handle: - enable_gqa = self.n_head != self.n_kv_head # Group Query Attention (GQA): duplicate key/value heads to match query heads if desired + enable_gqa = ( + self.n_head != self.n_kv_head + ) # Group Query Attention (GQA): duplicate key/value heads to match query heads if desired if kv_cache is None or Tq == Tk: # During training (no KV cache), attend as usual with causal attention # And even if there is KV cache, we can still use this simple version when Tq == Tk @@ -96,9 +104,9 @@ class CausalSelfAttention(nn.Module): else: # During inference AND we have a chunk of queries in this forward pass: # First, each query attends to all the cached keys/values (i.e. full prefix) - attn_mask = torch.zeros((Tq, Tk), dtype=torch.bool, device=q.device) # True = keep, False = mask + attn_mask = torch.zeros((Tq, Tk), dtype=torch.bool, device=q.device) # True = keep, False = mask prefix_len = Tk - Tq - if prefix_len > 0: # can't be negative but could be zero + if prefix_len > 0: # can't be negative but could be zero attn_mask[:, :prefix_len] = True # Then, causal attention within this chunk attn_mask[:, prefix_len:] = torch.tril(torch.ones((Tq, Tq), dtype=torch.bool, device=q.device)) @@ -139,19 +147,21 @@ class GPT(nn.Module): def __init__(self, config): super().__init__() self.config = config - self.transformer = nn.ModuleDict({ - "wte": nn.Embedding(config.vocab_size, config.n_embd), - "h": nn.ModuleList([Block(config, layer_idx) for layer_idx in range(config.n_layer)]), - }) + self.transformer = nn.ModuleDict( + { + "wte": nn.Embedding(config.vocab_size, config.n_embd), + "h": nn.ModuleList([Block(config, layer_idx) for layer_idx in range(config.n_layer)]), + } + ) self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) # To support meta device initialization, we init the rotary embeddings here, but it's fake # As for rotary_seq_len, these rotary embeddings are pretty small/cheap in memory, # so let's just over-compute them, but assert fail if we ever reach that amount. # In the future we can dynamically grow the cache, for now it's fine. - self.rotary_seq_len = config.sequence_len * 10 # 10X over-compute should be enough, TODO make nicer? + self.rotary_seq_len = config.sequence_len * 10 # 10X over-compute should be enough, TODO make nicer? head_dim = config.n_embd // config.n_head cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim) - self.register_buffer("cos", cos, persistent=False) # persistent=False means it's not saved to the checkpoint + self.register_buffer("cos", cos, persistent=False) # persistent=False means it's not saved to the checkpoint self.register_buffer("sin", sin, persistent=False) def init_weights(self): @@ -195,18 +205,23 @@ class GPT(nn.Module): # calculate the rotation frequencies at each (time, channel) pair freqs = torch.outer(t, inv_freq) cos, sin = freqs.cos(), freqs.sin() - cos, sin = cos.bfloat16(), sin.bfloat16() # keep them in bfloat16 - cos, sin = cos[None, :, None, :], sin[None, :, None, :] # add batch and head dims for later broadcasting + cos, sin = cos.bfloat16(), sin.bfloat16() # keep them in bfloat16 + cos, sin = cos[None, :, None, :], sin[None, :, None, :] # add batch and head dims for later broadcasting return cos, sin def get_device(self): return self.transformer.wte.weight.device def estimate_flops(self): - """ Return the estimated FLOPs per token for the model. Ref: https://arxiv.org/abs/2204.02311 """ + """Return the estimated FLOPs per token for the model. Ref: https://arxiv.org/abs/2204.02311""" nparams = sum(p.numel() for p in self.parameters()) nparams_embedding = self.transformer.wte.weight.numel() - l, h, q, t = self.config.n_layer, self.config.n_head, self.config.n_embd // self.config.n_head, self.config.sequence_len + l, h, q, t = ( + self.config.n_layer, + self.config.n_head, + self.config.n_embd // self.config.n_head, + self.config.sequence_len, + ) num_flops_per_token = 6 * (nparams - nparams_embedding) + 12 * l * h * q * t return num_flops_per_token @@ -245,12 +260,16 @@ class GPT(nn.Module): B, T = idx.size() # Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim/2)) - assert T <= self.cos.size(1), f"Sequence length grew beyond the rotary embeddings cache: {T} > {self.cos.size(1)}" - assert idx.device == self.cos.device, f"Rotary embeddings and idx are on different devices: {idx.device} != {self.cos.device}" + assert T <= self.cos.size(1), ( + f"Sequence length grew beyond the rotary embeddings cache: {T} > {self.cos.size(1)}" + ) + assert idx.device == self.cos.device, ( + f"Rotary embeddings and idx are on different devices: {idx.device} != {self.cos.device}" + ) assert self.cos.dtype == torch.bfloat16, "Rotary embeddings must be in bfloat16" # if kv cache exists, we need to offset the rotary embeddings to the current position in the cache T0 = 0 if kv_cache is None else kv_cache.get_pos() - cos_sin = self.cos[:, T0:T0+T], self.sin[:, T0:T0+T] # truncate cache to current sequence length + cos_sin = self.cos[:, T0 : T0 + T], self.sin[:, T0 : T0 + T] # truncate cache to current sequence length # Forward the trunk of the Transformer x = self.transformer.wte(idx) @@ -265,14 +284,16 @@ class GPT(nn.Module): # training mode: compute and return the loss # TODO: experiment with Liger Kernels / chunked cross-entropy etc. logits = self.lm_head(x) - logits = softcap * torch.tanh(logits / softcap) # logits softcap - logits = logits.float() # use tf32/fp32 for logits - loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1, reduction=loss_reduction) + logits = softcap * torch.tanh(logits / softcap) # logits softcap + logits = logits.float() # use tf32/fp32 for logits + loss = F.cross_entropy( + logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1, reduction=loss_reduction + ) return loss else: # inference mode: compute and return the logits logits = self.lm_head(x) - logits = softcap * torch.tanh(logits / softcap) # logits softcap + logits = softcap * torch.tanh(logits / softcap) # logits softcap return logits @torch.inference_mode() @@ -289,10 +310,10 @@ class GPT(nn.Module): if temperature > 0: rng = torch.Generator(device=device) rng.manual_seed(seed) - ids = torch.tensor([tokens], dtype=torch.long, device=device) # add batch dim + ids = torch.tensor([tokens], dtype=torch.long, device=device) # add batch dim for _ in range(max_tokens): - logits = self.forward(ids) # (B, T, vocab_size) - logits = logits[:, -1, :] # (B, vocab_size) + logits = self.forward(ids) # (B, T, vocab_size) + logits = logits[:, -1, :] # (B, vocab_size) if top_k is not None: v, _ = torch.topk(logits, min(top_k, logits.size(-1))) logits[logits < v[:, [-1]]] = -float('Inf') diff --git a/nanochat/loss_eval.py b/nanochat/loss_eval.py index 5a556e6..889f251 100644 --- a/nanochat/loss_eval.py +++ b/nanochat/loss_eval.py @@ -1,10 +1,13 @@ """ A number of functions that help with evaluating a base model. """ + import math + import torch import torch.distributed as dist + @torch.no_grad() def evaluate_bpb(model, batches, steps, token_bytes): """ @@ -30,20 +33,16 @@ def evaluate_bpb(model, batches, steps, token_bytes): batch_iter = iter(batches) for _ in range(steps): x, y = next(batch_iter) - loss2d = model(x, y, loss_reduction='none') # (B, T) - loss2d = loss2d.view(-1) # flatten - y = y.view(-1) # flatten - if (y.int() < 0).any(): # mps does not currently have kernel for < 0 for int64, only int32 + loss2d = model(x, y, loss_reduction='none') # (B, T) + loss2d = loss2d.view(-1) # flatten + y = y.view(-1) # flatten + if (y.int() < 0).any(): # mps does not currently have kernel for < 0 for int64, only int32 # slightly more complex code path if some target tokens are ignore_index (e.g. -1) # any target token < 0 is to be ignored: do NOT index token_bytes with negatives valid = y >= 0 y_safe = torch.where(valid, y, torch.zeros_like(y)) # map valid targets to their byte length; ignored targets contribute 0 bytes - num_bytes2d = torch.where( - valid, - token_bytes[y_safe], - torch.zeros_like(y, dtype=token_bytes.dtype) - ) + num_bytes2d = torch.where(valid, token_bytes[y_safe], torch.zeros_like(y, dtype=token_bytes.dtype)) total_nats += (loss2d * (num_bytes2d > 0)).sum() total_bytes += num_bytes2d.sum() else: diff --git a/nanochat/muon.py b/nanochat/muon.py index d916103..3caee93 100644 --- a/nanochat/muon.py +++ b/nanochat/muon.py @@ -2,9 +2,11 @@ Muon optimizer from Keller et al. Also a lot of borrowing of ideas from modded-nanogpt. """ + import torch -from torch import Tensor import torch.distributed as dist +from torch import Tensor + @torch.compile def zeropower_via_newtonschulz5(G: Tensor, steps: int) -> Tensor: @@ -17,8 +19,10 @@ def zeropower_via_newtonschulz5(G: Tensor, steps: int) -> Tensor: where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model performance at all relative to UV^T, where USV^T = G is the SVD. """ - assert G.ndim >= 2 # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng - a, b, c = (3.4445, -4.7750, 2.0315) + assert ( + G.ndim >= 2 + ) # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng + a, b, c = (3.4445, -4.7750, 2.0315) X = G.bfloat16() if G.size(-2) > G.size(-1): X = X.mT @@ -28,13 +32,16 @@ def zeropower_via_newtonschulz5(G: Tensor, steps: int) -> Tensor: # Perform the NS iterations for _ in range(steps): A = X @ X.mT - B = b * A + c * A @ A # quintic computation strategy adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng + B = ( + b * A + c * A @ A + ) # quintic computation strategy adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng X = a * X + B @ X if G.size(-2) > G.size(-1): X = X.mT return X + class Muon(torch.optim.Optimizer): """ Muon - MomentUm Orthogonalized by Newton-schulz @@ -57,6 +64,7 @@ class Muon(torch.optim.Optimizer): nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) ns_steps: The number of Newton-Schulz iteration steps to use. """ + def __init__(self, params, lr=0.02, momentum=0.95, nesterov=True, ns_steps=5): defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps) params: list[Tensor] = [*params] @@ -80,7 +88,7 @@ class Muon(torch.optim.Optimizer): buf.lerp_(g, 1 - group["momentum"]) g = g.lerp_(buf, group["momentum"]) if group["nesterov"] else buf g = zeropower_via_newtonschulz5(g, steps=group["ns_steps"]) - p.add_(g, alpha=-group["lr"] * max(1, p.size(-2) / p.size(-1))**0.5) + p.add_(g, alpha=-group["lr"] * max(1, p.size(-2) / p.size(-1)) ** 0.5) class DistMuon(torch.optim.Optimizer): @@ -104,14 +112,14 @@ class DistMuon(torch.optim.Optimizer): nesterov: if True, Nesterov-style update (g <- lerp(g, buf, momentum)); else use buf ns_steps: number of Newton–Schulz iterations for the orthogonalization """ - def __init__(self, params, lr: float = 0.02, momentum: float = 0.95, - nesterov: bool = True, ns_steps: int = 5): + + def __init__(self, params, lr: float = 0.02, momentum: float = 0.95, nesterov: bool = True, ns_steps: int = 5): defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps) params = list(params) assert all(p.ndim == 2 for p in params), "Muon expects 2D parameters only" rank = dist.get_rank() # Group all parameters by their shape - shapes = sorted({p.shape for p in params}) # sort to ensure consistent / deterministic ordering + shapes = sorted({p.shape for p in params}) # sort to ensure consistent / deterministic ordering param_groups = [] for shape in shapes: group_params = [p for p in params if p.shape == shape] @@ -129,7 +137,9 @@ class DistMuon(torch.optim.Optimizer): world_size = dist.get_world_size() # Ensure all grads exist - assert all(p.grad is not None for group in self.param_groups for p in group["params"]), "All params must have grads" + assert all(p.grad is not None for group in self.param_groups for p in group["params"]), ( + "All params must have grads" + ) # Kick off all the reduce scatter operations to average up the gradients across all ranks all_reduce_futures = [] @@ -141,7 +151,7 @@ class DistMuon(torch.optim.Optimizer): # The compute owner of each param is rank i % world_size owner_idx = base_i + rank # each rank stacks up its chunk of world_size params into a list - rs_input = [p.grad for p in params[base_i:base_i + world_size]] + rs_input = [p.grad for p in params[base_i : base_i + world_size]] # pad rs_input with the zero buffer to complete the group rs_input.extend([zero_buffer] * (world_size - len(rs_input))) # the output buffer gets strided across the group based on the rank @@ -159,9 +169,9 @@ class DistMuon(torch.optim.Optimizer): # Go through params in groups of world_size. for base_i in range(0, len(params), world_size): # The compute owner of each param is rank i % world_size - owner_idx = base_i + rank # calculate the index of the param that this rank owns + owner_idx = base_i + rank # calculate the index of the param that this rank owns # Wait for the reduce scatter to complete - all_reduce_futures[future_idx].wait() # possibly later we could use wait_any polling instead + all_reduce_futures[future_idx].wait() # possibly later we could use wait_any polling instead future_idx += 1 # Owner computes the Muon update, result is in its param if owner_idx < len(params): @@ -174,12 +184,12 @@ class DistMuon(torch.optim.Optimizer): buf.lerp_(g, 1.0 - group["momentum"]) g = g.lerp_(buf, group["momentum"]) if group["nesterov"] else buf g = zeropower_via_newtonschulz5(g, steps=group["ns_steps"]) - scale = (max(1.0, p.size(-2) / p.size(-1)) ** 0.5) + scale = max(1.0, p.size(-2) / p.size(-1)) ** 0.5 p.add_(g, alpha=-group["lr"] * scale) # Replicate updated parameters to all ranks ag_input = params[owner_idx] if owner_idx < len(params) else zero_buffer - ag_output = params[base_i:base_i + world_size] - ag_output.extend([torch.empty_like(zero_buffer) for _ in range(world_size - len(ag_output))]) # pad + ag_output = params[base_i : base_i + world_size] + ag_output.extend([torch.empty_like(zero_buffer) for _ in range(world_size - len(ag_output))]) # pad work = dist.all_gather(ag_output, ag_input, async_op=True).get_future() all_gather_futures.append(work) diff --git a/nanochat/report.py b/nanochat/report.py index 0b0ebd7..d95a473 100644 --- a/nanochat/report.py +++ b/nanochat/report.py @@ -2,16 +2,18 @@ Utilities for generating training report cards. More messy code than usual, will fix. """ +import datetime import os +import platform import re import shutil -import subprocess import socket -import datetime -import platform +import subprocess + import psutil import torch + def run_command(cmd): """Run a shell command and return output, or None if it fails.""" try: @@ -22,6 +24,7 @@ def run_command(cmd): except: return None + def get_git_info(): """Get current git commit, branch, and dirty status.""" info = {} @@ -38,18 +41,14 @@ def get_git_info(): return info + def get_gpu_info(): """Get GPU information.""" if not torch.cuda.is_available(): return {"available": False} num_devices = torch.cuda.device_count() - info = { - "available": True, - "count": num_devices, - "names": [], - "memory_gb": [] - } + info = {"available": True, "count": num_devices, "names": [], "memory_gb": []} for i in range(num_devices): props = torch.cuda.get_device_properties(i) @@ -61,6 +60,7 @@ def get_gpu_info(): return info + def get_system_info(): """Get system information.""" info = {} @@ -83,6 +83,7 @@ def get_system_info(): return info + def estimate_cost(gpu_info, runtime_hours=None): """Estimate training cost based on GPU type and runtime.""" @@ -111,9 +112,10 @@ def estimate_cost(gpu_info, runtime_hours=None): return { "hourly_rate": hourly_rate, "gpu_type": gpu_name, - "estimated_total": hourly_rate * runtime_hours if runtime_hours else None + "estimated_total": hourly_rate * runtime_hours if runtime_hours else None, } + def generate_header(): """Generate the header for a training report.""" timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") @@ -165,12 +167,12 @@ Generated: {timestamp} num_chars = len(packaged) num_lines = len(packaged.split('\n')) num_files = len([x for x in packaged.split('\n') if x.startswith('')]) - num_tokens = num_chars // 4 # assume approximately 4 chars per token + num_tokens = num_chars // 4 # assume approximately 4 chars per token # count dependencies via uv.lock uv_lock_lines = 0 if os.path.exists('uv.lock'): - with open('uv.lock', 'r', encoding='utf-8') as f: + with open('uv.lock', encoding='utf-8') as f: uv_lock_lines = len(f.readlines()) header += f""" @@ -184,12 +186,15 @@ Generated: {timestamp} """ return header + # ----------------------------------------------------------------------------- + def slugify(text): """Slugify a text string.""" return text.lower().replace(" ", "-") + # the expected files and their order EXPECTED_FILES = [ "tokenizer-training.md", @@ -207,10 +212,11 @@ EXPECTED_FILES = [ # the metrics we're currently interested in chat_metrics = ["ARC-Easy", "ARC-Challenge", "MMLU", "GSM8K", "HumanEval", "ChatCORE"] + def extract(section, keys): """simple def to extract a single key from a section""" if not isinstance(keys, list): - keys = [keys] # convenience + keys = [keys] # convenience out = {} for line in section.split("\n"): for key in keys: @@ -218,6 +224,7 @@ def extract(section, keys): out[key] = line.split(":")[1].strip() return out + def extract_timestamp(content, prefix): """Extract timestamp from content with given prefix.""" for line in content.split('\n'): @@ -229,6 +236,7 @@ def extract_timestamp(content, prefix): pass return None + class Report: """Maintains a bunch of logs, generates a final markdown report.""" @@ -269,14 +277,14 @@ class Report: report_dir = self.report_dir report_file = os.path.join(report_dir, "report.md") print(f"Generating report to {report_file}") - final_metrics = {} # the most important final metrics we'll add as table at the end + final_metrics = {} # the most important final metrics we'll add as table at the end start_time = None end_time = None with open(report_file, "w", encoding="utf-8") as out_file: # write the header first header_file = os.path.join(report_dir, "header.md") if os.path.exists(header_file): - with open(header_file, "r", encoding="utf-8") as f: + with open(header_file, encoding="utf-8") as f: header_content = f.read() out_file.write(header_content) start_time = extract_timestamp(header_content, "Run started:") @@ -284,7 +292,7 @@ class Report: bloat_data = re.search(r"### Bloat\n(.*?)\n\n", header_content, re.DOTALL) bloat_data = bloat_data.group(1) if bloat_data else "" else: - start_time = None # will cause us to not write the total wall clock time + start_time = None # will cause us to not write the total wall clock time bloat_data = "[bloat data missing]" print(f"Warning: {header_file} does not exist. Did you forget to run `nanochat reset`?") # process all the individual sections @@ -293,7 +301,7 @@ class Report: if not os.path.exists(section_file): print(f"Warning: {section_file} does not exist, skipping") continue - with open(section_file, "r", encoding="utf-8") as in_file: + with open(section_file, encoding="utf-8") as in_file: section = in_file.read() # Extract timestamp from this section (the last section's timestamp will "stick" as end_time) if "rl" not in file_name: @@ -307,7 +315,7 @@ class Report: if file_name == "chat-evaluation-sft.md": final_metrics["sft"] = extract(section, chat_metrics) if file_name == "chat-evaluation-rl.md": - final_metrics["rl"] = extract(section, "GSM8K") # RL only evals GSM8K + final_metrics["rl"] = extract(section, "GSM8K") # RL only evals GSM8K # append this section of the report out_file.write(section) out_file.write("\n") @@ -354,7 +362,7 @@ class Report: else: out_file.write("Total wall clock time: unknown\n") # also cp the report.md file to current directory - print(f"Copying report.md to current directory for convenience") + print("Copying report.md to current directory for convenience") shutil.copy(report_file, "report.md") return report_file @@ -378,18 +386,23 @@ class Report: f.write(f"Run started: {start_time}\n\n---\n\n") print(f"Reset report and wrote header to {header_file}") + # ----------------------------------------------------------------------------- # nanochat-specific convenience functions + class DummyReport: def log(self, *args, **kwargs): pass + def reset(self, *args, **kwargs): pass + def get_report(): # just for convenience, only rank 0 logs to report from nanochat.common import get_base_dir, get_dist_info + ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info() if ddp_rank == 0: report_dir = os.path.join(get_base_dir(), "report") @@ -397,10 +410,18 @@ def get_report(): else: return DummyReport() + if __name__ == "__main__": import argparse + parser = argparse.ArgumentParser(description="Generate or reset nanochat training reports.") - parser.add_argument("command", nargs="?", default="generate", choices=["generate", "reset"], help="Operation to perform (default: generate)") + parser.add_argument( + "command", + nargs="?", + default="generate", + choices=["generate", "reset"], + help="Operation to perform (default: generate)", + ) args = parser.parse_args() if args.command == "generate": get_report().generate() diff --git a/nanochat/tokenizer.py b/nanochat/tokenizer.py index 880f854..405de28 100644 --- a/nanochat/tokenizer.py +++ b/nanochat/tokenizer.py @@ -6,36 +6,39 @@ Two implementations are available: 2) Our own RustBPE Tokenizer for training and tiktoken for efficient inference """ -import os import copy +import os from functools import lru_cache SPECIAL_TOKENS = [ # every document begins with the Beginning of Sequence (BOS) token that delimits documents "<|bos|>", # tokens below are only used during finetuning to render Conversations into token ids - "<|user_start|>", # user messages + "<|user_start|>", # user messages "<|user_end|>", - "<|assistant_start|>", # assistant messages + "<|assistant_start|>", # assistant messages "<|assistant_end|>", - "<|python_start|>", # assistant invokes python REPL tool + "<|python_start|>", # assistant invokes python REPL tool "<|python_end|>", - "<|output_start|>", # python REPL outputs back to assistant + "<|output_start|>", # python REPL outputs back to assistant "<|output_end|>", ] # NOTE: this split pattern deviates from GPT-4 in that we use \p{N}{1,2} instead of \p{N}{1,3} # I did this because I didn't want to "waste" too many tokens on numbers for smaller vocab sizes. # I haven't validated that this is actually a good idea, TODO. -SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,2}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+""" +SPLIT_PATTERN = ( + r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,2}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+""" +) # ----------------------------------------------------------------------------- # Generic GPT-4-style tokenizer based on HuggingFace Tokenizer +from tokenizers import Regex, decoders, pre_tokenizers from tokenizers import Tokenizer as HFTokenizer -from tokenizers import pre_tokenizers, decoders, Regex from tokenizers.models import BPE from tokenizers.trainers import BpeTrainer + class HuggingFaceTokenizer: """Light wrapper around HuggingFace Tokenizer for some utilities""" @@ -59,11 +62,13 @@ class HuggingFaceTokenizer: def train_from_iterator(cls, text_iterator, vocab_size): # train from an iterator of text # Configure the HuggingFace Tokenizer - tokenizer = HFTokenizer(BPE( - byte_fallback=True, # needed! - unk_token=None, - fuse_unk=False, - )) + tokenizer = HFTokenizer( + BPE( + byte_fallback=True, # needed! + unk_token=None, + fuse_unk=False, + ) + ) # Normalizer: None tokenizer.normalizer = None # Pre-tokenizer: GPT-4 style @@ -71,11 +76,13 @@ class HuggingFaceTokenizer: # NOTE: The pattern was changed from \p{N}{1,3} to \p{N}{1,2} because I suspect it is harmful to # very small models and smaller vocab sizes, because it is a little bit wasteful in the token space. # (but I haven't validated this! TODO) - gpt4_split_regex = Regex(SPLIT_PATTERN) # huggingface demands that you wrap it in Regex!! - tokenizer.pre_tokenizer = pre_tokenizers.Sequence([ - pre_tokenizers.Split(pattern=gpt4_split_regex, behavior="isolated", invert=False), - pre_tokenizers.ByteLevel(add_prefix_space=False, use_regex=False) - ]) + gpt4_split_regex = Regex(SPLIT_PATTERN) # huggingface demands that you wrap it in Regex!! + tokenizer.pre_tokenizer = pre_tokenizers.Sequence( + [ + pre_tokenizers.Split(pattern=gpt4_split_regex, behavior="isolated", invert=False), + pre_tokenizers.ByteLevel(add_prefix_space=False, use_regex=False), + ] + ) # Decoder: ByteLevel (it pairs together with the ByteLevel pre-tokenizer) tokenizer.decoder = decoders.ByteLevel() # Post-processor: None @@ -84,7 +91,7 @@ class HuggingFaceTokenizer: trainer = BpeTrainer( vocab_size=vocab_size, show_progress=True, - min_frequency=0, # no minimum frequency + min_frequency=0, # no minimum frequency initial_alphabet=pre_tokenizers.ByteLevel.alphabet(), special_tokens=SPECIAL_TOKENS, ) @@ -146,12 +153,16 @@ class HuggingFaceTokenizer: self.tokenizer.save(tokenizer_path) print(f"Saved tokenizer to {tokenizer_path}") + # ----------------------------------------------------------------------------- # Tokenizer based on rustbpe + tiktoken combo import pickle -import rustbpe + import tiktoken +import rustbpe + + class RustBPETokenizer: """Light wrapper around tiktoken (for efficient inference) but train with rustbpe""" @@ -176,8 +187,8 @@ class RustBPETokenizer: enc = tiktoken.Encoding( name="rustbpe", pat_str=pattern, - mergeable_ranks=mergeable_ranks, # dict[bytes, int] (token bytes -> merge priority rank) - special_tokens=special_tokens, # dict[str, int] (special token name -> token id) + mergeable_ranks=mergeable_ranks, # dict[bytes, int] (token bytes -> merge priority rank) + special_tokens=special_tokens, # dict[str, int] (special token name -> token id) ) return cls(enc, "<|bos|>") @@ -225,14 +236,14 @@ class RustBPETokenizer: if isinstance(text, str): ids = self.enc.encode_ordinary(text) if prepend is not None: - ids.insert(0, prepend_id) # TODO: slightly inefficient here? :( hmm + ids.insert(0, prepend_id) # TODO: slightly inefficient here? :( hmm if append is not None: ids.append(append_id) elif isinstance(text, list): ids = self.enc.encode_ordinary_batch(text, num_threads=num_threads) if prepend is not None: for ids_row in ids: - ids_row.insert(0, prepend_id) # TODO: same + ids_row.insert(0, prepend_id) # TODO: same if append is not None: for ids_row in ids: ids_row.append(append_id) @@ -264,6 +275,7 @@ class RustBPETokenizer: """ # ids, masks that we will return and a helper function to help build them up. ids, mask = [], [] + def add_tokens(token_ids, mask_val): if isinstance(token_ids, int): token_ids = [token_ids] @@ -274,7 +286,7 @@ class RustBPETokenizer: # => just merge it with the second (user) message if conversation["messages"][0]["role"] == "system": # some conversation surgery is necessary here for now... - conversation = copy.deepcopy(conversation) # avoid mutating the original + conversation = copy.deepcopy(conversation) # avoid mutating the original messages = conversation["messages"] assert messages[1]["role"] == "user", "System message must be followed by a user message" messages[1]["content"] = messages[0]["content"] + "\n\n" + messages[1]["content"] @@ -286,17 +298,21 @@ class RustBPETokenizer: # fetch all the special tokens we need bos = self.get_bos_token_id() user_start, user_end = self.encode_special("<|user_start|>"), self.encode_special("<|user_end|>") - assistant_start, assistant_end = self.encode_special("<|assistant_start|>"), self.encode_special("<|assistant_end|>") + assistant_start, assistant_end = ( + self.encode_special("<|assistant_start|>"), + self.encode_special("<|assistant_end|>"), + ) python_start, python_end = self.encode_special("<|python_start|>"), self.encode_special("<|python_end|>") output_start, output_end = self.encode_special("<|output_start|>"), self.encode_special("<|output_end|>") # now we can tokenize the conversation add_tokens(bos, 0) for i, message in enumerate(messages): - # some sanity checking here around assumptions, to prevent footguns must_be_from = "user" if i % 2 == 0 else "assistant" - assert message["role"] == must_be_from, f"Message {i} is from {message['role']} but should be from {must_be_from}" + assert message["role"] == must_be_from, ( + f"Message {i} is from {message['role']} but should be from {must_be_from}" + ) # content can be either a simple string or a list of parts (e.g. containing tool calls) content = message["content"] @@ -363,10 +379,10 @@ class RustBPETokenizer: Unlike the Chat SFT case, we don't need to return the mask. """ # We have some surgery to do: we need to pop the last message (of the Assistant) - conversation = copy.deepcopy(conversation) # avoid mutating the original + conversation = copy.deepcopy(conversation) # avoid mutating the original messages = conversation["messages"] assert messages[-1]["role"] == "assistant", "Last message must be from the Assistant" - messages.pop() # remove the last message (of the Assistant) inplace + messages.pop() # remove the last message (of the Assistant) inplace # Now tokenize the conversation ids, mask = self.render_conversation(conversation) @@ -376,23 +392,31 @@ class RustBPETokenizer: ids.append(assistant_start) return ids + # ----------------------------------------------------------------------------- # nanochat-specific convenience functions + def get_tokenizer(): from nanochat.common import get_base_dir + base_dir = get_base_dir() tokenizer_dir = os.path.join(base_dir, "tokenizer") # return HuggingFaceTokenizer.from_directory(tokenizer_dir) return RustBPETokenizer.from_directory(tokenizer_dir) + def get_token_bytes(device="cpu"): import torch + from nanochat.common import get_base_dir + base_dir = get_base_dir() tokenizer_dir = os.path.join(base_dir, "tokenizer") token_bytes_path = os.path.join(tokenizer_dir, "token_bytes.pt") - assert os.path.exists(token_bytes_path), f"Token bytes not found at {token_bytes_path}? It gets written by tok_train.py" + assert os.path.exists(token_bytes_path), ( + f"Token bytes not found at {token_bytes_path}? It gets written by tok_train.py" + ) with open(token_bytes_path, "rb") as f: token_bytes = torch.load(f, map_location=device) return token_bytes diff --git a/scripts/base_eval.py b/scripts/base_eval.py index 3663538..b28de43 100644 --- a/scripts/base_eval.py +++ b/scripts/base_eval.py @@ -9,23 +9,31 @@ torchrun --nproc_per_node=8 -m scripts.base_eval The script will print the CORE metric to the console. """ -import os + import csv -import time import json -import yaml -import shutil +import os import random -import zipfile +import shutil import tempfile +import time +import zipfile from contextlib import nullcontext import torch +import yaml -from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir, autodetect_device_type, download_file_with_lock -from nanochat.tokenizer import HuggingFaceTokenizer from nanochat.checkpoint_manager import load_model +from nanochat.common import ( + autodetect_device_type, + compute_cleanup, + compute_init, + download_file_with_lock, + get_base_dir, + print0, +) from nanochat.core_eval import evaluate_task +from nanochat.tokenizer import HuggingFaceTokenizer # ----------------------------------------------------------------------------- # nanochat specific function dealing with I/O etc. @@ -33,6 +41,7 @@ from nanochat.core_eval import evaluate_task # ~162MB of data needed to evaluate the CORE metric EVAL_BUNDLE_URL = "https://karpathy-public.s3.us-west-2.amazonaws.com/eval_bundle.zip" + def place_eval_bundle(file_path): # here file_path is the path to the eval_bundle.zip file # we need to unzip it and place it in the base directory @@ -45,6 +54,7 @@ def place_eval_bundle(file_path): shutil.move(extracted_bundle_dir, eval_bundle_dir) print0(f"Placed eval_bundle directory at {eval_bundle_dir}") + def evaluate_model(model, tokenizer, device, max_per_task=-1): """ Evaluate a base model on the CORE benchmark. @@ -59,13 +69,13 @@ def evaluate_model(model, tokenizer, device, max_per_task=-1): config_path = os.path.join(eval_bundle_dir, "core.yaml") data_base_path = os.path.join(eval_bundle_dir, "eval_data") eval_meta_data = os.path.join(eval_bundle_dir, "eval_meta_data.csv") - with open(config_path, 'r', encoding='utf-8') as f: + with open(config_path, encoding='utf-8') as f: config = yaml.safe_load(f) tasks = config['icl_tasks'] # Load random baseline values from eval metadata random_baselines = {} - with open(eval_meta_data, 'r', encoding='utf-8') as f: + with open(eval_meta_data, encoding='utf-8') as f: reader = csv.DictReader(f) for row in reader: task_name = row['Eval Task'] @@ -82,13 +92,13 @@ def evaluate_model(model, tokenizer, device, max_per_task=-1): 'task_type': task['icl_task_type'], 'dataset_uri': task['dataset_uri'], 'num_fewshot': task['num_fewshot'][0], - 'continuation_delimiter': task.get('continuation_delimiter', ' ') + 'continuation_delimiter': task.get('continuation_delimiter', ' '), } print0(f"Evaluating: {label} ({task_meta['num_fewshot']}-shot, type: {task_meta['task_type']})... ", end='') # Load data for this task data_path = os.path.join(data_base_path, task_meta['dataset_uri']) - with open(data_path, 'r', encoding='utf-8') as f: + with open(data_path, encoding='utf-8') as f: data = [json.loads(line.strip()) for line in f] # shuffle the data because in many cases it appears ordered but we want @@ -109,18 +119,17 @@ def evaluate_model(model, tokenizer, device, max_per_task=-1): print0(f"accuracy: {accuracy:.4f} | centered: {centered_result:.4f} | time: {end_time - start_time:.2f}s") core_metric = sum(centered_results.values()) / len(centered_results) - out = { - "results": results, - "centered_results": centered_results, - "core_metric": core_metric - } + out = {"results": results, "centered_results": centered_results, "core_metric": core_metric} return out + # ----------------------------------------------------------------------------- # HuggingFace loading utilities and light wrappers for a model + class ModelWrapper: """Lightweight wrapper for a HuggingFace model""" + def __init__(self, model, max_seq_len=None): self.model = model self.max_seq_len = max_seq_len @@ -130,10 +139,12 @@ class ModelWrapper: logits = outputs.logits return logits + def load_hf_model(hf_path: str, device): print0(f"Loading model from: {hf_path}") # Load the model from transformers import AutoModelForCausalLM + model = AutoModelForCausalLM.from_pretrained(hf_path) model.to(device) model.eval() @@ -143,9 +154,11 @@ def load_hf_model(hf_path: str, device): tokenizer = HuggingFaceTokenizer.from_pretrained(hf_path) return model, tokenizer + # ----------------------------------------------------------------------------- def main(): import argparse + parser = argparse.ArgumentParser() parser.add_argument('--hf-path', type=str, default=None, help='HuggingFace model path to evaluate') parser.add_argument('--max-per-task', type=int, default=-1, help='Max examples per task to evaluate (-1 = disable)') @@ -154,7 +167,9 @@ def main(): # distributed / precision setup device_type = autodetect_device_type() ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type) - autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext() + autocast_ctx = ( + torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext() + ) # Load model and tokenizer from command line or from file system if args.hf_path is not None: @@ -162,13 +177,13 @@ def main(): hf_path = args.hf_path print0(f"Loading huggingface model from: {hf_path}") model, tokenizer = load_hf_model(hf_path, device) - model_name = hf_path # just for logging - model_slug = hf_path.replace("/", "-") # for the output csv file + model_name = hf_path # just for logging + model_slug = hf_path.replace("/", "-") # for the output csv file else: # load a local model from the file system model, tokenizer, meta = load_model("base", device, phase="eval") - model_name = f"base_model (step {meta['step']})" # just for logging - model_slug = f"base_model_{meta['step']:06d}" # for the output csv file + model_name = f"base_model (step {meta['step']})" # just for logging + model_slug = f"base_model_{meta['step']:06d}" # for the output csv file # Evaluate the model with autocast_ctx: @@ -190,23 +205,28 @@ def main(): f.write(f"{label:<35}, {results[label]:<10.6f}, {centered_results[label]:<10.6f}\n") f.write(f"{'CORE':<35}, {'':<10}, {core_metric:<10.6f}\n") # Print the content of the csv file to console too - print0("="*80) + print0("=" * 80) print0(f"Model: {model_name}") - print0("="*80) - with open(output_csv_path, 'r', encoding='utf-8') as f: + print0("=" * 80) + with open(output_csv_path, encoding='utf-8') as f: print0(f.read()) # Log to report from nanochat.report import get_report - get_report().log(section="Base model evaluation", data=[ - { - "Model": model_name, - "CORE metric": core_metric, - }, - centered_results, # the full table - ]) + + get_report().log( + section="Base model evaluation", + data=[ + { + "Model": model_name, + "CORE metric": core_metric, + }, + centered_results, # the full table + ], + ) compute_cleanup() + if __name__ == "__main__": main() diff --git a/scripts/base_loss.py b/scripts/base_loss.py index abcde5f..4a7631c 100644 --- a/scripts/base_loss.py +++ b/scripts/base_loss.py @@ -6,30 +6,35 @@ Loads a checkpoint, and: Example run as: torchrun --standalone --nproc_per_node=8 -m scripts.base_loss """ + import os from contextlib import nullcontext + import torch + from nanochat.checkpoint_manager import load_model -from nanochat.common import compute_init, print0, compute_cleanup, autodetect_device_type +from nanochat.common import autodetect_device_type, compute_cleanup, compute_init, print0 from nanochat.dataloader import tokenizing_distributed_data_loader -from nanochat.tokenizer import get_token_bytes -from nanochat.loss_eval import evaluate_bpb from nanochat.engine import Engine +from nanochat.loss_eval import evaluate_bpb +from nanochat.tokenizer import get_token_bytes # Configuration device_batch_size = 32 -split_tokens = 20*524288 # number of tokens to evaluate per split -model_tag = None # optional model tag for the output directory name -model_step = None # optional model step for the output directory name -device_type = "" # cuda|cpu|mps (empty => autodetect) -exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file +split_tokens = 20 * 524288 # number of tokens to evaluate per split +model_tag = None # optional model tag for the output directory name +model_step = None # optional model step for the output directory name +device_type = "" # cuda|cpu|mps (empty => autodetect) +exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file # Load the base model and the tokenizer device_type = autodetect_device_type() if device_type == "" else device_type ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type) model, tokenizer, meta = load_model("base", device, phase="eval", model_tag=model_tag, step=model_step) -sequence_len = meta["model_config"]["sequence_len"] # could be arbitrary really -autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext() +sequence_len = meta["model_config"]["sequence_len"] # could be arbitrary really +autocast_ctx = ( + torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext() +) # Evaluate the loss on each split tokens_per_step = device_batch_size * sequence_len * ddp_world_size @@ -67,13 +72,17 @@ if ddp_rank == 0: # Log to report from nanochat.report import get_report -get_report().log(section="Base model loss", data=[ - { - "train bpb": bpb_results["train"], - "val bpb": bpb_results["val"], - }, - {f"sample {i}": sample for i, sample in enumerate(samples)}, -]) + +get_report().log( + section="Base model loss", + data=[ + { + "train bpb": bpb_results["train"], + "val bpb": bpb_results["val"], + }, + {f"sample {i}": sample for i, sample in enumerate(samples)}, + ], +) # Cleanup compute_cleanup() diff --git a/scripts/base_train.py b/scripts/base_train.py index c9ea6c9..bde163d 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -12,67 +12,83 @@ python -m scripts.base_train --depth=4 --max_seq_len=512 --device_batch_size=1 - """ import os + os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" import time from contextlib import nullcontext -import wandb import torch +import wandb -from nanochat.gpt import GPT, GPTConfig +from nanochat.checkpoint_manager import load_checkpoint, save_checkpoint +from nanochat.common import ( + DummyWandb, + autodetect_device_type, + compute_cleanup, + compute_init, + get_base_dir, + print0, + print_banner, +) from nanochat.dataloader import tokenizing_distributed_data_loader, tokenizing_distributed_data_loader_with_state -from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir, autodetect_device_type -from nanochat.tokenizer import get_tokenizer, get_token_bytes -from nanochat.checkpoint_manager import save_checkpoint, load_checkpoint -from nanochat.loss_eval import evaluate_bpb from nanochat.engine import Engine +from nanochat.gpt import GPT, GPTConfig +from nanochat.loss_eval import evaluate_bpb +from nanochat.tokenizer import get_token_bytes, get_tokenizer from scripts.base_eval import evaluate_model + print_banner() # ----------------------------------------------------------------------------- # User settings -run = "dummy" # wandb run name default ("dummy" is special - we won't log to wandb) +run = "dummy" # wandb run name default ("dummy" is special - we won't log to wandb) # Runtime -device_type = "" # cuda|cpu|mps (empty => autodetect good device type default, in order: CUDA > MPS > CPU) +device_type = "" # cuda|cpu|mps (empty => autodetect good device type default, in order: CUDA > MPS > CPU) # Model architecture -depth = 20 # the depth of the Transformer model to train, rest of the kwargs are derived -max_seq_len = 2048 # max context length +depth = 20 # the depth of the Transformer model to train, rest of the kwargs are derived +max_seq_len = 2048 # max context length # Training horizon. Only one of these 3 will be used, in this order of precedence. -num_iterations = -1 # explicit number of steps of the optimization (-1 = disable) -target_flops = -1.0 # calculate num_iterations to reach target_flops. Useful for scaling laws experiments (-1 = disable) -target_param_data_ratio = 20 # calculate num_iterations to maintain fixed data:param ratio (Chinchilla=20) (-1 = disable) +num_iterations = -1 # explicit number of steps of the optimization (-1 = disable) +target_flops = ( + -1.0 +) # calculate num_iterations to reach target_flops. Useful for scaling laws experiments (-1 = disable) +target_param_data_ratio = ( + 20 # calculate num_iterations to maintain fixed data:param ratio (Chinchilla=20) (-1 = disable) +) # Optimization -device_batch_size = 32 # per-device batch size (set to not OOM) -total_batch_size = 524288 # total desired batch size, in #tokens -embedding_lr = 0.2 # learning rate for the embedding parameters (Adam) -unembedding_lr = 0.004 # learning rate for the unembedding parameters (Adam) -weight_decay = 0.0 # weight decay for the embedding/unembedding parameters (Adam) -matrix_lr = 0.02 # learning rate for the matrix parameters (Muon) -grad_clip = 1.0 # gradient clipping value (0.0 = disabled) -warmup_ratio = 0.0 # ratio of iterations for LR warmup -warmdown_ratio = 0.2 # ratio of iterations for LR warmdown -final_lr_frac = 0.0 # final LR is this fraction of the initial LR -resume_from_step = -1 # resume training from this step of the optimization (-1 = disable) +device_batch_size = 32 # per-device batch size (set to not OOM) +total_batch_size = 524288 # total desired batch size, in #tokens +embedding_lr = 0.2 # learning rate for the embedding parameters (Adam) +unembedding_lr = 0.004 # learning rate for the unembedding parameters (Adam) +weight_decay = 0.0 # weight decay for the embedding/unembedding parameters (Adam) +matrix_lr = 0.02 # learning rate for the matrix parameters (Muon) +grad_clip = 1.0 # gradient clipping value (0.0 = disabled) +warmup_ratio = 0.0 # ratio of iterations for LR warmup +warmdown_ratio = 0.2 # ratio of iterations for LR warmdown +final_lr_frac = 0.0 # final LR is this fraction of the initial LR +resume_from_step = -1 # resume training from this step of the optimization (-1 = disable) # Evaluation -eval_every = 250 # every how many steps to evaluate the model for val bpb -eval_tokens = 20*524288 # number of tokens to evaluate val loss on -core_metric_every = 2000 # every how many steps to evaluate the core metric (-1 = disable) -core_metric_max_per_task = 500 # examples per task in estimating the core metric -sample_every = 2000 # every how many steps to sample from the model -save_every = -1 # every how many steps to save model checkpoints (-1 = disable, and save only at the end of the run) +eval_every = 250 # every how many steps to evaluate the model for val bpb +eval_tokens = 20 * 524288 # number of tokens to evaluate val loss on +core_metric_every = 2000 # every how many steps to evaluate the core metric (-1 = disable) +core_metric_max_per_task = 500 # examples per task in estimating the core metric +sample_every = 2000 # every how many steps to sample from the model +save_every = -1 # every how many steps to save model checkpoints (-1 = disable, and save only at the end of the run) # Output -model_tag = "" # optionally override the model tag for the output checkpoint directory name +model_tag = "" # optionally override the model tag for the output checkpoint directory name # now allow CLI to override the settings via the configurator lol -config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))] -exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file -user_config = {k: globals()[k] for k in config_keys} # will be useful for logging +config_keys = [k for k, v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))] +exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file +user_config = {k: globals()[k] for k in config_keys} # will be useful for logging # ----------------------------------------------------------------------------- # Compute init device_type = autodetect_device_type() if device_type == "" else device_type ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type) -master_process = ddp_rank == 0 # this process will do logging, checkpointing etc. -autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext() +master_process = ddp_rank == 0 # this process will do logging, checkpointing etc. +autocast_ctx = ( + torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext() +) synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0 @@ -88,9 +104,9 @@ print0(f"Vocab size: {vocab_size:,}") # Model kwargs are derived from the desired depth of the model num_layers = depth -model_dim = depth * 64 # aspect ratio 64 (usually this is varied from 64 -> 128 as model size increases) -num_heads = max(1, (model_dim + 127) // 128) # head dim 128 (the division here is ceil div) -num_kv_heads = num_heads # default is 1:1 GQA (Group Query Attention) ratio (i.e. GQA is disabled) +model_dim = depth * 64 # aspect ratio 64 (usually this is varied from 64 -> 128 as model size increases) +num_heads = max(1, (model_dim + 127) // 128) # head dim 128 (the division here is ceil div) +num_kv_heads = num_heads # default is 1:1 GQA (Group Query Attention) ratio (i.e. GQA is disabled) print0(f"num_layers: {num_layers}") print0(f"model_dim: {model_dim}") print0(f"num_heads: {num_heads}") @@ -98,8 +114,8 @@ print0(f"num_kv_heads: {num_kv_heads}") # Optimizer / data / training length related hyperparameters # figure out the needed gradient accumulation to reach the desired total batch size -tokens_per_fwdbwd = device_batch_size * max_seq_len # tokens per iteration for a single rank -world_tokens_per_fwdbwd = tokens_per_fwdbwd * ddp_world_size # total tokens per iteration for all ranks +tokens_per_fwdbwd = device_batch_size * max_seq_len # tokens per iteration for a single rank +world_tokens_per_fwdbwd = tokens_per_fwdbwd * ddp_world_size # total tokens per iteration for all ranks assert total_batch_size % world_tokens_per_fwdbwd == 0 grad_accum_steps = total_batch_size // world_tokens_per_fwdbwd print0(f"Tokens / micro-batch / rank: {device_batch_size} x {max_seq_len} = {tokens_per_fwdbwd:,}") @@ -110,7 +126,14 @@ print0(f"Total batch size {total_batch_size:,} => gradient accumulation steps: { # Initialize the Model # Create a new model with random weights -model_config_kwargs = dict(sequence_len=max_seq_len, vocab_size=vocab_size, n_layer=num_layers, n_head=num_heads, n_kv_head=num_kv_heads, n_embd=model_dim) +model_config_kwargs = dict( + sequence_len=max_seq_len, + vocab_size=vocab_size, + n_layer=num_layers, + n_head=num_heads, + n_kv_head=num_kv_heads, + n_embd=model_dim, +) with torch.device("meta"): model_config = GPTConfig(**model_config_kwargs) model = GPT(model_config) @@ -119,17 +142,19 @@ model.init_weights() # If we are resuming, overwrite the model parameters with those of the checkpoint base_dir = get_base_dir() -output_dirname = model_tag if model_tag else f"d{depth}" # e.g. d12 +output_dirname = model_tag if model_tag else f"d{depth}" # e.g. d12 checkpoint_dir = os.path.join(base_dir, "base_checkpoints", output_dirname) resuming = resume_from_step != -1 if resuming: print0(f"Resuming optimization from step {resume_from_step}") - model_data, optimizer_data, meta_data = load_checkpoint(checkpoint_dir, resume_from_step, device, load_optimizer=True, rank=ddp_rank) + model_data, optimizer_data, meta_data = load_checkpoint( + checkpoint_dir, resume_from_step, device, load_optimizer=True, rank=ddp_rank + ) model.load_state_dict(model_data, strict=True, assign=True) - del model_data # free up this memory after the copy + del model_data # free up this memory after the copy -orig_model = model # original, uncompiled model, for saving raw model state_dict and for inference/evaluation (because the shapes may change shape) -model = torch.compile(model, dynamic=False) # the inputs to model will never change shape so dynamic=False is safe +orig_model = model # original, uncompiled model, for saving raw model state_dict and for inference/evaluation (because the shapes may change shape) +model = torch.compile(model, dynamic=False) # the inputs to model will never change shape so dynamic=False is safe num_params = sum(p.numel() for p in model.parameters()) print0(f"Number of parameters: {num_params:,}") num_flops_per_token = model.estimate_flops() @@ -152,30 +177,37 @@ else: raise ValueError("No training horizon specified") total_tokens = total_batch_size * num_iterations print0(f"Total number of training tokens: {total_tokens:,}") -print0(f"Tokens : Params ratio: {total_batch_size * num_iterations / num_params:.2f}") # Chinchilla is ~20 +print0(f"Tokens : Params ratio: {total_batch_size * num_iterations / num_params:.2f}") # Chinchilla is ~20 print0(f"Total training FLOPs estimate: {num_flops_per_token * total_tokens:e}") # ----------------------------------------------------------------------------- # Initialize the Optimizer (Muon for Linear layers, AdamW for embedding and lm_head) -optimizers = model.setup_optimizers(unembedding_lr=unembedding_lr, embedding_lr=embedding_lr, matrix_lr=matrix_lr, weight_decay=weight_decay) +optimizers = model.setup_optimizers( + unembedding_lr=unembedding_lr, embedding_lr=embedding_lr, matrix_lr=matrix_lr, weight_decay=weight_decay +) adamw_optimizer, muon_optimizer = optimizers if resuming: for opt, dat in zip(optimizers, optimizer_data): opt.load_state_dict(dat) - del optimizer_data # free up the memory + del optimizer_data # free up the memory # ----------------------------------------------------------------------------- # Initialize the DataLoaders for train/val tokens_dir = os.path.join(base_dir, "tokenized_data") dataloader_resume_state_dict = None if not resuming else meta_data["dataloader_state_dict"] -train_loader = tokenizing_distributed_data_loader_with_state(device_batch_size, max_seq_len, split="train", device=device, resume_state_dict=dataloader_resume_state_dict) -build_val_loader = lambda: tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="val", device=device) -x, y, dataloader_state_dict = next(train_loader) # kick off load of the very first batch of data +train_loader = tokenizing_distributed_data_loader_with_state( + device_batch_size, max_seq_len, split="train", device=device, resume_state_dict=dataloader_resume_state_dict +) +build_val_loader = lambda: tokenizing_distributed_data_loader( + device_batch_size, max_seq_len, split="val", device=device +) +x, y, dataloader_state_dict = next(train_loader) # kick off load of the very first batch of data # ----------------------------------------------------------------------------- # Set up hyperparameter schedulers + # Learning rate scheduler def get_lr_multiplier(it): warmup_iters = round(warmup_ratio * num_iterations) @@ -188,20 +220,22 @@ def get_lr_multiplier(it): progress = (num_iterations - it) / warmdown_iters return progress * 1.0 + (1 - progress) * final_lr_frac + # Momentum scheduler for Muon optimizer def get_muon_momentum(it): frac = min(it / 300, 1) momentum = (1 - frac) * 0.85 + frac * 0.95 return momentum + # ----------------------------------------------------------------------------- # Loop state (variables updated by the training loop) if not resuming: step = 0 min_val_bpb = float("inf") - smooth_train_loss = 0 # EMA of training loss - total_training_time = 0 # total wall-clock time of training + smooth_train_loss = 0 # EMA of training loss + total_training_time = 0 # total wall-clock time of training else: step = meta_data["step"] loop_state = meta_data["loop_state"] @@ -212,7 +246,7 @@ else: # ----------------------------------------------------------------------------- # Training loop while True: - last_step = step == num_iterations # loop runs num_iterations+1 times so that we can eval/save at the end + last_step = step == num_iterations # loop runs num_iterations+1 times so that we can eval/save at the end flops_so_far = num_flops_per_token * total_batch_size * step # once in a while: evaluate the val bpb (all ranks participate) @@ -225,12 +259,14 @@ while True: print0(f"Step {step:05d} | Validation bpb: {val_bpb:.4f}") if val_bpb < min_val_bpb: min_val_bpb = val_bpb - wandb_run.log({ - "step": step, - "total_training_flops": flops_so_far, - "total_training_time": total_training_time, - "val/bpb": val_bpb, - }) + wandb_run.log( + { + "step": step, + "total_training_flops": flops_so_far, + "total_training_time": total_training_time, + "val/bpb": val_bpb, + } + ) model.train() # once in a while: estimate the CORE metric (all ranks participate) @@ -241,12 +277,14 @@ while True: with autocast_ctx: results = evaluate_model(orig_model, tokenizer, device, max_per_task=core_metric_max_per_task) print0(f"Step {step:05d} | CORE metric: {results['core_metric']:.4f}") - wandb_run.log({ - "step": step, - "total_training_flops": flops_so_far, - "core_metric": results["core_metric"], - "centered_results": results["centered_results"], - }) + wandb_run.log( + { + "step": step, + "total_training_flops": flops_so_far, + "core_metric": results["core_metric"], + "centered_results": results["centered_results"], + } + ) model.train() # once in a while: sample from the model (only on master process) @@ -262,7 +300,7 @@ while True: "My favorite color is", "If 5*x + 3 = 13, then x is", ] - engine = Engine(orig_model, tokenizer) # use orig_model to avoid recompilation + engine = Engine(orig_model, tokenizer) # use orig_model to avoid recompilation for prompt in prompts: tokens = tokenizer(prompt, prepend="<|bos|>") with autocast_ctx: @@ -275,17 +313,17 @@ while True: save_checkpoint( checkpoint_dir, step, - orig_model.state_dict(), # model parameters - [opt.state_dict() for opt in optimizers], # optimizer states - { # metadata saved as json + orig_model.state_dict(), # model parameters + [opt.state_dict() for opt in optimizers], # optimizer states + { # metadata saved as json "step": step, - "val_bpb": val_bpb, # loss at last step + "val_bpb": val_bpb, # loss at last step "model_config": model_config_kwargs, - "user_config": user_config, # inputs to the training script + "user_config": user_config, # inputs to the training script "device_batch_size": device_batch_size, "max_seq_len": max_seq_len, "dataloader_state_dict": dataloader_state_dict, - "loop_state": { # all loop state (other than step) so that we can resume training + "loop_state": { # all loop state (other than step) so that we can resume training "min_val_bpb": min_val_bpb, "smooth_train_loss": smooth_train_loss, "total_training_time": total_training_time, @@ -306,15 +344,17 @@ while True: for micro_step in range(grad_accum_steps): with autocast_ctx: loss = model(x, y) - train_loss = loss.detach() # for logging - loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here + train_loss = loss.detach() # for logging + loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here loss.backward() - x, y, dataloader_state_dict = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward + x, y, dataloader_state_dict = next( + train_loader + ) # prefetch the next batch while the GPU is busy with forward/backward # gradient clipping grad_clip_enabled = grad_clip > 0.0 if grad_clip_enabled: grad_norm_tensor = torch.nn.utils.clip_grad_norm_(orig_model.parameters(), grad_clip) - grad_norm = grad_norm_tensor.item() # GPU tensor -> CPU float (note: cpu-gpu sync point) + grad_norm = grad_norm_tensor.item() # GPU tensor -> CPU float (note: cpu-gpu sync point) # step the optimizers lrm = get_lr_multiplier(step) for opt in optimizers: @@ -332,18 +372,20 @@ while True: # ------------------------------------------------------------------------- # logging - ema_beta = 0.9 # EMA decay factor for some smoothing just for nicer logging - smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss.item() # EMA the training loss - debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**(step + 1)) # debias the EMA + ema_beta = 0.9 # EMA decay factor for some smoothing just for nicer logging + smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss.item() # EMA the training loss + debiased_smooth_loss = smooth_train_loss / (1 - ema_beta ** (step + 1)) # debias the EMA pct_done = 100 * step / num_iterations tok_per_sec = int(total_batch_size / dt) flops_per_sec = num_flops_per_token * total_batch_size / dt - promised_flops_per_sec_h100 = 989e12 * ddp_world_size # bfloat16 H100 SXM and without 2:4 sparsity - mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in % + promised_flops_per_sec_h100 = 989e12 * ddp_world_size # bfloat16 H100 SXM and without 2:4 sparsity + mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in % if step > 10: - total_training_time += dt # only count the time after the first 10 steps + total_training_time += dt # only count the time after the first 10 steps print_grad_norm = f" grad norm: {grad_norm:.4f} |" if grad_clip_enabled else "" - print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} |{print_grad_norm} lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | total time: {total_training_time/60:.2f}m") + print0( + f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} |{print_grad_norm} lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | total time: {total_training_time / 60:.2f}m" + ) if step % 100 == 0: log_data = { "step": step, @@ -364,35 +406,39 @@ while True: # print a few more stats print0(f"Peak memory usage: {get_max_memory() / 1024 / 1024:.2f}MiB") -print0(f"Total training time: {total_training_time/60:.2f}m") +print0(f"Total training time: {total_training_time / 60:.2f}m") print0(f"Minimum validation bpb: {min_val_bpb:.4f}") # Log to report from nanochat.report import get_report -get_report().log(section="Base model training", data=[ - user_config, # CLI args - { # stats about the training setup - "Number of parameters": num_params, - "Number of FLOPs per token": f"{num_flops_per_token:e}", - "Calculated number of iterations": num_iterations, - "Number of training tokens": total_tokens, - "Tokens : Params ratio": total_batch_size * num_iterations / num_params, - "DDP world size": ddp_world_size, - "warmup_ratio": warmup_ratio, - "warmdown_ratio": warmdown_ratio, - "final_lr_frac": final_lr_frac, - }, - { # stats about training outcomes - "Minimum validation bpb": min_val_bpb, - "Final validation bpb": val_bpb, - "CORE metric estimate": results.get("core_metric", None), - "MFU %": f"{mfu:.2f}%", - "Total training flops": f"{flops_so_far:e}", - "Total training time": f"{total_training_time/60:.2f}m", - "Peak memory usage": f"{get_max_memory() / 1024 / 1024:.2f}MiB", - } -]) + +get_report().log( + section="Base model training", + data=[ + user_config, # CLI args + { # stats about the training setup + "Number of parameters": num_params, + "Number of FLOPs per token": f"{num_flops_per_token:e}", + "Calculated number of iterations": num_iterations, + "Number of training tokens": total_tokens, + "Tokens : Params ratio": total_batch_size * num_iterations / num_params, + "DDP world size": ddp_world_size, + "warmup_ratio": warmup_ratio, + "warmdown_ratio": warmdown_ratio, + "final_lr_frac": final_lr_frac, + }, + { # stats about training outcomes + "Minimum validation bpb": min_val_bpb, + "Final validation bpb": val_bpb, + "CORE metric estimate": results.get("core_metric", None), + "MFU %": f"{mfu:.2f}%", + "Total training flops": f"{flops_so_far:e}", + "Total training time": f"{total_training_time / 60:.2f}m", + "Peak memory usage": f"{get_max_memory() / 1024 / 1024:.2f}MiB", + }, + ], +) # cleanup -wandb_run.finish() # wandb run finish +wandb_run.finish() # wandb run finish compute_cleanup() diff --git a/scripts/chat_cli.py b/scripts/chat_cli.py index b14843a..d9b6373 100644 --- a/scripts/chat_cli.py +++ b/scripts/chat_cli.py @@ -4,12 +4,15 @@ New and upgraded chat mode because a lot of the code has changed since the last Intended to be run single GPU only atm: python -m scripts.chat_cli -i mid """ + import argparse -import torch -from nanochat.common import compute_init, autodetect_device_type from contextlib import nullcontext -from nanochat.engine import Engine + +import torch + from nanochat.checkpoint_manager import load_model +from nanochat.common import autodetect_device_type, compute_init +from nanochat.engine import Engine parser = argparse.ArgumentParser(description='Chat with the model') parser.add_argument('-i', '--source', type=str, default="sft", help="Source of the model: sft|mid|rl") @@ -18,7 +21,13 @@ parser.add_argument('-s', '--step', type=int, default=None, help='Step to load') parser.add_argument('-p', '--prompt', type=str, default='', help='Prompt the model, get a single response back') parser.add_argument('-t', '--temperature', type=float, default=0.6, help='Temperature for generation') parser.add_argument('-k', '--top-k', type=int, default=50, help='Top-k sampling parameter') -parser.add_argument('--device-type', type=str, default='', choices=['cuda', 'cpu', 'mps'], help='Device type for evaluation: cuda|cpu|mps. empty => autodetect') +parser.add_argument( + '--device-type', + type=str, + default='', + choices=['cuda', 'cpu', 'mps'], + help='Device type for evaluation: cuda|cpu|mps. empty => autodetect', +) parser.add_argument('-d', '--dtype', type=str, default='bfloat16', choices=['float32', 'bfloat16']) args = parser.parse_args() @@ -33,7 +42,10 @@ model, tokenizer, meta = load_model(args.source, device, phase="eval", model_tag # Special tokens for the chat state machine bos = tokenizer.get_bos_token_id() user_start, user_end = tokenizer.encode_special("<|user_start|>"), tokenizer.encode_special("<|user_end|>") -assistant_start, assistant_end = tokenizer.encode_special("<|assistant_start|>"), tokenizer.encode_special("<|assistant_end|>") +assistant_start, assistant_end = ( + tokenizer.encode_special("<|assistant_start|>"), + tokenizer.encode_special("<|assistant_end|>"), +) # Create Engine for efficient generation engine = Engine(model, tokenizer) @@ -47,7 +59,6 @@ print("-" * 50) conversation_tokens = [bos] while True: - if args.prompt: # Get the prompt from the launch command user_input = args.prompt @@ -89,7 +100,7 @@ while True: print("\nAssistant: ", end="", flush=True) with autocast_ctx: for token_column, token_masks in engine.generate(conversation_tokens, **generate_kwargs): - token = token_column[0] # pop the batch dimension (num_samples=1) + token = token_column[0] # pop the batch dimension (num_samples=1) response_tokens.append(token) token_text = tokenizer.decode([token]) print(token_text, end="", flush=True) diff --git a/scripts/chat_eval.py b/scripts/chat_eval.py index cae2f0f..697aa48 100644 --- a/scripts/chat_eval.py +++ b/scripts/chat_eval.py @@ -9,27 +9,28 @@ torchrun --nproc_per_node=8 -m scripts.chat_eval -- -a ARC-Easy """ import argparse -from functools import partial from contextlib import nullcontext +from functools import partial import torch import torch.distributed as dist -from nanochat.common import compute_init, compute_cleanup, get_dist_info, print0, autodetect_device_type from nanochat.checkpoint_manager import load_model +from nanochat.common import autodetect_device_type, compute_cleanup, compute_init, get_dist_info, print0 from nanochat.engine import Engine - -from tasks.humaneval import HumanEval -from tasks.mmlu import MMLU from tasks.arc import ARC from tasks.gsm8k import GSM8K +from tasks.humaneval import HumanEval +from tasks.mmlu import MMLU from tasks.spellingbee import SpellingBee # ----------------------------------------------------------------------------- # Generative evaluation loop (we go one problem at a time, sample, evaluate) -def run_generative_eval(task_object, tokenizer, model, engine, num_samples, max_new_tokens, temperature, top_k, max_problems=None): +def run_generative_eval( + task_object, tokenizer, model, engine, num_samples, max_new_tokens, temperature, top_k, max_problems=None +): ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info() device = model.get_device() @@ -62,7 +63,7 @@ def run_generative_eval(task_object, tokenizer, model, engine, num_samples, max_ num_passed += int(passed) # Logging (overwrite the same line in the console) - print(f"\r\033[KRank {ddp_rank} | {num_passed}/{total} ({100*num_passed/total:.2f}%)", end='', flush=True) + print(f"\r\033[KRank {ddp_rank} | {num_passed}/{total} ({100 * num_passed / total:.2f}%)", end='', flush=True) # Finish the in-place progress line with a newline before final summary print() @@ -77,21 +78,22 @@ def run_generative_eval(task_object, tokenizer, model, engine, num_samples, max_ total = total_tensor.item() print0("=" * 50) - print0(f"Final: {num_passed}/{total} ({100*num_passed/total:.2f}%)") + print0(f"Final: {num_passed}/{total} ({100 * num_passed / total:.2f}%)") # Return the accuracy - return num_passed/total + return num_passed / total + # ----------------------------------------------------------------------------- # Categorical evaluation loop # A lot easier because we don't have to sample. Therefore, we can actually go # batches at a time and just check the logits for correct answer choices. -def run_categorical_eval(task_object, tokenizer, model, batch_size, max_problems=None): +def run_categorical_eval(task_object, tokenizer, model, batch_size, max_problems=None): ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info() device = model.get_device() - bos = tokenizer.get_bos_token_id() # use BOS as pad token is ok, these positions are ignored + bos = tokenizer.get_bos_token_id() # use BOS as pad token is ok, these positions are ignored # We'll process batches of independent problems at a time because there is no sampling needed num_problems = len(task_object) if max_problems is None else min(len(task_object), max_problems) @@ -99,22 +101,26 @@ def run_categorical_eval(task_object, tokenizer, model, batch_size, max_problems num_batches = ceil_div(num_problems, batch_size) # Run the evaluation - letter_to_id_cache = {} # many letters will repeat often, let's save the tokenizer some work + letter_to_id_cache = {} # many letters will repeat often, let's save the tokenizer some work num_passed, total = 0, 0 for i in range(ddp_rank, num_batches, ddp_world_size): i0, i1 = i * batch_size, min((i + 1) * batch_size, num_problems) # Prepare the batch of problems. They might all be of different length, so we pad/collate them. conversations = [task_object[ii] for ii in range(i0, i1)] - prompt_ids = [tokenizer.render_for_completion(conversation) for conversation in conversations] # TODO: remake the way this works + prompt_ids = [ + tokenizer.render_for_completion(conversation) for conversation in conversations + ] # TODO: remake the way this works max_length = max(len(ids) for ids in prompt_ids) - answer_time_positions = [len(ids) - 1 for ids in prompt_ids] # where the last token is (and the predicted answer) + answer_time_positions = [ + len(ids) - 1 for ids in prompt_ids + ] # where the last token is (and the predicted answer) padded_prompt_ids = [ids + [bos] * (max_length - len(ids)) for ids in prompt_ids] prompt_ids = torch.tensor(padded_prompt_ids, dtype=torch.long, device=device) # Get the logits for the whole batch of conversations in parallel (efficiency win here) with torch.no_grad(): - logits = model(prompt_ids) # (B, T, V) + logits = model(prompt_ids) # (B, T, V) # Focus on the available answer on just the letters corresponding to choices # Note that this helps the evaluation a lot because it specifically narrows the focus to only the available letters @@ -150,15 +156,26 @@ def run_categorical_eval(task_object, tokenizer, model, batch_size, max_problems num_passed = num_passed_tensor.item() total = total_tensor.item() - average = num_passed/total - print0(f"Final: {num_passed}/{total} ({100*average:.2f}%)") + average = num_passed / total + print0(f"Final: {num_passed}/{total} ({100 * average:.2f}%)") return average + # ----------------------------------------------------------------------------- -def run_chat_eval(task_name, model, tokenizer, engine, - batch_size=1, num_samples=1, max_new_tokens=512, temperature=0.0, top_k=50, - max_problems=None): + +def run_chat_eval( + task_name, + model, + tokenizer, + engine, + batch_size=1, + num_samples=1, + max_new_tokens=512, + temperature=0.0, + top_k=50, + max_problems=None, +): # Create the evaluation object task_module = { 'HumanEval': HumanEval, @@ -171,20 +188,36 @@ def run_chat_eval(task_name, model, tokenizer, engine, task_object = task_module() # Run the evaluation if task_object.eval_type == 'generative': - acc = run_generative_eval(task_object, tokenizer, model, engine, num_samples, max_new_tokens, temperature, top_k, max_problems=max_problems) + acc = run_generative_eval( + task_object, + tokenizer, + model, + engine, + num_samples, + max_new_tokens, + temperature, + top_k, + max_problems=max_problems, + ) elif task_object.eval_type == 'categorical': acc = run_categorical_eval(task_object, tokenizer, model, batch_size, max_problems=max_problems) else: raise ValueError(f"Unsupported task evaluation type: {task_object.eval_type}") return acc + # ----------------------------------------------------------------------------- if __name__ == "__main__": - # Parse command-line arguments parser = argparse.ArgumentParser() parser.add_argument('-i', '--source', type=str, required=True, help="Source of the model: sft|mid|rl") - parser.add_argument('-a', '--task-name', type=str, default=None, help="Task name. Default = all tasks. Use | to split multiple tasks.") + parser.add_argument( + '-a', + '--task-name', + type=str, + default=None, + help="Task name. Default = all tasks. Use | to split multiple tasks.", + ) parser.add_argument('-d', '--dtype', type=str, default='bfloat16', choices=['float32', 'bfloat16']) parser.add_argument('-t', '--temperature', type=float, default=0.0) parser.add_argument('-m', '--max-new-tokens', type=int, default=512) @@ -194,13 +227,21 @@ if __name__ == "__main__": parser.add_argument('-g', '--model-tag', type=str, default=None, help='Model tag to load') parser.add_argument('-s', '--step', type=int, default=None, help='Step to load') parser.add_argument('-x', '--max-problems', type=int, default=None, help='Max problems to evaluate') - parser.add_argument('--device-type', type=str, default='', choices=['cuda', 'cpu', 'mps'], help='Device type for evaluation: cuda|cpu|mps. empty => autodetect') + parser.add_argument( + '--device-type', + type=str, + default='', + choices=['cuda', 'cpu', 'mps'], + help='Device type for evaluation: cuda|cpu|mps. empty => autodetect', + ) args = parser.parse_args() device_type = autodetect_device_type() if args.device_type == "" else args.device_type ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type) ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16 - autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext() + autocast_ctx = ( + torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext() + ) model, tokenizer, meta = load_model(args.source, device, phase="eval", model_tag=args.model_tag, step=args.step) engine = Engine(model, tokenizer) @@ -208,12 +249,12 @@ if __name__ == "__main__": # Get the tasks to evaluate on all_tasks = ['ARC-Easy', 'ARC-Challenge', 'MMLU', 'GSM8K', 'HumanEval', 'SpellingBee'] baseline_accuracies = { - 'ARC-Easy': 0.25, # multiple choice 1 of 4 => 25% - 'ARC-Challenge': 0.25, # multiple choice 1 of 4 => 25% - 'MMLU': 0.25, # multiple choice 1 of 4 => 25% - 'GSM8K': 0.0, # open-ended => 0% - 'HumanEval': 0.0, # open-ended => 0% - 'SpellingBee': 0.0, # open-ended => 0% + 'ARC-Easy': 0.25, # multiple choice 1 of 4 => 25% + 'ARC-Challenge': 0.25, # multiple choice 1 of 4 => 25% + 'MMLU': 0.25, # multiple choice 1 of 4 => 25% + 'GSM8K': 0.0, # open-ended => 0% + 'HumanEval': 0.0, # open-ended => 0% + 'SpellingBee': 0.0, # open-ended => 0% } task_names = all_tasks if args.task_name is None else args.task_name.split('|') @@ -223,7 +264,9 @@ if __name__ == "__main__": with autocast_ctx: acc = run_chat_eval( task_name, - model, tokenizer, engine, + model, + tokenizer, + engine, batch_size=args.batch_size, num_samples=args.num_samples, max_new_tokens=args.max_new_tokens, @@ -236,6 +279,7 @@ if __name__ == "__main__": # Log to report from nanochat.report import get_report + all_tasks_were_evaluated = all(task_name in results for task_name in all_tasks) # calculate the ChatCORE metric if we can (similar to CORE, it's the mean centered accuracy) # this way, ChatCORE ranges from 0 (at random baseline) to 1 (peak performance) @@ -248,10 +292,13 @@ if __name__ == "__main__": centered_mean += centered_acc chatcore_metric = centered_mean / len(results) chatcore_metric_dict = {"ChatCORE metric": chatcore_metric} - get_report().log(section="Chat evaluation " + args.source, data=[ - vars(args), # CLI args - results, - chatcore_metric_dict, - ]) + get_report().log( + section="Chat evaluation " + args.source, + data=[ + vars(args), # CLI args + results, + chatcore_metric_dict, + ], + ) compute_cleanup() diff --git a/scripts/chat_rl.py b/scripts/chat_rl.py index bc78e79..c0589f8 100644 --- a/scripts/chat_rl.py +++ b/scripts/chat_rl.py @@ -16,46 +16,46 @@ python -m scripts.chat_rl torchrun --standalone --nproc_per_node=8 -m scripts.chat_rl -- --run=default """ -import os import itertools -import re -import wandb +import os + import torch import torch.distributed as dist +import wandb -from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir, DummyWandb -from nanochat.checkpoint_manager import save_checkpoint, load_model +from nanochat.checkpoint_manager import load_model, save_checkpoint +from nanochat.common import DummyWandb, compute_cleanup, compute_init, get_base_dir, print0 from nanochat.engine import Engine from tasks.gsm8k import GSM8K # RL hyperparameters -run = "dummy" # wandb run name -source = "sft" # mid|sft +run = "dummy" # wandb run name +source = "sft" # mid|sft dtype = "bfloat16" -device_batch_size = 8 # no forward pass will go above this to not OOM -examples_per_step = 16 # in total and across all ranks (note: examples, not samples/completions!) -num_samples = 16 # number of samples per example (/question) +device_batch_size = 8 # no forward pass will go above this to not OOM +examples_per_step = 16 # in total and across all ranks (note: examples, not samples/completions!) +num_samples = 16 # number of samples per example (/question) max_new_tokens = 256 temperature = 1.0 -top_k = 50 # TODO: try None? +top_k = 50 # TODO: try None? unembedding_lr = 0.004 embedding_lr = 0.2 matrix_lr = 0.02 weight_decay = 0.0 init_lr_frac = 0.05 -num_epochs = 1 # how many epochs of gsm8k to train on -save_every = 60 # every how many steps to save the model -eval_every = 60 # every how many steps to evaluate the model for val pass@k -eval_examples = 400 # number of examples used for evaluating pass@k +num_epochs = 1 # how many epochs of gsm8k to train on +save_every = 60 # every how many steps to save the model +eval_every = 60 # every how many steps to evaluate the model for val pass@k +eval_examples = 400 # number of examples used for evaluating pass@k # now allow CLI to override the settings via the configurator lol -config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))] -exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file -user_config = {k: globals()[k] for k in config_keys} # will be useful for logging +config_keys = [k for k, v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))] +exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file +user_config = {k: globals()[k] for k in config_keys} # will be useful for logging # ----------------------------------------------------------------------------- # Init compute/precision ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init() -master_process = ddp_rank == 0 # this process will do logging, checkpointing etc. +master_process = ddp_rank == 0 # this process will do logging, checkpointing etc. dtype = torch.float32 if dtype == 'float32' else torch.bfloat16 autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=dtype) @@ -65,7 +65,7 @@ wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-rl # Init model and tokenizer model, tokenizer, meta = load_model(source, device, phase="eval") -engine = Engine(model, tokenizer) # for sampling rollouts +engine = Engine(model, tokenizer) # for sampling rollouts # ----------------------------------------------------------------------------- # Rollout / sampling generator loop that yields batches of examples for training @@ -75,12 +75,16 @@ val_task = GSM8K(subset="main", split="test") num_steps = (len(train_task) // examples_per_step) * num_epochs print0(f"Calculated number of steps: {num_steps}") + @torch.no_grad() def get_batch(): - assistant_end = tokenizer.encode_special("<|assistant_end|>") # ok to use this token, it's only for padding and isn't used in the loss. - rank_indices = range(ddp_rank, len(train_task), ddp_world_size) # each rank is responsible for different examples in the training data + assistant_end = tokenizer.encode_special( + "<|assistant_end|>" + ) # ok to use this token, it's only for padding and isn't used in the loss. + rank_indices = range( + ddp_rank, len(train_task), ddp_world_size + ) # each rank is responsible for different examples in the training data for example_idx in itertools.cycle(rank_indices): - # First get the full conversation of both user and assistant messages conversation = train_task[example_idx] @@ -90,12 +94,12 @@ def get_batch(): prefix_length = len(tokens) # Generate num_samples samples using batched generation, use loop to avoid OOMs - model.eval() # ensure the model is in eval mode + model.eval() # ensure the model is in eval mode generated_token_sequences = [] masks = [] - num_sampling_steps = num_samples // device_batch_size # go sequentially to prevent OOMs + num_sampling_steps = num_samples // device_batch_size # go sequentially to prevent OOMs for sampling_step in range(num_sampling_steps): - seed = hash((step, example_idx, sampling_step)) & 0x7FFFFFFF # positive half of int32 + seed = hash((step, example_idx, sampling_step)) & 0x7FFFFFFF # positive half of int32 with autocast_ctx: generated_token_sequences_batch, masks_batch = engine.generate_batch( tokens, @@ -103,7 +107,7 @@ def get_batch(): max_tokens=max_new_tokens, temperature=temperature, top_k=top_k, - seed=seed, # must make sure to change the seed for each sampling step + seed=seed, # must make sure to change the seed for each sampling step ) generated_token_sequences.extend(generated_token_sequences_batch) masks.extend(masks_batch) @@ -121,15 +125,17 @@ def get_batch(): # Pad the sequences so that their lengths (in time) match max_length = max(len(seq) for seq in generated_token_sequences) - padded_generated_token_sequences = [seq + [assistant_end] * (max_length - len(seq)) for seq in generated_token_sequences] + padded_generated_token_sequences = [ + seq + [assistant_end] * (max_length - len(seq)) for seq in generated_token_sequences + ] padded_masks = [mask + [0] * (max_length - len(mask)) for mask in masks] # Stack up the sequences and masks into PyTorch tensors ids = torch.tensor(padded_generated_token_sequences, dtype=torch.long, device=device) mask_ids = torch.tensor(padded_masks, dtype=torch.long, device=device) # Generate autoregressive inputs and targets to the Transformer inputs = ids[:, :-1] - targets = ids[:, 1:].clone() # clone to avoid in-place modification: - targets[mask_ids[:, 1:] == 0] = -1 # <-- inplace modification right here. -1 is the ignore index + targets = ids[:, 1:].clone() # clone to avoid in-place modification: + targets[mask_ids[:, 1:] == 0] = -1 # <-- inplace modification right here. -1 is the ignore index # NOTE also that the Engine returns mask=0 for BOTH the prompt tokens AND the tool use tokens. # So we will (correctly) end up not training on the prompt tokens, or the tool use forced tokens. rewards = torch.tensor(rewards, dtype=torch.float, device=device) @@ -139,14 +145,11 @@ def get_batch(): # yield inputs/targets as (B, T) of ids and rewards as (B,) of floats yield generated_token_sequences, inputs, targets, rewards, advantages + # ----------------------------------------------------------------------------- # Simple evaluation loop for GSM8K pass@k -def run_gsm8k_eval(task, tokenizer, engine, - max_examples=None, - num_samples=1, - max_completion_tokens=256, - temperature=0.0, - top_k=50 +def run_gsm8k_eval( + task, tokenizer, engine, max_examples=None, num_samples=1, max_completion_tokens=256, temperature=0.0, top_k=50 ): """ Evaluates GSM8K task and returns a list of records of evaluation outcomes. @@ -160,13 +163,9 @@ def run_gsm8k_eval(task, tokenizer, engine, tokens = tokenizer.render_for_completion(conversation) prefix_length = len(tokens) # Generate k samples using batched generation inside the Engine - assert num_samples <= device_batch_size # usually this is true. we can add a loop if not... + assert num_samples <= device_batch_size # usually this is true. we can add a loop if not... generated_token_sequences, masks = engine.generate_batch( - tokens, - num_samples=num_samples, - max_tokens=max_completion_tokens, - temperature=temperature, - top_k=top_k + tokens, num_samples=num_samples, max_tokens=max_completion_tokens, temperature=temperature, top_k=top_k ) # Check each sample for correctness outcomes = [] @@ -174,9 +173,7 @@ def run_gsm8k_eval(task, tokenizer, engine, generated_tokens = sample_tokens[prefix_length:] generated_text = tokenizer.decode(generated_tokens) is_correct = task.evaluate(conversation, generated_text) - outcomes.append({ - "is_correct": is_correct - }) + outcomes.append({"is_correct": is_correct}) # A bit bloated because I wanted to do more complex logging at one point. record = { "idx": idx, @@ -184,6 +181,7 @@ def run_gsm8k_eval(task, tokenizer, engine, } yield record + # ----------------------------------------------------------------------------- # Training loop @@ -199,44 +197,49 @@ optimizers = model.setup_optimizers( for opt in optimizers: for group in opt.param_groups: group["lr"] = group["lr"] * init_lr_frac - group["initial_lr"] = group["lr"] # save the initial learning so we can decay easily later + group["initial_lr"] = group["lr"] # save the initial learning so we can decay easily later + # Learning rate scheduler: simple rampdown to zero over num_steps def get_lr_multiplier(it): lrm = 1.0 - it / num_steps return lrm + # Calculate the number of examples each rank handles to achieve the desired examples_per_step -print0(f"Total sequences per step: {examples_per_step * num_samples}") # total batch size in sequences/step +print0(f"Total sequences per step: {examples_per_step * num_samples}") # total batch size in sequences/step assert examples_per_step % ddp_world_size == 0, "Desired examples per step must be divisible by the number of ranks" -examples_per_rank = examples_per_step // ddp_world_size # per GPU +examples_per_rank = examples_per_step // ddp_world_size # per GPU print0(f"Calculated examples per rank: {examples_per_rank}") # Kick off the training loop batch_iterator = get_batch() for step in range(num_steps): - # Evaluate the model once in a while and log to wandb if step % eval_every == 0: model.eval() - passk = torch.zeros(device_batch_size, device=device) # pass@k for k=1..device_batch_size + passk = torch.zeros(device_batch_size, device=device) # pass@k for k=1..device_batch_size with autocast_ctx: - records_iter = run_gsm8k_eval(val_task, tokenizer, engine, num_samples=device_batch_size, max_examples=eval_examples, temperature=1.0) - records = list(records_iter) # collect all records + records_iter = run_gsm8k_eval( + val_task, tokenizer, engine, num_samples=device_batch_size, max_examples=eval_examples, temperature=1.0 + ) + records = list(records_iter) # collect all records for k in range(1, device_batch_size + 1): passk[k - 1] = sum(any(o["is_correct"] for o in r["outcomes"][:k]) for r in records) num_records = torch.tensor(len(records), dtype=torch.long, device=device) if ddp: dist.all_reduce(num_records, op=dist.ReduceOp.SUM) dist.all_reduce(passk, op=dist.ReduceOp.SUM) - passk = passk / num_records.item() # normalize by the total number of records + passk = passk / num_records.item() # normalize by the total number of records print_passk = [f"Pass@{k}: {passk[k - 1].item():.4f}" for k in range(1, device_batch_size + 1)] print0(f"Step {step} | {', '.join(print_passk)}") log_passk = {f"pass@{k}": passk[k - 1].item() for k in range(1, device_batch_size + 1)} - wandb_run.log({ - "step": step, - **log_passk, - }) + wandb_run.log( + { + "step": step, + **log_passk, + } + ) # Forward/Backward on rollouts over multiple examples in the dataset rewards_list = [] @@ -245,7 +248,7 @@ for step in range(num_steps): # Get one batch corresponding to one example in the training dataset sequences_all, inputs_all, targets_all, rewards_all, advantages_all = next(batch_iterator) # Evaluate the loss and gradients - model.train() # ensure the model is in train mode + model.train() # ensure the model is in train mode # We need one more loop because we can never exceed the device_batch_size assert inputs_all.size(0) % device_batch_size == 0 num_passes = inputs_all.size(0) // device_batch_size @@ -258,7 +261,7 @@ for step in range(num_steps): advantages = advantages_all[b0:b1] # Calculate log probabilities. Note that the loss calculates NLL = -logp, so we negate with autocast_ctx: - logp = -model(inputs, targets, loss_reduction='none').view_as(inputs) # (B, T) + logp = -model(inputs, targets, loss_reduction='none').view_as(inputs) # (B, T) # Calculate the PG objective. Note that ignore_index=-1 ensures that invalid tokens have loss 0. pg_obj = (logp * advantages.unsqueeze(-1)).sum() # normalize by the number of valid tokens, number of passes, and examples_per_rank @@ -268,7 +271,9 @@ for step in range(num_steps): # Finally, formulate the loss that we want to minimize (instead of objective we wish to maximize) loss = -pg_obj loss.backward() - print0(f"Step {step}/{num_steps} | Example step {example_step} | Pass {pass_idx} | loss: {loss.item():.6f} | Average reward: {rewards.mean().item()}") + print0( + f"Step {step}/{num_steps} | Example step {example_step} | Pass {pass_idx} | loss: {loss.item():.6f} | Average reward: {rewards.mean().item()}" + ) # For logging rewards_list.append(rewards_all.mean().item()) sequence_lengths.extend(len(seq) for seq in sequences_all) @@ -276,56 +281,66 @@ for step in range(num_steps): # A bunch of logging for how the rollouts went this step mean_reward = sum(rewards_list) / len(rewards_list) mean_sequence_length = sum(sequence_lengths) / len(sequence_lengths) - if ddp: # aggregate across ranks + if ddp: # aggregate across ranks mean_reward_tensor = torch.tensor(mean_reward, dtype=torch.float, device=device) mean_sequence_length_tensor = torch.tensor(mean_sequence_length, dtype=torch.float, device=device) dist.all_reduce(mean_reward_tensor, op=dist.ReduceOp.AVG) dist.all_reduce(mean_sequence_length_tensor, op=dist.ReduceOp.AVG) mean_reward = mean_reward_tensor.item() mean_sequence_length = mean_sequence_length_tensor.item() - print0(f"Step {step}/{num_steps} | Average reward: {mean_reward} | Average sequence length: {mean_sequence_length:.2f}") - wandb_run.log({ - "step": step, - "reward": mean_reward, - "sequence_length": mean_sequence_length, - }) + print0( + f"Step {step}/{num_steps} | Average reward: {mean_reward} | Average sequence length: {mean_sequence_length:.2f}" + ) + wandb_run.log( + { + "step": step, + "reward": mean_reward, + "sequence_length": mean_sequence_length, + } + ) # Update the model parameters lrm = get_lr_multiplier(step) - for opt in optimizers: # first set the learning rate + for opt in optimizers: # first set the learning rate for group in opt.param_groups: group["lr"] = group["initial_lr"] * lrm - for opt in optimizers: # then step the optimizers + for opt in optimizers: # then step the optimizers opt.step() model.zero_grad(set_to_none=True) - wandb_run.log({ - "step": step, - "lrm": lrm, - }) + wandb_run.log( + { + "step": step, + "lrm": lrm, + } + ) # Master process saves the model once in a while. Skip first step. Save last step. if master_process and ((step > 0 and step % save_every == 0) or step == num_steps - 1): base_dir = get_base_dir() depth = model.config.n_layer - model_tag = f"d{depth}" # base the model tag on the depth of the base model + model_tag = f"d{depth}" # base the model tag on the depth of the base model checkpoint_dir = os.path.join(base_dir, "chatrl_checkpoints", model_tag) - model_config_kwargs = model.config.__dict__ # slightly naughty, abusing the simplicity of GPTConfig, TODO nicer + model_config_kwargs = model.config.__dict__ # slightly naughty, abusing the simplicity of GPTConfig, TODO nicer save_checkpoint( checkpoint_dir, step, model.state_dict(), - None, # note: we don't bother to save the optimizer state + None, # note: we don't bother to save the optimizer state { "model_config": model_config_kwargs, - } + }, ) print(f"✅ Saved model checkpoint to {checkpoint_dir}") # Log to report from nanochat.report import get_report -get_report().log(section="Chat RL", data=[ - user_config, # CLI args -]) -wandb_run.finish() # wandb run finish +get_report().log( + section="Chat RL", + data=[ + user_config, # CLI args + ], +) + +wandb_run.finish() # wandb run finish compute_cleanup() diff --git a/scripts/chat_sft.py b/scripts/chat_sft.py index bbeb1f9..d6a8fb4 100644 --- a/scripts/chat_sft.py +++ b/scripts/chat_sft.py @@ -10,40 +10,40 @@ torchrun --standalone --nproc_per_node=8 -m scripts.chat_sft """ import os + os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" -import wandb -import torch -import torch.distributed as dist from contextlib import nullcontext -from nanochat.common import compute_init, compute_cleanup, get_base_dir, print0, DummyWandb, autodetect_device_type -from nanochat.checkpoint_manager import load_model -from nanochat.checkpoint_manager import save_checkpoint +import torch +import torch.distributed as dist +import wandb + +from nanochat.checkpoint_manager import load_model, save_checkpoint +from nanochat.common import DummyWandb, autodetect_device_type, compute_cleanup, compute_init, get_base_dir, print0 from nanochat.engine import Engine from scripts.chat_eval import run_chat_eval - -from tasks.common import TaskMixture from tasks.arc import ARC +from tasks.common import TaskMixture +from tasks.customjson import CustomJSON from tasks.gsm8k import GSM8K from tasks.smoltalk import SmolTalk -from tasks.customjson import CustomJSON from tasks.spellingbee import SimpleSpelling, SpellingBee # ----------------------------------------------------------------------------- # SFT Hyperparameters -run = "dummy" # wandb run name default ("dummy" is special - we won't log to wandb) +run = "dummy" # wandb run name default ("dummy" is special - we won't log to wandb) # input model options -source = "mid" # base|mid , which checkpoint to load the model from (base model or midtrained model) -model_tag = None # model tag to load the model from (base model or midtrained model) -step = None # step to load the model from (base model or midtrained model) +source = "mid" # base|mid , which checkpoint to load the model from (base model or midtrained model) +model_tag = None # model tag to load the model from (base model or midtrained model) +step = None # step to load the model from (base model or midtrained model) # compute/precision -device_type = "" # cuda|cpu|mps (empty => autodetect) +device_type = "" # cuda|cpu|mps (empty => autodetect) dtype = "bfloat16" -device_batch_size = 4 # max to avoid OOM +device_batch_size = 4 # max to avoid OOM # optimization num_epochs = 1 -num_iterations = -1 # override number of iterations (-1 = disable, use num_epochs to derive it) +num_iterations = -1 # override number of iterations (-1 = disable, use num_epochs to derive it) target_examples_per_step = 32 unembedding_lr = 0.004 embedding_lr = 0.2 @@ -56,9 +56,9 @@ eval_steps = 100 eval_metrics_every = 200 eval_metrics_max_problems = 1024 # now allow CLI to override the settings via the configurator lol -config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))] -exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file -user_config = {k: globals()[k] for k in config_keys} # possibly useful for logging +config_keys = [k for k, v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))] +exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file +user_config = {k: globals()[k] for k in config_keys} # possibly useful for logging # ----------------------------------------------------------------------------- # Compute init @@ -70,52 +70,63 @@ autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if dev # wandb logging init use_dummy_wandb = run == "dummy" or not master_process -wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-sft", name=run, config=user_config, save_code=True) +wandb_run = ( + DummyWandb() + if use_dummy_wandb + else wandb.init(project="nanochat-sft", name=run, config=user_config, save_code=True) +) # Load the model and tokenizer model, tokenizer, meta = load_model(source, device, phase="train", model_tag=model_tag, step=step) -orig_model = model # original, uncompiled model +orig_model = model # original, uncompiled model # model = torch.compile(model, dynamic=True) # doesn't work super well because of variable lengths of inputs -engine = Engine(model, tokenizer) # will be used for inline model evaluation only +engine = Engine(model, tokenizer) # will be used for inline model evaluation only # ----------------------------------------------------------------------------- # Task data mixture we'll train on identity_conversations_filepath = os.path.join(get_base_dir(), "identity_conversations.jsonl") -train_ds = TaskMixture([ - ARC(subset="ARC-Easy", split="train"), # 2.3K rows - ARC(subset="ARC-Challenge", split="train"), # 1.1K rows - GSM8K(subset="main", split="train"), # 8K rows - SmolTalk(split="train", stop=10_000), # 10K rows of smoltalk - CustomJSON(filepath=identity_conversations_filepath), # 1K rows of synthetic identity conversations - SimpleSpelling(size=300, split="train"), # 300 rows of Simple Spelling (e.g. spell the word 'apple') - SpellingBee(size=300, split="train"), # 300 rows of Spelling Bee (e.g. how many 'r' are in 'strawberry'?) -]) # 2.3K + 1.1K + 8K + 10K + 1K + 0.3K + 0.3K = 23K rows -val_ds = SmolTalk(split="test") # general conversations, 24K rows (though we don't actually use all of it) +train_ds = TaskMixture( + [ + ARC(subset="ARC-Easy", split="train"), # 2.3K rows + ARC(subset="ARC-Challenge", split="train"), # 1.1K rows + GSM8K(subset="main", split="train"), # 8K rows + SmolTalk(split="train", stop=10_000), # 10K rows of smoltalk + CustomJSON(filepath=identity_conversations_filepath), # 1K rows of synthetic identity conversations + SimpleSpelling(size=300, split="train"), # 300 rows of Simple Spelling (e.g. spell the word 'apple') + SpellingBee(size=300, split="train"), # 300 rows of Spelling Bee (e.g. how many 'r' are in 'strawberry'?) + ] +) # 2.3K + 1.1K + 8K + 10K + 1K + 0.3K + 0.3K = 23K rows +val_ds = SmolTalk(split="test") # general conversations, 24K rows (though we don't actually use all of it) # ----------------------------------------------------------------------------- # DataLoader + def sft_data_generator(dataset, batch_size): - pad_token_id = tokenizer.encode_special("<|assistant_end|>") # use <|assistant_end|> as the pad token is ok, these positions are masked in the loss + pad_token_id = tokenizer.encode_special( + "<|assistant_end|>" + ) # use <|assistant_end|> as the pad token is ok, these positions are masked in the loss + # prepares a list of tokenized conversations into a batch and yields def collate_and_yield(batch): nrows = len(batch) - ncols = max(len(ids) for ids, mask in batch) - 1 # seq of n creates inputs/targets of n-1 + ncols = max(len(ids) for ids, mask in batch) - 1 # seq of n creates inputs/targets of n-1 inputs = torch.full((nrows, ncols), pad_token_id, dtype=torch.long) - targets = torch.full((nrows, ncols), -1, dtype=torch.long) # -1 is ignore index + targets = torch.full((nrows, ncols), -1, dtype=torch.long) # -1 is ignore index for i, (ids, mask) in enumerate(batch): n = len(ids) ids_tensor = torch.tensor(ids, dtype=torch.long) - inputs[i, :n-1] = ids_tensor[:-1] + inputs[i, : n - 1] = ids_tensor[:-1] # recall -1 is the ignore index, so mask out targets where mask is 0 row_targets = ids_tensor[1:] # mask[1:] omits the mask for the BOS token, which is never a target atm so it's ok mask_tensor = torch.tensor(mask[1:], dtype=torch.long) - row_targets[mask_tensor == 0] = -1 # mask out targets where mask is 0 - targets[i, :n-1] = row_targets - inputs = inputs.to(device) # move to device + row_targets[mask_tensor == 0] = -1 # mask out targets where mask is 0 + targets[i, : n - 1] = row_targets + inputs = inputs.to(device) # move to device targets = targets.to(device) return inputs, targets + # iterates over the dataset in epochs, tokenizes batch = [] while True: @@ -127,11 +138,14 @@ def sft_data_generator(dataset, batch_size): yield collate_and_yield(batch) batch = [] + examples_per_step = device_batch_size * ddp_world_size print0(f"Target examples per step: {target_examples_per_step}") print0(f"Device batch size: {device_batch_size}") print0(f"Examples per step is device_batch_size * ddp_world_size: {examples_per_step}") -assert target_examples_per_step % examples_per_step == 0, "Target examples per step must be divisible by examples per step" +assert target_examples_per_step % examples_per_step == 0, ( + "Target examples per step must be divisible by examples per step" +) grad_accum_steps = target_examples_per_step // examples_per_step print0(f"=> Setting grad accum steps: {grad_accum_steps}") @@ -155,16 +169,18 @@ optimizers = model.setup_optimizers( for opt in optimizers: for group in opt.param_groups: group["lr"] = group["lr"] * init_lr_frac - group["initial_lr"] = group["lr"] # save the initial learning so we can decay easily later + group["initial_lr"] = group["lr"] # save the initial learning so we can decay easily later # ----------------------------------------------------------------------------- # Training loop + # Learning rate scheduler def get_lr_multiplier(it): lrm = 1.0 - it / num_iterations return lrm + # Go! step = 0 train_iter = iter(train_loader) @@ -181,15 +197,17 @@ for step in range(num_iterations): with torch.no_grad(), autocast_ctx: loss = model(val_inputs, val_targets) losses.append(loss) - val_loss = torch.stack(losses).mean() # average over eval_steps + val_loss = torch.stack(losses).mean() # average over eval_steps if ddp: - dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) # average over ranks + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) # average over ranks val_loss = val_loss.item() print0(f"Step {step:05d} | Validation loss: {val_loss:.6f}") - wandb_run.log({ - "step": step, - "val_loss": val_loss, - }) + wandb_run.log( + { + "step": step, + "val_loss": val_loss, + } + ) model.train() # evaluate accuracy of the multiple choice tasks (which are quick to run) @@ -198,31 +216,47 @@ for step in range(num_iterations): metrics = {} with torch.no_grad(), autocast_ctx: # note that because these are inside no_grad, we can usually afford to at least ~2X the batch size - metrics["mmlu_acc"] = run_chat_eval("MMLU", model, tokenizer, engine, batch_size=device_batch_size*2, max_problems=eval_metrics_max_problems) - metrics["arc_easy_acc"] = run_chat_eval("ARC-Easy", model, tokenizer, engine, batch_size=device_batch_size*2, max_problems=eval_metrics_max_problems) + metrics["mmlu_acc"] = run_chat_eval( + "MMLU", + model, + tokenizer, + engine, + batch_size=device_batch_size * 2, + max_problems=eval_metrics_max_problems, + ) + metrics["arc_easy_acc"] = run_chat_eval( + "ARC-Easy", + model, + tokenizer, + engine, + batch_size=device_batch_size * 2, + max_problems=eval_metrics_max_problems, + ) metrics_str = ', '.join(f'{k}: {v:.6f}' for k, v in metrics.items()) print0(f"Step {step:05d} | {metrics_str}") - wandb_run.log({ - "step": step, - **metrics, - }) + wandb_run.log( + { + "step": step, + **metrics, + } + ) model.train() if last_step: break # evaluate the gradient - num_tokens = torch.tensor(0, device=device) # the number of "active" tokens of supervision seen + num_tokens = torch.tensor(0, device=device) # the number of "active" tokens of supervision seen for micro_step in range(grad_accum_steps): train_inputs, train_targets = next(train_iter) with autocast_ctx: loss = model(train_inputs, train_targets) - train_loss = loss.detach() # for logging - loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here - loss.backward() # accumulate the gradient + train_loss = loss.detach() # for logging + loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here + loss.backward() # accumulate the gradient num_tokens += (train_targets >= 0).sum() if ddp: - dist.all_reduce(num_tokens, op=dist.ReduceOp.SUM) # sum over ranks + dist.all_reduce(num_tokens, op=dist.ReduceOp.SUM) # sum over ranks # learning rate scheduler lrm = get_lr_multiplier(step) @@ -238,47 +272,55 @@ for step in range(num_iterations): # logging train_loss_item = train_loss.item() num_tokens_item = num_tokens.item() - print0(f"Step {step:05d}/{num_iterations:05d} | Training loss: {train_loss_item:.6f}| lrm: {lrm:.6f}| num_tokens: {num_tokens_item:,}") - wandb_run.log({ - "step": step, - "lrm": lrm, - "train_loss": train_loss_item, - "num_tokens": num_tokens_item, - }) + print0( + f"Step {step:05d}/{num_iterations:05d} | Training loss: {train_loss_item:.6f}| lrm: {lrm:.6f}| num_tokens: {num_tokens_item:,}" + ) + wandb_run.log( + { + "step": step, + "lrm": lrm, + "train_loss": train_loss_item, + "num_tokens": num_tokens_item, + } + ) step += 1 # Save the model at the end of the run if master_process: base_dir = get_base_dir() depth = model.config.n_layer - model_tag = f"d{depth}" # base the model tag on the depth of the base model + model_tag = f"d{depth}" # base the model tag on the depth of the base model checkpoint_dir = os.path.join(base_dir, "chatsft_checkpoints", model_tag) - model_config_kwargs = model.config.__dict__ # slightly naughty, abusing the simplicity of GPTConfig, TODO nicer + model_config_kwargs = model.config.__dict__ # slightly naughty, abusing the simplicity of GPTConfig, TODO nicer save_checkpoint( checkpoint_dir, step, model.state_dict(), - None, # note: we don't bother to save the optimizer state + None, # note: we don't bother to save the optimizer state { "step": step, "val_loss": val_loss, **metrics, "model_config": model_config_kwargs, - } + }, ) print(f"✅ Saved model checkpoint to {checkpoint_dir}") # Log to report from nanochat.report import get_report -get_report().log(section="Chat SFT", data=[ - user_config, # CLI args - { - "Training rows": len(train_ds), - "Number of iterations": num_iterations, - "Training loss": train_loss_item, - "Validation loss": val_loss, - }, -]) + +get_report().log( + section="Chat SFT", + data=[ + user_config, # CLI args + { + "Training rows": len(train_ds), + "Number of iterations": num_iterations, + "Training loss": train_loss_item, + "Validation loss": val_loss, + }, + ], +) # Cleanup wandb_run.finish() diff --git a/scripts/chat_web.py b/scripts/chat_web.py index 4b67b62..2ff1bc7 100644 --- a/scripts/chat_web.py +++ b/scripts/chat_web.py @@ -31,22 +31,23 @@ Abuse Prevention: """ import argparse -import json -import os -import torch import asyncio +import json import logging +import os import random -from contextlib import asynccontextmanager +from collections.abc import AsyncGenerator +from contextlib import asynccontextmanager, nullcontext +from dataclasses import dataclass + +import torch from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware -from fastapi.responses import StreamingResponse, HTMLResponse, FileResponse +from fastapi.responses import FileResponse, HTMLResponse, StreamingResponse from pydantic import BaseModel -from typing import List, Optional, AsyncGenerator -from dataclasses import dataclass -from contextlib import nullcontext -from nanochat.common import compute_init, autodetect_device_type + from nanochat.checkpoint_manager import load_model +from nanochat.common import autodetect_device_type, compute_init from nanochat.engine import Engine # Abuse prevention limits @@ -70,70 +71,70 @@ parser.add_argument('-g', '--model-tag', type=str, default=None, help='Model tag parser.add_argument('-s', '--step', type=int, default=None, help='Step to load') parser.add_argument('-p', '--port', type=int, default=8000, help='Port to run the server on') parser.add_argument('-d', '--dtype', type=str, default='bfloat16', choices=['float32', 'bfloat16']) -parser.add_argument('--device-type', type=str, default='', choices=['cuda', 'cpu', 'mps'], help='Device type for evaluation: cuda|cpu|mps. empty => autodetect') +parser.add_argument( + '--device-type', + type=str, + default='', + choices=['cuda', 'cpu', 'mps'], + help='Device type for evaluation: cuda|cpu|mps. empty => autodetect', +) parser.add_argument('--host', type=str, default='0.0.0.0', help='Host to bind the server to') args = parser.parse_args() # Configure logging for conversation traffic -logging.basicConfig( - level=logging.INFO, - format='%(asctime)s - %(message)s', - datefmt='%Y-%m-%d %H:%M:%S' -) +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S') logger = logging.getLogger(__name__) device_type = autodetect_device_type() if args.device_type == "" else args.device_type ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type) ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16 + @dataclass class Worker: """A worker with a model loaded on a specific GPU.""" + gpu_id: int device: torch.device engine: Engine tokenizer: object autocast_ctx: torch.amp.autocast + class WorkerPool: """Pool of workers, each with a model replica on a different GPU.""" - def __init__(self, num_gpus: Optional[int] = None): + def __init__(self, num_gpus: int | None = None): if num_gpus is None: if device_type == "cuda": num_gpus = torch.cuda.device_count() else: - num_gpus = 1 # e.g. cpu|mps + num_gpus = 1 # e.g. cpu|mps self.num_gpus = num_gpus - self.workers: List[Worker] = [] + self.workers: list[Worker] = [] self.available_workers: asyncio.Queue = asyncio.Queue() - async def initialize(self, source: str, model_tag: Optional[str] = None, step: Optional[int] = None): + async def initialize(self, source: str, model_tag: str | None = None, step: int | None = None): """Load model on each GPU.""" print(f"Initializing worker pool with {self.num_gpus} GPUs...") if self.num_gpus > 1: assert device_type == "cuda", "Only CUDA supports multiple workers/GPUs. cpu|mps does not." for gpu_id in range(self.num_gpus): - if device_type == "cuda": device = torch.device(f"cuda:{gpu_id}") print(f"Loading model on GPU {gpu_id}...") else: - device = torch.device(device_type) # e.g. cpu|mps + device = torch.device(device_type) # e.g. cpu|mps print(f"Loading model on {device_type}...") model, tokenizer, _ = load_model(source, device, phase="eval", model_tag=model_tag, step=step) engine = Engine(model, tokenizer) - autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext() - - worker = Worker( - gpu_id=gpu_id, - device=device, - engine=engine, - tokenizer=tokenizer, - autocast_ctx=autocast_ctx + autocast_ctx = ( + torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext() ) + + worker = Worker(gpu_id=gpu_id, device=device, engine=engine, tokenizer=tokenizer, autocast_ctx=autocast_ctx) self.workers.append(worker) await self.available_workers.put(worker) @@ -147,15 +148,18 @@ class WorkerPool: """Return a worker to the pool.""" await self.available_workers.put(worker) + class ChatMessage(BaseModel): role: str content: str + class ChatRequest(BaseModel): - messages: List[ChatMessage] - temperature: Optional[float] = None - max_tokens: Optional[int] = None - top_k: Optional[int] = None + messages: list[ChatMessage] + temperature: float | None = None + max_tokens: int | None = None + top_k: int | None = None + def validate_chat_request(request: ChatRequest): """Validate chat request to prevent abuse.""" @@ -165,7 +169,7 @@ def validate_chat_request(request: ChatRequest): if len(request.messages) > MAX_MESSAGES_PER_REQUEST: raise HTTPException( status_code=400, - detail=f"Too many messages. Maximum {MAX_MESSAGES_PER_REQUEST} messages allowed per request" + detail=f"Too many messages. Maximum {MAX_MESSAGES_PER_REQUEST} messages allowed per request", ) # Check individual message lengths and total conversation length @@ -178,48 +182,43 @@ def validate_chat_request(request: ChatRequest): if msg_length > MAX_MESSAGE_LENGTH: raise HTTPException( status_code=400, - detail=f"Message {i} is too long. Maximum {MAX_MESSAGE_LENGTH} characters allowed per message" + detail=f"Message {i} is too long. Maximum {MAX_MESSAGE_LENGTH} characters allowed per message", ) total_length += msg_length if total_length > MAX_TOTAL_CONVERSATION_LENGTH: raise HTTPException( status_code=400, - detail=f"Total conversation is too long. Maximum {MAX_TOTAL_CONVERSATION_LENGTH} characters allowed" + detail=f"Total conversation is too long. Maximum {MAX_TOTAL_CONVERSATION_LENGTH} characters allowed", ) # Validate role values for i, message in enumerate(request.messages): if message.role not in ["user", "assistant"]: raise HTTPException( - status_code=400, - detail=f"Message {i} has invalid role. Must be 'user', 'assistant', or 'system'" + status_code=400, detail=f"Message {i} has invalid role. Must be 'user', 'assistant', or 'system'" ) # Validate temperature if request.temperature is not None: if not (MIN_TEMPERATURE <= request.temperature <= MAX_TEMPERATURE): raise HTTPException( - status_code=400, - detail=f"Temperature must be between {MIN_TEMPERATURE} and {MAX_TEMPERATURE}" + status_code=400, detail=f"Temperature must be between {MIN_TEMPERATURE} and {MAX_TEMPERATURE}" ) # Validate top_k if request.top_k is not None: if not (MIN_TOP_K <= request.top_k <= MAX_TOP_K): - raise HTTPException( - status_code=400, - detail=f"top_k must be between {MIN_TOP_K} and {MAX_TOP_K}" - ) + raise HTTPException(status_code=400, detail=f"top_k must be between {MIN_TOP_K} and {MAX_TOP_K}") # Validate max_tokens if request.max_tokens is not None: if not (MIN_MAX_TOKENS <= request.max_tokens <= MAX_MAX_TOKENS): raise HTTPException( - status_code=400, - detail=f"max_tokens must be between {MIN_MAX_TOKENS} and {MAX_MAX_TOKENS}" + status_code=400, detail=f"max_tokens must be between {MIN_MAX_TOKENS} and {MAX_MAX_TOKENS}" ) + @asynccontextmanager async def lifespan(app: FastAPI): """Load models on all GPUs on startup.""" @@ -229,6 +228,7 @@ async def lifespan(app: FastAPI): print(f"Server ready at http://localhost:{args.port}") yield + app = FastAPI(lifespan=lifespan) app.add_middleware( @@ -239,16 +239,16 @@ app.add_middleware( allow_headers=["*"], ) + @app.get("/") async def root(): """Serve the chat UI.""" ui_html_path = os.path.join("nanochat", "ui.html") - with open(ui_html_path, "r", encoding="utf-8") as f: + with open(ui_html_path, encoding="utf-8") as f: html_content = f.read() # Replace the API_URL to use the same origin html_content = html_content.replace( - "const API_URL = `http://${window.location.hostname}:8000`;", - "const API_URL = '';" + "const API_URL = `http://${window.location.hostname}:8000`;", "const API_URL = '';" ) return HTMLResponse(content=html_content) @@ -259,12 +259,9 @@ async def logo(): logo_path = os.path.join("nanochat", "logo.svg") return FileResponse(logo_path, media_type="image/svg+xml") + async def generate_stream( - worker: Worker, - tokens, - temperature=None, - max_new_tokens=None, - top_k=None + worker: Worker, tokens, temperature=None, max_new_tokens=None, top_k=None ) -> AsyncGenerator[str, None]: """Generate assistant response with streaming.""" temperature = temperature if temperature is not None else args.temperature @@ -286,7 +283,7 @@ async def generate_stream( max_tokens=max_new_tokens, temperature=temperature, top_k=top_k, - seed=random.randint(0, 2**31 - 1) + seed=random.randint(0, 2**31 - 1), ): token = token_column[0] @@ -303,13 +300,14 @@ async def generate_stream( # This ensures we don't emit incomplete UTF-8 sequences if not current_text.endswith('�'): # Extract only the new text since last clean decode - new_text = current_text[len(last_clean_text):] + new_text = current_text[len(last_clean_text) :] if new_text: # Only yield if there's new content yield f"data: {json.dumps({'token': new_text, 'gpu': worker.gpu_id}, ensure_ascii=False)}\n\n" last_clean_text = current_text yield f"data: {json.dumps({'done': True})}\n\n" + @app.post("/chat/completions") async def chat_completions(request: ChatRequest): """Chat completion endpoint (streaming only) - uses worker pool for multi-GPU.""" @@ -318,10 +316,10 @@ async def chat_completions(request: ChatRequest): validate_chat_request(request) # Log incoming conversation to console - logger.info("="*20) + logger.info("=" * 20) for i, message in enumerate(request.messages): logger.info(f"[{message.role.upper()}]: {message.content}") - logger.info("-"*20) + logger.info("-" * 20) # Acquire a worker from the pool (will wait if all are busy) worker_pool = app.state.worker_pool @@ -350,6 +348,7 @@ async def chat_completions(request: ChatRequest): # Streaming response with worker release after completion response_tokens = [] + async def stream_and_release(): try: async for chunk in generate_stream( @@ -357,7 +356,7 @@ async def chat_completions(request: ChatRequest): conversation_tokens, temperature=request.temperature, max_new_tokens=request.max_tokens, - top_k=request.top_k + top_k=request.top_k, ): # Accumulate response for logging chunk_data = json.loads(chunk.replace("data: ", "").strip()) @@ -368,19 +367,17 @@ async def chat_completions(request: ChatRequest): # Log the assistant response to console full_response = "".join(response_tokens) logger.info(f"[ASSISTANT] (GPU {worker.gpu_id}): {full_response}") - logger.info("="*20) + logger.info("=" * 20) # Release worker back to pool after streaming is done await worker_pool.release_worker(worker) - return StreamingResponse( - stream_and_release(), - media_type="text/event-stream" - ) + return StreamingResponse(stream_and_release(), media_type="text/event-stream") except Exception as e: # Make sure to release worker even on error await worker_pool.release_worker(worker) raise e + @app.get("/health") async def health(): """Health check endpoint.""" @@ -389,9 +386,10 @@ async def health(): "status": "ok", "ready": worker_pool is not None and len(worker_pool.workers) > 0, "num_gpus": worker_pool.num_gpus if worker_pool else 0, - "available_workers": worker_pool.available_workers.qsize() if worker_pool else 0 + "available_workers": worker_pool.available_workers.qsize() if worker_pool else 0, } + @app.get("/stats") async def stats(): """Get worker pool statistics.""" @@ -400,16 +398,13 @@ async def stats(): "total_workers": len(worker_pool.workers), "available_workers": worker_pool.available_workers.qsize(), "busy_workers": len(worker_pool.workers) - worker_pool.available_workers.qsize(), - "workers": [ - { - "gpu_id": w.gpu_id, - "device": str(w.device) - } for w in worker_pool.workers - ] + "workers": [{"gpu_id": w.gpu_id, "device": str(w.device)} for w in worker_pool.workers], } + if __name__ == "__main__": import uvicorn - print(f"Starting NanoChat Web Server") + + print("Starting NanoChat Web Server") print(f"Temperature: {args.temperature}, Top-k: {args.top_k}, Max tokens: {args.max_tokens}") uvicorn.run(app, host=args.host, port=args.port) diff --git a/scripts/mid_train.py b/scripts/mid_train.py index 6c2b82f..7caa973 100644 --- a/scripts/mid_train.py +++ b/scripts/mid_train.py @@ -9,55 +9,58 @@ Or torchrun for training: torchrun --standalone --nproc_per_node=8 -m scripts.mid_train -- --device_batch_size=16 """ -from collections import deque import os +from collections import deque + os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" import time -import wandb -import torch from contextlib import nullcontext -from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, get_base_dir, autodetect_device_type -from nanochat.tokenizer import get_token_bytes -from nanochat.checkpoint_manager import save_checkpoint -from nanochat.loss_eval import evaluate_bpb -from nanochat.checkpoint_manager import load_model -import torch.distributed as dist +import torch +import torch.distributed as dist +import wandb + +from nanochat.checkpoint_manager import load_model, save_checkpoint +from nanochat.common import DummyWandb, autodetect_device_type, compute_cleanup, compute_init, get_base_dir, print0 +from nanochat.loss_eval import evaluate_bpb +from nanochat.tokenizer import get_token_bytes from tasks.common import TaskMixture +from tasks.customjson import CustomJSON from tasks.gsm8k import GSM8K from tasks.mmlu import MMLU from tasks.smoltalk import SmolTalk -from tasks.customjson import CustomJSON from tasks.spellingbee import SimpleSpelling, SpellingBee # ----------------------------------------------------------------------------- -run = "dummy" # wandb run name default ("dummy" is special - we won't log to wandb) -device_type = "" # cuda|cpu|mps (empty => autodetect) -model_tag = None # model tag to load the model from (base model or midtrained model) -step = None # step to load the model from (base model or midtrained model) +run = "dummy" # wandb run name default ("dummy" is special - we won't log to wandb) +device_type = "" # cuda|cpu|mps (empty => autodetect) +model_tag = None # model tag to load the model from (base model or midtrained model) +step = None # step to load the model from (base model or midtrained model) dtype = "bfloat16" -num_iterations = -1 # explicit number of steps of the optimization (-1 = disable) +num_iterations = -1 # explicit number of steps of the optimization (-1 = disable) max_seq_len = 2048 device_batch_size = 32 unembedding_lr = 0.004 embedding_lr = 0.2 matrix_lr = 0.02 -init_lr_frac = 1.0 # initial learning rate is this fraction of the base learning rate +init_lr_frac = 1.0 # initial learning rate is this fraction of the base learning rate weight_decay = 0.0 -eval_every = 150 # -1 = disable -eval_tokens = 20*524288 +eval_every = 150 # -1 = disable +eval_tokens = 20 * 524288 total_batch_size = 524288 -dry_run = 0 # dry_run=1 is for experiments: we will log to wandb but we won't write checkpoints or report -config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))] -exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file -user_config = {k: globals()[k] for k in config_keys} # possibly useful for logging +dry_run = 0 # dry_run=1 is for experiments: we will log to wandb but we won't write checkpoints or report +config_keys = [k for k, v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))] +exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file +user_config = {k: globals()[k] for k in config_keys} # possibly useful for logging # ----------------------------------------------------------------------------- # Compute init device_type = autodetect_device_type() if device_type == "" else device_type ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type) master_process = ddp_rank == 0 -autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext() +autocast_ctx = ( + torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext() +) synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0 @@ -69,13 +72,15 @@ wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-mi model, tokenizer, meta = load_model("base", device, phase="train", model_tag=model_tag, step=step) pretrain_batch_size = meta.get("device_batch_size", None) if pretrain_batch_size is not None and device_batch_size > pretrain_batch_size: - print0(f"FOOTGUN WARNING: base model training used device_batch_size {pretrain_batch_size}, did you pass in a good --device_batch_size to this script?") + print0( + f"FOOTGUN WARNING: base model training used device_batch_size {pretrain_batch_size}, did you pass in a good --device_batch_size to this script?" + ) orig_model = model model = torch.compile(model, dynamic=False) depth = model.config.n_layer num_flops_per_token = model.estimate_flops() -tokens_per_fwdbwd = device_batch_size * max_seq_len # tokens per iteration for a single rank -world_tokens_per_fwdbwd = tokens_per_fwdbwd * ddp_world_size # total tokens per iteration for all ranks +tokens_per_fwdbwd = device_batch_size * max_seq_len # tokens per iteration for a single rank +world_tokens_per_fwdbwd = tokens_per_fwdbwd * ddp_world_size # total tokens per iteration for all ranks assert total_batch_size % world_tokens_per_fwdbwd == 0 grad_accum_steps = total_batch_size // world_tokens_per_fwdbwd print0(f"Tokens / micro-batch / rank: {device_batch_size} x {max_seq_len} = {tokens_per_fwdbwd:,}") @@ -84,48 +89,58 @@ print0(f"Total batch size {total_batch_size:,} => gradient accumulation steps: { token_bytes = get_token_bytes(device=device) # Initialize the Optimizer (Muon for Linear layers, AdamW for embedding and lm_head) -optimizers = model.setup_optimizers(unembedding_lr=unembedding_lr, embedding_lr=embedding_lr, matrix_lr=matrix_lr, weight_decay=weight_decay) +optimizers = model.setup_optimizers( + unembedding_lr=unembedding_lr, embedding_lr=embedding_lr, matrix_lr=matrix_lr, weight_decay=weight_decay +) adamw_optimizer, muon_optimizer = optimizers # Override the initial learning rate as a fraction of the base learning rate for opt in optimizers: for group in opt.param_groups: group["lr"] = group["lr"] * init_lr_frac - group["initial_lr"] = group["lr"] # save the initial learning so we can decay easily later + group["initial_lr"] = group["lr"] # save the initial learning so we can decay easily later # Midtraining data mixture and DataLoader base_dir = get_base_dir() identity_conversations_filepath = os.path.join(base_dir, "identity_conversations.jsonl") -train_dataset = TaskMixture([ - SmolTalk(split="train"), # 460K rows of general conversations - MMLU(subset="auxiliary_train", split="train"), # 100K rows of multiple choice problems drawn from ARC, MC_TEST, OBQA, RACE - GSM8K(subset="main", split="train"), # 8K rows teaching simple math and (calculator) tool use - CustomJSON(filepath=identity_conversations_filepath), # 1000 rows of synthetic identity conversations - CustomJSON(filepath=identity_conversations_filepath), # let's do 2 epochs of these - SimpleSpelling(size=200000, split="train"), # 200K rows of Simple Spelling (e.g. spell the word 'apple') - SpellingBee(size=80000, split="train"), # 80K rows of Spelling Bee (e.g. how many 'r' are in 'strawberry'?) -]) # total: 460K + 100K + 8K + 200K + 80K = 848K rows -val_dataset = TaskMixture([ - SmolTalk(split="test"), # 24K rows in test set - MMLU(subset="all", split="test", stop=5200), # 14K rows in test set, use only 5.2K to match the train ratios - GSM8K(subset="main", split="test", stop=420), # 1.32K rows in test set, use only 420 to match the train ratios -]) # total: 24K + 14K + 1.32K ~= 39K rows +train_dataset = TaskMixture( + [ + SmolTalk(split="train"), # 460K rows of general conversations + MMLU( + subset="auxiliary_train", split="train" + ), # 100K rows of multiple choice problems drawn from ARC, MC_TEST, OBQA, RACE + GSM8K(subset="main", split="train"), # 8K rows teaching simple math and (calculator) tool use + CustomJSON(filepath=identity_conversations_filepath), # 1000 rows of synthetic identity conversations + CustomJSON(filepath=identity_conversations_filepath), # let's do 2 epochs of these + SimpleSpelling(size=200000, split="train"), # 200K rows of Simple Spelling (e.g. spell the word 'apple') + SpellingBee(size=80000, split="train"), # 80K rows of Spelling Bee (e.g. how many 'r' are in 'strawberry'?) + ] +) # total: 460K + 100K + 8K + 200K + 80K = 848K rows +val_dataset = TaskMixture( + [ + SmolTalk(split="test"), # 24K rows in test set + MMLU(subset="all", split="test", stop=5200), # 14K rows in test set, use only 5.2K to match the train ratios + GSM8K(subset="main", split="test", stop=420), # 1.32K rows in test set, use only 420 to match the train ratios + ] +) # total: 24K + 14K + 1.32K ~= 39K rows # DataLoader is defined here, it emits inputs, targets : 2D tensors of shape (device_batch_size, max_seq_len) # A big problem is that we don't know the final num_iterations in advance. So we create # these two global variables and update them from within the data generator. -last_step = False # we will toggle this to True when we reach the end of the dataset -approx_progress = 0.0 # will go from 0 to 1 over the course of the epoch +last_step = False # we will toggle this to True when we reach the end of the dataset +approx_progress = 0.0 # will go from 0 to 1 over the course of the epoch + + def mid_data_generator(split): global last_step, approx_progress assert split in {"train", "val"}, "split must be 'train' or 'val'" dataset = train_dataset if split == "train" else val_dataset dataset_size = len(dataset) assert dataset_size > 0 - needed_tokens = device_batch_size * max_seq_len + 1 # to form one training batch of inputs,targets + needed_tokens = device_batch_size * max_seq_len + 1 # to form one training batch of inputs,targets token_buffer = deque() # CUDA supports memory pinning for faster transfers between CPU and GPU: scratch = torch.empty(needed_tokens, dtype=torch.int64, pin_memory=(device_type == "cuda")) - cursor = ddp_rank # increments by ddp_world_size each time, so each rank processes unique documents - it = 0 # iteration counter + cursor = ddp_rank # increments by ddp_world_size each time, so each rank processes unique documents + it = 0 # iteration counter while True: # Accumulate enough tokens for one iteration before yielding while len(token_buffer) < needed_tokens: @@ -134,49 +149,55 @@ def mid_data_generator(split): token_buffer.extend(ids) cursor += ddp_world_size if cursor >= dataset_size: - cursor -= dataset_size # wrap around for another epoch + cursor -= dataset_size # wrap around for another epoch if split == "train": - last_step = True # toggle last_step to True, which will terminate the training loop + last_step = True # toggle last_step to True, which will terminate the training loop # Stopping condition to respect num_iterations, if given it += 1 if num_iterations > 0 and it >= num_iterations: - last_step = True # toggle last_step to True, which will terminate the training loop + last_step = True # toggle last_step to True, which will terminate the training loop # Build up inputs/targets and yield for i in range(needed_tokens): scratch[i] = token_buffer.popleft() inputs_cpu = scratch[:-1].to(dtype=torch.int32) targets_cpu = scratch[1:] inputs = inputs_cpu.view(device_batch_size, max_seq_len).to(device=device, dtype=torch.int32, non_blocking=True) - targets = targets_cpu.view(device_batch_size, max_seq_len).to(device=device, dtype=torch.int64, non_blocking=True) + targets = targets_cpu.view(device_batch_size, max_seq_len).to( + device=device, dtype=torch.int64, non_blocking=True + ) if split == "train": if num_iterations > 0: - approx_progress = it / num_iterations # calculate progress from the max number of iterations + approx_progress = it / num_iterations # calculate progress from the max number of iterations else: - approx_progress = cursor / dataset_size # approximate progress as a fraction of the dataset + approx_progress = cursor / dataset_size # approximate progress as a fraction of the dataset yield inputs, targets + train_loader = mid_data_generator("train") build_val_loader = lambda: mid_data_generator("val") -progress = 0 # will go from 0 to 1 over the course of the epoch +progress = 0 # will go from 0 to 1 over the course of the epoch + # Learning rate scheduler def get_lr_multiplier(progress): # first 80% of training: no decay, then linearly ramp down to 0. return 1 if progress < 0.8 else 1 - (progress - 0.8) / 0.2 + # Momentum scheduler for Muon optimizer def get_muon_momentum(it): frac = min(it / 300, 1) momentum = (1 - frac) * 0.85 + frac * 0.95 return momentum + # ----------------------------------------------------------------------------- # Training loop -x, y = next(train_loader) # prefetch the very first batch of data +x, y = next(train_loader) # prefetch the very first batch of data min_val_bpb = float("inf") -smooth_train_loss = 0 # EMA of training loss -ema_beta = 0.9 # EMA decay factor -total_training_time = 0 # total wall-clock time of training +smooth_train_loss = 0 # EMA of training loss +ema_beta = 0.9 # EMA decay factor +total_training_time = 0 # total wall-clock time of training step = 0 while True: flops_so_far = num_flops_per_token * total_batch_size * step @@ -197,26 +218,28 @@ while True: print0(f"Step {step:05d} | Validation bpb: {val_bpb:.4f}") if val_bpb < min_val_bpb: min_val_bpb = val_bpb - wandb_run.log({ - "step": step, - "total_training_flops": flops_so_far, - "total_training_time": total_training_time, - "val/bpb": val_bpb, - }) + wandb_run.log( + { + "step": step, + "total_training_flops": flops_so_far, + "total_training_time": total_training_time, + "val/bpb": val_bpb, + } + ) model.train() # save checkpoint at the end of the run (only on master process) if master_process and last_step and not dry_run: - output_dirname = f"d{depth}" # e.g. d12 + output_dirname = f"d{depth}" # e.g. d12 checkpoint_dir = os.path.join(base_dir, "mid_checkpoints", output_dirname) save_checkpoint( checkpoint_dir, step, orig_model.state_dict(), - [opt.state_dict() for opt in optimizers], # TODO: make sure saving across ranks is done correctly + [opt.state_dict() for opt in optimizers], # TODO: make sure saving across ranks is done correctly { "step": step, - "val_bpb": val_bpb, # loss at last step + "val_bpb": val_bpb, # loss at last step "model_config": { "sequence_len": max_seq_len, "vocab_size": tokenizer.get_vocab_size(), @@ -225,8 +248,8 @@ while True: "n_kv_head": model.config.n_kv_head, "n_embd": model.config.n_embd, }, - "user_config": user_config, # inputs to the training script - } + "user_config": user_config, # inputs to the training script + }, ) if last_step: @@ -240,11 +263,11 @@ while True: for micro_step in range(grad_accum_steps): with autocast_ctx: loss = model(x, y) - train_loss = loss.detach() # for logging - loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here + train_loss = loss.detach() # for logging + loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here loss.backward() - x, y = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward - progress = max(progress, approx_progress) # only increase progress monotonically + x, y = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward + progress = max(progress, approx_progress) # only increase progress monotonically # step the optimizers lrm = get_lr_multiplier(progress) for opt in optimizers: @@ -265,47 +288,55 @@ while True: step += 1 # logging - smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss.item() # EMA the training loss - debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**(step + 1)) # debias the EMA + smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss.item() # EMA the training loss + debiased_smooth_loss = smooth_train_loss / (1 - ema_beta ** (step + 1)) # debias the EMA pct_done = 100 * progress tok_per_sec = int(total_batch_size / dt) flops_per_sec = num_flops_per_token * total_batch_size / dt - promised_flops_per_sec_h100 = 989e12 * ddp_world_size # bfloat16 H100 SXM and without 2:4 sparsity - mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in % + promised_flops_per_sec_h100 = 989e12 * ddp_world_size # bfloat16 H100 SXM and without 2:4 sparsity + mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in % if step > 10: - total_training_time += dt # only count the time after the first 10 steps - print0(f"step {step:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | total time: {total_training_time/60:.2f}m") + total_training_time += dt # only count the time after the first 10 steps + print0( + f"step {step:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | total time: {total_training_time / 60:.2f}m" + ) if step % 10 == 0: - wandb_run.log({ - "step": step, - "total_training_flops": flops_so_far, - "total_training_time": total_training_time, - "train/loss": debiased_smooth_loss, - "train/lrm": lrm, - "train/dt": dt, - "train/tok_per_sec": tok_per_sec, - "train/mfu": mfu, - }) + wandb_run.log( + { + "step": step, + "total_training_flops": flops_so_far, + "total_training_time": total_training_time, + "train/loss": debiased_smooth_loss, + "train/lrm": lrm, + "train/dt": dt, + "train/tok_per_sec": tok_per_sec, + "train/mfu": mfu, + } + ) # print a few more stats print0(f"Peak memory usage: {get_max_memory() / 1024 / 1024:.2f}MiB") -print0(f"Total training time: {total_training_time/60:.2f}m") +print0(f"Total training time: {total_training_time / 60:.2f}m") print0(f"Minimum validation bpb: {min_val_bpb:.4f}") # Log to report if not dry_run: from nanochat.report import get_report - get_report().log(section="Midtraining", data=[ - user_config, # CLI args - { # stats about the training setup - "Number of iterations": step, - "DDP world size": ddp_world_size, - }, - { # stats about training outcomes - "Minimum validation bpb": min_val_bpb, - } - ]) + + get_report().log( + section="Midtraining", + data=[ + user_config, # CLI args + { # stats about the training setup + "Number of iterations": step, + "DDP world size": ddp_world_size, + }, + { # stats about training outcomes + "Minimum validation bpb": min_val_bpb, + }, + ], + ) # cleanup -wandb_run.finish() # wandb run finish +wandb_run.finish() # wandb run finish compute_cleanup() diff --git a/scripts/tok_eval.py b/scripts/tok_eval.py index 9233d71..9b52907 100644 --- a/scripts/tok_eval.py +++ b/scripts/tok_eval.py @@ -2,8 +2,8 @@ Evaluate compression ratio of the tokenizer. """ -from nanochat.tokenizer import get_tokenizer, RustBPETokenizer from nanochat.dataset import parquets_iter_batched +from nanochat.tokenizer import RustBPETokenizer, get_tokenizer # Random text I got from a random website this morning news_text = r""" @@ -165,11 +165,10 @@ tokenizer_results = {} vocab_sizes = {} for tokenizer_name in ["gpt2", "gpt4", "ours"]: - if tokenizer_name == "gpt2": - tokenizer = RustBPETokenizer.from_pretrained("gpt2") # gpt-2 base model tokenizer + tokenizer = RustBPETokenizer.from_pretrained("gpt2") # gpt-2 base model tokenizer elif tokenizer_name == "gpt4": - tokenizer = RustBPETokenizer.from_pretrained("cl100k_base") # gpt-4 base model tokenizer + tokenizer = RustBPETokenizer.from_pretrained("cl100k_base") # gpt-4 base model tokenizer else: tokenizer = get_tokenizer() @@ -183,11 +182,7 @@ for tokenizer_name in ["gpt2", "gpt4", "ours"]: encoded_bytes = text.encode('utf-8') ratio = len(encoded_bytes) / len(encoded) - tokenizer_results[tokenizer_name][name] = { - 'bytes': len(encoded_bytes), - 'tokens': len(encoded), - 'ratio': ratio - } + tokenizer_results[tokenizer_name][name] = {'bytes': len(encoded_bytes), 'tokens': len(encoded), 'ratio': ratio} # ANSI color codes GREEN = '\033[92m' @@ -195,11 +190,12 @@ RED = '\033[91m' RESET = '\033[0m' # Print vocab sizes -print(f"\nVocab sizes:") +print("\nVocab sizes:") print(f"GPT-2: {vocab_sizes['gpt2']}") print(f"GPT-4: {vocab_sizes['gpt4']}") print(f"Ours: {vocab_sizes['ours']}") + def print_comparison(baseline_name, baseline_results, ours_results, all_text): """Print comparison table between baseline tokenizer and ours.""" print(f"\nComparison with {baseline_name}:") @@ -230,13 +226,16 @@ def print_comparison(baseline_name, baseline_results, ours_results, all_text): better = "Tie" diff_color = "" - print(f"{name:<10} {baseline_data['bytes']:<8} " - f"{baseline_color}{baseline_data['tokens']:<7}{RESET} " - f"{baseline_color}{baseline_data['ratio']:<7.2f}{RESET} " - f"{ours_color}{ours_data['tokens']:<7}{RESET} " - f"{ours_color}{ours_data['ratio']:<7.2f}{RESET} " - f"{diff_color}{relative_diff:+7.1f}%{RESET} " - f"{better:<10}") + print( + f"{name:<10} {baseline_data['bytes']:<8} " + f"{baseline_color}{baseline_data['tokens']:<7}{RESET} " + f"{baseline_color}{baseline_data['ratio']:<7.2f}{RESET} " + f"{ours_color}{ours_data['tokens']:<7}{RESET} " + f"{ours_color}{ours_data['ratio']:<7.2f}{RESET} " + f"{diff_color}{relative_diff:+7.1f}%{RESET} " + f"{better:<10}" + ) + # Print comparisons print_comparison("GPT-2", tokenizer_results['gpt2'], tokenizer_results['ours'], all_text) @@ -244,6 +243,7 @@ print_comparison("GPT-4", tokenizer_results['gpt4'], tokenizer_results['ours'], # Log to report from nanochat.report import get_report + lines = [] for baseline_name in ["GPT-2", "GPT-4"]: baseline_key = baseline_name.lower().replace('-', '') @@ -251,15 +251,26 @@ for baseline_name in ["GPT-2", "GPT-4"]: ours_results = tokenizer_results['ours'] lines.append(f"### Comparison with {baseline_name}") lines.append("") - lines.append("| Text Type | Bytes | " + baseline_name + " Tokens | " + baseline_name + " Ratio | Ours Tokens | Ours Ratio | Relative Diff % |") + lines.append( + "| Text Type | Bytes | " + + baseline_name + + " Tokens | " + + baseline_name + + " Ratio | Ours Tokens | Ours Ratio | Relative Diff % |" + ) lines.append("|-----------|-------|--------------|--------------|-------------|------------|-----------------|") for name, text in all_text: baseline_data = baseline_results[name] ours_data = ours_results[name] relative_diff = ((baseline_data['tokens'] - ours_data['tokens']) / baseline_data['tokens']) * 100 - lines.append(f"| {name} | {baseline_data['bytes']} | {baseline_data['tokens']} | {baseline_data['ratio']:.2f} | {ours_data['tokens']} | {ours_data['ratio']:.2f} | {relative_diff:+.1f}% |") + lines.append( + f"| {name} | {baseline_data['bytes']} | {baseline_data['tokens']} | {baseline_data['ratio']:.2f} | {ours_data['tokens']} | {ours_data['ratio']:.2f} | {relative_diff:+.1f}% |" + ) lines.append("") report_markdown = "\n".join(lines) -get_report().log(section="Tokenizer evaluation", data=[ - report_markdown, -]) +get_report().log( + section="Tokenizer evaluation", + data=[ + report_markdown, + ], +) diff --git a/scripts/tok_train.py b/scripts/tok_train.py index c2faf17..d971d43 100644 --- a/scripts/tok_train.py +++ b/scripts/tok_train.py @@ -2,19 +2,24 @@ Train a tokenizer using the HuggingFace Tokenizers library. In the style of GPT-4 tokenizer. """ + +import argparse import os import time -import argparse + import torch -from nanochat.tokenizer import RustBPETokenizer + from nanochat.common import get_base_dir from nanochat.dataset import parquets_iter_batched +from nanochat.tokenizer import RustBPETokenizer # ----------------------------------------------------------------------------- # Parse command line arguments parser = argparse.ArgumentParser(description='Train a BPE tokenizer') -parser.add_argument('--max_chars', type=int, default=10_000_000_000, help='Maximum characters to train on (default: 10B)') +parser.add_argument( + '--max_chars', type=int, default=10_000_000_000, help='Maximum characters to train on (default: 10B)' +) parser.add_argument('--doc_cap', type=int, default=10_000, help='Maximum characters per document (default: 10,000)') parser.add_argument('--vocab_size', type=int, default=65536, help='Vocabulary size (default: 65536 = 2^16)') args = parser.parse_args() @@ -25,6 +30,7 @@ print(f"vocab_size: {args.vocab_size:,}") # ----------------------------------------------------------------------------- # Text iterator + def text_iterator(): """ 1) Flatten the batches into a single iterator @@ -36,11 +42,13 @@ def text_iterator(): for doc in batch: doc_text = doc if len(doc_text) > args.doc_cap: - doc_text = doc_text[:args.doc_cap] + doc_text = doc_text[: args.doc_cap] nchars += len(doc_text) yield doc_text if nchars > args.max_chars: return + + text_iter = text_iterator() # ----------------------------------------------------------------------------- @@ -78,11 +86,11 @@ special_set = set(tokenizer.get_special_tokens()) token_strings = [tokenizer.decode([token_id]) for token_id in range(vocab_size)] token_bytes = [] for token_id in range(vocab_size): - token_str = token_strings[token_id] # the Python string representation of this token + token_str = token_strings[token_id] # the Python string representation of this token if token_str in special_set: - token_bytes.append(0) # special characters are not counted + token_bytes.append(0) # special characters are not counted else: - id_bytes = len(token_str.encode("utf-8")) # number of bytes that make up this token + id_bytes = len(token_str.encode("utf-8")) # number of bytes that make up this token token_bytes.append(id_bytes) token_bytes = torch.tensor(token_bytes, dtype=torch.int32, device='cpu') token_bytes_path = os.path.join(tokenizer_dir, "token_bytes.pt") @@ -92,15 +100,19 @@ print(f"Saved token_bytes to {token_bytes_path}") # Log to report from nanochat.report import get_report + token_bytes_nonzero = (token_bytes[token_bytes > 0]).to(dtype=torch.float32) -get_report().log(section="Tokenizer training", data=[ - vars(args), # argparse command line arguments - {"train_time": train_time}, - {"num_special_tokens": len(special_set)}, - { - "token_bytes_min": int(token_bytes_nonzero.min().item()), - "token_bytes_max": int(token_bytes_nonzero.max().item()), - "token_bytes_mean": token_bytes_nonzero.mean().item(), - "token_bytes_std": token_bytes_nonzero.std().item(), - } -]) +get_report().log( + section="Tokenizer training", + data=[ + vars(args), # argparse command line arguments + {"train_time": train_time}, + {"num_special_tokens": len(special_set)}, + { + "token_bytes_min": int(token_bytes_nonzero.min().item()), + "token_bytes_max": int(token_bytes_nonzero.max().item()), + "token_bytes_mean": token_bytes_nonzero.mean().item(), + "token_bytes_std": token_bytes_nonzero.std().item(), + }, + ], +) diff --git a/tasks/arc.py b/tasks/arc.py index 862cca9..239af82 100644 --- a/tasks/arc.py +++ b/tasks/arc.py @@ -4,10 +4,11 @@ https://huggingface.co/datasets/allenai/ai2_arc """ from datasets import load_dataset + from tasks.common import Task, render_mc -class ARC(Task): +class ARC(Task): def __init__(self, subset, split, **kwargs): super().__init__(**kwargs) assert subset in ["ARC-Easy", "ARC-Challenge"], "ARC subset must be ARC-Easy or ARC-Challenge" @@ -23,26 +24,25 @@ class ARC(Task): def get_example(self, index): row = self.ds[index] - question = row["question"] # the question text - choices = row["choices"]["text"] # the text of each choice - answer_string = row["answerKey"] # e.g. "A", "B", "C", "D" - letters = row["choices"]["label"] # e.g. ["A", "B", "C", "D"] - assert answer_string in letters, f"ARC answer {answer_string} must be one of {letters}" # sanity check + question = row["question"] # the question text + choices = row["choices"]["text"] # the text of each choice + answer_string = row["answerKey"] # e.g. "A", "B", "C", "D" + letters = row["choices"]["label"] # e.g. ["A", "B", "C", "D"] + assert answer_string in letters, f"ARC answer {answer_string} must be one of {letters}" # sanity check # create and return the Conversation object user_message = render_mc(question, letters, choices) - messages = [ - {"role": "user", "content": user_message}, - {"role": "assistant", "content": answer_string} - ] + messages = [{"role": "user", "content": user_message}, {"role": "assistant", "content": answer_string}] conversation = { "messages": messages, - "letters": letters, # useful during evaluation, so we can narrow and clamp the assistant prediction to one of the letters + "letters": letters, # useful during evaluation, so we can narrow and clamp the assistant prediction to one of the letters } return conversation def evaluate(self, conversation, assistant_response): # the assert here is not strictly speaking needed, but currently the way we eval, we expect this to be true # I'm going to leave the assert here to prevent footguns, but possibly in the future can remove it. - assert assistant_response in conversation['letters'], f"ARC answer {assistant_response} is expected to be one of {conversation['letters']}" - assistant_message = conversation['messages'][-1]['content'] # e.g. "A" + assert assistant_response in conversation['letters'], ( + f"ARC answer {assistant_response} is expected to be one of {conversation['letters']}" + ) + assistant_message = conversation['messages'][-1]['content'] # e.g. "A" return assistant_response == assistant_message diff --git a/tasks/common.py b/tasks/common.py index dcd2e91..250d619 100644 --- a/tasks/common.py +++ b/tasks/common.py @@ -7,6 +7,7 @@ Example tasks: MMLU, ARC-Easy, ARC-Challenge, GSM8K, HumanEval, SmolTalk. import random + class Task: """ Base class of a Task. Allows for lightweight slicing of the underlying dataset. @@ -18,7 +19,7 @@ class Task: assert stop is None or stop >= start, f"Stop should be greater than or equal to start, got {stop} and {start}" assert step >= 1, f"Step must be strictly positive, got {step}" self.start = start - self.stop = stop # could be None here + self.stop = stop # could be None here self.step = step @property @@ -37,8 +38,8 @@ class Task: stop = self.num_examples() if self.stop is None else self.stop step = self.step span = stop - start - num = (span + step - 1) // step # ceil_div(span, step) - assert num >= 0, f"Negative number of examples???: {num}" # prevent footguns + num = (span + step - 1) // step # ceil_div(span, step) + assert num >= 0, f"Negative number of examples???: {num}" # prevent footguns return num def __getitem__(self, index: int): @@ -81,7 +82,9 @@ class TaskMixture(Task): Access conversations according to a deterministic shuffle of all examples. This ensures tasks are mixed throughout training, regardless of dataset size. """ - assert 0 <= index < self.num_conversations, f"Index {index} out of range for mixture with {self.num_conversations} conversations" + assert 0 <= index < self.num_conversations, ( + f"Index {index} out of range for mixture with {self.num_conversations} conversations" + ) task_idx, local_idx = self.index_map[index] return self.tasks[task_idx][local_idx] @@ -102,7 +105,9 @@ class TaskSequence(Task): return self.num_conversations def get_example(self, index): - assert 0 <= index < self.num_conversations, f"Index {index} out of range for sequence with {self.num_conversations} conversations" + assert 0 <= index < self.num_conversations, ( + f"Index {index} out of range for sequence with {self.num_conversations} conversations" + ) for task_idx, task_length in enumerate(self.lengths): if index < task_length: return self.tasks[task_idx][index] diff --git a/tasks/customjson.py b/tasks/customjson.py index e1b5f0b..66edc96 100644 --- a/tasks/customjson.py +++ b/tasks/customjson.py @@ -3,10 +3,12 @@ CustomJSON task for loading conversations from JSONL files. Each line in the JSONL file should be a JSON array of messages. """ -import os import json +import os + from tasks.common import Task + class CustomJSON(Task): """ Load conversations from a JSONL file. @@ -25,14 +27,18 @@ class CustomJSON(Task): print("-" * 80) print(f"Warning: File {filepath} does not exist") print("HINT (Oct 21 2025)") - print("If you recently did a git pull and suddely see this, it might be due to the new addition of identity conversations") + print( + "If you recently did a git pull and suddely see this, it might be due to the new addition of identity conversations" + ) print("See this discussion for more details: https://github.com/karpathy/nanochat/discussions/139") print("Quick fix: simply run the following command to download the file and you're done:") - print(f"curl -L -o {filepath} https://karpathy-public.s3.us-west-2.amazonaws.com/identity_conversations.jsonl") + print( + f"curl -L -o {filepath} https://karpathy-public.s3.us-west-2.amazonaws.com/identity_conversations.jsonl" + ) print("-" * 80) else: - with open(filepath, 'r', encoding='utf-8') as f: + with open(filepath, encoding='utf-8') as f: for line in f: line = line.strip() if not line: # skip empty lines @@ -46,7 +52,9 @@ class CustomJSON(Task): assert "role" in message, f"Message {i} missing 'role' field" assert "content" in message, f"Message {i} missing 'content' field" expected_role = "user" if i % 2 == 0 else "assistant" - assert message["role"] == expected_role, f"Message {i} has role {message['role']} but should be {expected_role}" + assert message["role"] == expected_role, ( + f"Message {i} has role {message['role']} but should be {expected_role}" + ) assert isinstance(message["content"], str), f"Message {i} content must be a string" self.conversations.append(messages) @@ -62,4 +70,3 @@ class CustomJSON(Task): "messages": messages, } return conversation - diff --git a/tasks/gsm8k.py b/tasks/gsm8k.py index c05e21c..76bbcff 100644 --- a/tasks/gsm8k.py +++ b/tasks/gsm8k.py @@ -15,11 +15,14 @@ Notice that GSM8K uses tool calls inside << >> tags. """ import re + from datasets import load_dataset + from tasks.common import Task - GSM_RE = re.compile(r"#### (\-?[0-9\.\,]+)") + + def extract_answer(completion): """ Extract the numerical answer after #### marker. @@ -35,7 +38,6 @@ def extract_answer(completion): class GSM8K(Task): - def __init__(self, subset, split, **kwargs): super().__init__(**kwargs) assert subset in ["main", "socratic"], "GSM8K subset must be main|socratic" @@ -50,10 +52,10 @@ class GSM8K(Task): return len(self.ds) def get_example(self, index): - """ Get a single problem from the dataset. """ + """Get a single problem from the dataset.""" row = self.ds[index] - question = row['question'] # string of the question prompt - answer = row['answer'] # string of the full solution and the answer after #### marker + question = row['question'] # string of the question prompt + answer = row['answer'] # string of the full solution and the answer after #### marker # Create and return the Conversation object # This is tricky because GSM8K uses tool calls, which we need to parse here. assistant_message_parts = [] @@ -76,8 +78,8 @@ class GSM8K(Task): assistant_message_parts.append({"type": "text", "text": part}) # No put it all together messages = [ - {"role": "user", "content": question}, # note: simple string - {"role": "assistant", "content": assistant_message_parts}, # note: list of parts (as dicts) + {"role": "user", "content": question}, # note: simple string + {"role": "assistant", "content": assistant_message_parts}, # note: list of parts (as dicts) ] conversation = { "messages": messages, @@ -99,7 +101,7 @@ class GSM8K(Task): assistant_message = conversation['messages'][-1] assert assistant_message['role'] == "assistant", "Last message must be from the Assistant" assert isinstance(assistant_message['content'], list), "This is expected to be a list of parts" - last_text_part = assistant_message['content'][-1]['text'] # this contains the final answer in GSM8K + last_text_part = assistant_message['content'][-1]['text'] # this contains the final answer in GSM8K # Extract both the ground truth answer and the predicted answer ref_num = extract_answer(last_text_part) pred_num = extract_answer(assistant_response) diff --git a/tasks/humaneval.py b/tasks/humaneval.py index e9dd489..21d3ffa 100644 --- a/tasks/humaneval.py +++ b/tasks/humaneval.py @@ -5,10 +5,13 @@ It is a coding benchmark. """ import re + from datasets import load_dataset + from nanochat.execution import execute_code from tasks.common import Task + def extract_imports(prompt): """Extract import statements from the beginning of a code block.""" imports = [] @@ -21,6 +24,7 @@ def extract_imports(prompt): break return '\n'.join(imports) + def extract_program(completion): """ Extract Python code from LLM completion. @@ -44,8 +48,8 @@ def extract_program(completion): # No code blocks found, return the whole completion return completion.strip() -class HumanEval(Task): +class HumanEval(Task): def __init__(self, **kwargs): super().__init__(**kwargs) self.ds = load_dataset("openai/openai_humaneval", split="test").shuffle(seed=42) @@ -58,12 +62,12 @@ class HumanEval(Task): return len(self.ds) def get_example(self, index): - """ Get a single problem from the dataset. """ + """Get a single problem from the dataset.""" row = self.ds[index] - prompt = row['prompt'] # prompts in HumanEval are the beginning of the program - solution = row['canonical_solution'] # the correct continuation of the program - entry_point = row['entry_point'] # the function to check - test = row['test'] # the test cases + prompt = row['prompt'] # prompts in HumanEval are the beginning of the program + solution = row['canonical_solution'] # the correct continuation of the program + entry_point = row['entry_point'] # the function to check + test = row['test'] # the test cases complete_solution = f"{prompt}\n{solution}" messages = [ {"role": "user", "content": prompt}, @@ -71,13 +75,13 @@ class HumanEval(Task): ] conversation = { "messages": messages, - "entry_point": entry_point, # needed during evaluation - "test": test, # needed during evaluation + "entry_point": entry_point, # needed during evaluation + "test": test, # needed during evaluation } return conversation def evaluate(self, conversation, completion): - """ Given (conversation, completion), return boolean success of the completion. """ + """Given (conversation, completion), return boolean success of the completion.""" # the prompt will contain the imports and the function signature imports = extract_imports(conversation['messages'][0]['content']) # the completion will usually contain the whole function diff --git a/tasks/mmlu.py b/tasks/mmlu.py index 3ba2254..d95f9c6 100644 --- a/tasks/mmlu.py +++ b/tasks/mmlu.py @@ -4,12 +4,71 @@ https://huggingface.co/datasets/cais/mmlu """ from datasets import load_dataset + from tasks.common import Task, render_mc -class MMLU(Task): +class MMLU(Task): letters = ('A', 'B', 'C', 'D') - groups = ('abstract_algebra', 'anatomy', 'astronomy', 'business_ethics', 'clinical_knowledge', 'college_biology', 'college_chemistry', 'college_computer_science', 'college_mathematics', 'college_medicine', 'college_physics', 'computer_security', 'conceptual_physics', 'econometrics', 'electrical_engineering', 'elementary_mathematics', 'formal_logic', 'global_facts', 'high_school_biology', 'high_school_chemistry', 'high_school_computer_science', 'high_school_european_history', 'high_school_geography', 'high_school_government_and_politics', 'high_school_macroeconomics', 'high_school_mathematics', 'high_school_microeconomics', 'high_school_physics', 'high_school_psychology', 'high_school_statistics', 'high_school_us_history', 'high_school_world_history', 'human_aging', 'human_sexuality', 'international_law', 'jurisprudence', 'logical_fallacies', 'machine_learning', 'management', 'marketing', 'medical_genetics', 'miscellaneous', 'moral_disputes', 'moral_scenarios', 'nutrition', 'philosophy', 'prehistory', 'professional_accounting', 'professional_law', 'professional_medicine', 'professional_psychology', 'public_relations', 'security_studies', 'sociology', 'us_foreign_policy', 'virology', 'world_religions') + groups = ( + 'abstract_algebra', + 'anatomy', + 'astronomy', + 'business_ethics', + 'clinical_knowledge', + 'college_biology', + 'college_chemistry', + 'college_computer_science', + 'college_mathematics', + 'college_medicine', + 'college_physics', + 'computer_security', + 'conceptual_physics', + 'econometrics', + 'electrical_engineering', + 'elementary_mathematics', + 'formal_logic', + 'global_facts', + 'high_school_biology', + 'high_school_chemistry', + 'high_school_computer_science', + 'high_school_european_history', + 'high_school_geography', + 'high_school_government_and_politics', + 'high_school_macroeconomics', + 'high_school_mathematics', + 'high_school_microeconomics', + 'high_school_physics', + 'high_school_psychology', + 'high_school_statistics', + 'high_school_us_history', + 'high_school_world_history', + 'human_aging', + 'human_sexuality', + 'international_law', + 'jurisprudence', + 'logical_fallacies', + 'machine_learning', + 'management', + 'marketing', + 'medical_genetics', + 'miscellaneous', + 'moral_disputes', + 'moral_scenarios', + 'nutrition', + 'philosophy', + 'prehistory', + 'professional_accounting', + 'professional_law', + 'professional_medicine', + 'professional_psychology', + 'public_relations', + 'security_studies', + 'sociology', + 'us_foreign_policy', + 'virology', + 'world_religions', + ) def __init__(self, subset, split, **kwargs): super().__init__(**kwargs) @@ -33,28 +92,27 @@ class MMLU(Task): def get_example(self, index): row = self.ds[index] - question = row["question"] # the question text - choices = row["choices"] # the text of each choice - answer = row["answer"] # index of the answer, e.g. 0,1,2,3 (for A,B,C,D) - subject = row["subject"] # e.g. "college_biology", "college_chemistry", etc. + question = row["question"] # the question text + choices = row["choices"] # the text of each choice + answer = row["answer"] # index of the answer, e.g. 0,1,2,3 (for A,B,C,D) + subject = row["subject"] # e.g. "college_biology", "college_chemistry", etc. assert len(choices) == 4, "MMLU should have 4 choices" # create and return the Conversation object user_message = render_mc(question, self.letters, choices) assistant_message = self.letters[answer] - messages = [ - {"role": "user", "content": user_message}, - {"role": "assistant", "content": assistant_message} - ] + messages = [{"role": "user", "content": user_message}, {"role": "assistant", "content": assistant_message}] conversation = { "messages": messages, - "subject": subject, # might be useful later for grouping metrics by subject - "letters": self.letters, # useful during evaluation, so we can narrow and clamp the assistant prediction to one of the letters + "subject": subject, # might be useful later for grouping metrics by subject + "letters": self.letters, # useful during evaluation, so we can narrow and clamp the assistant prediction to one of the letters } return conversation def evaluate(self, conversation, assistant_response): # the assert here is not strictly speaking needed, but currently the way we eval, we expect this to be true # I'm going to leave the assert here to prevent footguns, but possibly in the future can remove it. - assert assistant_response in self.letters, f"MMLU answer {assistant_response} is expected to be one of {self.letters}" - assistant_message = conversation['messages'][-1]['content'] # e.g. "A" + assert assistant_response in self.letters, ( + f"MMLU answer {assistant_response} is expected to be one of {self.letters}" + ) + assistant_message = conversation['messages'][-1]['content'] # e.g. "A" return assistant_response == assistant_message diff --git a/tasks/smoltalk.py b/tasks/smoltalk.py index b4d4f5f..47fb0fb 100644 --- a/tasks/smoltalk.py +++ b/tasks/smoltalk.py @@ -5,10 +5,12 @@ We use the "smol" version, which is more appropriate for smaller models. """ from datasets import load_dataset + from tasks.common import Task + class SmolTalk(Task): - """ smol-smoltalk dataset. train is 460K rows, test is 24K rows. """ + """smol-smoltalk dataset. train is 460K rows, test is 24K rows.""" def __init__(self, split, **kwargs): super().__init__(**kwargs) @@ -29,14 +31,16 @@ class SmolTalk(Task): assert len(messages) >= 1 first_message = messages[0] if first_message["role"] == "system": - rest_messages = messages[1:] # optional system message is OK + rest_messages = messages[1:] # optional system message is OK else: rest_messages = messages assert len(rest_messages) >= 2, "SmolTalk messages must have at least 2 messages" for i, message in enumerate(rest_messages): # user and assistant alternate as user,assistant,user,assistant,... expected_role = "user" if i % 2 == 0 else "assistant" - assert message["role"] == expected_role, f"Message {i} has role {message['role']} but should be {expected_role}" + assert message["role"] == expected_role, ( + f"Message {i} has role {message['role']} but should be {expected_role}" + ) assert isinstance(message["content"], str), "Content must be a string" # --------------------------------------------------------------------- # create and return the Conversation object (ok to emit the system message too) diff --git a/tasks/spellingbee.py b/tasks/spellingbee.py index 3b45305..4d57697 100644 --- a/tasks/spellingbee.py +++ b/tasks/spellingbee.py @@ -26,10 +26,11 @@ To preview a few example conversations, run: python -m tasks.spellingbee """ -import re import random -from tasks.common import Task +import re + from nanochat.common import download_file_with_lock +from tasks.common import Task # Letters of the alphabet LETTERS = "abcdefghijklmnopqrstuvwxyz" @@ -38,6 +39,8 @@ WORD_LIST_URL = "https://raw.githubusercontent.com/dwyl/english-words/refs/heads # Identical to gsm8k's answer extraction ANSWER_RE = re.compile(r"#### (\-?[0-9\.\,]+)") + + def extract_answer(completion): """ Extract the numerical answer after #### marker. @@ -49,6 +52,7 @@ def extract_answer(completion): return match_str return None + # User message templates for data augmentation USER_MSG_TEMPLATES = [ "How many {letter} are in the word {word}", @@ -110,8 +114,8 @@ USER_MSG_TEMPLATES = [ "{word}に{letter}が何回出てくる", ] -class SpellingBee(Task): +class SpellingBee(Task): def __init__(self, size=1000, split="train", **kwargs): super().__init__(**kwargs) assert split in ["train", "test"], "SpellingBee split must be train|test" @@ -119,7 +123,7 @@ class SpellingBee(Task): self.split = split filename = WORD_LIST_URL.split("/")[-1] word_list_path = download_file_with_lock(WORD_LIST_URL, filename) - with open(word_list_path, 'r', encoding='utf-8') as f: + with open(word_list_path, encoding='utf-8') as f: words = [line.strip() for line in f] self.words = words @@ -131,7 +135,7 @@ class SpellingBee(Task): return self.size def get_example(self, index): - seed = index if self.split == "train" else -(index + 1) # avoid collision at 0 + seed = index if self.split == "train" else -(index + 1) # avoid collision at 0 rng = random.Random(seed) # pick a random word @@ -148,12 +152,12 @@ class SpellingBee(Task): if rng.random() < 0.3: template = template.lower() quote_options = ['', "'", '"'] - letter_quote = rng.choice(quote_options) # is the letter quoted? - word_quote = rng.choice(quote_options) # is the word quoted? + letter_quote = rng.choice(quote_options) # is the letter quoted? + word_quote = rng.choice(quote_options) # is the word quoted? letter_wrapped = f"{letter_quote}{letter}{letter_quote}" word_wrapped = f"{word_quote}{word}{word_quote}" user_msg = template.format(letter=letter_wrapped, word=word_wrapped) - if rng.random() < 0.5: # 50% of people don't even use question marks + if rng.random() < 0.5: # 50% of people don't even use question marks user_msg += "?" # Now create the ideal assistant response - build as parts (text + tool calls) @@ -190,13 +194,12 @@ Then count the occurrences of '{letter}': # Part 4: Python output assistant_parts.append({"type": "python_output", "text": str(count)}) # Part 5: Final answer - assistant_parts.append({"type": "text", "text": f"\n\nPython gives us {count}.\n\nMy final answer is:\n\n#### {count}"}) + assistant_parts.append( + {"type": "text", "text": f"\n\nPython gives us {count}.\n\nMy final answer is:\n\n#### {count}"} + ) # return the full conversation - messages = [ - {"role": "user", "content": user_msg}, - {"role": "assistant", "content": assistant_parts} - ] + messages = [{"role": "user", "content": user_msg}, {"role": "assistant", "content": assistant_parts}] conversation = { "messages": messages, } @@ -222,7 +225,7 @@ Then count the occurrences of '{letter}': return is_correct def reward(self, conversation, assistant_response): - """ Use simple 0-1 reward just like gsm8k.""" + """Use simple 0-1 reward just like gsm8k.""" is_correct = self.evaluate(conversation, assistant_response) is_correct_float = float(is_correct) return is_correct_float @@ -238,10 +241,10 @@ class SimpleSpelling(Task): self.split = split filename = WORD_LIST_URL.split("/")[-1] word_list_path = download_file_with_lock(WORD_LIST_URL, filename) - with open(word_list_path, 'r', encoding='utf-8') as f: + with open(word_list_path, encoding='utf-8') as f: words = [line.strip() for line in f] rng = random.Random(42) - rng.shuffle(words) # use a different word order than the SpellingBee task + rng.shuffle(words) # use a different word order than the SpellingBee task self.words = words @property @@ -252,7 +255,7 @@ class SimpleSpelling(Task): return self.size def get_example(self, index): - seed = index if self.split == "train" else -(index + 1) # avoid collision at 0 + seed = index if self.split == "train" else -(index + 1) # avoid collision at 0 rng = random.Random(seed) # pick a random word word = rng.choice(self.words) @@ -260,7 +263,7 @@ class SimpleSpelling(Task): # return the full conversation messages = [ {"role": "user", "content": f"Spell the word: {word}"}, - {"role": "assistant", "content": f"{word}:{word_letters}"} + {"role": "assistant", "content": f"{word}:{word_letters}"}, ] conversation = { "messages": messages, @@ -269,7 +272,6 @@ class SimpleSpelling(Task): if __name__ == "__main__": - # preview the SpellingBee task, first 10 examples task = SpellingBee() for i in range(10): diff --git a/tests/test_engine.py b/tests/test_engine.py index 7403b36..71cf9bb 100644 --- a/tests/test_engine.py +++ b/tests/test_engine.py @@ -5,8 +5,10 @@ python -m pytest tests/test_engine.py -v """ import torch + from nanochat.engine import KVCache + def test_kv_cache_resize(): """ The KV cache was not resized correctly, more information here: @@ -21,11 +23,7 @@ def test_kv_cache_resize(): num_layers = 6 kv_cache = KVCache( - batch_size=batch_size, - num_heads=num_heads, - seq_len=seq_len, - head_dim=head_dim, - num_layers=num_layers + batch_size=batch_size, num_heads=num_heads, seq_len=seq_len, head_dim=head_dim, num_layers=num_layers ) # Insert a single token with a distinct fill value to all layers @@ -47,7 +45,9 @@ def test_kv_cache_resize(): insert_token(4) # Verify that the cache actually resized new_seq_len = kv_cache.kv_cache.shape[4] - assert new_seq_len > original_seq_len, f"Cache did not resize: original seq_len={original_seq_len}, new seq_len={new_seq_len}" + assert new_seq_len > original_seq_len, ( + f"Cache did not resize: original seq_len={original_seq_len}, new seq_len={new_seq_len}" + ) # Verify that the original 4 tokens are still intact after resize for layer_idx in range(num_layers): @@ -57,8 +57,12 @@ def test_kv_cache_resize(): expected_v = float(token_idx * 100) actual_k = kv_cache.kv_cache[layer_idx, 0, :, :, token_idx, :] actual_v = kv_cache.kv_cache[layer_idx, 1, :, :, token_idx, :] - assert (actual_k == expected_k).all(), f"Layer {layer_idx}, token {token_idx}: key corrupted, expected {expected_k}" - assert (actual_v == expected_v).all(), f"Layer {layer_idx}, token {token_idx}: value corrupted, expected {expected_v}" + assert (actual_k == expected_k).all(), ( + f"Layer {layer_idx}, token {token_idx}: key corrupted, expected {expected_k}" + ) + assert (actual_v == expected_v).all(), ( + f"Layer {layer_idx}, token {token_idx}: value corrupted, expected {expected_v}" + ) # And that the original cache matches resized cache original_k = original_cache[layer_idx, 0, :, :, token_idx, :] original_v = original_cache[layer_idx, 1, :, :, token_idx, :] diff --git a/tests/test_rustbpe.py b/tests/test_rustbpe.py index aca67fc..99eef83 100644 --- a/tests/test_rustbpe.py +++ b/tests/test_rustbpe.py @@ -18,18 +18,23 @@ python -m pytest tests/test_rustbpe.py -v -s -v is verbose, -s is show prints """ -import regex as re -from collections import Counter, defaultdict import time -import rustbpe -import tiktoken -import pytest +from collections import Counter, defaultdict -GPT4_SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+""" +import pytest +import regex as re +import tiktoken + +import rustbpe + +GPT4_SPLIT_PATTERN = ( + r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+""" +) # ----------------------------------------------------------------------------- # Reference tokenizer, pretty much copy pasted and pruned a bit from minbpe + def get_stats(ids, counts=None): """ Given a list of integers, return a dictionary of counts of consecutive pairs @@ -37,10 +42,11 @@ def get_stats(ids, counts=None): Optionally allows to update an existing dictionary of counts """ counts = {} if counts is None else counts - for pair in zip(ids, ids[1:]): # iterate consecutive elements + for pair in zip(ids, ids[1:]): # iterate consecutive elements counts[pair] = counts.get(pair, 0) + 1 return counts + def merge(ids, pair, idx): """ In the list of integers (ids), replace all consecutive occurrences @@ -51,7 +57,7 @@ def merge(ids, pair, idx): i = 0 while i < len(ids): # if not at the very last position AND the pair matches, replace it - if ids[i] == pair[0] and i < len(ids) - 1 and ids[i+1] == pair[1]: + if ids[i] == pair[0] and i < len(ids) - 1 and ids[i + 1] == pair[1]: newids.append(idx) i += 2 else: @@ -59,8 +65,8 @@ def merge(ids, pair, idx): i += 1 return newids -class RegexTokenizer: +class RegexTokenizer: def __init__(self, pattern=None): """ - pattern: optional string to override the default (GPT-4 split pattern) @@ -68,7 +74,7 @@ class RegexTokenizer: example: {'<|endoftext|>': 100257} """ self.pattern = GPT4_SPLIT_PATTERN if pattern is None else pattern - self.merges = {} # (int, int) -> int + self.merges = {} # (int, int) -> int self.compiled_pattern = re.compile(self.pattern) self.special_tokens = {} self.inverse_special_tokens = {} @@ -97,8 +103,8 @@ class RegexTokenizer: ids = [list(ch.encode("utf-8")) for ch in text_chunks] # iteratively merge the most common pairs to create new tokens - merges = {} # (int, int) -> int - vocab = {idx: bytes([idx]) for idx in range(256)} # idx -> bytes + merges = {} # (int, int) -> int + vocab = {idx: bytes([idx]) for idx in range(256)} # idx -> bytes for i in range(num_merges): # count the number of times every consecutive pair appears stats = {} @@ -125,11 +131,11 @@ class RegexTokenizer: vocab[idx] = vocab[pair[0]] + vocab[pair[1]] # prints if verbose: - print(f"merge {i+1}/{num_merges}: {pair} -> {idx} ({vocab[idx]}) had {stats[pair]} occurrences") + print(f"merge {i + 1}/{num_merges}: {pair} -> {idx} ({vocab[idx]}) had {stats[pair]} occurrences") # save class variables - self.merges = merges # used in encode() - self.vocab = vocab # used in decode() + self.merges = merges # used in encode() + self.vocab = vocab # used in decode() return ambiguous def _encode_chunk(self, text_bytes): @@ -145,7 +151,7 @@ class RegexTokenizer: # just the first pair in the list, arbitrarily # we can detect this terminating case by a membership check if pair not in self.merges: - break # nothing else can be merged anymore + break # nothing else can be merged anymore # otherwise let's merge the best pair (lowest merge index) idx = self.merges[pair] ids = merge(ids, pair, idx) @@ -158,14 +164,16 @@ class RegexTokenizer: # all chunks of text are encoded separately, then results are joined ids = [] for chunk in text_chunks: - chunk_bytes = chunk.encode("utf-8") # raw bytes + chunk_bytes = chunk.encode("utf-8") # raw bytes chunk_ids = self._encode_chunk(chunk_bytes) ids.extend(chunk_ids) return ids + # ----------------------------------------------------------------------------- # Faster Python tokenizer, optimized version of the reference tokenizer + def fast_merge_inplace(ids, pair, idx): """ In the list of integers (ids), replace all consecutive occurrences @@ -175,16 +183,15 @@ def fast_merge_inplace(ids, pair, idx): # Find all positions where the pair occurs i = 0 while i < len(ids) - 1: - if ids[i] == pair[0] and ids[i+1] == pair[1]: + if ids[i] == pair[0] and ids[i + 1] == pair[1]: ids[i] = idx - ids.pop(i+1) + ids.pop(i + 1) else: i += 1 return ids class FastRegexTokenizer: - def __init__(self, pattern=None): """ - pattern: optional string to override the default (GPT-4 split pattern) @@ -229,8 +236,8 @@ class FastRegexTokenizer: # input text preprocessing ids = [list(ch.encode("utf-8")) for ch in unique_chunks] # iteratively merge the most common pairs to create new tokens - merges = {} # (int, int) -> int - vocab = {idx: bytes([idx]) for idx in range(256)} # idx -> bytes + merges = {} # (int, int) -> int + vocab = {idx: bytes([idx]) for idx in range(256)} # idx -> bytes # Initial count: build stats and position tracking stats = defaultdict(int) @@ -262,31 +269,31 @@ class FastRegexTokenizer: chunk_count = chunk_counts[chunk_idx] ix = 0 while ix < len(chunk_ids) - 1: - if chunk_ids[ix] == pair[0] and chunk_ids[ix+1] == pair[1]: + if chunk_ids[ix] == pair[0] and chunk_ids[ix + 1] == pair[1]: # Track what pairs are being removed/added # Remove: (prev, A), (A, B), (B, next) if ix > 0: - old_left = (chunk_ids[ix-1], chunk_ids[ix]) + old_left = (chunk_ids[ix - 1], chunk_ids[ix]) count_changes[old_left] -= chunk_count # The merged pair disappears count_changes[pair] -= chunk_count if ix + 2 < len(chunk_ids): - old_right = (chunk_ids[ix+1], chunk_ids[ix+2]) + old_right = (chunk_ids[ix + 1], chunk_ids[ix + 2]) count_changes[old_right] -= chunk_count # Apply the merge chunk_ids[ix] = idx - chunk_ids.pop(ix+1) + chunk_ids.pop(ix + 1) # Add: (prev, C), (C, next) if ix > 0: - new_left = (chunk_ids[ix-1], chunk_ids[ix]) + new_left = (chunk_ids[ix - 1], chunk_ids[ix]) count_changes[new_left] += chunk_count if ix + 1 < len(chunk_ids): - new_right = (chunk_ids[ix], chunk_ids[ix+1]) + new_right = (chunk_ids[ix], chunk_ids[ix + 1]) count_changes[new_right] += chunk_count else: ix += 1 @@ -302,8 +309,9 @@ class FastRegexTokenizer: # Update positions for changed pairs - only check affected chunks for chunk_idx in affected_chunks: chunk_ids = ids[chunk_idx] - contains_pair = any((chunk_ids[j], chunk_ids[j+1]) == changed_pair - for j in range(len(chunk_ids) - 1)) + contains_pair = any( + (chunk_ids[j], chunk_ids[j + 1]) == changed_pair for j in range(len(chunk_ids) - 1) + ) if contains_pair: positions[changed_pair].add(chunk_idx) else: @@ -318,8 +326,8 @@ class FastRegexTokenizer: vocab[idx] = vocab[pair[0]] + vocab[pair[1]] # save class variables - self.merges = merges # used in encode() - self.vocab = vocab # used in decode() + self.merges = merges # used in encode() + self.vocab = vocab # used in decode() def register_special_tokens(self, special_tokens): # special_tokens is a dictionary of str -> int @@ -354,7 +362,7 @@ class FastRegexTokenizer: # just the first pair in the list, arbitrarily # we can detect this terminating case by a membership check if pair not in self.merges: - break # nothing else can be merged anymore + break # nothing else can be merged anymore # otherwise let's merge the best pair (lowest merge index) idx = self.merges[pair] ids = fast_merge_inplace(ids, pair, idx) @@ -367,18 +375,20 @@ class FastRegexTokenizer: # all chunks of text are encoded separately, then results are joined ids = [] for chunk in text_chunks: - chunk_bytes = chunk.encode("utf-8") # raw bytes + chunk_bytes = chunk.encode("utf-8") # raw bytes chunk_ids = self._encode_chunk(chunk_bytes) ids.extend(chunk_ids) return ids + # ----------------------------------------------------------------------------- # HuggingFace tokenizer +from tokenizers import Regex, decoders, pre_tokenizers from tokenizers import Tokenizer as HFTokenizer -from tokenizers import pre_tokenizers, decoders, Regex from tokenizers.models import BPE from tokenizers.trainers import BpeTrainer + class HuggingFaceTokenizer: """Light wrapper around HuggingFace Tokenizer for some utilities""" @@ -389,19 +399,23 @@ class HuggingFaceTokenizer: def train_from_iterator(cls, text_iterator, vocab_size): # train from an iterator of text # Configure the HuggingFace Tokenizer - tokenizer = HFTokenizer(BPE( - byte_fallback=True, # needed! - unk_token=None, - fuse_unk=False, - )) + tokenizer = HFTokenizer( + BPE( + byte_fallback=True, # needed! + unk_token=None, + fuse_unk=False, + ) + ) # Normalizer: None tokenizer.normalizer = None # Pre-tokenizer: GPT-4 style - gpt4_split_regex = Regex(GPT4_SPLIT_PATTERN) # huggingface demands that you wrap it in Regex!! - tokenizer.pre_tokenizer = pre_tokenizers.Sequence([ - pre_tokenizers.Split(pattern=gpt4_split_regex, behavior="isolated", invert=False), - pre_tokenizers.ByteLevel(add_prefix_space=False, use_regex=False) - ]) + gpt4_split_regex = Regex(GPT4_SPLIT_PATTERN) # huggingface demands that you wrap it in Regex!! + tokenizer.pre_tokenizer = pre_tokenizers.Sequence( + [ + pre_tokenizers.Split(pattern=gpt4_split_regex, behavior="isolated", invert=False), + pre_tokenizers.ByteLevel(add_prefix_space=False, use_regex=False), + ] + ) # Decoder: ByteLevel (it pairs together with the ByteLevel pre-tokenizer) tokenizer.decoder = decoders.ByteLevel() # Post-processor: None @@ -410,9 +424,9 @@ class HuggingFaceTokenizer: trainer = BpeTrainer( vocab_size=vocab_size, show_progress=True, - min_frequency=0, # no minimum frequency + min_frequency=0, # no minimum frequency initial_alphabet=pre_tokenizers.ByteLevel.alphabet(), - special_tokens=[], # no special tokens + special_tokens=[], # no special tokens ) # Kick off the training tokenizer.train_from_iterator(text_iterator, trainer) @@ -422,15 +436,19 @@ class HuggingFaceTokenizer: ids = self.tokenizer.encode(text, add_special_tokens=False).ids return ids + # ----------------------------------------------------------------------------- # Test all of the above + @pytest.fixture(scope="module") def enwik8_path(): """Fixture to download and cache enwik8 dataset.""" import os import zipfile + from nanochat.common import get_base_dir + base_dir = get_base_dir() # download and unzip enwik8 to .cache directory enwik8_url = "https://mattmahoney.net/dc/enwik8.zip" @@ -439,6 +457,7 @@ def enwik8_path(): if not os.path.exists(enwik8_local_path): print(f"Downloading enwik8 to {enwik8_local_path_zip}") import requests + response = requests.get(enwik8_url) with open(enwik8_local_path_zip, "wb") as f: f.write(response.content) @@ -455,15 +474,17 @@ def enwik8_path(): @pytest.fixture(scope="module") def enwik8_small(enwik8_path): """Fixture providing 100KB of enwik8 for quick tests.""" - with open(enwik8_path, "r", encoding="utf-8") as f: + with open(enwik8_path, encoding="utf-8") as f: return f.read(100_000) + @pytest.fixture(scope="module") def enwik8_large(enwik8_path): """Fixture providing 10MB of enwik8 for performance tests.""" - with open(enwik8_path, "r", encoding="utf-8") as f: + with open(enwik8_path, encoding="utf-8") as f: return f.read(10**7) + def time_function(func, *args, **kwargs): """Time a function call and return the result and elapsed time""" start_time = time.time() @@ -472,6 +493,7 @@ def time_function(func, *args, **kwargs): elapsed = end_time - start_time return result, elapsed + def test_correctness(enwik8_small): """Test that all tokenizer implementations produce the same results.""" text = enwik8_small @@ -482,7 +504,9 @@ def test_correctness(enwik8_small): print("\nTraining slow reference...") slow_reference_tokenizer = RegexTokenizer() ambiguous_flag, slow_reference_train_time = time_function(slow_reference_tokenizer.train, text, vocab_size) - slow_reference_ids, slow_reference_encode_time = time_function(slow_reference_tokenizer.encode_ordinary, encode_text) + slow_reference_ids, slow_reference_encode_time = time_function( + slow_reference_tokenizer.encode_ordinary, encode_text + ) print(f"Slow reference train time: {slow_reference_train_time:.4f}s") print(f"Slow reference encode time: {slow_reference_encode_time:.4f}s") print(slow_reference_ids[:20]) @@ -497,7 +521,9 @@ def test_correctness(enwik8_small): print("\nTraining fast reference...") fast_reference_tokenizer = FastRegexTokenizer() _, fast_reference_train_time = time_function(fast_reference_tokenizer.train, text, vocab_size) - fast_reference_ids, fast_reference_encode_time = time_function(fast_reference_tokenizer.encode_ordinary, encode_text) + fast_reference_ids, fast_reference_encode_time = time_function( + fast_reference_tokenizer.encode_ordinary, encode_text + ) print(f"Fast reference train time: {fast_reference_train_time:.4f}s") print(f"Fast reference encode time: {fast_reference_encode_time:.4f}s") print(fast_reference_ids[:20]) @@ -589,14 +615,16 @@ def test_training_performance(enwik8_large): assert hf_train_time > 0, "Training should take some time" # Print comparison - print(f"\n📊 Performance comparison:") + print("\n📊 Performance comparison:") print(f" RustBPE: {rustbpe_train_time:.4f}s") print(f" HuggingFace: {hf_train_time:.4f}s") - print(f" Speedup: {hf_train_time/rustbpe_train_time:.2f}x") + print(f" Speedup: {hf_train_time / rustbpe_train_time:.2f}x") + def test_interface(enwik8_small): """Test the RustBPETokenizer interface for training, encoding, decoding, and serialization.""" import tempfile + from nanochat.tokenizer import RustBPETokenizer # Simple train test