indent fix, limited device options to cuda/mps

This commit is contained in:
Jeroen Versteege 2025-10-15 08:28:34 +02:00
parent 428b902763
commit 90e3bc778b
3 changed files with 3 additions and 3 deletions

View File

@ -98,7 +98,7 @@ def compute_init(device_type="cuda"):
# Reproducibility
torch.manual_seed(42)
if device_type == "cuda":
torch.cuda.manual_seed(42)
torch.cuda.manual_seed(42)
elif device_type == "mps":
torch.mps.manual_seed(42)
# skipping full reproducibility for now, possibly investigate slowdown later

View File

@ -17,7 +17,7 @@ parser.add_argument('-s', '--step', type=int, default=None, help='Step to load')
parser.add_argument('-p', '--prompt', type=str, default='', help='Prompt the model, get a single response back')
parser.add_argument('-t', '--temperature', type=float, default=0.6, help='Temperature for generation')
parser.add_argument('-k', '--top-k', type=int, default=50, help='Top-k sampling parameter')
parser.add_argument('-d', '--device', type=str, default='cuda', help='Device to run the model on: cuda|mps')
parser.add_argument('-d', '--device', type=str, default='cuda', choices=['mps', 'cuda'], help='Device to run the model on: cuda|mps')
args = parser.parse_args()

View File

@ -29,7 +29,7 @@ parser.add_argument('-g', '--model-tag', type=str, default=None, help='Model tag
parser.add_argument('-s', '--step', type=int, default=None, help='Step to load')
parser.add_argument('-p', '--port', type=int, default=8000, help='Port to run the server on')
parser.add_argument('--host', type=str, default='0.0.0.0', help='Host to bind the server to')
parser.add_argument('-d', '--device', type=str, default='cuda', help='Device to run the model on: cuda|mps')
parser.add_argument('-d', '--device', type=str, default='cuda', choices=['mps', 'cuda'], help='Device to run the model on: cuda|mps')
args = parser.parse_args()
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(args.device)