nanochat/vertex_pipelines/pipeline.py
2025-12-01 19:59:58 -05:00

326 lines
13 KiB
Python

import os
import kfp
from kfp import dsl
from kfp.compiler import Compiler
from google.cloud import aiplatform
from google_cloud_pipeline_components.v1.custom_job import CustomTrainingJobOp
# Global configuration for accelerator type
ACCELERATOR_TYPE = 'NVIDIA_L4'
# Read image URI from environment variable.
# This allows compiling the pipeline with a specific image without passing it as a PipelineParam,
# which avoids issues with dsl.ContainerSpec.
DOCKER_IMAGE_URI = os.environ.get("DOCKER_IMAGE_URI", "gcr.io/nzp-nanochat/nanochat:latest")
@dsl.container_component
def tokenizer_step(gcs_bucket: str) -> dsl.ContainerSpec:
"""
Tokenizer component.
"""
return dsl.ContainerSpec(
image=DOCKER_IMAGE_URI,
command=["python", "vertex_pipelines/tokenizer_step.py"],
args=["--gcs-bucket", gcs_bucket],
)
@dsl.container_component
def midtraining_step(gcs_bucket: str, wandb_run: str, vertex_experiment: str, vertex_tensorboard: str) -> dsl.ContainerSpec:
"""
Midtraining component.
"""
return dsl.ContainerSpec(
image=DOCKER_IMAGE_URI,
command=["python", "vertex_pipelines/midtraining_step.py"],
args=["--gcs-bucket", gcs_bucket, "--wandb-run", wandb_run, "--vertex-experiment", vertex_experiment, "--vertex-tensorboard", vertex_tensorboard],
)
@dsl.container_component
def sft_step(gcs_bucket: str, wandb_run: str, vertex_experiment: str, vertex_tensorboard: str) -> dsl.ContainerSpec:
"""
SFT component.
"""
return dsl.ContainerSpec(
image=DOCKER_IMAGE_URI,
command=["python", "vertex_pipelines/sft_step.py"],
args=["--gcs-bucket", gcs_bucket, "--wandb-run", wandb_run, "--vertex-experiment", vertex_experiment, "--vertex-tensorboard", vertex_tensorboard],
)
@dsl.container_component
def data_download_step(gcs_bucket: str, num_shards: int = 50):
"""
Data download component - downloads training data from HuggingFace to GCS.
"""
return dsl.ContainerSpec(
image=DOCKER_IMAGE_URI,
command=["python", "vertex_pipelines/data_download_step.py"],
args=["--gcs-bucket", gcs_bucket, "--num-shards", str(num_shards)],
)
@dsl.container_component
def report_step(gcs_bucket: str) -> dsl.ContainerSpec:
"""
Report component.
"""
return dsl.ContainerSpec(
image=DOCKER_IMAGE_URI,
command=["python", "vertex_pipelines/report_step.py"],
args=["--gcs-bucket", gcs_bucket],
)
# Let's rewrite the function to use the global ACCELERATOR_TYPE which we will ensure is set BEFORE the function is decorated/called.
# Actually, dsl.pipeline is a decorator. It runs when the module is loaded.
# So 'nanochat_pipeline' is compiled/registered immediately.
# If we want to change the structure based on args, we should define the pipeline function INSIDE __main__ or
# create a function that returns the pipeline function.
def create_pipeline_func(accelerator_type, accelerator_count, is_preemptible):
@dsl.pipeline(
name="nanochat-pipeline",
description="A pipeline to train NanoChat",
)
def nanochat_pipeline(
gcs_bucket: str,
project: str,
location: str,
wandb_run: str = "dummy",
vertex_experiment: str = "",
vertex_tensorboard: str = "",
num_data_shards: int = 20,
scheduling_strategy: str = "FLEX_START",
max_wait_duration: str = "0s",
service_account: str = "",
device_batch_size: int = 8
):
# Data download step
data_download_task = data_download_step(
gcs_bucket=gcs_bucket,
num_shards=num_data_shards
)
data_download_task.set_cpu_limit('8').set_memory_limit('32G')
# Tokenizer step
tokenizer_task = tokenizer_step(gcs_bucket=gcs_bucket)
tokenizer_task.set_cpu_limit('8').set_memory_limit('32G')
# Pretraining step using CustomTrainingJobOp
# Define worker pool specs
# Note: We use the same image and command as before
worker_pool_specs = [{
"machine_spec": {
"machine_type": "a2-highgpu-1g" if accelerator_type == "NVIDIA_TESLA_A100" and accelerator_count == 1 else "a2-highgpu-8g" if accelerator_type == "NVIDIA_TESLA_A100" and accelerator_count == 8 else "n1-standard-16", # Fallback/Logic needs to be robust
"accelerator_type": accelerator_type,
"accelerator_count": accelerator_count,
},
"replica_count": 1,
"disk_spec": {
"boot_disk_type": "pd-ssd",
"boot_disk_size_gb": 500,
},
"container_spec": {
"image_uri": DOCKER_IMAGE_URI,
"command": ["python", "vertex_pipelines/pretraining_step.py"],
"args": [
"--gcs-bucket", gcs_bucket,
"--wandb-run", wandb_run,
"--vertex-experiment", vertex_experiment,
"--vertex-tensorboard", vertex_tensorboard,
"--device-batch-size", str(device_batch_size)
],
},
}]
# Refine machine type logic based on accelerator
# A100 40GB: a2-highgpu-1g (1 GPU), a2-highgpu-2g (2 GPUs), a2-highgpu-4g (4 GPUs), a2-highgpu-8g (8 GPUs)
# L4: g2-standard-4 (1 GPU), etc.
# For now, let's assume the user passes valid combinations or we map them.
# Given the user specifically asked for 8x A100, we target a2-highgpu-8g.
machine_type = "n1-standard-16" # Default
if accelerator_type == "NVIDIA_TESLA_A100":
if accelerator_count == 1: machine_type = "a2-highgpu-1g"
elif accelerator_count == 2: machine_type = "a2-highgpu-2g"
elif accelerator_count == 4: machine_type = "a2-highgpu-4g"
elif accelerator_count == 8: machine_type = "a2-highgpu-8g"
elif accelerator_type == "NVIDIA_L4":
if accelerator_count == 1: machine_type = "g2-standard-4"
elif accelerator_count == 8: machine_type = "g2-standard-96"
worker_pool_specs[0]["machine_spec"]["machine_type"] = machine_type
# Scheduling strategy is now a runtime parameter
# Common values:
# FLEX_START: Dynamic Workload Scheduler - queues jobs when resources unavailable
# SPOT: Preemptible instances (deprecated in favor of FLEX_START)
# STANDARD: Standard on-demand instances
# max_wait_duration: "0s" = wait indefinitely, "3600s" = 1 hour, "86400s" = 24 hours
pretraining_task = CustomTrainingJobOp(
project=project,
location=location,
display_name="nanochat-pretraining-job",
worker_pool_specs=worker_pool_specs,
base_output_directory=f"{gcs_bucket}/pipeline_root",
timeout="604800s", # 7 days
restart_job_on_worker_restart=True,
strategy=scheduling_strategy,
max_wait_duration=max_wait_duration,
service_account=service_account,
tensorboard=vertex_tensorboard,
).after(tokenizer_task)
# CustomTrainingJobOp returns a Model (if configured) or just the job resource.
# We don't need to set resources/accelerators on the task itself because they are in worker_pool_specs.
# Mid-training step - use same resources as pretraining (A100s on FLEX)
mid_worker_pool_specs = [{
"machine_spec": worker_pool_specs[0]["machine_spec"],
"replica_count": 1,
"disk_spec": {
"boot_disk_type": "pd-ssd",
"boot_disk_size_gb": 500,
},
"container_spec": {
"image_uri": DOCKER_IMAGE_URI,
"command": ["python", "vertex_pipelines/midtraining_step.py"],
"args": [
"--gcs-bucket", gcs_bucket,
"--wandb-run", wandb_run,
"--vertex-experiment", vertex_experiment,
"--vertex-tensorboard", vertex_tensorboard,
"--device-batch-size", str(device_batch_size),
],
},
}]
midtraining_task = CustomTrainingJobOp(
project=project,
location=location,
display_name="nanochat-midtraining-job",
worker_pool_specs=mid_worker_pool_specs,
base_output_directory=f"{gcs_bucket}/pipeline_root",
service_account=service_account,
strategy=scheduling_strategy,
max_wait_duration=max_wait_duration,
).after(pretraining_task)
# SFT step - use same resources as pretraining (A100s on FLEX)
sft_worker_pool_specs = [{
"machine_spec": worker_pool_specs[0]["machine_spec"],
"replica_count": 1,
"disk_spec": {
"boot_disk_type": "pd-ssd",
"boot_disk_size_gb": 500,
},
"container_spec": {
"image_uri": DOCKER_IMAGE_URI,
"command": ["python", "vertex_pipelines/sft_step.py"],
"args": [
"--gcs-bucket", gcs_bucket,
"--wandb-run", wandb_run,
"--vertex-experiment", vertex_experiment,
"--vertex-tensorboard", vertex_tensorboard,
],
},
}]
sft_task = CustomTrainingJobOp(
project=project,
location=location,
display_name="nanochat-sft-job",
worker_pool_specs=sft_worker_pool_specs,
base_output_directory=f"{gcs_bucket}/pipeline_root",
service_account=service_account,
strategy=scheduling_strategy,
max_wait_duration=max_wait_duration,
).after(midtraining_task)
report_task = report_step(gcs_bucket=gcs_bucket).after(sft_task)
report_task.set_cpu_limit('2').set_memory_limit('8G')
return nanochat_pipeline
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--gcp-project", type=str, required=False) # Optional if we don't run it here
parser.add_argument("--gcs-bucket", type=str, required=True)
parser.add_argument("--pipeline-root", type=str, required=False)
parser.add_argument("--region", type=str, default="us-central1")
parser.add_argument("--wandb-run", type=str, default="dummy")
parser.add_argument("--vertex-experiment", type=str, default="")
parser.add_argument("--vertex-tensorboard", type=str, default="")
parser.add_argument("--accelerator-type", type=str, default="NVIDIA_L4")
parser.add_argument("--accelerator-count", type=int, default=1)
parser.add_argument("--num-data-shards", type=int, default=20)
parser.add_argument("--preemptible", type=str, default="false")
parser.add_argument("--scheduling-strategy", type=str, default=None, help="Scheduling strategy: FLEX_START, SPOT, or STANDARD")
parser.add_argument("--max-wait-duration", type=str, default=None, help="Max wait duration for FLEX_START, e.g., '0s', '3600s'")
parser.add_argument("--service-account", type=str, required=False, help="Service account to run the pipeline")
parser.add_argument("--device-batch-size", type=int, default=8, help="Batch size per device")
parser.add_argument("--template_path", type=str, default="nanochat_pipeline.json")
args = parser.parse_args()
is_preemptible = args.preemptible.lower() == "true"
# Set smart defaults for scheduling strategy based on preemptible flag
if args.scheduling_strategy is None:
scheduling_strategy = "FLEX_START" if is_preemptible else "STANDARD"
else:
scheduling_strategy = args.scheduling_strategy
if args.max_wait_duration is None:
max_wait_duration = "0s" if is_preemptible else "86400s"
else:
max_wait_duration = args.max_wait_duration
# Create the pipeline function dynamically with captured arguments
pipeline_func = create_pipeline_func(
accelerator_type=args.accelerator_type,
accelerator_count=args.accelerator_count,
is_preemptible=is_preemptible
)
Compiler().compile(
pipeline_func=pipeline_func,
package_path=args.template_path,
)
# Initialize Vertex AI SDK
if args.gcp_project:
aiplatform.init(project=args.gcp_project, location=args.region)
job = aiplatform.PipelineJob(
display_name="nanochat-pipeline",
template_path=args.template_path,
pipeline_root=args.pipeline_root,
parameter_values={
"gcs_bucket": args.gcs_bucket,
"project": args.gcp_project,
"location": args.region,
"wandb_run": args.wandb_run,
"vertex_experiment": args.vertex_experiment,
"vertex_tensorboard": args.vertex_tensorboard,
"num_data_shards": args.num_data_shards,
"scheduling_strategy": scheduling_strategy,
"max_wait_duration": max_wait_duration,
"service_account": args.service_account,
"device_batch_size": args.device_batch_size,
},
)
# Run the pipeline
# service_account is optional but recommended
job.run(
service_account=args.service_account,
sync=True # Block until completion or failure to ensure we see logs
)