Compare commits

...

10 Commits

Author SHA1 Message Date
Sermet Pekin
6181b08c14
Merge 33ddc13ed4 into f66a780f68 2025-11-14 14:44:02 -08:00
Andrej
f66a780f68
Fix torch.dtype mismatching when running engine inline test. 2025-11-14 07:28:29 -08:00
Andrej
4763ce612a
Small fixes to typos 2025-11-14 07:25:59 -08:00
Sofie Van Landeghem
c6f5bd67db
revert change of base to sft for quick inline test 2025-11-14 12:20:03 +01:00
svlandeg
a2fb3c83a6 fix typos 2025-11-14 11:20:25 +01:00
svlandeg
e5efb4b471 add test_engine.py to file structure 2025-11-14 11:13:42 +01:00
howardgao@outlook.com
b399e43168 fix engine test bug 2025-11-06 08:56:45 +08:00
svlandeg
52e85aaf80 Merge branch 'master' into fix/typo 2025-11-02 13:41:13 +01:00
svlandeg
70319851fc fix typo 2025-10-29 19:48:34 +01:00
Sermet Pekin
33ddc13ed4
Improve configurator: add testable parse_args() and ConfigManager class
Refactored configurator for better architecture and control over global variable injection.
2025-10-23 09:56:52 +03:00
6 changed files with 85 additions and 52 deletions

View File

