mirror of
https://github.com/karpathy/nanochat.git
synced 2025-12-06 04:12:13 +00:00
indent fix, limited device options to cuda/mps
This commit is contained in:
parent
428b902763
commit
90e3bc778b
|
|
@ -98,7 +98,7 @@ def compute_init(device_type="cuda"):
|
|||
# Reproducibility
|
||||
torch.manual_seed(42)
|
||||
if device_type == "cuda":
|
||||
torch.cuda.manual_seed(42)
|
||||
torch.cuda.manual_seed(42)
|
||||
elif device_type == "mps":
|
||||
torch.mps.manual_seed(42)
|
||||
# skipping full reproducibility for now, possibly investigate slowdown later
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ parser.add_argument('-s', '--step', type=int, default=None, help='Step to load')
|
|||
parser.add_argument('-p', '--prompt', type=str, default='', help='Prompt the model, get a single response back')
|
||||
parser.add_argument('-t', '--temperature', type=float, default=0.6, help='Temperature for generation')
|
||||
parser.add_argument('-k', '--top-k', type=int, default=50, help='Top-k sampling parameter')
|
||||
parser.add_argument('-d', '--device', type=str, default='cuda', help='Device to run the model on: cuda|mps')
|
||||
parser.add_argument('-d', '--device', type=str, default='cuda', choices=['mps', 'cuda'], help='Device to run the model on: cuda|mps')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
|
|
|||
|
|
@ -29,7 +29,7 @@ parser.add_argument('-g', '--model-tag', type=str, default=None, help='Model tag
|
|||
parser.add_argument('-s', '--step', type=int, default=None, help='Step to load')
|
||||
parser.add_argument('-p', '--port', type=int, default=8000, help='Port to run the server on')
|
||||
parser.add_argument('--host', type=str, default='0.0.0.0', help='Host to bind the server to')
|
||||
parser.add_argument('-d', '--device', type=str, default='cuda', help='Device to run the model on: cuda|mps')
|
||||
parser.add_argument('-d', '--device', type=str, default='cuda', choices=['mps', 'cuda'], help='Device to run the model on: cuda|mps')
|
||||
args = parser.parse_args()
|
||||
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(args.device)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user