nanochat/vertex_pipelines/pipeline.py
google-labs-jules[bot] 2781d216c6 feat: Refactor nanochat to run on Vertex AI Pipelines
This refactoring enables the nanochat project to be executed as a scalable and robust pipeline on Vertex AI.

The monolithic `speedrun.sh` script has been decomposed into a series of containerized components orchestrated by a Kubeflow pipeline.

The codebase has been updated to use Google Cloud Storage for artifact management, allowing for seamless data sharing between pipeline steps.

A `Dockerfile` and Python wrappers for each pipeline step have been added to the `vertex_pipelines` directory.

The `README.md` has been updated with instructions on how to build the Docker image and run the Vertex AI pipeline.
2025-11-04 01:26:51 +00:00

75 lines
2.5 KiB
Python

import kfp
from kfp.v2 import dsl
from kfp.v2.compiler import Compiler
from google.cloud import aiplatform
@dsl.pipeline(name="nanochat-pipeline")
def nanochat_pipeline(gcs_bucket: str, docker_image_uri: str, wandb_run: str = "dummy"):
"""
A Vertex AI pipeline for training and evaluating a nanochat model.
"""
tokenizer_op = dsl.ContainerOp(
name="tokenizer",
image=docker_image_uri,
command=["python", "vertex_pipelines/tokenizer_step.py"],
arguments=["--gcs-bucket", gcs_bucket],
)
pretraining_op = dsl.ContainerOp(
name="pretraining",
image=docker_image_uri,
command=["python", "vertex_pipelines/pretraining_step.py"],
arguments=["--gcs-bucket", gcs_bucket, "--wandb-run", wandb_run],
).after(tokenizer_op)
midtraining_op = dsl.ContainerOp(
name="midtraining",
image=docker_image_uri,
command=["python", "vertex_pipelines/midtraining_step.py"],
arguments=["--gcs-bucket", gcs_bucket, "--wandb-run", wandb_run],
).after(pretraining_op)
sft_op = dsl.ContainerOp(
name="sft",
image=docker_image_uri,
command=["python", "vertex_pipelines/sft_step.py"],
arguments=["--gcs-bucket", gcs_bucket, "--wandb-run", wandb_run],
).after(midtraining_op)
report_op = dsl.ContainerOp(
name="report",
image=docker_image_uri,
command=["python", "vertex_pipelines/report_step.py"],
arguments=["--gcs-bucket", gcs_bucket],
).after(sft_op)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--gcp-project", type=str, required=True)
parser.add_argument("--gcs-bucket", type=str, required=True)
parser.add_argument("--pipeline-root", type=str, required=True)
parser.add_argument("--docker-image-uri", type=str, required=True)
parser.add_argument("--region", type=str, default="us-central1")
args = parser.parse_args()
Compiler().compile(
pipeline_func=nanochat_pipeline,
package_path="nanochat_pipeline.json",
)
aiplatform.init(project=args.gcp_project, location=args.region)
job = aiplatform.PipelineJob(
display_name="nanochat-pipeline",
template_path="nanochat_pipeline.json",
pipeline_root=args.pipeline_root,
parameter_values={
"gcs_bucket": args.gcs_bucket,
"docker_image_uri": args.docker_image_uri,
},
)
job.run()