mirror of
https://github.com/karpathy/nanochat.git
synced 2025-12-06 04:12:13 +00:00
326 lines
13 KiB
Python
326 lines
13 KiB
Python
import os
|
|
import kfp
|
|
from kfp import dsl
|
|
from kfp.compiler import Compiler
|
|
from google.cloud import aiplatform
|
|
from google_cloud_pipeline_components.v1.custom_job import CustomTrainingJobOp
|
|
|
|
# Global configuration for accelerator type
|
|
ACCELERATOR_TYPE = 'NVIDIA_L4'
|
|
# Read image URI from environment variable.
|
|
# This allows compiling the pipeline with a specific image without passing it as a PipelineParam,
|
|
# which avoids issues with dsl.ContainerSpec.
|
|
DOCKER_IMAGE_URI = os.environ.get("DOCKER_IMAGE_URI", "gcr.io/nzp-nanochat/nanochat:latest")
|
|
|
|
@dsl.container_component
|
|
def tokenizer_step(gcs_bucket: str) -> dsl.ContainerSpec:
|
|
"""
|
|
Tokenizer component.
|
|
"""
|
|
return dsl.ContainerSpec(
|
|
image=DOCKER_IMAGE_URI,
|
|
command=["python", "vertex_pipelines/tokenizer_step.py"],
|
|
args=["--gcs-bucket", gcs_bucket],
|
|
)
|
|
|
|
|
|
|
|
@dsl.container_component
|
|
def midtraining_step(gcs_bucket: str, wandb_run: str, vertex_experiment: str, vertex_tensorboard: str) -> dsl.ContainerSpec:
|
|
"""
|
|
Midtraining component.
|
|
"""
|
|
return dsl.ContainerSpec(
|
|
image=DOCKER_IMAGE_URI,
|
|
command=["python", "vertex_pipelines/midtraining_step.py"],
|
|
args=["--gcs-bucket", gcs_bucket, "--wandb-run", wandb_run, "--vertex-experiment", vertex_experiment, "--vertex-tensorboard", vertex_tensorboard],
|
|
)
|
|
|
|
@dsl.container_component
|
|
def sft_step(gcs_bucket: str, wandb_run: str, vertex_experiment: str, vertex_tensorboard: str) -> dsl.ContainerSpec:
|
|
"""
|
|
SFT component.
|
|
"""
|
|
return dsl.ContainerSpec(
|
|
image=DOCKER_IMAGE_URI,
|
|
command=["python", "vertex_pipelines/sft_step.py"],
|
|
args=["--gcs-bucket", gcs_bucket, "--wandb-run", wandb_run, "--vertex-experiment", vertex_experiment, "--vertex-tensorboard", vertex_tensorboard],
|
|
)
|
|
|
|
@dsl.container_component
|
|
def data_download_step(gcs_bucket: str, num_shards: int = 50):
|
|
"""
|
|
Data download component - downloads training data from HuggingFace to GCS.
|
|
"""
|
|
return dsl.ContainerSpec(
|
|
image=DOCKER_IMAGE_URI,
|
|
command=["python", "vertex_pipelines/data_download_step.py"],
|
|
args=["--gcs-bucket", gcs_bucket, "--num-shards", str(num_shards)],
|
|
)
|
|
|
|
@dsl.container_component
|
|
def report_step(gcs_bucket: str) -> dsl.ContainerSpec:
|
|
"""
|
|
Report component.
|
|
"""
|
|
return dsl.ContainerSpec(
|
|
image=DOCKER_IMAGE_URI,
|
|
command=["python", "vertex_pipelines/report_step.py"],
|
|
args=["--gcs-bucket", gcs_bucket],
|
|
)
|
|
|
|
|
|
|
|
# Let's rewrite the function to use the global ACCELERATOR_TYPE which we will ensure is set BEFORE the function is decorated/called.
|
|
# Actually, dsl.pipeline is a decorator. It runs when the module is loaded.
|
|
# So 'nanochat_pipeline' is compiled/registered immediately.
|
|
# If we want to change the structure based on args, we should define the pipeline function INSIDE __main__ or
|
|
# create a function that returns the pipeline function.
|
|
|
|
def create_pipeline_func(accelerator_type, accelerator_count, is_preemptible):
|
|
@dsl.pipeline(
|
|
name="nanochat-pipeline",
|
|
description="A pipeline to train NanoChat",
|
|
)
|
|
def nanochat_pipeline(
|
|
gcs_bucket: str,
|
|
project: str,
|
|
location: str,
|
|
wandb_run: str = "dummy",
|
|
vertex_experiment: str = "",
|
|
vertex_tensorboard: str = "",
|
|
num_data_shards: int = 20,
|
|
scheduling_strategy: str = "FLEX_START",
|
|
max_wait_duration: str = "0s",
|
|
service_account: str = "",
|
|
device_batch_size: int = 8
|
|
):
|
|
# Data download step
|
|
data_download_task = data_download_step(
|
|
gcs_bucket=gcs_bucket,
|
|
num_shards=num_data_shards
|
|
)
|
|
data_download_task.set_cpu_limit('8').set_memory_limit('32G')
|
|
|
|
# Tokenizer step
|
|
tokenizer_task = tokenizer_step(gcs_bucket=gcs_bucket)
|
|
tokenizer_task.set_cpu_limit('8').set_memory_limit('32G')
|
|
|
|
# Pretraining step using CustomTrainingJobOp
|
|
# Define worker pool specs
|
|
# Note: We use the same image and command as before
|
|
|
|
worker_pool_specs = [{
|
|
"machine_spec": {
|
|
"machine_type": "a2-highgpu-1g" if accelerator_type == "NVIDIA_TESLA_A100" and accelerator_count == 1 else "a2-highgpu-8g" if accelerator_type == "NVIDIA_TESLA_A100" and accelerator_count == 8 else "n1-standard-16", # Fallback/Logic needs to be robust
|
|
"accelerator_type": accelerator_type,
|
|
"accelerator_count": accelerator_count,
|
|
},
|
|
"replica_count": 1,
|
|
"disk_spec": {
|
|
"boot_disk_type": "pd-ssd",
|
|
"boot_disk_size_gb": 500,
|
|
},
|
|
"container_spec": {
|
|
"image_uri": DOCKER_IMAGE_URI,
|
|
"command": ["python", "vertex_pipelines/pretraining_step.py"],
|
|
"args": [
|
|
"--gcs-bucket", gcs_bucket,
|
|
"--wandb-run", wandb_run,
|
|
"--vertex-experiment", vertex_experiment,
|
|
"--vertex-tensorboard", vertex_tensorboard,
|
|
"--device-batch-size", str(device_batch_size)
|
|
],
|
|
},
|
|
}]
|
|
|
|
# Refine machine type logic based on accelerator
|
|
# A100 40GB: a2-highgpu-1g (1 GPU), a2-highgpu-2g (2 GPUs), a2-highgpu-4g (4 GPUs), a2-highgpu-8g (8 GPUs)
|
|
# L4: g2-standard-4 (1 GPU), etc.
|
|
# For now, let's assume the user passes valid combinations or we map them.
|
|
# Given the user specifically asked for 8x A100, we target a2-highgpu-8g.
|
|
|
|
machine_type = "n1-standard-16" # Default
|
|
if accelerator_type == "NVIDIA_TESLA_A100":
|
|
if accelerator_count == 1: machine_type = "a2-highgpu-1g"
|
|
elif accelerator_count == 2: machine_type = "a2-highgpu-2g"
|
|
elif accelerator_count == 4: machine_type = "a2-highgpu-4g"
|
|
elif accelerator_count == 8: machine_type = "a2-highgpu-8g"
|
|
elif accelerator_type == "NVIDIA_L4":
|
|
if accelerator_count == 1: machine_type = "g2-standard-4"
|
|
elif accelerator_count == 8: machine_type = "g2-standard-96"
|
|
|
|
worker_pool_specs[0]["machine_spec"]["machine_type"] = machine_type
|
|
|
|
# Scheduling strategy is now a runtime parameter
|
|
# Common values:
|
|
# FLEX_START: Dynamic Workload Scheduler - queues jobs when resources unavailable
|
|
# SPOT: Preemptible instances (deprecated in favor of FLEX_START)
|
|
# STANDARD: Standard on-demand instances
|
|
# max_wait_duration: "0s" = wait indefinitely, "3600s" = 1 hour, "86400s" = 24 hours
|
|
|
|
pretraining_task = CustomTrainingJobOp(
|
|
project=project,
|
|
location=location,
|
|
display_name="nanochat-pretraining-job",
|
|
worker_pool_specs=worker_pool_specs,
|
|
base_output_directory=f"{gcs_bucket}/pipeline_root",
|
|
timeout="604800s", # 7 days
|
|
restart_job_on_worker_restart=True,
|
|
strategy=scheduling_strategy,
|
|
max_wait_duration=max_wait_duration,
|
|
service_account=service_account,
|
|
tensorboard=vertex_tensorboard,
|
|
).after(tokenizer_task)
|
|
|
|
# CustomTrainingJobOp returns a Model (if configured) or just the job resource.
|
|
# We don't need to set resources/accelerators on the task itself because they are in worker_pool_specs.
|
|
|
|
# Mid-training step - use same resources as pretraining (A100s on FLEX)
|
|
mid_worker_pool_specs = [{
|
|
"machine_spec": worker_pool_specs[0]["machine_spec"],
|
|
"replica_count": 1,
|
|
"disk_spec": {
|
|
"boot_disk_type": "pd-ssd",
|
|
"boot_disk_size_gb": 500,
|
|
},
|
|
"container_spec": {
|
|
"image_uri": DOCKER_IMAGE_URI,
|
|
"command": ["python", "vertex_pipelines/midtraining_step.py"],
|
|
"args": [
|
|
"--gcs-bucket", gcs_bucket,
|
|
"--wandb-run", wandb_run,
|
|
"--vertex-experiment", vertex_experiment,
|
|
"--vertex-tensorboard", vertex_tensorboard,
|
|
"--device-batch-size", str(device_batch_size),
|
|
],
|
|
},
|
|
}]
|
|
|
|
midtraining_task = CustomTrainingJobOp(
|
|
project=project,
|
|
location=location,
|
|
display_name="nanochat-midtraining-job",
|
|
worker_pool_specs=mid_worker_pool_specs,
|
|
base_output_directory=f"{gcs_bucket}/pipeline_root",
|
|
service_account=service_account,
|
|
strategy=scheduling_strategy,
|
|
max_wait_duration=max_wait_duration,
|
|
).after(pretraining_task)
|
|
|
|
# SFT step - use same resources as pretraining (A100s on FLEX)
|
|
sft_worker_pool_specs = [{
|
|
"machine_spec": worker_pool_specs[0]["machine_spec"],
|
|
"replica_count": 1,
|
|
"disk_spec": {
|
|
"boot_disk_type": "pd-ssd",
|
|
"boot_disk_size_gb": 500,
|
|
},
|
|
"container_spec": {
|
|
"image_uri": DOCKER_IMAGE_URI,
|
|
"command": ["python", "vertex_pipelines/sft_step.py"],
|
|
"args": [
|
|
"--gcs-bucket", gcs_bucket,
|
|
"--wandb-run", wandb_run,
|
|
"--vertex-experiment", vertex_experiment,
|
|
"--vertex-tensorboard", vertex_tensorboard,
|
|
],
|
|
},
|
|
}]
|
|
|
|
sft_task = CustomTrainingJobOp(
|
|
project=project,
|
|
location=location,
|
|
display_name="nanochat-sft-job",
|
|
worker_pool_specs=sft_worker_pool_specs,
|
|
base_output_directory=f"{gcs_bucket}/pipeline_root",
|
|
service_account=service_account,
|
|
strategy=scheduling_strategy,
|
|
max_wait_duration=max_wait_duration,
|
|
).after(midtraining_task)
|
|
|
|
report_task = report_step(gcs_bucket=gcs_bucket).after(sft_task)
|
|
report_task.set_cpu_limit('2').set_memory_limit('8G')
|
|
|
|
return nanochat_pipeline
|
|
|
|
if __name__ == "__main__":
|
|
import argparse
|
|
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--gcp-project", type=str, required=False) # Optional if we don't run it here
|
|
parser.add_argument("--gcs-bucket", type=str, required=True)
|
|
parser.add_argument("--pipeline-root", type=str, required=False)
|
|
parser.add_argument("--region", type=str, default="us-central1")
|
|
parser.add_argument("--wandb-run", type=str, default="dummy")
|
|
parser.add_argument("--vertex-experiment", type=str, default="")
|
|
parser.add_argument("--vertex-tensorboard", type=str, default="")
|
|
parser.add_argument("--accelerator-type", type=str, default="NVIDIA_L4")
|
|
parser.add_argument("--accelerator-count", type=int, default=1)
|
|
parser.add_argument("--num-data-shards", type=int, default=20)
|
|
parser.add_argument("--preemptible", type=str, default="false")
|
|
parser.add_argument("--scheduling-strategy", type=str, default=None, help="Scheduling strategy: FLEX_START, SPOT, or STANDARD")
|
|
parser.add_argument("--max-wait-duration", type=str, default=None, help="Max wait duration for FLEX_START, e.g., '0s', '3600s'")
|
|
parser.add_argument("--service-account", type=str, required=False, help="Service account to run the pipeline")
|
|
parser.add_argument("--device-batch-size", type=int, default=8, help="Batch size per device")
|
|
parser.add_argument("--template_path", type=str, default="nanochat_pipeline.json")
|
|
args = parser.parse_args()
|
|
|
|
is_preemptible = args.preemptible.lower() == "true"
|
|
|
|
# Set smart defaults for scheduling strategy based on preemptible flag
|
|
if args.scheduling_strategy is None:
|
|
scheduling_strategy = "FLEX_START" if is_preemptible else "STANDARD"
|
|
else:
|
|
scheduling_strategy = args.scheduling_strategy
|
|
|
|
if args.max_wait_duration is None:
|
|
max_wait_duration = "0s" if is_preemptible else "86400s"
|
|
else:
|
|
max_wait_duration = args.max_wait_duration
|
|
|
|
# Create the pipeline function dynamically with captured arguments
|
|
pipeline_func = create_pipeline_func(
|
|
accelerator_type=args.accelerator_type,
|
|
accelerator_count=args.accelerator_count,
|
|
is_preemptible=is_preemptible
|
|
)
|
|
|
|
Compiler().compile(
|
|
pipeline_func=pipeline_func,
|
|
package_path=args.template_path,
|
|
)
|
|
|
|
|
|
|
|
# Initialize Vertex AI SDK
|
|
if args.gcp_project:
|
|
aiplatform.init(project=args.gcp_project, location=args.region)
|
|
|
|
job = aiplatform.PipelineJob(
|
|
display_name="nanochat-pipeline",
|
|
template_path=args.template_path,
|
|
pipeline_root=args.pipeline_root,
|
|
parameter_values={
|
|
"gcs_bucket": args.gcs_bucket,
|
|
"project": args.gcp_project,
|
|
"location": args.region,
|
|
"wandb_run": args.wandb_run,
|
|
"vertex_experiment": args.vertex_experiment,
|
|
"vertex_tensorboard": args.vertex_tensorboard,
|
|
"num_data_shards": args.num_data_shards,
|
|
"scheduling_strategy": scheduling_strategy,
|
|
"max_wait_duration": max_wait_duration,
|
|
"service_account": args.service_account,
|
|
"device_batch_size": args.device_batch_size,
|
|
},
|
|
)
|
|
|
|
# Run the pipeline
|
|
# service_account is optional but recommended
|
|
job.run(
|
|
service_account=args.service_account,
|
|
sync=True # Block until completion or failure to ensure we see logs
|
|
)
|
|
|