mirror of
https://github.com/karpathy/nanochat.git
synced 2025-12-28 23:32:21 +00:00
Merge 90e3bc778b into 67aaca98f5
This commit is contained in:
commit
e4692ea3a7
|
|
@ -89,15 +89,18 @@ def get_dist_info():
|
|||
else:
|
||||
return False, 0, 0, 1
|
||||
|
||||
def compute_init():
|
||||
def compute_init(device_type="cuda"):
|
||||
"""Basic initialization that we keep doing over and over, so make common."""
|
||||
|
||||
# CUDA is currently required
|
||||
assert torch.cuda.is_available(), "CUDA is needed for a distributed run atm"
|
||||
assert torch.cuda.is_available() | torch.mps.is_available(), "CUDA or MPS is needed for a distributed run atm"
|
||||
|
||||
# Reproducibility
|
||||
torch.manual_seed(42)
|
||||
torch.cuda.manual_seed(42)
|
||||
if device_type == "cuda":
|
||||
torch.cuda.manual_seed(42)
|
||||
elif device_type == "mps":
|
||||
torch.mps.manual_seed(42)
|
||||
# skipping full reproducibility for now, possibly investigate slowdown later
|
||||
# torch.use_deterministic_algorithms(True)
|
||||
# torch.backends.cudnn.deterministic = True
|
||||
|
|
@ -109,12 +112,15 @@ def compute_init():
|
|||
# Distributed setup: Distributed Data Parallel (DDP), optional
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
|
||||
if ddp:
|
||||
device = torch.device("cuda", ddp_local_rank)
|
||||
torch.cuda.set_device(device) # make "cuda" default to this device
|
||||
device = torch.device(device_type, ddp_local_rank)
|
||||
if device_type == "mps":
|
||||
torch.mps.set_device(device) # make "device" default to this device
|
||||
else:
|
||||
torch.cuda.set_device(device) # make "device" default to this device
|
||||
dist.init_process_group(backend="nccl", device_id=device)
|
||||
dist.barrier()
|
||||
else:
|
||||
device = torch.device("cuda")
|
||||
device = torch.device(device_type)
|
||||
|
||||
if ddp_rank == 0:
|
||||
logger.info(f"Distributed world size: {ddp_world_size}")
|
||||
|
|
|
|||
|
|
@ -17,11 +17,13 @@ parser.add_argument('-s', '--step', type=int, default=None, help='Step to load')
|
|||
parser.add_argument('-p', '--prompt', type=str, default='', help='Prompt the model, get a single response back')
|
||||
parser.add_argument('-t', '--temperature', type=float, default=0.6, help='Temperature for generation')
|
||||
parser.add_argument('-k', '--top-k', type=int, default=50, help='Top-k sampling parameter')
|
||||
parser.add_argument('-d', '--device', type=str, default='cuda', choices=['mps', 'cuda'], help='Device to run the model on: cuda|mps')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Init the model and tokenizer
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
|
||||
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(args.device)
|
||||
autocast_ctx = torch.amp.autocast(device_type=args.device, dtype=torch.bfloat16)
|
||||
model, tokenizer, meta = load_model(args.source, device, phase="eval", model_tag=args.model_tag, step=args.step)
|
||||
|
||||
# Special tokens for the chat state machine
|
||||
|
|
|
|||
|
|
@ -29,10 +29,11 @@ parser.add_argument('-g', '--model-tag', type=str, default=None, help='Model tag
|
|||
parser.add_argument('-s', '--step', type=int, default=None, help='Step to load')
|
||||
parser.add_argument('-p', '--port', type=int, default=8000, help='Port to run the server on')
|
||||
parser.add_argument('--host', type=str, default='0.0.0.0', help='Host to bind the server to')
|
||||
parser.add_argument('-d', '--device', type=str, default='cuda', choices=['mps', 'cuda'], help='Device to run the model on: cuda|mps')
|
||||
args = parser.parse_args()
|
||||
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
|
||||
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(args.device)
|
||||
autocast_ctx = torch.amp.autocast(device_type=args.device, dtype=torch.bfloat16)
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
role: str
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user