nanochat/vertex_pipelines/midtraining_step.py
2025-12-01 19:59:58 -05:00

150 lines
6.6 KiB
Python

import os
import subprocess
import argparse
import shutil
from google.cloud import storage
def download_directory_from_gcs(bucket_name, gcs_path, local_path):
storage_client = storage.Client()
bucket = storage_client.bucket(bucket_name)
blobs = bucket.list_blobs(prefix=gcs_path)
for blob in blobs:
if blob.name.endswith("/"):
continue
relative_path = os.path.relpath(blob.name, gcs_path)
local_file = os.path.join(local_path, relative_path)
os.makedirs(os.path.dirname(local_file), exist_ok=True)
blob.download_to_filename(local_file)
print(f"Downloaded gs://{bucket_name}/{blob.name} to {local_file}")
def upload_directory_to_gcs(local_path, bucket_name, gcs_path):
storage_client = storage.Client()
bucket = storage_client.bucket(bucket_name)
for root, _, files in os.walk(local_path):
for file in files:
local_file = os.path.join(root, file)
relative_path = os.path.relpath(local_file, local_path)
blob_path = os.path.join(gcs_path, relative_path)
blob = bucket.blob(blob_path)
blob.upload_from_file(open(local_file, 'rb'))
print(f"Uploaded {local_file} to gs://{bucket_name}/{blob_path}")
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--gcs-bucket", type=str, required=True, help="GCS bucket for artifacts")
parser.add_argument("--wandb-run", type=str, default="dummy", help="Wandb run name")
parser.add_argument("--vertex-experiment", type=str, default="", help="Vertex AI experiment name")
parser.add_argument("--vertex-tensorboard", type=str, default="", help="Vertex AI TensorBoard resource name")
parser.add_argument("--device-batch-size", type=int, default=16, help="Device batch size")
args = parser.parse_args()
# Parse bucket name and prefix
if args.gcs_bucket.startswith("gs://"):
bucket_name = args.gcs_bucket.replace("gs://", "").split("/")[0]
prefix_parts = args.gcs_bucket.replace("gs://", "").split("/")[1:]
prefix = "/".join(prefix_parts) if prefix_parts else ""
else:
bucket_name = args.gcs_bucket
prefix = ""
# Check if midtraining checkpoint already exists (checkpoint detection)
storage_client = storage.Client()
bucket = storage_client.bucket(bucket_name)
gcs_mid_ckpt_path = os.path.join(prefix, "mid_checkpoints") if prefix else "mid_checkpoints"
# Check for model.pt (the key checkpoint file)
# Note: mid_train.py saves to f"d{depth}" where depth defaults to 20 (inherited from base model)
depth = 20
gcs_mid_ckpt_path = os.path.join(gcs_mid_ckpt_path, f"d{depth}")
checkpoint_exists = bucket.blob(os.path.join(gcs_mid_ckpt_path, "model.pt")).exists()
if checkpoint_exists:
print(f"✓ Midtraining checkpoint already exists in gs://{bucket_name}/{gcs_mid_ckpt_path}")
print("Skipping midtraining (already completed)")
return
print(f"Midtraining checkpoint not found. Running midtraining...")
# Set local tmp dir for temporary files
local_base_dir = "/tmp/nanochat"
os.makedirs(local_base_dir, exist_ok=True)
# Download tokenizer from GCS
print("Downloading tokenizer from GCS...")
gcs_tokenizer_path = os.path.join(prefix, "tokenizer") if prefix else "tokenizer"
local_tokenizer_dir = os.path.join(local_base_dir, "tokenizer")
download_directory_from_gcs(bucket_name, gcs_tokenizer_path, local_tokenizer_dir)
# Download base checkpoints from GCS
print("Downloading base checkpoints from GCS...")
gcs_base_checkpoints_path = os.path.join(prefix, "base_checkpoints") if prefix else "base_checkpoints"
local_base_checkpoints_dir = os.path.join(local_base_dir, "base_checkpoints")
download_directory_from_gcs(bucket_name, gcs_base_checkpoints_path, local_base_checkpoints_dir)
# Download report dir from GCS
print("Downloading report dir from GCS...")
gcs_report_path = os.path.join(prefix, "report") if prefix else "report"
local_report_dir = os.path.join(local_base_dir, "report")
download_directory_from_gcs(bucket_name, gcs_report_path, local_report_dir)
# Ensure report directory exists even if nothing was downloaded
os.makedirs(local_report_dir, exist_ok=True)
try:
# Download the identity conversations dataset.
# This is needed for midtraining.
# We can download it to local base dir or just let the script handle it if it downloads from URL.
# scripts/mid_train.py doesn't seem to download it automatically?
# Let's check mid_train.py later. Assuming the previous code was correct about downloading it.
# The previous code had:
# subprocess.run(["curl", "-L", "-o", f"{get_base_dir()}/identity_conversations.jsonl", ...])
# I'll include that.
print("Downloading identity conversations...")
subprocess.run([
"curl", "-L", "-o",
f"{local_base_dir}/identity_conversations.jsonl",
"https://karpathy-public.s3.us-west-2.amazonaws.com/identity_conversations.jsonl"
], check=True)
# Mid-train the model.
print("Starting midtraining...")
env = os.environ.copy()
env["NANOCHAT_BASE_DIR"] = local_base_dir
subprocess.run([
"torchrun", "--standalone", "--nproc_per_node=1",
"-m", "scripts.mid_train",
f"--device_batch_size={args.device_batch_size}",
f"--wandb_run_name={args.wandb_run}",
f"--vertex_experiment={args.vertex_experiment}",
f"--vertex_tensorboard={args.vertex_tensorboard}"
], check=True, env=env)
# Evaluate the model.
print("Running chat_eval (mid)...")
subprocess.run([
"torchrun", "--standalone", "--nproc_per_node=1",
"-m", "scripts.chat_eval", "--",
"-i", "mid"
], check=True, env=env)
except subprocess.CalledProcessError as e:
print(f"Error during midtraining steps: {e}")
raise
# Upload checkpoints to GCS
print("Uploading artifacts to GCS...")
# Upload mid_checkpoints
local_checkpoints_dir = os.path.join(local_base_dir, "mid_checkpoints")
gcs_checkpoints_path = os.path.join(prefix, "mid_checkpoints") if prefix else "mid_checkpoints"
if os.path.exists(local_checkpoints_dir):
upload_directory_to_gcs(local_checkpoints_dir, bucket_name, gcs_checkpoints_path)
else:
print(f"Warning: {local_checkpoints_dir} does not exist.")
# Upload report dir
if os.path.exists(local_report_dir):
upload_directory_to_gcs(local_report_dir, bucket_name, gcs_report_path)
if __name__ == "__main__":
main()