diff --git a/scripts/chat_web.py b/scripts/chat_web.py index ffaf7da..4ccc582 100644 --- a/scripts/chat_web.py +++ b/scripts/chat_web.py @@ -109,6 +109,11 @@ class WorkerPool: print(f"Initializing worker pool with {self.num_gpus} GPUs...") if self.num_gpus > 1: assert device_type == "cuda", "Only CUDA supports multiple workers/GPUs. cpu|mps does not." + if device_type == "cuda" and self.num_gpus > torch.cuda.device_count(): + raise ValueError( + f"--num-gpus {self.num_gpus} exceeds available GPUs ({torch.cuda.device_count()}). " + "Use a value <= torch.cuda.device_count()." + ) for gpu_id in range(self.num_gpus):