mirror of
https://github.com/karpathy/nanochat.git
synced 2025-12-06 04:12:13 +00:00
Compare commits
3 Commits
75b89c3666
...
59ed9392ed
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
59ed9392ed | ||
|
|
449494c8b6 | ||
|
|
6587063479 |
|
|
@ -1,19 +1,4 @@
|
|||
repos:
|
||||
- repo: https://github.com/PyCQA/autoflake
|
||||
rev: v2.3.1
|
||||
hooks:
|
||||
- id: autoflake
|
||||
|
||||
- repo: https://github.com/asottile/pyupgrade
|
||||
rev: v3.21.2
|
||||
hooks:
|
||||
- id: pyupgrade
|
||||
|
||||
- repo: https://github.com/pre-commit/mirrors-isort
|
||||
rev: v5.10.1
|
||||
hooks:
|
||||
- id: isort
|
||||
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v6.0.0
|
||||
hooks:
|
||||
|
|
@ -28,10 +13,11 @@ repos:
|
|||
- id: mixed-line-ending
|
||||
args: [--fix=lf]
|
||||
|
||||
- repo: https://github.com/psf/black
|
||||
rev: 25.11.0
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.14.8
|
||||
hooks:
|
||||
- id: black
|
||||
- id: ruff-check
|
||||
- id: ruff-format
|
||||
|
||||
- repo: https://github.com/codespell-project/codespell
|
||||
rev: v2.4.1 # Use the latest stable version
|
||||
|
|
|
|||
|
|
@ -133,11 +133,10 @@ Linting and formatting are enforced with [pre-commit](https://pre-commit.com/) b
|
|||
|
||||
Hook coverage (auto-fixes most issues; review and stage the changes afterward):
|
||||
|
||||
- [`autoflake`](https://github.com/PyCQA/autoflake): strips unused imports and variables.
|
||||
- [`pyupgrade`](https://github.com/asottile/pyupgrade): rewrites code to the latest supported Python syntax.
|
||||
- [`isort`](https://github.com/PyCQA/isort): keeps import blocks consistently ordered and grouped.
|
||||
- [`ruff`](https://github.com/astral-sh/ruff): a fast Rust-based linter and formatter that replaces multiple tools:
|
||||
- **Linting** (`ruff-check`): removes unused imports (like autoflake), upgrades syntax (like pyupgrade), and sorts imports (like isort).
|
||||
- **Formatting** (`ruff-format`): applies consistent code formatting (like black), with quote style preserved.
|
||||
- [`pre-commit-hooks`](https://github.com/pre-commit/pre-commit-hooks): repo hygiene (trim trailing whitespace, enforce LF endings/newlines, detect merge conflicts, block oversized files).
|
||||
- [`black`](https://github.com/psf/black): applies the canonical Python formatting.
|
||||
- [`codespell`](https://github.com/codespell-project/codespell): catches common spelling mistakes in code and docs (add false positives to `[tool.codespell].ignore-words-list` in `pyproject.toml`).
|
||||
|
||||
## File structure
|
||||
|
|
|
|||
|
|
@ -269,9 +269,7 @@ hola hola, todo bien?
|
|||
hej, hur är läget
|
||||
ahoj, jak se máš
|
||||
γειά, τι κάνεις
|
||||
""".strip().split(
|
||||
"\n"
|
||||
)
|
||||
""".strip().split("\n")
|
||||
|
||||
prompt = prompt.replace("%README%", readme)
|
||||
|
||||
|
|
@ -294,10 +292,7 @@ response_format = {
|
|||
"type": "string",
|
||||
"description": "The role of the speaker, either 'user' or 'assistant'",
|
||||
},
|
||||
"content": {
|
||||
"type": "string",
|
||||
"description": "The message content",
|
||||
},
|
||||
"content": {"type": "string", "description": "The message content"},
|
||||
},
|
||||
"required": ["role", "content"],
|
||||
"additionalProperties": False,
|
||||
|
|
@ -331,15 +326,15 @@ def generate_conversation(idx: int):
|
|||
user_first_prompt = "\n".join(rng.choice(user_first_prompts) for _ in range(5))
|
||||
payload = copy.deepcopy(base_payload)
|
||||
modified_prompt = prompt.replace("%USER_FIRST_PROMPTS%", user_first_prompt)
|
||||
payload["messages"] = [{"role": "user", "content": modified_prompt}]
|
||||
payload['messages'] = [{"role": "user", "content": modified_prompt}]
|
||||
|
||||
response = requests.post(url, headers=headers, json=payload)
|
||||
result = response.json()
|
||||
content = result["choices"][0]["message"]["content"]
|
||||
content = result['choices'][0]['message']['content']
|
||||
|
||||
# Parse the JSON response and unpack the messages
|
||||
conversation_data = json.loads(content)
|
||||
messages = conversation_data["messages"]
|
||||
messages = conversation_data['messages']
|
||||
|
||||
return messages
|
||||
|
||||
|
|
@ -359,11 +354,8 @@ print(f"Generating {num_conversations} conversations with {num_workers} workers.
|
|||
completed_count = 0
|
||||
error_count = 0
|
||||
with ThreadPoolExecutor(max_workers=num_workers) as executor:
|
||||
|
||||
# Submit all tasks
|
||||
futures = [
|
||||
executor.submit(generate_conversation, idx) for idx in range(num_conversations)
|
||||
]
|
||||
futures = [executor.submit(generate_conversation, idx) for idx in range(num_conversations)]
|
||||
|
||||
# Process results as they complete
|
||||
for future in as_completed(futures):
|
||||
|
|
@ -373,13 +365,13 @@ with ThreadPoolExecutor(max_workers=num_workers) as executor:
|
|||
# Lightly validate the conversation structure
|
||||
for i, message in enumerate(messages):
|
||||
expected_role = "user" if i % 2 == 0 else "assistant"
|
||||
assert (
|
||||
message["role"] == expected_role
|
||||
), f"Message {i} has role {message['role']} but should be {expected_role}"
|
||||
assert message['role'] == expected_role, (
|
||||
f"Message {i} has role {message['role']} but should be {expected_role}"
|
||||
)
|
||||
|
||||
# If all looks good, write the messages to file
|
||||
with open(output_file, "a") as f:
|
||||
f.write(json.dumps(messages) + "\n")
|
||||
with open(output_file, 'a') as f:
|
||||
f.write(json.dumps(messages) + '\n')
|
||||
completed_count += 1
|
||||
print(f"✓ Saved conversation {completed_count}/{num_conversations}")
|
||||
|
||||
|
|
|
|||
|
|
@ -48,14 +48,12 @@ total_docs_processed = 0
|
|||
total_time_spent = 0
|
||||
t0 = time.time()
|
||||
for doc in ds:
|
||||
text = doc["text"]
|
||||
text = doc['text']
|
||||
shard_docs.append(text)
|
||||
shard_characters += len(text)
|
||||
collected_enough_chars = shard_characters >= chars_per_shard
|
||||
docs_multiple_of_row_group_size = len(shard_docs) % row_group_size == 0
|
||||
if (
|
||||
collected_enough_chars and docs_multiple_of_row_group_size
|
||||
): # leads to ~100MB of text (compressed)
|
||||
if collected_enough_chars and docs_multiple_of_row_group_size: # leads to ~100MB of text (compressed)
|
||||
shard_path = os.path.join(output_dir, f"shard_{shard_index:05d}.parquet")
|
||||
shard_table = pa.Table.from_pydict({"text": shard_docs})
|
||||
pq.write_table(
|
||||
|
|
|
|||
|
|
@ -40,35 +40,33 @@ class DistAdamW(torch.optim.Optimizer):
|
|||
rank_size = grad.shape[0] // world_size
|
||||
grad_slice = torch.empty_like(grad[:rank_size])
|
||||
reduce_scatter_futures.append(
|
||||
dist.reduce_scatter_tensor(
|
||||
grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True
|
||||
).get_future()
|
||||
dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()
|
||||
)
|
||||
grad_slices.append(grad_slice)
|
||||
|
||||
idx = 0
|
||||
for group in self.param_groups:
|
||||
beta1, beta2 = group["betas"]
|
||||
eps = group["eps"]
|
||||
wd = group["weight_decay"]
|
||||
params = group["params"]
|
||||
beta1, beta2 = group['betas']
|
||||
eps = group['eps']
|
||||
wd = group['weight_decay']
|
||||
params = group['params']
|
||||
for base in range(len(params)):
|
||||
reduce_scatter_futures[idx].wait()
|
||||
p = params[base]
|
||||
rank_size = p.shape[0] // world_size
|
||||
p_slice = p[rank * rank_size : (rank + 1) * rank_size]
|
||||
lr = group["lr"] * getattr(p, "lr_mul", 1.0)
|
||||
lr = group['lr'] * getattr(p, "lr_mul", 1.0)
|
||||
state = self.state[p]
|
||||
g_slice = grad_slices[idx]
|
||||
# State init
|
||||
if not state:
|
||||
state["step"] = torch.tensor(0, dtype=torch.int64, device=p.device)
|
||||
state["exp_avg"] = torch.zeros_like(p_slice)
|
||||
state["exp_avg_sq"] = torch.zeros_like(p_slice)
|
||||
exp_avg = state["exp_avg"]
|
||||
exp_avg_sq = state["exp_avg_sq"]
|
||||
state["step"] += 1
|
||||
t = state["step"]
|
||||
state['step'] = torch.tensor(0, dtype=torch.int64, device=p.device)
|
||||
state['exp_avg'] = torch.zeros_like(p_slice)
|
||||
state['exp_avg_sq'] = torch.zeros_like(p_slice)
|
||||
exp_avg = state['exp_avg']
|
||||
exp_avg_sq = state['exp_avg_sq']
|
||||
state['step'] += 1
|
||||
t = state['step']
|
||||
# weight decay
|
||||
if wd != 0:
|
||||
eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0)
|
||||
|
|
@ -85,7 +83,5 @@ class DistAdamW(torch.optim.Optimizer):
|
|||
update = exp_avg.div(denom).mul_(step_size)
|
||||
p_slice.add_(other=update, alpha=-1.0)
|
||||
idx += 1
|
||||
all_reduce_futures.append(
|
||||
dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()
|
||||
)
|
||||
all_reduce_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future())
|
||||
torch.futures.collect_all(all_reduce_futures).wait()
|
||||
|
|
|
|||
|
|
@ -20,13 +20,11 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
def log0(message):
|
||||
if int(os.environ.get("RANK", 0)) == 0:
|
||||
if int(os.environ.get('RANK', 0)) == 0:
|
||||
logger.info(message)
|
||||
|
||||
|
||||
def save_checkpoint(
|
||||
checkpoint_dir, step, model_data, optimizer_data, meta_data, rank=0
|
||||
):
|
||||
def save_checkpoint(checkpoint_dir, step, model_data, optimizer_data, meta_data, rank=0):
|
||||
if rank == 0:
|
||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||
# Save the model state parameters
|
||||
|
|
@ -40,9 +38,7 @@ def save_checkpoint(
|
|||
logger.info(f"Saved metadata to: {meta_path}")
|
||||
# Note that optimizer state is sharded across ranks, so each rank must save its own.
|
||||
if optimizer_data is not None:
|
||||
optimizer_path = os.path.join(
|
||||
checkpoint_dir, f"optim_{step:06d}_rank{rank:d}.pt"
|
||||
)
|
||||
optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}_rank{rank:d}.pt")
|
||||
torch.save(optimizer_data, optimizer_path)
|
||||
logger.info(f"Saved optimizer state to: {optimizer_path}")
|
||||
|
||||
|
|
@ -54,9 +50,7 @@ def load_checkpoint(checkpoint_dir, step, device, load_optimizer=False, rank=0):
|
|||
# Load the optimizer state if requested
|
||||
optimizer_data = None
|
||||
if load_optimizer:
|
||||
optimizer_path = os.path.join(
|
||||
checkpoint_dir, f"optim_{step:06d}_rank{rank:d}.pt"
|
||||
)
|
||||
optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}_rank{rank:d}.pt")
|
||||
optimizer_data = torch.load(optimizer_path, map_location=device)
|
||||
# Load the metadata
|
||||
meta_path = os.path.join(checkpoint_dir, f"meta_{step:06d}.json")
|
||||
|
|
@ -74,15 +68,10 @@ def build_model(checkpoint_dir, step, device, phase):
|
|||
- meta data saved during base model training
|
||||
"""
|
||||
assert phase in ["train", "eval"], f"Invalid phase: {phase}"
|
||||
model_data, optimizer_data, meta_data = load_checkpoint(
|
||||
checkpoint_dir, step, device, load_optimizer=False
|
||||
)
|
||||
model_data, optimizer_data, meta_data = load_checkpoint(checkpoint_dir, step, device, load_optimizer=False)
|
||||
if device.type in {"cpu", "mps"}:
|
||||
# Convert bfloat16 tensors to float for CPU inference
|
||||
model_data = {
|
||||
k: v.float() if v.dtype == torch.bfloat16 else v
|
||||
for k, v in model_data.items()
|
||||
}
|
||||
model_data = {k: v.float() if v.dtype == torch.bfloat16 else v for k, v in model_data.items()}
|
||||
# Hack: fix torch compile issue, which prepends all keys with _orig_mod.
|
||||
model_data = {k.removeprefix("_orig_mod."): v for k, v in model_data.items()}
|
||||
model_config_kwargs = meta_data["model_config"]
|
||||
|
|
@ -108,11 +97,7 @@ def build_model(checkpoint_dir, step, device, phase):
|
|||
|
||||
def find_largest_model(checkpoint_dir):
|
||||
# attempt to guess the model tag: take the biggest model available
|
||||
model_tags = [
|
||||
f
|
||||
for f in os.listdir(checkpoint_dir)
|
||||
if os.path.isdir(os.path.join(checkpoint_dir, f))
|
||||
]
|
||||
model_tags = [f for f in os.listdir(checkpoint_dir) if os.path.isdir(os.path.join(checkpoint_dir, f))]
|
||||
if not model_tags:
|
||||
raise FileNotFoundError(f"No checkpoints found in {checkpoint_dir}")
|
||||
# 1) normally all model tags are of the form d<number>, try that first:
|
||||
|
|
@ -126,9 +111,7 @@ def find_largest_model(checkpoint_dir):
|
|||
candidates.sort(key=lambda x: x[0], reverse=True)
|
||||
return candidates[0][1]
|
||||
# 2) if that failed, take the most recently updated model:
|
||||
model_tags.sort(
|
||||
key=lambda x: os.path.getmtime(os.path.join(checkpoint_dir, x)), reverse=True
|
||||
)
|
||||
model_tags.sort(key=lambda x: os.path.getmtime(os.path.join(checkpoint_dir, x)), reverse=True)
|
||||
return model_tags[0]
|
||||
|
||||
|
||||
|
|
@ -137,9 +120,7 @@ def find_last_step(checkpoint_dir):
|
|||
checkpoint_files = glob.glob(os.path.join(checkpoint_dir, "model_*.pt"))
|
||||
if not checkpoint_files:
|
||||
raise FileNotFoundError(f"No checkpoints found in {checkpoint_dir}")
|
||||
last_step = int(
|
||||
max(os.path.basename(f).split("_")[-1].split(".")[0] for f in checkpoint_files)
|
||||
)
|
||||
last_step = int(max(os.path.basename(f).split("_")[-1].split(".")[0] for f in checkpoint_files))
|
||||
return last_step
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -17,45 +17,33 @@ class ColoredFormatter(logging.Formatter):
|
|||
|
||||
# ANSI color codes
|
||||
COLORS = {
|
||||
"DEBUG": "\033[36m", # Cyan
|
||||
"INFO": "\033[32m", # Green
|
||||
"WARNING": "\033[33m", # Yellow
|
||||
"ERROR": "\033[31m", # Red
|
||||
"CRITICAL": "\033[35m", # Magenta
|
||||
'DEBUG': '\033[36m', # Cyan
|
||||
'INFO': '\033[32m', # Green
|
||||
'WARNING': '\033[33m', # Yellow
|
||||
'ERROR': '\033[31m', # Red
|
||||
'CRITICAL': '\033[35m', # Magenta
|
||||
}
|
||||
RESET = "\033[0m"
|
||||
BOLD = "\033[1m"
|
||||
RESET = '\033[0m'
|
||||
BOLD = '\033[1m'
|
||||
|
||||
def format(self, record):
|
||||
# Add color to the level name
|
||||
levelname = record.levelname
|
||||
if levelname in self.COLORS:
|
||||
record.levelname = (
|
||||
f"{self.COLORS[levelname]}{self.BOLD}{levelname}{self.RESET}"
|
||||
)
|
||||
record.levelname = f"{self.COLORS[levelname]}{self.BOLD}{levelname}{self.RESET}"
|
||||
# Format the message
|
||||
message = super().format(record)
|
||||
# Add color to specific parts of the message
|
||||
if levelname == "INFO":
|
||||
if levelname == 'INFO':
|
||||
# Highlight numbers and percentages
|
||||
message = re.sub(
|
||||
r"(\d+\.?\d*\s*(?:GB|MB|%|docs))",
|
||||
rf"{self.BOLD}\1{self.RESET}",
|
||||
message,
|
||||
)
|
||||
message = re.sub(
|
||||
r"(Shard \d+)",
|
||||
rf'{self.COLORS["INFO"]}{self.BOLD}\1{self.RESET}',
|
||||
message,
|
||||
)
|
||||
message = re.sub(r'(\d+\.?\d*\s*(?:GB|MB|%|docs))', rf'{self.BOLD}\1{self.RESET}', message)
|
||||
message = re.sub(r'(Shard \d+)', rf'{self.COLORS["INFO"]}{self.BOLD}\1{self.RESET}', message)
|
||||
return message
|
||||
|
||||
|
||||
def setup_default_logging():
|
||||
handler = logging.StreamHandler()
|
||||
handler.setFormatter(
|
||||
ColoredFormatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
||||
)
|
||||
handler.setFormatter(ColoredFormatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
|
||||
logging.basicConfig(level=logging.INFO, handlers=[handler])
|
||||
|
||||
|
||||
|
|
@ -101,7 +89,7 @@ def download_file_with_lock(url, filename, postprocess_fn=None):
|
|||
content = response.read() # bytes
|
||||
|
||||
# Write to local file
|
||||
with open(file_path, "wb") as f:
|
||||
with open(file_path, 'wb') as f:
|
||||
f.write(content)
|
||||
print(f"Downloaded to {file_path}")
|
||||
|
||||
|
|
@ -113,7 +101,7 @@ def download_file_with_lock(url, filename, postprocess_fn=None):
|
|||
|
||||
|
||||
def print0(s="", **kwargs):
|
||||
ddp_rank = int(os.environ.get("RANK", 0))
|
||||
ddp_rank = int(os.environ.get('RANK', 0))
|
||||
if ddp_rank == 0:
|
||||
print(s, **kwargs)
|
||||
|
||||
|
|
@ -135,15 +123,15 @@ def print_banner():
|
|||
|
||||
def is_ddp():
|
||||
# TODO is there a proper way
|
||||
return int(os.environ.get("RANK", -1)) != -1
|
||||
return int(os.environ.get('RANK', -1)) != -1
|
||||
|
||||
|
||||
def get_dist_info():
|
||||
if is_ddp():
|
||||
assert all(var in os.environ for var in ["RANK", "LOCAL_RANK", "WORLD_SIZE"])
|
||||
ddp_rank = int(os.environ["RANK"])
|
||||
ddp_local_rank = int(os.environ["LOCAL_RANK"])
|
||||
ddp_world_size = int(os.environ["WORLD_SIZE"])
|
||||
assert all(var in os.environ for var in ['RANK', 'LOCAL_RANK', 'WORLD_SIZE'])
|
||||
ddp_rank = int(os.environ['RANK'])
|
||||
ddp_local_rank = int(os.environ['LOCAL_RANK'])
|
||||
ddp_world_size = int(os.environ['WORLD_SIZE'])
|
||||
return True, ddp_rank, ddp_local_rank, ddp_world_size
|
||||
else:
|
||||
return False, 0, 0, 1
|
||||
|
|
@ -166,13 +154,13 @@ def compute_init(device_type="cuda"): # cuda|cpu|mps
|
|||
|
||||
assert device_type in ["cuda", "mps", "cpu"], "Invalid device type atm"
|
||||
if device_type == "cuda":
|
||||
assert (
|
||||
torch.cuda.is_available()
|
||||
), "Your PyTorch installation is not configured for CUDA but device_type is 'cuda'"
|
||||
assert torch.cuda.is_available(), (
|
||||
"Your PyTorch installation is not configured for CUDA but device_type is 'cuda'"
|
||||
)
|
||||
if device_type == "mps":
|
||||
assert (
|
||||
torch.backends.mps.is_available()
|
||||
), "Your PyTorch installation is not configured for MPS but device_type is 'mps'"
|
||||
assert torch.backends.mps.is_available(), (
|
||||
"Your PyTorch installation is not configured for MPS but device_type is 'mps'"
|
||||
)
|
||||
|
||||
# Reproducibility
|
||||
# Note that we set the global seeds here, but most of the code uses explicit rng objects.
|
||||
|
|
@ -185,9 +173,7 @@ def compute_init(device_type="cuda"): # cuda|cpu|mps
|
|||
|
||||
# Precision
|
||||
if device_type == "cuda":
|
||||
torch.set_float32_matmul_precision(
|
||||
"high"
|
||||
) # uses tf32 instead of fp32 for matmuls
|
||||
torch.set_float32_matmul_precision("high") # uses tf32 instead of fp32 for matmuls
|
||||
|
||||
# Distributed setup: Distributed Data Parallel (DDP), optional, and requires CUDA
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
|
||||
|
|
|
|||
|
|
@ -20,15 +20,15 @@ from ast import literal_eval
|
|||
|
||||
|
||||
def print0(s="", **kwargs):
|
||||
ddp_rank = int(os.environ.get("RANK", 0))
|
||||
ddp_rank = int(os.environ.get('RANK', 0))
|
||||
if ddp_rank == 0:
|
||||
print(s, **kwargs)
|
||||
|
||||
|
||||
for arg in sys.argv[1:]:
|
||||
if "=" not in arg:
|
||||
if '=' not in arg:
|
||||
# assume it's the name of a config file
|
||||
assert not arg.startswith("--")
|
||||
assert not arg.startswith('--')
|
||||
config_file = arg
|
||||
print0(f"Overriding config with {config_file}:")
|
||||
with open(config_file) as f:
|
||||
|
|
@ -36,8 +36,8 @@ for arg in sys.argv[1:]:
|
|||
exec(open(config_file).read())
|
||||
else:
|
||||
# assume it's a --key=value argument
|
||||
assert arg.startswith("--")
|
||||
key, val = arg.split("=")
|
||||
assert arg.startswith('--')
|
||||
key, val = arg.split('=')
|
||||
key = key[2:]
|
||||
if key in globals():
|
||||
try:
|
||||
|
|
@ -50,9 +50,7 @@ for arg in sys.argv[1:]:
|
|||
if globals()[key] is not None:
|
||||
attempt_type = type(attempt)
|
||||
default_type = type(globals()[key])
|
||||
assert (
|
||||
attempt_type == default_type
|
||||
), f"Type mismatch: {attempt_type} != {default_type}"
|
||||
assert attempt_type == default_type, f"Type mismatch: {attempt_type} != {default_type}"
|
||||
# cross fingers
|
||||
print0(f"Overriding: {key} = {attempt}")
|
||||
globals()[key] = attempt
|
||||
|
|
|
|||
|
|
@ -26,12 +26,8 @@ def render_prompts_mc(item, continuation_delimiter, fewshot_examples=None):
|
|||
{{ item.query }}{{ continuation_delimiter }}{{ choice }}""".strip()
|
||||
template = Template(template_str)
|
||||
fewshot_examples = fewshot_examples or []
|
||||
context = {
|
||||
"fewshot_examples": fewshot_examples,
|
||||
"continuation_delimiter": continuation_delimiter,
|
||||
"item": item,
|
||||
}
|
||||
prompts = [template.render(choice=choice, **context) for choice in item["choices"]]
|
||||
context = {'fewshot_examples': fewshot_examples, 'continuation_delimiter': continuation_delimiter, 'item': item}
|
||||
prompts = [template.render(choice=choice, **context) for choice in item['choices']]
|
||||
return prompts
|
||||
|
||||
|
||||
|
|
@ -45,15 +41,8 @@ def render_prompts_schema(item, continuation_delimiter, fewshot_examples=None):
|
|||
{{ context }}{{ continuation_delimiter }}{{ item.continuation }}""".strip()
|
||||
template = Template(template_str)
|
||||
fewshot_examples = fewshot_examples or []
|
||||
context = {
|
||||
"fewshot_examples": fewshot_examples,
|
||||
"continuation_delimiter": continuation_delimiter,
|
||||
"item": item,
|
||||
}
|
||||
prompts = [
|
||||
template.render(context=context_option, **context)
|
||||
for context_option in item["context_options"]
|
||||
]
|
||||
context = {'fewshot_examples': fewshot_examples, 'continuation_delimiter': continuation_delimiter, 'item': item}
|
||||
prompts = [template.render(context=context_option, **context) for context_option in item['context_options']]
|
||||
return prompts
|
||||
|
||||
|
||||
|
|
@ -71,11 +60,7 @@ def render_prompts_lm(item, continuation_delimiter, fewshot_examples=None):
|
|||
{{ item.context | trim }}{{ continuation_delimiter }}{% if include_continuation %}{{ item.continuation }}{% endif %}""".strip()
|
||||
template = Template(template_str)
|
||||
fewshot_examples = fewshot_examples or []
|
||||
context = {
|
||||
"fewshot_examples": fewshot_examples,
|
||||
"continuation_delimiter": continuation_delimiter,
|
||||
"item": item,
|
||||
}
|
||||
context = {'fewshot_examples': fewshot_examples, 'continuation_delimiter': continuation_delimiter, 'item': item}
|
||||
# Return two prompts: without and with the continuation
|
||||
prompt_without = template.render(include_continuation=False, **context)
|
||||
prompt_with = template.render(include_continuation=True, **context)
|
||||
|
|
@ -87,13 +72,13 @@ def render_prompts_lm(item, continuation_delimiter, fewshot_examples=None):
|
|||
return [prompt_without, prompt_with]
|
||||
|
||||
|
||||
def find_common_length(token_sequences, direction="left"):
|
||||
def find_common_length(token_sequences, direction='left'):
|
||||
"""
|
||||
Find the length of the common prefix or suffix across token sequences
|
||||
- direction: 'left' for prefix, 'right' for suffix
|
||||
"""
|
||||
min_len = min(len(seq) for seq in token_sequences)
|
||||
indices = {"left": range(min_len), "right": range(-1, -min_len - 1, -1)}[direction]
|
||||
indices = {'left': range(min_len), 'right': range(-1, -min_len - 1, -1)}[direction]
|
||||
# Find the first position where the token sequences differ
|
||||
for i, idx in enumerate(indices):
|
||||
token = token_sequences[0][idx]
|
||||
|
|
@ -115,7 +100,7 @@ def batch_sequences_mc(tokenizer, prompts):
|
|||
# In multiple choice, contexts are the same but the continuation is different (common prefix)
|
||||
tokens = tokenizer(prompts, prepend=tokenizer.get_bos_token_id())
|
||||
# figure out the start and end of each continuation
|
||||
answer_start_idx = find_common_length(tokens, direction="left")
|
||||
answer_start_idx = find_common_length(tokens, direction='left')
|
||||
start_indices = [answer_start_idx] * len(prompts)
|
||||
end_indices = [len(x) for x in tokens]
|
||||
return tokens, start_indices, end_indices
|
||||
|
|
@ -125,7 +110,7 @@ def batch_sequences_schema(tokenizer, prompts):
|
|||
# In schema tasks, contexts vary but continuation is the same (common suffix)
|
||||
tokens = tokenizer(prompts, prepend=tokenizer.get_bos_token_id())
|
||||
# figure out the start and end of each context
|
||||
suffix_length = find_common_length(tokens, direction="right")
|
||||
suffix_length = find_common_length(tokens, direction='right')
|
||||
end_indices = [len(x) for x in tokens]
|
||||
start_indices = [ei - suffix_length for ei in end_indices]
|
||||
return tokens, start_indices, end_indices
|
||||
|
|
@ -136,12 +121,8 @@ def batch_sequences_lm(tokenizer, prompts):
|
|||
tokens = tokenizer(prompts, prepend=tokenizer.get_bos_token_id())
|
||||
tokens_without, tokens_with = tokens
|
||||
start_idx, end_idx = len(tokens_without), len(tokens_with)
|
||||
assert (
|
||||
start_idx < end_idx
|
||||
), "prompt without is supposed to be a prefix of prompt with"
|
||||
assert (
|
||||
tokens_without == tokens_with[:start_idx]
|
||||
), "prompt without is supposed to be a prefix of prompt with"
|
||||
assert start_idx < end_idx, "prompt without is supposed to be a prefix of prompt with"
|
||||
assert tokens_without == tokens_with[:start_idx], "prompt without is supposed to be a prefix of prompt with"
|
||||
# we only need the with continuation prompt in the LM task, i.e. batch size of 1
|
||||
return [tokens_with], [start_idx], [end_idx]
|
||||
|
||||
|
|
@ -158,12 +139,10 @@ def forward_model(model, input_ids):
|
|||
target_ids = torch.roll(input_ids, shifts=-1, dims=1)
|
||||
# Calculate cross entropy at all positions
|
||||
losses = torch.nn.functional.cross_entropy(
|
||||
outputs.view(batch_size * seq_len, -1),
|
||||
target_ids.view(batch_size * seq_len),
|
||||
reduction="none",
|
||||
outputs.view(batch_size * seq_len, -1), target_ids.view(batch_size * seq_len), reduction='none'
|
||||
).view(batch_size, seq_len)
|
||||
# Set the last column to be nan because there is no autoregressive loss there
|
||||
losses[:, -1] = float("nan")
|
||||
losses[:, -1] = float('nan')
|
||||
# Get the argmax predictions at each position
|
||||
predictions = outputs.argmax(dim=-1)
|
||||
return losses, predictions
|
||||
|
|
@ -173,9 +152,9 @@ def forward_model(model, input_ids):
|
|||
def evaluate_example(idx, model, tokenizer, data, device, task_meta):
|
||||
"""Evaluate a single example, return True if correct, False otherwise"""
|
||||
item = data[idx]
|
||||
task_type = task_meta["task_type"]
|
||||
num_fewshot = task_meta["num_fewshot"]
|
||||
continuation_delimiter = task_meta["continuation_delimiter"]
|
||||
task_type = task_meta['task_type']
|
||||
num_fewshot = task_meta['num_fewshot']
|
||||
continuation_delimiter = task_meta['continuation_delimiter']
|
||||
|
||||
# Sample few-shot examples (excluding current item)
|
||||
fewshot_examples = []
|
||||
|
|
@ -186,13 +165,13 @@ def evaluate_example(idx, model, tokenizer, data, device, task_meta):
|
|||
fewshot_examples = [data[i] for i in fewshot_indices]
|
||||
|
||||
# Render prompts and batch sequences based on task type
|
||||
if task_type == "multiple_choice":
|
||||
if task_type == 'multiple_choice':
|
||||
prompts = render_prompts_mc(item, continuation_delimiter, fewshot_examples)
|
||||
tokens, start_idxs, end_idxs = batch_sequences_mc(tokenizer, prompts)
|
||||
elif task_type == "schema":
|
||||
elif task_type == 'schema':
|
||||
prompts = render_prompts_schema(item, continuation_delimiter, fewshot_examples)
|
||||
tokens, start_idxs, end_idxs = batch_sequences_schema(tokenizer, prompts)
|
||||
elif task_type == "language_modeling":
|
||||
elif task_type == 'language_modeling':
|
||||
prompts = render_prompts_lm(item, continuation_delimiter, fewshot_examples)
|
||||
tokens, start_idxs, end_idxs = batch_sequences_lm(tokenizer, prompts)
|
||||
else:
|
||||
|
|
@ -200,7 +179,7 @@ def evaluate_example(idx, model, tokenizer, data, device, task_meta):
|
|||
|
||||
# Some models can't forward sequences beyond a certain length (e.g. GPT-2)
|
||||
# In these cases, we have to truncate sequences to max length and adjust the indices
|
||||
if hasattr(model, "max_seq_len") and model.max_seq_len is not None:
|
||||
if hasattr(model, 'max_seq_len') and model.max_seq_len is not None:
|
||||
max_tokens = model.max_seq_len
|
||||
new_tokens, new_start_idxs, new_end_idxs = [], [], []
|
||||
for t, s, e in zip(tokens, start_idxs, end_idxs):
|
||||
|
|
@ -226,7 +205,7 @@ def evaluate_example(idx, model, tokenizer, data, device, task_meta):
|
|||
losses, predictions = forward_model(model, input_ids)
|
||||
|
||||
# See if the losses/predictions come out correctly
|
||||
if task_type == "language_modeling":
|
||||
if task_type == 'language_modeling':
|
||||
# language modeling task is currently always batch size 1
|
||||
si = start_idxs[0]
|
||||
ei = end_idxs[0]
|
||||
|
|
@ -234,14 +213,11 @@ def evaluate_example(idx, model, tokenizer, data, device, task_meta):
|
|||
predicted_tokens = predictions[0, si - 1 : ei - 1]
|
||||
actual_tokens = input_ids[0, si:ei]
|
||||
is_correct = torch.all(predicted_tokens == actual_tokens).item()
|
||||
elif task_type in ["multiple_choice", "schema"]:
|
||||
elif task_type in ['multiple_choice', 'schema']:
|
||||
# For MC/schema: find the option with lowest average loss
|
||||
mean_losses = [
|
||||
losses[i, si - 1 : ei - 1].mean().item()
|
||||
for i, (si, ei) in enumerate(zip(start_idxs, end_idxs))
|
||||
]
|
||||
mean_losses = [losses[i, si - 1 : ei - 1].mean().item() for i, (si, ei) in enumerate(zip(start_idxs, end_idxs))]
|
||||
pred_idx = mean_losses.index(min(mean_losses))
|
||||
is_correct = pred_idx == item["gold"]
|
||||
is_correct = pred_idx == item['gold']
|
||||
else:
|
||||
raise ValueError(f"Unsupported task type: {task_type}")
|
||||
|
||||
|
|
|
|||
|
|
@ -9,13 +9,7 @@ from nanochat.tokenizer import get_tokenizer
|
|||
|
||||
|
||||
def tokenizing_distributed_data_loader_with_state(
|
||||
B,
|
||||
T,
|
||||
split,
|
||||
tokenizer_threads=4,
|
||||
tokenizer_batch_size=128,
|
||||
device="cuda",
|
||||
resume_state_dict=None,
|
||||
B, T, split, tokenizer_threads=4, tokenizer_batch_size=128, device="cuda", resume_state_dict=None
|
||||
):
|
||||
"""
|
||||
Stream pretraining text from parquet files, tokenize, yield training batches.
|
||||
|
|
@ -37,12 +31,8 @@ def tokenizing_distributed_data_loader_with_state(
|
|||
def document_batches():
|
||||
parquet_paths = list_parquet_files()
|
||||
parquet_paths = parquet_paths[:-1] if split == "train" else parquet_paths[-1:]
|
||||
resume_pq_idx = (
|
||||
resume_state_dict["pq_idx"] if resume_state_dict is not None else 0
|
||||
)
|
||||
resume_rg_idx = (
|
||||
resume_state_dict["rg_idx"] if resume_state_dict is not None else None
|
||||
)
|
||||
resume_pq_idx = resume_state_dict["pq_idx"] if resume_state_dict is not None else 0
|
||||
resume_rg_idx = resume_state_dict["rg_idx"] if resume_state_dict is not None else None
|
||||
pq_idx = resume_pq_idx # we kick off parquet files at the resume index (or by default just 0)
|
||||
while True: # iterate infinitely (multi-epoch)
|
||||
while pq_idx < len(parquet_paths): # iterate over all parquet files
|
||||
|
|
@ -51,21 +41,15 @@ def tokenizing_distributed_data_loader_with_state(
|
|||
# Start from resume point if resuming on same file, otherwise from DDP rank
|
||||
# I know this state resumption is a little bit tricky and a little bit hacky... sigh.
|
||||
if resume_rg_idx is not None:
|
||||
base_idx = (
|
||||
resume_rg_idx // ddp_world_size
|
||||
) # in units of ddp_world_size
|
||||
base_idx = resume_rg_idx // ddp_world_size # in units of ddp_world_size
|
||||
base_idx += 1 # advance by 1 so that we definitely don't repeat data after resuming
|
||||
rg_idx = base_idx * ddp_world_size + ddp_rank
|
||||
resume_rg_idx = (
|
||||
None # set to None as we only want to do this a single time
|
||||
)
|
||||
resume_rg_idx = None # set to None as we only want to do this a single time
|
||||
else:
|
||||
rg_idx = ddp_rank
|
||||
while rg_idx < pf.num_row_groups:
|
||||
rg = pf.read_row_group(rg_idx)
|
||||
batch = rg.column(
|
||||
"text"
|
||||
).to_pylist() # each batch is a parquet group, e.g. 1024 rows
|
||||
batch = rg.column('text').to_pylist() # each batch is a parquet group, e.g. 1024 rows
|
||||
# the tokenizer encode might want to go in even smaller batches, e.g. 128 rows
|
||||
for i in range(0, len(batch), tokenizer_batch_size):
|
||||
yield batch[i : i + tokenizer_batch_size], (pq_idx, rg_idx)
|
||||
|
|
@ -85,28 +69,20 @@ def tokenizing_distributed_data_loader_with_state(
|
|||
# Accumulate enough tokens for one iteration before yielding.
|
||||
while len(token_buffer) < needed_tokens:
|
||||
doc_batch, (pq_idx, rg_idx) = next(batches)
|
||||
token_lists = tokenizer.encode(
|
||||
doc_batch, prepend=bos_token, num_threads=tokenizer_threads
|
||||
)
|
||||
token_lists = tokenizer.encode(doc_batch, prepend=bos_token, num_threads=tokenizer_threads)
|
||||
for tokens in token_lists:
|
||||
token_buffer.extend(tokens)
|
||||
# Move tokens from the deque into the scratch buffer
|
||||
tokens = [token_buffer.popleft() for _ in range(needed_tokens)]
|
||||
# CUDA supports memory pinning for asynchronous transfers between CPU and GPU
|
||||
use_cuda_optimizations = device == "cuda"
|
||||
scratch = torch.tensor(
|
||||
tokens, dtype=torch.long, pin_memory=use_cuda_optimizations
|
||||
) # in PyTorch, long=int64
|
||||
scratch = torch.tensor(tokens, dtype=torch.long, pin_memory=use_cuda_optimizations) # in PyTorch, long=int64
|
||||
# Create the inputs/targets as 1D tensors
|
||||
inputs_cpu = scratch[:-1]
|
||||
targets_cpu = scratch[1:]
|
||||
# Reshape to 2D and move to GPU async
|
||||
inputs = inputs_cpu.view(B, T).to(
|
||||
device=device, non_blocking=use_cuda_optimizations
|
||||
)
|
||||
targets = targets_cpu.view(B, T).to(
|
||||
device=device, non_blocking=use_cuda_optimizations
|
||||
)
|
||||
inputs = inputs_cpu.view(B, T).to(device=device, non_blocking=use_cuda_optimizations)
|
||||
targets = targets_cpu.view(B, T).to(device=device, non_blocking=use_cuda_optimizations)
|
||||
state_dict = {
|
||||
"pq_idx": pq_idx,
|
||||
"rg_idx": rg_idx,
|
||||
|
|
@ -116,7 +92,5 @@ def tokenizing_distributed_data_loader_with_state(
|
|||
|
||||
def tokenizing_distributed_data_loader(*args, **kwargs):
|
||||
# helper function that only emits the inputs/targets and not the state_dict
|
||||
for inputs, targets, state_dict in tokenizing_distributed_data_loader_with_state(
|
||||
*args, **kwargs
|
||||
):
|
||||
for inputs, targets, state_dict in tokenizing_distributed_data_loader_with_state(*args, **kwargs):
|
||||
yield inputs, targets
|
||||
|
|
|
|||
|
|
@ -21,13 +21,9 @@ from nanochat.common import get_base_dir
|
|||
# The specifics of the current pretraining dataset
|
||||
|
||||
# The URL on the internet where the data is hosted and downloaded from on demand
|
||||
BASE_URL = (
|
||||
"https://huggingface.co/datasets/karpathy/fineweb-edu-100b-shuffle/resolve/main"
|
||||
)
|
||||
BASE_URL = "https://huggingface.co/datasets/karpathy/fineweb-edu-100b-shuffle/resolve/main"
|
||||
MAX_SHARD = 1822 # the last datashard is shard_01822.parquet
|
||||
index_to_filename = (
|
||||
lambda index: f"shard_{index:05d}.parquet"
|
||||
) # format of the filenames
|
||||
index_to_filename = lambda index: f"shard_{index:05d}.parquet" # format of the filenames
|
||||
base_dir = get_base_dir()
|
||||
DATA_DIR = os.path.join(base_dir, "base_data")
|
||||
os.makedirs(DATA_DIR, exist_ok=True)
|
||||
|
|
@ -39,13 +35,7 @@ os.makedirs(DATA_DIR, exist_ok=True)
|
|||
def list_parquet_files(data_dir=None):
|
||||
"""Looks into a data dir and returns full paths to all parquet files."""
|
||||
data_dir = DATA_DIR if data_dir is None else data_dir
|
||||
parquet_files = sorted(
|
||||
[
|
||||
f
|
||||
for f in os.listdir(data_dir)
|
||||
if f.endswith(".parquet") and not f.endswith(".tmp")
|
||||
]
|
||||
)
|
||||
parquet_files = sorted([f for f in os.listdir(data_dir) if f.endswith('.parquet') and not f.endswith('.tmp')])
|
||||
parquet_paths = [os.path.join(data_dir, f) for f in parquet_files]
|
||||
return parquet_paths
|
||||
|
||||
|
|
@ -63,7 +53,7 @@ def parquets_iter_batched(split, start=0, step=1):
|
|||
pf = pq.ParquetFile(filepath)
|
||||
for rg_idx in range(start, pf.num_row_groups, step):
|
||||
rg = pf.read_row_group(rg_idx)
|
||||
texts = rg.column("text").to_pylist()
|
||||
texts = rg.column('text').to_pylist()
|
||||
yield texts
|
||||
|
||||
|
||||
|
|
@ -89,11 +79,9 @@ def download_single_file(index):
|
|||
response = requests.get(url, stream=True, timeout=30)
|
||||
response.raise_for_status()
|
||||
# Write to temporary file first
|
||||
temp_path = filepath + f".tmp"
|
||||
with open(temp_path, "wb") as f:
|
||||
for chunk in response.iter_content(
|
||||
chunk_size=1024 * 1024
|
||||
): # 1MB chunks
|
||||
temp_path = filepath + ".tmp"
|
||||
with open(temp_path, 'wb') as f:
|
||||
for chunk in response.iter_content(chunk_size=1024 * 1024): # 1MB chunks
|
||||
if chunk:
|
||||
f.write(chunk)
|
||||
# Move temp file to final location
|
||||
|
|
@ -101,10 +89,10 @@ def download_single_file(index):
|
|||
print(f"Successfully downloaded {filename}")
|
||||
return True
|
||||
|
||||
except (requests.RequestException, OSError) as e:
|
||||
except (OSError, requests.RequestException) as e:
|
||||
print(f"Attempt {attempt}/{max_attempts} failed for {filename}: {e}")
|
||||
# Clean up any partial files
|
||||
for path in [filepath + f".tmp", filepath]:
|
||||
for path in [filepath + ".tmp", filepath]:
|
||||
if os.path.exists(path):
|
||||
try:
|
||||
os.remove(path)
|
||||
|
|
@ -123,30 +111,18 @@ def download_single_file(index):
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Download FineWeb-Edu 100BT dataset shards"
|
||||
parser = argparse.ArgumentParser(description="Download FineWeb-Edu 100BT dataset shards")
|
||||
parser.add_argument(
|
||||
"-n", "--num-files", type=int, default=-1, help="Number of shards to download (default: -1), -1 = disable"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-n",
|
||||
"--num-files",
|
||||
type=int,
|
||||
default=-1,
|
||||
help="Number of shards to download (default: -1), -1 = disable",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-w",
|
||||
"--num-workers",
|
||||
type=int,
|
||||
default=4,
|
||||
help="Number of parallel download workers (default: 4)",
|
||||
"-w", "--num-workers", type=int, default=4, help="Number of parallel download workers (default: 4)"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
num = MAX_SHARD + 1 if args.num_files == -1 else min(args.num_files, MAX_SHARD + 1)
|
||||
ids_to_download = list(range(num))
|
||||
print(
|
||||
f"Downloading {len(ids_to_download)} shards using {args.num_workers} workers..."
|
||||
)
|
||||
print(f"Downloading {len(ids_to_download)} shards using {args.num_workers} workers...")
|
||||
print(f"Target directory: {DATA_DIR}")
|
||||
print()
|
||||
with Pool(processes=args.num_workers) as pool:
|
||||
|
|
|
|||
|
|
@ -64,38 +64,36 @@ def use_calculator(expr):
|
|||
|
||||
# Check if it's a string operation we support
|
||||
# Allow: strings (single/double quotes), .count(), letters, numbers, spaces, parens
|
||||
allowed_chars = (
|
||||
"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'\"()._ "
|
||||
)
|
||||
allowed_chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'\"()._ "
|
||||
if not all([x in allowed_chars for x in expr]):
|
||||
return None
|
||||
|
||||
# Disallow dangerous patterns
|
||||
dangerous_patterns = [
|
||||
"__",
|
||||
"import",
|
||||
"exec",
|
||||
"eval",
|
||||
"compile",
|
||||
"open",
|
||||
"file",
|
||||
"input",
|
||||
"raw_input",
|
||||
"globals",
|
||||
"locals",
|
||||
"vars",
|
||||
"dir",
|
||||
"getattr",
|
||||
"setattr",
|
||||
"delattr",
|
||||
"hasattr",
|
||||
'__',
|
||||
'import',
|
||||
'exec',
|
||||
'eval',
|
||||
'compile',
|
||||
'open',
|
||||
'file',
|
||||
'input',
|
||||
'raw_input',
|
||||
'globals',
|
||||
'locals',
|
||||
'vars',
|
||||
'dir',
|
||||
'getattr',
|
||||
'setattr',
|
||||
'delattr',
|
||||
'hasattr',
|
||||
]
|
||||
expr_lower = expr.lower()
|
||||
if any(pattern in expr_lower for pattern in dangerous_patterns):
|
||||
return None
|
||||
|
||||
# Only allow .count() method for now (can expand later)
|
||||
if ".count(" not in expr:
|
||||
if '.count(' not in expr:
|
||||
return None
|
||||
|
||||
# Evaluate with timeout
|
||||
|
|
@ -137,9 +135,7 @@ class KVCache:
|
|||
assert dim1 == dim2, f"Dim {ix} mismatch: {dim1} != {dim2}"
|
||||
elif ix == 2:
|
||||
# batch_size can be expanded
|
||||
assert (
|
||||
dim1 == dim2 or dim2 == 1
|
||||
), f"Batch dim mismatch: {dim1} != {dim2}"
|
||||
assert dim1 == dim2 or dim2 == 1, f"Batch dim mismatch: {dim1} != {dim2}"
|
||||
elif ix == 4:
|
||||
# seq_len: self must be longer than other
|
||||
assert dim1 >= dim2, f"Seq len mismatch: {dim1} < {dim2}"
|
||||
|
|
@ -161,17 +157,11 @@ class KVCache:
|
|||
# Dynamically grow the cache if needed
|
||||
if t1 > self.kv_cache.size(4):
|
||||
t_needed = t1 + 1024 # as much as we need plus buffer of 1024
|
||||
t_needed = (
|
||||
t_needed + 1023
|
||||
) & ~1023 # then round up to the nearest multiple of 1024
|
||||
t_needed = (t_needed + 1023) & ~1023 # then round up to the nearest multiple of 1024
|
||||
additional_shape = list(self.kv_cache.shape)
|
||||
additional_shape[4] = t_needed - self.kv_cache.size(4)
|
||||
additional_cache = torch.empty(
|
||||
additional_shape, dtype=k.dtype, device=k.device
|
||||
)
|
||||
self.kv_cache = torch.cat(
|
||||
[self.kv_cache, additional_cache], dim=4
|
||||
).contiguous()
|
||||
additional_cache = torch.empty(additional_shape, dtype=k.dtype, device=k.device)
|
||||
self.kv_cache = torch.cat([self.kv_cache, additional_cache], dim=4).contiguous()
|
||||
self.kv_shape = self.kv_cache.shape
|
||||
# Insert k, v into the cache
|
||||
self.kv_cache[layer_idx, 0, :, :, t0:t1] = k
|
||||
|
|
@ -211,9 +201,7 @@ def sample_next_token(logits, rng, temperature=1.0, top_k=None):
|
|||
class RowState:
|
||||
# Per-row state tracking during generation
|
||||
def __init__(self, current_tokens=None):
|
||||
self.current_tokens = (
|
||||
current_tokens or []
|
||||
) # Current token sequence for this row
|
||||
self.current_tokens = current_tokens or [] # Current token sequence for this row
|
||||
self.forced_tokens = deque() # Queue of tokens to force inject
|
||||
self.in_python_block = False # Whether we are inside a python block
|
||||
self.python_expr_tokens = [] # Tokens of the current python expression
|
||||
|
|
@ -221,25 +209,14 @@ class RowState:
|
|||
|
||||
|
||||
class Engine:
|
||||
|
||||
def __init__(self, model, tokenizer):
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer # needed for tool use
|
||||
|
||||
@torch.inference_mode()
|
||||
def generate(
|
||||
self,
|
||||
tokens,
|
||||
num_samples=1,
|
||||
max_tokens=None,
|
||||
temperature=1.0,
|
||||
top_k=None,
|
||||
seed=42,
|
||||
):
|
||||
def generate(self, tokens, num_samples=1, max_tokens=None, temperature=1.0, top_k=None, seed=42):
|
||||
"""Same as generate, but does single prefill and then clones the KV cache."""
|
||||
assert isinstance(tokens, list) and isinstance(
|
||||
tokens[0], int
|
||||
), "expecting list of ints"
|
||||
assert isinstance(tokens, list) and isinstance(tokens[0], int), "expecting list of ints"
|
||||
device = self.model.get_device()
|
||||
rng = torch.Generator(device=device)
|
||||
rng.manual_seed(seed)
|
||||
|
|
@ -255,11 +232,7 @@ class Engine:
|
|||
|
||||
# 1) Run a batch 1 prefill of the prompt tokens
|
||||
m = self.model.config
|
||||
kv_model_kwargs = {
|
||||
"num_heads": m.n_kv_head,
|
||||
"head_dim": m.n_embd // m.n_head,
|
||||
"num_layers": m.n_layer,
|
||||
}
|
||||
kv_model_kwargs = {"num_heads": m.n_kv_head, "head_dim": m.n_embd // m.n_head, "num_layers": m.n_layer}
|
||||
kv_cache_prefill = KVCache(
|
||||
batch_size=1,
|
||||
seq_len=len(tokens),
|
||||
|
|
@ -272,11 +245,7 @@ class Engine:
|
|||
sampled_tokens = next_ids[:, 0].tolist()
|
||||
|
||||
# 2) Replicate the KV cache for each sample/row
|
||||
kv_length_hint = (
|
||||
(len(tokens) + max_tokens)
|
||||
if max_tokens is not None
|
||||
else self.model.config.sequence_len
|
||||
)
|
||||
kv_length_hint = (len(tokens) + max_tokens) if max_tokens is not None else self.model.config.sequence_len
|
||||
kv_cache_decode = KVCache(
|
||||
batch_size=num_samples,
|
||||
seq_len=kv_length_hint,
|
||||
|
|
@ -302,36 +271,24 @@ class Engine:
|
|||
# Get sampled tokens - either from prefill or from forward pass
|
||||
if first_iteration:
|
||||
# Use the tokens we already sampled from prefill
|
||||
sampled_tokens = [
|
||||
sampled_tokens[0]
|
||||
] * num_samples # Broadcast first token to all rows
|
||||
sampled_tokens = [sampled_tokens[0]] * num_samples # Broadcast first token to all rows
|
||||
# TODO: we should sample a token for each row instead of broadcasting
|
||||
first_iteration = False
|
||||
else:
|
||||
# Forward the model and get the next token for each row
|
||||
logits = self.model.forward(
|
||||
ids, kv_cache=kv_cache_decode
|
||||
) # (B, T, vocab_size)
|
||||
logits = self.model.forward(ids, kv_cache=kv_cache_decode) # (B, T, vocab_size)
|
||||
logits = logits[:, -1, :] # (B, vocab_size) at last time step
|
||||
next_ids = sample_next_token(logits, rng, temperature, top_k) # (B, 1)
|
||||
sampled_tokens = next_ids[:, 0].tolist()
|
||||
|
||||
# Process each row: choose the next token, update state, optional tool use
|
||||
token_column = [] # contains the next token id along each row
|
||||
token_masks = (
|
||||
[]
|
||||
) # contains the mask (was it sampled (1) or forced (0)?) along each row
|
||||
token_masks = [] # contains the mask (was it sampled (1) or forced (0)?) along each row
|
||||
for i, state in enumerate(row_states):
|
||||
# Select the next token in this row
|
||||
is_forced = (
|
||||
len(state.forced_tokens) > 0
|
||||
) # are there tokens waiting to be forced in deque?
|
||||
token_masks.append(
|
||||
0 if is_forced else 1
|
||||
) # mask is 0 if forced, 1 if sampled
|
||||
next_token = (
|
||||
state.forced_tokens.popleft() if is_forced else sampled_tokens[i]
|
||||
)
|
||||
is_forced = len(state.forced_tokens) > 0 # are there tokens waiting to be forced in deque?
|
||||
token_masks.append(0 if is_forced else 1) # mask is 0 if forced, 1 if sampled
|
||||
next_token = state.forced_tokens.popleft() if is_forced else sampled_tokens[i]
|
||||
token_column.append(next_token)
|
||||
# Update the state of this row to include the next token
|
||||
state.current_tokens.append(next_token)
|
||||
|
|
@ -360,9 +317,7 @@ class Engine:
|
|||
yield token_column, token_masks
|
||||
num_generated += 1
|
||||
# Prepare ids for next iteration
|
||||
ids = torch.tensor(token_column, dtype=torch.long, device=device).unsqueeze(
|
||||
1
|
||||
)
|
||||
ids = torch.tensor(token_column, dtype=torch.long, device=device).unsqueeze(1)
|
||||
|
||||
def generate_batch(self, tokens, num_samples=1, **kwargs):
|
||||
"""
|
||||
|
|
@ -400,9 +355,7 @@ if __name__ == "__main__":
|
|||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
|
||||
device_type = autodetect_device_type()
|
||||
autocast_ctx = (
|
||||
torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16)
|
||||
if device_type == "cuda"
|
||||
else nullcontext()
|
||||
torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
|
||||
)
|
||||
|
||||
# load the model and tokenizer
|
||||
|
|
@ -411,9 +364,7 @@ if __name__ == "__main__":
|
|||
# common hyperparameters
|
||||
kwargs = dict(max_tokens=64, temperature=0.0)
|
||||
# set the starting prompt
|
||||
prompt_tokens = tokenizer.encode(
|
||||
"The chemical formula of water is", prepend=bos_token_id
|
||||
)
|
||||
prompt_tokens = tokenizer.encode("The chemical formula of water is", prepend=bos_token_id)
|
||||
# generate the reference sequence using the model.generate() function
|
||||
generated_tokens = []
|
||||
torch.cuda.synchronize()
|
||||
|
|
@ -432,9 +383,7 @@ if __name__ == "__main__":
|
|||
# generate tokens with Engine
|
||||
generated_tokens = []
|
||||
engine = Engine(model, tokenizer)
|
||||
stream = engine.generate(
|
||||
prompt_tokens, num_samples=1, **kwargs
|
||||
) # note: runs in fp32
|
||||
stream = engine.generate(prompt_tokens, num_samples=1, **kwargs) # note: runs in fp32
|
||||
torch.cuda.synchronize()
|
||||
t0 = time.time()
|
||||
with autocast_ctx:
|
||||
|
|
|
|||
|
|
@ -30,7 +30,6 @@ import platform
|
|||
import signal
|
||||
import tempfile
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
|
@ -42,7 +41,7 @@ class ExecutionResult:
|
|||
success: bool
|
||||
stdout: str
|
||||
stderr: str
|
||||
error: Optional[str] = None
|
||||
error: str | None = None
|
||||
timeout: bool = False
|
||||
memory_exceeded: bool = False
|
||||
|
||||
|
|
@ -133,7 +132,7 @@ def chdir(root):
|
|||
os.chdir(cwd)
|
||||
|
||||
|
||||
def reliability_guard(maximum_memory_bytes: Optional[int] = None):
|
||||
def reliability_guard(maximum_memory_bytes: int | None = None):
|
||||
"""
|
||||
This disables various destructive functions and prevents the generated code
|
||||
from interfering with the test (e.g. fork bomb, killing other processes,
|
||||
|
|
@ -150,15 +149,9 @@ def reliability_guard(maximum_memory_bytes: Optional[int] = None):
|
|||
# These resource limit calls seem to fail on macOS (Darwin), skip?
|
||||
import resource
|
||||
|
||||
resource.setrlimit(
|
||||
resource.RLIMIT_AS, (maximum_memory_bytes, maximum_memory_bytes)
|
||||
)
|
||||
resource.setrlimit(
|
||||
resource.RLIMIT_DATA, (maximum_memory_bytes, maximum_memory_bytes)
|
||||
)
|
||||
resource.setrlimit(
|
||||
resource.RLIMIT_STACK, (maximum_memory_bytes, maximum_memory_bytes)
|
||||
)
|
||||
resource.setrlimit(resource.RLIMIT_AS, (maximum_memory_bytes, maximum_memory_bytes))
|
||||
resource.setrlimit(resource.RLIMIT_DATA, (maximum_memory_bytes, maximum_memory_bytes))
|
||||
resource.setrlimit(resource.RLIMIT_STACK, (maximum_memory_bytes, maximum_memory_bytes))
|
||||
|
||||
faulthandler.disable()
|
||||
|
||||
|
|
@ -220,12 +213,9 @@ def reliability_guard(maximum_memory_bytes: Optional[int] = None):
|
|||
sys.modules["tkinter"] = None
|
||||
|
||||
|
||||
def _unsafe_execute(
|
||||
code: str, timeout: float, maximum_memory_bytes: Optional[int], result_dict
|
||||
):
|
||||
def _unsafe_execute(code: str, timeout: float, maximum_memory_bytes: int | None, result_dict):
|
||||
"""Execute code in a subprocess with safety guards. Results are written to result_dict."""
|
||||
with create_tempdir():
|
||||
|
||||
# These system calls are needed when cleaning up tempdir.
|
||||
import os
|
||||
import shutil
|
||||
|
|
@ -307,7 +297,7 @@ def _unsafe_execute(
|
|||
def execute_code(
|
||||
code: str,
|
||||
timeout: float = 5.0, # 5 seconds default
|
||||
maximum_memory_bytes: Optional[int] = 256 * 1024 * 1024, # 256MB default
|
||||
maximum_memory_bytes: int | None = 256 * 1024 * 1024, # 256MB default
|
||||
) -> ExecutionResult:
|
||||
"""
|
||||
Execute Python code in a sandboxed environment.
|
||||
|
|
@ -331,9 +321,7 @@ def execute_code(
|
|||
manager = multiprocessing.Manager()
|
||||
result_dict = manager.dict()
|
||||
|
||||
p = multiprocessing.Process(
|
||||
target=_unsafe_execute, args=(code, timeout, maximum_memory_bytes, result_dict)
|
||||
)
|
||||
p = multiprocessing.Process(target=_unsafe_execute, args=(code, timeout, maximum_memory_bytes, result_dict))
|
||||
p.start()
|
||||
p.join(timeout=timeout + 1)
|
||||
|
||||
|
|
|
|||
|
|
@ -75,9 +75,7 @@ class CausalSelfAttention(nn.Module):
|
|||
|
||||
# Apply Rotary Embeddings to queries and keys to get relative positional encoding
|
||||
cos, sin = cos_sin
|
||||
q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(
|
||||
k, cos, sin
|
||||
) # QK rotary embedding
|
||||
q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin) # QK rotary embedding
|
||||
q, k = norm(q), norm(k) # QK norm
|
||||
q, k, v = (
|
||||
q.transpose(1, 2),
|
||||
|
|
@ -89,9 +87,7 @@ class CausalSelfAttention(nn.Module):
|
|||
if kv_cache is not None:
|
||||
k, v = kv_cache.insert_kv(self.layer_idx, k, v)
|
||||
Tq = q.size(2) # number of queries in this forward pass
|
||||
Tk = k.size(
|
||||
2
|
||||
) # number of keys/values in total (in the cache + current forward pass)
|
||||
Tk = k.size(2) # number of keys/values in total (in the cache + current forward pass)
|
||||
|
||||
# Attention: queries attend to keys/values autoregressively. A few cases to handle:
|
||||
enable_gqa = (
|
||||
|
|
@ -100,31 +96,21 @@ class CausalSelfAttention(nn.Module):
|
|||
if kv_cache is None or Tq == Tk:
|
||||
# During training (no KV cache), attend as usual with causal attention
|
||||
# And even if there is KV cache, we can still use this simple version when Tq == Tk
|
||||
y = F.scaled_dot_product_attention(
|
||||
q, k, v, is_causal=True, enable_gqa=enable_gqa
|
||||
)
|
||||
y = F.scaled_dot_product_attention(q, k, v, is_causal=True, enable_gqa=enable_gqa)
|
||||
elif Tq == 1:
|
||||
# During inference but with a single query in this forward pass:
|
||||
# The query has to attend to all the keys/values in the cache
|
||||
y = F.scaled_dot_product_attention(
|
||||
q, k, v, is_causal=False, enable_gqa=enable_gqa
|
||||
)
|
||||
y = F.scaled_dot_product_attention(q, k, v, is_causal=False, enable_gqa=enable_gqa)
|
||||
else:
|
||||
# During inference AND we have a chunk of queries in this forward pass:
|
||||
# First, each query attends to all the cached keys/values (i.e. full prefix)
|
||||
attn_mask = torch.zeros(
|
||||
(Tq, Tk), dtype=torch.bool, device=q.device
|
||||
) # True = keep, False = mask
|
||||
attn_mask = torch.zeros((Tq, Tk), dtype=torch.bool, device=q.device) # True = keep, False = mask
|
||||
prefix_len = Tk - Tq
|
||||
if prefix_len > 0: # can't be negative but could be zero
|
||||
attn_mask[:, :prefix_len] = True
|
||||
# Then, causal attention within this chunk
|
||||
attn_mask[:, prefix_len:] = torch.tril(
|
||||
torch.ones((Tq, Tq), dtype=torch.bool, device=q.device)
|
||||
)
|
||||
y = F.scaled_dot_product_attention(
|
||||
q, k, v, attn_mask=attn_mask, enable_gqa=enable_gqa
|
||||
)
|
||||
attn_mask[:, prefix_len:] = torch.tril(torch.ones((Tq, Tq), dtype=torch.bool, device=q.device))
|
||||
y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, enable_gqa=enable_gqa)
|
||||
|
||||
# Re-assemble the heads side by side and project back to residual stream
|
||||
y = y.transpose(1, 2).contiguous().view(B, T, -1)
|
||||
|
|
@ -164,9 +150,7 @@ class GPT(nn.Module):
|
|||
self.transformer = nn.ModuleDict(
|
||||
{
|
||||
"wte": nn.Embedding(config.vocab_size, config.n_embd),
|
||||
"h": nn.ModuleList(
|
||||
[Block(config, layer_idx) for layer_idx in range(config.n_layer)]
|
||||
),
|
||||
"h": nn.ModuleList([Block(config, layer_idx) for layer_idx in range(config.n_layer)]),
|
||||
}
|
||||
)
|
||||
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
||||
|
|
@ -174,14 +158,10 @@ class GPT(nn.Module):
|
|||
# As for rotary_seq_len, these rotary embeddings are pretty small/cheap in memory,
|
||||
# so let's just over-compute them, but assert fail if we ever reach that amount.
|
||||
# In the future we can dynamically grow the cache, for now it's fine.
|
||||
self.rotary_seq_len = (
|
||||
config.sequence_len * 10
|
||||
) # 10X over-compute should be enough, TODO make nicer?
|
||||
self.rotary_seq_len = config.sequence_len * 10 # 10X over-compute should be enough, TODO make nicer?
|
||||
head_dim = config.n_embd // config.n_head
|
||||
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
|
||||
self.register_buffer(
|
||||
"cos", cos, persistent=False
|
||||
) # persistent=False means it's not saved to the checkpoint
|
||||
self.register_buffer("cos", cos, persistent=False) # persistent=False means it's not saved to the checkpoint
|
||||
self.register_buffer("sin", sin, persistent=False)
|
||||
|
||||
def init_weights(self):
|
||||
|
|
@ -226,10 +206,7 @@ class GPT(nn.Module):
|
|||
freqs = torch.outer(t, inv_freq)
|
||||
cos, sin = freqs.cos(), freqs.sin()
|
||||
cos, sin = cos.bfloat16(), sin.bfloat16() # keep them in bfloat16
|
||||
cos, sin = (
|
||||
cos[None, :, None, :],
|
||||
sin[None, :, None, :],
|
||||
) # add batch and head dims for later broadcasting
|
||||
cos, sin = cos[None, :, None, :], sin[None, :, None, :] # add batch and head dims for later broadcasting
|
||||
return cos, sin
|
||||
|
||||
def get_device(self):
|
||||
|
|
@ -248,25 +225,19 @@ class GPT(nn.Module):
|
|||
num_flops_per_token = 6 * (nparams - nparams_embedding) + 12 * l * h * q * t
|
||||
return num_flops_per_token
|
||||
|
||||
def setup_optimizers(
|
||||
self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0
|
||||
):
|
||||
def setup_optimizers(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0):
|
||||
model_dim = self.config.n_embd
|
||||
ddp, rank, local_rank, world_size = get_dist_info()
|
||||
# Separate out all parameters into 3 groups (matrix, embedding, lm_head)
|
||||
matrix_params = list(self.transformer.h.parameters())
|
||||
embedding_params = list(self.transformer.wte.parameters())
|
||||
lm_head_params = list(self.lm_head.parameters())
|
||||
assert len(list(self.parameters())) == len(matrix_params) + len(
|
||||
embedding_params
|
||||
) + len(lm_head_params)
|
||||
assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params)
|
||||
# Create the AdamW optimizer for the embedding and lm_head
|
||||
# Scale the LR for the AdamW parameters by ∝1/√dmodel (having tuned the LRs for 768 dim model)
|
||||
dmodel_lr_scale = (model_dim / 768) ** -0.5
|
||||
if rank == 0:
|
||||
print(
|
||||
f"Scaling the LR for the AdamW parameters ∝1/√({model_dim}/768) = {dmodel_lr_scale:.6f}"
|
||||
)
|
||||
print(f"Scaling the LR for the AdamW parameters ∝1/√({model_dim}/768) = {dmodel_lr_scale:.6f}")
|
||||
adam_groups = [
|
||||
dict(params=lm_head_params, lr=unembedding_lr * dmodel_lr_scale),
|
||||
dict(params=embedding_params, lr=embedding_lr * dmodel_lr_scale),
|
||||
|
|
@ -285,23 +256,20 @@ class GPT(nn.Module):
|
|||
group["initial_lr"] = group["lr"]
|
||||
return optimizers
|
||||
|
||||
def forward(self, idx, targets=None, kv_cache=None, loss_reduction="mean"):
|
||||
def forward(self, idx, targets=None, kv_cache=None, loss_reduction='mean'):
|
||||
B, T = idx.size()
|
||||
|
||||
# Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim/2))
|
||||
assert T <= self.cos.size(
|
||||
1
|
||||
), f"Sequence length grew beyond the rotary embeddings cache: {T} > {self.cos.size(1)}"
|
||||
assert (
|
||||
idx.device == self.cos.device
|
||||
), f"Rotary embeddings and idx are on different devices: {idx.device} != {self.cos.device}"
|
||||
assert T <= self.cos.size(1), (
|
||||
f"Sequence length grew beyond the rotary embeddings cache: {T} > {self.cos.size(1)}"
|
||||
)
|
||||
assert idx.device == self.cos.device, (
|
||||
f"Rotary embeddings and idx are on different devices: {idx.device} != {self.cos.device}"
|
||||
)
|
||||
assert self.cos.dtype == torch.bfloat16, "Rotary embeddings must be in bfloat16"
|
||||
# if kv cache exists, we need to offset the rotary embeddings to the current position in the cache
|
||||
T0 = 0 if kv_cache is None else kv_cache.get_pos()
|
||||
cos_sin = (
|
||||
self.cos[:, T0 : T0 + T],
|
||||
self.sin[:, T0 : T0 + T],
|
||||
) # truncate cache to current sequence length
|
||||
cos_sin = self.cos[:, T0 : T0 + T], self.sin[:, T0 : T0 + T] # truncate cache to current sequence length
|
||||
|
||||
# Forward the trunk of the Transformer
|
||||
x = self.transformer.wte(idx)
|
||||
|
|
@ -319,10 +287,7 @@ class GPT(nn.Module):
|
|||
logits = softcap * torch.tanh(logits / softcap) # logits softcap
|
||||
logits = logits.float() # use tf32/fp32 for logits
|
||||
loss = F.cross_entropy(
|
||||
logits.view(-1, logits.size(-1)),
|
||||
targets.view(-1),
|
||||
ignore_index=-1,
|
||||
reduction=loss_reduction,
|
||||
logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1, reduction=loss_reduction
|
||||
)
|
||||
return loss
|
||||
else:
|
||||
|
|
@ -351,7 +316,7 @@ class GPT(nn.Module):
|
|||
logits = logits[:, -1, :] # (B, vocab_size)
|
||||
if top_k is not None:
|
||||
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
||||
logits[logits < v[:, [-1]]] = -float("Inf")
|
||||
logits[logits < v[:, [-1]]] = -float('Inf')
|
||||
if temperature > 0:
|
||||
logits = logits / temperature
|
||||
probs = F.softmax(logits, dim=-1)
|
||||
|
|
|
|||
|
|
@ -33,20 +33,16 @@ def evaluate_bpb(model, batches, steps, token_bytes):
|
|||
batch_iter = iter(batches)
|
||||
for _ in range(steps):
|
||||
x, y = next(batch_iter)
|
||||
loss2d = model(x, y, loss_reduction="none") # (B, T)
|
||||
loss2d = model(x, y, loss_reduction='none') # (B, T)
|
||||
loss2d = loss2d.view(-1) # flatten
|
||||
y = y.view(-1) # flatten
|
||||
if (
|
||||
y.int() < 0
|
||||
).any(): # mps does not currently have kernel for < 0 for int64, only int32
|
||||
if (y.int() < 0).any(): # mps does not currently have kernel for < 0 for int64, only int32
|
||||
# slightly more complex code path if some target tokens are ignore_index (e.g. -1)
|
||||
# any target token < 0 is to be ignored: do NOT index token_bytes with negatives
|
||||
valid = y >= 0
|
||||
y_safe = torch.where(valid, y, torch.zeros_like(y))
|
||||
# map valid targets to their byte length; ignored targets contribute 0 bytes
|
||||
num_bytes2d = torch.where(
|
||||
valid, token_bytes[y_safe], torch.zeros_like(y, dtype=token_bytes.dtype)
|
||||
)
|
||||
num_bytes2d = torch.where(valid, token_bytes[y_safe], torch.zeros_like(y, dtype=token_bytes.dtype))
|
||||
total_nats += (loss2d * (num_bytes2d > 0)).sum()
|
||||
total_bytes += num_bytes2d.sum()
|
||||
else:
|
||||
|
|
@ -63,6 +59,6 @@ def evaluate_bpb(model, batches, steps, token_bytes):
|
|||
total_nats = total_nats.item()
|
||||
total_bytes = total_bytes.item()
|
||||
if total_bytes == 0:
|
||||
return float("inf")
|
||||
return float('inf')
|
||||
bpb = total_nats / (math.log(2) * total_bytes)
|
||||
return bpb
|
||||
|
|
|
|||
|
|
@ -113,22 +113,13 @@ class DistMuon(torch.optim.Optimizer):
|
|||
ns_steps: number of Newton–Schulz iterations for the orthogonalization
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
params,
|
||||
lr: float = 0.02,
|
||||
momentum: float = 0.95,
|
||||
nesterov: bool = True,
|
||||
ns_steps: int = 5,
|
||||
):
|
||||
def __init__(self, params, lr: float = 0.02, momentum: float = 0.95, nesterov: bool = True, ns_steps: int = 5):
|
||||
defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps)
|
||||
params = list(params)
|
||||
assert all(p.ndim == 2 for p in params), "Muon expects 2D parameters only"
|
||||
rank = dist.get_rank()
|
||||
# Group all parameters by their shape
|
||||
shapes = sorted(
|
||||
{p.shape for p in params}
|
||||
) # sort to ensure consistent / deterministic ordering
|
||||
shapes = sorted({p.shape for p in params}) # sort to ensure consistent / deterministic ordering
|
||||
param_groups = []
|
||||
for shape in shapes:
|
||||
group_params = [p for p in params if p.shape == shape]
|
||||
|
|
@ -136,12 +127,8 @@ class DistMuon(torch.optim.Optimizer):
|
|||
assert all(p.device == device for p in group_params)
|
||||
assert all(p.dtype == dtype for p in group_params)
|
||||
if rank == 0:
|
||||
print(
|
||||
f"Muon: Grouping {len(group_params)} params of shape {shape}, device {device}, dtype {dtype}"
|
||||
)
|
||||
param_groups.append(
|
||||
dict(params=group_params, zero_buffer=torch.zeros_like(group_params[0]))
|
||||
)
|
||||
print(f"Muon: Grouping {len(group_params)} params of shape {shape}, device {device}, dtype {dtype}")
|
||||
param_groups.append(dict(params=group_params, zero_buffer=torch.zeros_like(group_params[0])))
|
||||
super().__init__(param_groups, defaults)
|
||||
|
||||
@torch.no_grad()
|
||||
|
|
@ -150,9 +137,9 @@ class DistMuon(torch.optim.Optimizer):
|
|||
world_size = dist.get_world_size()
|
||||
|
||||
# Ensure all grads exist
|
||||
assert all(
|
||||
p.grad is not None for group in self.param_groups for p in group["params"]
|
||||
), "All params must have grads"
|
||||
assert all(p.grad is not None for group in self.param_groups for p in group["params"]), (
|
||||
"All params must have grads"
|
||||
)
|
||||
|
||||
# Kick off all the reduce scatter operations to average up the gradients across all ranks
|
||||
all_reduce_futures = []
|
||||
|
|
@ -168,15 +155,9 @@ class DistMuon(torch.optim.Optimizer):
|
|||
# pad rs_input with the zero buffer to complete the group
|
||||
rs_input.extend([zero_buffer] * (world_size - len(rs_input)))
|
||||
# the output buffer gets strided across the group based on the rank
|
||||
rs_output = (
|
||||
params[owner_idx].grad
|
||||
if owner_idx < len(params)
|
||||
else torch.empty_like(zero_buffer)
|
||||
)
|
||||
rs_output = params[owner_idx].grad if owner_idx < len(params) else torch.empty_like(zero_buffer)
|
||||
# reduce scatter the gradients within this group of world_size params
|
||||
work = dist.reduce_scatter(
|
||||
rs_output, rs_input, op=dist.ReduceOp.AVG, async_op=True
|
||||
).get_future()
|
||||
work = dist.reduce_scatter(rs_output, rs_input, op=dist.ReduceOp.AVG, async_op=True).get_future()
|
||||
all_reduce_futures.append(work)
|
||||
|
||||
# Now each rank computes the update and gathers
|
||||
|
|
@ -188,13 +169,9 @@ class DistMuon(torch.optim.Optimizer):
|
|||
# Go through params in groups of world_size.
|
||||
for base_i in range(0, len(params), world_size):
|
||||
# The compute owner of each param is rank i % world_size
|
||||
owner_idx = (
|
||||
base_i + rank
|
||||
) # calculate the index of the param that this rank owns
|
||||
owner_idx = base_i + rank # calculate the index of the param that this rank owns
|
||||
# Wait for the reduce scatter to complete
|
||||
all_reduce_futures[
|
||||
future_idx
|
||||
].wait() # possibly later we could use wait_any polling instead
|
||||
all_reduce_futures[future_idx].wait() # possibly later we could use wait_any polling instead
|
||||
future_idx += 1
|
||||
# Owner computes the Muon update, result is in its param
|
||||
if owner_idx < len(params):
|
||||
|
|
@ -212,12 +189,7 @@ class DistMuon(torch.optim.Optimizer):
|
|||
# Replicate updated parameters to all ranks
|
||||
ag_input = params[owner_idx] if owner_idx < len(params) else zero_buffer
|
||||
ag_output = params[base_i : base_i + world_size]
|
||||
ag_output.extend(
|
||||
[
|
||||
torch.empty_like(zero_buffer)
|
||||
for _ in range(world_size - len(ag_output))
|
||||
]
|
||||
) # pad
|
||||
ag_output.extend([torch.empty_like(zero_buffer) for _ in range(world_size - len(ag_output))]) # pad
|
||||
work = dist.all_gather(ag_output, ag_input, async_op=True).get_future()
|
||||
all_gather_futures.append(work)
|
||||
|
||||
|
|
|
|||
|
|
@ -17,9 +17,7 @@ import torch
|
|||
def run_command(cmd):
|
||||
"""Run a shell command and return output, or None if it fails."""
|
||||
try:
|
||||
result = subprocess.run(
|
||||
cmd, shell=True, capture_output=True, text=True, timeout=5
|
||||
)
|
||||
result = subprocess.run(cmd, shell=True, capture_output=True, text=True, timeout=5)
|
||||
if result.returncode == 0:
|
||||
return result.stdout.strip()
|
||||
return None
|
||||
|
|
@ -30,16 +28,16 @@ def run_command(cmd):
|
|||
def get_git_info():
|
||||
"""Get current git commit, branch, and dirty status."""
|
||||
info = {}
|
||||
info["commit"] = run_command("git rev-parse --short HEAD") or "unknown"
|
||||
info["branch"] = run_command("git rev-parse --abbrev-ref HEAD") or "unknown"
|
||||
info['commit'] = run_command("git rev-parse --short HEAD") or "unknown"
|
||||
info['branch'] = run_command("git rev-parse --abbrev-ref HEAD") or "unknown"
|
||||
|
||||
# Check if repo is dirty (has uncommitted changes)
|
||||
status = run_command("git status --porcelain")
|
||||
info["dirty"] = bool(status) if status is not None else False
|
||||
info['dirty'] = bool(status) if status is not None else False
|
||||
|
||||
# Get commit message
|
||||
info["message"] = run_command("git log -1 --pretty=%B") or ""
|
||||
info["message"] = info["message"].split("\n")[0][:80] # First line, truncated
|
||||
info['message'] = run_command("git log -1 --pretty=%B") or ""
|
||||
info['message'] = info['message'].split('\n')[0][:80] # First line, truncated
|
||||
|
||||
return info
|
||||
|
||||
|
|
@ -68,20 +66,20 @@ def get_system_info():
|
|||
info = {}
|
||||
|
||||
# Basic system info
|
||||
info["hostname"] = socket.gethostname()
|
||||
info["platform"] = platform.system()
|
||||
info["python_version"] = platform.python_version()
|
||||
info["torch_version"] = torch.__version__
|
||||
info['hostname'] = socket.gethostname()
|
||||
info['platform'] = platform.system()
|
||||
info['python_version'] = platform.python_version()
|
||||
info['torch_version'] = torch.__version__
|
||||
|
||||
# CPU and memory
|
||||
info["cpu_count"] = psutil.cpu_count(logical=False)
|
||||
info["cpu_count_logical"] = psutil.cpu_count(logical=True)
|
||||
info["memory_gb"] = psutil.virtual_memory().total / (1024**3)
|
||||
info['cpu_count'] = psutil.cpu_count(logical=False)
|
||||
info['cpu_count_logical'] = psutil.cpu_count(logical=True)
|
||||
info['memory_gb'] = psutil.virtual_memory().total / (1024**3)
|
||||
|
||||
# User and environment
|
||||
info["user"] = os.environ.get("USER", "unknown")
|
||||
info["nanochat_base_dir"] = os.environ.get("NANOCHAT_BASE_DIR", "out")
|
||||
info["working_dir"] = os.getcwd()
|
||||
info['user'] = os.environ.get('USER', 'unknown')
|
||||
info['nanochat_base_dir'] = os.environ.get('NANOCHAT_BASE_DIR', 'out')
|
||||
info['working_dir'] = os.getcwd()
|
||||
|
||||
return info
|
||||
|
||||
|
|
@ -165,18 +163,16 @@ Generated: {timestamp}
|
|||
"""
|
||||
|
||||
# bloat metrics: package all of the source code and assess its weight
|
||||
packaged = run_command(
|
||||
'files-to-prompt . -e py -e md -e rs -e html -e toml -e sh --ignore "*target*" --cxml'
|
||||
)
|
||||
packaged = run_command('files-to-prompt . -e py -e md -e rs -e html -e toml -e sh --ignore "*target*" --cxml')
|
||||
num_chars = len(packaged)
|
||||
num_lines = len(packaged.split("\n"))
|
||||
num_files = len([x for x in packaged.split("\n") if x.startswith("<source>")])
|
||||
num_lines = len(packaged.split('\n'))
|
||||
num_files = len([x for x in packaged.split('\n') if x.startswith('<source>')])
|
||||
num_tokens = num_chars // 4 # assume approximately 4 chars per token
|
||||
|
||||
# count dependencies via uv.lock
|
||||
uv_lock_lines = 0
|
||||
if os.path.exists("uv.lock"):
|
||||
with open("uv.lock", encoding="utf-8") as f:
|
||||
if os.path.exists('uv.lock'):
|
||||
with open('uv.lock', encoding='utf-8') as f:
|
||||
uv_lock_lines = len(f.readlines())
|
||||
|
||||
header += f"""
|
||||
|
|
@ -231,7 +227,7 @@ def extract(section, keys):
|
|||
|
||||
def extract_timestamp(content, prefix):
|
||||
"""Extract timestamp from content with given prefix."""
|
||||
for line in content.split("\n"):
|
||||
for line in content.split('\n'):
|
||||
if line.startswith(prefix):
|
||||
time_str = line.split(":", 1)[1].strip()
|
||||
try:
|
||||
|
|
@ -255,9 +251,7 @@ class Report:
|
|||
file_path = os.path.join(self.report_dir, file_name)
|
||||
with open(file_path, "w", encoding="utf-8") as f:
|
||||
f.write(f"## {section}\n")
|
||||
f.write(
|
||||
f"timestamp: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n"
|
||||
)
|
||||
f.write(f"timestamp: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n")
|
||||
for item in data:
|
||||
if not item:
|
||||
# skip falsy values like None or empty dict etc.
|
||||
|
|
@ -283,9 +277,7 @@ class Report:
|
|||
report_dir = self.report_dir
|
||||
report_file = os.path.join(report_dir, "report.md")
|
||||
print(f"Generating report to {report_file}")
|
||||
final_metrics = (
|
||||
{}
|
||||
) # the most important final metrics we'll add as table at the end
|
||||
final_metrics = {} # the most important final metrics we'll add as table at the end
|
||||
start_time = None
|
||||
end_time = None
|
||||
with open(report_file, "w", encoding="utf-8") as out_file:
|
||||
|
|
@ -297,18 +289,12 @@ class Report:
|
|||
out_file.write(header_content)
|
||||
start_time = extract_timestamp(header_content, "Run started:")
|
||||
# capture bloat data for summary later (the stuff after Bloat header and until \n\n)
|
||||
bloat_data = re.search(
|
||||
r"### Bloat\n(.*?)\n\n", header_content, re.DOTALL
|
||||
)
|
||||
bloat_data = re.search(r"### Bloat\n(.*?)\n\n", header_content, re.DOTALL)
|
||||
bloat_data = bloat_data.group(1) if bloat_data else ""
|
||||
else:
|
||||
start_time = (
|
||||
None # will cause us to not write the total wall clock time
|
||||
)
|
||||
start_time = None # will cause us to not write the total wall clock time
|
||||
bloat_data = "[bloat data missing]"
|
||||
print(
|
||||
f"Warning: {header_file} does not exist. Did you forget to run `nanochat reset`?"
|
||||
)
|
||||
print(f"Warning: {header_file} does not exist. Did you forget to run `nanochat reset`?")
|
||||
# process all the individual sections
|
||||
for file_name in EXPECTED_FILES:
|
||||
section_file = os.path.join(report_dir, file_name)
|
||||
|
|
@ -329,9 +315,7 @@ class Report:
|
|||
if file_name == "chat-evaluation-sft.md":
|
||||
final_metrics["sft"] = extract(section, chat_metrics)
|
||||
if file_name == "chat-evaluation-rl.md":
|
||||
final_metrics["rl"] = extract(
|
||||
section, "GSM8K"
|
||||
) # RL only evals GSM8K
|
||||
final_metrics["rl"] = extract(section, "GSM8K") # RL only evals GSM8K
|
||||
# append this section of the report
|
||||
out_file.write(section)
|
||||
out_file.write("\n")
|
||||
|
|
@ -345,9 +329,7 @@ class Report:
|
|||
for stage_metrics in final_metrics.values():
|
||||
all_metrics.update(stage_metrics.keys())
|
||||
# Custom ordering: CORE first, ChatCORE last, rest in middle
|
||||
all_metrics = sorted(
|
||||
all_metrics, key=lambda x: (x != "CORE", x == "ChatCORE", x)
|
||||
)
|
||||
all_metrics = sorted(all_metrics, key=lambda x: (x != "CORE", x == "ChatCORE", x))
|
||||
# Fixed column widths
|
||||
stages = ["base", "mid", "sft", "rl"]
|
||||
metric_width = 15
|
||||
|
|
@ -380,7 +362,7 @@ class Report:
|
|||
else:
|
||||
out_file.write("Total wall clock time: unknown\n")
|
||||
# also cp the report.md file to current directory
|
||||
print(f"Copying report.md to current directory for convenience")
|
||||
print("Copying report.md to current directory for convenience")
|
||||
shutil.copy(report_file, "report.md")
|
||||
return report_file
|
||||
|
||||
|
|
@ -432,9 +414,7 @@ def get_report():
|
|||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Generate or reset nanochat training reports."
|
||||
)
|
||||
parser = argparse.ArgumentParser(description="Generate or reset nanochat training reports.")
|
||||
parser.add_argument(
|
||||
"command",
|
||||
nargs="?",
|
||||
|
|
|
|||
|
|
@ -27,13 +27,14 @@ SPECIAL_TOKENS = [
|
|||
# NOTE: this split pattern deviates from GPT-4 in that we use \p{N}{1,2} instead of \p{N}{1,3}
|
||||
# I did this because I didn't want to "waste" too many tokens on numbers for smaller vocab sizes.
|
||||
# I haven't validated that this is actually a good idea, TODO.
|
||||
SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,2}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""
|
||||
SPLIT_PATTERN = (
|
||||
r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,2}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""
|
||||
)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Generic GPT-4-style tokenizer based on HuggingFace Tokenizer
|
||||
from tokenizers import Regex
|
||||
from tokenizers import Regex, decoders, pre_tokenizers
|
||||
from tokenizers import Tokenizer as HFTokenizer
|
||||
from tokenizers import decoders, pre_tokenizers
|
||||
from tokenizers.models import BPE
|
||||
from tokenizers.trainers import BpeTrainer
|
||||
|
||||
|
|
@ -75,14 +76,10 @@ class HuggingFaceTokenizer:
|
|||
# NOTE: The pattern was changed from \p{N}{1,3} to \p{N}{1,2} because I suspect it is harmful to
|
||||
# very small models and smaller vocab sizes, because it is a little bit wasteful in the token space.
|
||||
# (but I haven't validated this! TODO)
|
||||
gpt4_split_regex = Regex(
|
||||
SPLIT_PATTERN
|
||||
) # huggingface demands that you wrap it in Regex!!
|
||||
gpt4_split_regex = Regex(SPLIT_PATTERN) # huggingface demands that you wrap it in Regex!!
|
||||
tokenizer.pre_tokenizer = pre_tokenizers.Sequence(
|
||||
[
|
||||
pre_tokenizers.Split(
|
||||
pattern=gpt4_split_regex, behavior="isolated", invert=False
|
||||
),
|
||||
pre_tokenizers.Split(pattern=gpt4_split_regex, behavior="isolated", invert=False),
|
||||
pre_tokenizers.ByteLevel(add_prefix_space=False, use_regex=False),
|
||||
]
|
||||
)
|
||||
|
|
@ -119,15 +116,11 @@ class HuggingFaceTokenizer:
|
|||
assert isinstance(text, str)
|
||||
ids = []
|
||||
if prepend is not None:
|
||||
prepend_id = (
|
||||
prepend if isinstance(prepend, int) else self.encode_special(prepend)
|
||||
)
|
||||
prepend_id = prepend if isinstance(prepend, int) else self.encode_special(prepend)
|
||||
ids.append(prepend_id)
|
||||
ids.extend(self.tokenizer.encode(text, add_special_tokens=False).ids)
|
||||
if append is not None:
|
||||
append_id = (
|
||||
append if isinstance(append, int) else self.encode_special(append)
|
||||
)
|
||||
append_id = append if isinstance(append, int) else self.encode_special(append)
|
||||
ids.append(append_id)
|
||||
return ids
|
||||
|
||||
|
|
@ -183,20 +176,14 @@ class RustBPETokenizer:
|
|||
tokenizer = rustbpe.Tokenizer()
|
||||
# the special tokens are inserted later in __init__, we don't train them here
|
||||
vocab_size_no_special = vocab_size - len(SPECIAL_TOKENS)
|
||||
assert (
|
||||
vocab_size_no_special >= 256
|
||||
), f"vocab_size_no_special must be at least 256, got {vocab_size_no_special}"
|
||||
tokenizer.train_from_iterator(
|
||||
text_iterator, vocab_size_no_special, pattern=SPLIT_PATTERN
|
||||
)
|
||||
assert vocab_size_no_special >= 256, f"vocab_size_no_special must be at least 256, got {vocab_size_no_special}"
|
||||
tokenizer.train_from_iterator(text_iterator, vocab_size_no_special, pattern=SPLIT_PATTERN)
|
||||
# 2) construct the associated tiktoken encoding for inference
|
||||
pattern = tokenizer.get_pattern()
|
||||
mergeable_ranks_list = tokenizer.get_mergeable_ranks()
|
||||
mergeable_ranks = {bytes(k): v for k, v in mergeable_ranks_list}
|
||||
tokens_offset = len(mergeable_ranks)
|
||||
special_tokens = {
|
||||
name: tokens_offset + i for i, name in enumerate(SPECIAL_TOKENS)
|
||||
}
|
||||
special_tokens = {name: tokens_offset + i for i, name in enumerate(SPECIAL_TOKENS)}
|
||||
enc = tiktoken.Encoding(
|
||||
name="rustbpe",
|
||||
pat_str=pattern,
|
||||
|
|
@ -242,13 +229,9 @@ class RustBPETokenizer:
|
|||
# text can be either a string or a list of strings
|
||||
|
||||
if prepend is not None:
|
||||
prepend_id = (
|
||||
prepend if isinstance(prepend, int) else self.encode_special(prepend)
|
||||
)
|
||||
prepend_id = prepend if isinstance(prepend, int) else self.encode_special(prepend)
|
||||
if append is not None:
|
||||
append_id = (
|
||||
append if isinstance(append, int) else self.encode_special(append)
|
||||
)
|
||||
append_id = append if isinstance(append, int) else self.encode_special(append)
|
||||
|
||||
if isinstance(text, str):
|
||||
ids = self.enc.encode_ordinary(text)
|
||||
|
|
@ -305,12 +288,8 @@ class RustBPETokenizer:
|
|||
# some conversation surgery is necessary here for now...
|
||||
conversation = copy.deepcopy(conversation) # avoid mutating the original
|
||||
messages = conversation["messages"]
|
||||
assert (
|
||||
messages[1]["role"] == "user"
|
||||
), "System message must be followed by a user message"
|
||||
messages[1]["content"] = (
|
||||
messages[0]["content"] + "\n\n" + messages[1]["content"]
|
||||
)
|
||||
assert messages[1]["role"] == "user", "System message must be followed by a user message"
|
||||
messages[1]["content"] = messages[0]["content"] + "\n\n" + messages[1]["content"]
|
||||
messages = messages[1:]
|
||||
else:
|
||||
messages = conversation["messages"]
|
||||
|
|
@ -318,36 +297,28 @@ class RustBPETokenizer:
|
|||
|
||||
# fetch all the special tokens we need
|
||||
bos = self.get_bos_token_id()
|
||||
user_start, user_end = self.encode_special(
|
||||
"<|user_start|>"
|
||||
), self.encode_special("<|user_end|>")
|
||||
assistant_start, assistant_end = self.encode_special(
|
||||
"<|assistant_start|>"
|
||||
), self.encode_special("<|assistant_end|>")
|
||||
python_start, python_end = self.encode_special(
|
||||
"<|python_start|>"
|
||||
), self.encode_special("<|python_end|>")
|
||||
output_start, output_end = self.encode_special(
|
||||
"<|output_start|>"
|
||||
), self.encode_special("<|output_end|>")
|
||||
user_start, user_end = self.encode_special("<|user_start|>"), self.encode_special("<|user_end|>")
|
||||
assistant_start, assistant_end = (
|
||||
self.encode_special("<|assistant_start|>"),
|
||||
self.encode_special("<|assistant_end|>"),
|
||||
)
|
||||
python_start, python_end = self.encode_special("<|python_start|>"), self.encode_special("<|python_end|>")
|
||||
output_start, output_end = self.encode_special("<|output_start|>"), self.encode_special("<|output_end|>")
|
||||
|
||||
# now we can tokenize the conversation
|
||||
add_tokens(bos, 0)
|
||||
for i, message in enumerate(messages):
|
||||
|
||||
# some sanity checking here around assumptions, to prevent footguns
|
||||
must_be_from = "user" if i % 2 == 0 else "assistant"
|
||||
assert (
|
||||
message["role"] == must_be_from
|
||||
), f"Message {i} is from {message['role']} but should be from {must_be_from}"
|
||||
assert message["role"] == must_be_from, (
|
||||
f"Message {i} is from {message['role']} but should be from {must_be_from}"
|
||||
)
|
||||
|
||||
# content can be either a simple string or a list of parts (e.g. containing tool calls)
|
||||
content = message["content"]
|
||||
|
||||
if message["role"] == "user":
|
||||
assert isinstance(
|
||||
content, str
|
||||
), "User messages are simply expected to be strings"
|
||||
assert isinstance(content, str), "User messages are simply expected to be strings"
|
||||
value_ids = self.encode(content)
|
||||
add_tokens(user_start, 0)
|
||||
add_tokens(value_ids, 0)
|
||||
|
|
@ -388,10 +359,10 @@ class RustBPETokenizer:
|
|||
|
||||
def visualize_tokenization(self, ids, mask, with_token_id=False):
|
||||
"""Small helper function useful in debugging: visualize the tokenization of render_conversation"""
|
||||
RED = "\033[91m"
|
||||
GREEN = "\033[92m"
|
||||
RESET = "\033[0m"
|
||||
GRAY = "\033[90m"
|
||||
RED = '\033[91m'
|
||||
GREEN = '\033[92m'
|
||||
RESET = '\033[0m'
|
||||
GRAY = '\033[90m'
|
||||
tokens = []
|
||||
for i, (token_id, mask_val) in enumerate(zip(ids, mask)):
|
||||
token_str = self.decode([token_id])
|
||||
|
|
@ -399,7 +370,7 @@ class RustBPETokenizer:
|
|||
tokens.append(f"{color}{token_str}{RESET}")
|
||||
if with_token_id:
|
||||
tokens.append(f"{GRAY}({token_id}){RESET}")
|
||||
return "|".join(tokens)
|
||||
return '|'.join(tokens)
|
||||
|
||||
def render_for_completion(self, conversation):
|
||||
"""
|
||||
|
|
@ -410,9 +381,7 @@ class RustBPETokenizer:
|
|||
# We have some surgery to do: we need to pop the last message (of the Assistant)
|
||||
conversation = copy.deepcopy(conversation) # avoid mutating the original
|
||||
messages = conversation["messages"]
|
||||
assert (
|
||||
messages[-1]["role"] == "assistant"
|
||||
), "Last message must be from the Assistant"
|
||||
assert messages[-1]["role"] == "assistant", "Last message must be from the Assistant"
|
||||
messages.pop() # remove the last message (of the Assistant) inplace
|
||||
|
||||
# Now tokenize the conversation
|
||||
|
|
@ -445,9 +414,9 @@ def get_token_bytes(device="cpu"):
|
|||
base_dir = get_base_dir()
|
||||
tokenizer_dir = os.path.join(base_dir, "tokenizer")
|
||||
token_bytes_path = os.path.join(tokenizer_dir, "token_bytes.pt")
|
||||
assert os.path.exists(
|
||||
token_bytes_path
|
||||
), f"Token bytes not found at {token_bytes_path}? It gets written by tok_train.py"
|
||||
assert os.path.exists(token_bytes_path), (
|
||||
f"Token bytes not found at {token_bytes_path}? It gets written by tok_train.py"
|
||||
)
|
||||
with open(token_bytes_path, "rb") as f:
|
||||
token_bytes = torch.load(f, map_location=device)
|
||||
return token_bytes
|
||||
|
|
|
|||
|
|
@ -77,18 +77,24 @@ conflicts = [
|
|||
],
|
||||
]
|
||||
|
||||
[tool.autoflake]
|
||||
remove-all-unused-imports = true
|
||||
in-place = true
|
||||
[tool.ruff]
|
||||
target-version = "py310"
|
||||
line-length = 120
|
||||
fix = true
|
||||
unsafe-fixes = true
|
||||
|
||||
[tool.isort]
|
||||
profile = "black"
|
||||
[tool.ruff.lint]
|
||||
select = [
|
||||
"F", # Pyflakes (unused imports) - replaces autoflake
|
||||
"I", # isort - replaces isort
|
||||
"UP", # pyupgrade - replaces pyupgrade
|
||||
]
|
||||
|
||||
[tool.pyupgrade]
|
||||
py310-plus = true
|
||||
[tool.ruff.lint.isort]
|
||||
known-first-party = ["nanochat"]
|
||||
|
||||
[tool.black]
|
||||
target-version = ["py310"]
|
||||
[tool.ruff.format]
|
||||
quote-style = "preserve"
|
||||
|
||||
[tool.codespell]
|
||||
write-changes = true
|
||||
|
|
|
|||
|
|
@ -48,7 +48,7 @@ def place_eval_bundle(file_path):
|
|||
base_dir = get_base_dir()
|
||||
eval_bundle_dir = os.path.join(base_dir, "eval_bundle")
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
with zipfile.ZipFile(file_path, "r") as zip_ref:
|
||||
with zipfile.ZipFile(file_path, 'r') as zip_ref:
|
||||
zip_ref.extractall(tmpdir)
|
||||
extracted_bundle_dir = os.path.join(tmpdir, "eval_bundle")
|
||||
shutil.move(extracted_bundle_dir, eval_bundle_dir)
|
||||
|
|
@ -65,23 +65,21 @@ def evaluate_model(model, tokenizer, device, max_per_task=-1):
|
|||
eval_bundle_dir = os.path.join(base_dir, "eval_bundle")
|
||||
# Download the eval bundle to disk (and unzip if needed)
|
||||
if not os.path.exists(eval_bundle_dir):
|
||||
download_file_with_lock(
|
||||
EVAL_BUNDLE_URL, "eval_bundle.zip", postprocess_fn=place_eval_bundle
|
||||
)
|
||||
download_file_with_lock(EVAL_BUNDLE_URL, "eval_bundle.zip", postprocess_fn=place_eval_bundle)
|
||||
config_path = os.path.join(eval_bundle_dir, "core.yaml")
|
||||
data_base_path = os.path.join(eval_bundle_dir, "eval_data")
|
||||
eval_meta_data = os.path.join(eval_bundle_dir, "eval_meta_data.csv")
|
||||
with open(config_path, encoding="utf-8") as f:
|
||||
with open(config_path, encoding='utf-8') as f:
|
||||
config = yaml.safe_load(f)
|
||||
tasks = config["icl_tasks"]
|
||||
tasks = config['icl_tasks']
|
||||
|
||||
# Load random baseline values from eval metadata
|
||||
random_baselines = {}
|
||||
with open(eval_meta_data, encoding="utf-8") as f:
|
||||
with open(eval_meta_data, encoding='utf-8') as f:
|
||||
reader = csv.DictReader(f)
|
||||
for row in reader:
|
||||
task_name = row["Eval Task"]
|
||||
random_baseline = row["Random baseline"]
|
||||
task_name = row['Eval Task']
|
||||
random_baseline = row['Random baseline']
|
||||
random_baselines[task_name] = float(random_baseline)
|
||||
|
||||
# Evaluate each task
|
||||
|
|
@ -89,21 +87,18 @@ def evaluate_model(model, tokenizer, device, max_per_task=-1):
|
|||
centered_results = {}
|
||||
for task in tasks:
|
||||
start_time = time.time()
|
||||
label = task["label"]
|
||||
label = task['label']
|
||||
task_meta = {
|
||||
"task_type": task["icl_task_type"],
|
||||
"dataset_uri": task["dataset_uri"],
|
||||
"num_fewshot": task["num_fewshot"][0],
|
||||
"continuation_delimiter": task.get("continuation_delimiter", " "),
|
||||
'task_type': task['icl_task_type'],
|
||||
'dataset_uri': task['dataset_uri'],
|
||||
'num_fewshot': task['num_fewshot'][0],
|
||||
'continuation_delimiter': task.get('continuation_delimiter', ' '),
|
||||
}
|
||||
print0(
|
||||
f"Evaluating: {label} ({task_meta['num_fewshot']}-shot, type: {task_meta['task_type']})... ",
|
||||
end="",
|
||||
)
|
||||
print0(f"Evaluating: {label} ({task_meta['num_fewshot']}-shot, type: {task_meta['task_type']})... ", end='')
|
||||
|
||||
# Load data for this task
|
||||
data_path = os.path.join(data_base_path, task_meta["dataset_uri"])
|
||||
with open(data_path, encoding="utf-8") as f:
|
||||
data_path = os.path.join(data_base_path, task_meta['dataset_uri'])
|
||||
with open(data_path, encoding='utf-8') as f:
|
||||
data = [json.loads(line.strip()) for line in f]
|
||||
|
||||
# shuffle the data because in many cases it appears ordered but we want
|
||||
|
|
@ -118,21 +113,13 @@ def evaluate_model(model, tokenizer, device, max_per_task=-1):
|
|||
|
||||
results[label] = accuracy
|
||||
random_baseline = random_baselines[label]
|
||||
centered_result = (accuracy - 0.01 * random_baseline) / (
|
||||
1.0 - 0.01 * random_baseline
|
||||
)
|
||||
centered_result = (accuracy - 0.01 * random_baseline) / (1.0 - 0.01 * random_baseline)
|
||||
centered_results[label] = centered_result
|
||||
end_time = time.time()
|
||||
print0(
|
||||
f"accuracy: {accuracy:.4f} | centered: {centered_result:.4f} | time: {end_time - start_time:.2f}s"
|
||||
)
|
||||
print0(f"accuracy: {accuracy:.4f} | centered: {centered_result:.4f} | time: {end_time - start_time:.2f}s")
|
||||
|
||||
core_metric = sum(centered_results.values()) / len(centered_results)
|
||||
out = {
|
||||
"results": results,
|
||||
"centered_results": centered_results,
|
||||
"core_metric": core_metric,
|
||||
}
|
||||
out = {"results": results, "centered_results": centered_results, "core_metric": core_metric}
|
||||
return out
|
||||
|
||||
|
||||
|
|
@ -173,24 +160,15 @@ def main():
|
|||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--hf-path", type=str, default=None, help="HuggingFace model path to evaluate"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-per-task",
|
||||
type=int,
|
||||
default=-1,
|
||||
help="Max examples per task to evaluate (-1 = disable)",
|
||||
)
|
||||
parser.add_argument('--hf-path', type=str, default=None, help='HuggingFace model path to evaluate')
|
||||
parser.add_argument('--max-per-task', type=int, default=-1, help='Max examples per task to evaluate (-1 = disable)')
|
||||
args = parser.parse_args()
|
||||
|
||||
# distributed / precision setup
|
||||
device_type = autodetect_device_type()
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||
autocast_ctx = (
|
||||
torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16)
|
||||
if device_type == "cuda"
|
||||
else nullcontext()
|
||||
torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
|
||||
)
|
||||
|
||||
# Load model and tokenizer from command line or from file system
|
||||
|
|
@ -221,18 +199,16 @@ def main():
|
|||
results = out["results"]
|
||||
centered_results = out["centered_results"]
|
||||
core_metric = out["core_metric"]
|
||||
with open(output_csv_path, "w", encoding="utf-8", newline="") as f:
|
||||
with open(output_csv_path, 'w', encoding='utf-8', newline='') as f:
|
||||
f.write(f"{'Task':<35}, {'Accuracy':<10}, {'Centered':<10}\n")
|
||||
for label in results:
|
||||
f.write(
|
||||
f"{label:<35}, {results[label]:<10.6f}, {centered_results[label]:<10.6f}\n"
|
||||
)
|
||||
f.write(f"{label:<35}, {results[label]:<10.6f}, {centered_results[label]:<10.6f}\n")
|
||||
f.write(f"{'CORE':<35}, {'':<10}, {core_metric:<10.6f}\n")
|
||||
# Print the content of the csv file to console too
|
||||
print0("=" * 80)
|
||||
print0(f"Model: {model_name}")
|
||||
print0("=" * 80)
|
||||
with open(output_csv_path, encoding="utf-8") as f:
|
||||
with open(output_csv_path, encoding='utf-8') as f:
|
||||
print0(f.read())
|
||||
|
||||
# Log to report
|
||||
|
|
|
|||
|
|
@ -13,12 +13,7 @@ from contextlib import nullcontext
|
|||
import torch
|
||||
|
||||
from nanochat.checkpoint_manager import load_model
|
||||
from nanochat.common import (
|
||||
autodetect_device_type,
|
||||
compute_cleanup,
|
||||
compute_init,
|
||||
print0,
|
||||
)
|
||||
from nanochat.common import autodetect_device_type, compute_cleanup, compute_init, print0
|
||||
from nanochat.dataloader import tokenizing_distributed_data_loader
|
||||
from nanochat.engine import Engine
|
||||
from nanochat.loss_eval import evaluate_bpb
|
||||
|
|
@ -30,35 +25,25 @@ split_tokens = 20 * 524288 # number of tokens to evaluate per split
|
|||
model_tag = None # optional model tag for the output directory name
|
||||
model_step = None # optional model step for the output directory name
|
||||
device_type = "" # cuda|cpu|mps (empty => autodetect)
|
||||
exec(
|
||||
open(os.path.join("nanochat", "configurator.py")).read()
|
||||
) # overrides from command line or config file
|
||||
exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file
|
||||
|
||||
# Load the base model and the tokenizer
|
||||
device_type = autodetect_device_type() if device_type == "" else device_type
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||
model, tokenizer, meta = load_model(
|
||||
"base", device, phase="eval", model_tag=model_tag, step=model_step
|
||||
)
|
||||
model, tokenizer, meta = load_model("base", device, phase="eval", model_tag=model_tag, step=model_step)
|
||||
sequence_len = meta["model_config"]["sequence_len"] # could be arbitrary really
|
||||
autocast_ctx = (
|
||||
torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16)
|
||||
if device_type == "cuda"
|
||||
else nullcontext()
|
||||
torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
|
||||
)
|
||||
|
||||
# Evaluate the loss on each split
|
||||
tokens_per_step = device_batch_size * sequence_len * ddp_world_size
|
||||
assert (
|
||||
split_tokens % tokens_per_step == 0
|
||||
), "split_tokens must be divisible by tokens_per_step"
|
||||
assert split_tokens % tokens_per_step == 0, "split_tokens must be divisible by tokens_per_step"
|
||||
steps = split_tokens // tokens_per_step
|
||||
token_bytes = get_token_bytes(device=device)
|
||||
bpb_results = {}
|
||||
for split_name in ["train", "val"]:
|
||||
loader = tokenizing_distributed_data_loader(
|
||||
device_batch_size, sequence_len, split_name, device=device
|
||||
)
|
||||
loader = tokenizing_distributed_data_loader(device_batch_size, sequence_len, split_name, device=device)
|
||||
with autocast_ctx:
|
||||
bpb = evaluate_bpb(model, loader, steps, token_bytes)
|
||||
print0(f"{split_name} bpb: {bpb:.4f}")
|
||||
|
|
@ -80,9 +65,7 @@ if ddp_rank == 0:
|
|||
for prompt in prompts:
|
||||
tokens = tokenizer(prompt, prepend="<|bos|>")
|
||||
with autocast_ctx:
|
||||
sample, _ = engine.generate_batch(
|
||||
tokens, num_samples=1, max_tokens=16, temperature=0
|
||||
)
|
||||
sample, _ = engine.generate_batch(tokens, num_samples=1, max_tokens=16, temperature=0)
|
||||
sample_str = tokenizer.decode(sample[0])
|
||||
print0(sample_str)
|
||||
samples.append(sample_str)
|
||||
|
|
|
|||
|
|
@ -30,10 +30,7 @@ from nanochat.common import (
|
|||
print0,
|
||||
print_banner,
|
||||
)
|
||||
from nanochat.dataloader import (
|
||||
tokenizing_distributed_data_loader,
|
||||
tokenizing_distributed_data_loader_with_state,
|
||||
)
|
||||
from nanochat.dataloader import tokenizing_distributed_data_loader, tokenizing_distributed_data_loader_with_state
|
||||
from nanochat.engine import Engine
|
||||
from nanochat.gpt import GPT, GPTConfig
|
||||
from nanochat.loss_eval import evaluate_bpb
|
||||
|
|
@ -48,16 +45,16 @@ run = "dummy" # wandb run name default ("dummy" is special - we won't log to wa
|
|||
# Runtime
|
||||
device_type = "" # cuda|cpu|mps (empty => autodetect good device type default, in order: CUDA > MPS > CPU)
|
||||
# Model architecture
|
||||
depth = (
|
||||
20 # the depth of the Transformer model to train, rest of the kwargs are derived
|
||||
)
|
||||
depth = 20 # the depth of the Transformer model to train, rest of the kwargs are derived
|
||||
max_seq_len = 2048 # max context length
|
||||
# Training horizon. Only one of these 3 will be used, in this order of precedence.
|
||||
num_iterations = -1 # explicit number of steps of the optimization (-1 = disable)
|
||||
target_flops = (
|
||||
-1.0
|
||||
) # calculate num_iterations to reach target_flops. Useful for scaling laws experiments (-1 = disable)
|
||||
target_param_data_ratio = 20 # calculate num_iterations to maintain fixed data:param ratio (Chinchilla=20) (-1 = disable)
|
||||
target_param_data_ratio = (
|
||||
20 # calculate num_iterations to maintain fixed data:param ratio (Chinchilla=20) (-1 = disable)
|
||||
)
|
||||
# Optimization
|
||||
device_batch_size = 32 # per-device batch size (set to not OOM)
|
||||
total_batch_size = 524288 # total desired batch size, in #tokens
|
||||
|
|
@ -69,33 +66,19 @@ grad_clip = 1.0 # gradient clipping value (0.0 = disabled)
|
|||
warmup_ratio = 0.0 # ratio of iterations for LR warmup
|
||||
warmdown_ratio = 0.2 # ratio of iterations for LR warmdown
|
||||
final_lr_frac = 0.0 # final LR is this fraction of the initial LR
|
||||
resume_from_step = (
|
||||
-1
|
||||
) # resume training from this step of the optimization (-1 = disable)
|
||||
resume_from_step = -1 # resume training from this step of the optimization (-1 = disable)
|
||||
# Evaluation
|
||||
eval_every = 250 # every how many steps to evaluate the model for val bpb
|
||||
eval_tokens = 20 * 524288 # number of tokens to evaluate val loss on
|
||||
core_metric_every = (
|
||||
2000 # every how many steps to evaluate the core metric (-1 = disable)
|
||||
)
|
||||
core_metric_every = 2000 # every how many steps to evaluate the core metric (-1 = disable)
|
||||
core_metric_max_per_task = 500 # examples per task in estimating the core metric
|
||||
sample_every = 2000 # every how many steps to sample from the model
|
||||
save_every = (
|
||||
-1
|
||||
) # every how many steps to save model checkpoints (-1 = disable, and save only at the end of the run)
|
||||
save_every = -1 # every how many steps to save model checkpoints (-1 = disable, and save only at the end of the run)
|
||||
# Output
|
||||
model_tag = (
|
||||
"" # optionally override the model tag for the output checkpoint directory name
|
||||
)
|
||||
model_tag = "" # optionally override the model tag for the output checkpoint directory name
|
||||
# now allow CLI to override the settings via the configurator lol
|
||||
config_keys = [
|
||||
k
|
||||
for k, v in globals().items()
|
||||
if not k.startswith("_") and isinstance(v, (int, float, bool, str))
|
||||
]
|
||||
exec(
|
||||
open(os.path.join("nanochat", "configurator.py")).read()
|
||||
) # overrides from command line or config file
|
||||
config_keys = [k for k, v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
|
||||
exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file
|
||||
user_config = {k: globals()[k] for k in config_keys} # will be useful for logging
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
|
@ -104,20 +87,14 @@ device_type = autodetect_device_type() if device_type == "" else device_type
|
|||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||
master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
|
||||
autocast_ctx = (
|
||||
torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16)
|
||||
if device_type == "cuda"
|
||||
else nullcontext()
|
||||
torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
|
||||
)
|
||||
synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None
|
||||
get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0
|
||||
|
||||
# wandb logging init
|
||||
use_dummy_wandb = run == "dummy" or not master_process
|
||||
wandb_run = (
|
||||
DummyWandb()
|
||||
if use_dummy_wandb
|
||||
else wandb.init(project="nanochat", name=run, config=user_config)
|
||||
)
|
||||
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat", name=run, config=user_config)
|
||||
|
||||
# Tokenizer will be useful for evaluation, also we need the vocab size
|
||||
tokenizer = get_tokenizer()
|
||||
|
|
@ -127,15 +104,9 @@ print0(f"Vocab size: {vocab_size:,}")
|
|||
|
||||
# Model kwargs are derived from the desired depth of the model
|
||||
num_layers = depth
|
||||
model_dim = (
|
||||
depth * 64
|
||||
) # aspect ratio 64 (usually this is varied from 64 -> 128 as model size increases)
|
||||
num_heads = max(
|
||||
1, (model_dim + 127) // 128
|
||||
) # head dim 128 (the division here is ceil div)
|
||||
num_kv_heads = (
|
||||
num_heads # default is 1:1 GQA (Group Query Attention) ratio (i.e. GQA is disabled)
|
||||
)
|
||||
model_dim = depth * 64 # aspect ratio 64 (usually this is varied from 64 -> 128 as model size increases)
|
||||
num_heads = max(1, (model_dim + 127) // 128) # head dim 128 (the division here is ceil div)
|
||||
num_kv_heads = num_heads # default is 1:1 GQA (Group Query Attention) ratio (i.e. GQA is disabled)
|
||||
print0(f"num_layers: {num_layers}")
|
||||
print0(f"model_dim: {model_dim}")
|
||||
print0(f"num_heads: {num_heads}")
|
||||
|
|
@ -143,21 +114,13 @@ print0(f"num_kv_heads: {num_kv_heads}")
|
|||
|
||||
# Optimizer / data / training length related hyperparameters
|
||||
# figure out the needed gradient accumulation to reach the desired total batch size
|
||||
tokens_per_fwdbwd = (
|
||||
device_batch_size * max_seq_len
|
||||
) # tokens per iteration for a single rank
|
||||
world_tokens_per_fwdbwd = (
|
||||
tokens_per_fwdbwd * ddp_world_size
|
||||
) # total tokens per iteration for all ranks
|
||||
tokens_per_fwdbwd = device_batch_size * max_seq_len # tokens per iteration for a single rank
|
||||
world_tokens_per_fwdbwd = tokens_per_fwdbwd * ddp_world_size # total tokens per iteration for all ranks
|
||||
assert total_batch_size % world_tokens_per_fwdbwd == 0
|
||||
grad_accum_steps = total_batch_size // world_tokens_per_fwdbwd
|
||||
print0(
|
||||
f"Tokens / micro-batch / rank: {device_batch_size} x {max_seq_len} = {tokens_per_fwdbwd:,}"
|
||||
)
|
||||
print0(f"Tokens / micro-batch / rank: {device_batch_size} x {max_seq_len} = {tokens_per_fwdbwd:,}")
|
||||
print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}")
|
||||
print0(
|
||||
f"Total batch size {total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}"
|
||||
)
|
||||
print0(f"Total batch size {total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}")
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Initialize the Model
|
||||
|
|
@ -191,9 +154,7 @@ if resuming:
|
|||
del model_data # free up this memory after the copy
|
||||
|
||||
orig_model = model # original, uncompiled model, for saving raw model state_dict and for inference/evaluation (because the shapes may change shape)
|
||||
model = torch.compile(
|
||||
model, dynamic=False
|
||||
) # the inputs to model will never change shape so dynamic=False is safe
|
||||
model = torch.compile(model, dynamic=False) # the inputs to model will never change shape so dynamic=False is safe
|
||||
num_params = sum(p.numel() for p in model.parameters())
|
||||
print0(f"Number of parameters: {num_params:,}")
|
||||
num_flops_per_token = model.estimate_flops()
|
||||
|
|
@ -211,25 +172,18 @@ elif target_param_data_ratio > 0:
|
|||
# calculate the number of iterations from the target param data ratio
|
||||
target_tokens = target_param_data_ratio * num_params
|
||||
num_iterations = target_tokens // total_batch_size
|
||||
print0(
|
||||
f"Calculated number of iterations from target data:param ratio: {num_iterations:,}"
|
||||
)
|
||||
print0(f"Calculated number of iterations from target data:param ratio: {num_iterations:,}")
|
||||
else:
|
||||
raise ValueError("No training horizon specified")
|
||||
total_tokens = total_batch_size * num_iterations
|
||||
print0(f"Total number of training tokens: {total_tokens:,}")
|
||||
print0(
|
||||
f"Tokens : Params ratio: {total_batch_size * num_iterations / num_params:.2f}"
|
||||
) # Chinchilla is ~20
|
||||
print0(f"Tokens : Params ratio: {total_batch_size * num_iterations / num_params:.2f}") # Chinchilla is ~20
|
||||
print0(f"Total training FLOPs estimate: {num_flops_per_token * total_tokens:e}")
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Initialize the Optimizer (Muon for Linear layers, AdamW for embedding and lm_head)
|
||||
optimizers = model.setup_optimizers(
|
||||
unembedding_lr=unembedding_lr,
|
||||
embedding_lr=embedding_lr,
|
||||
matrix_lr=matrix_lr,
|
||||
weight_decay=weight_decay,
|
||||
unembedding_lr=unembedding_lr, embedding_lr=embedding_lr, matrix_lr=matrix_lr, weight_decay=weight_decay
|
||||
)
|
||||
adamw_optimizer, muon_optimizer = optimizers
|
||||
|
||||
|
|
@ -241,22 +195,14 @@ if resuming:
|
|||
# -----------------------------------------------------------------------------
|
||||
# Initialize the DataLoaders for train/val
|
||||
tokens_dir = os.path.join(base_dir, "tokenized_data")
|
||||
dataloader_resume_state_dict = (
|
||||
None if not resuming else meta_data["dataloader_state_dict"]
|
||||
)
|
||||
dataloader_resume_state_dict = None if not resuming else meta_data["dataloader_state_dict"]
|
||||
train_loader = tokenizing_distributed_data_loader_with_state(
|
||||
device_batch_size,
|
||||
max_seq_len,
|
||||
split="train",
|
||||
device=device,
|
||||
resume_state_dict=dataloader_resume_state_dict,
|
||||
device_batch_size, max_seq_len, split="train", device=device, resume_state_dict=dataloader_resume_state_dict
|
||||
)
|
||||
build_val_loader = lambda: tokenizing_distributed_data_loader(
|
||||
device_batch_size, max_seq_len, split="val", device=device
|
||||
)
|
||||
x, y, dataloader_state_dict = next(
|
||||
train_loader
|
||||
) # kick off load of the very first batch of data
|
||||
x, y, dataloader_state_dict = next(train_loader) # kick off load of the very first batch of data
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Set up hyperparameter schedulers
|
||||
|
|
@ -300,9 +246,7 @@ else:
|
|||
# -----------------------------------------------------------------------------
|
||||
# Training loop
|
||||
while True:
|
||||
last_step = (
|
||||
step == num_iterations
|
||||
) # loop runs num_iterations+1 times so that we can eval/save at the end
|
||||
last_step = step == num_iterations # loop runs num_iterations+1 times so that we can eval/save at the end
|
||||
flops_so_far = num_flops_per_token * total_batch_size * step
|
||||
|
||||
# once in a while: evaluate the val bpb (all ranks participate)
|
||||
|
|
@ -328,14 +272,10 @@ while True:
|
|||
# once in a while: estimate the CORE metric (all ranks participate)
|
||||
# use the original uncompiled model because the inputs keep changing shape
|
||||
results = {}
|
||||
if core_metric_every > 0 and (
|
||||
last_step or (step > 0 and step % core_metric_every == 0)
|
||||
):
|
||||
if core_metric_every > 0 and (last_step or (step > 0 and step % core_metric_every == 0)):
|
||||
model.eval()
|
||||
with autocast_ctx:
|
||||
results = evaluate_model(
|
||||
orig_model, tokenizer, device, max_per_task=core_metric_max_per_task
|
||||
)
|
||||
results = evaluate_model(orig_model, tokenizer, device, max_per_task=core_metric_max_per_task)
|
||||
print0(f"Step {step:05d} | CORE metric: {results['core_metric']:.4f}")
|
||||
wandb_run.log(
|
||||
{
|
||||
|
|
@ -364,19 +304,12 @@ while True:
|
|||
for prompt in prompts:
|
||||
tokens = tokenizer(prompt, prepend="<|bos|>")
|
||||
with autocast_ctx:
|
||||
sample, _ = engine.generate_batch(
|
||||
tokens, num_samples=1, max_tokens=16, temperature=0
|
||||
)
|
||||
sample, _ = engine.generate_batch(tokens, num_samples=1, max_tokens=16, temperature=0)
|
||||
print0(tokenizer.decode(sample[0]))
|
||||
model.train()
|
||||
|
||||
# save checkpoint: at the end of the run, or every save_every steps, except at the first step or the resume step
|
||||
if last_step or (
|
||||
step > 0
|
||||
and step != resume_from_step
|
||||
and save_every > 0
|
||||
and step % save_every == 0
|
||||
):
|
||||
if last_step or (step > 0 and step != resume_from_step and save_every > 0 and step % save_every == 0):
|
||||
save_checkpoint(
|
||||
checkpoint_dir,
|
||||
step,
|
||||
|
|
@ -412,9 +345,7 @@ while True:
|
|||
with autocast_ctx:
|
||||
loss = model(x, y)
|
||||
train_loss = loss.detach() # for logging
|
||||
loss = (
|
||||
loss / grad_accum_steps
|
||||
) # each .backward() is a grad sum => normalize loss here
|
||||
loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here
|
||||
loss.backward()
|
||||
x, y, dataloader_state_dict = next(
|
||||
train_loader
|
||||
|
|
@ -422,12 +353,8 @@ while True:
|
|||
# gradient clipping
|
||||
grad_clip_enabled = grad_clip > 0.0
|
||||
if grad_clip_enabled:
|
||||
grad_norm_tensor = torch.nn.utils.clip_grad_norm_(
|
||||
orig_model.parameters(), grad_clip
|
||||
)
|
||||
grad_norm = (
|
||||
grad_norm_tensor.item()
|
||||
) # GPU tensor -> CPU float (note: cpu-gpu sync point)
|
||||
grad_norm_tensor = torch.nn.utils.clip_grad_norm_(orig_model.parameters(), grad_clip)
|
||||
grad_norm = grad_norm_tensor.item() # GPU tensor -> CPU float (note: cpu-gpu sync point)
|
||||
# step the optimizers
|
||||
lrm = get_lr_multiplier(step)
|
||||
for opt in optimizers:
|
||||
|
|
@ -446,24 +373,18 @@ while True:
|
|||
|
||||
# logging
|
||||
ema_beta = 0.9 # EMA decay factor for some smoothing just for nicer logging
|
||||
smooth_train_loss = (
|
||||
ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss.item()
|
||||
) # EMA the training loss
|
||||
debiased_smooth_loss = smooth_train_loss / (
|
||||
1 - ema_beta ** (step + 1)
|
||||
) # debias the EMA
|
||||
smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss.item() # EMA the training loss
|
||||
debiased_smooth_loss = smooth_train_loss / (1 - ema_beta ** (step + 1)) # debias the EMA
|
||||
pct_done = 100 * step / num_iterations
|
||||
tok_per_sec = int(total_batch_size / dt)
|
||||
flops_per_sec = num_flops_per_token * total_batch_size / dt
|
||||
promised_flops_per_sec_h100 = (
|
||||
989e12 * ddp_world_size
|
||||
) # bfloat16 H100 SXM and without 2:4 sparsity
|
||||
promised_flops_per_sec_h100 = 989e12 * ddp_world_size # bfloat16 H100 SXM and without 2:4 sparsity
|
||||
mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in %
|
||||
if step > 10:
|
||||
total_training_time += dt # only count the time after the first 10 steps
|
||||
print_grad_norm = f" grad norm: {grad_norm:.4f} |" if grad_clip_enabled else ""
|
||||
print0(
|
||||
f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} |{print_grad_norm} lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | total time: {total_training_time/60:.2f}m"
|
||||
f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} |{print_grad_norm} lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | total time: {total_training_time / 60:.2f}m"
|
||||
)
|
||||
if step % 100 == 0:
|
||||
log_data = {
|
||||
|
|
@ -485,7 +406,7 @@ while True:
|
|||
|
||||
# print a few more stats
|
||||
print0(f"Peak memory usage: {get_max_memory() / 1024 / 1024:.2f}MiB")
|
||||
print0(f"Total training time: {total_training_time/60:.2f}m")
|
||||
print0(f"Total training time: {total_training_time / 60:.2f}m")
|
||||
print0(f"Minimum validation bpb: {min_val_bpb:.4f}")
|
||||
|
||||
# Log to report
|
||||
|
|
@ -512,7 +433,7 @@ get_report().log(
|
|||
"CORE metric estimate": results.get("core_metric", None),
|
||||
"MFU %": f"{mfu:.2f}%",
|
||||
"Total training flops": f"{flops_so_far:e}",
|
||||
"Total training time": f"{total_training_time/60:.2f}m",
|
||||
"Total training time": f"{total_training_time / 60:.2f}m",
|
||||
"Peak memory usage": f"{get_max_memory() / 1024 / 1024:.2f}MiB",
|
||||
},
|
||||
],
|
||||
|
|
|
|||
|
|
@ -14,61 +14,38 @@ from nanochat.checkpoint_manager import load_model
|
|||
from nanochat.common import autodetect_device_type, compute_init
|
||||
from nanochat.engine import Engine
|
||||
|
||||
parser = argparse.ArgumentParser(description="Chat with the model")
|
||||
parser = argparse.ArgumentParser(description='Chat with the model')
|
||||
parser.add_argument('-i', '--source', type=str, default="sft", help="Source of the model: sft|mid|rl")
|
||||
parser.add_argument('-g', '--model-tag', type=str, default=None, help='Model tag to load')
|
||||
parser.add_argument('-s', '--step', type=int, default=None, help='Step to load')
|
||||
parser.add_argument('-p', '--prompt', type=str, default='', help='Prompt the model, get a single response back')
|
||||
parser.add_argument('-t', '--temperature', type=float, default=0.6, help='Temperature for generation')
|
||||
parser.add_argument('-k', '--top-k', type=int, default=50, help='Top-k sampling parameter')
|
||||
parser.add_argument(
|
||||
"-i", "--source", type=str, default="sft", help="Source of the model: sft|mid|rl"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-g", "--model-tag", type=str, default=None, help="Model tag to load"
|
||||
)
|
||||
parser.add_argument("-s", "--step", type=int, default=None, help="Step to load")
|
||||
parser.add_argument(
|
||||
"-p",
|
||||
"--prompt",
|
||||
'--device-type',
|
||||
type=str,
|
||||
default="",
|
||||
help="Prompt the model, get a single response back",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-t", "--temperature", type=float, default=0.6, help="Temperature for generation"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-k", "--top-k", type=int, default=50, help="Top-k sampling parameter"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device-type",
|
||||
type=str,
|
||||
default="",
|
||||
choices=["cuda", "cpu", "mps"],
|
||||
help="Device type for evaluation: cuda|cpu|mps. empty => autodetect",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-d", "--dtype", type=str, default="bfloat16", choices=["float32", "bfloat16"]
|
||||
default='',
|
||||
choices=['cuda', 'cpu', 'mps'],
|
||||
help='Device type for evaluation: cuda|cpu|mps. empty => autodetect',
|
||||
)
|
||||
parser.add_argument('-d', '--dtype', type=str, default='bfloat16', choices=['float32', 'bfloat16'])
|
||||
args = parser.parse_args()
|
||||
|
||||
# Init the model and tokenizer
|
||||
|
||||
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||
ptdtype = torch.float32 if args.dtype == "float32" else torch.bfloat16
|
||||
autocast_ctx = (
|
||||
torch.amp.autocast(device_type=device_type, dtype=ptdtype)
|
||||
if device_type == "cuda"
|
||||
else nullcontext()
|
||||
)
|
||||
model, tokenizer, meta = load_model(
|
||||
args.source, device, phase="eval", model_tag=args.model_tag, step=args.step
|
||||
)
|
||||
ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16
|
||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext()
|
||||
model, tokenizer, meta = load_model(args.source, device, phase="eval", model_tag=args.model_tag, step=args.step)
|
||||
|
||||
# Special tokens for the chat state machine
|
||||
bos = tokenizer.get_bos_token_id()
|
||||
user_start, user_end = tokenizer.encode_special(
|
||||
"<|user_start|>"
|
||||
), tokenizer.encode_special("<|user_end|>")
|
||||
assistant_start, assistant_end = tokenizer.encode_special(
|
||||
"<|assistant_start|>"
|
||||
), tokenizer.encode_special("<|assistant_end|>")
|
||||
user_start, user_end = tokenizer.encode_special("<|user_start|>"), tokenizer.encode_special("<|user_end|>")
|
||||
assistant_start, assistant_end = (
|
||||
tokenizer.encode_special("<|assistant_start|>"),
|
||||
tokenizer.encode_special("<|assistant_end|>"),
|
||||
)
|
||||
|
||||
# Create Engine for efficient generation
|
||||
engine = Engine(model, tokenizer)
|
||||
|
|
@ -82,7 +59,6 @@ print("-" * 50)
|
|||
conversation_tokens = [bos]
|
||||
|
||||
while True:
|
||||
|
||||
if args.prompt:
|
||||
# Get the prompt from the launch command
|
||||
user_input = args.prompt
|
||||
|
|
@ -95,11 +71,11 @@ while True:
|
|||
break
|
||||
|
||||
# Handle special commands
|
||||
if user_input.lower() in ["quit", "exit"]:
|
||||
if user_input.lower() in ['quit', 'exit']:
|
||||
print("Goodbye!")
|
||||
break
|
||||
|
||||
if user_input.lower() == "clear":
|
||||
if user_input.lower() == 'clear':
|
||||
conversation_tokens = [bos]
|
||||
print("Conversation cleared.")
|
||||
continue
|
||||
|
|
@ -123,9 +99,7 @@ while True:
|
|||
response_tokens = []
|
||||
print("\nAssistant: ", end="", flush=True)
|
||||
with autocast_ctx:
|
||||
for token_column, token_masks in engine.generate(
|
||||
conversation_tokens, **generate_kwargs
|
||||
):
|
||||
for token_column, token_masks in engine.generate(conversation_tokens, **generate_kwargs):
|
||||
token = token_column[0] # pop the batch dimension (num_samples=1)
|
||||
response_tokens.append(token)
|
||||
token_text = tokenizer.decode([token])
|
||||
|
|
|
|||
|
|
@ -16,13 +16,7 @@ import torch
|
|||
import torch.distributed as dist
|
||||
|
||||
from nanochat.checkpoint_manager import load_model
|
||||
from nanochat.common import (
|
||||
autodetect_device_type,
|
||||
compute_cleanup,
|
||||
compute_init,
|
||||
get_dist_info,
|
||||
print0,
|
||||
)
|
||||
from nanochat.common import autodetect_device_type, compute_cleanup, compute_init, get_dist_info, print0
|
||||
from nanochat.engine import Engine
|
||||
from tasks.arc import ARC
|
||||
from tasks.gsm8k import GSM8K
|
||||
|
|
@ -35,25 +29,12 @@ from tasks.spellingbee import SpellingBee
|
|||
|
||||
|
||||
def run_generative_eval(
|
||||
task_object,
|
||||
tokenizer,
|
||||
model,
|
||||
engine,
|
||||
num_samples,
|
||||
max_new_tokens,
|
||||
temperature,
|
||||
top_k,
|
||||
max_problems=None,
|
||||
task_object, tokenizer, model, engine, num_samples, max_new_tokens, temperature, top_k, max_problems=None
|
||||
):
|
||||
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
|
||||
device = model.get_device()
|
||||
|
||||
num_problems = (
|
||||
len(task_object)
|
||||
if max_problems is None
|
||||
else min(len(task_object), max_problems)
|
||||
)
|
||||
num_problems = len(task_object) if max_problems is None else min(len(task_object), max_problems)
|
||||
|
||||
# Run the evaluation
|
||||
num_passed, total = 0, 0
|
||||
|
|
@ -72,13 +53,9 @@ def run_generative_eval(
|
|||
)
|
||||
# Decode the completions as text
|
||||
prefix_length = len(encoded_prompt)
|
||||
completions = [
|
||||
tokenizer.decode(result_tokens[prefix_length:]) for result_tokens in results
|
||||
]
|
||||
completions = [tokenizer.decode(result_tokens[prefix_length:]) for result_tokens in results]
|
||||
# Evaluate success criteria
|
||||
outcomes = [
|
||||
task_object.evaluate(conversation, completion) for completion in completions
|
||||
]
|
||||
outcomes = [task_object.evaluate(conversation, completion) for completion in completions]
|
||||
passed = any(outcomes)
|
||||
|
||||
# Keep stats
|
||||
|
|
@ -86,11 +63,7 @@ def run_generative_eval(
|
|||
num_passed += int(passed)
|
||||
|
||||
# Logging (overwrite the same line in the console)
|
||||
print(
|
||||
f"\r\033[KRank {ddp_rank} | {num_passed}/{total} ({100*num_passed/total:.2f}%)",
|
||||
end="",
|
||||
flush=True,
|
||||
)
|
||||
print(f"\r\033[KRank {ddp_rank} | {num_passed}/{total} ({100 * num_passed / total:.2f}%)", end='', flush=True)
|
||||
|
||||
# Finish the in-place progress line with a newline before final summary
|
||||
print()
|
||||
|
|
@ -105,7 +78,7 @@ def run_generative_eval(
|
|||
total = total_tensor.item()
|
||||
|
||||
print0("=" * 50)
|
||||
print0(f"Final: {num_passed}/{total} ({100*num_passed/total:.2f}%)")
|
||||
print0(f"Final: {num_passed}/{total} ({100 * num_passed / total:.2f}%)")
|
||||
|
||||
# Return the accuracy
|
||||
return num_passed / total
|
||||
|
|
@ -118,26 +91,17 @@ def run_generative_eval(
|
|||
|
||||
|
||||
def run_categorical_eval(task_object, tokenizer, model, batch_size, max_problems=None):
|
||||
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
|
||||
device = model.get_device()
|
||||
bos = (
|
||||
tokenizer.get_bos_token_id()
|
||||
) # use BOS as pad token is ok, these positions are ignored
|
||||
bos = tokenizer.get_bos_token_id() # use BOS as pad token is ok, these positions are ignored
|
||||
|
||||
# We'll process batches of independent problems at a time because there is no sampling needed
|
||||
num_problems = (
|
||||
len(task_object)
|
||||
if max_problems is None
|
||||
else min(len(task_object), max_problems)
|
||||
)
|
||||
num_problems = len(task_object) if max_problems is None else min(len(task_object), max_problems)
|
||||
ceil_div = lambda x, y: -(-x // y)
|
||||
num_batches = ceil_div(num_problems, batch_size)
|
||||
|
||||
# Run the evaluation
|
||||
letter_to_id_cache = (
|
||||
{}
|
||||
) # many letters will repeat often, let's save the tokenizer some work
|
||||
letter_to_id_cache = {} # many letters will repeat often, let's save the tokenizer some work
|
||||
num_passed, total = 0, 0
|
||||
for i in range(ddp_rank, num_batches, ddp_world_size):
|
||||
i0, i1 = i * batch_size, min((i + 1) * batch_size, num_problems)
|
||||
|
|
@ -145,16 +109,13 @@ def run_categorical_eval(task_object, tokenizer, model, batch_size, max_problems
|
|||
# Prepare the batch of problems. They might all be of different length, so we pad/collate them.
|
||||
conversations = [task_object[ii] for ii in range(i0, i1)]
|
||||
prompt_ids = [
|
||||
tokenizer.render_for_completion(conversation)
|
||||
for conversation in conversations
|
||||
tokenizer.render_for_completion(conversation) for conversation in conversations
|
||||
] # TODO: remake the way this works
|
||||
max_length = max(len(ids) for ids in prompt_ids)
|
||||
answer_time_positions = [
|
||||
len(ids) - 1 for ids in prompt_ids
|
||||
] # where the last token is (and the predicted answer)
|
||||
padded_prompt_ids = [
|
||||
ids + [bos] * (max_length - len(ids)) for ids in prompt_ids
|
||||
]
|
||||
padded_prompt_ids = [ids + [bos] * (max_length - len(ids)) for ids in prompt_ids]
|
||||
prompt_ids = torch.tensor(padded_prompt_ids, dtype=torch.long, device=device)
|
||||
|
||||
# Get the logits for the whole batch of conversations in parallel (efficiency win here)
|
||||
|
|
@ -167,14 +128,12 @@ def run_categorical_eval(task_object, tokenizer, model, batch_size, max_problems
|
|||
# letter (e.g. A, B, C, D), but evaluations typically make the task easier in this way.
|
||||
for idx, conversation in enumerate(conversations):
|
||||
# get the token ids of all the available letters of this problem
|
||||
letters = conversation["letters"]
|
||||
letters = conversation['letters']
|
||||
letter_ids = []
|
||||
for letter in letters:
|
||||
if not letter in letter_to_id_cache:
|
||||
encoded_letter = tokenizer.encode(letter)
|
||||
assert (
|
||||
len(encoded_letter) == 1
|
||||
), "Each letter must be a single token"
|
||||
assert len(encoded_letter) == 1, "Each letter must be a single token"
|
||||
letter_to_id_cache[letter] = encoded_letter[0]
|
||||
letter_ids.append(letter_to_id_cache[letter])
|
||||
# focus logits just down to the answer position and the available letters of the answer
|
||||
|
|
@ -198,7 +157,7 @@ def run_categorical_eval(task_object, tokenizer, model, batch_size, max_problems
|
|||
total = total_tensor.item()
|
||||
|
||||
average = num_passed / total
|
||||
print0(f"Final: {num_passed}/{total} ({100*average:.2f}%)")
|
||||
print0(f"Final: {num_passed}/{total} ({100 * average:.2f}%)")
|
||||
return average
|
||||
|
||||
|
||||
|
|
@ -219,16 +178,16 @@ def run_chat_eval(
|
|||
):
|
||||
# Create the evaluation object
|
||||
task_module = {
|
||||
"HumanEval": HumanEval,
|
||||
"MMLU": partial(MMLU, subset="all", split="test"),
|
||||
"ARC-Easy": partial(ARC, subset="ARC-Easy", split="test"),
|
||||
"ARC-Challenge": partial(ARC, subset="ARC-Challenge", split="test"),
|
||||
"GSM8K": partial(GSM8K, subset="main", split="test"),
|
||||
"SpellingBee": partial(SpellingBee, size=256, split="test"),
|
||||
'HumanEval': HumanEval,
|
||||
'MMLU': partial(MMLU, subset="all", split="test"),
|
||||
'ARC-Easy': partial(ARC, subset="ARC-Easy", split="test"),
|
||||
'ARC-Challenge': partial(ARC, subset="ARC-Challenge", split="test"),
|
||||
'GSM8K': partial(GSM8K, subset="main", split="test"),
|
||||
'SpellingBee': partial(SpellingBee, size=256, split="test"),
|
||||
}[task_name]
|
||||
task_object = task_module()
|
||||
# Run the evaluation
|
||||
if task_object.eval_type == "generative":
|
||||
if task_object.eval_type == 'generative':
|
||||
acc = run_generative_eval(
|
||||
task_object,
|
||||
tokenizer,
|
||||
|
|
@ -240,10 +199,8 @@ def run_chat_eval(
|
|||
top_k,
|
||||
max_problems=max_problems,
|
||||
)
|
||||
elif task_object.eval_type == "categorical":
|
||||
acc = run_categorical_eval(
|
||||
task_object, tokenizer, model, batch_size, max_problems=max_problems
|
||||
)
|
||||
elif task_object.eval_type == 'categorical':
|
||||
acc = run_categorical_eval(task_object, tokenizer, model, batch_size, max_problems=max_problems)
|
||||
else:
|
||||
raise ValueError(f"Unsupported task evaluation type: {task_object.eval_type}")
|
||||
return acc
|
||||
|
|
@ -251,87 +208,55 @@ def run_chat_eval(
|
|||
|
||||
# -----------------------------------------------------------------------------
|
||||
if __name__ == "__main__":
|
||||
|
||||
# Parse command-line arguments
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-i', '--source', type=str, required=True, help="Source of the model: sft|mid|rl")
|
||||
parser.add_argument(
|
||||
"-i",
|
||||
"--source",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Source of the model: sft|mid|rl",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-a",
|
||||
"--task-name",
|
||||
'-a',
|
||||
'--task-name',
|
||||
type=str,
|
||||
default=None,
|
||||
help="Task name. Default = all tasks. Use | to split multiple tasks.",
|
||||
)
|
||||
parser.add_argument('-d', '--dtype', type=str, default='bfloat16', choices=['float32', 'bfloat16'])
|
||||
parser.add_argument('-t', '--temperature', type=float, default=0.0)
|
||||
parser.add_argument('-m', '--max-new-tokens', type=int, default=512)
|
||||
parser.add_argument('-n', '--num-samples', type=int, default=1)
|
||||
parser.add_argument('-k', '--top-k', type=int, default=50)
|
||||
parser.add_argument('-b', '--batch-size', type=int, default=8, help='Batch size for categorical evaluation')
|
||||
parser.add_argument('-g', '--model-tag', type=str, default=None, help='Model tag to load')
|
||||
parser.add_argument('-s', '--step', type=int, default=None, help='Step to load')
|
||||
parser.add_argument('-x', '--max-problems', type=int, default=None, help='Max problems to evaluate')
|
||||
parser.add_argument(
|
||||
"-d", "--dtype", type=str, default="bfloat16", choices=["float32", "bfloat16"]
|
||||
)
|
||||
parser.add_argument("-t", "--temperature", type=float, default=0.0)
|
||||
parser.add_argument("-m", "--max-new-tokens", type=int, default=512)
|
||||
parser.add_argument("-n", "--num-samples", type=int, default=1)
|
||||
parser.add_argument("-k", "--top-k", type=int, default=50)
|
||||
parser.add_argument(
|
||||
"-b",
|
||||
"--batch-size",
|
||||
type=int,
|
||||
default=8,
|
||||
help="Batch size for categorical evaluation",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-g", "--model-tag", type=str, default=None, help="Model tag to load"
|
||||
)
|
||||
parser.add_argument("-s", "--step", type=int, default=None, help="Step to load")
|
||||
parser.add_argument(
|
||||
"-x", "--max-problems", type=int, default=None, help="Max problems to evaluate"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device-type",
|
||||
'--device-type',
|
||||
type=str,
|
||||
default="",
|
||||
choices=["cuda", "cpu", "mps"],
|
||||
help="Device type for evaluation: cuda|cpu|mps. empty => autodetect",
|
||||
default='',
|
||||
choices=['cuda', 'cpu', 'mps'],
|
||||
help='Device type for evaluation: cuda|cpu|mps. empty => autodetect',
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
device_type = (
|
||||
autodetect_device_type() if args.device_type == "" else args.device_type
|
||||
)
|
||||
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||
ptdtype = torch.float32 if args.dtype == "float32" else torch.bfloat16
|
||||
ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16
|
||||
autocast_ctx = (
|
||||
torch.amp.autocast(device_type=device_type, dtype=ptdtype)
|
||||
if device_type == "cuda"
|
||||
else nullcontext()
|
||||
torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext()
|
||||
)
|
||||
|
||||
model, tokenizer, meta = load_model(
|
||||
args.source, device, phase="eval", model_tag=args.model_tag, step=args.step
|
||||
)
|
||||
model, tokenizer, meta = load_model(args.source, device, phase="eval", model_tag=args.model_tag, step=args.step)
|
||||
engine = Engine(model, tokenizer)
|
||||
|
||||
# Get the tasks to evaluate on
|
||||
all_tasks = [
|
||||
"ARC-Easy",
|
||||
"ARC-Challenge",
|
||||
"MMLU",
|
||||
"GSM8K",
|
||||
"HumanEval",
|
||||
"SpellingBee",
|
||||
]
|
||||
all_tasks = ['ARC-Easy', 'ARC-Challenge', 'MMLU', 'GSM8K', 'HumanEval', 'SpellingBee']
|
||||
baseline_accuracies = {
|
||||
"ARC-Easy": 0.25, # multiple choice 1 of 4 => 25%
|
||||
"ARC-Challenge": 0.25, # multiple choice 1 of 4 => 25%
|
||||
"MMLU": 0.25, # multiple choice 1 of 4 => 25%
|
||||
"GSM8K": 0.0, # open-ended => 0%
|
||||
"HumanEval": 0.0, # open-ended => 0%
|
||||
"SpellingBee": 0.0, # open-ended => 0%
|
||||
'ARC-Easy': 0.25, # multiple choice 1 of 4 => 25%
|
||||
'ARC-Challenge': 0.25, # multiple choice 1 of 4 => 25%
|
||||
'MMLU': 0.25, # multiple choice 1 of 4 => 25%
|
||||
'GSM8K': 0.0, # open-ended => 0%
|
||||
'HumanEval': 0.0, # open-ended => 0%
|
||||
'SpellingBee': 0.0, # open-ended => 0%
|
||||
}
|
||||
task_names = all_tasks if args.task_name is None else args.task_name.split("|")
|
||||
task_names = all_tasks if args.task_name is None else args.task_name.split('|')
|
||||
|
||||
# Run all the task evaluations sequentially
|
||||
results = {}
|
||||
|
|
|
|||
|
|
@ -24,13 +24,7 @@ import torch.distributed as dist
|
|||
import wandb
|
||||
|
||||
from nanochat.checkpoint_manager import load_model, save_checkpoint
|
||||
from nanochat.common import (
|
||||
DummyWandb,
|
||||
compute_cleanup,
|
||||
compute_init,
|
||||
get_base_dir,
|
||||
print0,
|
||||
)
|
||||
from nanochat.common import DummyWandb, compute_cleanup, compute_init, get_base_dir, print0
|
||||
from nanochat.engine import Engine
|
||||
from tasks.gsm8k import GSM8K
|
||||
|
||||
|
|
@ -39,9 +33,7 @@ run = "dummy" # wandb run name
|
|||
source = "sft" # mid|sft
|
||||
dtype = "bfloat16"
|
||||
device_batch_size = 8 # no forward pass will go above this to not OOM
|
||||
examples_per_step = (
|
||||
16 # in total and across all ranks (note: examples, not samples/completions!)
|
||||
)
|
||||
examples_per_step = 16 # in total and across all ranks (note: examples, not samples/completions!)
|
||||
num_samples = 16 # number of samples per example (/question)
|
||||
max_new_tokens = 256
|
||||
temperature = 1.0
|
||||
|
|
@ -56,30 +48,20 @@ save_every = 60 # every how many steps to save the model
|
|||
eval_every = 60 # every how many steps to evaluate the model for val pass@k
|
||||
eval_examples = 400 # number of examples used for evaluating pass@k
|
||||
# now allow CLI to override the settings via the configurator lol
|
||||
config_keys = [
|
||||
k
|
||||
for k, v in globals().items()
|
||||
if not k.startswith("_") and isinstance(v, (int, float, bool, str))
|
||||
]
|
||||
exec(
|
||||
open(os.path.join("nanochat", "configurator.py")).read()
|
||||
) # overrides from command line or config file
|
||||
config_keys = [k for k, v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
|
||||
exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file
|
||||
user_config = {k: globals()[k] for k in config_keys} # will be useful for logging
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
# Init compute/precision
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
|
||||
master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
|
||||
dtype = torch.float32 if dtype == "float32" else torch.bfloat16
|
||||
dtype = torch.float32 if dtype == 'float32' else torch.bfloat16
|
||||
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=dtype)
|
||||
|
||||
# wandb logging init
|
||||
use_dummy_wandb = run == "dummy" or not master_process
|
||||
wandb_run = (
|
||||
DummyWandb()
|
||||
if use_dummy_wandb
|
||||
else wandb.init(project="nanochat-rl", name=run, config=user_config)
|
||||
)
|
||||
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-rl", name=run, config=user_config)
|
||||
|
||||
# Init model and tokenizer
|
||||
model, tokenizer, meta = load_model(source, device, phase="eval")
|
||||
|
|
@ -103,7 +85,6 @@ def get_batch():
|
|||
ddp_rank, len(train_task), ddp_world_size
|
||||
) # each rank is responsible for different examples in the training data
|
||||
for example_idx in itertools.cycle(rank_indices):
|
||||
|
||||
# First get the full conversation of both user and assistant messages
|
||||
conversation = train_task[example_idx]
|
||||
|
||||
|
|
@ -116,13 +97,9 @@ def get_batch():
|
|||
model.eval() # ensure the model is in eval mode
|
||||
generated_token_sequences = []
|
||||
masks = []
|
||||
num_sampling_steps = (
|
||||
num_samples // device_batch_size
|
||||
) # go sequentially to prevent OOMs
|
||||
num_sampling_steps = num_samples // device_batch_size # go sequentially to prevent OOMs
|
||||
for sampling_step in range(num_sampling_steps):
|
||||
seed = (
|
||||
hash((step, example_idx, sampling_step)) & 0x7FFFFFFF
|
||||
) # positive half of int32
|
||||
seed = hash((step, example_idx, sampling_step)) & 0x7FFFFFFF # positive half of int32
|
||||
with autocast_ctx:
|
||||
generated_token_sequences_batch, masks_batch = engine.generate_batch(
|
||||
tokens,
|
||||
|
|
@ -149,21 +126,16 @@ def get_batch():
|
|||
# Pad the sequences so that their lengths (in time) match
|
||||
max_length = max(len(seq) for seq in generated_token_sequences)
|
||||
padded_generated_token_sequences = [
|
||||
seq + [assistant_end] * (max_length - len(seq))
|
||||
for seq in generated_token_sequences
|
||||
seq + [assistant_end] * (max_length - len(seq)) for seq in generated_token_sequences
|
||||
]
|
||||
padded_masks = [mask + [0] * (max_length - len(mask)) for mask in masks]
|
||||
# Stack up the sequences and masks into PyTorch tensors
|
||||
ids = torch.tensor(
|
||||
padded_generated_token_sequences, dtype=torch.long, device=device
|
||||
)
|
||||
ids = torch.tensor(padded_generated_token_sequences, dtype=torch.long, device=device)
|
||||
mask_ids = torch.tensor(padded_masks, dtype=torch.long, device=device)
|
||||
# Generate autoregressive inputs and targets to the Transformer
|
||||
inputs = ids[:, :-1]
|
||||
targets = ids[:, 1:].clone() # clone to avoid in-place modification:
|
||||
targets[mask_ids[:, 1:] == 0] = (
|
||||
-1
|
||||
) # <-- inplace modification right here. -1 is the ignore index
|
||||
targets[mask_ids[:, 1:] == 0] = -1 # <-- inplace modification right here. -1 is the ignore index
|
||||
# NOTE also that the Engine returns mask=0 for BOTH the prompt tokens AND the tool use tokens.
|
||||
# So we will (correctly) end up not training on the prompt tokens, or the tool use forced tokens.
|
||||
rewards = torch.tensor(rewards, dtype=torch.float, device=device)
|
||||
|
|
@ -177,14 +149,7 @@ def get_batch():
|
|||
# -----------------------------------------------------------------------------
|
||||
# Simple evaluation loop for GSM8K pass@k
|
||||
def run_gsm8k_eval(
|
||||
task,
|
||||
tokenizer,
|
||||
engine,
|
||||
max_examples=None,
|
||||
num_samples=1,
|
||||
max_completion_tokens=256,
|
||||
temperature=0.0,
|
||||
top_k=50,
|
||||
task, tokenizer, engine, max_examples=None, num_samples=1, max_completion_tokens=256, temperature=0.0, top_k=50
|
||||
):
|
||||
"""
|
||||
Evaluates GSM8K task and returns a list of records of evaluation outcomes.
|
||||
|
|
@ -192,23 +157,15 @@ def run_gsm8k_eval(
|
|||
do the reduction across ranks. This is the responsibility of the caller.
|
||||
Because the evaluation can take a while, this function will yield records one by one.
|
||||
"""
|
||||
max_examples = (
|
||||
min(max_examples, len(task)) if max_examples is not None else len(task)
|
||||
)
|
||||
max_examples = min(max_examples, len(task)) if max_examples is not None else len(task)
|
||||
for idx in range(ddp_rank, max_examples, ddp_world_size):
|
||||
conversation = task[idx]
|
||||
tokens = tokenizer.render_for_completion(conversation)
|
||||
prefix_length = len(tokens)
|
||||
# Generate k samples using batched generation inside the Engine
|
||||
assert (
|
||||
num_samples <= device_batch_size
|
||||
) # usually this is true. we can add a loop if not...
|
||||
assert num_samples <= device_batch_size # usually this is true. we can add a loop if not...
|
||||
generated_token_sequences, masks = engine.generate_batch(
|
||||
tokens,
|
||||
num_samples=num_samples,
|
||||
max_tokens=max_completion_tokens,
|
||||
temperature=temperature,
|
||||
top_k=top_k,
|
||||
tokens, num_samples=num_samples, max_tokens=max_completion_tokens, temperature=temperature, top_k=top_k
|
||||
)
|
||||
# Check each sample for correctness
|
||||
outcomes = []
|
||||
|
|
@ -240,9 +197,7 @@ optimizers = model.setup_optimizers(
|
|||
for opt in optimizers:
|
||||
for group in opt.param_groups:
|
||||
group["lr"] = group["lr"] * init_lr_frac
|
||||
group["initial_lr"] = group[
|
||||
"lr"
|
||||
] # save the initial learning so we can decay easily later
|
||||
group["initial_lr"] = group["lr"] # save the initial learning so we can decay easily later
|
||||
|
||||
|
||||
# Learning rate scheduler: simple rampdown to zero over num_steps
|
||||
|
|
@ -252,52 +207,33 @@ def get_lr_multiplier(it):
|
|||
|
||||
|
||||
# Calculate the number of examples each rank handles to achieve the desired examples_per_step
|
||||
print0(
|
||||
f"Total sequences per step: {examples_per_step * num_samples}"
|
||||
) # total batch size in sequences/step
|
||||
assert (
|
||||
examples_per_step % ddp_world_size == 0
|
||||
), "Desired examples per step must be divisible by the number of ranks"
|
||||
print0(f"Total sequences per step: {examples_per_step * num_samples}") # total batch size in sequences/step
|
||||
assert examples_per_step % ddp_world_size == 0, "Desired examples per step must be divisible by the number of ranks"
|
||||
examples_per_rank = examples_per_step // ddp_world_size # per GPU
|
||||
print0(f"Calculated examples per rank: {examples_per_rank}")
|
||||
|
||||
# Kick off the training loop
|
||||
batch_iterator = get_batch()
|
||||
for step in range(num_steps):
|
||||
|
||||
# Evaluate the model once in a while and log to wandb
|
||||
if step % eval_every == 0:
|
||||
model.eval()
|
||||
passk = torch.zeros(
|
||||
device_batch_size, device=device
|
||||
) # pass@k for k=1..device_batch_size
|
||||
passk = torch.zeros(device_batch_size, device=device) # pass@k for k=1..device_batch_size
|
||||
with autocast_ctx:
|
||||
records_iter = run_gsm8k_eval(
|
||||
val_task,
|
||||
tokenizer,
|
||||
engine,
|
||||
num_samples=device_batch_size,
|
||||
max_examples=eval_examples,
|
||||
temperature=1.0,
|
||||
val_task, tokenizer, engine, num_samples=device_batch_size, max_examples=eval_examples, temperature=1.0
|
||||
)
|
||||
records = list(records_iter) # collect all records
|
||||
for k in range(1, device_batch_size + 1):
|
||||
passk[k - 1] = sum(
|
||||
any(o["is_correct"] for o in r["outcomes"][:k]) for r in records
|
||||
)
|
||||
passk[k - 1] = sum(any(o["is_correct"] for o in r["outcomes"][:k]) for r in records)
|
||||
num_records = torch.tensor(len(records), dtype=torch.long, device=device)
|
||||
if ddp:
|
||||
dist.all_reduce(num_records, op=dist.ReduceOp.SUM)
|
||||
dist.all_reduce(passk, op=dist.ReduceOp.SUM)
|
||||
passk = passk / num_records.item() # normalize by the total number of records
|
||||
print_passk = [
|
||||
f"Pass@{k}: {passk[k - 1].item():.4f}"
|
||||
for k in range(1, device_batch_size + 1)
|
||||
]
|
||||
print_passk = [f"Pass@{k}: {passk[k - 1].item():.4f}" for k in range(1, device_batch_size + 1)]
|
||||
print0(f"Step {step} | {', '.join(print_passk)}")
|
||||
log_passk = {
|
||||
f"pass@{k}": passk[k - 1].item() for k in range(1, device_batch_size + 1)
|
||||
}
|
||||
log_passk = {f"pass@{k}": passk[k - 1].item() for k in range(1, device_batch_size + 1)}
|
||||
wandb_run.log(
|
||||
{
|
||||
"step": step,
|
||||
|
|
@ -310,9 +246,7 @@ for step in range(num_steps):
|
|||
sequence_lengths = []
|
||||
for example_step in range(examples_per_rank):
|
||||
# Get one batch corresponding to one example in the training dataset
|
||||
sequences_all, inputs_all, targets_all, rewards_all, advantages_all = next(
|
||||
batch_iterator
|
||||
)
|
||||
sequences_all, inputs_all, targets_all, rewards_all, advantages_all = next(batch_iterator)
|
||||
# Evaluate the loss and gradients
|
||||
model.train() # ensure the model is in train mode
|
||||
# We need one more loop because we can never exceed the device_batch_size
|
||||
|
|
@ -327,9 +261,7 @@ for step in range(num_steps):
|
|||
advantages = advantages_all[b0:b1]
|
||||
# Calculate log probabilities. Note that the loss calculates NLL = -logp, so we negate
|
||||
with autocast_ctx:
|
||||
logp = -model(inputs, targets, loss_reduction="none").view_as(
|
||||
inputs
|
||||
) # (B, T)
|
||||
logp = -model(inputs, targets, loss_reduction='none').view_as(inputs) # (B, T)
|
||||
# Calculate the PG objective. Note that ignore_index=-1 ensures that invalid tokens have loss 0.
|
||||
pg_obj = (logp * advantages.unsqueeze(-1)).sum()
|
||||
# normalize by the number of valid tokens, number of passes, and examples_per_rank
|
||||
|
|
@ -351,9 +283,7 @@ for step in range(num_steps):
|
|||
mean_sequence_length = sum(sequence_lengths) / len(sequence_lengths)
|
||||
if ddp: # aggregate across ranks
|
||||
mean_reward_tensor = torch.tensor(mean_reward, dtype=torch.float, device=device)
|
||||
mean_sequence_length_tensor = torch.tensor(
|
||||
mean_sequence_length, dtype=torch.float, device=device
|
||||
)
|
||||
mean_sequence_length_tensor = torch.tensor(mean_sequence_length, dtype=torch.float, device=device)
|
||||
dist.all_reduce(mean_reward_tensor, op=dist.ReduceOp.AVG)
|
||||
dist.all_reduce(mean_sequence_length_tensor, op=dist.ReduceOp.AVG)
|
||||
mean_reward = mean_reward_tensor.item()
|
||||
|
|
@ -385,16 +315,12 @@ for step in range(num_steps):
|
|||
)
|
||||
|
||||
# Master process saves the model once in a while. Skip first step. Save last step.
|
||||
if master_process and (
|
||||
(step > 0 and step % save_every == 0) or step == num_steps - 1
|
||||
):
|
||||
if master_process and ((step > 0 and step % save_every == 0) or step == num_steps - 1):
|
||||
base_dir = get_base_dir()
|
||||
depth = model.config.n_layer
|
||||
model_tag = f"d{depth}" # base the model tag on the depth of the base model
|
||||
checkpoint_dir = os.path.join(base_dir, "chatrl_checkpoints", model_tag)
|
||||
model_config_kwargs = (
|
||||
model.config.__dict__
|
||||
) # slightly naughty, abusing the simplicity of GPTConfig, TODO nicer
|
||||
model_config_kwargs = model.config.__dict__ # slightly naughty, abusing the simplicity of GPTConfig, TODO nicer
|
||||
save_checkpoint(
|
||||
checkpoint_dir,
|
||||
step,
|
||||
|
|
|
|||
|
|
@ -20,14 +20,7 @@ import torch.distributed as dist
|
|||
import wandb
|
||||
|
||||
from nanochat.checkpoint_manager import load_model, save_checkpoint
|
||||
from nanochat.common import (
|
||||
DummyWandb,
|
||||
autodetect_device_type,
|
||||
compute_cleanup,
|
||||
compute_init,
|
||||
get_base_dir,
|
||||
print0,
|
||||
)
|
||||
from nanochat.common import DummyWandb, autodetect_device_type, compute_cleanup, compute_init, get_base_dir, print0
|
||||
from nanochat.engine import Engine
|
||||
from scripts.chat_eval import run_chat_eval
|
||||
from tasks.arc import ARC
|
||||
|
|
@ -50,9 +43,7 @@ dtype = "bfloat16"
|
|||
device_batch_size = 4 # max to avoid OOM
|
||||
# optimization
|
||||
num_epochs = 1
|
||||
num_iterations = (
|
||||
-1
|
||||
) # override number of iterations (-1 = disable, use num_epochs to derive it)
|
||||
num_iterations = -1 # override number of iterations (-1 = disable, use num_epochs to derive it)
|
||||
target_examples_per_step = 32
|
||||
unembedding_lr = 0.004
|
||||
embedding_lr = 0.2
|
||||
|
|
@ -65,14 +56,8 @@ eval_steps = 100
|
|||
eval_metrics_every = 200
|
||||
eval_metrics_max_problems = 1024
|
||||
# now allow CLI to override the settings via the configurator lol
|
||||
config_keys = [
|
||||
k
|
||||
for k, v in globals().items()
|
||||
if not k.startswith("_") and isinstance(v, (int, float, bool, str))
|
||||
]
|
||||
exec(
|
||||
open(os.path.join("nanochat", "configurator.py")).read()
|
||||
) # overrides from command line or config file
|
||||
config_keys = [k for k, v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
|
||||
exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file
|
||||
user_config = {k: globals()[k] for k in config_keys} # possibly useful for logging
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
|
@ -80,56 +65,38 @@ user_config = {k: globals()[k] for k in config_keys} # possibly useful for logg
|
|||
device_type = autodetect_device_type() if device_type == "" else device_type
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||
master_process = ddp_rank == 0
|
||||
ptdtype = torch.float32 if dtype == "float32" else torch.bfloat16
|
||||
autocast_ctx = (
|
||||
torch.amp.autocast(device_type=device_type, dtype=ptdtype)
|
||||
if device_type == "cuda"
|
||||
else nullcontext()
|
||||
)
|
||||
ptdtype = torch.float32 if dtype == 'float32' else torch.bfloat16
|
||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext()
|
||||
|
||||
# wandb logging init
|
||||
use_dummy_wandb = run == "dummy" or not master_process
|
||||
wandb_run = (
|
||||
DummyWandb()
|
||||
if use_dummy_wandb
|
||||
else wandb.init(
|
||||
project="nanochat-sft", name=run, config=user_config, save_code=True
|
||||
)
|
||||
else wandb.init(project="nanochat-sft", name=run, config=user_config, save_code=True)
|
||||
)
|
||||
|
||||
# Load the model and tokenizer
|
||||
model, tokenizer, meta = load_model(
|
||||
source, device, phase="train", model_tag=model_tag, step=step
|
||||
)
|
||||
model, tokenizer, meta = load_model(source, device, phase="train", model_tag=model_tag, step=step)
|
||||
orig_model = model # original, uncompiled model
|
||||
# model = torch.compile(model, dynamic=True) # doesn't work super well because of variable lengths of inputs
|
||||
engine = Engine(model, tokenizer) # will be used for inline model evaluation only
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Task data mixture we'll train on
|
||||
identity_conversations_filepath = os.path.join(
|
||||
get_base_dir(), "identity_conversations.jsonl"
|
||||
)
|
||||
identity_conversations_filepath = os.path.join(get_base_dir(), "identity_conversations.jsonl")
|
||||
train_ds = TaskMixture(
|
||||
[
|
||||
ARC(subset="ARC-Easy", split="train"), # 2.3K rows
|
||||
ARC(subset="ARC-Challenge", split="train"), # 1.1K rows
|
||||
GSM8K(subset="main", split="train"), # 8K rows
|
||||
SmolTalk(split="train", stop=10_000), # 10K rows of smoltalk
|
||||
CustomJSON(
|
||||
filepath=identity_conversations_filepath
|
||||
), # 1K rows of synthetic identity conversations
|
||||
SimpleSpelling(
|
||||
size=300, split="train"
|
||||
), # 300 rows of Simple Spelling (e.g. spell the word 'apple')
|
||||
SpellingBee(
|
||||
size=300, split="train"
|
||||
), # 300 rows of Spelling Bee (e.g. how many 'r' are in 'strawberry'?)
|
||||
CustomJSON(filepath=identity_conversations_filepath), # 1K rows of synthetic identity conversations
|
||||
SimpleSpelling(size=300, split="train"), # 300 rows of Simple Spelling (e.g. spell the word 'apple')
|
||||
SpellingBee(size=300, split="train"), # 300 rows of Spelling Bee (e.g. how many 'r' are in 'strawberry'?)
|
||||
]
|
||||
) # 2.3K + 1.1K + 8K + 10K + 1K + 0.3K + 0.3K = 23K rows
|
||||
val_ds = SmolTalk(
|
||||
split="test"
|
||||
) # general conversations, 24K rows (though we don't actually use all of it)
|
||||
val_ds = SmolTalk(split="test") # general conversations, 24K rows (though we don't actually use all of it)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# DataLoader
|
||||
|
|
@ -143,9 +110,7 @@ def sft_data_generator(dataset, batch_size):
|
|||
# prepares a list of tokenized conversations into a batch and yields
|
||||
def collate_and_yield(batch):
|
||||
nrows = len(batch)
|
||||
ncols = (
|
||||
max(len(ids) for ids, mask in batch) - 1
|
||||
) # seq of n creates inputs/targets of n-1
|
||||
ncols = max(len(ids) for ids, mask in batch) - 1 # seq of n creates inputs/targets of n-1
|
||||
inputs = torch.full((nrows, ncols), pad_token_id, dtype=torch.long)
|
||||
targets = torch.full((nrows, ncols), -1, dtype=torch.long) # -1 is ignore index
|
||||
for i, (ids, mask) in enumerate(batch):
|
||||
|
|
@ -178,9 +143,9 @@ examples_per_step = device_batch_size * ddp_world_size
|
|||
print0(f"Target examples per step: {target_examples_per_step}")
|
||||
print0(f"Device batch size: {device_batch_size}")
|
||||
print0(f"Examples per step is device_batch_size * ddp_world_size: {examples_per_step}")
|
||||
assert (
|
||||
target_examples_per_step % examples_per_step == 0
|
||||
), "Target examples per step must be divisible by examples per step"
|
||||
assert target_examples_per_step % examples_per_step == 0, (
|
||||
"Target examples per step must be divisible by examples per step"
|
||||
)
|
||||
grad_accum_steps = target_examples_per_step // examples_per_step
|
||||
print0(f"=> Setting grad accum steps: {grad_accum_steps}")
|
||||
|
||||
|
|
@ -204,9 +169,7 @@ optimizers = model.setup_optimizers(
|
|||
for opt in optimizers:
|
||||
for group in opt.param_groups:
|
||||
group["lr"] = group["lr"] * init_lr_frac
|
||||
group["initial_lr"] = group[
|
||||
"lr"
|
||||
] # save the initial learning so we can decay easily later
|
||||
group["initial_lr"] = group["lr"] # save the initial learning so we can decay easily later
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Training loop
|
||||
|
|
@ -269,7 +232,7 @@ for step in range(num_iterations):
|
|||
batch_size=device_batch_size * 2,
|
||||
max_problems=eval_metrics_max_problems,
|
||||
)
|
||||
metrics_str = ", ".join(f"{k}: {v:.6f}" for k, v in metrics.items())
|
||||
metrics_str = ', '.join(f'{k}: {v:.6f}' for k, v in metrics.items())
|
||||
print0(f"Step {step:05d} | {metrics_str}")
|
||||
wandb_run.log(
|
||||
{
|
||||
|
|
@ -283,17 +246,13 @@ for step in range(num_iterations):
|
|||
break
|
||||
|
||||
# evaluate the gradient
|
||||
num_tokens = torch.tensor(
|
||||
0, device=device
|
||||
) # the number of "active" tokens of supervision seen
|
||||
num_tokens = torch.tensor(0, device=device) # the number of "active" tokens of supervision seen
|
||||
for micro_step in range(grad_accum_steps):
|
||||
train_inputs, train_targets = next(train_iter)
|
||||
with autocast_ctx:
|
||||
loss = model(train_inputs, train_targets)
|
||||
train_loss = loss.detach() # for logging
|
||||
loss = (
|
||||
loss / grad_accum_steps
|
||||
) # each .backward() is a grad sum => normalize loss here
|
||||
loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here
|
||||
loss.backward() # accumulate the gradient
|
||||
num_tokens += (train_targets >= 0).sum()
|
||||
if ddp:
|
||||
|
|
@ -332,9 +291,7 @@ if master_process:
|
|||
depth = model.config.n_layer
|
||||
model_tag = f"d{depth}" # base the model tag on the depth of the base model
|
||||
checkpoint_dir = os.path.join(base_dir, "chatsft_checkpoints", model_tag)
|
||||
model_config_kwargs = (
|
||||
model.config.__dict__
|
||||
) # slightly naughty, abusing the simplicity of GPTConfig, TODO nicer
|
||||
model_config_kwargs = model.config.__dict__ # slightly naughty, abusing the simplicity of GPTConfig, TODO nicer
|
||||
save_checkpoint(
|
||||
checkpoint_dir,
|
||||
step,
|
||||
|
|
|
|||
|
|
@ -36,9 +36,9 @@ import json
|
|||
import logging
|
||||
import os
|
||||
import random
|
||||
from collections.abc import AsyncGenerator
|
||||
from contextlib import asynccontextmanager, nullcontext
|
||||
from dataclasses import dataclass
|
||||
from typing import AsyncGenerator, List, Optional
|
||||
|
||||
import torch
|
||||
from fastapi import FastAPI, HTTPException
|
||||
|
|
@ -61,61 +61,33 @@ MAX_TOP_K = 200
|
|||
MIN_MAX_TOKENS = 1
|
||||
MAX_MAX_TOKENS = 4096
|
||||
|
||||
parser = argparse.ArgumentParser(description="NanoChat Web Server")
|
||||
parser = argparse.ArgumentParser(description='NanoChat Web Server')
|
||||
parser.add_argument('-n', '--num-gpus', type=int, default=1, help='Number of GPUs to use (default: 1)')
|
||||
parser.add_argument('-i', '--source', type=str, default="sft", help="Source of the model: sft|mid|rl")
|
||||
parser.add_argument('-t', '--temperature', type=float, default=0.8, help='Default temperature for generation')
|
||||
parser.add_argument('-k', '--top-k', type=int, default=50, help='Default top-k sampling parameter')
|
||||
parser.add_argument('-m', '--max-tokens', type=int, default=512, help='Default max tokens for generation')
|
||||
parser.add_argument('-g', '--model-tag', type=str, default=None, help='Model tag to load')
|
||||
parser.add_argument('-s', '--step', type=int, default=None, help='Step to load')
|
||||
parser.add_argument('-p', '--port', type=int, default=8000, help='Port to run the server on')
|
||||
parser.add_argument('-d', '--dtype', type=str, default='bfloat16', choices=['float32', 'bfloat16'])
|
||||
parser.add_argument(
|
||||
"-n", "--num-gpus", type=int, default=1, help="Number of GPUs to use (default: 1)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-i", "--source", type=str, default="sft", help="Source of the model: sft|mid|rl"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-t",
|
||||
"--temperature",
|
||||
type=float,
|
||||
default=0.8,
|
||||
help="Default temperature for generation",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-k", "--top-k", type=int, default=50, help="Default top-k sampling parameter"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-m",
|
||||
"--max-tokens",
|
||||
type=int,
|
||||
default=512,
|
||||
help="Default max tokens for generation",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-g", "--model-tag", type=str, default=None, help="Model tag to load"
|
||||
)
|
||||
parser.add_argument("-s", "--step", type=int, default=None, help="Step to load")
|
||||
parser.add_argument(
|
||||
"-p", "--port", type=int, default=8000, help="Port to run the server on"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-d", "--dtype", type=str, default="bfloat16", choices=["float32", "bfloat16"]
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device-type",
|
||||
'--device-type',
|
||||
type=str,
|
||||
default="",
|
||||
choices=["cuda", "cpu", "mps"],
|
||||
help="Device type for evaluation: cuda|cpu|mps. empty => autodetect",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--host", type=str, default="0.0.0.0", help="Host to bind the server to"
|
||||
default='',
|
||||
choices=['cuda', 'cpu', 'mps'],
|
||||
help='Device type for evaluation: cuda|cpu|mps. empty => autodetect',
|
||||
)
|
||||
parser.add_argument('--host', type=str, default='0.0.0.0', help='Host to bind the server to')
|
||||
args = parser.parse_args()
|
||||
|
||||
# Configure logging for conversation traffic
|
||||
logging.basicConfig(
|
||||
level=logging.INFO, format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S"
|
||||
)
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||
ptdtype = torch.float32 if args.dtype == "float32" else torch.bfloat16
|
||||
ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -132,28 +104,23 @@ class Worker:
|
|||
class WorkerPool:
|
||||
"""Pool of workers, each with a model replica on a different GPU."""
|
||||
|
||||
def __init__(self, num_gpus: Optional[int] = None):
|
||||
def __init__(self, num_gpus: int | None = None):
|
||||
if num_gpus is None:
|
||||
if device_type == "cuda":
|
||||
num_gpus = torch.cuda.device_count()
|
||||
else:
|
||||
num_gpus = 1 # e.g. cpu|mps
|
||||
self.num_gpus = num_gpus
|
||||
self.workers: List[Worker] = []
|
||||
self.workers: list[Worker] = []
|
||||
self.available_workers: asyncio.Queue = asyncio.Queue()
|
||||
|
||||
async def initialize(
|
||||
self, source: str, model_tag: Optional[str] = None, step: Optional[int] = None
|
||||
):
|
||||
async def initialize(self, source: str, model_tag: str | None = None, step: int | None = None):
|
||||
"""Load model on each GPU."""
|
||||
print(f"Initializing worker pool with {self.num_gpus} GPUs...")
|
||||
if self.num_gpus > 1:
|
||||
assert (
|
||||
device_type == "cuda"
|
||||
), "Only CUDA supports multiple workers/GPUs. cpu|mps does not."
|
||||
assert device_type == "cuda", "Only CUDA supports multiple workers/GPUs. cpu|mps does not."
|
||||
|
||||
for gpu_id in range(self.num_gpus):
|
||||
|
||||
if device_type == "cuda":
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
print(f"Loading model on GPU {gpu_id}...")
|
||||
|
|
@ -161,23 +128,13 @@ class WorkerPool:
|
|||
device = torch.device(device_type) # e.g. cpu|mps
|
||||
print(f"Loading model on {device_type}...")
|
||||
|
||||
model, tokenizer, _ = load_model(
|
||||
source, device, phase="eval", model_tag=model_tag, step=step
|
||||
)
|
||||
model, tokenizer, _ = load_model(source, device, phase="eval", model_tag=model_tag, step=step)
|
||||
engine = Engine(model, tokenizer)
|
||||
autocast_ctx = (
|
||||
torch.amp.autocast(device_type=device_type, dtype=ptdtype)
|
||||
if device_type == "cuda"
|
||||
else nullcontext()
|
||||
torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext()
|
||||
)
|
||||
|
||||
worker = Worker(
|
||||
gpu_id=gpu_id,
|
||||
device=device,
|
||||
engine=engine,
|
||||
tokenizer=tokenizer,
|
||||
autocast_ctx=autocast_ctx,
|
||||
)
|
||||
worker = Worker(gpu_id=gpu_id, device=device, engine=engine, tokenizer=tokenizer, autocast_ctx=autocast_ctx)
|
||||
self.workers.append(worker)
|
||||
await self.available_workers.put(worker)
|
||||
|
||||
|
|
@ -198,10 +155,10 @@ class ChatMessage(BaseModel):
|
|||
|
||||
|
||||
class ChatRequest(BaseModel):
|
||||
messages: List[ChatMessage]
|
||||
temperature: Optional[float] = None
|
||||
max_tokens: Optional[int] = None
|
||||
top_k: Optional[int] = None
|
||||
messages: list[ChatMessage]
|
||||
temperature: float | None = None
|
||||
max_tokens: int | None = None
|
||||
top_k: int | None = None
|
||||
|
||||
|
||||
def validate_chat_request(request: ChatRequest):
|
||||
|
|
@ -219,9 +176,7 @@ def validate_chat_request(request: ChatRequest):
|
|||
total_length = 0
|
||||
for i, message in enumerate(request.messages):
|
||||
if not message.content:
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"Message {i} has empty content"
|
||||
)
|
||||
raise HTTPException(status_code=400, detail=f"Message {i} has empty content")
|
||||
|
||||
msg_length = len(message.content)
|
||||
if msg_length > MAX_MESSAGE_LENGTH:
|
||||
|
|
@ -241,32 +196,26 @@ def validate_chat_request(request: ChatRequest):
|
|||
for i, message in enumerate(request.messages):
|
||||
if message.role not in ["user", "assistant"]:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Message {i} has invalid role. Must be 'user', 'assistant', or 'system'",
|
||||
status_code=400, detail=f"Message {i} has invalid role. Must be 'user', 'assistant', or 'system'"
|
||||
)
|
||||
|
||||
# Validate temperature
|
||||
if request.temperature is not None:
|
||||
if not (MIN_TEMPERATURE <= request.temperature <= MAX_TEMPERATURE):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Temperature must be between {MIN_TEMPERATURE} and {MAX_TEMPERATURE}",
|
||||
status_code=400, detail=f"Temperature must be between {MIN_TEMPERATURE} and {MAX_TEMPERATURE}"
|
||||
)
|
||||
|
||||
# Validate top_k
|
||||
if request.top_k is not None:
|
||||
if not (MIN_TOP_K <= request.top_k <= MAX_TOP_K):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"top_k must be between {MIN_TOP_K} and {MAX_TOP_K}",
|
||||
)
|
||||
raise HTTPException(status_code=400, detail=f"top_k must be between {MIN_TOP_K} and {MAX_TOP_K}")
|
||||
|
||||
# Validate max_tokens
|
||||
if request.max_tokens is not None:
|
||||
if not (MIN_MAX_TOKENS <= request.max_tokens <= MAX_MAX_TOKENS):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"max_tokens must be between {MIN_MAX_TOKENS} and {MAX_MAX_TOKENS}",
|
||||
status_code=400, detail=f"max_tokens must be between {MIN_MAX_TOKENS} and {MAX_MAX_TOKENS}"
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -275,9 +224,7 @@ async def lifespan(app: FastAPI):
|
|||
"""Load models on all GPUs on startup."""
|
||||
print("Loading nanochat models across GPUs...")
|
||||
app.state.worker_pool = WorkerPool(num_gpus=args.num_gpus)
|
||||
await app.state.worker_pool.initialize(
|
||||
args.source, model_tag=args.model_tag, step=args.step
|
||||
)
|
||||
await app.state.worker_pool.initialize(args.source, model_tag=args.model_tag, step=args.step)
|
||||
print(f"Server ready at http://localhost:{args.port}")
|
||||
yield
|
||||
|
||||
|
|
@ -301,8 +248,7 @@ async def root():
|
|||
html_content = f.read()
|
||||
# Replace the API_URL to use the same origin
|
||||
html_content = html_content.replace(
|
||||
"const API_URL = `http://${window.location.hostname}:8000`;",
|
||||
"const API_URL = '';",
|
||||
"const API_URL = `http://${window.location.hostname}:8000`;", "const API_URL = '';"
|
||||
)
|
||||
return HTMLResponse(content=html_content)
|
||||
|
||||
|
|
@ -352,7 +298,7 @@ async def generate_stream(
|
|||
current_text = worker.tokenizer.decode(accumulated_tokens)
|
||||
# Only emit text if it doesn't end with a replacement character
|
||||
# This ensures we don't emit incomplete UTF-8 sequences
|
||||
if not current_text.endswith("<EFBFBD>"):
|
||||
if not current_text.endswith('<EFBFBD>'):
|
||||
# Extract only the new text since last clean decode
|
||||
new_text = current_text[len(last_clean_text) :]
|
||||
if new_text: # Only yield if there's new content
|
||||
|
|
@ -435,14 +381,12 @@ async def chat_completions(request: ChatRequest):
|
|||
@app.get("/health")
|
||||
async def health():
|
||||
"""Health check endpoint."""
|
||||
worker_pool = getattr(app.state, "worker_pool", None)
|
||||
worker_pool = getattr(app.state, 'worker_pool', None)
|
||||
return {
|
||||
"status": "ok",
|
||||
"ready": worker_pool is not None and len(worker_pool.workers) > 0,
|
||||
"num_gpus": worker_pool.num_gpus if worker_pool else 0,
|
||||
"available_workers": (
|
||||
worker_pool.available_workers.qsize() if worker_pool else 0
|
||||
),
|
||||
"available_workers": worker_pool.available_workers.qsize() if worker_pool else 0,
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -453,19 +397,14 @@ async def stats():
|
|||
return {
|
||||
"total_workers": len(worker_pool.workers),
|
||||
"available_workers": worker_pool.available_workers.qsize(),
|
||||
"busy_workers": len(worker_pool.workers)
|
||||
- worker_pool.available_workers.qsize(),
|
||||
"workers": [
|
||||
{"gpu_id": w.gpu_id, "device": str(w.device)} for w in worker_pool.workers
|
||||
],
|
||||
"busy_workers": len(worker_pool.workers) - worker_pool.available_workers.qsize(),
|
||||
"workers": [{"gpu_id": w.gpu_id, "device": str(w.device)} for w in worker_pool.workers],
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
print(f"Starting NanoChat Web Server")
|
||||
print(
|
||||
f"Temperature: {args.temperature}, Top-k: {args.top_k}, Max tokens: {args.max_tokens}"
|
||||
)
|
||||
print("Starting NanoChat Web Server")
|
||||
print(f"Temperature: {args.temperature}, Top-k: {args.top_k}, Max tokens: {args.max_tokens}")
|
||||
uvicorn.run(app, host=args.host, port=args.port)
|
||||
|
|
|
|||
|
|
@ -21,14 +21,7 @@ import torch.distributed as dist
|
|||
import wandb
|
||||
|
||||
from nanochat.checkpoint_manager import load_model, save_checkpoint
|
||||
from nanochat.common import (
|
||||
DummyWandb,
|
||||
autodetect_device_type,
|
||||
compute_cleanup,
|
||||
compute_init,
|
||||
get_base_dir,
|
||||
print0,
|
||||
)
|
||||
from nanochat.common import DummyWandb, autodetect_device_type, compute_cleanup, compute_init, get_base_dir, print0
|
||||
from nanochat.loss_eval import evaluate_bpb
|
||||
from nanochat.tokenizer import get_token_bytes
|
||||
from tasks.common import TaskMixture
|
||||
|
|
@ -56,14 +49,8 @@ eval_every = 150 # -1 = disable
|
|||
eval_tokens = 20 * 524288
|
||||
total_batch_size = 524288
|
||||
dry_run = 0 # dry_run=1 is for experiments: we will log to wandb but we won't write checkpoints or report
|
||||
config_keys = [
|
||||
k
|
||||
for k, v in globals().items()
|
||||
if not k.startswith("_") and isinstance(v, (int, float, bool, str))
|
||||
]
|
||||
exec(
|
||||
open(os.path.join("nanochat", "configurator.py")).read()
|
||||
) # overrides from command line or config file
|
||||
config_keys = [k for k, v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
|
||||
exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file
|
||||
user_config = {k: globals()[k] for k in config_keys} # possibly useful for logging
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
|
@ -72,25 +59,17 @@ device_type = autodetect_device_type() if device_type == "" else device_type
|
|||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||
master_process = ddp_rank == 0
|
||||
autocast_ctx = (
|
||||
torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16)
|
||||
if device_type == "cuda"
|
||||
else nullcontext()
|
||||
torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
|
||||
)
|
||||
synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None
|
||||
get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0
|
||||
|
||||
# wandb logging init
|
||||
use_dummy_wandb = run == "dummy" or not master_process
|
||||
wandb_run = (
|
||||
DummyWandb()
|
||||
if use_dummy_wandb
|
||||
else wandb.init(project="nanochat-mid", name=run, config=user_config)
|
||||
)
|
||||
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-mid", name=run, config=user_config)
|
||||
|
||||
# Load the model and tokenizer
|
||||
model, tokenizer, meta = load_model(
|
||||
"base", device, phase="train", model_tag=model_tag, step=step
|
||||
)
|
||||
model, tokenizer, meta = load_model("base", device, phase="train", model_tag=model_tag, step=step)
|
||||
pretrain_batch_size = meta.get("device_batch_size", None)
|
||||
if pretrain_batch_size is not None and device_batch_size > pretrain_batch_size:
|
||||
print0(
|
||||
|
|
@ -100,38 +79,25 @@ orig_model = model
|
|||
model = torch.compile(model, dynamic=False)
|
||||
depth = model.config.n_layer
|
||||
num_flops_per_token = model.estimate_flops()
|
||||
tokens_per_fwdbwd = (
|
||||
device_batch_size * max_seq_len
|
||||
) # tokens per iteration for a single rank
|
||||
world_tokens_per_fwdbwd = (
|
||||
tokens_per_fwdbwd * ddp_world_size
|
||||
) # total tokens per iteration for all ranks
|
||||
tokens_per_fwdbwd = device_batch_size * max_seq_len # tokens per iteration for a single rank
|
||||
world_tokens_per_fwdbwd = tokens_per_fwdbwd * ddp_world_size # total tokens per iteration for all ranks
|
||||
assert total_batch_size % world_tokens_per_fwdbwd == 0
|
||||
grad_accum_steps = total_batch_size // world_tokens_per_fwdbwd
|
||||
print0(
|
||||
f"Tokens / micro-batch / rank: {device_batch_size} x {max_seq_len} = {tokens_per_fwdbwd:,}"
|
||||
)
|
||||
print0(f"Tokens / micro-batch / rank: {device_batch_size} x {max_seq_len} = {tokens_per_fwdbwd:,}")
|
||||
print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}")
|
||||
print0(
|
||||
f"Total batch size {total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}"
|
||||
)
|
||||
print0(f"Total batch size {total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}")
|
||||
token_bytes = get_token_bytes(device=device)
|
||||
|
||||
# Initialize the Optimizer (Muon for Linear layers, AdamW for embedding and lm_head)
|
||||
optimizers = model.setup_optimizers(
|
||||
unembedding_lr=unembedding_lr,
|
||||
embedding_lr=embedding_lr,
|
||||
matrix_lr=matrix_lr,
|
||||
weight_decay=weight_decay,
|
||||
unembedding_lr=unembedding_lr, embedding_lr=embedding_lr, matrix_lr=matrix_lr, weight_decay=weight_decay
|
||||
)
|
||||
adamw_optimizer, muon_optimizer = optimizers
|
||||
# Override the initial learning rate as a fraction of the base learning rate
|
||||
for opt in optimizers:
|
||||
for group in opt.param_groups:
|
||||
group["lr"] = group["lr"] * init_lr_frac
|
||||
group["initial_lr"] = group[
|
||||
"lr"
|
||||
] # save the initial learning so we can decay easily later
|
||||
group["initial_lr"] = group["lr"] # save the initial learning so we can decay easily later
|
||||
|
||||
# Midtraining data mixture and DataLoader
|
||||
base_dir = get_base_dir()
|
||||
|
|
@ -142,32 +108,18 @@ train_dataset = TaskMixture(
|
|||
MMLU(
|
||||
subset="auxiliary_train", split="train"
|
||||
), # 100K rows of multiple choice problems drawn from ARC, MC_TEST, OBQA, RACE
|
||||
GSM8K(
|
||||
subset="main", split="train"
|
||||
), # 8K rows teaching simple math and (calculator) tool use
|
||||
CustomJSON(
|
||||
filepath=identity_conversations_filepath
|
||||
), # 1000 rows of synthetic identity conversations
|
||||
CustomJSON(
|
||||
filepath=identity_conversations_filepath
|
||||
), # let's do 2 epochs of these
|
||||
SimpleSpelling(
|
||||
size=200000, split="train"
|
||||
), # 200K rows of Simple Spelling (e.g. spell the word 'apple')
|
||||
SpellingBee(
|
||||
size=80000, split="train"
|
||||
), # 80K rows of Spelling Bee (e.g. how many 'r' are in 'strawberry'?)
|
||||
GSM8K(subset="main", split="train"), # 8K rows teaching simple math and (calculator) tool use
|
||||
CustomJSON(filepath=identity_conversations_filepath), # 1000 rows of synthetic identity conversations
|
||||
CustomJSON(filepath=identity_conversations_filepath), # let's do 2 epochs of these
|
||||
SimpleSpelling(size=200000, split="train"), # 200K rows of Simple Spelling (e.g. spell the word 'apple')
|
||||
SpellingBee(size=80000, split="train"), # 80K rows of Spelling Bee (e.g. how many 'r' are in 'strawberry'?)
|
||||
]
|
||||
) # total: 460K + 100K + 8K + 200K + 80K = 848K rows
|
||||
val_dataset = TaskMixture(
|
||||
[
|
||||
SmolTalk(split="test"), # 24K rows in test set
|
||||
MMLU(
|
||||
subset="all", split="test", stop=5200
|
||||
), # 14K rows in test set, use only 5.2K to match the train ratios
|
||||
GSM8K(
|
||||
subset="main", split="test", stop=420
|
||||
), # 1.32K rows in test set, use only 420 to match the train ratios
|
||||
MMLU(subset="all", split="test", stop=5200), # 14K rows in test set, use only 5.2K to match the train ratios
|
||||
GSM8K(subset="main", split="test", stop=420), # 1.32K rows in test set, use only 420 to match the train ratios
|
||||
]
|
||||
) # total: 24K + 14K + 1.32K ~= 39K rows
|
||||
# DataLoader is defined here, it emits inputs, targets : 2D tensors of shape (device_batch_size, max_seq_len)
|
||||
|
|
@ -183,14 +135,10 @@ def mid_data_generator(split):
|
|||
dataset = train_dataset if split == "train" else val_dataset
|
||||
dataset_size = len(dataset)
|
||||
assert dataset_size > 0
|
||||
needed_tokens = (
|
||||
device_batch_size * max_seq_len + 1
|
||||
) # to form one training batch of inputs,targets
|
||||
needed_tokens = device_batch_size * max_seq_len + 1 # to form one training batch of inputs,targets
|
||||
token_buffer = deque()
|
||||
# CUDA supports memory pinning for faster transfers between CPU and GPU:
|
||||
scratch = torch.empty(
|
||||
needed_tokens, dtype=torch.int64, pin_memory=(device_type == "cuda")
|
||||
)
|
||||
scratch = torch.empty(needed_tokens, dtype=torch.int64, pin_memory=(device_type == "cuda"))
|
||||
cursor = ddp_rank # increments by ddp_world_size each time, so each rank processes unique documents
|
||||
it = 0 # iteration counter
|
||||
while True:
|
||||
|
|
@ -207,29 +155,21 @@ def mid_data_generator(split):
|
|||
# Stopping condition to respect num_iterations, if given
|
||||
it += 1
|
||||
if num_iterations > 0 and it >= num_iterations:
|
||||
last_step = (
|
||||
True # toggle last_step to True, which will terminate the training loop
|
||||
)
|
||||
last_step = True # toggle last_step to True, which will terminate the training loop
|
||||
# Build up inputs/targets and yield
|
||||
for i in range(needed_tokens):
|
||||
scratch[i] = token_buffer.popleft()
|
||||
inputs_cpu = scratch[:-1].to(dtype=torch.int32)
|
||||
targets_cpu = scratch[1:]
|
||||
inputs = inputs_cpu.view(device_batch_size, max_seq_len).to(
|
||||
device=device, dtype=torch.int32, non_blocking=True
|
||||
)
|
||||
inputs = inputs_cpu.view(device_batch_size, max_seq_len).to(device=device, dtype=torch.int32, non_blocking=True)
|
||||
targets = targets_cpu.view(device_batch_size, max_seq_len).to(
|
||||
device=device, dtype=torch.int64, non_blocking=True
|
||||
)
|
||||
if split == "train":
|
||||
if num_iterations > 0:
|
||||
approx_progress = (
|
||||
it / num_iterations
|
||||
) # calculate progress from the max number of iterations
|
||||
approx_progress = it / num_iterations # calculate progress from the max number of iterations
|
||||
else:
|
||||
approx_progress = (
|
||||
cursor / dataset_size
|
||||
) # approximate progress as a fraction of the dataset
|
||||
approx_progress = cursor / dataset_size # approximate progress as a fraction of the dataset
|
||||
yield inputs, targets
|
||||
|
||||
|
||||
|
|
@ -296,9 +236,7 @@ while True:
|
|||
checkpoint_dir,
|
||||
step,
|
||||
orig_model.state_dict(),
|
||||
[
|
||||
opt.state_dict() for opt in optimizers
|
||||
], # TODO: make sure saving across ranks is done correctly
|
||||
[opt.state_dict() for opt in optimizers], # TODO: make sure saving across ranks is done correctly
|
||||
{
|
||||
"step": step,
|
||||
"val_bpb": val_bpb, # loss at last step
|
||||
|
|
@ -326,16 +264,10 @@ while True:
|
|||
with autocast_ctx:
|
||||
loss = model(x, y)
|
||||
train_loss = loss.detach() # for logging
|
||||
loss = (
|
||||
loss / grad_accum_steps
|
||||
) # each .backward() is a grad sum => normalize loss here
|
||||
loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here
|
||||
loss.backward()
|
||||
x, y = next(
|
||||
train_loader
|
||||
) # prefetch the next batch while the GPU is busy with forward/backward
|
||||
progress = max(
|
||||
progress, approx_progress
|
||||
) # only increase progress monotonically
|
||||
x, y = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward
|
||||
progress = max(progress, approx_progress) # only increase progress monotonically
|
||||
# step the optimizers
|
||||
lrm = get_lr_multiplier(progress)
|
||||
for opt in optimizers:
|
||||
|
|
@ -356,23 +288,17 @@ while True:
|
|||
step += 1
|
||||
|
||||
# logging
|
||||
smooth_train_loss = (
|
||||
ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss.item()
|
||||
) # EMA the training loss
|
||||
debiased_smooth_loss = smooth_train_loss / (
|
||||
1 - ema_beta ** (step + 1)
|
||||
) # debias the EMA
|
||||
smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss.item() # EMA the training loss
|
||||
debiased_smooth_loss = smooth_train_loss / (1 - ema_beta ** (step + 1)) # debias the EMA
|
||||
pct_done = 100 * progress
|
||||
tok_per_sec = int(total_batch_size / dt)
|
||||
flops_per_sec = num_flops_per_token * total_batch_size / dt
|
||||
promised_flops_per_sec_h100 = (
|
||||
989e12 * ddp_world_size
|
||||
) # bfloat16 H100 SXM and without 2:4 sparsity
|
||||
promised_flops_per_sec_h100 = 989e12 * ddp_world_size # bfloat16 H100 SXM and without 2:4 sparsity
|
||||
mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in %
|
||||
if step > 10:
|
||||
total_training_time += dt # only count the time after the first 10 steps
|
||||
print0(
|
||||
f"step {step:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | total time: {total_training_time/60:.2f}m"
|
||||
f"step {step:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | total time: {total_training_time / 60:.2f}m"
|
||||
)
|
||||
if step % 10 == 0:
|
||||
wandb_run.log(
|
||||
|
|
@ -390,7 +316,7 @@ while True:
|
|||
|
||||
# print a few more stats
|
||||
print0(f"Peak memory usage: {get_max_memory() / 1024 / 1024:.2f}MiB")
|
||||
print0(f"Total training time: {total_training_time/60:.2f}m")
|
||||
print0(f"Total training time: {total_training_time / 60:.2f}m")
|
||||
print0(f"Minimum validation bpb: {min_val_bpb:.4f}")
|
||||
|
||||
# Log to report
|
||||
|
|
|
|||
|
|
@ -165,15 +165,10 @@ tokenizer_results = {}
|
|||
vocab_sizes = {}
|
||||
|
||||
for tokenizer_name in ["gpt2", "gpt4", "ours"]:
|
||||
|
||||
if tokenizer_name == "gpt2":
|
||||
tokenizer = RustBPETokenizer.from_pretrained(
|
||||
"gpt2"
|
||||
) # gpt-2 base model tokenizer
|
||||
tokenizer = RustBPETokenizer.from_pretrained("gpt2") # gpt-2 base model tokenizer
|
||||
elif tokenizer_name == "gpt4":
|
||||
tokenizer = RustBPETokenizer.from_pretrained(
|
||||
"cl100k_base"
|
||||
) # gpt-4 base model tokenizer
|
||||
tokenizer = RustBPETokenizer.from_pretrained("cl100k_base") # gpt-4 base model tokenizer
|
||||
else:
|
||||
tokenizer = get_tokenizer()
|
||||
|
||||
|
|
@ -185,21 +180,17 @@ for tokenizer_name in ["gpt2", "gpt4", "ours"]:
|
|||
decoded = tokenizer.decode(encoded)
|
||||
assert decoded == text
|
||||
|
||||
encoded_bytes = text.encode("utf-8")
|
||||
encoded_bytes = text.encode('utf-8')
|
||||
ratio = len(encoded_bytes) / len(encoded)
|
||||
tokenizer_results[tokenizer_name][name] = {
|
||||
"bytes": len(encoded_bytes),
|
||||
"tokens": len(encoded),
|
||||
"ratio": ratio,
|
||||
}
|
||||
tokenizer_results[tokenizer_name][name] = {'bytes': len(encoded_bytes), 'tokens': len(encoded), 'ratio': ratio}
|
||||
|
||||
# ANSI color codes
|
||||
GREEN = "\033[92m"
|
||||
RED = "\033[91m"
|
||||
RESET = "\033[0m"
|
||||
GREEN = '\033[92m'
|
||||
RED = '\033[91m'
|
||||
RESET = '\033[0m'
|
||||
|
||||
# Print vocab sizes
|
||||
print(f"\nVocab sizes:")
|
||||
print("\nVocab sizes:")
|
||||
print(f"GPT-2: {vocab_sizes['gpt2']}")
|
||||
print(f"GPT-4: {vocab_sizes['gpt4']}")
|
||||
print(f"Ours: {vocab_sizes['ours']}")
|
||||
|
|
@ -209,12 +200,8 @@ def print_comparison(baseline_name, baseline_results, ours_results, all_text):
|
|||
"""Print comparison table between baseline tokenizer and ours."""
|
||||
print(f"\nComparison with {baseline_name}:")
|
||||
print("=" * 95)
|
||||
print(
|
||||
f"{'Text Type':<10} {'Bytes':<8} {baseline_name:<15} {'Ours':<15} {'Relative':<12} {'Better':<10}"
|
||||
)
|
||||
print(
|
||||
f"{'':10} {'':8} {'Tokens':<7} {'Ratio':<7} {'Tokens':<7} {'Ratio':<7} {'Diff %':<12}"
|
||||
)
|
||||
print(f"{'Text Type':<10} {'Bytes':<8} {baseline_name:<15} {'Ours':<15} {'Relative':<12} {'Better':<10}")
|
||||
print(f"{'':10} {'':8} {'Tokens':<7} {'Ratio':<7} {'Tokens':<7} {'Ratio':<7} {'Diff %':<12}")
|
||||
print("-" * 95)
|
||||
|
||||
for name, text in all_text:
|
||||
|
|
@ -223,16 +210,14 @@ def print_comparison(baseline_name, baseline_results, ours_results, all_text):
|
|||
|
||||
# Calculate relative difference (positive means ours is better, negative means worse)
|
||||
# Using tokens: fewer tokens is better, so we calculate (baseline_tokens - ours_tokens) / baseline_tokens
|
||||
relative_diff = (
|
||||
(baseline_data["tokens"] - ours_data["tokens"]) / baseline_data["tokens"]
|
||||
) * 100
|
||||
relative_diff = ((baseline_data['tokens'] - ours_data['tokens']) / baseline_data['tokens']) * 100
|
||||
|
||||
# Determine which has better compression (higher ratio = better)
|
||||
if baseline_data["ratio"] > ours_data["ratio"]:
|
||||
if baseline_data['ratio'] > ours_data['ratio']:
|
||||
baseline_color, ours_color = GREEN, RED
|
||||
better = baseline_name
|
||||
diff_color = RED
|
||||
elif ours_data["ratio"] > baseline_data["ratio"]:
|
||||
elif ours_data['ratio'] > baseline_data['ratio']:
|
||||
baseline_color, ours_color = RED, GREEN
|
||||
better = "Ours"
|
||||
diff_color = GREEN
|
||||
|
|
@ -253,21 +238,17 @@ def print_comparison(baseline_name, baseline_results, ours_results, all_text):
|
|||
|
||||
|
||||
# Print comparisons
|
||||
print_comparison(
|
||||
"GPT-2", tokenizer_results["gpt2"], tokenizer_results["ours"], all_text
|
||||
)
|
||||
print_comparison(
|
||||
"GPT-4", tokenizer_results["gpt4"], tokenizer_results["ours"], all_text
|
||||
)
|
||||
print_comparison("GPT-2", tokenizer_results['gpt2'], tokenizer_results['ours'], all_text)
|
||||
print_comparison("GPT-4", tokenizer_results['gpt4'], tokenizer_results['ours'], all_text)
|
||||
|
||||
# Log to report
|
||||
from nanochat.report import get_report
|
||||
|
||||
lines = []
|
||||
for baseline_name in ["GPT-2", "GPT-4"]:
|
||||
baseline_key = baseline_name.lower().replace("-", "")
|
||||
baseline_key = baseline_name.lower().replace('-', '')
|
||||
baseline_results = tokenizer_results[baseline_key]
|
||||
ours_results = tokenizer_results["ours"]
|
||||
ours_results = tokenizer_results['ours']
|
||||
lines.append(f"### Comparison with {baseline_name}")
|
||||
lines.append("")
|
||||
lines.append(
|
||||
|
|
@ -277,15 +258,11 @@ for baseline_name in ["GPT-2", "GPT-4"]:
|
|||
+ baseline_name
|
||||
+ " Ratio | Ours Tokens | Ours Ratio | Relative Diff % |"
|
||||
)
|
||||
lines.append(
|
||||
"|-----------|-------|--------------|--------------|-------------|------------|-----------------|"
|
||||
)
|
||||
lines.append("|-----------|-------|--------------|--------------|-------------|------------|-----------------|")
|
||||
for name, text in all_text:
|
||||
baseline_data = baseline_results[name]
|
||||
ours_data = ours_results[name]
|
||||
relative_diff = (
|
||||
(baseline_data["tokens"] - ours_data["tokens"]) / baseline_data["tokens"]
|
||||
) * 100
|
||||
relative_diff = ((baseline_data['tokens'] - ours_data['tokens']) / baseline_data['tokens']) * 100
|
||||
lines.append(
|
||||
f"| {name} | {baseline_data['bytes']} | {baseline_data['tokens']} | {baseline_data['ratio']:.2f} | {ours_data['tokens']} | {ours_data['ratio']:.2f} | {relative_diff:+.1f}% |"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -16,25 +16,12 @@ from nanochat.tokenizer import RustBPETokenizer
|
|||
# -----------------------------------------------------------------------------
|
||||
# Parse command line arguments
|
||||
|
||||
parser = argparse.ArgumentParser(description="Train a BPE tokenizer")
|
||||
parser = argparse.ArgumentParser(description='Train a BPE tokenizer')
|
||||
parser.add_argument(
|
||||
"--max_chars",
|
||||
type=int,
|
||||
default=10_000_000_000,
|
||||
help="Maximum characters to train on (default: 10B)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--doc_cap",
|
||||
type=int,
|
||||
default=10_000,
|
||||
help="Maximum characters per document (default: 10,000)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vocab_size",
|
||||
type=int,
|
||||
default=65536,
|
||||
help="Vocabulary size (default: 65536 = 2^16)",
|
||||
'--max_chars', type=int, default=10_000_000_000, help='Maximum characters to train on (default: 10B)'
|
||||
)
|
||||
parser.add_argument('--doc_cap', type=int, default=10_000, help='Maximum characters per document (default: 10,000)')
|
||||
parser.add_argument('--vocab_size', type=int, default=65536, help='Vocabulary size (default: 65536 = 2^16)')
|
||||
args = parser.parse_args()
|
||||
print(f"max_chars: {args.max_chars:,}")
|
||||
print(f"doc_cap: {args.doc_cap:,}")
|
||||
|
|
@ -99,17 +86,13 @@ special_set = set(tokenizer.get_special_tokens())
|
|||
token_strings = [tokenizer.decode([token_id]) for token_id in range(vocab_size)]
|
||||
token_bytes = []
|
||||
for token_id in range(vocab_size):
|
||||
token_str = token_strings[
|
||||
token_id
|
||||
] # the Python string representation of this token
|
||||
token_str = token_strings[token_id] # the Python string representation of this token
|
||||
if token_str in special_set:
|
||||
token_bytes.append(0) # special characters are not counted
|
||||
else:
|
||||
id_bytes = len(
|
||||
token_str.encode("utf-8")
|
||||
) # number of bytes that make up this token
|
||||
id_bytes = len(token_str.encode("utf-8")) # number of bytes that make up this token
|
||||
token_bytes.append(id_bytes)
|
||||
token_bytes = torch.tensor(token_bytes, dtype=torch.int32, device="cpu")
|
||||
token_bytes = torch.tensor(token_bytes, dtype=torch.int32, device='cpu')
|
||||
token_bytes_path = os.path.join(tokenizer_dir, "token_bytes.pt")
|
||||
with open(token_bytes_path, "wb") as f:
|
||||
torch.save(token_bytes, f)
|
||||
|
|
|
|||
31
tasks/arc.py
31
tasks/arc.py
|
|
@ -9,23 +9,15 @@ from tasks.common import Task, render_mc
|
|||
|
||||
|
||||
class ARC(Task):
|
||||
|
||||
def __init__(self, subset, split, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
assert subset in [
|
||||
"ARC-Easy",
|
||||
"ARC-Challenge",
|
||||
], "ARC subset must be ARC-Easy or ARC-Challenge"
|
||||
assert split in [
|
||||
"train",
|
||||
"validation",
|
||||
"test",
|
||||
], "ARC split must be train|validation|test"
|
||||
assert subset in ["ARC-Easy", "ARC-Challenge"], "ARC subset must be ARC-Easy or ARC-Challenge"
|
||||
assert split in ["train", "validation", "test"], "ARC split must be train|validation|test"
|
||||
self.ds = load_dataset("allenai/ai2_arc", subset, split=split).shuffle(seed=42)
|
||||
|
||||
@property
|
||||
def eval_type(self):
|
||||
return "categorical"
|
||||
return 'categorical'
|
||||
|
||||
def num_examples(self):
|
||||
return len(self.ds)
|
||||
|
|
@ -36,15 +28,10 @@ class ARC(Task):
|
|||
choices = row["choices"]["text"] # the text of each choice
|
||||
answer_string = row["answerKey"] # e.g. "A", "B", "C", "D"
|
||||
letters = row["choices"]["label"] # e.g. ["A", "B", "C", "D"]
|
||||
assert (
|
||||
answer_string in letters
|
||||
), f"ARC answer {answer_string} must be one of {letters}" # sanity check
|
||||
assert answer_string in letters, f"ARC answer {answer_string} must be one of {letters}" # sanity check
|
||||
# create and return the Conversation object
|
||||
user_message = render_mc(question, letters, choices)
|
||||
messages = [
|
||||
{"role": "user", "content": user_message},
|
||||
{"role": "assistant", "content": answer_string},
|
||||
]
|
||||
messages = [{"role": "user", "content": user_message}, {"role": "assistant", "content": answer_string}]
|
||||
conversation = {
|
||||
"messages": messages,
|
||||
"letters": letters, # useful during evaluation, so we can narrow and clamp the assistant prediction to one of the letters
|
||||
|
|
@ -54,8 +41,8 @@ class ARC(Task):
|
|||
def evaluate(self, conversation, assistant_response):
|
||||
# the assert here is not strictly speaking needed, but currently the way we eval, we expect this to be true
|
||||
# I'm going to leave the assert here to prevent footguns, but possibly in the future can remove it.
|
||||
assert (
|
||||
assistant_response in conversation["letters"]
|
||||
), f"ARC answer {assistant_response} is expected to be one of {conversation['letters']}"
|
||||
assistant_message = conversation["messages"][-1]["content"] # e.g. "A"
|
||||
assert assistant_response in conversation['letters'], (
|
||||
f"ARC answer {assistant_response} is expected to be one of {conversation['letters']}"
|
||||
)
|
||||
assistant_message = conversation['messages'][-1]['content'] # e.g. "A"
|
||||
return assistant_response == assistant_message
|
||||
|
|
|
|||
|
|
@ -16,9 +16,7 @@ class Task:
|
|||
def __init__(self, start=0, stop=None, step=1):
|
||||
# allows a lightweight logical view over a dataset
|
||||
assert start >= 0, f"Start must be non-negative, got {start}"
|
||||
assert (
|
||||
stop is None or stop >= start
|
||||
), f"Stop should be greater than or equal to start, got {stop} and {start}"
|
||||
assert stop is None or stop >= start, f"Stop should be greater than or equal to start, got {stop} and {start}"
|
||||
assert step >= 1, f"Step must be strictly positive, got {step}"
|
||||
self.start = start
|
||||
self.stop = stop # could be None here
|
||||
|
|
@ -84,9 +82,9 @@ class TaskMixture(Task):
|
|||
Access conversations according to a deterministic shuffle of all examples.
|
||||
This ensures tasks are mixed throughout training, regardless of dataset size.
|
||||
"""
|
||||
assert (
|
||||
0 <= index < self.num_conversations
|
||||
), f"Index {index} out of range for mixture with {self.num_conversations} conversations"
|
||||
assert 0 <= index < self.num_conversations, (
|
||||
f"Index {index} out of range for mixture with {self.num_conversations} conversations"
|
||||
)
|
||||
task_idx, local_idx = self.index_map[index]
|
||||
return self.tasks[task_idx][local_idx]
|
||||
|
||||
|
|
@ -107,9 +105,9 @@ class TaskSequence(Task):
|
|||
return self.num_conversations
|
||||
|
||||
def get_example(self, index):
|
||||
assert (
|
||||
0 <= index < self.num_conversations
|
||||
), f"Index {index} out of range for sequence with {self.num_conversations} conversations"
|
||||
assert 0 <= index < self.num_conversations, (
|
||||
f"Index {index} out of range for sequence with {self.num_conversations} conversations"
|
||||
)
|
||||
for task_idx, task_length in enumerate(self.lengths):
|
||||
if index < task_length:
|
||||
return self.tasks[task_idx][index]
|
||||
|
|
@ -133,9 +131,7 @@ def render_mc(question, letters, choices):
|
|||
about this too much, but smaller models do care about some of these details.
|
||||
"""
|
||||
query = f"Multiple Choice question: {question}\n"
|
||||
query += "".join(
|
||||
[f"- {choice}={letter}\n" for letter, choice in zip(letters, choices)]
|
||||
)
|
||||
query += "".join([f"- {choice}={letter}\n" for letter, choice in zip(letters, choices)])
|
||||
query += "\nRespond only with the letter of the correct answer."
|
||||
return query
|
||||
|
||||
|
|
|
|||
|
|
@ -30,44 +30,32 @@ class CustomJSON(Task):
|
|||
print(
|
||||
"If you recently did a git pull and suddely see this, it might be due to the new addition of identity conversations"
|
||||
)
|
||||
print(
|
||||
"See this discussion for more details: https://github.com/karpathy/nanochat/discussions/139"
|
||||
)
|
||||
print(
|
||||
"Quick fix: simply run the following command to download the file and you're done:"
|
||||
)
|
||||
print("See this discussion for more details: https://github.com/karpathy/nanochat/discussions/139")
|
||||
print("Quick fix: simply run the following command to download the file and you're done:")
|
||||
print(
|
||||
f"curl -L -o {filepath} https://karpathy-public.s3.us-west-2.amazonaws.com/identity_conversations.jsonl"
|
||||
)
|
||||
print("-" * 80)
|
||||
|
||||
else:
|
||||
with open(filepath, encoding="utf-8") as f:
|
||||
with open(filepath, encoding='utf-8') as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if not line: # skip empty lines
|
||||
continue
|
||||
messages = json.loads(line)
|
||||
# Validate the conversation structure
|
||||
assert isinstance(
|
||||
messages, list
|
||||
), f"Expected list of messages, got {type(messages)}"
|
||||
assert (
|
||||
len(messages) >= 2
|
||||
), f"Conversation must have at least 2 messages, got {len(messages)}"
|
||||
assert isinstance(messages, list), f"Expected list of messages, got {type(messages)}"
|
||||
assert len(messages) >= 2, f"Conversation must have at least 2 messages, got {len(messages)}"
|
||||
# Validate message structure and alternating roles
|
||||
for i, message in enumerate(messages):
|
||||
assert "role" in message, f"Message {i} missing 'role' field"
|
||||
assert (
|
||||
"content" in message
|
||||
), f"Message {i} missing 'content' field"
|
||||
assert "content" in message, f"Message {i} missing 'content' field"
|
||||
expected_role = "user" if i % 2 == 0 else "assistant"
|
||||
assert (
|
||||
message["role"] == expected_role
|
||||
), f"Message {i} has role {message['role']} but should be {expected_role}"
|
||||
assert isinstance(
|
||||
message["content"], str
|
||||
), f"Message {i} content must be a string"
|
||||
assert message["role"] == expected_role, (
|
||||
f"Message {i} has role {message['role']} but should be {expected_role}"
|
||||
)
|
||||
assert isinstance(message["content"], str), f"Message {i} content must be a string"
|
||||
|
||||
self.conversations.append(messages)
|
||||
|
||||
|
|
|
|||
|
|
@ -38,7 +38,6 @@ def extract_answer(completion):
|
|||
|
||||
|
||||
class GSM8K(Task):
|
||||
|
||||
def __init__(self, subset, split, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
assert subset in ["main", "socratic"], "GSM8K subset must be main|socratic"
|
||||
|
|
@ -47,7 +46,7 @@ class GSM8K(Task):
|
|||
|
||||
@property
|
||||
def eval_type(self):
|
||||
return "generative"
|
||||
return 'generative'
|
||||
|
||||
def num_examples(self):
|
||||
return len(self.ds)
|
||||
|
|
@ -55,39 +54,32 @@ class GSM8K(Task):
|
|||
def get_example(self, index):
|
||||
"""Get a single problem from the dataset."""
|
||||
row = self.ds[index]
|
||||
question = row["question"] # string of the question prompt
|
||||
answer = row[
|
||||
"answer"
|
||||
] # string of the full solution and the answer after #### marker
|
||||
question = row['question'] # string of the question prompt
|
||||
answer = row['answer'] # string of the full solution and the answer after #### marker
|
||||
# Create and return the Conversation object
|
||||
# This is tricky because GSM8K uses tool calls, which we need to parse here.
|
||||
assistant_message_parts = []
|
||||
parts = re.split(r"(<<[^>]+>>)", answer)
|
||||
parts = re.split(r'(<<[^>]+>>)', answer)
|
||||
for part in parts:
|
||||
if part.startswith("<<") and part.endswith(">>"):
|
||||
if part.startswith('<<') and part.endswith('>>'):
|
||||
# This is a calculator tool call
|
||||
inner = part[2:-2] # Remove << >>
|
||||
# Split on = to get expression and result
|
||||
if "=" in inner:
|
||||
expr, result = inner.rsplit("=", 1)
|
||||
if '=' in inner:
|
||||
expr, result = inner.rsplit('=', 1)
|
||||
else:
|
||||
expr, result = inner, ""
|
||||
# Add the tool call as a part
|
||||
assistant_message_parts.append({"type": "python", "text": expr})
|
||||
# Add the result as a part
|
||||
assistant_message_parts.append(
|
||||
{"type": "python_output", "text": result}
|
||||
)
|
||||
assistant_message_parts.append({"type": "python_output", "text": result})
|
||||
else:
|
||||
# Regular text in between tool calls
|
||||
assistant_message_parts.append({"type": "text", "text": part})
|
||||
# No put it all together
|
||||
messages = [
|
||||
{"role": "user", "content": question}, # note: simple string
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": assistant_message_parts,
|
||||
}, # note: list of parts (as dicts)
|
||||
{"role": "assistant", "content": assistant_message_parts}, # note: list of parts (as dicts)
|
||||
]
|
||||
conversation = {
|
||||
"messages": messages,
|
||||
|
|
@ -104,20 +96,12 @@ class GSM8K(Task):
|
|||
TODO: Technically, assistant_response should be a Message (either a string or a list of parts)
|
||||
We can handle this later possibly. For now just assume string.
|
||||
"""
|
||||
assert isinstance(
|
||||
assistant_response, str
|
||||
), "Assuming simple string response for now"
|
||||
assert isinstance(assistant_response, str), "Assuming simple string response for now"
|
||||
# First extract the ground truth answer
|
||||
assistant_message = conversation["messages"][-1]
|
||||
assert (
|
||||
assistant_message["role"] == "assistant"
|
||||
), "Last message must be from the Assistant"
|
||||
assert isinstance(
|
||||
assistant_message["content"], list
|
||||
), "This is expected to be a list of parts"
|
||||
last_text_part = assistant_message["content"][-1][
|
||||
"text"
|
||||
] # this contains the final answer in GSM8K
|
||||
assistant_message = conversation['messages'][-1]
|
||||
assert assistant_message['role'] == "assistant", "Last message must be from the Assistant"
|
||||
assert isinstance(assistant_message['content'], list), "This is expected to be a list of parts"
|
||||
last_text_part = assistant_message['content'][-1]['text'] # this contains the final answer in GSM8K
|
||||
# Extract both the ground truth answer and the predicted answer
|
||||
ref_num = extract_answer(last_text_part)
|
||||
pred_num = extract_answer(assistant_response)
|
||||
|
|
|
|||
|
|
@ -15,14 +15,14 @@ from tasks.common import Task
|
|||
def extract_imports(prompt):
|
||||
"""Extract import statements from the beginning of a code block."""
|
||||
imports = []
|
||||
for line in prompt.split("\n"):
|
||||
for line in prompt.split('\n'):
|
||||
stripped = line.strip()
|
||||
if stripped.startswith("import ") or stripped.startswith("from "):
|
||||
if stripped.startswith('import ') or stripped.startswith('from '):
|
||||
imports.append(stripped)
|
||||
elif stripped and not stripped.startswith("#"):
|
||||
elif stripped and not stripped.startswith('#'):
|
||||
# Stop at first non-import, non-comment line
|
||||
break
|
||||
return "\n".join(imports)
|
||||
return '\n'.join(imports)
|
||||
|
||||
|
||||
def extract_program(completion):
|
||||
|
|
@ -38,7 +38,7 @@ def extract_program(completion):
|
|||
"""
|
||||
# Try to find markdown code blocks (```python or just ```)
|
||||
# Match ```python\n...\n``` or ```\n...\n```
|
||||
pattern = r"```(?:python)?\s*\n(.*?)\n```"
|
||||
pattern = r'```(?:python)?\s*\n(.*?)\n```'
|
||||
matches = re.findall(pattern, completion, re.DOTALL)
|
||||
|
||||
if matches:
|
||||
|
|
@ -50,14 +50,13 @@ def extract_program(completion):
|
|||
|
||||
|
||||
class HumanEval(Task):
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.ds = load_dataset("openai/openai_humaneval", split="test").shuffle(seed=42)
|
||||
|
||||
@property
|
||||
def eval_type(self):
|
||||
return "generative"
|
||||
return 'generative'
|
||||
|
||||
def num_examples(self):
|
||||
return len(self.ds)
|
||||
|
|
@ -65,10 +64,10 @@ class HumanEval(Task):
|
|||
def get_example(self, index):
|
||||
"""Get a single problem from the dataset."""
|
||||
row = self.ds[index]
|
||||
prompt = row["prompt"] # prompts in HumanEval are the beginning of the program
|
||||
solution = row["canonical_solution"] # the correct continuation of the program
|
||||
entry_point = row["entry_point"] # the function to check
|
||||
test = row["test"] # the test cases
|
||||
prompt = row['prompt'] # prompts in HumanEval are the beginning of the program
|
||||
solution = row['canonical_solution'] # the correct continuation of the program
|
||||
entry_point = row['entry_point'] # the function to check
|
||||
test = row['test'] # the test cases
|
||||
complete_solution = f"{prompt}\n{solution}"
|
||||
messages = [
|
||||
{"role": "user", "content": prompt},
|
||||
|
|
@ -84,7 +83,7 @@ class HumanEval(Task):
|
|||
def evaluate(self, conversation, completion):
|
||||
"""Given (conversation, completion), return boolean success of the completion."""
|
||||
# the prompt will contain the imports and the function signature
|
||||
imports = extract_imports(conversation["messages"][0]["content"])
|
||||
imports = extract_imports(conversation['messages'][0]['content'])
|
||||
# the completion will usually contain the whole function
|
||||
# but not always with the needed imports, so we manually append them
|
||||
completion_code = extract_program(completion)
|
||||
|
|
@ -93,7 +92,7 @@ class HumanEval(Task):
|
|||
+ "\n\n"
|
||||
+ completion_code
|
||||
+ "\n\n"
|
||||
+ conversation["test"]
|
||||
+ conversation['test']
|
||||
+ "\n"
|
||||
+ f"check({conversation['entry_point']})"
|
||||
)
|
||||
|
|
|
|||
146
tasks/mmlu.py
146
tasks/mmlu.py
|
|
@ -9,80 +9,71 @@ from tasks.common import Task, render_mc
|
|||
|
||||
|
||||
class MMLU(Task):
|
||||
|
||||
letters = ("A", "B", "C", "D")
|
||||
letters = ('A', 'B', 'C', 'D')
|
||||
groups = (
|
||||
"abstract_algebra",
|
||||
"anatomy",
|
||||
"astronomy",
|
||||
"business_ethics",
|
||||
"clinical_knowledge",
|
||||
"college_biology",
|
||||
"college_chemistry",
|
||||
"college_computer_science",
|
||||
"college_mathematics",
|
||||
"college_medicine",
|
||||
"college_physics",
|
||||
"computer_security",
|
||||
"conceptual_physics",
|
||||
"econometrics",
|
||||
"electrical_engineering",
|
||||
"elementary_mathematics",
|
||||
"formal_logic",
|
||||
"global_facts",
|
||||
"high_school_biology",
|
||||
"high_school_chemistry",
|
||||
"high_school_computer_science",
|
||||
"high_school_european_history",
|
||||
"high_school_geography",
|
||||
"high_school_government_and_politics",
|
||||
"high_school_macroeconomics",
|
||||
"high_school_mathematics",
|
||||
"high_school_microeconomics",
|
||||
"high_school_physics",
|
||||
"high_school_psychology",
|
||||
"high_school_statistics",
|
||||
"high_school_us_history",
|
||||
"high_school_world_history",
|
||||
"human_aging",
|
||||
"human_sexuality",
|
||||
"international_law",
|
||||
"jurisprudence",
|
||||
"logical_fallacies",
|
||||
"machine_learning",
|
||||
"management",
|
||||
"marketing",
|
||||
"medical_genetics",
|
||||
"miscellaneous",
|
||||
"moral_disputes",
|
||||
"moral_scenarios",
|
||||
"nutrition",
|
||||
"philosophy",
|
||||
"prehistory",
|
||||
"professional_accounting",
|
||||
"professional_law",
|
||||
"professional_medicine",
|
||||
"professional_psychology",
|
||||
"public_relations",
|
||||
"security_studies",
|
||||
"sociology",
|
||||
"us_foreign_policy",
|
||||
"virology",
|
||||
"world_religions",
|
||||
'abstract_algebra',
|
||||
'anatomy',
|
||||
'astronomy',
|
||||
'business_ethics',
|
||||
'clinical_knowledge',
|
||||
'college_biology',
|
||||
'college_chemistry',
|
||||
'college_computer_science',
|
||||
'college_mathematics',
|
||||
'college_medicine',
|
||||
'college_physics',
|
||||
'computer_security',
|
||||
'conceptual_physics',
|
||||
'econometrics',
|
||||
'electrical_engineering',
|
||||
'elementary_mathematics',
|
||||
'formal_logic',
|
||||
'global_facts',
|
||||
'high_school_biology',
|
||||
'high_school_chemistry',
|
||||
'high_school_computer_science',
|
||||
'high_school_european_history',
|
||||
'high_school_geography',
|
||||
'high_school_government_and_politics',
|
||||
'high_school_macroeconomics',
|
||||
'high_school_mathematics',
|
||||
'high_school_microeconomics',
|
||||
'high_school_physics',
|
||||
'high_school_psychology',
|
||||
'high_school_statistics',
|
||||
'high_school_us_history',
|
||||
'high_school_world_history',
|
||||
'human_aging',
|
||||
'human_sexuality',
|
||||
'international_law',
|
||||
'jurisprudence',
|
||||
'logical_fallacies',
|
||||
'machine_learning',
|
||||
'management',
|
||||
'marketing',
|
||||
'medical_genetics',
|
||||
'miscellaneous',
|
||||
'moral_disputes',
|
||||
'moral_scenarios',
|
||||
'nutrition',
|
||||
'philosophy',
|
||||
'prehistory',
|
||||
'professional_accounting',
|
||||
'professional_law',
|
||||
'professional_medicine',
|
||||
'professional_psychology',
|
||||
'public_relations',
|
||||
'security_studies',
|
||||
'sociology',
|
||||
'us_foreign_policy',
|
||||
'virology',
|
||||
'world_religions',
|
||||
)
|
||||
|
||||
def __init__(self, subset, split, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
assert subset in [
|
||||
"all",
|
||||
"auxiliary_train",
|
||||
], f"subset {subset} must be all|auxiliary_train"
|
||||
assert split in [
|
||||
"train",
|
||||
"validation",
|
||||
"dev",
|
||||
"test",
|
||||
], f"split {split} must be train|validation|dev|test"
|
||||
assert subset in ["all", "auxiliary_train"], f"subset {subset} must be all|auxiliary_train"
|
||||
assert split in ["train", "validation", "dev", "test"], f"split {split} must be train|validation|dev|test"
|
||||
if subset == "auxiliary_train":
|
||||
assert split == "train", "auxiliary_train must be split into train"
|
||||
self.subset = subset
|
||||
|
|
@ -90,11 +81,11 @@ class MMLU(Task):
|
|||
self.ds = load_dataset("cais/mmlu", subset, split=split).shuffle(seed=42)
|
||||
if subset == "auxiliary_train":
|
||||
# I don't understand why but the auxiliary_train rows have some weird additional 'train' wrapper
|
||||
self.ds = self.ds.map(lambda row: row["train"], remove_columns=["train"])
|
||||
self.ds = self.ds.map(lambda row: row['train'], remove_columns=['train'])
|
||||
|
||||
@property
|
||||
def eval_type(self):
|
||||
return "categorical"
|
||||
return 'categorical'
|
||||
|
||||
def num_examples(self):
|
||||
return len(self.ds)
|
||||
|
|
@ -109,10 +100,7 @@ class MMLU(Task):
|
|||
# create and return the Conversation object
|
||||
user_message = render_mc(question, self.letters, choices)
|
||||
assistant_message = self.letters[answer]
|
||||
messages = [
|
||||
{"role": "user", "content": user_message},
|
||||
{"role": "assistant", "content": assistant_message},
|
||||
]
|
||||
messages = [{"role": "user", "content": user_message}, {"role": "assistant", "content": assistant_message}]
|
||||
conversation = {
|
||||
"messages": messages,
|
||||
"subject": subject, # might be useful later for grouping metrics by subject
|
||||
|
|
@ -123,8 +111,8 @@ class MMLU(Task):
|
|||
def evaluate(self, conversation, assistant_response):
|
||||
# the assert here is not strictly speaking needed, but currently the way we eval, we expect this to be true
|
||||
# I'm going to leave the assert here to prevent footguns, but possibly in the future can remove it.
|
||||
assert (
|
||||
assistant_response in self.letters
|
||||
), f"MMLU answer {assistant_response} is expected to be one of {self.letters}"
|
||||
assistant_message = conversation["messages"][-1]["content"] # e.g. "A"
|
||||
assert assistant_response in self.letters, (
|
||||
f"MMLU answer {assistant_response} is expected to be one of {self.letters}"
|
||||
)
|
||||
assistant_message = conversation['messages'][-1]['content'] # e.g. "A"
|
||||
return assistant_response == assistant_message
|
||||
|
|
|
|||
|
|
@ -15,9 +15,7 @@ class SmolTalk(Task):
|
|||
def __init__(self, split, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
assert split in ["train", "test"], "SmolTalk split must be train|test"
|
||||
self.ds = load_dataset("HuggingFaceTB/smol-smoltalk", split=split).shuffle(
|
||||
seed=42
|
||||
)
|
||||
self.ds = load_dataset("HuggingFaceTB/smol-smoltalk", split=split).shuffle(seed=42)
|
||||
self.length = len(self.ds)
|
||||
|
||||
def num_examples(self):
|
||||
|
|
@ -36,15 +34,13 @@ class SmolTalk(Task):
|
|||
rest_messages = messages[1:] # optional system message is OK
|
||||
else:
|
||||
rest_messages = messages
|
||||
assert (
|
||||
len(rest_messages) >= 2
|
||||
), "SmolTalk messages must have at least 2 messages"
|
||||
assert len(rest_messages) >= 2, "SmolTalk messages must have at least 2 messages"
|
||||
for i, message in enumerate(rest_messages):
|
||||
# user and assistant alternate as user,assistant,user,assistant,...
|
||||
expected_role = "user" if i % 2 == 0 else "assistant"
|
||||
assert (
|
||||
message["role"] == expected_role
|
||||
), f"Message {i} has role {message['role']} but should be {expected_role}"
|
||||
assert message["role"] == expected_role, (
|
||||
f"Message {i} has role {message['role']} but should be {expected_role}"
|
||||
)
|
||||
assert isinstance(message["content"], str), "Content must be a string"
|
||||
# ---------------------------------------------------------------------
|
||||
# create and return the Conversation object (ok to emit the system message too)
|
||||
|
|
|
|||
|
|
@ -116,7 +116,6 @@ USER_MSG_TEMPLATES = [
|
|||
|
||||
|
||||
class SpellingBee(Task):
|
||||
|
||||
def __init__(self, size=1000, split="train", **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
assert split in ["train", "test"], "SpellingBee split must be train|test"
|
||||
|
|
@ -124,13 +123,13 @@ class SpellingBee(Task):
|
|||
self.split = split
|
||||
filename = WORD_LIST_URL.split("/")[-1]
|
||||
word_list_path = download_file_with_lock(WORD_LIST_URL, filename)
|
||||
with open(word_list_path, encoding="utf-8") as f:
|
||||
with open(word_list_path, encoding='utf-8') as f:
|
||||
words = [line.strip() for line in f]
|
||||
self.words = words
|
||||
|
||||
@property
|
||||
def eval_type(self):
|
||||
return "generative"
|
||||
return 'generative'
|
||||
|
||||
def num_examples(self):
|
||||
return self.size
|
||||
|
|
@ -152,7 +151,7 @@ class SpellingBee(Task):
|
|||
# 30% chance to lowercase the template (lazy people don't use shift)
|
||||
if rng.random() < 0.3:
|
||||
template = template.lower()
|
||||
quote_options = ["", "'", '"']
|
||||
quote_options = ['', "'", '"']
|
||||
letter_quote = rng.choice(quote_options) # is the letter quoted?
|
||||
word_quote = rng.choice(quote_options) # is the word quoted?
|
||||
letter_wrapped = f"{letter_quote}{letter}{letter_quote}"
|
||||
|
|
@ -188,9 +187,7 @@ Then count the occurrences of '{letter}':
|
|||
manual_text += f"\nThis gives us {running_count}."
|
||||
assistant_parts.append({"type": "text", "text": manual_text})
|
||||
# Part 2: Python verification
|
||||
assistant_parts.append(
|
||||
{"type": "text", "text": "\n\nLet me double check this using Python:\n\n"}
|
||||
)
|
||||
assistant_parts.append({"type": "text", "text": "\n\nLet me double check this using Python:\n\n"})
|
||||
# Part 3: Python tool call
|
||||
python_expr = f"'{word}'.count('{letter}')"
|
||||
assistant_parts.append({"type": "python", "text": python_expr})
|
||||
|
|
@ -198,17 +195,11 @@ Then count the occurrences of '{letter}':
|
|||
assistant_parts.append({"type": "python_output", "text": str(count)})
|
||||
# Part 5: Final answer
|
||||
assistant_parts.append(
|
||||
{
|
||||
"type": "text",
|
||||
"text": f"\n\nPython gives us {count}.\n\nMy final answer is:\n\n#### {count}",
|
||||
}
|
||||
{"type": "text", "text": f"\n\nPython gives us {count}.\n\nMy final answer is:\n\n#### {count}"}
|
||||
)
|
||||
|
||||
# return the full conversation
|
||||
messages = [
|
||||
{"role": "user", "content": user_msg},
|
||||
{"role": "assistant", "content": assistant_parts},
|
||||
]
|
||||
messages = [{"role": "user", "content": user_msg}, {"role": "assistant", "content": assistant_parts}]
|
||||
conversation = {
|
||||
"messages": messages,
|
||||
}
|
||||
|
|
@ -219,19 +210,13 @@ Then count the occurrences of '{letter}':
|
|||
Given (conversation, completion), return evaluation outcome (0 = wrong, 1 = correct)
|
||||
Identical to gsm8k's evaluation.
|
||||
"""
|
||||
assert isinstance(
|
||||
assistant_response, str
|
||||
), "Assuming simple string response for now"
|
||||
assert isinstance(assistant_response, str), "Assuming simple string response for now"
|
||||
# First extract the ground truth answer from the conversation
|
||||
assistant_message = conversation["messages"][-1]
|
||||
assert (
|
||||
assistant_message["role"] == "assistant"
|
||||
), "Last message must be from the Assistant"
|
||||
assert isinstance(
|
||||
assistant_message["content"], list
|
||||
), "This is expected to be a list of parts"
|
||||
assistant_message = conversation['messages'][-1]
|
||||
assert assistant_message['role'] == "assistant", "Last message must be from the Assistant"
|
||||
assert isinstance(assistant_message['content'], list), "This is expected to be a list of parts"
|
||||
# The last text part contains the final answer with ####
|
||||
last_text_part = assistant_message["content"][-1]["text"]
|
||||
last_text_part = assistant_message['content'][-1]['text']
|
||||
# Extract both the ground truth answer and the predicted answer
|
||||
ref_num = extract_answer(last_text_part)
|
||||
pred_num = extract_answer(assistant_response)
|
||||
|
|
@ -256,7 +241,7 @@ class SimpleSpelling(Task):
|
|||
self.split = split
|
||||
filename = WORD_LIST_URL.split("/")[-1]
|
||||
word_list_path = download_file_with_lock(WORD_LIST_URL, filename)
|
||||
with open(word_list_path, encoding="utf-8") as f:
|
||||
with open(word_list_path, encoding='utf-8') as f:
|
||||
words = [line.strip() for line in f]
|
||||
rng = random.Random(42)
|
||||
rng.shuffle(words) # use a different word order than the SpellingBee task
|
||||
|
|
@ -264,7 +249,7 @@ class SimpleSpelling(Task):
|
|||
|
||||
@property
|
||||
def eval_type(self):
|
||||
return "generative"
|
||||
return 'generative'
|
||||
|
||||
def num_examples(self):
|
||||
return self.size
|
||||
|
|
@ -287,23 +272,22 @@ class SimpleSpelling(Task):
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
# preview the SpellingBee task, first 10 examples
|
||||
task = SpellingBee()
|
||||
for i in range(10):
|
||||
ex = task.get_example(i)
|
||||
print("=" * 100)
|
||||
print(ex["messages"][0]["content"])
|
||||
print(ex['messages'][0]['content'])
|
||||
print("-" * 100)
|
||||
# Assistant content is now a list of parts
|
||||
assistant_parts = ex["messages"][1]["content"]
|
||||
assistant_parts = ex['messages'][1]['content']
|
||||
for part in assistant_parts:
|
||||
if part["type"] == "text":
|
||||
print(part["text"], end="")
|
||||
elif part["type"] == "python":
|
||||
print(f"<<{part['text']}=", end="")
|
||||
elif part["type"] == "python_output":
|
||||
print(f"{part['text']}>>", end="")
|
||||
if part['type'] == 'text':
|
||||
print(part['text'], end='')
|
||||
elif part['type'] == 'python':
|
||||
print(f"<<{part['text']}=", end='')
|
||||
elif part['type'] == 'python_output':
|
||||
print(f"{part['text']}>>", end='')
|
||||
print()
|
||||
print("-" * 100)
|
||||
|
||||
|
|
|
|||
|
|
@ -23,26 +23,14 @@ def test_kv_cache_resize():
|
|||
num_layers = 6
|
||||
|
||||
kv_cache = KVCache(
|
||||
batch_size=batch_size,
|
||||
num_heads=num_heads,
|
||||
seq_len=seq_len,
|
||||
head_dim=head_dim,
|
||||
num_layers=num_layers,
|
||||
batch_size=batch_size, num_heads=num_heads, seq_len=seq_len, head_dim=head_dim, num_layers=num_layers
|
||||
)
|
||||
|
||||
# Insert a single token with a distinct fill value to all layers
|
||||
def insert_token(token_idx):
|
||||
for layer_idx in range(num_layers):
|
||||
k = torch.full(
|
||||
(batch_size, num_heads, 1, head_dim),
|
||||
fill_value=float(token_idx),
|
||||
dtype=torch.float32,
|
||||
)
|
||||
v = torch.full(
|
||||
(batch_size, num_heads, 1, head_dim),
|
||||
fill_value=float(token_idx * 100),
|
||||
dtype=torch.float32,
|
||||
)
|
||||
k = torch.full((batch_size, num_heads, 1, head_dim), fill_value=float(token_idx), dtype=torch.float32)
|
||||
v = torch.full((batch_size, num_heads, 1, head_dim), fill_value=float(token_idx * 100), dtype=torch.float32)
|
||||
kv_cache.insert_kv(layer_idx, k, v)
|
||||
|
||||
# Insert 4 tokens (fills the initial seq_len=4)
|
||||
|
|
@ -57,9 +45,9 @@ def test_kv_cache_resize():
|
|||
insert_token(4)
|
||||
# Verify that the cache actually resized
|
||||
new_seq_len = kv_cache.kv_cache.shape[4]
|
||||
assert (
|
||||
new_seq_len > original_seq_len
|
||||
), f"Cache did not resize: original seq_len={original_seq_len}, new seq_len={new_seq_len}"
|
||||
assert new_seq_len > original_seq_len, (
|
||||
f"Cache did not resize: original seq_len={original_seq_len}, new seq_len={new_seq_len}"
|
||||
)
|
||||
|
||||
# Verify that the original 4 tokens are still intact after resize
|
||||
for layer_idx in range(num_layers):
|
||||
|
|
@ -69,20 +57,14 @@ def test_kv_cache_resize():
|
|||
expected_v = float(token_idx * 100)
|
||||
actual_k = kv_cache.kv_cache[layer_idx, 0, :, :, token_idx, :]
|
||||
actual_v = kv_cache.kv_cache[layer_idx, 1, :, :, token_idx, :]
|
||||
assert (
|
||||
actual_k == expected_k
|
||||
).all(), f"Layer {layer_idx}, token {token_idx}: key corrupted, expected {expected_k}"
|
||||
assert (
|
||||
actual_v == expected_v
|
||||
).all(), f"Layer {layer_idx}, token {token_idx}: value corrupted, expected {expected_v}"
|
||||
assert (actual_k == expected_k).all(), (
|
||||
f"Layer {layer_idx}, token {token_idx}: key corrupted, expected {expected_k}"
|
||||
)
|
||||
assert (actual_v == expected_v).all(), (
|
||||
f"Layer {layer_idx}, token {token_idx}: value corrupted, expected {expected_v}"
|
||||
)
|
||||
# And that the original cache matches resized cache
|
||||
original_k = original_cache[layer_idx, 0, :, :, token_idx, :]
|
||||
original_v = original_cache[layer_idx, 1, :, :, token_idx, :]
|
||||
assert (
|
||||
actual_k == original_k
|
||||
).all(), f"Layer {layer_idx}, token {token_idx}: key doesn't match original"
|
||||
assert (
|
||||
actual_v == original_v
|
||||
).all(), (
|
||||
f"Layer {layer_idx}, token {token_idx}: value doesn't match original"
|
||||
)
|
||||
assert (actual_k == original_k).all(), f"Layer {layer_idx}, token {token_idx}: key doesn't match original"
|
||||
assert (actual_v == original_v).all(), f"Layer {layer_idx}, token {token_idx}: value doesn't match original"
|
||||
|
|
|
|||
|
|
@ -27,7 +27,9 @@ import tiktoken
|
|||
|
||||
import rustbpe
|
||||
|
||||
GPT4_SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""
|
||||
GPT4_SPLIT_PATTERN = (
|
||||
r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""
|
||||
)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Reference tokenizer, pretty much copy pasted and pruned a bit from minbpe
|
||||
|
|
@ -65,7 +67,6 @@ def merge(ids, pair, idx):
|
|||
|
||||
|
||||
class RegexTokenizer:
|
||||
|
||||
def __init__(self, pattern=None):
|
||||
"""
|
||||
- pattern: optional string to override the default (GPT-4 split pattern)
|
||||
|
|
@ -114,9 +115,7 @@ class RegexTokenizer:
|
|||
pair = max(stats, key=stats.get)
|
||||
# check if the merge is ambiguous - i.e. the max value is not unique
|
||||
pair_count = stats[pair]
|
||||
pairs_with_max_count = [
|
||||
pair for pair, count in stats.items() if count == pair_count
|
||||
]
|
||||
pairs_with_max_count = [pair for pair, count in stats.items() if count == pair_count]
|
||||
if len(pairs_with_max_count) > 1:
|
||||
# print the top 10 pairs with their counts
|
||||
# print(f"{i} Merge is ambiguous! {pair} has {pair_count} occurrences")
|
||||
|
|
@ -132,9 +131,7 @@ class RegexTokenizer:
|
|||
vocab[idx] = vocab[pair[0]] + vocab[pair[1]]
|
||||
# prints
|
||||
if verbose:
|
||||
print(
|
||||
f"merge {i+1}/{num_merges}: {pair} -> {idx} ({vocab[idx]}) had {stats[pair]} occurrences"
|
||||
)
|
||||
print(f"merge {i + 1}/{num_merges}: {pair} -> {idx} ({vocab[idx]}) had {stats[pair]} occurrences")
|
||||
|
||||
# save class variables
|
||||
self.merges = merges # used in encode()
|
||||
|
|
@ -195,7 +192,6 @@ def fast_merge_inplace(ids, pair, idx):
|
|||
|
||||
|
||||
class FastRegexTokenizer:
|
||||
|
||||
def __init__(self, pattern=None):
|
||||
"""
|
||||
- pattern: optional string to override the default (GPT-4 split pattern)
|
||||
|
|
@ -245,9 +241,7 @@ class FastRegexTokenizer:
|
|||
|
||||
# Initial count: build stats and position tracking
|
||||
stats = defaultdict(int)
|
||||
positions = defaultdict(
|
||||
set
|
||||
) # pair -> set of chunk indices that contain this pair
|
||||
positions = defaultdict(set) # pair -> set of chunk indices that contain this pair
|
||||
|
||||
for chunk_idx, (chunk_ids, count) in enumerate(zip(ids, chunk_counts)):
|
||||
for pair in zip(chunk_ids, chunk_ids[1:]):
|
||||
|
|
@ -316,8 +310,7 @@ class FastRegexTokenizer:
|
|||
for chunk_idx in affected_chunks:
|
||||
chunk_ids = ids[chunk_idx]
|
||||
contains_pair = any(
|
||||
(chunk_ids[j], chunk_ids[j + 1]) == changed_pair
|
||||
for j in range(len(chunk_ids) - 1)
|
||||
(chunk_ids[j], chunk_ids[j + 1]) == changed_pair for j in range(len(chunk_ids) - 1)
|
||||
)
|
||||
if contains_pair:
|
||||
positions[changed_pair].add(chunk_idx)
|
||||
|
|
@ -390,9 +383,8 @@ class FastRegexTokenizer:
|
|||
|
||||
# -----------------------------------------------------------------------------
|
||||
# HuggingFace tokenizer
|
||||
from tokenizers import Regex
|
||||
from tokenizers import Regex, decoders, pre_tokenizers
|
||||
from tokenizers import Tokenizer as HFTokenizer
|
||||
from tokenizers import decoders, pre_tokenizers
|
||||
from tokenizers.models import BPE
|
||||
from tokenizers.trainers import BpeTrainer
|
||||
|
||||
|
|
@ -417,14 +409,10 @@ class HuggingFaceTokenizer:
|
|||
# Normalizer: None
|
||||
tokenizer.normalizer = None
|
||||
# Pre-tokenizer: GPT-4 style
|
||||
gpt4_split_regex = Regex(
|
||||
GPT4_SPLIT_PATTERN
|
||||
) # huggingface demands that you wrap it in Regex!!
|
||||
gpt4_split_regex = Regex(GPT4_SPLIT_PATTERN) # huggingface demands that you wrap it in Regex!!
|
||||
tokenizer.pre_tokenizer = pre_tokenizers.Sequence(
|
||||
[
|
||||
pre_tokenizers.Split(
|
||||
pattern=gpt4_split_regex, behavior="isolated", invert=False
|
||||
),
|
||||
pre_tokenizers.Split(pattern=gpt4_split_regex, behavior="isolated", invert=False),
|
||||
pre_tokenizers.ByteLevel(add_prefix_space=False, use_regex=False),
|
||||
]
|
||||
)
|
||||
|
|
@ -515,9 +503,7 @@ def test_correctness(enwik8_small):
|
|||
# Train slow reference
|
||||
print("\nTraining slow reference...")
|
||||
slow_reference_tokenizer = RegexTokenizer()
|
||||
ambiguous_flag, slow_reference_train_time = time_function(
|
||||
slow_reference_tokenizer.train, text, vocab_size
|
||||
)
|
||||
ambiguous_flag, slow_reference_train_time = time_function(slow_reference_tokenizer.train, text, vocab_size)
|
||||
slow_reference_ids, slow_reference_encode_time = time_function(
|
||||
slow_reference_tokenizer.encode_ordinary, encode_text
|
||||
)
|
||||
|
|
@ -526,21 +512,15 @@ def test_correctness(enwik8_small):
|
|||
print(slow_reference_ids[:20])
|
||||
|
||||
if ambiguous_flag:
|
||||
print(
|
||||
"‼️ WARNING: merge order was detected to be ambiguous given current text and vocab size"
|
||||
)
|
||||
print(
|
||||
"The implementation could be correct but we might see different results below"
|
||||
)
|
||||
print("‼️ WARNING: merge order was detected to be ambiguous given current text and vocab size")
|
||||
print("The implementation could be correct but we might see different results below")
|
||||
else:
|
||||
print("✅ Merge order is NOT ambiguous")
|
||||
|
||||
# Train fast reference
|
||||
print("\nTraining fast reference...")
|
||||
fast_reference_tokenizer = FastRegexTokenizer()
|
||||
_, fast_reference_train_time = time_function(
|
||||
fast_reference_tokenizer.train, text, vocab_size
|
||||
)
|
||||
_, fast_reference_train_time = time_function(fast_reference_tokenizer.train, text, vocab_size)
|
||||
fast_reference_ids, fast_reference_encode_time = time_function(
|
||||
fast_reference_tokenizer.encode_ordinary, encode_text
|
||||
)
|
||||
|
|
@ -549,16 +529,12 @@ def test_correctness(enwik8_small):
|
|||
print(fast_reference_ids[:20])
|
||||
|
||||
# Assert fast equals slow
|
||||
assert (
|
||||
fast_reference_ids == slow_reference_ids
|
||||
), "Fast reference should match slow reference"
|
||||
assert fast_reference_ids == slow_reference_ids, "Fast reference should match slow reference"
|
||||
print("✅ Fast == Slow")
|
||||
|
||||
# Train HuggingFace
|
||||
print("\nTraining HuggingFace...")
|
||||
hf_tokenizer, hf_train_time = time_function(
|
||||
HuggingFaceTokenizer.train_from_iterator, [text], vocab_size
|
||||
)
|
||||
hf_tokenizer, hf_train_time = time_function(HuggingFaceTokenizer.train_from_iterator, [text], vocab_size)
|
||||
hf_ids, hf_encode_time = time_function(hf_tokenizer.encode_ordinary, encode_text)
|
||||
print(f"HuggingFace train time: {hf_train_time:.4f}s")
|
||||
print(f"HuggingFace encode time: {hf_encode_time:.4f}s")
|
||||
|
|
@ -577,20 +553,14 @@ def test_correctness(enwik8_small):
|
|||
return False
|
||||
return True
|
||||
|
||||
assert custom_match(
|
||||
hf_ids, fast_reference_ids
|
||||
), "HuggingFace should match fast reference"
|
||||
assert custom_match(hf_ids, fast_reference_ids), "HuggingFace should match fast reference"
|
||||
print("✅ HuggingFace == Fast")
|
||||
|
||||
# Finally use our own Rust implementation
|
||||
print("\nTraining rustbpe...")
|
||||
rustbpe_tokenizer = rustbpe.Tokenizer()
|
||||
_, rustbpe_train_time = time_function(
|
||||
rustbpe_tokenizer.train_from_iterator, [text], vocab_size
|
||||
)
|
||||
rustbpe_ids, rustbpe_encode_time = time_function(
|
||||
rustbpe_tokenizer.encode, encode_text
|
||||
)
|
||||
_, rustbpe_train_time = time_function(rustbpe_tokenizer.train_from_iterator, [text], vocab_size)
|
||||
rustbpe_ids, rustbpe_encode_time = time_function(rustbpe_tokenizer.encode, encode_text)
|
||||
print(f"RustBPE train time: {rustbpe_train_time:.4f}s")
|
||||
print(f"RustBPE encode time: {rustbpe_encode_time:.4f}s")
|
||||
print(rustbpe_ids[:20])
|
||||
|
|
@ -634,25 +604,21 @@ def test_training_performance(enwik8_large):
|
|||
# Train rustbpe
|
||||
print("\nTraining rustbpe...")
|
||||
rustbpe_tokenizer = rustbpe.Tokenizer()
|
||||
_, rustbpe_train_time = time_function(
|
||||
rustbpe_tokenizer.train_from_iterator, [text], vocab_size
|
||||
)
|
||||
_, rustbpe_train_time = time_function(rustbpe_tokenizer.train_from_iterator, [text], vocab_size)
|
||||
print(f"RustBPE train time: {rustbpe_train_time:.4f}s")
|
||||
assert rustbpe_train_time > 0, "Training should take some time"
|
||||
|
||||
# Train HuggingFace
|
||||
print("\nTraining HuggingFace...")
|
||||
hf_tokenizer, hf_train_time = time_function(
|
||||
HuggingFaceTokenizer.train_from_iterator, [text], vocab_size
|
||||
)
|
||||
hf_tokenizer, hf_train_time = time_function(HuggingFaceTokenizer.train_from_iterator, [text], vocab_size)
|
||||
print(f"HuggingFace train time: {hf_train_time:.4f}s")
|
||||
assert hf_train_time > 0, "Training should take some time"
|
||||
|
||||
# Print comparison
|
||||
print(f"\n📊 Performance comparison:")
|
||||
print("\n📊 Performance comparison:")
|
||||
print(f" RustBPE: {rustbpe_train_time:.4f}s")
|
||||
print(f" HuggingFace: {hf_train_time:.4f}s")
|
||||
print(f" Speedup: {hf_train_time/rustbpe_train_time:.2f}x")
|
||||
print(f" Speedup: {hf_train_time / rustbpe_train_time:.2f}x")
|
||||
|
||||
|
||||
def test_interface(enwik8_small):
|
||||
|
|
@ -664,9 +630,7 @@ def test_interface(enwik8_small):
|
|||
# Simple train test
|
||||
vocab_size = 300
|
||||
tok = RustBPETokenizer.train_from_iterator([enwik8_small], vocab_size)
|
||||
assert (
|
||||
tok.get_vocab_size() == vocab_size
|
||||
), f"Expected vocab size {vocab_size}, got {tok.get_vocab_size()}"
|
||||
assert tok.get_vocab_size() == vocab_size, f"Expected vocab size {vocab_size}, got {tok.get_vocab_size()}"
|
||||
print(f"✅ Trained tokenizer with vocab size {vocab_size}")
|
||||
|
||||
# Encode/decode text
|
||||
|
|
@ -676,24 +640,18 @@ def test_interface(enwik8_small):
|
|||
print(f"IDs: {ids}")
|
||||
decoded = tok.decode(ids)
|
||||
print(f"Decoded: {decoded}")
|
||||
assert (
|
||||
decoded == encode_text
|
||||
), f"Decoded text doesn't match: {decoded} != {encode_text}"
|
||||
assert decoded == encode_text, f"Decoded text doesn't match: {decoded} != {encode_text}"
|
||||
print("✅ Encode/decode test passed")
|
||||
|
||||
# Encode batch test
|
||||
ids_new = tok.encode([encode_text, encode_text])
|
||||
assert all(
|
||||
x == ids for x in ids_new
|
||||
), "Batch encoding should produce identical results"
|
||||
assert all(x == ids for x in ids_new), "Batch encoding should produce identical results"
|
||||
print("✅ Encode batch OK")
|
||||
|
||||
# append/prepend functionality
|
||||
ids_special = tok.encode(encode_text, prepend="<|bos|>", append="<|bos|>")
|
||||
bos_token_id = tok.encode_special("<|bos|>")
|
||||
assert ids_special == [bos_token_id] + ids + [
|
||||
bos_token_id
|
||||
], "Special tokens not correctly added"
|
||||
assert ids_special == [bos_token_id] + ids + [bos_token_id], "Special tokens not correctly added"
|
||||
print("✅ append/prepend OK")
|
||||
|
||||
# Save/load test through a temporary directory
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user