mirror of
https://github.com/karpathy/nanochat.git
synced 2026-03-25 14:15:21 +00:00
235 lines
7.7 KiB
Python
235 lines
7.7 KiB
Python
import io
|
|
import logging
|
|
import os
|
|
import types
|
|
|
|
import pytest
|
|
import torch
|
|
|
|
import nanochat.common as common
|
|
|
|
|
|
def test_colored_formatter_info_and_non_info():
|
|
fmt = common.ColoredFormatter("%(levelname)s %(message)s")
|
|
rec = logging.LogRecord("x", logging.INFO, __file__, 1, "Shard 2: 10 MB 5%", (), None)
|
|
out = fmt.format(rec)
|
|
assert "Shard 2" in out
|
|
assert "10 MB" in out
|
|
assert "5%" in out
|
|
|
|
rec2 = logging.LogRecord("x", logging.WARNING, __file__, 1, "warn", (), None)
|
|
out2 = fmt.format(rec2)
|
|
assert "warn" in out2
|
|
|
|
|
|
def test_setup_default_logging(monkeypatch):
|
|
calls = {}
|
|
|
|
def fake_basic_config(**kwargs):
|
|
calls["kwargs"] = kwargs
|
|
|
|
monkeypatch.setattr(common.logging, "basicConfig", fake_basic_config)
|
|
common.setup_default_logging()
|
|
assert calls["kwargs"]["level"] == logging.INFO
|
|
assert len(calls["kwargs"]["handlers"]) == 1
|
|
|
|
|
|
def test_get_base_dir_env_and_default(monkeypatch, tmp_path):
|
|
monkeypatch.setenv("NANOCHAT_BASE_DIR", str(tmp_path / "from_env"))
|
|
got = common.get_base_dir()
|
|
assert got.endswith("from_env")
|
|
assert os.path.isdir(got)
|
|
|
|
monkeypatch.delenv("NANOCHAT_BASE_DIR", raising=False)
|
|
monkeypatch.setattr(common.os.path, "expanduser", lambda _: str(tmp_path))
|
|
got2 = common.get_base_dir()
|
|
assert got2.endswith(".cache/nanochat")
|
|
assert os.path.isdir(got2)
|
|
|
|
|
|
def test_download_file_with_lock_paths(monkeypatch, tmp_path):
|
|
monkeypatch.setattr(common, "get_base_dir", lambda: str(tmp_path))
|
|
out = tmp_path / "f.bin"
|
|
out.write_bytes(b"x")
|
|
assert common.download_file_with_lock("http://x", "f.bin") == str(out)
|
|
|
|
out.unlink()
|
|
payload = b"hello"
|
|
marker = {"post": None}
|
|
|
|
class Resp:
|
|
def __enter__(self):
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc, tb):
|
|
return False
|
|
|
|
def read(self):
|
|
return payload
|
|
|
|
monkeypatch.setattr(common.urllib.request, "urlopen", lambda _url: Resp())
|
|
got = common.download_file_with_lock(
|
|
"http://example.test",
|
|
"f.bin",
|
|
postprocess_fn=lambda p: marker.__setitem__("post", p),
|
|
)
|
|
assert got == str(out)
|
|
assert out.read_bytes() == payload
|
|
assert marker["post"] == str(out)
|
|
|
|
|
|
def test_download_file_with_lock_recheck_after_lock(monkeypatch, tmp_path):
|
|
monkeypatch.setattr(common, "get_base_dir", lambda: str(tmp_path))
|
|
file_path = str(tmp_path / "race.bin")
|
|
|
|
calls = {"n": 0}
|
|
real_exists = common.os.path.exists
|
|
|
|
def fake_exists(path):
|
|
if path == file_path:
|
|
calls["n"] += 1
|
|
# first check (before lock): missing; second check (inside lock): present
|
|
return calls["n"] >= 2
|
|
return real_exists(path)
|
|
|
|
class Lock:
|
|
def __init__(self, _p):
|
|
pass
|
|
|
|
def __enter__(self):
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc, tb):
|
|
return False
|
|
|
|
monkeypatch.setattr(common.os.path, "exists", fake_exists)
|
|
monkeypatch.setattr(common, "FileLock", Lock)
|
|
got = common.download_file_with_lock("http://unused", "race.bin")
|
|
assert got == file_path
|
|
|
|
|
|
def test_print0_and_banner(monkeypatch, capsys):
|
|
monkeypatch.setenv("RANK", "1")
|
|
common.print0("hidden")
|
|
assert capsys.readouterr().out == ""
|
|
|
|
monkeypatch.setenv("RANK", "0")
|
|
common.print0("shown")
|
|
assert "shown" in capsys.readouterr().out
|
|
|
|
common.print_banner()
|
|
assert "█████" in capsys.readouterr().out
|
|
|
|
|
|
def test_ddp_helpers(monkeypatch):
|
|
for k in ("RANK", "LOCAL_RANK", "WORLD_SIZE"):
|
|
monkeypatch.delenv(k, raising=False)
|
|
assert common.is_ddp_requested() is False
|
|
assert common.get_dist_info() == (False, 0, 0, 1)
|
|
|
|
monkeypatch.setenv("RANK", "3")
|
|
monkeypatch.setenv("LOCAL_RANK", "2")
|
|
monkeypatch.setenv("WORLD_SIZE", "8")
|
|
assert common.is_ddp_requested() is True
|
|
assert common.get_dist_info() == (True, 3, 2, 8)
|
|
|
|
monkeypatch.setattr(common.dist, "is_available", lambda: True)
|
|
monkeypatch.setattr(common.dist, "is_initialized", lambda: True)
|
|
assert common.is_ddp_initialized() is True
|
|
|
|
|
|
def test_autodetect_device_type(monkeypatch):
|
|
monkeypatch.setattr(common.torch.cuda, "is_available", lambda: True)
|
|
assert common.autodetect_device_type() == "cuda"
|
|
|
|
monkeypatch.setattr(common.torch.cuda, "is_available", lambda: False)
|
|
monkeypatch.setattr(common.torch.backends.mps, "is_available", lambda: True)
|
|
assert common.autodetect_device_type() == "mps"
|
|
|
|
monkeypatch.setattr(common.torch.backends.mps, "is_available", lambda: False)
|
|
assert common.autodetect_device_type() == "cpu"
|
|
|
|
|
|
def test_compute_init_cpu_and_cleanup(monkeypatch):
|
|
monkeypatch.setattr(common, "get_dist_info", lambda: (False, 0, 0, 1))
|
|
out = common.compute_init("cpu")
|
|
assert out[0] is False
|
|
assert out[-1].type == "cpu"
|
|
|
|
called = {"destroy": 0}
|
|
monkeypatch.setattr(common, "is_ddp_initialized", lambda: True)
|
|
monkeypatch.setattr(common.dist, "destroy_process_group", lambda: called.__setitem__("destroy", 1))
|
|
common.compute_cleanup()
|
|
assert called["destroy"] == 1
|
|
|
|
|
|
def test_compute_init_mps_and_cuda_paths(monkeypatch):
|
|
monkeypatch.setattr(common, "get_dist_info", lambda: (True, 1, 0, 2))
|
|
monkeypatch.setattr(common.torch.backends.mps, "is_available", lambda: True)
|
|
out = common.compute_init("mps")
|
|
assert out[0] is True
|
|
assert out[-1].type == "mps"
|
|
|
|
monkeypatch.setattr(common.torch.cuda, "is_available", lambda: True)
|
|
set_device_calls = {}
|
|
init_calls = {}
|
|
barrier_calls = {}
|
|
matmul_calls = {}
|
|
seed_calls = {}
|
|
cuda_seed_calls = {}
|
|
|
|
monkeypatch.setattr(common.torch.cuda, "set_device", lambda d: set_device_calls.__setitem__("device", d))
|
|
monkeypatch.setattr(common.dist, "init_process_group", lambda **kw: init_calls.__setitem__("kw", kw))
|
|
monkeypatch.setattr(common.dist, "barrier", lambda: barrier_calls.__setitem__("n", 1))
|
|
monkeypatch.setattr(common.torch, "set_float32_matmul_precision", lambda x: matmul_calls.__setitem__("x", x))
|
|
monkeypatch.setattr(common.torch, "manual_seed", lambda x: seed_calls.__setitem__("x", x))
|
|
monkeypatch.setattr(common.torch.cuda, "manual_seed", lambda x: cuda_seed_calls.__setitem__("x", x))
|
|
|
|
out2 = common.compute_init("cuda")
|
|
assert out2[0] is True
|
|
assert out2[-1].type == "cuda"
|
|
assert matmul_calls["x"] == "high"
|
|
assert seed_calls["x"] == 42
|
|
assert cuda_seed_calls["x"] == 42
|
|
assert "device_id" in init_calls["kw"]
|
|
assert barrier_calls["n"] == 1
|
|
assert set_device_calls["device"].type == "cuda"
|
|
|
|
|
|
def test_compute_init_assertions(monkeypatch):
|
|
monkeypatch.setattr(common.torch.cuda, "is_available", lambda: False)
|
|
with pytest.raises(AssertionError):
|
|
common.compute_init("cuda")
|
|
monkeypatch.setattr(common.torch.backends.mps, "is_available", lambda: False)
|
|
with pytest.raises(AssertionError):
|
|
common.compute_init("mps")
|
|
with pytest.raises(AssertionError):
|
|
common.compute_init("bad")
|
|
|
|
|
|
def test_dummy_wandb_and_peak_flops(monkeypatch):
|
|
w = common.DummyWandb()
|
|
assert w.log() is None
|
|
assert w.finish() is None
|
|
|
|
assert common.get_peak_flops("NVIDIA H100 NVL") == 835e12
|
|
assert common.get_peak_flops("A100-SXM") == 312e12
|
|
|
|
class Props:
|
|
max_compute_units = 8
|
|
|
|
class XPU:
|
|
@staticmethod
|
|
def get_device_properties(_name):
|
|
return Props()
|
|
|
|
monkeypatch.setattr(common.torch, "xpu", XPU())
|
|
pvc = common.get_peak_flops("Data Center GPU Max 1550")
|
|
assert pvc == 512 * 8 * 1300 * 10**6
|
|
|
|
warned = {}
|
|
monkeypatch.setattr(common.logger, "warning", lambda msg: warned.__setitem__("msg", msg))
|
|
unknown = common.get_peak_flops("mystery gpu")
|
|
assert unknown == float("inf")
|
|
assert "Peak flops undefined" in warned["msg"]
|