@ -184,6 +184,7 @@ python -m pytest tests/test_rustbpe.py -v -s
│ ├── smoltalk.py # Conglomerate dataset of SmolTalk from HF
│ └── spellingbee.py # Task teaching model to spell/count letters
├── tests
│ └── test_engine.py
│ └── test_rustbpe.py
└── uv.lock
```

View File

@ -1,56 +1,82 @@
"""
Poor Man's Configurator. Probably a terrible idea. Example usage:
Poor Man's Configurator v3. Clean refactored version. Example usage:
$ python train.py config/override_file.py --batch_size=32
this will first run config/override_file.py, then override batch_size to 32
The code in this file will be run as follows from e.g. train.py:
>>> exec(open('configurator.py').read())
So it's not a Python module, it's just shuttling this code away from train.py
The code in this script then overrides the globals()
I know people are not going to love this, I just really dislike configuration
complexity and having to prepend config. to every single variable. If someone
comes up with a better simple Python solution I am all ears.
Improved version with better separation of concerns and cleaner architecture
while maintaining the same simple usage pattern.
"""
import os
import sys
from ast import literal_eval
def print0(s="",**kwargs):
ddp_rank = int(os.environ.get('RANK', 0))
if ddp_rank == 0:
def print0(s="", **kwargs):
"""Print only from rank 0 in distributed settings"""
if int(os.environ.get("RANK", 0)) == 0:
print(s, **kwargs)
for arg in sys.argv[1:]:
if '=' not in arg:
# assume it's the name of a config file
assert not arg.startswith('--')
config_file = arg
print0(f"Overriding config with {config_file}:")
with open(config_file) as f:
print0(f.read())
exec(open(config_file).read())
else:
# assume it's a --key=value argument
assert arg.startswith('--')
key, val = arg.split('=')
key = key[2:]
if key in globals():
class ConfigManager:
"""Clean configurator with explicit global injection control"""
def __init__(self):
self.config = {}
def load(self, config_file=None, **overrides):
"""Load configuration from file and apply overrides"""
if config_file:
with open(config_file) as f:
exec(f.read(), {}, self.config)
self.config.update(overrides)
return self
def inject_globals(self):
"""Inject config vars into global namespace"""
for k, v in self.config.items():
if not k.startswith("_"):
globals()[k] = v
def parse_args(args=None):
"""Parse command line arguments like original version"""
args = args or sys.argv
config_file, overrides = None, {}
for arg in args[1:]:
if "=" not in arg:
# Config file
assert not arg.startswith("--"), f"Invalid config file: {arg}"
config_file = arg
print0(f"Overriding config with {config_file}:")
with open(config_file) as f:
print0(f.read())
else:
# Key=value override
assert arg.startswith("--"), f"Override must start with --: {arg}"
key, val = arg.split("=", 1)
key = key[2:]
if key not in globals():
raise ValueError(f"Unknown config key: {key}")
# Try to parse the value
try:
# attempt to eval it it (e.g. if bool, number, or etc)
attempt = literal_eval(val)
except (SyntaxError, ValueError):
# if that goes wrong, just use the string
attempt = val
# ensure the types match ok
# Type check if global has a non-None value
if globals()[key] is not None:
attempt_type = type(attempt)
default_type = type(globals()[key])
default_type, attempt_type = type(globals()[key]), type(attempt)
assert attempt_type == default_type, f"Type mismatch: {attempt_type} != {default_type}"
# cross fingers
print0(f"Overriding: {key} = {attempt}")
globals()[key] = attempt
else:
raise ValueError(f"Unknown config key: {key}")
overrides[key] = attempt
return config_file, overrides
# Execute configuration
config_file, overrides = parse_args()
ConfigManager().load(config_file, **overrides).inject_globals()

View File

@ -17,8 +17,9 @@ import signal
import warnings
from contextlib import contextmanager
from collections import deque
from nanochat.common import compute_init
from nanochat.common import compute_init, autodetect_device_type
from nanochat.checkpoint_manager import load_model
from contextlib import nullcontext
# -----------------------------------------------------------------------------
# Calculator tool helpers
@ -328,6 +329,9 @@ if __name__ == "__main__":
import time
# init compute
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
device_type = autodetect_device_type()
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
# load the model and tokenizer
model, tokenizer, meta = load_model("base", device, phase="eval")
bos_token_id = tokenizer.get_bos_token_id()
@ -340,10 +344,11 @@ if __name__ == "__main__":
torch.cuda.synchronize()
t0 = time.time()
stream = model.generate(prompt_tokens, **kwargs)
for token in stream:
generated_tokens.append(token)
chunk = tokenizer.decode([token])
print(chunk, end="", flush=True)
with autocast_ctx:
for token in stream:
generated_tokens.append(token)
chunk = tokenizer.decode([token])
print(chunk, end="", flush=True)
print()
torch.cuda.synchronize()
t1 = time.time()
@ -355,11 +360,12 @@ if __name__ == "__main__":
stream = engine.generate(prompt_tokens, num_samples=1, **kwargs) # note: runs in fp32
torch.cuda.synchronize()
t0 = time.time()
for token_column, token_masks in stream:
token = token_column[0] # only print out the first row
generated_tokens.append(token)
chunk = tokenizer.decode([token])
print(chunk, end="", flush=True)
with autocast_ctx:
for token_column, token_masks in stream:
token = token_column[0] # only print out the first row
generated_tokens.append(token)
chunk = tokenizer.decode([token])
print(chunk, end="", flush=True)
print()
torch.cuda.synchronize()
t1 = time.time()

View File

@ -9,9 +9,9 @@ import torch.distributed as dist
def evaluate_bpb(model, batches, steps, token_bytes):
"""
Instead of the naive 'mean loss', this function returns the bits per byte (bpb),
which is a tokenization vocab size-indepedent metric, meaning you are still comparing
which is a tokenization vocab size-independent metric, meaning you are still comparing
apples:apples if you change the vocab size. The way this works is that instead of just
calculating the average loss as usual, you calculate the sum loss, and indepependently
calculating the average loss as usual, you calculate the sum loss, and independently
also the sum bytes (of all the target tokens), and divide. This normalizes the loss by
the number of bytes that the target tokens represent.

View File

@ -1,6 +1,6 @@
"""
Evaluate the Chat model.
All the generic code lives here, and all the evlauation-specific
All the generic code lives here, and all the evaluation-specific
code lives in nanochat directory and is imported from here.
Example runs:

View File

@ -192,7 +192,7 @@ for step in range(num_iterations):
})
model.train()
# evlauate accuracy of the multiple choice tasks (which are quick to run)
# evaluate accuracy of the multiple choice tasks (which are quick to run)
if last_step or (step > 0 and step % eval_metrics_every == 0):
model.eval()
metrics = {}