Compare commits

..

3 Commits

Author SHA1 Message Date
Eyal Frishman
59ed9392ed Add pre-commit documentation to README and GitHub workflow
Sets up a pre-commit workflow to automate code linting and formatting.

This ensures code quality and consistency by running checks before code is committed.
2025-12-05 19:59:38 +02:00
Eyal Frishman
449494c8b6 Fix (automatically) all pre-commit errors 2025-12-05 19:59:35 +02:00
Eyal Frishman
6587063479 Add pre-commit hooks for code formatting (not yet executed)
Adds pre-commit configuration to automate code formatting and linting.

This includes:
- ruff-check for linting
- ruff-format for code formatting
- pre-commit-hooks for various checks (whitespace, large files, etc.)
- codespell for fixing common misspellings in text

Special configurations added to `pyproject.toml`:
- Line length: 120 (more flexible than black's 88)
- Quote style: preserve (keeps existing quotes)
Rename ruff hook (avoid using legacy alias)
2025-12-05 19:59:35 +02:00
40 changed files with 778 additions and 1742 deletions

View File

@ -1,19 +1,4 @@
repos:
- repo: https://github.com/PyCQA/autoflake
rev: v2.3.1
hooks:
- id: autoflake
- repo: https://github.com/asottile/pyupgrade
rev: v3.21.2
hooks:
- id: pyupgrade
- repo: https://github.com/pre-commit/mirrors-isort
rev: v5.10.1
hooks:
- id: isort
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v6.0.0
hooks:
@ -28,10 +13,11 @@ repos:
- id: mixed-line-ending
args: [--fix=lf]
- repo: https://github.com/psf/black
rev: 25.11.0
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.14.8
hooks:
- id: black
- id: ruff-check
- id: ruff-format
- repo: https://github.com/codespell-project/codespell
rev: v2.4.1 # Use the latest stable version

View File

@ -133,11 +133,10 @@ Linting and formatting are enforced with [pre-commit](https://pre-commit.com/) b
Hook coverage (auto-fixes most issues; review and stage the changes afterward):
- [`autoflake`](https://github.com/PyCQA/autoflake): strips unused imports and variables.
- [`pyupgrade`](https://github.com/asottile/pyupgrade): rewrites code to the latest supported Python syntax.
- [`isort`](https://github.com/PyCQA/isort): keeps import blocks consistently ordered and grouped.
- [`ruff`](https://github.com/astral-sh/ruff): a fast Rust-based linter and formatter that replaces multiple tools:
- **Linting** (`ruff-check`): removes unused imports (like autoflake), upgrades syntax (like pyupgrade), and sorts imports (like isort).
- **Formatting** (`ruff-format`): applies consistent code formatting (like black), with quote style preserved.
- [`pre-commit-hooks`](https://github.com/pre-commit/pre-commit-hooks): repo hygiene (trim trailing whitespace, enforce LF endings/newlines, detect merge conflicts, block oversized files).
- [`black`](https://github.com/psf/black): applies the canonical Python formatting.
- [`codespell`](https://github.com/codespell-project/codespell): catches common spelling mistakes in code and docs (add false positives to `[tool.codespell].ignore-words-list` in `pyproject.toml`).
## File structure

View File

@ -269,9 +269,7 @@ hola hola, todo bien?
hej, hur är läget
ahoj, jak se máš
γειά, τι κάνεις
""".strip().split(
"\n"
)
""".strip().split("\n")
prompt = prompt.replace("%README%", readme)
@ -294,10 +292,7 @@ response_format = {
"type": "string",
"description": "The role of the speaker, either 'user' or 'assistant'",
},
"content": {
"type": "string",
"description": "The message content",
},
"content": {"type": "string", "description": "The message content"},
},
"required": ["role", "content"],
"additionalProperties": False,
@ -331,15 +326,15 @@ def generate_conversation(idx: int):
user_first_prompt = "\n".join(rng.choice(user_first_prompts) for _ in range(5))
payload = copy.deepcopy(base_payload)
modified_prompt = prompt.replace("%USER_FIRST_PROMPTS%", user_first_prompt)
payload["messages"] = [{"role": "user", "content": modified_prompt}]
payload['messages'] = [{"role": "user", "content": modified_prompt}]
response = requests.post(url, headers=headers, json=payload)
result = response.json()
content = result["choices"][0]["message"]["content"]
content = result['choices'][0]['message']['content']
# Parse the JSON response and unpack the messages
conversation_data = json.loads(content)
messages = conversation_data["messages"]
messages = conversation_data['messages']
return messages
@ -359,11 +354,8 @@ print(f"Generating {num_conversations} conversations with {num_workers} workers.
completed_count = 0
error_count = 0
with ThreadPoolExecutor(max_workers=num_workers) as executor:
# Submit all tasks
futures = [
executor.submit(generate_conversation, idx) for idx in range(num_conversations)
]
futures = [executor.submit(generate_conversation, idx) for idx in range(num_conversations)]
# Process results as they complete
for future in as_completed(futures):
@ -373,13 +365,13 @@ with ThreadPoolExecutor(max_workers=num_workers) as executor:
# Lightly validate the conversation structure
for i, message in enumerate(messages):
expected_role = "user" if i % 2 == 0 else "assistant"
assert (
message["role"] == expected_role
), f"Message {i} has role {message['role']} but should be {expected_role}"
assert message['role'] == expected_role, (
f"Message {i} has role {message['role']} but should be {expected_role}"
)
# If all looks good, write the messages to file
with open(output_file, "a") as f:
f.write(json.dumps(messages) + "\n")
with open(output_file, 'a') as f:
f.write(json.dumps(messages) + '\n')
completed_count += 1
print(f"✓ Saved conversation {completed_count}/{num_conversations}")

View File

@ -48,14 +48,12 @@ total_docs_processed = 0
total_time_spent = 0
t0 = time.time()
for doc in ds:
text = doc["text"]
text = doc['text']
shard_docs.append(text)
shard_characters += len(text)
collected_enough_chars = shard_characters >= chars_per_shard
docs_multiple_of_row_group_size = len(shard_docs) % row_group_size == 0
if (
collected_enough_chars and docs_multiple_of_row_group_size
): # leads to ~100MB of text (compressed)
if collected_enough_chars and docs_multiple_of_row_group_size: # leads to ~100MB of text (compressed)
shard_path = os.path.join(output_dir, f"shard_{shard_index:05d}.parquet")
shard_table = pa.Table.from_pydict({"text": shard_docs})
pq.write_table(

View File

@ -40,35 +40,33 @@ class DistAdamW(torch.optim.Optimizer):
rank_size = grad.shape[0] // world_size
grad_slice = torch.empty_like(grad[:rank_size])
reduce_scatter_futures.append(
dist.reduce_scatter_tensor(
grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True
).get_future()
dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()
)
grad_slices.append(grad_slice)
idx = 0
for group in self.param_groups:
beta1, beta2 = group["betas"]
eps = group["eps"]
wd = group["weight_decay"]
params = group["params"]
beta1, beta2 = group['betas']
eps = group['eps']
wd = group['weight_decay']
params = group['params']
for base in range(len(params)):
reduce_scatter_futures[idx].wait()
p = params[base]
rank_size = p.shape[0] // world_size
p_slice = p[rank * rank_size : (rank + 1) * rank_size]
lr = group["lr"] * getattr(p, "lr_mul", 1.0)
lr = group['lr'] * getattr(p, "lr_mul", 1.0)
state = self.state[p]
g_slice = grad_slices[idx]
# State init
if not state:
state["step"] = torch.tensor(0, dtype=torch.int64, device=p.device)
state["exp_avg"] = torch.zeros_like(p_slice)
state["exp_avg_sq"] = torch.zeros_like(p_slice)
exp_avg = state["exp_avg"]
exp_avg_sq = state["exp_avg_sq"]
state["step"] += 1
t = state["step"]
state['step'] = torch.tensor(0, dtype=torch.int64, device=p.device)
state['exp_avg'] = torch.zeros_like(p_slice)
state['exp_avg_sq'] = torch.zeros_like(p_slice)
exp_avg = state['exp_avg']
exp_avg_sq = state['exp_avg_sq']
state['step'] += 1
t = state['step']
# weight decay
if wd != 0:
eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0)
@ -85,7 +83,5 @@ class DistAdamW(torch.optim.Optimizer):
update = exp_avg.div(denom).mul_(step_size)
p_slice.add_(other=update, alpha=-1.0)
idx += 1
all_reduce_futures.append(
dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()
)
all_reduce_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future())
torch.futures.collect_all(all_reduce_futures).wait()

View File

@ -20,13 +20,11 @@ logger = logging.getLogger(__name__)
def log0(message):
if int(os.environ.get("RANK", 0)) == 0:
if int(os.environ.get('RANK', 0)) == 0:
logger.info(message)
def save_checkpoint(
checkpoint_dir, step, model_data, optimizer_data, meta_data, rank=0
):
def save_checkpoint(checkpoint_dir, step, model_data, optimizer_data, meta_data, rank=0):
if rank == 0:
os.makedirs(checkpoint_dir, exist_ok=True)
# Save the model state parameters
@ -40,9 +38,7 @@ def save_checkpoint(
logger.info(f"Saved metadata to: {meta_path}")
# Note that optimizer state is sharded across ranks, so each rank must save its own.
if optimizer_data is not None:
optimizer_path = os.path.join(
checkpoint_dir, f"optim_{step:06d}_rank{rank:d}.pt"
)
optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}_rank{rank:d}.pt")
torch.save(optimizer_data, optimizer_path)
logger.info(f"Saved optimizer state to: {optimizer_path}")
@ -54,9 +50,7 @@ def load_checkpoint(checkpoint_dir, step, device, load_optimizer=False, rank=0):
# Load the optimizer state if requested
optimizer_data = None
if load_optimizer:
optimizer_path = os.path.join(
checkpoint_dir, f"optim_{step:06d}_rank{rank:d}.pt"
)
optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}_rank{rank:d}.pt")
optimizer_data = torch.load(optimizer_path, map_location=device)
# Load the metadata
meta_path = os.path.join(checkpoint_dir, f"meta_{step:06d}.json")
@ -74,15 +68,10 @@ def build_model(checkpoint_dir, step, device, phase):
- meta data saved during base model training
"""
assert phase in ["train", "eval"], f"Invalid phase: {phase}"
model_data, optimizer_data, meta_data = load_checkpoint(
checkpoint_dir, step, device, load_optimizer=False
)
model_data, optimizer_data, meta_data = load_checkpoint(checkpoint_dir, step, device, load_optimizer=False)
if device.type in {"cpu", "mps"}:
# Convert bfloat16 tensors to float for CPU inference
model_data = {
k: v.float() if v.dtype == torch.bfloat16 else v
for k, v in model_data.items()
}
model_data = {k: v.float() if v.dtype == torch.bfloat16 else v for k, v in model_data.items()}
# Hack: fix torch compile issue, which prepends all keys with _orig_mod.
model_data = {k.removeprefix("_orig_mod."): v for k, v in model_data.items()}
model_config_kwargs = meta_data["model_config"]
@ -108,11 +97,7 @@ def build_model(checkpoint_dir, step, device, phase):
def find_largest_model(checkpoint_dir):
# attempt to guess the model tag: take the biggest model available
model_tags = [
f
for f in os.listdir(checkpoint_dir)
if os.path.isdir(os.path.join(checkpoint_dir, f))
]
model_tags = [f for f in os.listdir(checkpoint_dir) if os.path.isdir(os.path.join(checkpoint_dir, f))]
if not model_tags:
raise FileNotFoundError(f"No checkpoints found in {checkpoint_dir}")
# 1) normally all model tags are of the form d<number>, try that first:
@ -126,9 +111,7 @@ def find_largest_model(checkpoint_dir):
candidates.sort(key=lambda x: x[0], reverse=True)
return candidates[0][1]
# 2) if that failed, take the most recently updated model:
model_tags.sort(
key=lambda x: os.path.getmtime(os.path.join(checkpoint_dir, x)), reverse=True
)
model_tags.sort(key=lambda x: os.path.getmtime(os.path.join(checkpoint_dir, x)), reverse=True)
return model_tags[0]
@ -137,9 +120,7 @@ def find_last_step(checkpoint_dir):
checkpoint_files = glob.glob(os.path.join(checkpoint_dir, "model_*.pt"))
if not checkpoint_files:
raise FileNotFoundError(f"No checkpoints found in {checkpoint_dir}")
last_step = int(
max(os.path.basename(f).split("_")[-1].split(".")[0] for f in checkpoint_files)
)
last_step = int(max(os.path.basename(f).split("_")[-1].split(".")[0] for f in checkpoint_files))
return last_step

View File

@ -17,45 +17,33 @@ class ColoredFormatter(logging.Formatter):
# ANSI color codes
COLORS = {
"DEBUG": "\033[36m", # Cyan
"INFO": "\033[32m", # Green
"WARNING": "\033[33m", # Yellow
"ERROR": "\033[31m", # Red
"CRITICAL": "\033[35m", # Magenta
'DEBUG': '\033[36m', # Cyan
'INFO': '\033[32m', # Green
'WARNING': '\033[33m', # Yellow
'ERROR': '\033[31m', # Red
'CRITICAL': '\033[35m', # Magenta
}
RESET = "\033[0m"
BOLD = "\033[1m"
RESET = '\033[0m'
BOLD = '\033[1m'
def format(self, record):
# Add color to the level name
levelname = record.levelname
if levelname in self.COLORS:
record.levelname = (
f"{self.COLORS[levelname]}{self.BOLD}{levelname}{self.RESET}"
)
record.levelname = f"{self.COLORS[levelname]}{self.BOLD}{levelname}{self.RESET}"
# Format the message
message = super().format(record)
# Add color to specific parts of the message
if levelname == "INFO":
if levelname == 'INFO':
# Highlight numbers and percentages
message = re.sub(
r"(\d+\.?\d*\s*(?:GB|MB|%|docs))",
rf"{self.BOLD}\1{self.RESET}",
message,
)
message = re.sub(
r"(Shard \d+)",
rf'{self.COLORS["INFO"]}{self.BOLD}\1{self.RESET}',
message,
)
message = re.sub(r'(\d+\.?\d*\s*(?:GB|MB|%|docs))', rf'{self.BOLD}\1{self.RESET}', message)
message = re.sub(r'(Shard \d+)', rf'{self.COLORS["INFO"]}{self.BOLD}\1{self.RESET}', message)
return message
def setup_default_logging():
handler = logging.StreamHandler()
handler.setFormatter(
ColoredFormatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
)
handler.setFormatter(ColoredFormatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
logging.basicConfig(level=logging.INFO, handlers=[handler])
@ -101,7 +89,7 @@ def download_file_with_lock(url, filename, postprocess_fn=None):
content = response.read() # bytes
# Write to local file
with open(file_path, "wb") as f:
with open(file_path, 'wb') as f:
f.write(content)
print(f"Downloaded to {file_path}")
@ -113,7 +101,7 @@ def download_file_with_lock(url, filename, postprocess_fn=None):
def print0(s="", **kwargs):
ddp_rank = int(os.environ.get("RANK", 0))
ddp_rank = int(os.environ.get('RANK', 0))
if ddp_rank == 0:
print(s, **kwargs)
@ -135,15 +123,15 @@ def print_banner():
def is_ddp():
# TODO is there a proper way
return int(os.environ.get("RANK", -1)) != -1
return int(os.environ.get('RANK', -1)) != -1
def get_dist_info():
if is_ddp():
assert all(var in os.environ for var in ["RANK", "LOCAL_RANK", "WORLD_SIZE"])
ddp_rank = int(os.environ["RANK"])
ddp_local_rank = int(os.environ["LOCAL_RANK"])
ddp_world_size = int(os.environ["WORLD_SIZE"])
assert all(var in os.environ for var in ['RANK', 'LOCAL_RANK', 'WORLD_SIZE'])
ddp_rank = int(os.environ['RANK'])
ddp_local_rank = int(os.environ['LOCAL_RANK'])
ddp_world_size = int(os.environ['WORLD_SIZE'])
return True, ddp_rank, ddp_local_rank, ddp_world_size
else:
return False, 0, 0, 1
@ -166,13 +154,13 @@ def compute_init(device_type="cuda"): # cuda|cpu|mps
assert device_type in ["cuda", "mps", "cpu"], "Invalid device type atm"
if device_type == "cuda":
assert (
torch.cuda.is_available()
), "Your PyTorch installation is not configured for CUDA but device_type is 'cuda'"
assert torch.cuda.is_available(), (
"Your PyTorch installation is not configured for CUDA but device_type is 'cuda'"
)
if device_type == "mps":
assert (
torch.backends.mps.is_available()
), "Your PyTorch installation is not configured for MPS but device_type is 'mps'"
assert torch.backends.mps.is_available(), (
"Your PyTorch installation is not configured for MPS but device_type is 'mps'"
)
# Reproducibility
# Note that we set the global seeds here, but most of the code uses explicit rng objects.
@ -185,9 +173,7 @@ def compute_init(device_type="cuda"): # cuda|cpu|mps
# Precision
if device_type == "cuda":
torch.set_float32_matmul_precision(
"high"
) # uses tf32 instead of fp32 for matmuls
torch.set_float32_matmul_precision("high") # uses tf32 instead of fp32 for matmuls
# Distributed setup: Distributed Data Parallel (DDP), optional, and requires CUDA
ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()

View File

@ -20,15 +20,15 @@ from ast import literal_eval
def print0(s="", **kwargs):
ddp_rank = int(os.environ.get("RANK", 0))
ddp_rank = int(os.environ.get('RANK', 0))
if ddp_rank == 0:
print(s, **kwargs)
for arg in sys.argv[1:]:
if "=" not in arg:
if '=' not in arg:
# assume it's the name of a config file
assert not arg.startswith("--")
assert not arg.startswith('--')
config_file = arg
print0(f"Overriding config with {config_file}:")
with open(config_file) as f:
@ -36,8 +36,8 @@ for arg in sys.argv[1:]:
exec(open(config_file).read())
else:
# assume it's a --key=value argument
assert arg.startswith("--")
key, val = arg.split("=")
assert arg.startswith('--')
key, val = arg.split('=')
key = key[2:]
if key in globals():
try:
@ -50,9 +50,7 @@ for arg in sys.argv[1:]:
if globals()[key] is not None:
attempt_type = type(attempt)
default_type = type(globals()[key])
assert (
attempt_type == default_type
), f"Type mismatch: {attempt_type} != {default_type}"
assert attempt_type == default_type, f"Type mismatch: {attempt_type} != {default_type}"
# cross fingers
print0(f"Overriding: {key} = {attempt}")
globals()[key] = attempt

View File

@ -26,12 +26,8 @@ def render_prompts_mc(item, continuation_delimiter, fewshot_examples=None):
{{ item.query }}{{ continuation_delimiter }}{{ choice }}""".strip()
template = Template(template_str)
fewshot_examples = fewshot_examples or []
context = {
"fewshot_examples": fewshot_examples,
"continuation_delimiter": continuation_delimiter,
"item": item,
}
prompts = [template.render(choice=choice, **context) for choice in item["choices"]]
context = {'fewshot_examples': fewshot_examples, 'continuation_delimiter': continuation_delimiter, 'item': item}
prompts = [template.render(choice=choice, **context) for choice in item['choices']]
return prompts
@ -45,15 +41,8 @@ def render_prompts_schema(item, continuation_delimiter, fewshot_examples=None):
{{ context }}{{ continuation_delimiter }}{{ item.continuation }}""".strip()
template = Template(template_str)
fewshot_examples = fewshot_examples or []
context = {
"fewshot_examples": fewshot_examples,
"continuation_delimiter": continuation_delimiter,
"item": item,
}
prompts = [
template.render(context=context_option, **context)
for context_option in item["context_options"]
]
context = {'fewshot_examples': fewshot_examples, 'continuation_delimiter': continuation_delimiter, 'item': item}
prompts = [template.render(context=context_option, **context) for context_option in item['context_options']]
return prompts
@ -71,11 +60,7 @@ def render_prompts_lm(item, continuation_delimiter, fewshot_examples=None):
{{ item.context | trim }}{{ continuation_delimiter }}{% if include_continuation %}{{ item.continuation }}{% endif %}""".strip()
template = Template(template_str)
fewshot_examples = fewshot_examples or []
context = {
"fewshot_examples": fewshot_examples,
"continuation_delimiter": continuation_delimiter,
"item": item,
}
context = {'fewshot_examples': fewshot_examples, 'continuation_delimiter': continuation_delimiter, 'item': item}
# Return two prompts: without and with the continuation
prompt_without = template.render(include_continuation=False, **context)
prompt_with = template.render(include_continuation=True, **context)
@ -87,13 +72,13 @@ def render_prompts_lm(item, continuation_delimiter, fewshot_examples=None):
return [prompt_without, prompt_with]
def find_common_length(token_sequences, direction="left"):
def find_common_length(token_sequences, direction='left'):
"""
Find the length of the common prefix or suffix across token sequences
- direction: 'left' for prefix, 'right' for suffix
"""
min_len = min(len(seq) for seq in token_sequences)
indices = {"left": range(min_len), "right": range(-1, -min_len - 1, -1)}[direction]
indices = {'left': range(min_len), 'right': range(-1, -min_len - 1, -1)}[direction]
# Find the first position where the token sequences differ
for i, idx in enumerate(indices):
token = token_sequences[0][idx]
@ -115,7 +100,7 @@ def batch_sequences_mc(tokenizer, prompts):
# In multiple choice, contexts are the same but the continuation is different (common prefix)
tokens = tokenizer(prompts, prepend=tokenizer.get_bos_token_id())
# figure out the start and end of each continuation
answer_start_idx = find_common_length(tokens, direction="left")
answer_start_idx = find_common_length(tokens, direction='left')
start_indices = [answer_start_idx] * len(prompts)
end_indices = [len(x) for x in tokens]
return tokens, start_indices, end_indices
@ -125,7 +110,7 @@ def batch_sequences_schema(tokenizer, prompts):
# In schema tasks, contexts vary but continuation is the same (common suffix)
tokens = tokenizer(prompts, prepend=tokenizer.get_bos_token_id())
# figure out the start and end of each context
suffix_length = find_common_length(tokens, direction="right")
suffix_length = find_common_length(tokens, direction='right')
end_indices = [len(x) for x in tokens]
start_indices = [ei - suffix_length for ei in end_indices]
return tokens, start_indices, end_indices
@ -136,12 +121,8 @@ def batch_sequences_lm(tokenizer, prompts):
tokens = tokenizer(prompts, prepend=tokenizer.get_bos_token_id())
tokens_without, tokens_with = tokens
start_idx, end_idx = len(tokens_without), len(tokens_with)
assert (
start_idx < end_idx
), "prompt without is supposed to be a prefix of prompt with"
assert (
tokens_without == tokens_with[:start_idx]
), "prompt without is supposed to be a prefix of prompt with"
assert start_idx < end_idx, "prompt without is supposed to be a prefix of prompt with"
assert tokens_without == tokens_with[:start_idx], "prompt without is supposed to be a prefix of prompt with"
# we only need the with continuation prompt in the LM task, i.e. batch size of 1
return [tokens_with], [start_idx], [end_idx]
@ -158,12 +139,10 @@ def forward_model(model, input_ids):
target_ids = torch.roll(input_ids, shifts=-1, dims=1)
# Calculate cross entropy at all positions
losses = torch.nn.functional.cross_entropy(
outputs.view(batch_size * seq_len, -1),
target_ids.view(batch_size * seq_len),
reduction="none",
outputs.view(batch_size * seq_len, -1), target_ids.view(batch_size * seq_len), reduction='none'
).view(batch_size, seq_len)
# Set the last column to be nan because there is no autoregressive loss there
losses[:, -1] = float("nan")
losses[:, -1] = float('nan')
# Get the argmax predictions at each position
predictions = outputs.argmax(dim=-1)
return losses, predictions
@ -173,9 +152,9 @@ def forward_model(model, input_ids):
def evaluate_example(idx, model, tokenizer, data, device, task_meta):
"""Evaluate a single example, return True if correct, False otherwise"""
item = data[idx]
task_type = task_meta["task_type"]
num_fewshot = task_meta["num_fewshot"]
continuation_delimiter = task_meta["continuation_delimiter"]
task_type = task_meta['task_type']
num_fewshot = task_meta['num_fewshot']
continuation_delimiter = task_meta['continuation_delimiter']
# Sample few-shot examples (excluding current item)
fewshot_examples = []
@ -186,13 +165,13 @@ def evaluate_example(idx, model, tokenizer, data, device, task_meta):
fewshot_examples = [data[i] for i in fewshot_indices]
# Render prompts and batch sequences based on task type
if task_type == "multiple_choice":
if task_type == 'multiple_choice':
prompts = render_prompts_mc(item, continuation_delimiter, fewshot_examples)
tokens, start_idxs, end_idxs = batch_sequences_mc(tokenizer, prompts)
elif task_type == "schema":
elif task_type == 'schema':
prompts = render_prompts_schema(item, continuation_delimiter, fewshot_examples)
tokens, start_idxs, end_idxs = batch_sequences_schema(tokenizer, prompts)
elif task_type == "language_modeling":
elif task_type == 'language_modeling':
prompts = render_prompts_lm(item, continuation_delimiter, fewshot_examples)
tokens, start_idxs, end_idxs = batch_sequences_lm(tokenizer, prompts)
else:
@ -200,7 +179,7 @@ def evaluate_example(idx, model, tokenizer, data, device, task_meta):
# Some models can't forward sequences beyond a certain length (e.g. GPT-2)
# In these cases, we have to truncate sequences to max length and adjust the indices
if hasattr(model, "max_seq_len") and model.max_seq_len is not None:
if hasattr(model, 'max_seq_len') and model.max_seq_len is not None:
max_tokens = model.max_seq_len
new_tokens, new_start_idxs, new_end_idxs = [], [], []
for t, s, e in zip(tokens, start_idxs, end_idxs):
@ -226,7 +205,7 @@ def evaluate_example(idx, model, tokenizer, data, device, task_meta):
losses, predictions = forward_model(model, input_ids)
# See if the losses/predictions come out correctly
if task_type == "language_modeling":
if task_type == 'language_modeling':
# language modeling task is currently always batch size 1
si = start_idxs[0]
ei = end_idxs[0]
@ -234,14 +213,11 @@ def evaluate_example(idx, model, tokenizer, data, device, task_meta):
predicted_tokens = predictions[0, si - 1 : ei - 1]
actual_tokens = input_ids[0, si:ei]
is_correct = torch.all(predicted_tokens == actual_tokens).item()
elif task_type in ["multiple_choice", "schema"]:
elif task_type in ['multiple_choice', 'schema']:
# For MC/schema: find the option with lowest average loss
mean_losses = [
losses[i, si - 1 : ei - 1].mean().item()
for i, (si, ei) in enumerate(zip(start_idxs, end_idxs))
]
mean_losses = [losses[i, si - 1 : ei - 1].mean().item() for i, (si, ei) in enumerate(zip(start_idxs, end_idxs))]
pred_idx = mean_losses.index(min(mean_losses))
is_correct = pred_idx == item["gold"]
is_correct = pred_idx == item['gold']
else:
raise ValueError(f"Unsupported task type: {task_type}")

View File

@ -9,13 +9,7 @@ from nanochat.tokenizer import get_tokenizer
def tokenizing_distributed_data_loader_with_state(
B,
T,
split,
tokenizer_threads=4,
tokenizer_batch_size=128,
device="cuda",
resume_state_dict=None,
B, T, split, tokenizer_threads=4, tokenizer_batch_size=128, device="cuda", resume_state_dict=None
):
"""
Stream pretraining text from parquet files, tokenize, yield training batches.
@ -37,12 +31,8 @@ def tokenizing_distributed_data_loader_with_state(
def document_batches():
parquet_paths = list_parquet_files()
parquet_paths = parquet_paths[:-1] if split == "train" else parquet_paths[-1:]
resume_pq_idx = (
resume_state_dict["pq_idx"] if resume_state_dict is not None else 0
)
resume_rg_idx = (
resume_state_dict["rg_idx"] if resume_state_dict is not None else None
)
resume_pq_idx = resume_state_dict["pq_idx"] if resume_state_dict is not None else 0
resume_rg_idx = resume_state_dict["rg_idx"] if resume_state_dict is not None else None
pq_idx = resume_pq_idx # we kick off parquet files at the resume index (or by default just 0)
while True: # iterate infinitely (multi-epoch)
while pq_idx < len(parquet_paths): # iterate over all parquet files
@ -51,21 +41,15 @@ def tokenizing_distributed_data_loader_with_state(
# Start from resume point if resuming on same file, otherwise from DDP rank
# I know this state resumption is a little bit tricky and a little bit hacky... sigh.
if resume_rg_idx is not None:
base_idx = (
resume_rg_idx // ddp_world_size
) # in units of ddp_world_size
base_idx = resume_rg_idx // ddp_world_size # in units of ddp_world_size
base_idx += 1 # advance by 1 so that we definitely don't repeat data after resuming
rg_idx = base_idx * ddp_world_size + ddp_rank
resume_rg_idx = (
None # set to None as we only want to do this a single time
)
resume_rg_idx = None # set to None as we only want to do this a single time
else:
rg_idx = ddp_rank
while rg_idx < pf.num_row_groups:
rg = pf.read_row_group(rg_idx)
batch = rg.column(
"text"
).to_pylist() # each batch is a parquet group, e.g. 1024 rows
batch = rg.column('text').to_pylist() # each batch is a parquet group, e.g. 1024 rows
# the tokenizer encode might want to go in even smaller batches, e.g. 128 rows
for i in range(0, len(batch), tokenizer_batch_size):
yield batch[i : i + tokenizer_batch_size], (pq_idx, rg_idx)
@ -85,28 +69,20 @@ def tokenizing_distributed_data_loader_with_state(
# Accumulate enough tokens for one iteration before yielding.
while len(token_buffer) < needed_tokens:
doc_batch, (pq_idx, rg_idx) = next(batches)
token_lists = tokenizer.encode(
doc_batch, prepend=bos_token, num_threads=tokenizer_threads
)
token_lists = tokenizer.encode(doc_batch, prepend=bos_token, num_threads=tokenizer_threads)
for tokens in token_lists:
token_buffer.extend(tokens)
# Move tokens from the deque into the scratch buffer
tokens = [token_buffer.popleft() for _ in range(needed_tokens)]
# CUDA supports memory pinning for asynchronous transfers between CPU and GPU
use_cuda_optimizations = device == "cuda"
scratch = torch.tensor(
tokens, dtype=torch.long, pin_memory=use_cuda_optimizations
) # in PyTorch, long=int64
scratch = torch.tensor(tokens, dtype=torch.long, pin_memory=use_cuda_optimizations) # in PyTorch, long=int64
# Create the inputs/targets as 1D tensors
inputs_cpu = scratch[:-1]
targets_cpu = scratch[1:]
# Reshape to 2D and move to GPU async
inputs = inputs_cpu.view(B, T).to(
device=device, non_blocking=use_cuda_optimizations
)
targets = targets_cpu.view(B, T).to(
device=device, non_blocking=use_cuda_optimizations
)
inputs = inputs_cpu.view(B, T).to(device=device, non_blocking=use_cuda_optimizations)
targets = targets_cpu.view(B, T).to(device=device, non_blocking=use_cuda_optimizations)
state_dict = {
"pq_idx": pq_idx,
"rg_idx": rg_idx,
@ -116,7 +92,5 @@ def tokenizing_distributed_data_loader_with_state(
def tokenizing_distributed_data_loader(*args, **kwargs):
# helper function that only emits the inputs/targets and not the state_dict
for inputs, targets, state_dict in tokenizing_distributed_data_loader_with_state(
*args, **kwargs
):
for inputs, targets, state_dict in tokenizing_distributed_data_loader_with_state(*args, **kwargs):
yield inputs, targets

View File

@ -21,13 +21,9 @@ from nanochat.common import get_base_dir
# The specifics of the current pretraining dataset
# The URL on the internet where the data is hosted and downloaded from on demand
BASE_URL = (
"https://huggingface.co/datasets/karpathy/fineweb-edu-100b-shuffle/resolve/main"
)
BASE_URL = "https://huggingface.co/datasets/karpathy/fineweb-edu-100b-shuffle/resolve/main"
MAX_SHARD = 1822 # the last datashard is shard_01822.parquet
index_to_filename = (
lambda index: f"shard_{index:05d}.parquet"
) # format of the filenames
index_to_filename = lambda index: f"shard_{index:05d}.parquet" # format of the filenames
base_dir = get_base_dir()
DATA_DIR = os.path.join(base_dir, "base_data")
os.makedirs(DATA_DIR, exist_ok=True)
@ -39,13 +35,7 @@ os.makedirs(DATA_DIR, exist_ok=True)
def list_parquet_files(data_dir=None):
"""Looks into a data dir and returns full paths to all parquet files."""
data_dir = DATA_DIR if data_dir is None else data_dir
parquet_files = sorted(
[
f
for f in os.listdir(data_dir)
if f.endswith(".parquet") and not f.endswith(".tmp")
]
)
parquet_files = sorted([f for f in os.listdir(data_dir) if f.endswith('.parquet') and not f.endswith('.tmp')])
parquet_paths = [os.path.join(data_dir, f) for f in parquet_files]
return parquet_paths
@ -63,7 +53,7 @@ def parquets_iter_batched(split, start=0, step=1):
pf = pq.ParquetFile(filepath)
for rg_idx in range(start, pf.num_row_groups, step):
rg = pf.read_row_group(rg_idx)
texts = rg.column("text").to_pylist()
texts = rg.column('text').to_pylist()
yield texts
@ -89,11 +79,9 @@ def download_single_file(index):
response = requests.get(url, stream=True, timeout=30)
response.raise_for_status()
# Write to temporary file first
temp_path = filepath + f".tmp"
with open(temp_path, "wb") as f:
for chunk in response.iter_content(
chunk_size=1024 * 1024
): # 1MB chunks
temp_path = filepath + ".tmp"
with open(temp_path, 'wb') as f:
for chunk in response.iter_content(chunk_size=1024 * 1024): # 1MB chunks
if chunk:
f.write(chunk)
# Move temp file to final location
@ -101,10 +89,10 @@ def download_single_file(index):
print(f"Successfully downloaded {filename}")
return True
except (requests.RequestException, OSError) as e:
except (OSError, requests.RequestException) as e:
print(f"Attempt {attempt}/{max_attempts} failed for {filename}: {e}")
# Clean up any partial files
for path in [filepath + f".tmp", filepath]:
for path in [filepath + ".tmp", filepath]:
if os.path.exists(path):
try:
os.remove(path)
@ -123,30 +111,18 @@ def download_single_file(index):
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Download FineWeb-Edu 100BT dataset shards"
parser = argparse.ArgumentParser(description="Download FineWeb-Edu 100BT dataset shards")
parser.add_argument(
"-n", "--num-files", type=int, default=-1, help="Number of shards to download (default: -1), -1 = disable"
)
parser.add_argument(
"-n",
"--num-files",
type=int,
default=-1,
help="Number of shards to download (default: -1), -1 = disable",
)
parser.add_argument(
"-w",
"--num-workers",
type=int,
default=4,
help="Number of parallel download workers (default: 4)",
"-w", "--num-workers", type=int, default=4, help="Number of parallel download workers (default: 4)"
)
args = parser.parse_args()
num = MAX_SHARD + 1 if args.num_files == -1 else min(args.num_files, MAX_SHARD + 1)
ids_to_download = list(range(num))
print(
f"Downloading {len(ids_to_download)} shards using {args.num_workers} workers..."
)
print(f"Downloading {len(ids_to_download)} shards using {args.num_workers} workers...")
print(f"Target directory: {DATA_DIR}")
print()
with Pool(processes=args.num_workers) as pool:

View File

@ -64,38 +64,36 @@ def use_calculator(expr):
# Check if it's a string operation we support
# Allow: strings (single/double quotes), .count(), letters, numbers, spaces, parens
allowed_chars = (
"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'\"()._ "
)
allowed_chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'\"()._ "
if not all([x in allowed_chars for x in expr]):
return None
# Disallow dangerous patterns
dangerous_patterns = [
"__",
"import",
"exec",
"eval",
"compile",
"open",
"file",
"input",
"raw_input",
"globals",
"locals",
"vars",
"dir",
"getattr",
"setattr",
"delattr",
"hasattr",
'__',
'import',
'exec',
'eval',
'compile',
'open',
'file',
'input',
'raw_input',
'globals',
'locals',
'vars',
'dir',
'getattr',
'setattr',
'delattr',
'hasattr',
]
expr_lower = expr.lower()
if any(pattern in expr_lower for pattern in dangerous_patterns):
return None
# Only allow .count() method for now (can expand later)
if ".count(" not in expr:
if '.count(' not in expr:
return None
# Evaluate with timeout
@ -137,9 +135,7 @@ class KVCache:
assert dim1 == dim2, f"Dim {ix} mismatch: {dim1} != {dim2}"
elif ix == 2:
# batch_size can be expanded
assert (
dim1 == dim2 or dim2 == 1
), f"Batch dim mismatch: {dim1} != {dim2}"
assert dim1 == dim2 or dim2 == 1, f"Batch dim mismatch: {dim1} != {dim2}"
elif ix == 4:
# seq_len: self must be longer than other
assert dim1 >= dim2, f"Seq len mismatch: {dim1} < {dim2}"
@ -161,17 +157,11 @@ class KVCache:
# Dynamically grow the cache if needed
if t1 > self.kv_cache.size(4):
t_needed = t1 + 1024 # as much as we need plus buffer of 1024
t_needed = (
t_needed + 1023
) & ~1023 # then round up to the nearest multiple of 1024
t_needed = (t_needed + 1023) & ~1023 # then round up to the nearest multiple of 1024
additional_shape = list(self.kv_cache.shape)
additional_shape[4] = t_needed - self.kv_cache.size(4)
additional_cache = torch.empty(
additional_shape, dtype=k.dtype, device=k.device
)
self.kv_cache = torch.cat(
[self.kv_cache, additional_cache], dim=4
).contiguous()
additional_cache = torch.empty(additional_shape, dtype=k.dtype, device=k.device)
self.kv_cache = torch.cat([self.kv_cache, additional_cache], dim=4).contiguous()
self.kv_shape = self.kv_cache.shape
# Insert k, v into the cache
self.kv_cache[layer_idx, 0, :, :, t0:t1] = k
@ -211,9 +201,7 @@ def sample_next_token(logits, rng, temperature=1.0, top_k=None):
class RowState:
# Per-row state tracking during generation
def __init__(self, current_tokens=None):
self.current_tokens = (
current_tokens or []
) # Current token sequence for this row
self.current_tokens = current_tokens or [] # Current token sequence for this row
self.forced_tokens = deque() # Queue of tokens to force inject
self.in_python_block = False # Whether we are inside a python block
self.python_expr_tokens = [] # Tokens of the current python expression
@ -221,25 +209,14 @@ class RowState:
class Engine:
def __init__(self, model, tokenizer):
self.model = model
self.tokenizer = tokenizer # needed for tool use
@torch.inference_mode()
def generate(
self,
tokens,
num_samples=1,
max_tokens=None,
temperature=1.0,
top_k=None,
seed=42,
):
def generate(self, tokens, num_samples=1, max_tokens=None, temperature=1.0, top_k=None, seed=42):
"""Same as generate, but does single prefill and then clones the KV cache."""
assert isinstance(tokens, list) and isinstance(
tokens[0], int
), "expecting list of ints"
assert isinstance(tokens, list) and isinstance(tokens[0], int), "expecting list of ints"
device = self.model.get_device()
rng = torch.Generator(device=device)
rng.manual_seed(seed)
@ -255,11 +232,7 @@ class Engine:
# 1) Run a batch 1 prefill of the prompt tokens
m = self.model.config
kv_model_kwargs = {
"num_heads": m.n_kv_head,
"head_dim": m.n_embd // m.n_head,
"num_layers": m.n_layer,
}
kv_model_kwargs = {"num_heads": m.n_kv_head, "head_dim": m.n_embd // m.n_head, "num_layers": m.n_layer}
kv_cache_prefill = KVCache(
batch_size=1,
seq_len=len(tokens),
@ -272,11 +245,7 @@ class Engine:
sampled_tokens = next_ids[:, 0].tolist()
# 2) Replicate the KV cache for each sample/row
kv_length_hint = (
(len(tokens) + max_tokens)
if max_tokens is not None
else self.model.config.sequence_len
)
kv_length_hint = (len(tokens) + max_tokens) if max_tokens is not None else self.model.config.sequence_len
kv_cache_decode = KVCache(
batch_size=num_samples,
seq_len=kv_length_hint,
@ -302,36 +271,24 @@ class Engine:
# Get sampled tokens - either from prefill or from forward pass
if first_iteration:
# Use the tokens we already sampled from prefill
sampled_tokens = [
sampled_tokens[0]
] * num_samples # Broadcast first token to all rows
sampled_tokens = [sampled_tokens[0]] * num_samples # Broadcast first token to all rows
# TODO: we should sample a token for each row instead of broadcasting
first_iteration = False
else:
# Forward the model and get the next token for each row
logits = self.model.forward(
ids, kv_cache=kv_cache_decode
) # (B, T, vocab_size)
logits = self.model.forward(ids, kv_cache=kv_cache_decode) # (B, T, vocab_size)
logits = logits[:, -1, :] # (B, vocab_size) at last time step
next_ids = sample_next_token(logits, rng, temperature, top_k) # (B, 1)
sampled_tokens = next_ids[:, 0].tolist()
# Process each row: choose the next token, update state, optional tool use
token_column = [] # contains the next token id along each row
token_masks = (
[]
) # contains the mask (was it sampled (1) or forced (0)?) along each row
token_masks = [] # contains the mask (was it sampled (1) or forced (0)?) along each row
for i, state in enumerate(row_states):
# Select the next token in this row
is_forced = (
len(state.forced_tokens) > 0
) # are there tokens waiting to be forced in deque?
token_masks.append(
0 if is_forced else 1
) # mask is 0 if forced, 1 if sampled
next_token = (
state.forced_tokens.popleft() if is_forced else sampled_tokens[i]
)
is_forced = len(state.forced_tokens) > 0 # are there tokens waiting to be forced in deque?
token_masks.append(0 if is_forced else 1) # mask is 0 if forced, 1 if sampled
next_token = state.forced_tokens.popleft() if is_forced else sampled_tokens[i]
token_column.append(next_token)
# Update the state of this row to include the next token
state.current_tokens.append(next_token)
@ -360,9 +317,7 @@ class Engine:
yield token_column, token_masks
num_generated += 1
# Prepare ids for next iteration
ids = torch.tensor(token_column, dtype=torch.long, device=device).unsqueeze(
1
)
ids = torch.tensor(token_column, dtype=torch.long, device=device).unsqueeze(1)
def generate_batch(self, tokens, num_samples=1, **kwargs):
"""
@ -400,9 +355,7 @@ if __name__ == "__main__":
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
device_type = autodetect_device_type()
autocast_ctx = (
torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16)
if device_type == "cuda"
else nullcontext()
torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
)
# load the model and tokenizer
@ -411,9 +364,7 @@ if __name__ == "__main__":
# common hyperparameters
kwargs = dict(max_tokens=64, temperature=0.0)
# set the starting prompt
prompt_tokens = tokenizer.encode(
"The chemical formula of water is", prepend=bos_token_id
)
prompt_tokens = tokenizer.encode("The chemical formula of water is", prepend=bos_token_id)
# generate the reference sequence using the model.generate() function
generated_tokens = []
torch.cuda.synchronize()
@ -432,9 +383,7 @@ if __name__ == "__main__":
# generate tokens with Engine
generated_tokens = []
engine = Engine(model, tokenizer)
stream = engine.generate(
prompt_tokens, num_samples=1, **kwargs
) # note: runs in fp32
stream = engine.generate(prompt_tokens, num_samples=1, **kwargs) # note: runs in fp32
torch.cuda.synchronize()
t0 = time.time()
with autocast_ctx:

View File

@ -30,7 +30,6 @@ import platform
import signal
import tempfile
from dataclasses import dataclass
from typing import Optional
# -----------------------------------------------------------------------------
@ -42,7 +41,7 @@ class ExecutionResult:
success: bool
stdout: str
stderr: str
error: Optional[str] = None
error: str | None = None
timeout: bool = False
memory_exceeded: bool = False
@ -133,7 +132,7 @@ def chdir(root):
os.chdir(cwd)
def reliability_guard(maximum_memory_bytes: Optional[int] = None):
def reliability_guard(maximum_memory_bytes: int | None = None):
"""
This disables various destructive functions and prevents the generated code
from interfering with the test (e.g. fork bomb, killing other processes,
@ -150,15 +149,9 @@ def reliability_guard(maximum_memory_bytes: Optional[int] = None):
# These resource limit calls seem to fail on macOS (Darwin), skip?
import resource
resource.setrlimit(
resource.RLIMIT_AS, (maximum_memory_bytes, maximum_memory_bytes)
)
resource.setrlimit(
resource.RLIMIT_DATA, (maximum_memory_bytes, maximum_memory_bytes)
)
resource.setrlimit(
resource.RLIMIT_STACK, (maximum_memory_bytes, maximum_memory_bytes)
)
resource.setrlimit(resource.RLIMIT_AS, (maximum_memory_bytes, maximum_memory_bytes))
resource.setrlimit(resource.RLIMIT_DATA, (maximum_memory_bytes, maximum_memory_bytes))
resource.setrlimit(resource.RLIMIT_STACK, (maximum_memory_bytes, maximum_memory_bytes))
faulthandler.disable()
@ -220,12 +213,9 @@ def reliability_guard(maximum_memory_bytes: Optional[int] = None):
sys.modules["tkinter"] = None
def _unsafe_execute(
code: str, timeout: float, maximum_memory_bytes: Optional[int], result_dict
):
def _unsafe_execute(code: str, timeout: float, maximum_memory_bytes: int | None, result_dict):
"""Execute code in a subprocess with safety guards. Results are written to result_dict."""
with create_tempdir():
# These system calls are needed when cleaning up tempdir.
import os
import shutil
@ -307,7 +297,7 @@ def _unsafe_execute(
def execute_code(
code: str,
timeout: float = 5.0, # 5 seconds default
maximum_memory_bytes: Optional[int] = 256 * 1024 * 1024, # 256MB default
maximum_memory_bytes: int | None = 256 * 1024 * 1024, # 256MB default
) -> ExecutionResult:
"""
Execute Python code in a sandboxed environment.
@ -331,9 +321,7 @@ def execute_code(
manager = multiprocessing.Manager()
result_dict = manager.dict()
p = multiprocessing.Process(
target=_unsafe_execute, args=(code, timeout, maximum_memory_bytes, result_dict)
)
p = multiprocessing.Process(target=_unsafe_execute, args=(code, timeout, maximum_memory_bytes, result_dict))
p.start()
p.join(timeout=timeout + 1)

View File

@ -75,9 +75,7 @@ class CausalSelfAttention(nn.Module):
# Apply Rotary Embeddings to queries and keys to get relative positional encoding
cos, sin = cos_sin
q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(
k, cos, sin
) # QK rotary embedding
q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin) # QK rotary embedding
q, k = norm(q), norm(k) # QK norm
q, k, v = (
q.transpose(1, 2),
@ -89,9 +87,7 @@ class CausalSelfAttention(nn.Module):
if kv_cache is not None:
k, v = kv_cache.insert_kv(self.layer_idx, k, v)
Tq = q.size(2) # number of queries in this forward pass
Tk = k.size(
2
) # number of keys/values in total (in the cache + current forward pass)
Tk = k.size(2) # number of keys/values in total (in the cache + current forward pass)
# Attention: queries attend to keys/values autoregressively. A few cases to handle:
enable_gqa = (
@ -100,31 +96,21 @@ class CausalSelfAttention(nn.Module):
if kv_cache is None or Tq == Tk:
# During training (no KV cache), attend as usual with causal attention
# And even if there is KV cache, we can still use this simple version when Tq == Tk
y = F.scaled_dot_product_attention(
q, k, v, is_causal=True, enable_gqa=enable_gqa
)
y = F.scaled_dot_product_attention(q, k, v, is_causal=True, enable_gqa=enable_gqa)
elif Tq == 1:
# During inference but with a single query in this forward pass:
# The query has to attend to all the keys/values in the cache
y = F.scaled_dot_product_attention(
q, k, v, is_causal=False, enable_gqa=enable_gqa
)
y = F.scaled_dot_product_attention(q, k, v, is_causal=False, enable_gqa=enable_gqa)
else:
# During inference AND we have a chunk of queries in this forward pass:
# First, each query attends to all the cached keys/values (i.e. full prefix)
attn_mask = torch.zeros(
(Tq, Tk), dtype=torch.bool, device=q.device
) # True = keep, False = mask
attn_mask = torch.zeros((Tq, Tk), dtype=torch.bool, device=q.device) # True = keep, False = mask
prefix_len = Tk - Tq
if prefix_len > 0: # can't be negative but could be zero
attn_mask[:, :prefix_len] = True
# Then, causal attention within this chunk
attn_mask[:, prefix_len:] = torch.tril(
torch.ones((Tq, Tq), dtype=torch.bool, device=q.device)
)
y = F.scaled_dot_product_attention(
q, k, v, attn_mask=attn_mask, enable_gqa=enable_gqa
)
attn_mask[:, prefix_len:] = torch.tril(torch.ones((Tq, Tq), dtype=torch.bool, device=q.device))
y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, enable_gqa=enable_gqa)
# Re-assemble the heads side by side and project back to residual stream
y = y.transpose(1, 2).contiguous().view(B, T, -1)
@ -164,9 +150,7 @@ class GPT(nn.Module):
self.transformer = nn.ModuleDict(
{
"wte": nn.Embedding(config.vocab_size, config.n_embd),
"h": nn.ModuleList(
[Block(config, layer_idx) for layer_idx in range(config.n_layer)]
),
"h": nn.ModuleList([Block(config, layer_idx) for layer_idx in range(config.n_layer)]),
}
)
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
@ -174,14 +158,10 @@ class GPT(nn.Module):
# As for rotary_seq_len, these rotary embeddings are pretty small/cheap in memory,
# so let's just over-compute them, but assert fail if we ever reach that amount.
# In the future we can dynamically grow the cache, for now it's fine.
self.rotary_seq_len = (
config.sequence_len * 10
) # 10X over-compute should be enough, TODO make nicer?
self.rotary_seq_len = config.sequence_len * 10 # 10X over-compute should be enough, TODO make nicer?
head_dim = config.n_embd // config.n_head
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
self.register_buffer(
"cos", cos, persistent=False
) # persistent=False means it's not saved to the checkpoint
self.register_buffer("cos", cos, persistent=False) # persistent=False means it's not saved to the checkpoint
self.register_buffer("sin", sin, persistent=False)
def init_weights(self):
@ -226,10 +206,7 @@ class GPT(nn.Module):
freqs = torch.outer(t, inv_freq)
cos, sin = freqs.cos(), freqs.sin()
cos, sin = cos.bfloat16(), sin.bfloat16() # keep them in bfloat16
cos, sin = (
cos[None, :, None, :],
sin[None, :, None, :],
) # add batch and head dims for later broadcasting
cos, sin = cos[None, :, None, :], sin[None, :, None, :] # add batch and head dims for later broadcasting
return cos, sin
def get_device(self):
@ -248,25 +225,19 @@ class GPT(nn.Module):
num_flops_per_token = 6 * (nparams - nparams_embedding) + 12 * l * h * q * t
return num_flops_per_token
def setup_optimizers(
self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0
):
def setup_optimizers(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0):
model_dim = self.config.n_embd
ddp, rank, local_rank, world_size = get_dist_info()
# Separate out all parameters into 3 groups (matrix, embedding, lm_head)
matrix_params = list(self.transformer.h.parameters())
embedding_params = list(self.transformer.wte.parameters())
lm_head_params = list(self.lm_head.parameters())
assert len(list(self.parameters())) == len(matrix_params) + len(
embedding_params
) + len(lm_head_params)
assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params)
# Create the AdamW optimizer for the embedding and lm_head
# Scale the LR for the AdamW parameters by ∝1/√dmodel (having tuned the LRs for 768 dim model)
dmodel_lr_scale = (model_dim / 768) ** -0.5
if rank == 0:
print(
f"Scaling the LR for the AdamW parameters ∝1/√({model_dim}/768) = {dmodel_lr_scale:.6f}"
)
print(f"Scaling the LR for the AdamW parameters ∝1/√({model_dim}/768) = {dmodel_lr_scale:.6f}")
adam_groups = [
dict(params=lm_head_params, lr=unembedding_lr * dmodel_lr_scale),
dict(params=embedding_params, lr=embedding_lr * dmodel_lr_scale),
@ -285,23 +256,20 @@ class GPT(nn.Module):
group["initial_lr"] = group["lr"]
return optimizers
def forward(self, idx, targets=None, kv_cache=None, loss_reduction="mean"):
def forward(self, idx, targets=None, kv_cache=None, loss_reduction='mean'):
B, T = idx.size()
# Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim/2))
assert T <= self.cos.size(
1
), f"Sequence length grew beyond the rotary embeddings cache: {T} > {self.cos.size(1)}"
assert (
idx.device == self.cos.device
), f"Rotary embeddings and idx are on different devices: {idx.device} != {self.cos.device}"
assert T <= self.cos.size(1), (
f"Sequence length grew beyond the rotary embeddings cache: {T} > {self.cos.size(1)}"
)
assert idx.device == self.cos.device, (
f"Rotary embeddings and idx are on different devices: {idx.device} != {self.cos.device}"
)
assert self.cos.dtype == torch.bfloat16, "Rotary embeddings must be in bfloat16"
# if kv cache exists, we need to offset the rotary embeddings to the current position in the cache
T0 = 0 if kv_cache is None else kv_cache.get_pos()
cos_sin = (
self.cos[:, T0 : T0 + T],
self.sin[:, T0 : T0 + T],
) # truncate cache to current sequence length
cos_sin = self.cos[:, T0 : T0 + T], self.sin[:, T0 : T0 + T] # truncate cache to current sequence length
# Forward the trunk of the Transformer
x = self.transformer.wte(idx)
@ -319,10 +287,7 @@ class GPT(nn.Module):
logits = softcap * torch.tanh(logits / softcap) # logits softcap
logits = logits.float() # use tf32/fp32 for logits
loss = F.cross_entropy(
logits.view(-1, logits.size(-1)),
targets.view(-1),
ignore_index=-1,
reduction=loss_reduction,
logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1, reduction=loss_reduction
)
return loss
else:
@ -351,7 +316,7 @@ class GPT(nn.Module):
logits = logits[:, -1, :] # (B, vocab_size)
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = -float("Inf")
logits[logits < v[:, [-1]]] = -float('Inf')
if temperature > 0:
logits = logits / temperature
probs = F.softmax(logits, dim=-1)

View File

@ -33,20 +33,16 @@ def evaluate_bpb(model, batches, steps, token_bytes):
batch_iter = iter(batches)
for _ in range(steps):
x, y = next(batch_iter)
loss2d = model(x, y, loss_reduction="none") # (B, T)
loss2d = model(x, y, loss_reduction='none') # (B, T)
loss2d = loss2d.view(-1) # flatten
y = y.view(-1) # flatten
if (
y.int() < 0
).any(): # mps does not currently have kernel for < 0 for int64, only int32
if (y.int() < 0).any(): # mps does not currently have kernel for < 0 for int64, only int32
# slightly more complex code path if some target tokens are ignore_index (e.g. -1)
# any target token < 0 is to be ignored: do NOT index token_bytes with negatives
valid = y >= 0
y_safe = torch.where(valid, y, torch.zeros_like(y))
# map valid targets to their byte length; ignored targets contribute 0 bytes
num_bytes2d = torch.where(
valid, token_bytes[y_safe], torch.zeros_like(y, dtype=token_bytes.dtype)
)
num_bytes2d = torch.where(valid, token_bytes[y_safe], torch.zeros_like(y, dtype=token_bytes.dtype))
total_nats += (loss2d * (num_bytes2d > 0)).sum()
total_bytes += num_bytes2d.sum()
else:
@ -63,6 +59,6 @@ def evaluate_bpb(model, batches, steps, token_bytes):
total_nats = total_nats.item()
total_bytes = total_bytes.item()
if total_bytes == 0:
return float("inf")
return float('inf')
bpb = total_nats / (math.log(2) * total_bytes)
return bpb

View File

@ -113,22 +113,13 @@ class DistMuon(torch.optim.Optimizer):
ns_steps: number of NewtonSchulz iterations for the orthogonalization
"""
def __init__(
self,
params,
lr: float = 0.02,
momentum: float = 0.95,
nesterov: bool = True,
ns_steps: int = 5,
):
def __init__(self, params, lr: float = 0.02, momentum: float = 0.95, nesterov: bool = True, ns_steps: int = 5):
defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps)
params = list(params)
assert all(p.ndim == 2 for p in params), "Muon expects 2D parameters only"
rank = dist.get_rank()
# Group all parameters by their shape
shapes = sorted(
{p.shape for p in params}
) # sort to ensure consistent / deterministic ordering
shapes = sorted({p.shape for p in params}) # sort to ensure consistent / deterministic ordering
param_groups = []
for shape in shapes:
group_params = [p for p in params if p.shape == shape]
@ -136,12 +127,8 @@ class DistMuon(torch.optim.Optimizer):
assert all(p.device == device for p in group_params)
assert all(p.dtype == dtype for p in group_params)
if rank == 0:
print(
f"Muon: Grouping {len(group_params)} params of shape {shape}, device {device}, dtype {dtype}"
)
param_groups.append(
dict(params=group_params, zero_buffer=torch.zeros_like(group_params[0]))
)
print(f"Muon: Grouping {len(group_params)} params of shape {shape}, device {device}, dtype {dtype}")
param_groups.append(dict(params=group_params, zero_buffer=torch.zeros_like(group_params[0])))
super().__init__(param_groups, defaults)
@torch.no_grad()
@ -150,9 +137,9 @@ class DistMuon(torch.optim.Optimizer):
world_size = dist.get_world_size()
# Ensure all grads exist
assert all(
p.grad is not None for group in self.param_groups for p in group["params"]
), "All params must have grads"
assert all(p.grad is not None for group in self.param_groups for p in group["params"]), (
"All params must have grads"
)
# Kick off all the reduce scatter operations to average up the gradients across all ranks
all_reduce_futures = []
@ -168,15 +155,9 @@ class DistMuon(torch.optim.Optimizer):
# pad rs_input with the zero buffer to complete the group
rs_input.extend([zero_buffer] * (world_size - len(rs_input)))
# the output buffer gets strided across the group based on the rank
rs_output = (
params[owner_idx].grad
if owner_idx < len(params)
else torch.empty_like(zero_buffer)
)
rs_output = params[owner_idx].grad if owner_idx < len(params) else torch.empty_like(zero_buffer)
# reduce scatter the gradients within this group of world_size params
work = dist.reduce_scatter(
rs_output, rs_input, op=dist.ReduceOp.AVG, async_op=True
).get_future()
work = dist.reduce_scatter(rs_output, rs_input, op=dist.ReduceOp.AVG, async_op=True).get_future()
all_reduce_futures.append(work)
# Now each rank computes the update and gathers
@ -188,13 +169,9 @@ class DistMuon(torch.optim.Optimizer):
# Go through params in groups of world_size.
for base_i in range(0, len(params), world_size):
# The compute owner of each param is rank i % world_size
owner_idx = (
base_i + rank
) # calculate the index of the param that this rank owns
owner_idx = base_i + rank # calculate the index of the param that this rank owns
# Wait for the reduce scatter to complete
all_reduce_futures[
future_idx
].wait() # possibly later we could use wait_any polling instead
all_reduce_futures[future_idx].wait() # possibly later we could use wait_any polling instead
future_idx += 1
# Owner computes the Muon update, result is in its param
if owner_idx < len(params):
@ -212,12 +189,7 @@ class DistMuon(torch.optim.Optimizer):
# Replicate updated parameters to all ranks
ag_input = params[owner_idx] if owner_idx < len(params) else zero_buffer
ag_output = params[base_i : base_i + world_size]
ag_output.extend(
[
torch.empty_like(zero_buffer)
for _ in range(world_size - len(ag_output))
]
) # pad
ag_output.extend([torch.empty_like(zero_buffer) for _ in range(world_size - len(ag_output))]) # pad
work = dist.all_gather(ag_output, ag_input, async_op=True).get_future()
all_gather_futures.append(work)

View File

@ -17,9 +17,7 @@ import torch
def run_command(cmd):
"""Run a shell command and return output, or None if it fails."""
try:
result = subprocess.run(
cmd, shell=True, capture_output=True, text=True, timeout=5
)
result = subprocess.run(cmd, shell=True, capture_output=True, text=True, timeout=5)
if result.returncode == 0:
return result.stdout.strip()
return None
@ -30,16 +28,16 @@ def run_command(cmd):
def get_git_info():
"""Get current git commit, branch, and dirty status."""
info = {}
info["commit"] = run_command("git rev-parse --short HEAD") or "unknown"
info["branch"] = run_command("git rev-parse --abbrev-ref HEAD") or "unknown"
info['commit'] = run_command("git rev-parse --short HEAD") or "unknown"
info['branch'] = run_command("git rev-parse --abbrev-ref HEAD") or "unknown"
# Check if repo is dirty (has uncommitted changes)
status = run_command("git status --porcelain")
info["dirty"] = bool(status) if status is not None else False
info['dirty'] = bool(status) if status is not None else False
# Get commit message
info["message"] = run_command("git log -1 --pretty=%B") or ""
info["message"] = info["message"].split("\n")[0][:80] # First line, truncated
info['message'] = run_command("git log -1 --pretty=%B") or ""
info['message'] = info['message'].split('\n')[0][:80] # First line, truncated
return info
@ -68,20 +66,20 @@ def get_system_info():
info = {}
# Basic system info
info["hostname"] = socket.gethostname()
info["platform"] = platform.system()
info["python_version"] = platform.python_version()
info["torch_version"] = torch.__version__
info['hostname'] = socket.gethostname()
info['platform'] = platform.system()
info['python_version'] = platform.python_version()
info['torch_version'] = torch.__version__
# CPU and memory
info["cpu_count"] = psutil.cpu_count(logical=False)
info["cpu_count_logical"] = psutil.cpu_count(logical=True)
info["memory_gb"] = psutil.virtual_memory().total / (1024**3)
info['cpu_count'] = psutil.cpu_count(logical=False)
info['cpu_count_logical'] = psutil.cpu_count(logical=True)
info['memory_gb'] = psutil.virtual_memory().total / (1024**3)
# User and environment
info["user"] = os.environ.get("USER", "unknown")
info["nanochat_base_dir"] = os.environ.get("NANOCHAT_BASE_DIR", "out")
info["working_dir"] = os.getcwd()
info['user'] = os.environ.get('USER', 'unknown')
info['nanochat_base_dir'] = os.environ.get('NANOCHAT_BASE_DIR', 'out')
info['working_dir'] = os.getcwd()
return info
@ -165,18 +163,16 @@ Generated: {timestamp}
"""
# bloat metrics: package all of the source code and assess its weight
packaged = run_command(
'files-to-prompt . -e py -e md -e rs -e html -e toml -e sh --ignore "*target*" --cxml'
)
packaged = run_command('files-to-prompt . -e py -e md -e rs -e html -e toml -e sh --ignore "*target*" --cxml')
num_chars = len(packaged)
num_lines = len(packaged.split("\n"))
num_files = len([x for x in packaged.split("\n") if x.startswith("<source>")])
num_lines = len(packaged.split('\n'))
num_files = len([x for x in packaged.split('\n') if x.startswith('<source>')])
num_tokens = num_chars // 4 # assume approximately 4 chars per token
# count dependencies via uv.lock
uv_lock_lines = 0
if os.path.exists("uv.lock"):
with open("uv.lock", encoding="utf-8") as f:
if os.path.exists('uv.lock'):
with open('uv.lock', encoding='utf-8') as f:
uv_lock_lines = len(f.readlines())
header += f"""
@ -231,7 +227,7 @@ def extract(section, keys):
def extract_timestamp(content, prefix):
"""Extract timestamp from content with given prefix."""
for line in content.split("\n"):
for line in content.split('\n'):
if line.startswith(prefix):
time_str = line.split(":", 1)[1].strip()
try:
@ -255,9 +251,7 @@ class Report:
file_path = os.path.join(self.report_dir, file_name)
with open(file_path, "w", encoding="utf-8") as f:
f.write(f"## {section}\n")
f.write(
f"timestamp: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n"
)
f.write(f"timestamp: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n")
for item in data:
if not item:
# skip falsy values like None or empty dict etc.
@ -283,9 +277,7 @@ class Report:
report_dir = self.report_dir
report_file = os.path.join(report_dir, "report.md")
print(f"Generating report to {report_file}")
final_metrics = (
{}
) # the most important final metrics we'll add as table at the end
final_metrics = {} # the most important final metrics we'll add as table at the end
start_time = None
end_time = None
with open(report_file, "w", encoding="utf-8") as out_file:
@ -297,18 +289,12 @@ class Report:
out_file.write(header_content)
start_time = extract_timestamp(header_content, "Run started:")
# capture bloat data for summary later (the stuff after Bloat header and until \n\n)
bloat_data = re.search(
r"### Bloat\n(.*?)\n\n", header_content, re.DOTALL
)
bloat_data = re.search(r"### Bloat\n(.*?)\n\n", header_content, re.DOTALL)
bloat_data = bloat_data.group(1) if bloat_data else ""
else:
start_time = (
None # will cause us to not write the total wall clock time
)
start_time = None # will cause us to not write the total wall clock time
bloat_data = "[bloat data missing]"
print(
f"Warning: {header_file} does not exist. Did you forget to run `nanochat reset`?"
)
print(f"Warning: {header_file} does not exist. Did you forget to run `nanochat reset`?")
# process all the individual sections
for file_name in EXPECTED_FILES:
section_file = os.path.join(report_dir, file_name)
@ -329,9 +315,7 @@ class Report:
if file_name == "chat-evaluation-sft.md":
final_metrics["sft"] = extract(section, chat_metrics)
if file_name == "chat-evaluation-rl.md":
final_metrics["rl"] = extract(
section, "GSM8K"
) # RL only evals GSM8K
final_metrics["rl"] = extract(section, "GSM8K") # RL only evals GSM8K
# append this section of the report
out_file.write(section)
out_file.write("\n")
@ -345,9 +329,7 @@ class Report:
for stage_metrics in final_metrics.values():
all_metrics.update(stage_metrics.keys())
# Custom ordering: CORE first, ChatCORE last, rest in middle
all_metrics = sorted(
all_metrics, key=lambda x: (x != "CORE", x == "ChatCORE", x)
)
all_metrics = sorted(all_metrics, key=lambda x: (x != "CORE", x == "ChatCORE", x))
# Fixed column widths
stages = ["base", "mid", "sft", "rl"]
metric_width = 15
@ -380,7 +362,7 @@ class Report:
else:
out_file.write("Total wall clock time: unknown\n")
# also cp the report.md file to current directory
print(f"Copying report.md to current directory for convenience")
print("Copying report.md to current directory for convenience")
shutil.copy(report_file, "report.md")
return report_file
@ -432,9 +414,7 @@ def get_report():
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(
description="Generate or reset nanochat training reports."
)
parser = argparse.ArgumentParser(description="Generate or reset nanochat training reports.")
parser.add_argument(
"command",
nargs="?",

View File

@ -27,13 +27,14 @@ SPECIAL_TOKENS = [
# NOTE: this split pattern deviates from GPT-4 in that we use \p{N}{1,2} instead of \p{N}{1,3}
# I did this because I didn't want to "waste" too many tokens on numbers for smaller vocab sizes.
# I haven't validated that this is actually a good idea, TODO.
SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,2}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""
SPLIT_PATTERN = (
r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,2}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""
)
# -----------------------------------------------------------------------------
# Generic GPT-4-style tokenizer based on HuggingFace Tokenizer
from tokenizers import Regex
from tokenizers import Regex, decoders, pre_tokenizers
from tokenizers import Tokenizer as HFTokenizer
from tokenizers import decoders, pre_tokenizers
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
@ -75,14 +76,10 @@ class HuggingFaceTokenizer:
# NOTE: The pattern was changed from \p{N}{1,3} to \p{N}{1,2} because I suspect it is harmful to
# very small models and smaller vocab sizes, because it is a little bit wasteful in the token space.
# (but I haven't validated this! TODO)
gpt4_split_regex = Regex(
SPLIT_PATTERN
) # huggingface demands that you wrap it in Regex!!
gpt4_split_regex = Regex(SPLIT_PATTERN) # huggingface demands that you wrap it in Regex!!
tokenizer.pre_tokenizer = pre_tokenizers.Sequence(
[
pre_tokenizers.Split(
pattern=gpt4_split_regex, behavior="isolated", invert=False
),
pre_tokenizers.Split(pattern=gpt4_split_regex, behavior="isolated", invert=False),
pre_tokenizers.ByteLevel(add_prefix_space=False, use_regex=False),
]
)
@ -119,15 +116,11 @@ class HuggingFaceTokenizer:
assert isinstance(text, str)
ids = []
if prepend is not None:
prepend_id = (
prepend if isinstance(prepend, int) else self.encode_special(prepend)
)
prepend_id = prepend if isinstance(prepend, int) else self.encode_special(prepend)
ids.append(prepend_id)
ids.extend(self.tokenizer.encode(text, add_special_tokens=False).ids)
if append is not None:
append_id = (
append if isinstance(append, int) else self.encode_special(append)
)
append_id = append if isinstance(append, int) else self.encode_special(append)
ids.append(append_id)
return ids
@ -183,20 +176,14 @@ class RustBPETokenizer:
tokenizer = rustbpe.Tokenizer()
# the special tokens are inserted later in __init__, we don't train them here
vocab_size_no_special = vocab_size - len(SPECIAL_TOKENS)
assert (
vocab_size_no_special >= 256
), f"vocab_size_no_special must be at least 256, got {vocab_size_no_special}"
tokenizer.train_from_iterator(
text_iterator, vocab_size_no_special, pattern=SPLIT_PATTERN
)
assert vocab_size_no_special >= 256, f"vocab_size_no_special must be at least 256, got {vocab_size_no_special}"
tokenizer.train_from_iterator(text_iterator, vocab_size_no_special, pattern=SPLIT_PATTERN)
# 2) construct the associated tiktoken encoding for inference
pattern = tokenizer.get_pattern()
mergeable_ranks_list = tokenizer.get_mergeable_ranks()
mergeable_ranks = {bytes(k): v for k, v in mergeable_ranks_list}
tokens_offset = len(mergeable_ranks)
special_tokens = {
name: tokens_offset + i for i, name in enumerate(SPECIAL_TOKENS)
}
special_tokens = {name: tokens_offset + i for i, name in enumerate(SPECIAL_TOKENS)}
enc = tiktoken.Encoding(
name="rustbpe",
pat_str=pattern,
@ -242,13 +229,9 @@ class RustBPETokenizer:
# text can be either a string or a list of strings
if prepend is not None:
prepend_id = (
prepend if isinstance(prepend, int) else self.encode_special(prepend)
)
prepend_id = prepend if isinstance(prepend, int) else self.encode_special(prepend)
if append is not None:
append_id = (
append if isinstance(append, int) else self.encode_special(append)
)
append_id = append if isinstance(append, int) else self.encode_special(append)
if isinstance(text, str):
ids = self.enc.encode_ordinary(text)
@ -305,12 +288,8 @@ class RustBPETokenizer:
# some conversation surgery is necessary here for now...
conversation = copy.deepcopy(conversation) # avoid mutating the original
messages = conversation["messages"]
assert (
messages[1]["role"] == "user"
), "System message must be followed by a user message"
messages[1]["content"] = (
messages[0]["content"] + "\n\n" + messages[1]["content"]
)
assert messages[1]["role"] == "user", "System message must be followed by a user message"
messages[1]["content"] = messages[0]["content"] + "\n\n" + messages[1]["content"]
messages = messages[1:]
else:
messages = conversation["messages"]
@ -318,36 +297,28 @@ class RustBPETokenizer:
# fetch all the special tokens we need
bos = self.get_bos_token_id()
user_start, user_end = self.encode_special(
"<|user_start|>"
), self.encode_special("<|user_end|>")
assistant_start, assistant_end = self.encode_special(
"<|assistant_start|>"
), self.encode_special("<|assistant_end|>")
python_start, python_end = self.encode_special(
"<|python_start|>"
), self.encode_special("<|python_end|>")
output_start, output_end = self.encode_special(
"<|output_start|>"
), self.encode_special("<|output_end|>")
user_start, user_end = self.encode_special("<|user_start|>"), self.encode_special("<|user_end|>")
assistant_start, assistant_end = (
self.encode_special("<|assistant_start|>"),
self.encode_special("<|assistant_end|>"),
)
python_start, python_end = self.encode_special("<|python_start|>"), self.encode_special("<|python_end|>")
output_start, output_end = self.encode_special("<|output_start|>"), self.encode_special("<|output_end|>")
# now we can tokenize the conversation
add_tokens(bos, 0)
for i, message in enumerate(messages):
# some sanity checking here around assumptions, to prevent footguns
must_be_from = "user" if i % 2 == 0 else "assistant"
assert (
message["role"] == must_be_from
), f"Message {i} is from {message['role']} but should be from {must_be_from}"
assert message["role"] == must_be_from, (
f"Message {i} is from {message['role']} but should be from {must_be_from}"
)
# content can be either a simple string or a list of parts (e.g. containing tool calls)
content = message["content"]
if message["role"] == "user":
assert isinstance(
content, str
), "User messages are simply expected to be strings"
assert isinstance(content, str), "User messages are simply expected to be strings"
value_ids = self.encode(content)
add_tokens(user_start, 0)
add_tokens(value_ids, 0)
@ -388,10 +359,10 @@ class RustBPETokenizer:
def visualize_tokenization(self, ids, mask, with_token_id=False):
"""Small helper function useful in debugging: visualize the tokenization of render_conversation"""
RED = "\033[91m"
GREEN = "\033[92m"
RESET = "\033[0m"
GRAY = "\033[90m"
RED = '\033[91m'
GREEN = '\033[92m'
RESET = '\033[0m'
GRAY = '\033[90m'
tokens = []
for i, (token_id, mask_val) in enumerate(zip(ids, mask)):
token_str = self.decode([token_id])
@ -399,7 +370,7 @@ class RustBPETokenizer:
tokens.append(f"{color}{token_str}{RESET}")
if with_token_id:
tokens.append(f"{GRAY}({token_id}){RESET}")
return "|".join(tokens)
return '|'.join(tokens)
def render_for_completion(self, conversation):
"""
@ -410,9 +381,7 @@ class RustBPETokenizer:
# We have some surgery to do: we need to pop the last message (of the Assistant)
conversation = copy.deepcopy(conversation) # avoid mutating the original
messages = conversation["messages"]
assert (
messages[-1]["role"] == "assistant"
), "Last message must be from the Assistant"
assert messages[-1]["role"] == "assistant", "Last message must be from the Assistant"
messages.pop() # remove the last message (of the Assistant) inplace
# Now tokenize the conversation
@ -445,9 +414,9 @@ def get_token_bytes(device="cpu"):
base_dir = get_base_dir()
tokenizer_dir = os.path.join(base_dir, "tokenizer")
token_bytes_path = os.path.join(tokenizer_dir, "token_bytes.pt")
assert os.path.exists(
token_bytes_path
), f"Token bytes not found at {token_bytes_path}? It gets written by tok_train.py"
assert os.path.exists(token_bytes_path), (
f"Token bytes not found at {token_bytes_path}? It gets written by tok_train.py"
)
with open(token_bytes_path, "rb") as f:
token_bytes = torch.load(f, map_location=device)
return token_bytes

View File

@ -77,18 +77,24 @@ conflicts = [
],
]
[tool.autoflake]
remove-all-unused-imports = true
in-place = true
[tool.ruff]
target-version = "py310"
line-length = 120
fix = true
unsafe-fixes = true
[tool.isort]
profile = "black"
[tool.ruff.lint]
select = [
"F", # Pyflakes (unused imports) - replaces autoflake
"I", # isort - replaces isort
"UP", # pyupgrade - replaces pyupgrade
]
[tool.pyupgrade]
py310-plus = true
[tool.ruff.lint.isort]
known-first-party = ["nanochat"]
[tool.black]
target-version = ["py310"]
[tool.ruff.format]
quote-style = "preserve"
[tool.codespell]
write-changes = true

View File

@ -48,7 +48,7 @@ def place_eval_bundle(file_path):
base_dir = get_base_dir()
eval_bundle_dir = os.path.join(base_dir, "eval_bundle")
with tempfile.TemporaryDirectory() as tmpdir:
with zipfile.ZipFile(file_path, "r") as zip_ref:
with zipfile.ZipFile(file_path, 'r') as zip_ref:
zip_ref.extractall(tmpdir)
extracted_bundle_dir = os.path.join(tmpdir, "eval_bundle")
shutil.move(extracted_bundle_dir, eval_bundle_dir)
@ -65,23 +65,21 @@ def evaluate_model(model, tokenizer, device, max_per_task=-1):
eval_bundle_dir = os.path.join(base_dir, "eval_bundle")
# Download the eval bundle to disk (and unzip if needed)
if not os.path.exists(eval_bundle_dir):
download_file_with_lock(
EVAL_BUNDLE_URL, "eval_bundle.zip", postprocess_fn=place_eval_bundle
)
download_file_with_lock(EVAL_BUNDLE_URL, "eval_bundle.zip", postprocess_fn=place_eval_bundle)
config_path = os.path.join(eval_bundle_dir, "core.yaml")
data_base_path = os.path.join(eval_bundle_dir, "eval_data")
eval_meta_data = os.path.join(eval_bundle_dir, "eval_meta_data.csv")
with open(config_path, encoding="utf-8") as f:
with open(config_path, encoding='utf-8') as f:
config = yaml.safe_load(f)
tasks = config["icl_tasks"]
tasks = config['icl_tasks']
# Load random baseline values from eval metadata
random_baselines = {}
with open(eval_meta_data, encoding="utf-8") as f:
with open(eval_meta_data, encoding='utf-8') as f:
reader = csv.DictReader(f)
for row in reader:
task_name = row["Eval Task"]
random_baseline = row["Random baseline"]
task_name = row['Eval Task']
random_baseline = row['Random baseline']
random_baselines[task_name] = float(random_baseline)
# Evaluate each task
@ -89,21 +87,18 @@ def evaluate_model(model, tokenizer, device, max_per_task=-1):
centered_results = {}
for task in tasks:
start_time = time.time()
label = task["label"]
label = task['label']
task_meta = {
"task_type": task["icl_task_type"],
"dataset_uri": task["dataset_uri"],
"num_fewshot": task["num_fewshot"][0],
"continuation_delimiter": task.get("continuation_delimiter", " "),
'task_type': task['icl_task_type'],
'dataset_uri': task['dataset_uri'],
'num_fewshot': task['num_fewshot'][0],
'continuation_delimiter': task.get('continuation_delimiter', ' '),
}
print0(
f"Evaluating: {label} ({task_meta['num_fewshot']}-shot, type: {task_meta['task_type']})... ",
end="",
)
print0(f"Evaluating: {label} ({task_meta['num_fewshot']}-shot, type: {task_meta['task_type']})... ", end='')
# Load data for this task
data_path = os.path.join(data_base_path, task_meta["dataset_uri"])
with open(data_path, encoding="utf-8") as f:
data_path = os.path.join(data_base_path, task_meta['dataset_uri'])
with open(data_path, encoding='utf-8') as f:
data = [json.loads(line.strip()) for line in f]
# shuffle the data because in many cases it appears ordered but we want
@ -118,21 +113,13 @@ def evaluate_model(model, tokenizer, device, max_per_task=-1):
results[label] = accuracy
random_baseline = random_baselines[label]
centered_result = (accuracy - 0.01 * random_baseline) / (
1.0 - 0.01 * random_baseline
)
centered_result = (accuracy - 0.01 * random_baseline) / (1.0 - 0.01 * random_baseline)
centered_results[label] = centered_result
end_time = time.time()
print0(
f"accuracy: {accuracy:.4f} | centered: {centered_result:.4f} | time: {end_time - start_time:.2f}s"
)
print0(f"accuracy: {accuracy:.4f} | centered: {centered_result:.4f} | time: {end_time - start_time:.2f}s")
core_metric = sum(centered_results.values()) / len(centered_results)
out = {
"results": results,
"centered_results": centered_results,
"core_metric": core_metric,
}
out = {"results": results, "centered_results": centered_results, "core_metric": core_metric}
return out
@ -173,24 +160,15 @@ def main():
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"--hf-path", type=str, default=None, help="HuggingFace model path to evaluate"
)
parser.add_argument(
"--max-per-task",
type=int,
default=-1,
help="Max examples per task to evaluate (-1 = disable)",
)
parser.add_argument('--hf-path', type=str, default=None, help='HuggingFace model path to evaluate')
parser.add_argument('--max-per-task', type=int, default=-1, help='Max examples per task to evaluate (-1 = disable)')
args = parser.parse_args()
# distributed / precision setup
device_type = autodetect_device_type()
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
autocast_ctx = (
torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16)
if device_type == "cuda"
else nullcontext()
torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
)
# Load model and tokenizer from command line or from file system
@ -221,18 +199,16 @@ def main():
results = out["results"]
centered_results = out["centered_results"]
core_metric = out["core_metric"]
with open(output_csv_path, "w", encoding="utf-8", newline="") as f:
with open(output_csv_path, 'w', encoding='utf-8', newline='') as f:
f.write(f"{'Task':<35}, {'Accuracy':<10}, {'Centered':<10}\n")
for label in results:
f.write(
f"{label:<35}, {results[label]:<10.6f}, {centered_results[label]:<10.6f}\n"
)
f.write(f"{label:<35}, {results[label]:<10.6f}, {centered_results[label]:<10.6f}\n")
f.write(f"{'CORE':<35}, {'':<10}, {core_metric:<10.6f}\n")
# Print the content of the csv file to console too
print0("=" * 80)
print0(f"Model: {model_name}")
print0("=" * 80)
with open(output_csv_path, encoding="utf-8") as f:
with open(output_csv_path, encoding='utf-8') as f:
print0(f.read())
# Log to report

View File

@ -13,12 +13,7 @@ from contextlib import nullcontext
import torch
from nanochat.checkpoint_manager import load_model
from nanochat.common import (
autodetect_device_type,
compute_cleanup,
compute_init,
print0,
)
from nanochat.common import autodetect_device_type, compute_cleanup, compute_init, print0
from nanochat.dataloader import tokenizing_distributed_data_loader
from nanochat.engine import Engine
from nanochat.loss_eval import evaluate_bpb
@ -30,35 +25,25 @@ split_tokens = 20 * 524288 # number of tokens to evaluate per split
model_tag = None # optional model tag for the output directory name
model_step = None # optional model step for the output directory name
device_type = "" # cuda|cpu|mps (empty => autodetect)
exec(
open(os.path.join("nanochat", "configurator.py")).read()
) # overrides from command line or config file
exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file
# Load the base model and the tokenizer
device_type = autodetect_device_type() if device_type == "" else device_type
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
model, tokenizer, meta = load_model(
"base", device, phase="eval", model_tag=model_tag, step=model_step
)
model, tokenizer, meta = load_model("base", device, phase="eval", model_tag=model_tag, step=model_step)
sequence_len = meta["model_config"]["sequence_len"] # could be arbitrary really
autocast_ctx = (
torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16)
if device_type == "cuda"
else nullcontext()
torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
)
# Evaluate the loss on each split
tokens_per_step = device_batch_size * sequence_len * ddp_world_size
assert (
split_tokens % tokens_per_step == 0
), "split_tokens must be divisible by tokens_per_step"
assert split_tokens % tokens_per_step == 0, "split_tokens must be divisible by tokens_per_step"
steps = split_tokens // tokens_per_step
token_bytes = get_token_bytes(device=device)
bpb_results = {}
for split_name in ["train", "val"]:
loader = tokenizing_distributed_data_loader(
device_batch_size, sequence_len, split_name, device=device
)
loader = tokenizing_distributed_data_loader(device_batch_size, sequence_len, split_name, device=device)
with autocast_ctx:
bpb = evaluate_bpb(model, loader, steps, token_bytes)
print0(f"{split_name} bpb: {bpb:.4f}")
@ -80,9 +65,7 @@ if ddp_rank == 0:
for prompt in prompts:
tokens = tokenizer(prompt, prepend="<|bos|>")
with autocast_ctx:
sample, _ = engine.generate_batch(
tokens, num_samples=1, max_tokens=16, temperature=0
)
sample, _ = engine.generate_batch(tokens, num_samples=1, max_tokens=16, temperature=0)
sample_str = tokenizer.decode(sample[0])
print0(sample_str)
samples.append(sample_str)

View File

@ -30,10 +30,7 @@ from nanochat.common import (
print0,
print_banner,
)
from nanochat.dataloader import (
tokenizing_distributed_data_loader,
tokenizing_distributed_data_loader_with_state,
)
from nanochat.dataloader import tokenizing_distributed_data_loader, tokenizing_distributed_data_loader_with_state
from nanochat.engine import Engine
from nanochat.gpt import GPT, GPTConfig
from nanochat.loss_eval import evaluate_bpb
@ -48,16 +45,16 @@ run = "dummy" # wandb run name default ("dummy" is special - we won't log to wa
# Runtime
device_type = "" # cuda|cpu|mps (empty => autodetect good device type default, in order: CUDA > MPS > CPU)
# Model architecture
depth = (
20 # the depth of the Transformer model to train, rest of the kwargs are derived
)
depth = 20 # the depth of the Transformer model to train, rest of the kwargs are derived
max_seq_len = 2048 # max context length
# Training horizon. Only one of these 3 will be used, in this order of precedence.
num_iterations = -1 # explicit number of steps of the optimization (-1 = disable)
target_flops = (
-1.0
) # calculate num_iterations to reach target_flops. Useful for scaling laws experiments (-1 = disable)
target_param_data_ratio = 20 # calculate num_iterations to maintain fixed data:param ratio (Chinchilla=20) (-1 = disable)
target_param_data_ratio = (
20 # calculate num_iterations to maintain fixed data:param ratio (Chinchilla=20) (-1 = disable)
)
# Optimization
device_batch_size = 32 # per-device batch size (set to not OOM)
total_batch_size = 524288 # total desired batch size, in #tokens
@ -69,33 +66,19 @@ grad_clip = 1.0 # gradient clipping value (0.0 = disabled)
warmup_ratio = 0.0 # ratio of iterations for LR warmup
warmdown_ratio = 0.2 # ratio of iterations for LR warmdown
final_lr_frac = 0.0 # final LR is this fraction of the initial LR
resume_from_step = (
-1
) # resume training from this step of the optimization (-1 = disable)
resume_from_step = -1 # resume training from this step of the optimization (-1 = disable)
# Evaluation
eval_every = 250 # every how many steps to evaluate the model for val bpb
eval_tokens = 20 * 524288 # number of tokens to evaluate val loss on
core_metric_every = (
2000 # every how many steps to evaluate the core metric (-1 = disable)
)
core_metric_every = 2000 # every how many steps to evaluate the core metric (-1 = disable)
core_metric_max_per_task = 500 # examples per task in estimating the core metric
sample_every = 2000 # every how many steps to sample from the model
save_every = (
-1
) # every how many steps to save model checkpoints (-1 = disable, and save only at the end of the run)
save_every = -1 # every how many steps to save model checkpoints (-1 = disable, and save only at the end of the run)
# Output
model_tag = (
"" # optionally override the model tag for the output checkpoint directory name
)
model_tag = "" # optionally override the model tag for the output checkpoint directory name
# now allow CLI to override the settings via the configurator lol
config_keys = [
k
for k, v in globals().items()
if not k.startswith("_") and isinstance(v, (int, float, bool, str))
]
exec(
open(os.path.join("nanochat", "configurator.py")).read()
) # overrides from command line or config file
config_keys = [k for k, v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file
user_config = {k: globals()[k] for k in config_keys} # will be useful for logging
# -----------------------------------------------------------------------------
@ -104,20 +87,14 @@ device_type = autodetect_device_type() if device_type == "" else device_type
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
autocast_ctx = (
torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16)
if device_type == "cuda"
else nullcontext()
torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
)
synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None
get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0
# wandb logging init
use_dummy_wandb = run == "dummy" or not master_process
wandb_run = (
DummyWandb()
if use_dummy_wandb
else wandb.init(project="nanochat", name=run, config=user_config)
)
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat", name=run, config=user_config)
# Tokenizer will be useful for evaluation, also we need the vocab size
tokenizer = get_tokenizer()
@ -127,15 +104,9 @@ print0(f"Vocab size: {vocab_size:,}")
# Model kwargs are derived from the desired depth of the model
num_layers = depth
model_dim = (
depth * 64
) # aspect ratio 64 (usually this is varied from 64 -> 128 as model size increases)
num_heads = max(
1, (model_dim + 127) // 128
) # head dim 128 (the division here is ceil div)
num_kv_heads = (
num_heads # default is 1:1 GQA (Group Query Attention) ratio (i.e. GQA is disabled)
)
model_dim = depth * 64 # aspect ratio 64 (usually this is varied from 64 -> 128 as model size increases)
num_heads = max(1, (model_dim + 127) // 128) # head dim 128 (the division here is ceil div)
num_kv_heads = num_heads # default is 1:1 GQA (Group Query Attention) ratio (i.e. GQA is disabled)
print0(f"num_layers: {num_layers}")
print0(f"model_dim: {model_dim}")
print0(f"num_heads: {num_heads}")
@ -143,21 +114,13 @@ print0(f"num_kv_heads: {num_kv_heads}")
# Optimizer / data / training length related hyperparameters
# figure out the needed gradient accumulation to reach the desired total batch size
tokens_per_fwdbwd = (
device_batch_size * max_seq_len
) # tokens per iteration for a single rank
world_tokens_per_fwdbwd = (
tokens_per_fwdbwd * ddp_world_size
) # total tokens per iteration for all ranks
tokens_per_fwdbwd = device_batch_size * max_seq_len # tokens per iteration for a single rank
world_tokens_per_fwdbwd = tokens_per_fwdbwd * ddp_world_size # total tokens per iteration for all ranks
assert total_batch_size % world_tokens_per_fwdbwd == 0
grad_accum_steps = total_batch_size // world_tokens_per_fwdbwd
print0(
f"Tokens / micro-batch / rank: {device_batch_size} x {max_seq_len} = {tokens_per_fwdbwd:,}"
)
print0(f"Tokens / micro-batch / rank: {device_batch_size} x {max_seq_len} = {tokens_per_fwdbwd:,}")
print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}")
print0(
f"Total batch size {total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}"
)
print0(f"Total batch size {total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}")
# -----------------------------------------------------------------------------
# Initialize the Model
@ -191,9 +154,7 @@ if resuming:
del model_data # free up this memory after the copy
orig_model = model # original, uncompiled model, for saving raw model state_dict and for inference/evaluation (because the shapes may change shape)
model = torch.compile(
model, dynamic=False
) # the inputs to model will never change shape so dynamic=False is safe
model = torch.compile(model, dynamic=False) # the inputs to model will never change shape so dynamic=False is safe
num_params = sum(p.numel() for p in model.parameters())
print0(f"Number of parameters: {num_params:,}")
num_flops_per_token = model.estimate_flops()
@ -211,25 +172,18 @@ elif target_param_data_ratio > 0:
# calculate the number of iterations from the target param data ratio
target_tokens = target_param_data_ratio * num_params
num_iterations = target_tokens // total_batch_size
print0(
f"Calculated number of iterations from target data:param ratio: {num_iterations:,}"
)
print0(f"Calculated number of iterations from target data:param ratio: {num_iterations:,}")
else:
raise ValueError("No training horizon specified")
total_tokens = total_batch_size * num_iterations
print0(f"Total number of training tokens: {total_tokens:,}")
print0(
f"Tokens : Params ratio: {total_batch_size * num_iterations / num_params:.2f}"
) # Chinchilla is ~20
print0(f"Tokens : Params ratio: {total_batch_size * num_iterations / num_params:.2f}") # Chinchilla is ~20
print0(f"Total training FLOPs estimate: {num_flops_per_token * total_tokens:e}")
# -----------------------------------------------------------------------------
# Initialize the Optimizer (Muon for Linear layers, AdamW for embedding and lm_head)
optimizers = model.setup_optimizers(
unembedding_lr=unembedding_lr,
embedding_lr=embedding_lr,
matrix_lr=matrix_lr,
weight_decay=weight_decay,
unembedding_lr=unembedding_lr, embedding_lr=embedding_lr, matrix_lr=matrix_lr, weight_decay=weight_decay
)
adamw_optimizer, muon_optimizer = optimizers
@ -241,22 +195,14 @@ if resuming:
# -----------------------------------------------------------------------------
# Initialize the DataLoaders for train/val
tokens_dir = os.path.join(base_dir, "tokenized_data")
dataloader_resume_state_dict = (
None if not resuming else meta_data["dataloader_state_dict"]
)
dataloader_resume_state_dict = None if not resuming else meta_data["dataloader_state_dict"]
train_loader = tokenizing_distributed_data_loader_with_state(
device_batch_size,
max_seq_len,
split="train",
device=device,
resume_state_dict=dataloader_resume_state_dict,
device_batch_size, max_seq_len, split="train", device=device, resume_state_dict=dataloader_resume_state_dict
)
build_val_loader = lambda: tokenizing_distributed_data_loader(
device_batch_size, max_seq_len, split="val", device=device
)
x, y, dataloader_state_dict = next(
train_loader
) # kick off load of the very first batch of data
x, y, dataloader_state_dict = next(train_loader) # kick off load of the very first batch of data
# -----------------------------------------------------------------------------
# Set up hyperparameter schedulers
@ -300,9 +246,7 @@ else:
# -----------------------------------------------------------------------------
# Training loop
while True:
last_step = (
step == num_iterations
) # loop runs num_iterations+1 times so that we can eval/save at the end
last_step = step == num_iterations # loop runs num_iterations+1 times so that we can eval/save at the end
flops_so_far = num_flops_per_token * total_batch_size * step
# once in a while: evaluate the val bpb (all ranks participate)
@ -328,14 +272,10 @@ while True:
# once in a while: estimate the CORE metric (all ranks participate)
# use the original uncompiled model because the inputs keep changing shape
results = {}
if core_metric_every > 0 and (
last_step or (step > 0 and step % core_metric_every == 0)
):
if core_metric_every > 0 and (last_step or (step > 0 and step % core_metric_every == 0)):
model.eval()
with autocast_ctx:
results = evaluate_model(
orig_model, tokenizer, device, max_per_task=core_metric_max_per_task
)
results = evaluate_model(orig_model, tokenizer, device, max_per_task=core_metric_max_per_task)
print0(f"Step {step:05d} | CORE metric: {results['core_metric']:.4f}")
wandb_run.log(
{
@ -364,19 +304,12 @@ while True:
for prompt in prompts:
tokens = tokenizer(prompt, prepend="<|bos|>")
with autocast_ctx:
sample, _ = engine.generate_batch(
tokens, num_samples=1, max_tokens=16, temperature=0
)
sample, _ = engine.generate_batch(tokens, num_samples=1, max_tokens=16, temperature=0)
print0(tokenizer.decode(sample[0]))
model.train()
# save checkpoint: at the end of the run, or every save_every steps, except at the first step or the resume step
if last_step or (
step > 0
and step != resume_from_step
and save_every > 0
and step % save_every == 0
):
if last_step or (step > 0 and step != resume_from_step and save_every > 0 and step % save_every == 0):
save_checkpoint(
checkpoint_dir,
step,
@ -412,9 +345,7 @@ while True:
with autocast_ctx:
loss = model(x, y)
train_loss = loss.detach() # for logging
loss = (
loss / grad_accum_steps
) # each .backward() is a grad sum => normalize loss here
loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here
loss.backward()
x, y, dataloader_state_dict = next(
train_loader
@ -422,12 +353,8 @@ while True:
# gradient clipping
grad_clip_enabled = grad_clip > 0.0
if grad_clip_enabled:
grad_norm_tensor = torch.nn.utils.clip_grad_norm_(
orig_model.parameters(), grad_clip
)
grad_norm = (
grad_norm_tensor.item()
) # GPU tensor -> CPU float (note: cpu-gpu sync point)
grad_norm_tensor = torch.nn.utils.clip_grad_norm_(orig_model.parameters(), grad_clip)
grad_norm = grad_norm_tensor.item() # GPU tensor -> CPU float (note: cpu-gpu sync point)
# step the optimizers
lrm = get_lr_multiplier(step)
for opt in optimizers:
@ -446,24 +373,18 @@ while True:
# logging
ema_beta = 0.9 # EMA decay factor for some smoothing just for nicer logging
smooth_train_loss = (
ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss.item()
) # EMA the training loss
debiased_smooth_loss = smooth_train_loss / (
1 - ema_beta ** (step + 1)
) # debias the EMA
smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss.item() # EMA the training loss
debiased_smooth_loss = smooth_train_loss / (1 - ema_beta ** (step + 1)) # debias the EMA
pct_done = 100 * step / num_iterations
tok_per_sec = int(total_batch_size / dt)
flops_per_sec = num_flops_per_token * total_batch_size / dt
promised_flops_per_sec_h100 = (
989e12 * ddp_world_size
) # bfloat16 H100 SXM and without 2:4 sparsity
promised_flops_per_sec_h100 = 989e12 * ddp_world_size # bfloat16 H100 SXM and without 2:4 sparsity
mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in %
if step > 10:
total_training_time += dt # only count the time after the first 10 steps
print_grad_norm = f" grad norm: {grad_norm:.4f} |" if grad_clip_enabled else ""
print0(
f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} |{print_grad_norm} lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | total time: {total_training_time/60:.2f}m"
f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} |{print_grad_norm} lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | total time: {total_training_time / 60:.2f}m"
)
if step % 100 == 0:
log_data = {
@ -485,7 +406,7 @@ while True:
# print a few more stats
print0(f"Peak memory usage: {get_max_memory() / 1024 / 1024:.2f}MiB")
print0(f"Total training time: {total_training_time/60:.2f}m")
print0(f"Total training time: {total_training_time / 60:.2f}m")
print0(f"Minimum validation bpb: {min_val_bpb:.4f}")
# Log to report
@ -512,7 +433,7 @@ get_report().log(
"CORE metric estimate": results.get("core_metric", None),
"MFU %": f"{mfu:.2f}%",
"Total training flops": f"{flops_so_far:e}",
"Total training time": f"{total_training_time/60:.2f}m",
"Total training time": f"{total_training_time / 60:.2f}m",
"Peak memory usage": f"{get_max_memory() / 1024 / 1024:.2f}MiB",
},
],

View File

@ -14,61 +14,38 @@ from nanochat.checkpoint_manager import load_model
from nanochat.common import autodetect_device_type, compute_init
from nanochat.engine import Engine
parser = argparse.ArgumentParser(description="Chat with the model")
parser = argparse.ArgumentParser(description='Chat with the model')
parser.add_argument('-i', '--source', type=str, default="sft", help="Source of the model: sft|mid|rl")
parser.add_argument('-g', '--model-tag', type=str, default=None, help='Model tag to load')
parser.add_argument('-s', '--step', type=int, default=None, help='Step to load')
parser.add_argument('-p', '--prompt', type=str, default='', help='Prompt the model, get a single response back')
parser.add_argument('-t', '--temperature', type=float, default=0.6, help='Temperature for generation')
parser.add_argument('-k', '--top-k', type=int, default=50, help='Top-k sampling parameter')
parser.add_argument(
"-i", "--source", type=str, default="sft", help="Source of the model: sft|mid|rl"
)
parser.add_argument(
"-g", "--model-tag", type=str, default=None, help="Model tag to load"
)
parser.add_argument("-s", "--step", type=int, default=None, help="Step to load")
parser.add_argument(
"-p",
"--prompt",
'--device-type',
type=str,
default="",
help="Prompt the model, get a single response back",
)
parser.add_argument(
"-t", "--temperature", type=float, default=0.6, help="Temperature for generation"
)
parser.add_argument(
"-k", "--top-k", type=int, default=50, help="Top-k sampling parameter"
)
parser.add_argument(
"--device-type",
type=str,
default="",
choices=["cuda", "cpu", "mps"],
help="Device type for evaluation: cuda|cpu|mps. empty => autodetect",
)
parser.add_argument(
"-d", "--dtype", type=str, default="bfloat16", choices=["float32", "bfloat16"]
default='',
choices=['cuda', 'cpu', 'mps'],
help='Device type for evaluation: cuda|cpu|mps. empty => autodetect',
)
parser.add_argument('-d', '--dtype', type=str, default='bfloat16', choices=['float32', 'bfloat16'])
args = parser.parse_args()
# Init the model and tokenizer
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
ptdtype = torch.float32 if args.dtype == "float32" else torch.bfloat16
autocast_ctx = (
torch.amp.autocast(device_type=device_type, dtype=ptdtype)
if device_type == "cuda"
else nullcontext()
)
model, tokenizer, meta = load_model(
args.source, device, phase="eval", model_tag=args.model_tag, step=args.step
)
ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext()
model, tokenizer, meta = load_model(args.source, device, phase="eval", model_tag=args.model_tag, step=args.step)
# Special tokens for the chat state machine
bos = tokenizer.get_bos_token_id()
user_start, user_end = tokenizer.encode_special(
"<|user_start|>"
), tokenizer.encode_special("<|user_end|>")
assistant_start, assistant_end = tokenizer.encode_special(
"<|assistant_start|>"
), tokenizer.encode_special("<|assistant_end|>")
user_start, user_end = tokenizer.encode_special("<|user_start|>"), tokenizer.encode_special("<|user_end|>")
assistant_start, assistant_end = (
tokenizer.encode_special("<|assistant_start|>"),
tokenizer.encode_special("<|assistant_end|>"),
)
# Create Engine for efficient generation
engine = Engine(model, tokenizer)
@ -82,7 +59,6 @@ print("-" * 50)
conversation_tokens = [bos]
while True:
if args.prompt:
# Get the prompt from the launch command
user_input = args.prompt
@ -95,11 +71,11 @@ while True:
break
# Handle special commands
if user_input.lower() in ["quit", "exit"]:
if user_input.lower() in ['quit', 'exit']:
print("Goodbye!")
break
if user_input.lower() == "clear":
if user_input.lower() == 'clear':
conversation_tokens = [bos]
print("Conversation cleared.")
continue
@ -123,9 +99,7 @@ while True:
response_tokens = []
print("\nAssistant: ", end="", flush=True)
with autocast_ctx:
for token_column, token_masks in engine.generate(
conversation_tokens, **generate_kwargs
):
for token_column, token_masks in engine.generate(conversation_tokens, **generate_kwargs):
token = token_column[0] # pop the batch dimension (num_samples=1)
response_tokens.append(token)
token_text = tokenizer.decode([token])

View File

@ -16,13 +16,7 @@ import torch
import torch.distributed as dist
from nanochat.checkpoint_manager import load_model
from nanochat.common import (
autodetect_device_type,
compute_cleanup,
compute_init,
get_dist_info,
print0,
)
from nanochat.common import autodetect_device_type, compute_cleanup, compute_init, get_dist_info, print0
from nanochat.engine import Engine
from tasks.arc import ARC
from tasks.gsm8k import GSM8K
@ -35,25 +29,12 @@ from tasks.spellingbee import SpellingBee
def run_generative_eval(
task_object,
tokenizer,
model,
engine,
num_samples,
max_new_tokens,
temperature,
top_k,
max_problems=None,
task_object, tokenizer, model, engine, num_samples, max_new_tokens, temperature, top_k, max_problems=None
):
ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
device = model.get_device()
num_problems = (
len(task_object)
if max_problems is None
else min(len(task_object), max_problems)
)
num_problems = len(task_object) if max_problems is None else min(len(task_object), max_problems)
# Run the evaluation
num_passed, total = 0, 0
@ -72,13 +53,9 @@ def run_generative_eval(
)
# Decode the completions as text
prefix_length = len(encoded_prompt)
completions = [
tokenizer.decode(result_tokens[prefix_length:]) for result_tokens in results
]
completions = [tokenizer.decode(result_tokens[prefix_length:]) for result_tokens in results]
# Evaluate success criteria
outcomes = [
task_object.evaluate(conversation, completion) for completion in completions
]
outcomes = [task_object.evaluate(conversation, completion) for completion in completions]
passed = any(outcomes)
# Keep stats
@ -86,11 +63,7 @@ def run_generative_eval(
num_passed += int(passed)
# Logging (overwrite the same line in the console)
print(
f"\r\033[KRank {ddp_rank} | {num_passed}/{total} ({100*num_passed/total:.2f}%)",
end="",
flush=True,
)
print(f"\r\033[KRank {ddp_rank} | {num_passed}/{total} ({100 * num_passed / total:.2f}%)", end='', flush=True)
# Finish the in-place progress line with a newline before final summary
print()
@ -105,7 +78,7 @@ def run_generative_eval(
total = total_tensor.item()
print0("=" * 50)
print0(f"Final: {num_passed}/{total} ({100*num_passed/total:.2f}%)")
print0(f"Final: {num_passed}/{total} ({100 * num_passed / total:.2f}%)")
# Return the accuracy
return num_passed / total
@ -118,26 +91,17 @@ def run_generative_eval(
def run_categorical_eval(task_object, tokenizer, model, batch_size, max_problems=None):
ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
device = model.get_device()
bos = (
tokenizer.get_bos_token_id()
) # use BOS as pad token is ok, these positions are ignored
bos = tokenizer.get_bos_token_id() # use BOS as pad token is ok, these positions are ignored
# We'll process batches of independent problems at a time because there is no sampling needed
num_problems = (
len(task_object)
if max_problems is None
else min(len(task_object), max_problems)
)
num_problems = len(task_object) if max_problems is None else min(len(task_object), max_problems)
ceil_div = lambda x, y: -(-x // y)
num_batches = ceil_div(num_problems, batch_size)
# Run the evaluation
letter_to_id_cache = (
{}
) # many letters will repeat often, let's save the tokenizer some work
letter_to_id_cache = {} # many letters will repeat often, let's save the tokenizer some work
num_passed, total = 0, 0
for i in range(ddp_rank, num_batches, ddp_world_size):
i0, i1 = i * batch_size, min((i + 1) * batch_size, num_problems)
@ -145,16 +109,13 @@ def run_categorical_eval(task_object, tokenizer, model, batch_size, max_problems
# Prepare the batch of problems. They might all be of different length, so we pad/collate them.
conversations = [task_object[ii] for ii in range(i0, i1)]
prompt_ids = [
tokenizer.render_for_completion(conversation)
for conversation in conversations
tokenizer.render_for_completion(conversation) for conversation in conversations
] # TODO: remake the way this works
max_length = max(len(ids) for ids in prompt_ids)
answer_time_positions = [
len(ids) - 1 for ids in prompt_ids
] # where the last token is (and the predicted answer)
padded_prompt_ids = [
ids + [bos] * (max_length - len(ids)) for ids in prompt_ids
]
padded_prompt_ids = [ids + [bos] * (max_length - len(ids)) for ids in prompt_ids]
prompt_ids = torch.tensor(padded_prompt_ids, dtype=torch.long, device=device)
# Get the logits for the whole batch of conversations in parallel (efficiency win here)
@ -167,14 +128,12 @@ def run_categorical_eval(task_object, tokenizer, model, batch_size, max_problems
# letter (e.g. A, B, C, D), but evaluations typically make the task easier in this way.
for idx, conversation in enumerate(conversations):
# get the token ids of all the available letters of this problem
letters = conversation["letters"]
letters = conversation['letters']
letter_ids = []
for letter in letters:
if not letter in letter_to_id_cache:
encoded_letter = tokenizer.encode(letter)
assert (
len(encoded_letter) == 1
), "Each letter must be a single token"
assert len(encoded_letter) == 1, "Each letter must be a single token"
letter_to_id_cache[letter] = encoded_letter[0]
letter_ids.append(letter_to_id_cache[letter])
# focus logits just down to the answer position and the available letters of the answer
@ -198,7 +157,7 @@ def run_categorical_eval(task_object, tokenizer, model, batch_size, max_problems
total = total_tensor.item()
average = num_passed / total
print0(f"Final: {num_passed}/{total} ({100*average:.2f}%)")
print0(f"Final: {num_passed}/{total} ({100 * average:.2f}%)")
return average
@ -219,16 +178,16 @@ def run_chat_eval(
):
# Create the evaluation object
task_module = {
"HumanEval": HumanEval,
"MMLU": partial(MMLU, subset="all", split="test"),
"ARC-Easy": partial(ARC, subset="ARC-Easy", split="test"),
"ARC-Challenge": partial(ARC, subset="ARC-Challenge", split="test"),
"GSM8K": partial(GSM8K, subset="main", split="test"),
"SpellingBee": partial(SpellingBee, size=256, split="test"),
'HumanEval': HumanEval,
'MMLU': partial(MMLU, subset="all", split="test"),
'ARC-Easy': partial(ARC, subset="ARC-Easy", split="test"),
'ARC-Challenge': partial(ARC, subset="ARC-Challenge", split="test"),
'GSM8K': partial(GSM8K, subset="main", split="test"),
'SpellingBee': partial(SpellingBee, size=256, split="test"),
}[task_name]
task_object = task_module()
# Run the evaluation
if task_object.eval_type == "generative":
if task_object.eval_type == 'generative':
acc = run_generative_eval(
task_object,
tokenizer,
@ -240,10 +199,8 @@ def run_chat_eval(
top_k,
max_problems=max_problems,
)
elif task_object.eval_type == "categorical":
acc = run_categorical_eval(
task_object, tokenizer, model, batch_size, max_problems=max_problems
)
elif task_object.eval_type == 'categorical':
acc = run_categorical_eval(task_object, tokenizer, model, batch_size, max_problems=max_problems)
else:
raise ValueError(f"Unsupported task evaluation type: {task_object.eval_type}")
return acc
@ -251,87 +208,55 @@ def run_chat_eval(
# -----------------------------------------------------------------------------
if __name__ == "__main__":
# Parse command-line arguments
parser = argparse.ArgumentParser()
parser.add_argument('-i', '--source', type=str, required=True, help="Source of the model: sft|mid|rl")
parser.add_argument(
"-i",
"--source",
type=str,
required=True,
help="Source of the model: sft|mid|rl",
)
parser.add_argument(
"-a",
"--task-name",
'-a',
'--task-name',
type=str,
default=None,
help="Task name. Default = all tasks. Use | to split multiple tasks.",
)
parser.add_argument('-d', '--dtype', type=str, default='bfloat16', choices=['float32', 'bfloat16'])
parser.add_argument('-t', '--temperature', type=float, default=0.0)
parser.add_argument('-m', '--max-new-tokens', type=int, default=512)
parser.add_argument('-n', '--num-samples', type=int, default=1)
parser.add_argument('-k', '--top-k', type=int, default=50)
parser.add_argument('-b', '--batch-size', type=int, default=8, help='Batch size for categorical evaluation')
parser.add_argument('-g', '--model-tag', type=str, default=None, help='Model tag to load')
parser.add_argument('-s', '--step', type=int, default=None, help='Step to load')
parser.add_argument('-x', '--max-problems', type=int, default=None, help='Max problems to evaluate')
parser.add_argument(
"-d", "--dtype", type=str, default="bfloat16", choices=["float32", "bfloat16"]
)
parser.add_argument("-t", "--temperature", type=float, default=0.0)
parser.add_argument("-m", "--max-new-tokens", type=int, default=512)
parser.add_argument("-n", "--num-samples", type=int, default=1)
parser.add_argument("-k", "--top-k", type=int, default=50)
parser.add_argument(
"-b",
"--batch-size",
type=int,
default=8,
help="Batch size for categorical evaluation",
)
parser.add_argument(
"-g", "--model-tag", type=str, default=None, help="Model tag to load"
)
parser.add_argument("-s", "--step", type=int, default=None, help="Step to load")
parser.add_argument(
"-x", "--max-problems", type=int, default=None, help="Max problems to evaluate"
)
parser.add_argument(
"--device-type",
'--device-type',
type=str,
default="",
choices=["cuda", "cpu", "mps"],
help="Device type for evaluation: cuda|cpu|mps. empty => autodetect",
default='',
choices=['cuda', 'cpu', 'mps'],
help='Device type for evaluation: cuda|cpu|mps. empty => autodetect',
)
args = parser.parse_args()
device_type = (
autodetect_device_type() if args.device_type == "" else args.device_type
)
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
ptdtype = torch.float32 if args.dtype == "float32" else torch.bfloat16
ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16
autocast_ctx = (
torch.amp.autocast(device_type=device_type, dtype=ptdtype)
if device_type == "cuda"
else nullcontext()
torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext()
)
model, tokenizer, meta = load_model(
args.source, device, phase="eval", model_tag=args.model_tag, step=args.step
)
model, tokenizer, meta = load_model(args.source, device, phase="eval", model_tag=args.model_tag, step=args.step)
engine = Engine(model, tokenizer)
# Get the tasks to evaluate on
all_tasks = [
"ARC-Easy",
"ARC-Challenge",
"MMLU",
"GSM8K",
"HumanEval",
"SpellingBee",
]
all_tasks = ['ARC-Easy', 'ARC-Challenge', 'MMLU', 'GSM8K', 'HumanEval', 'SpellingBee']
baseline_accuracies = {
"ARC-Easy": 0.25, # multiple choice 1 of 4 => 25%
"ARC-Challenge": 0.25, # multiple choice 1 of 4 => 25%
"MMLU": 0.25, # multiple choice 1 of 4 => 25%
"GSM8K": 0.0, # open-ended => 0%
"HumanEval": 0.0, # open-ended => 0%
"SpellingBee": 0.0, # open-ended => 0%
'ARC-Easy': 0.25, # multiple choice 1 of 4 => 25%
'ARC-Challenge': 0.25, # multiple choice 1 of 4 => 25%
'MMLU': 0.25, # multiple choice 1 of 4 => 25%
'GSM8K': 0.0, # open-ended => 0%
'HumanEval': 0.0, # open-ended => 0%
'SpellingBee': 0.0, # open-ended => 0%
}
task_names = all_tasks if args.task_name is None else args.task_name.split("|")
task_names = all_tasks if args.task_name is None else args.task_name.split('|')
# Run all the task evaluations sequentially
results = {}

View File

@ -24,13 +24,7 @@ import torch.distributed as dist
import wandb
from nanochat.checkpoint_manager import load_model, save_checkpoint
from nanochat.common import (
DummyWandb,
compute_cleanup,
compute_init,
get_base_dir,
print0,
)
from nanochat.common import DummyWandb, compute_cleanup, compute_init, get_base_dir, print0
from nanochat.engine import Engine
from tasks.gsm8k import GSM8K
@ -39,9 +33,7 @@ run = "dummy" # wandb run name
source = "sft" # mid|sft
dtype = "bfloat16"
device_batch_size = 8 # no forward pass will go above this to not OOM
examples_per_step = (
16 # in total and across all ranks (note: examples, not samples/completions!)
)
examples_per_step = 16 # in total and across all ranks (note: examples, not samples/completions!)
num_samples = 16 # number of samples per example (/question)
max_new_tokens = 256
temperature = 1.0
@ -56,30 +48,20 @@ save_every = 60 # every how many steps to save the model
eval_every = 60 # every how many steps to evaluate the model for val pass@k
eval_examples = 400 # number of examples used for evaluating pass@k
# now allow CLI to override the settings via the configurator lol
config_keys = [
k
for k, v in globals().items()
if not k.startswith("_") and isinstance(v, (int, float, bool, str))
]
exec(
open(os.path.join("nanochat", "configurator.py")).read()
) # overrides from command line or config file
config_keys = [k for k, v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file
user_config = {k: globals()[k] for k in config_keys} # will be useful for logging
# -----------------------------------------------------------------------------
# Init compute/precision
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
dtype = torch.float32 if dtype == "float32" else torch.bfloat16
dtype = torch.float32 if dtype == 'float32' else torch.bfloat16
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=dtype)
# wandb logging init
use_dummy_wandb = run == "dummy" or not master_process
wandb_run = (
DummyWandb()
if use_dummy_wandb
else wandb.init(project="nanochat-rl", name=run, config=user_config)
)
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-rl", name=run, config=user_config)
# Init model and tokenizer
model, tokenizer, meta = load_model(source, device, phase="eval")
@ -103,7 +85,6 @@ def get_batch():
ddp_rank, len(train_task), ddp_world_size
) # each rank is responsible for different examples in the training data
for example_idx in itertools.cycle(rank_indices):
# First get the full conversation of both user and assistant messages
conversation = train_task[example_idx]
@ -116,13 +97,9 @@ def get_batch():
model.eval() # ensure the model is in eval mode
generated_token_sequences = []
masks = []
num_sampling_steps = (
num_samples // device_batch_size
) # go sequentially to prevent OOMs
num_sampling_steps = num_samples // device_batch_size # go sequentially to prevent OOMs
for sampling_step in range(num_sampling_steps):
seed = (
hash((step, example_idx, sampling_step)) & 0x7FFFFFFF
) # positive half of int32
seed = hash((step, example_idx, sampling_step)) & 0x7FFFFFFF # positive half of int32
with autocast_ctx:
generated_token_sequences_batch, masks_batch = engine.generate_batch(
tokens,
@ -149,21 +126,16 @@ def get_batch():
# Pad the sequences so that their lengths (in time) match
max_length = max(len(seq) for seq in generated_token_sequences)
padded_generated_token_sequences = [
seq + [assistant_end] * (max_length - len(seq))
for seq in generated_token_sequences
seq + [assistant_end] * (max_length - len(seq)) for seq in generated_token_sequences
]
padded_masks = [mask + [0] * (max_length - len(mask)) for mask in masks]
# Stack up the sequences and masks into PyTorch tensors
ids = torch.tensor(
padded_generated_token_sequences, dtype=torch.long, device=device
)
ids = torch.tensor(padded_generated_token_sequences, dtype=torch.long, device=device)
mask_ids = torch.tensor(padded_masks, dtype=torch.long, device=device)
# Generate autoregressive inputs and targets to the Transformer
inputs = ids[:, :-1]
targets = ids[:, 1:].clone() # clone to avoid in-place modification:
targets[mask_ids[:, 1:] == 0] = (
-1
) # <-- inplace modification right here. -1 is the ignore index
targets[mask_ids[:, 1:] == 0] = -1 # <-- inplace modification right here. -1 is the ignore index
# NOTE also that the Engine returns mask=0 for BOTH the prompt tokens AND the tool use tokens.
# So we will (correctly) end up not training on the prompt tokens, or the tool use forced tokens.
rewards = torch.tensor(rewards, dtype=torch.float, device=device)
@ -177,14 +149,7 @@ def get_batch():
# -----------------------------------------------------------------------------
# Simple evaluation loop for GSM8K pass@k
def run_gsm8k_eval(
task,
tokenizer,
engine,
max_examples=None,
num_samples=1,
max_completion_tokens=256,
temperature=0.0,
top_k=50,
task, tokenizer, engine, max_examples=None, num_samples=1, max_completion_tokens=256, temperature=0.0, top_k=50
):
"""
Evaluates GSM8K task and returns a list of records of evaluation outcomes.
@ -192,23 +157,15 @@ def run_gsm8k_eval(
do the reduction across ranks. This is the responsibility of the caller.
Because the evaluation can take a while, this function will yield records one by one.
"""
max_examples = (
min(max_examples, len(task)) if max_examples is not None else len(task)
)
max_examples = min(max_examples, len(task)) if max_examples is not None else len(task)
for idx in range(ddp_rank, max_examples, ddp_world_size):
conversation = task[idx]
tokens = tokenizer.render_for_completion(conversation)
prefix_length = len(tokens)
# Generate k samples using batched generation inside the Engine
assert (
num_samples <= device_batch_size
) # usually this is true. we can add a loop if not...
assert num_samples <= device_batch_size # usually this is true. we can add a loop if not...
generated_token_sequences, masks = engine.generate_batch(
tokens,
num_samples=num_samples,
max_tokens=max_completion_tokens,
temperature=temperature,
top_k=top_k,
tokens, num_samples=num_samples, max_tokens=max_completion_tokens, temperature=temperature, top_k=top_k
)
# Check each sample for correctness
outcomes = []
@ -240,9 +197,7 @@ optimizers = model.setup_optimizers(
for opt in optimizers:
for group in opt.param_groups:
group["lr"] = group["lr"] * init_lr_frac
group["initial_lr"] = group[
"lr"
] # save the initial learning so we can decay easily later
group["initial_lr"] = group["lr"] # save the initial learning so we can decay easily later
# Learning rate scheduler: simple rampdown to zero over num_steps
@ -252,52 +207,33 @@ def get_lr_multiplier(it):
# Calculate the number of examples each rank handles to achieve the desired examples_per_step
print0(
f"Total sequences per step: {examples_per_step * num_samples}"
) # total batch size in sequences/step
assert (
examples_per_step % ddp_world_size == 0
), "Desired examples per step must be divisible by the number of ranks"
print0(f"Total sequences per step: {examples_per_step * num_samples}") # total batch size in sequences/step
assert examples_per_step % ddp_world_size == 0, "Desired examples per step must be divisible by the number of ranks"
examples_per_rank = examples_per_step // ddp_world_size # per GPU
print0(f"Calculated examples per rank: {examples_per_rank}")
# Kick off the training loop
batch_iterator = get_batch()
for step in range(num_steps):
# Evaluate the model once in a while and log to wandb
if step % eval_every == 0:
model.eval()
passk = torch.zeros(
device_batch_size, device=device
) # pass@k for k=1..device_batch_size
passk = torch.zeros(device_batch_size, device=device) # pass@k for k=1..device_batch_size
with autocast_ctx:
records_iter = run_gsm8k_eval(
val_task,
tokenizer,
engine,
num_samples=device_batch_size,
max_examples=eval_examples,
temperature=1.0,
val_task, tokenizer, engine, num_samples=device_batch_size, max_examples=eval_examples, temperature=1.0
)
records = list(records_iter) # collect all records
for k in range(1, device_batch_size + 1):
passk[k - 1] = sum(
any(o["is_correct"] for o in r["outcomes"][:k]) for r in records
)
passk[k - 1] = sum(any(o["is_correct"] for o in r["outcomes"][:k]) for r in records)
num_records = torch.tensor(len(records), dtype=torch.long, device=device)
if ddp:
dist.all_reduce(num_records, op=dist.ReduceOp.SUM)
dist.all_reduce(passk, op=dist.ReduceOp.SUM)
passk = passk / num_records.item() # normalize by the total number of records
print_passk = [
f"Pass@{k}: {passk[k - 1].item():.4f}"
for k in range(1, device_batch_size + 1)
]
print_passk = [f"Pass@{k}: {passk[k - 1].item():.4f}" for k in range(1, device_batch_size + 1)]
print0(f"Step {step} | {', '.join(print_passk)}")
log_passk = {
f"pass@{k}": passk[k - 1].item() for k in range(1, device_batch_size + 1)
}
log_passk = {f"pass@{k}": passk[k - 1].item() for k in range(1, device_batch_size + 1)}
wandb_run.log(
{
"step": step,
@ -310,9 +246,7 @@ for step in range(num_steps):
sequence_lengths = []
for example_step in range(examples_per_rank):
# Get one batch corresponding to one example in the training dataset
sequences_all, inputs_all, targets_all, rewards_all, advantages_all = next(
batch_iterator
)
sequences_all, inputs_all, targets_all, rewards_all, advantages_all = next(batch_iterator)
# Evaluate the loss and gradients
model.train() # ensure the model is in train mode
# We need one more loop because we can never exceed the device_batch_size
@ -327,9 +261,7 @@ for step in range(num_steps):
advantages = advantages_all[b0:b1]
# Calculate log probabilities. Note that the loss calculates NLL = -logp, so we negate
with autocast_ctx:
logp = -model(inputs, targets, loss_reduction="none").view_as(
inputs
) # (B, T)
logp = -model(inputs, targets, loss_reduction='none').view_as(inputs) # (B, T)
# Calculate the PG objective. Note that ignore_index=-1 ensures that invalid tokens have loss 0.
pg_obj = (logp * advantages.unsqueeze(-1)).sum()
# normalize by the number of valid tokens, number of passes, and examples_per_rank
@ -351,9 +283,7 @@ for step in range(num_steps):
mean_sequence_length = sum(sequence_lengths) / len(sequence_lengths)
if ddp: # aggregate across ranks
mean_reward_tensor = torch.tensor(mean_reward, dtype=torch.float, device=device)
mean_sequence_length_tensor = torch.tensor(
mean_sequence_length, dtype=torch.float, device=device
)
mean_sequence_length_tensor = torch.tensor(mean_sequence_length, dtype=torch.float, device=device)
dist.all_reduce(mean_reward_tensor, op=dist.ReduceOp.AVG)
dist.all_reduce(mean_sequence_length_tensor, op=dist.ReduceOp.AVG)
mean_reward = mean_reward_tensor.item()
@ -385,16 +315,12 @@ for step in range(num_steps):
)
# Master process saves the model once in a while. Skip first step. Save last step.
if master_process and (
(step > 0 and step % save_every == 0) or step == num_steps - 1
):
if master_process and ((step > 0 and step % save_every == 0) or step == num_steps - 1):
base_dir = get_base_dir()
depth = model.config.n_layer
model_tag = f"d{depth}" # base the model tag on the depth of the base model
checkpoint_dir = os.path.join(base_dir, "chatrl_checkpoints", model_tag)
model_config_kwargs = (
model.config.__dict__
) # slightly naughty, abusing the simplicity of GPTConfig, TODO nicer
model_config_kwargs = model.config.__dict__ # slightly naughty, abusing the simplicity of GPTConfig, TODO nicer
save_checkpoint(
checkpoint_dir,
step,

View File

@ -20,14 +20,7 @@ import torch.distributed as dist
import wandb
from nanochat.checkpoint_manager import load_model, save_checkpoint
from nanochat.common import (
DummyWandb,
autodetect_device_type,
compute_cleanup,
compute_init,
get_base_dir,
print0,
)
from nanochat.common import DummyWandb, autodetect_device_type, compute_cleanup, compute_init, get_base_dir, print0
from nanochat.engine import Engine
from scripts.chat_eval import run_chat_eval
from tasks.arc import ARC
@ -50,9 +43,7 @@ dtype = "bfloat16"
device_batch_size = 4 # max to avoid OOM
# optimization
num_epochs = 1
num_iterations = (
-1
) # override number of iterations (-1 = disable, use num_epochs to derive it)
num_iterations = -1 # override number of iterations (-1 = disable, use num_epochs to derive it)
target_examples_per_step = 32
unembedding_lr = 0.004
embedding_lr = 0.2
@ -65,14 +56,8 @@ eval_steps = 100
eval_metrics_every = 200
eval_metrics_max_problems = 1024
# now allow CLI to override the settings via the configurator lol
config_keys = [
k
for k, v in globals().items()
if not k.startswith("_") and isinstance(v, (int, float, bool, str))
]
exec(
open(os.path.join("nanochat", "configurator.py")).read()
) # overrides from command line or config file
config_keys = [k for k, v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file
user_config = {k: globals()[k] for k in config_keys} # possibly useful for logging
# -----------------------------------------------------------------------------
@ -80,56 +65,38 @@ user_config = {k: globals()[k] for k in config_keys} # possibly useful for logg
device_type = autodetect_device_type() if device_type == "" else device_type
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
master_process = ddp_rank == 0
ptdtype = torch.float32 if dtype == "float32" else torch.bfloat16
autocast_ctx = (
torch.amp.autocast(device_type=device_type, dtype=ptdtype)
if device_type == "cuda"
else nullcontext()
)
ptdtype = torch.float32 if dtype == 'float32' else torch.bfloat16
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext()
# wandb logging init
use_dummy_wandb = run == "dummy" or not master_process
wandb_run = (
DummyWandb()
if use_dummy_wandb
else wandb.init(
project="nanochat-sft", name=run, config=user_config, save_code=True
)
else wandb.init(project="nanochat-sft", name=run, config=user_config, save_code=True)
)
# Load the model and tokenizer
model, tokenizer, meta = load_model(
source, device, phase="train", model_tag=model_tag, step=step
)
model, tokenizer, meta = load_model(source, device, phase="train", model_tag=model_tag, step=step)
orig_model = model # original, uncompiled model
# model = torch.compile(model, dynamic=True) # doesn't work super well because of variable lengths of inputs
engine = Engine(model, tokenizer) # will be used for inline model evaluation only
# -----------------------------------------------------------------------------
# Task data mixture we'll train on
identity_conversations_filepath = os.path.join(
get_base_dir(), "identity_conversations.jsonl"
)
identity_conversations_filepath = os.path.join(get_base_dir(), "identity_conversations.jsonl")
train_ds = TaskMixture(
[
ARC(subset="ARC-Easy", split="train"), # 2.3K rows
ARC(subset="ARC-Challenge", split="train"), # 1.1K rows
GSM8K(subset="main", split="train"), # 8K rows
SmolTalk(split="train", stop=10_000), # 10K rows of smoltalk
CustomJSON(
filepath=identity_conversations_filepath
), # 1K rows of synthetic identity conversations
SimpleSpelling(
size=300, split="train"
), # 300 rows of Simple Spelling (e.g. spell the word 'apple')
SpellingBee(
size=300, split="train"
), # 300 rows of Spelling Bee (e.g. how many 'r' are in 'strawberry'?)
CustomJSON(filepath=identity_conversations_filepath), # 1K rows of synthetic identity conversations
SimpleSpelling(size=300, split="train"), # 300 rows of Simple Spelling (e.g. spell the word 'apple')
SpellingBee(size=300, split="train"), # 300 rows of Spelling Bee (e.g. how many 'r' are in 'strawberry'?)
]
) # 2.3K + 1.1K + 8K + 10K + 1K + 0.3K + 0.3K = 23K rows
val_ds = SmolTalk(
split="test"
) # general conversations, 24K rows (though we don't actually use all of it)
val_ds = SmolTalk(split="test") # general conversations, 24K rows (though we don't actually use all of it)
# -----------------------------------------------------------------------------
# DataLoader
@ -143,9 +110,7 @@ def sft_data_generator(dataset, batch_size):
# prepares a list of tokenized conversations into a batch and yields
def collate_and_yield(batch):
nrows = len(batch)
ncols = (
max(len(ids) for ids, mask in batch) - 1
) # seq of n creates inputs/targets of n-1
ncols = max(len(ids) for ids, mask in batch) - 1 # seq of n creates inputs/targets of n-1
inputs = torch.full((nrows, ncols), pad_token_id, dtype=torch.long)
targets = torch.full((nrows, ncols), -1, dtype=torch.long) # -1 is ignore index
for i, (ids, mask) in enumerate(batch):
@ -178,9 +143,9 @@ examples_per_step = device_batch_size * ddp_world_size
print0(f"Target examples per step: {target_examples_per_step}")
print0(f"Device batch size: {device_batch_size}")
print0(f"Examples per step is device_batch_size * ddp_world_size: {examples_per_step}")
assert (
target_examples_per_step % examples_per_step == 0
), "Target examples per step must be divisible by examples per step"
assert target_examples_per_step % examples_per_step == 0, (
"Target examples per step must be divisible by examples per step"
)
grad_accum_steps = target_examples_per_step // examples_per_step
print0(f"=> Setting grad accum steps: {grad_accum_steps}")
@ -204,9 +169,7 @@ optimizers = model.setup_optimizers(
for opt in optimizers:
for group in opt.param_groups:
group["lr"] = group["lr"] * init_lr_frac
group["initial_lr"] = group[
"lr"
] # save the initial learning so we can decay easily later
group["initial_lr"] = group["lr"] # save the initial learning so we can decay easily later
# -----------------------------------------------------------------------------
# Training loop
@ -269,7 +232,7 @@ for step in range(num_iterations):
batch_size=device_batch_size * 2,
max_problems=eval_metrics_max_problems,
)
metrics_str = ", ".join(f"{k}: {v:.6f}" for k, v in metrics.items())
metrics_str = ', '.join(f'{k}: {v:.6f}' for k, v in metrics.items())
print0(f"Step {step:05d} | {metrics_str}")
wandb_run.log(
{
@ -283,17 +246,13 @@ for step in range(num_iterations):
break
# evaluate the gradient
num_tokens = torch.tensor(
0, device=device
) # the number of "active" tokens of supervision seen
num_tokens = torch.tensor(0, device=device) # the number of "active" tokens of supervision seen
for micro_step in range(grad_accum_steps):
train_inputs, train_targets = next(train_iter)
with autocast_ctx:
loss = model(train_inputs, train_targets)
train_loss = loss.detach() # for logging
loss = (
loss / grad_accum_steps
) # each .backward() is a grad sum => normalize loss here
loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here
loss.backward() # accumulate the gradient
num_tokens += (train_targets >= 0).sum()
if ddp:
@ -332,9 +291,7 @@ if master_process:
depth = model.config.n_layer
model_tag = f"d{depth}" # base the model tag on the depth of the base model
checkpoint_dir = os.path.join(base_dir, "chatsft_checkpoints", model_tag)
model_config_kwargs = (
model.config.__dict__
) # slightly naughty, abusing the simplicity of GPTConfig, TODO nicer
model_config_kwargs = model.config.__dict__ # slightly naughty, abusing the simplicity of GPTConfig, TODO nicer
save_checkpoint(
checkpoint_dir,
step,

View File

@ -36,9 +36,9 @@ import json
import logging
import os
import random
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager, nullcontext
from dataclasses import dataclass
from typing import AsyncGenerator, List, Optional
import torch
from fastapi import FastAPI, HTTPException
@ -61,61 +61,33 @@ MAX_TOP_K = 200
MIN_MAX_TOKENS = 1
MAX_MAX_TOKENS = 4096
parser = argparse.ArgumentParser(description="NanoChat Web Server")
parser = argparse.ArgumentParser(description='NanoChat Web Server')
parser.add_argument('-n', '--num-gpus', type=int, default=1, help='Number of GPUs to use (default: 1)')
parser.add_argument('-i', '--source', type=str, default="sft", help="Source of the model: sft|mid|rl")
parser.add_argument('-t', '--temperature', type=float, default=0.8, help='Default temperature for generation')
parser.add_argument('-k', '--top-k', type=int, default=50, help='Default top-k sampling parameter')
parser.add_argument('-m', '--max-tokens', type=int, default=512, help='Default max tokens for generation')
parser.add_argument('-g', '--model-tag', type=str, default=None, help='Model tag to load')
parser.add_argument('-s', '--step', type=int, default=None, help='Step to load')
parser.add_argument('-p', '--port', type=int, default=8000, help='Port to run the server on')
parser.add_argument('-d', '--dtype', type=str, default='bfloat16', choices=['float32', 'bfloat16'])
parser.add_argument(
"-n", "--num-gpus", type=int, default=1, help="Number of GPUs to use (default: 1)"
)
parser.add_argument(
"-i", "--source", type=str, default="sft", help="Source of the model: sft|mid|rl"
)
parser.add_argument(
"-t",
"--temperature",
type=float,
default=0.8,
help="Default temperature for generation",
)
parser.add_argument(
"-k", "--top-k", type=int, default=50, help="Default top-k sampling parameter"
)
parser.add_argument(
"-m",
"--max-tokens",
type=int,
default=512,
help="Default max tokens for generation",
)
parser.add_argument(
"-g", "--model-tag", type=str, default=None, help="Model tag to load"
)
parser.add_argument("-s", "--step", type=int, default=None, help="Step to load")
parser.add_argument(
"-p", "--port", type=int, default=8000, help="Port to run the server on"
)
parser.add_argument(
"-d", "--dtype", type=str, default="bfloat16", choices=["float32", "bfloat16"]
)
parser.add_argument(
"--device-type",
'--device-type',
type=str,
default="",
choices=["cuda", "cpu", "mps"],
help="Device type for evaluation: cuda|cpu|mps. empty => autodetect",
)
parser.add_argument(
"--host", type=str, default="0.0.0.0", help="Host to bind the server to"
default='',
choices=['cuda', 'cpu', 'mps'],
help='Device type for evaluation: cuda|cpu|mps. empty => autodetect',
)
parser.add_argument('--host', type=str, default='0.0.0.0', help='Host to bind the server to')
args = parser.parse_args()
# Configure logging for conversation traffic
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S"
)
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S')
logger = logging.getLogger(__name__)
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
ptdtype = torch.float32 if args.dtype == "float32" else torch.bfloat16
ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16
@dataclass
@ -132,28 +104,23 @@ class Worker:
class WorkerPool:
"""Pool of workers, each with a model replica on a different GPU."""
def __init__(self, num_gpus: Optional[int] = None):
def __init__(self, num_gpus: int | None = None):
if num_gpus is None:
if device_type == "cuda":
num_gpus = torch.cuda.device_count()
else:
num_gpus = 1 # e.g. cpu|mps
self.num_gpus = num_gpus
self.workers: List[Worker] = []
self.workers: list[Worker] = []
self.available_workers: asyncio.Queue = asyncio.Queue()
async def initialize(
self, source: str, model_tag: Optional[str] = None, step: Optional[int] = None
):
async def initialize(self, source: str, model_tag: str | None = None, step: int | None = None):
"""Load model on each GPU."""
print(f"Initializing worker pool with {self.num_gpus} GPUs...")
if self.num_gpus > 1:
assert (
device_type == "cuda"
), "Only CUDA supports multiple workers/GPUs. cpu|mps does not."
assert device_type == "cuda", "Only CUDA supports multiple workers/GPUs. cpu|mps does not."
for gpu_id in range(self.num_gpus):
if device_type == "cuda":
device = torch.device(f"cuda:{gpu_id}")
print(f"Loading model on GPU {gpu_id}...")
@ -161,23 +128,13 @@ class WorkerPool:
device = torch.device(device_type) # e.g. cpu|mps
print(f"Loading model on {device_type}...")
model, tokenizer, _ = load_model(
source, device, phase="eval", model_tag=model_tag, step=step
)
model, tokenizer, _ = load_model(source, device, phase="eval", model_tag=model_tag, step=step)
engine = Engine(model, tokenizer)
autocast_ctx = (
torch.amp.autocast(device_type=device_type, dtype=ptdtype)
if device_type == "cuda"
else nullcontext()
torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext()
)
worker = Worker(
gpu_id=gpu_id,
device=device,
engine=engine,
tokenizer=tokenizer,
autocast_ctx=autocast_ctx,
)
worker = Worker(gpu_id=gpu_id, device=device, engine=engine, tokenizer=tokenizer, autocast_ctx=autocast_ctx)
self.workers.append(worker)
await self.available_workers.put(worker)
@ -198,10 +155,10 @@ class ChatMessage(BaseModel):
class ChatRequest(BaseModel):
messages: List[ChatMessage]
temperature: Optional[float] = None
max_tokens: Optional[int] = None
top_k: Optional[int] = None
messages: list[ChatMessage]
temperature: float | None = None
max_tokens: int | None = None
top_k: int | None = None
def validate_chat_request(request: ChatRequest):
@ -219,9 +176,7 @@ def validate_chat_request(request: ChatRequest):
total_length = 0
for i, message in enumerate(request.messages):
if not message.content:
raise HTTPException(
status_code=400, detail=f"Message {i} has empty content"
)
raise HTTPException(status_code=400, detail=f"Message {i} has empty content")
msg_length = len(message.content)
if msg_length > MAX_MESSAGE_LENGTH:
@ -241,32 +196,26 @@ def validate_chat_request(request: ChatRequest):
for i, message in enumerate(request.messages):
if message.role not in ["user", "assistant"]:
raise HTTPException(
status_code=400,
detail=f"Message {i} has invalid role. Must be 'user', 'assistant', or 'system'",
status_code=400, detail=f"Message {i} has invalid role. Must be 'user', 'assistant', or 'system'"
)
# Validate temperature
if request.temperature is not None:
if not (MIN_TEMPERATURE <= request.temperature <= MAX_TEMPERATURE):
raise HTTPException(
status_code=400,
detail=f"Temperature must be between {MIN_TEMPERATURE} and {MAX_TEMPERATURE}",
status_code=400, detail=f"Temperature must be between {MIN_TEMPERATURE} and {MAX_TEMPERATURE}"
)
# Validate top_k
if request.top_k is not None:
if not (MIN_TOP_K <= request.top_k <= MAX_TOP_K):
raise HTTPException(
status_code=400,
detail=f"top_k must be between {MIN_TOP_K} and {MAX_TOP_K}",
)
raise HTTPException(status_code=400, detail=f"top_k must be between {MIN_TOP_K} and {MAX_TOP_K}")
# Validate max_tokens
if request.max_tokens is not None:
if not (MIN_MAX_TOKENS <= request.max_tokens <= MAX_MAX_TOKENS):
raise HTTPException(
status_code=400,
detail=f"max_tokens must be between {MIN_MAX_TOKENS} and {MAX_MAX_TOKENS}",
status_code=400, detail=f"max_tokens must be between {MIN_MAX_TOKENS} and {MAX_MAX_TOKENS}"
)
@ -275,9 +224,7 @@ async def lifespan(app: FastAPI):
"""Load models on all GPUs on startup."""
print("Loading nanochat models across GPUs...")
app.state.worker_pool = WorkerPool(num_gpus=args.num_gpus)
await app.state.worker_pool.initialize(
args.source, model_tag=args.model_tag, step=args.step
)
await app.state.worker_pool.initialize(args.source, model_tag=args.model_tag, step=args.step)
print(f"Server ready at http://localhost:{args.port}")
yield
@ -301,8 +248,7 @@ async def root():
html_content = f.read()
# Replace the API_URL to use the same origin
html_content = html_content.replace(
"const API_URL = `http://${window.location.hostname}:8000`;",
"const API_URL = '';",
"const API_URL = `http://${window.location.hostname}:8000`;", "const API_URL = '';"
)
return HTMLResponse(content=html_content)
@ -352,7 +298,7 @@ async def generate_stream(
current_text = worker.tokenizer.decode(accumulated_tokens)
# Only emit text if it doesn't end with a replacement character
# This ensures we don't emit incomplete UTF-8 sequences
if not current_text.endswith("<EFBFBD>"):
if not current_text.endswith('<EFBFBD>'):
# Extract only the new text since last clean decode
new_text = current_text[len(last_clean_text) :]
if new_text: # Only yield if there's new content
@ -435,14 +381,12 @@ async def chat_completions(request: ChatRequest):
@app.get("/health")
async def health():
"""Health check endpoint."""
worker_pool = getattr(app.state, "worker_pool", None)
worker_pool = getattr(app.state, 'worker_pool', None)
return {
"status": "ok",
"ready": worker_pool is not None and len(worker_pool.workers) > 0,
"num_gpus": worker_pool.num_gpus if worker_pool else 0,
"available_workers": (
worker_pool.available_workers.qsize() if worker_pool else 0
),
"available_workers": worker_pool.available_workers.qsize() if worker_pool else 0,
}
@ -453,19 +397,14 @@ async def stats():
return {
"total_workers": len(worker_pool.workers),
"available_workers": worker_pool.available_workers.qsize(),
"busy_workers": len(worker_pool.workers)
- worker_pool.available_workers.qsize(),
"workers": [
{"gpu_id": w.gpu_id, "device": str(w.device)} for w in worker_pool.workers
],
"busy_workers": len(worker_pool.workers) - worker_pool.available_workers.qsize(),
"workers": [{"gpu_id": w.gpu_id, "device": str(w.device)} for w in worker_pool.workers],
}
if __name__ == "__main__":
import uvicorn
print(f"Starting NanoChat Web Server")
print(
f"Temperature: {args.temperature}, Top-k: {args.top_k}, Max tokens: {args.max_tokens}"
)
print("Starting NanoChat Web Server")
print(f"Temperature: {args.temperature}, Top-k: {args.top_k}, Max tokens: {args.max_tokens}")
uvicorn.run(app, host=args.host, port=args.port)

View File

@ -21,14 +21,7 @@ import torch.distributed as dist
import wandb
from nanochat.checkpoint_manager import load_model, save_checkpoint
from nanochat.common import (
DummyWandb,
autodetect_device_type,
compute_cleanup,
compute_init,
get_base_dir,
print0,
)
from nanochat.common import DummyWandb, autodetect_device_type, compute_cleanup, compute_init, get_base_dir, print0
from nanochat.loss_eval import evaluate_bpb
from nanochat.tokenizer import get_token_bytes
from tasks.common import TaskMixture
@ -56,14 +49,8 @@ eval_every = 150 # -1 = disable
eval_tokens = 20 * 524288
total_batch_size = 524288
dry_run = 0 # dry_run=1 is for experiments: we will log to wandb but we won't write checkpoints or report
config_keys = [
k
for k, v in globals().items()
if not k.startswith("_") and isinstance(v, (int, float, bool, str))
]
exec(
open(os.path.join("nanochat", "configurator.py")).read()
) # overrides from command line or config file
config_keys = [k for k, v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file
user_config = {k: globals()[k] for k in config_keys} # possibly useful for logging
# -----------------------------------------------------------------------------
@ -72,25 +59,17 @@ device_type = autodetect_device_type() if device_type == "" else device_type
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
master_process = ddp_rank == 0
autocast_ctx = (
torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16)
if device_type == "cuda"
else nullcontext()
torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
)
synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None
get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0
# wandb logging init
use_dummy_wandb = run == "dummy" or not master_process
wandb_run = (
DummyWandb()
if use_dummy_wandb
else wandb.init(project="nanochat-mid", name=run, config=user_config)
)
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-mid", name=run, config=user_config)
# Load the model and tokenizer
model, tokenizer, meta = load_model(
"base", device, phase="train", model_tag=model_tag, step=step
)
model, tokenizer, meta = load_model("base", device, phase="train", model_tag=model_tag, step=step)
pretrain_batch_size = meta.get("device_batch_size", None)
if pretrain_batch_size is not None and device_batch_size > pretrain_batch_size:
print0(
@ -100,38 +79,25 @@ orig_model = model
model = torch.compile(model, dynamic=False)
depth = model.config.n_layer
num_flops_per_token = model.estimate_flops()
tokens_per_fwdbwd = (
device_batch_size * max_seq_len
) # tokens per iteration for a single rank
world_tokens_per_fwdbwd = (
tokens_per_fwdbwd * ddp_world_size
) # total tokens per iteration for all ranks
tokens_per_fwdbwd = device_batch_size * max_seq_len # tokens per iteration for a single rank
world_tokens_per_fwdbwd = tokens_per_fwdbwd * ddp_world_size # total tokens per iteration for all ranks
assert total_batch_size % world_tokens_per_fwdbwd == 0
grad_accum_steps = total_batch_size // world_tokens_per_fwdbwd
print0(
f"Tokens / micro-batch / rank: {device_batch_size} x {max_seq_len} = {tokens_per_fwdbwd:,}"
)
print0(f"Tokens / micro-batch / rank: {device_batch_size} x {max_seq_len} = {tokens_per_fwdbwd:,}")
print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}")
print0(
f"Total batch size {total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}"
)
print0(f"Total batch size {total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}")
token_bytes = get_token_bytes(device=device)
# Initialize the Optimizer (Muon for Linear layers, AdamW for embedding and lm_head)
optimizers = model.setup_optimizers(
unembedding_lr=unembedding_lr,
embedding_lr=embedding_lr,
matrix_lr=matrix_lr,
weight_decay=weight_decay,
unembedding_lr=unembedding_lr, embedding_lr=embedding_lr, matrix_lr=matrix_lr, weight_decay=weight_decay
)
adamw_optimizer, muon_optimizer = optimizers
# Override the initial learning rate as a fraction of the base learning rate
for opt in optimizers:
for group in opt.param_groups:
group["lr"] = group["lr"] * init_lr_frac
group["initial_lr"] = group[
"lr"
] # save the initial learning so we can decay easily later
group["initial_lr"] = group["lr"] # save the initial learning so we can decay easily later
# Midtraining data mixture and DataLoader
base_dir = get_base_dir()
@ -142,32 +108,18 @@ train_dataset = TaskMixture(
MMLU(
subset="auxiliary_train", split="train"
), # 100K rows of multiple choice problems drawn from ARC, MC_TEST, OBQA, RACE
GSM8K(
subset="main", split="train"
), # 8K rows teaching simple math and (calculator) tool use
CustomJSON(
filepath=identity_conversations_filepath
), # 1000 rows of synthetic identity conversations
CustomJSON(
filepath=identity_conversations_filepath
), # let's do 2 epochs of these
SimpleSpelling(
size=200000, split="train"
), # 200K rows of Simple Spelling (e.g. spell the word 'apple')
SpellingBee(
size=80000, split="train"
), # 80K rows of Spelling Bee (e.g. how many 'r' are in 'strawberry'?)
GSM8K(subset="main", split="train"), # 8K rows teaching simple math and (calculator) tool use
CustomJSON(filepath=identity_conversations_filepath), # 1000 rows of synthetic identity conversations
CustomJSON(filepath=identity_conversations_filepath), # let's do 2 epochs of these
SimpleSpelling(size=200000, split="train"), # 200K rows of Simple Spelling (e.g. spell the word 'apple')
SpellingBee(size=80000, split="train"), # 80K rows of Spelling Bee (e.g. how many 'r' are in 'strawberry'?)
]
) # total: 460K + 100K + 8K + 200K + 80K = 848K rows
val_dataset = TaskMixture(
[
SmolTalk(split="test"), # 24K rows in test set
MMLU(
subset="all", split="test", stop=5200
), # 14K rows in test set, use only 5.2K to match the train ratios
GSM8K(
subset="main", split="test", stop=420
), # 1.32K rows in test set, use only 420 to match the train ratios
MMLU(subset="all", split="test", stop=5200), # 14K rows in test set, use only 5.2K to match the train ratios
GSM8K(subset="main", split="test", stop=420), # 1.32K rows in test set, use only 420 to match the train ratios
]
) # total: 24K + 14K + 1.32K ~= 39K rows
# DataLoader is defined here, it emits inputs, targets : 2D tensors of shape (device_batch_size, max_seq_len)
@ -183,14 +135,10 @@ def mid_data_generator(split):
dataset = train_dataset if split == "train" else val_dataset
dataset_size = len(dataset)
assert dataset_size > 0
needed_tokens = (
device_batch_size * max_seq_len + 1
) # to form one training batch of inputs,targets
needed_tokens = device_batch_size * max_seq_len + 1 # to form one training batch of inputs,targets
token_buffer = deque()
# CUDA supports memory pinning for faster transfers between CPU and GPU:
scratch = torch.empty(
needed_tokens, dtype=torch.int64, pin_memory=(device_type == "cuda")
)
scratch = torch.empty(needed_tokens, dtype=torch.int64, pin_memory=(device_type == "cuda"))
cursor = ddp_rank # increments by ddp_world_size each time, so each rank processes unique documents
it = 0 # iteration counter
while True:
@ -207,29 +155,21 @@ def mid_data_generator(split):
# Stopping condition to respect num_iterations, if given
it += 1
if num_iterations > 0 and it >= num_iterations:
last_step = (
True # toggle last_step to True, which will terminate the training loop
)
last_step = True # toggle last_step to True, which will terminate the training loop
# Build up inputs/targets and yield
for i in range(needed_tokens):
scratch[i] = token_buffer.popleft()
inputs_cpu = scratch[:-1].to(dtype=torch.int32)
targets_cpu = scratch[1:]
inputs = inputs_cpu.view(device_batch_size, max_seq_len).to(
device=device, dtype=torch.int32, non_blocking=True
)
inputs = inputs_cpu.view(device_batch_size, max_seq_len).to(device=device, dtype=torch.int32, non_blocking=True)
targets = targets_cpu.view(device_batch_size, max_seq_len).to(
device=device, dtype=torch.int64, non_blocking=True
)
if split == "train":
if num_iterations > 0:
approx_progress = (
it / num_iterations
) # calculate progress from the max number of iterations
approx_progress = it / num_iterations # calculate progress from the max number of iterations
else:
approx_progress = (
cursor / dataset_size
) # approximate progress as a fraction of the dataset
approx_progress = cursor / dataset_size # approximate progress as a fraction of the dataset
yield inputs, targets
@ -296,9 +236,7 @@ while True:
checkpoint_dir,
step,
orig_model.state_dict(),
[
opt.state_dict() for opt in optimizers
], # TODO: make sure saving across ranks is done correctly
[opt.state_dict() for opt in optimizers], # TODO: make sure saving across ranks is done correctly
{
"step": step,
"val_bpb": val_bpb, # loss at last step
@ -326,16 +264,10 @@ while True:
with autocast_ctx:
loss = model(x, y)
train_loss = loss.detach() # for logging
loss = (
loss / grad_accum_steps
) # each .backward() is a grad sum => normalize loss here
loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here
loss.backward()
x, y = next(
train_loader
) # prefetch the next batch while the GPU is busy with forward/backward
progress = max(
progress, approx_progress
) # only increase progress monotonically
x, y = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward
progress = max(progress, approx_progress) # only increase progress monotonically
# step the optimizers
lrm = get_lr_multiplier(progress)
for opt in optimizers:
@ -356,23 +288,17 @@ while True:
step += 1
# logging
smooth_train_loss = (
ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss.item()
) # EMA the training loss
debiased_smooth_loss = smooth_train_loss / (
1 - ema_beta ** (step + 1)
) # debias the EMA
smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss.item() # EMA the training loss
debiased_smooth_loss = smooth_train_loss / (1 - ema_beta ** (step + 1)) # debias the EMA
pct_done = 100 * progress
tok_per_sec = int(total_batch_size / dt)
flops_per_sec = num_flops_per_token * total_batch_size / dt
promised_flops_per_sec_h100 = (
989e12 * ddp_world_size
) # bfloat16 H100 SXM and without 2:4 sparsity
promised_flops_per_sec_h100 = 989e12 * ddp_world_size # bfloat16 H100 SXM and without 2:4 sparsity
mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in %
if step > 10:
total_training_time += dt # only count the time after the first 10 steps
print0(
f"step {step:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | total time: {total_training_time/60:.2f}m"
f"step {step:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | total time: {total_training_time / 60:.2f}m"
)
if step % 10 == 0:
wandb_run.log(
@ -390,7 +316,7 @@ while True:
# print a few more stats
print0(f"Peak memory usage: {get_max_memory() / 1024 / 1024:.2f}MiB")
print0(f"Total training time: {total_training_time/60:.2f}m")
print0(f"Total training time: {total_training_time / 60:.2f}m")
print0(f"Minimum validation bpb: {min_val_bpb:.4f}")
# Log to report

View File

@ -165,15 +165,10 @@ tokenizer_results = {}
vocab_sizes = {}
for tokenizer_name in ["gpt2", "gpt4", "ours"]:
if tokenizer_name == "gpt2":
tokenizer = RustBPETokenizer.from_pretrained(
"gpt2"
) # gpt-2 base model tokenizer
tokenizer = RustBPETokenizer.from_pretrained("gpt2") # gpt-2 base model tokenizer
elif tokenizer_name == "gpt4":
tokenizer = RustBPETokenizer.from_pretrained(
"cl100k_base"
) # gpt-4 base model tokenizer
tokenizer = RustBPETokenizer.from_pretrained("cl100k_base") # gpt-4 base model tokenizer
else:
tokenizer = get_tokenizer()
@ -185,21 +180,17 @@ for tokenizer_name in ["gpt2", "gpt4", "ours"]:
decoded = tokenizer.decode(encoded)
assert decoded == text
encoded_bytes = text.encode("utf-8")
encoded_bytes = text.encode('utf-8')
ratio = len(encoded_bytes) / len(encoded)
tokenizer_results[tokenizer_name][name] = {
"bytes": len(encoded_bytes),
"tokens": len(encoded),
"ratio": ratio,
}
tokenizer_results[tokenizer_name][name] = {'bytes': len(encoded_bytes), 'tokens': len(encoded), 'ratio': ratio}
# ANSI color codes
GREEN = "\033[92m"
RED = "\033[91m"
RESET = "\033[0m"
GREEN = '\033[92m'
RED = '\033[91m'
RESET = '\033[0m'
# Print vocab sizes
print(f"\nVocab sizes:")
print("\nVocab sizes:")
print(f"GPT-2: {vocab_sizes['gpt2']}")
print(f"GPT-4: {vocab_sizes['gpt4']}")
print(f"Ours: {vocab_sizes['ours']}")
@ -209,12 +200,8 @@ def print_comparison(baseline_name, baseline_results, ours_results, all_text):
"""Print comparison table between baseline tokenizer and ours."""
print(f"\nComparison with {baseline_name}:")
print("=" * 95)
print(
f"{'Text Type':<10} {'Bytes':<8} {baseline_name:<15} {'Ours':<15} {'Relative':<12} {'Better':<10}"
)
print(
f"{'':10} {'':8} {'Tokens':<7} {'Ratio':<7} {'Tokens':<7} {'Ratio':<7} {'Diff %':<12}"
)
print(f"{'Text Type':<10} {'Bytes':<8} {baseline_name:<15} {'Ours':<15} {'Relative':<12} {'Better':<10}")
print(f"{'':10} {'':8} {'Tokens':<7} {'Ratio':<7} {'Tokens':<7} {'Ratio':<7} {'Diff %':<12}")
print("-" * 95)
for name, text in all_text:
@ -223,16 +210,14 @@ def print_comparison(baseline_name, baseline_results, ours_results, all_text):
# Calculate relative difference (positive means ours is better, negative means worse)
# Using tokens: fewer tokens is better, so we calculate (baseline_tokens - ours_tokens) / baseline_tokens
relative_diff = (
(baseline_data["tokens"] - ours_data["tokens"]) / baseline_data["tokens"]
) * 100
relative_diff = ((baseline_data['tokens'] - ours_data['tokens']) / baseline_data['tokens']) * 100
# Determine which has better compression (higher ratio = better)
if baseline_data["ratio"] > ours_data["ratio"]:
if baseline_data['ratio'] > ours_data['ratio']:
baseline_color, ours_color = GREEN, RED
better = baseline_name
diff_color = RED
elif ours_data["ratio"] > baseline_data["ratio"]:
elif ours_data['ratio'] > baseline_data['ratio']:
baseline_color, ours_color = RED, GREEN
better = "Ours"
diff_color = GREEN
@ -253,21 +238,17 @@ def print_comparison(baseline_name, baseline_results, ours_results, all_text):
# Print comparisons
print_comparison(
"GPT-2", tokenizer_results["gpt2"], tokenizer_results["ours"], all_text
)
print_comparison(
"GPT-4", tokenizer_results["gpt4"], tokenizer_results["ours"], all_text
)
print_comparison("GPT-2", tokenizer_results['gpt2'], tokenizer_results['ours'], all_text)
print_comparison("GPT-4", tokenizer_results['gpt4'], tokenizer_results['ours'], all_text)
# Log to report
from nanochat.report import get_report
lines = []
for baseline_name in ["GPT-2", "GPT-4"]:
baseline_key = baseline_name.lower().replace("-", "")
baseline_key = baseline_name.lower().replace('-', '')
baseline_results = tokenizer_results[baseline_key]
ours_results = tokenizer_results["ours"]
ours_results = tokenizer_results['ours']
lines.append(f"### Comparison with {baseline_name}")
lines.append("")
lines.append(
@ -277,15 +258,11 @@ for baseline_name in ["GPT-2", "GPT-4"]:
+ baseline_name
+ " Ratio | Ours Tokens | Ours Ratio | Relative Diff % |"
)
lines.append(
"|-----------|-------|--------------|--------------|-------------|------------|-----------------|"
)
lines.append("|-----------|-------|--------------|--------------|-------------|------------|-----------------|")
for name, text in all_text:
baseline_data = baseline_results[name]
ours_data = ours_results[name]
relative_diff = (
(baseline_data["tokens"] - ours_data["tokens"]) / baseline_data["tokens"]
) * 100
relative_diff = ((baseline_data['tokens'] - ours_data['tokens']) / baseline_data['tokens']) * 100
lines.append(
f"| {name} | {baseline_data['bytes']} | {baseline_data['tokens']} | {baseline_data['ratio']:.2f} | {ours_data['tokens']} | {ours_data['ratio']:.2f} | {relative_diff:+.1f}% |"
)

View File

@ -16,25 +16,12 @@ from nanochat.tokenizer import RustBPETokenizer
# -----------------------------------------------------------------------------
# Parse command line arguments
parser = argparse.ArgumentParser(description="Train a BPE tokenizer")
parser = argparse.ArgumentParser(description='Train a BPE tokenizer')
parser.add_argument(
"--max_chars",
type=int,
default=10_000_000_000,
help="Maximum characters to train on (default: 10B)",
)
parser.add_argument(
"--doc_cap",
type=int,
default=10_000,
help="Maximum characters per document (default: 10,000)",
)
parser.add_argument(
"--vocab_size",
type=int,
default=65536,
help="Vocabulary size (default: 65536 = 2^16)",
'--max_chars', type=int, default=10_000_000_000, help='Maximum characters to train on (default: 10B)'
)
parser.add_argument('--doc_cap', type=int, default=10_000, help='Maximum characters per document (default: 10,000)')
parser.add_argument('--vocab_size', type=int, default=65536, help='Vocabulary size (default: 65536 = 2^16)')
args = parser.parse_args()
print(f"max_chars: {args.max_chars:,}")
print(f"doc_cap: {args.doc_cap:,}")
@ -99,17 +86,13 @@ special_set = set(tokenizer.get_special_tokens())
token_strings = [tokenizer.decode([token_id]) for token_id in range(vocab_size)]
token_bytes = []
for token_id in range(vocab_size):
token_str = token_strings[
token_id
] # the Python string representation of this token
token_str = token_strings[token_id] # the Python string representation of this token
if token_str in special_set:
token_bytes.append(0) # special characters are not counted
else:
id_bytes = len(
token_str.encode("utf-8")
) # number of bytes that make up this token
id_bytes = len(token_str.encode("utf-8")) # number of bytes that make up this token
token_bytes.append(id_bytes)
token_bytes = torch.tensor(token_bytes, dtype=torch.int32, device="cpu")
token_bytes = torch.tensor(token_bytes, dtype=torch.int32, device='cpu')
token_bytes_path = os.path.join(tokenizer_dir, "token_bytes.pt")
with open(token_bytes_path, "wb") as f:
torch.save(token_bytes, f)

View File

@ -9,23 +9,15 @@ from tasks.common import Task, render_mc
class ARC(Task):
def __init__(self, subset, split, **kwargs):
super().__init__(**kwargs)
assert subset in [
"ARC-Easy",
"ARC-Challenge",
], "ARC subset must be ARC-Easy or ARC-Challenge"
assert split in [
"train",
"validation",
"test",
], "ARC split must be train|validation|test"
assert subset in ["ARC-Easy", "ARC-Challenge"], "ARC subset must be ARC-Easy or ARC-Challenge"
assert split in ["train", "validation", "test"], "ARC split must be train|validation|test"
self.ds = load_dataset("allenai/ai2_arc", subset, split=split).shuffle(seed=42)
@property
def eval_type(self):
return "categorical"
return 'categorical'
def num_examples(self):
return len(self.ds)
@ -36,15 +28,10 @@ class ARC(Task):
choices = row["choices"]["text"] # the text of each choice
answer_string = row["answerKey"] # e.g. "A", "B", "C", "D"
letters = row["choices"]["label"] # e.g. ["A", "B", "C", "D"]
assert (
answer_string in letters
), f"ARC answer {answer_string} must be one of {letters}" # sanity check
assert answer_string in letters, f"ARC answer {answer_string} must be one of {letters}" # sanity check
# create and return the Conversation object
user_message = render_mc(question, letters, choices)
messages = [
{"role": "user", "content": user_message},
{"role": "assistant", "content": answer_string},
]
messages = [{"role": "user", "content": user_message}, {"role": "assistant", "content": answer_string}]
conversation = {
"messages": messages,
"letters": letters, # useful during evaluation, so we can narrow and clamp the assistant prediction to one of the letters
@ -54,8 +41,8 @@ class ARC(Task):
def evaluate(self, conversation, assistant_response):
# the assert here is not strictly speaking needed, but currently the way we eval, we expect this to be true
# I'm going to leave the assert here to prevent footguns, but possibly in the future can remove it.
assert (
assistant_response in conversation["letters"]
), f"ARC answer {assistant_response} is expected to be one of {conversation['letters']}"
assistant_message = conversation["messages"][-1]["content"] # e.g. "A"
assert assistant_response in conversation['letters'], (
f"ARC answer {assistant_response} is expected to be one of {conversation['letters']}"
)
assistant_message = conversation['messages'][-1]['content'] # e.g. "A"
return assistant_response == assistant_message

View File

@ -16,9 +16,7 @@ class Task:
def __init__(self, start=0, stop=None, step=1):
# allows a lightweight logical view over a dataset
assert start >= 0, f"Start must be non-negative, got {start}"
assert (
stop is None or stop >= start
), f"Stop should be greater than or equal to start, got {stop} and {start}"
assert stop is None or stop >= start, f"Stop should be greater than or equal to start, got {stop} and {start}"
assert step >= 1, f"Step must be strictly positive, got {step}"
self.start = start
self.stop = stop # could be None here
@ -84,9 +82,9 @@ class TaskMixture(Task):
Access conversations according to a deterministic shuffle of all examples.
This ensures tasks are mixed throughout training, regardless of dataset size.
"""
assert (
0 <= index < self.num_conversations
), f"Index {index} out of range for mixture with {self.num_conversations} conversations"
assert 0 <= index < self.num_conversations, (
f"Index {index} out of range for mixture with {self.num_conversations} conversations"
)
task_idx, local_idx = self.index_map[index]
return self.tasks[task_idx][local_idx]
@ -107,9 +105,9 @@ class TaskSequence(Task):
return self.num_conversations
def get_example(self, index):
assert (
0 <= index < self.num_conversations
), f"Index {index} out of range for sequence with {self.num_conversations} conversations"
assert 0 <= index < self.num_conversations, (
f"Index {index} out of range for sequence with {self.num_conversations} conversations"
)
for task_idx, task_length in enumerate(self.lengths):
if index < task_length:
return self.tasks[task_idx][index]
@ -133,9 +131,7 @@ def render_mc(question, letters, choices):
about this too much, but smaller models do care about some of these details.
"""
query = f"Multiple Choice question: {question}\n"
query += "".join(
[f"- {choice}={letter}\n" for letter, choice in zip(letters, choices)]
)
query += "".join([f"- {choice}={letter}\n" for letter, choice in zip(letters, choices)])
query += "\nRespond only with the letter of the correct answer."
return query

View File

@ -30,44 +30,32 @@ class CustomJSON(Task):
print(
"If you recently did a git pull and suddely see this, it might be due to the new addition of identity conversations"
)
print(
"See this discussion for more details: https://github.com/karpathy/nanochat/discussions/139"
)
print(
"Quick fix: simply run the following command to download the file and you're done:"
)
print("See this discussion for more details: https://github.com/karpathy/nanochat/discussions/139")
print("Quick fix: simply run the following command to download the file and you're done:")
print(
f"curl -L -o {filepath} https://karpathy-public.s3.us-west-2.amazonaws.com/identity_conversations.jsonl"
)
print("-" * 80)
else:
with open(filepath, encoding="utf-8") as f:
with open(filepath, encoding='utf-8') as f:
for line in f:
line = line.strip()
if not line: # skip empty lines
continue
messages = json.loads(line)
# Validate the conversation structure
assert isinstance(
messages, list
), f"Expected list of messages, got {type(messages)}"
assert (
len(messages) >= 2
), f"Conversation must have at least 2 messages, got {len(messages)}"
assert isinstance(messages, list), f"Expected list of messages, got {type(messages)}"
assert len(messages) >= 2, f"Conversation must have at least 2 messages, got {len(messages)}"
# Validate message structure and alternating roles
for i, message in enumerate(messages):
assert "role" in message, f"Message {i} missing 'role' field"
assert (
"content" in message
), f"Message {i} missing 'content' field"
assert "content" in message, f"Message {i} missing 'content' field"
expected_role = "user" if i % 2 == 0 else "assistant"
assert (
message["role"] == expected_role
), f"Message {i} has role {message['role']} but should be {expected_role}"
assert isinstance(
message["content"], str
), f"Message {i} content must be a string"
assert message["role"] == expected_role, (
f"Message {i} has role {message['role']} but should be {expected_role}"
)
assert isinstance(message["content"], str), f"Message {i} content must be a string"
self.conversations.append(messages)

View File

@ -38,7 +38,6 @@ def extract_answer(completion):
class GSM8K(Task):
def __init__(self, subset, split, **kwargs):
super().__init__(**kwargs)
assert subset in ["main", "socratic"], "GSM8K subset must be main|socratic"
@ -47,7 +46,7 @@ class GSM8K(Task):
@property
def eval_type(self):
return "generative"
return 'generative'
def num_examples(self):
return len(self.ds)
@ -55,39 +54,32 @@ class GSM8K(Task):
def get_example(self, index):
"""Get a single problem from the dataset."""
row = self.ds[index]
question = row["question"] # string of the question prompt
answer = row[
"answer"
] # string of the full solution and the answer after #### marker
question = row['question'] # string of the question prompt
answer = row['answer'] # string of the full solution and the answer after #### marker
# Create and return the Conversation object
# This is tricky because GSM8K uses tool calls, which we need to parse here.
assistant_message_parts = []
parts = re.split(r"(<<[^>]+>>)", answer)
parts = re.split(r'(<<[^>]+>>)', answer)
for part in parts:
if part.startswith("<<") and part.endswith(">>"):
if part.startswith('<<') and part.endswith('>>'):
# This is a calculator tool call
inner = part[2:-2] # Remove << >>
# Split on = to get expression and result
if "=" in inner:
expr, result = inner.rsplit("=", 1)
if '=' in inner:
expr, result = inner.rsplit('=', 1)
else:
expr, result = inner, ""
# Add the tool call as a part
assistant_message_parts.append({"type": "python", "text": expr})
# Add the result as a part
assistant_message_parts.append(
{"type": "python_output", "text": result}
)
assistant_message_parts.append({"type": "python_output", "text": result})
else:
# Regular text in between tool calls
assistant_message_parts.append({"type": "text", "text": part})
# No put it all together
messages = [
{"role": "user", "content": question}, # note: simple string
{
"role": "assistant",
"content": assistant_message_parts,
}, # note: list of parts (as dicts)
{"role": "assistant", "content": assistant_message_parts}, # note: list of parts (as dicts)
]
conversation = {
"messages": messages,
@ -104,20 +96,12 @@ class GSM8K(Task):
TODO: Technically, assistant_response should be a Message (either a string or a list of parts)
We can handle this later possibly. For now just assume string.
"""
assert isinstance(
assistant_response, str
), "Assuming simple string response for now"
assert isinstance(assistant_response, str), "Assuming simple string response for now"
# First extract the ground truth answer
assistant_message = conversation["messages"][-1]
assert (
assistant_message["role"] == "assistant"
), "Last message must be from the Assistant"
assert isinstance(
assistant_message["content"], list
), "This is expected to be a list of parts"
last_text_part = assistant_message["content"][-1][
"text"
] # this contains the final answer in GSM8K
assistant_message = conversation['messages'][-1]
assert assistant_message['role'] == "assistant", "Last message must be from the Assistant"
assert isinstance(assistant_message['content'], list), "This is expected to be a list of parts"
last_text_part = assistant_message['content'][-1]['text'] # this contains the final answer in GSM8K
# Extract both the ground truth answer and the predicted answer
ref_num = extract_answer(last_text_part)
pred_num = extract_answer(assistant_response)

View File

@ -15,14 +15,14 @@ from tasks.common import Task
def extract_imports(prompt):
"""Extract import statements from the beginning of a code block."""
imports = []
for line in prompt.split("\n"):
for line in prompt.split('\n'):
stripped = line.strip()
if stripped.startswith("import ") or stripped.startswith("from "):
if stripped.startswith('import ') or stripped.startswith('from '):
imports.append(stripped)
elif stripped and not stripped.startswith("#"):
elif stripped and not stripped.startswith('#'):
# Stop at first non-import, non-comment line
break
return "\n".join(imports)
return '\n'.join(imports)
def extract_program(completion):
@ -38,7 +38,7 @@ def extract_program(completion):
"""
# Try to find markdown code blocks (```python or just ```)
# Match ```python\n...\n``` or ```\n...\n```
pattern = r"```(?:python)?\s*\n(.*?)\n```"
pattern = r'```(?:python)?\s*\n(.*?)\n```'
matches = re.findall(pattern, completion, re.DOTALL)
if matches:
@ -50,14 +50,13 @@ def extract_program(completion):
class HumanEval(Task):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.ds = load_dataset("openai/openai_humaneval", split="test").shuffle(seed=42)
@property
def eval_type(self):
return "generative"
return 'generative'
def num_examples(self):
return len(self.ds)
@ -65,10 +64,10 @@ class HumanEval(Task):
def get_example(self, index):
"""Get a single problem from the dataset."""
row = self.ds[index]
prompt = row["prompt"] # prompts in HumanEval are the beginning of the program
solution = row["canonical_solution"] # the correct continuation of the program
entry_point = row["entry_point"] # the function to check
test = row["test"] # the test cases
prompt = row['prompt'] # prompts in HumanEval are the beginning of the program
solution = row['canonical_solution'] # the correct continuation of the program
entry_point = row['entry_point'] # the function to check
test = row['test'] # the test cases
complete_solution = f"{prompt}\n{solution}"
messages = [
{"role": "user", "content": prompt},
@ -84,7 +83,7 @@ class HumanEval(Task):
def evaluate(self, conversation, completion):
"""Given (conversation, completion), return boolean success of the completion."""
# the prompt will contain the imports and the function signature
imports = extract_imports(conversation["messages"][0]["content"])
imports = extract_imports(conversation['messages'][0]['content'])
# the completion will usually contain the whole function
# but not always with the needed imports, so we manually append them
completion_code = extract_program(completion)
@ -93,7 +92,7 @@ class HumanEval(Task):
+ "\n\n"
+ completion_code
+ "\n\n"
+ conversation["test"]
+ conversation['test']
+ "\n"
+ f"check({conversation['entry_point']})"
)

View File

@ -9,80 +9,71 @@ from tasks.common import Task, render_mc
class MMLU(Task):
letters = ("A", "B", "C", "D")
letters = ('A', 'B', 'C', 'D')
groups = (
"abstract_algebra",
"anatomy",
"astronomy",
"business_ethics",
"clinical_knowledge",
"college_biology",
"college_chemistry",
"college_computer_science",
"college_mathematics",
"college_medicine",
"college_physics",
"computer_security",
"conceptual_physics",
"econometrics",
"electrical_engineering",
"elementary_mathematics",
"formal_logic",
"global_facts",
"high_school_biology",
"high_school_chemistry",
"high_school_computer_science",
"high_school_european_history",
"high_school_geography",
"high_school_government_and_politics",
"high_school_macroeconomics",
"high_school_mathematics",
"high_school_microeconomics",
"high_school_physics",
"high_school_psychology",
"high_school_statistics",
"high_school_us_history",
"high_school_world_history",
"human_aging",
"human_sexuality",
"international_law",
"jurisprudence",
"logical_fallacies",
"machine_learning",
"management",
"marketing",
"medical_genetics",
"miscellaneous",
"moral_disputes",
"moral_scenarios",
"nutrition",
"philosophy",
"prehistory",
"professional_accounting",
"professional_law",
"professional_medicine",
"professional_psychology",
"public_relations",
"security_studies",
"sociology",
"us_foreign_policy",
"virology",
"world_religions",
'abstract_algebra',
'anatomy',
'astronomy',
'business_ethics',
'clinical_knowledge',
'college_biology',
'college_chemistry',
'college_computer_science',
'college_mathematics',
'college_medicine',
'college_physics',
'computer_security',
'conceptual_physics',
'econometrics',
'electrical_engineering',
'elementary_mathematics',
'formal_logic',
'global_facts',
'high_school_biology',
'high_school_chemistry',
'high_school_computer_science',
'high_school_european_history',
'high_school_geography',
'high_school_government_and_politics',
'high_school_macroeconomics',
'high_school_mathematics',
'high_school_microeconomics',
'high_school_physics',
'high_school_psychology',
'high_school_statistics',
'high_school_us_history',
'high_school_world_history',
'human_aging',
'human_sexuality',
'international_law',
'jurisprudence',
'logical_fallacies',
'machine_learning',
'management',
'marketing',
'medical_genetics',
'miscellaneous',
'moral_disputes',
'moral_scenarios',
'nutrition',
'philosophy',
'prehistory',
'professional_accounting',
'professional_law',
'professional_medicine',
'professional_psychology',
'public_relations',
'security_studies',
'sociology',
'us_foreign_policy',
'virology',
'world_religions',
)
def __init__(self, subset, split, **kwargs):
super().__init__(**kwargs)
assert subset in [
"all",
"auxiliary_train",
], f"subset {subset} must be all|auxiliary_train"
assert split in [
"train",
"validation",
"dev",
"test",
], f"split {split} must be train|validation|dev|test"
assert subset in ["all", "auxiliary_train"], f"subset {subset} must be all|auxiliary_train"
assert split in ["train", "validation", "dev", "test"], f"split {split} must be train|validation|dev|test"
if subset == "auxiliary_train":
assert split == "train", "auxiliary_train must be split into train"
self.subset = subset
@ -90,11 +81,11 @@ class MMLU(Task):
self.ds = load_dataset("cais/mmlu", subset, split=split).shuffle(seed=42)
if subset == "auxiliary_train":
# I don't understand why but the auxiliary_train rows have some weird additional 'train' wrapper
self.ds = self.ds.map(lambda row: row["train"], remove_columns=["train"])
self.ds = self.ds.map(lambda row: row['train'], remove_columns=['train'])
@property
def eval_type(self):
return "categorical"
return 'categorical'
def num_examples(self):
return len(self.ds)
@ -109,10 +100,7 @@ class MMLU(Task):
# create and return the Conversation object
user_message = render_mc(question, self.letters, choices)
assistant_message = self.letters[answer]
messages = [
{"role": "user", "content": user_message},
{"role": "assistant", "content": assistant_message},
]
messages = [{"role": "user", "content": user_message}, {"role": "assistant", "content": assistant_message}]
conversation = {
"messages": messages,
"subject": subject, # might be useful later for grouping metrics by subject
@ -123,8 +111,8 @@ class MMLU(Task):
def evaluate(self, conversation, assistant_response):
# the assert here is not strictly speaking needed, but currently the way we eval, we expect this to be true
# I'm going to leave the assert here to prevent footguns, but possibly in the future can remove it.
assert (
assistant_response in self.letters
), f"MMLU answer {assistant_response} is expected to be one of {self.letters}"
assistant_message = conversation["messages"][-1]["content"] # e.g. "A"
assert assistant_response in self.letters, (
f"MMLU answer {assistant_response} is expected to be one of {self.letters}"
)
assistant_message = conversation['messages'][-1]['content'] # e.g. "A"
return assistant_response == assistant_message

View File

@ -15,9 +15,7 @@ class SmolTalk(Task):
def __init__(self, split, **kwargs):
super().__init__(**kwargs)
assert split in ["train", "test"], "SmolTalk split must be train|test"
self.ds = load_dataset("HuggingFaceTB/smol-smoltalk", split=split).shuffle(
seed=42
)
self.ds = load_dataset("HuggingFaceTB/smol-smoltalk", split=split).shuffle(seed=42)
self.length = len(self.ds)
def num_examples(self):
@ -36,15 +34,13 @@ class SmolTalk(Task):
rest_messages = messages[1:] # optional system message is OK
else:
rest_messages = messages
assert (
len(rest_messages) >= 2
), "SmolTalk messages must have at least 2 messages"
assert len(rest_messages) >= 2, "SmolTalk messages must have at least 2 messages"
for i, message in enumerate(rest_messages):
# user and assistant alternate as user,assistant,user,assistant,...
expected_role = "user" if i % 2 == 0 else "assistant"
assert (
message["role"] == expected_role
), f"Message {i} has role {message['role']} but should be {expected_role}"
assert message["role"] == expected_role, (
f"Message {i} has role {message['role']} but should be {expected_role}"
)
assert isinstance(message["content"], str), "Content must be a string"
# ---------------------------------------------------------------------
# create and return the Conversation object (ok to emit the system message too)

View File

@ -116,7 +116,6 @@ USER_MSG_TEMPLATES = [
class SpellingBee(Task):
def __init__(self, size=1000, split="train", **kwargs):
super().__init__(**kwargs)
assert split in ["train", "test"], "SpellingBee split must be train|test"
@ -124,13 +123,13 @@ class SpellingBee(Task):
self.split = split
filename = WORD_LIST_URL.split("/")[-1]
word_list_path = download_file_with_lock(WORD_LIST_URL, filename)
with open(word_list_path, encoding="utf-8") as f:
with open(word_list_path, encoding='utf-8') as f:
words = [line.strip() for line in f]
self.words = words
@property
def eval_type(self):
return "generative"
return 'generative'
def num_examples(self):
return self.size
@ -152,7 +151,7 @@ class SpellingBee(Task):
# 30% chance to lowercase the template (lazy people don't use shift)
if rng.random() < 0.3:
template = template.lower()
quote_options = ["", "'", '"']
quote_options = ['', "'", '"']
letter_quote = rng.choice(quote_options) # is the letter quoted?
word_quote = rng.choice(quote_options) # is the word quoted?
letter_wrapped = f"{letter_quote}{letter}{letter_quote}"
@ -188,9 +187,7 @@ Then count the occurrences of '{letter}':
manual_text += f"\nThis gives us {running_count}."
assistant_parts.append({"type": "text", "text": manual_text})
# Part 2: Python verification
assistant_parts.append(
{"type": "text", "text": "\n\nLet me double check this using Python:\n\n"}
)
assistant_parts.append({"type": "text", "text": "\n\nLet me double check this using Python:\n\n"})
# Part 3: Python tool call
python_expr = f"'{word}'.count('{letter}')"
assistant_parts.append({"type": "python", "text": python_expr})
@ -198,17 +195,11 @@ Then count the occurrences of '{letter}':
assistant_parts.append({"type": "python_output", "text": str(count)})
# Part 5: Final answer
assistant_parts.append(
{
"type": "text",
"text": f"\n\nPython gives us {count}.\n\nMy final answer is:\n\n#### {count}",
}
{"type": "text", "text": f"\n\nPython gives us {count}.\n\nMy final answer is:\n\n#### {count}"}
)
# return the full conversation
messages = [
{"role": "user", "content": user_msg},
{"role": "assistant", "content": assistant_parts},
]
messages = [{"role": "user", "content": user_msg}, {"role": "assistant", "content": assistant_parts}]
conversation = {
"messages": messages,
}
@ -219,19 +210,13 @@ Then count the occurrences of '{letter}':
Given (conversation, completion), return evaluation outcome (0 = wrong, 1 = correct)
Identical to gsm8k's evaluation.
"""
assert isinstance(
assistant_response, str
), "Assuming simple string response for now"
assert isinstance(assistant_response, str), "Assuming simple string response for now"
# First extract the ground truth answer from the conversation
assistant_message = conversation["messages"][-1]
assert (
assistant_message["role"] == "assistant"
), "Last message must be from the Assistant"
assert isinstance(
assistant_message["content"], list
), "This is expected to be a list of parts"
assistant_message = conversation['messages'][-1]
assert assistant_message['role'] == "assistant", "Last message must be from the Assistant"
assert isinstance(assistant_message['content'], list), "This is expected to be a list of parts"
# The last text part contains the final answer with ####
last_text_part = assistant_message["content"][-1]["text"]
last_text_part = assistant_message['content'][-1]['text']
# Extract both the ground truth answer and the predicted answer
ref_num = extract_answer(last_text_part)
pred_num = extract_answer(assistant_response)
@ -256,7 +241,7 @@ class SimpleSpelling(Task):
self.split = split
filename = WORD_LIST_URL.split("/")[-1]
word_list_path = download_file_with_lock(WORD_LIST_URL, filename)
with open(word_list_path, encoding="utf-8") as f:
with open(word_list_path, encoding='utf-8') as f:
words = [line.strip() for line in f]
rng = random.Random(42)
rng.shuffle(words) # use a different word order than the SpellingBee task
@ -264,7 +249,7 @@ class SimpleSpelling(Task):
@property
def eval_type(self):
return "generative"
return 'generative'
def num_examples(self):
return self.size
@ -287,23 +272,22 @@ class SimpleSpelling(Task):
if __name__ == "__main__":
# preview the SpellingBee task, first 10 examples
task = SpellingBee()
for i in range(10):
ex = task.get_example(i)
print("=" * 100)
print(ex["messages"][0]["content"])
print(ex['messages'][0]['content'])
print("-" * 100)
# Assistant content is now a list of parts
assistant_parts = ex["messages"][1]["content"]
assistant_parts = ex['messages'][1]['content']
for part in assistant_parts:
if part["type"] == "text":
print(part["text"], end="")
elif part["type"] == "python":
print(f"<<{part['text']}=", end="")
elif part["type"] == "python_output":
print(f"{part['text']}>>", end="")
if part['type'] == 'text':
print(part['text'], end='')
elif part['type'] == 'python':
print(f"<<{part['text']}=", end='')
elif part['type'] == 'python_output':
print(f"{part['text']}>>", end='')
print()
print("-" * 100)

View File

@ -23,26 +23,14 @@ def test_kv_cache_resize():
num_layers = 6
kv_cache = KVCache(
batch_size=batch_size,
num_heads=num_heads,
seq_len=seq_len,
head_dim=head_dim,
num_layers=num_layers,
batch_size=batch_size, num_heads=num_heads, seq_len=seq_len, head_dim=head_dim, num_layers=num_layers
)
# Insert a single token with a distinct fill value to all layers
def insert_token(token_idx):
for layer_idx in range(num_layers):
k = torch.full(
(batch_size, num_heads, 1, head_dim),
fill_value=float(token_idx),
dtype=torch.float32,
)
v = torch.full(
(batch_size, num_heads, 1, head_dim),
fill_value=float(token_idx * 100),
dtype=torch.float32,
)
k = torch.full((batch_size, num_heads, 1, head_dim), fill_value=float(token_idx), dtype=torch.float32)
v = torch.full((batch_size, num_heads, 1, head_dim), fill_value=float(token_idx * 100), dtype=torch.float32)
kv_cache.insert_kv(layer_idx, k, v)
# Insert 4 tokens (fills the initial seq_len=4)
@ -57,9 +45,9 @@ def test_kv_cache_resize():
insert_token(4)
# Verify that the cache actually resized
new_seq_len = kv_cache.kv_cache.shape[4]
assert (
new_seq_len > original_seq_len
), f"Cache did not resize: original seq_len={original_seq_len}, new seq_len={new_seq_len}"
assert new_seq_len > original_seq_len, (
f"Cache did not resize: original seq_len={original_seq_len}, new seq_len={new_seq_len}"
)
# Verify that the original 4 tokens are still intact after resize
for layer_idx in range(num_layers):
@ -69,20 +57,14 @@ def test_kv_cache_resize():
expected_v = float(token_idx * 100)
actual_k = kv_cache.kv_cache[layer_idx, 0, :, :, token_idx, :]
actual_v = kv_cache.kv_cache[layer_idx, 1, :, :, token_idx, :]
assert (
actual_k == expected_k
).all(), f"Layer {layer_idx}, token {token_idx}: key corrupted, expected {expected_k}"
assert (
actual_v == expected_v
).all(), f"Layer {layer_idx}, token {token_idx}: value corrupted, expected {expected_v}"
assert (actual_k == expected_k).all(), (
f"Layer {layer_idx}, token {token_idx}: key corrupted, expected {expected_k}"
)
assert (actual_v == expected_v).all(), (
f"Layer {layer_idx}, token {token_idx}: value corrupted, expected {expected_v}"
)
# And that the original cache matches resized cache
original_k = original_cache[layer_idx, 0, :, :, token_idx, :]
original_v = original_cache[layer_idx, 1, :, :, token_idx, :]
assert (
actual_k == original_k
).all(), f"Layer {layer_idx}, token {token_idx}: key doesn't match original"
assert (
actual_v == original_v
).all(), (
f"Layer {layer_idx}, token {token_idx}: value doesn't match original"
)
assert (actual_k == original_k).all(), f"Layer {layer_idx}, token {token_idx}: key doesn't match original"
assert (actual_v == original_v).all(), f"Layer {layer_idx}, token {token_idx}: value doesn't match original"

View File

@ -27,7 +27,9 @@ import tiktoken
import rustbpe
GPT4_SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""
GPT4_SPLIT_PATTERN = (
r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""
)
# -----------------------------------------------------------------------------
# Reference tokenizer, pretty much copy pasted and pruned a bit from minbpe
@ -65,7 +67,6 @@ def merge(ids, pair, idx):
class RegexTokenizer:
def __init__(self, pattern=None):
"""
- pattern: optional string to override the default (GPT-4 split pattern)
@ -114,9 +115,7 @@ class RegexTokenizer:
pair = max(stats, key=stats.get)
# check if the merge is ambiguous - i.e. the max value is not unique
pair_count = stats[pair]
pairs_with_max_count = [
pair for pair, count in stats.items() if count == pair_count
]
pairs_with_max_count = [pair for pair, count in stats.items() if count == pair_count]
if len(pairs_with_max_count) > 1:
# print the top 10 pairs with their counts
# print(f"{i} Merge is ambiguous! {pair} has {pair_count} occurrences")
@ -132,9 +131,7 @@ class RegexTokenizer:
vocab[idx] = vocab[pair[0]] + vocab[pair[1]]
# prints
if verbose:
print(
f"merge {i+1}/{num_merges}: {pair} -> {idx} ({vocab[idx]}) had {stats[pair]} occurrences"
)
print(f"merge {i + 1}/{num_merges}: {pair} -> {idx} ({vocab[idx]}) had {stats[pair]} occurrences")
# save class variables
self.merges = merges # used in encode()
@ -195,7 +192,6 @@ def fast_merge_inplace(ids, pair, idx):
class FastRegexTokenizer:
def __init__(self, pattern=None):
"""
- pattern: optional string to override the default (GPT-4 split pattern)
@ -245,9 +241,7 @@ class FastRegexTokenizer:
# Initial count: build stats and position tracking
stats = defaultdict(int)
positions = defaultdict(
set
) # pair -> set of chunk indices that contain this pair
positions = defaultdict(set) # pair -> set of chunk indices that contain this pair
for chunk_idx, (chunk_ids, count) in enumerate(zip(ids, chunk_counts)):
for pair in zip(chunk_ids, chunk_ids[1:]):
@ -316,8 +310,7 @@ class FastRegexTokenizer:
for chunk_idx in affected_chunks:
chunk_ids = ids[chunk_idx]
contains_pair = any(
(chunk_ids[j], chunk_ids[j + 1]) == changed_pair
for j in range(len(chunk_ids) - 1)
(chunk_ids[j], chunk_ids[j + 1]) == changed_pair for j in range(len(chunk_ids) - 1)
)
if contains_pair:
positions[changed_pair].add(chunk_idx)
@ -390,9 +383,8 @@ class FastRegexTokenizer:
# -----------------------------------------------------------------------------
# HuggingFace tokenizer
from tokenizers import Regex
from tokenizers import Regex, decoders, pre_tokenizers
from tokenizers import Tokenizer as HFTokenizer
from tokenizers import decoders, pre_tokenizers
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
@ -417,14 +409,10 @@ class HuggingFaceTokenizer:
# Normalizer: None
tokenizer.normalizer = None
# Pre-tokenizer: GPT-4 style
gpt4_split_regex = Regex(
GPT4_SPLIT_PATTERN
) # huggingface demands that you wrap it in Regex!!
gpt4_split_regex = Regex(GPT4_SPLIT_PATTERN) # huggingface demands that you wrap it in Regex!!
tokenizer.pre_tokenizer = pre_tokenizers.Sequence(
[
pre_tokenizers.Split(
pattern=gpt4_split_regex, behavior="isolated", invert=False
),
pre_tokenizers.Split(pattern=gpt4_split_regex, behavior="isolated", invert=False),
pre_tokenizers.ByteLevel(add_prefix_space=False, use_regex=False),
]
)
@ -515,9 +503,7 @@ def test_correctness(enwik8_small):
# Train slow reference
print("\nTraining slow reference...")
slow_reference_tokenizer = RegexTokenizer()
ambiguous_flag, slow_reference_train_time = time_function(
slow_reference_tokenizer.train, text, vocab_size
)
ambiguous_flag, slow_reference_train_time = time_function(slow_reference_tokenizer.train, text, vocab_size)
slow_reference_ids, slow_reference_encode_time = time_function(
slow_reference_tokenizer.encode_ordinary, encode_text
)
@ -526,21 +512,15 @@ def test_correctness(enwik8_small):
print(slow_reference_ids[:20])
if ambiguous_flag:
print(
"‼️ WARNING: merge order was detected to be ambiguous given current text and vocab size"
)
print(
"The implementation could be correct but we might see different results below"
)
print("‼️ WARNING: merge order was detected to be ambiguous given current text and vocab size")
print("The implementation could be correct but we might see different results below")
else:
print("✅ Merge order is NOT ambiguous")
# Train fast reference
print("\nTraining fast reference...")
fast_reference_tokenizer = FastRegexTokenizer()
_, fast_reference_train_time = time_function(
fast_reference_tokenizer.train, text, vocab_size
)
_, fast_reference_train_time = time_function(fast_reference_tokenizer.train, text, vocab_size)
fast_reference_ids, fast_reference_encode_time = time_function(
fast_reference_tokenizer.encode_ordinary, encode_text
)
@ -549,16 +529,12 @@ def test_correctness(enwik8_small):
print(fast_reference_ids[:20])
# Assert fast equals slow
assert (
fast_reference_ids == slow_reference_ids
), "Fast reference should match slow reference"
assert fast_reference_ids == slow_reference_ids, "Fast reference should match slow reference"
print("✅ Fast == Slow")
# Train HuggingFace
print("\nTraining HuggingFace...")
hf_tokenizer, hf_train_time = time_function(
HuggingFaceTokenizer.train_from_iterator, [text], vocab_size
)
hf_tokenizer, hf_train_time = time_function(HuggingFaceTokenizer.train_from_iterator, [text], vocab_size)
hf_ids, hf_encode_time = time_function(hf_tokenizer.encode_ordinary, encode_text)
print(f"HuggingFace train time: {hf_train_time:.4f}s")
print(f"HuggingFace encode time: {hf_encode_time:.4f}s")
@ -577,20 +553,14 @@ def test_correctness(enwik8_small):
return False
return True
assert custom_match(
hf_ids, fast_reference_ids
), "HuggingFace should match fast reference"
assert custom_match(hf_ids, fast_reference_ids), "HuggingFace should match fast reference"
print("✅ HuggingFace == Fast")
# Finally use our own Rust implementation
print("\nTraining rustbpe...")
rustbpe_tokenizer = rustbpe.Tokenizer()
_, rustbpe_train_time = time_function(
rustbpe_tokenizer.train_from_iterator, [text], vocab_size
)
rustbpe_ids, rustbpe_encode_time = time_function(
rustbpe_tokenizer.encode, encode_text
)
_, rustbpe_train_time = time_function(rustbpe_tokenizer.train_from_iterator, [text], vocab_size)
rustbpe_ids, rustbpe_encode_time = time_function(rustbpe_tokenizer.encode, encode_text)
print(f"RustBPE train time: {rustbpe_train_time:.4f}s")
print(f"RustBPE encode time: {rustbpe_encode_time:.4f}s")
print(rustbpe_ids[:20])
@ -634,25 +604,21 @@ def test_training_performance(enwik8_large):
# Train rustbpe
print("\nTraining rustbpe...")
rustbpe_tokenizer = rustbpe.Tokenizer()
_, rustbpe_train_time = time_function(
rustbpe_tokenizer.train_from_iterator, [text], vocab_size
)
_, rustbpe_train_time = time_function(rustbpe_tokenizer.train_from_iterator, [text], vocab_size)
print(f"RustBPE train time: {rustbpe_train_time:.4f}s")
assert rustbpe_train_time > 0, "Training should take some time"
# Train HuggingFace
print("\nTraining HuggingFace...")
hf_tokenizer, hf_train_time = time_function(
HuggingFaceTokenizer.train_from_iterator, [text], vocab_size
)
hf_tokenizer, hf_train_time = time_function(HuggingFaceTokenizer.train_from_iterator, [text], vocab_size)
print(f"HuggingFace train time: {hf_train_time:.4f}s")
assert hf_train_time > 0, "Training should take some time"
# Print comparison
print(f"\n📊 Performance comparison:")
print("\n📊 Performance comparison:")
print(f" RustBPE: {rustbpe_train_time:.4f}s")
print(f" HuggingFace: {hf_train_time:.4f}s")
print(f" Speedup: {hf_train_time/rustbpe_train_time:.2f}x")
print(f" Speedup: {hf_train_time / rustbpe_train_time:.2f}x")
def test_interface(enwik8_small):
@ -664,9 +630,7 @@ def test_interface(enwik8_small):
# Simple train test
vocab_size = 300
tok = RustBPETokenizer.train_from_iterator([enwik8_small], vocab_size)
assert (
tok.get_vocab_size() == vocab_size
), f"Expected vocab size {vocab_size}, got {tok.get_vocab_size()}"
assert tok.get_vocab_size() == vocab_size, f"Expected vocab size {vocab_size}, got {tok.get_vocab_size()}"
print(f"✅ Trained tokenizer with vocab size {vocab_size}")
# Encode/decode text
@ -676,24 +640,18 @@ def test_interface(enwik8_small):
print(f"IDs: {ids}")
decoded = tok.decode(ids)
print(f"Decoded: {decoded}")
assert (
decoded == encode_text
), f"Decoded text doesn't match: {decoded} != {encode_text}"
assert decoded == encode_text, f"Decoded text doesn't match: {decoded} != {encode_text}"
print("✅ Encode/decode test passed")
# Encode batch test
ids_new = tok.encode([encode_text, encode_text])
assert all(
x == ids for x in ids_new
), "Batch encoding should produce identical results"
assert all(x == ids for x in ids_new), "Batch encoding should produce identical results"
print("✅ Encode batch OK")
# append/prepend functionality
ids_special = tok.encode(encode_text, prepend="<|bos|>", append="<|bos|>")
bos_token_id = tok.encode_special("<|bos|>")
assert ids_special == [bos_token_id] + ids + [
bos_token_id
], "Special tokens not correctly added"
assert ids_special == [bos_token_id] + ids + [bos_token_id], "Special tokens not correctly added"
print("✅ append/prepend OK")
# Save/load test through a temporary directory