mirror of
https://github.com/karpathy/nanochat.git
synced 2026-05-10 01:40:17 +00:00
Refactor chat_web startup for safe imports
This commit is contained in:
parent
63395bbade
commit
7b4549495c
|
|
@ -60,6 +60,8 @@ MAX_TOP_K = 200
|
|||
MIN_MAX_TOKENS = 1
|
||||
MAX_MAX_TOKENS = 4096
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def build_parser():
|
||||
parser = argparse.ArgumentParser(description='NanoChat Web Server')
|
||||
parser.add_argument('-n', '--num-gpus', type=int, default=1, help='Number of GPUs to use (default: 1)')
|
||||
|
|
@ -82,18 +84,30 @@ def parse_args(argv=None):
|
|||
return parser.parse_args(argv)
|
||||
|
||||
|
||||
args = parse_args()
|
||||
args = parse_args([])
|
||||
device_type = None
|
||||
ddp = ddp_rank = ddp_local_rank = ddp_world_size = None
|
||||
device = None
|
||||
_runtime_initialized = False
|
||||
|
||||
# Configure logging for conversation traffic
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(message)s',
|
||||
datefmt='%Y-%m-%d %H:%M:%S'
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||
def configure_runtime(parsed_args=None):
|
||||
global args, device_type, ddp, ddp_rank, ddp_local_rank, ddp_world_size, device, _runtime_initialized
|
||||
|
||||
if parsed_args is not None:
|
||||
args = parsed_args
|
||||
if _runtime_initialized:
|
||||
return
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(message)s',
|
||||
datefmt='%Y-%m-%d %H:%M:%S'
|
||||
)
|
||||
|
||||
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||
_runtime_initialized = True
|
||||
|
||||
@dataclass
|
||||
class Worker:
|
||||
|
|
@ -106,9 +120,10 @@ class Worker:
|
|||
class WorkerPool:
|
||||
"""Pool of workers, each with a model replica on a different GPU."""
|
||||
|
||||
def __init__(self, num_gpus: Optional[int] = None):
|
||||
def __init__(self, runtime_device_type: str, num_gpus: Optional[int] = None):
|
||||
self.device_type = runtime_device_type
|
||||
if num_gpus is None:
|
||||
if device_type == "cuda":
|
||||
if self.device_type == "cuda":
|
||||
num_gpus = torch.cuda.device_count()
|
||||
else:
|
||||
num_gpus = 1 # e.g. cpu|mps
|
||||
|
|
@ -120,16 +135,16 @@ class WorkerPool:
|
|||
"""Load model on each GPU."""
|
||||
print(f"Initializing worker pool with {self.num_gpus} GPUs...")
|
||||
if self.num_gpus > 1:
|
||||
assert device_type == "cuda", "Only CUDA supports multiple workers/GPUs. cpu|mps does not."
|
||||
assert self.device_type == "cuda", "Only CUDA supports multiple workers/GPUs. cpu|mps does not."
|
||||
|
||||
for gpu_id in range(self.num_gpus):
|
||||
|
||||
if device_type == "cuda":
|
||||
if self.device_type == "cuda":
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
print(f"Loading model on GPU {gpu_id}...")
|
||||
else:
|
||||
device = torch.device(device_type) # e.g. cpu|mps
|
||||
print(f"Loading model on {device_type}...")
|
||||
device = torch.device(self.device_type) # e.g. cpu|mps
|
||||
print(f"Loading model on {self.device_type}...")
|
||||
|
||||
model, tokenizer, _ = load_model(source, device, phase="eval", model_tag=model_tag, step=step)
|
||||
engine = Engine(model, tokenizer)
|
||||
|
|
@ -228,8 +243,9 @@ def validate_chat_request(request: ChatRequest):
|
|||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Load models on all GPUs on startup."""
|
||||
configure_runtime()
|
||||
print("Loading nanochat models across GPUs...")
|
||||
app.state.worker_pool = WorkerPool(num_gpus=args.num_gpus)
|
||||
app.state.worker_pool = WorkerPool(device_type, num_gpus=args.num_gpus)
|
||||
await app.state.worker_pool.initialize(args.source, model_tag=args.model_tag, step=args.step)
|
||||
print(f"Server ready at http://localhost:{args.port}")
|
||||
yield
|
||||
|
|
@ -414,6 +430,7 @@ async def stats():
|
|||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
configure_runtime(parse_args())
|
||||
print(f"Starting NanoChat Web Server")
|
||||
print(f"Temperature: {args.temperature}, Top-k: {args.top_k}, Max tokens: {args.max_tokens}")
|
||||
uvicorn.run(app, host=args.host, port=args.port)
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ def test_engine_kv_cache_uses_compute_dtype(monkeypatch):
|
|||
|
||||
|
||||
def test_chat_web_rejects_system_role_with_consistent_error(monkeypatch):
|
||||
monkeypatch.setattr(sys, "argv", ["chat_web_test"])
|
||||
monkeypatch.setattr(sys, "argv", ["chat_web_test", "--definitely-not-a-valid-flag"])
|
||||
sys.modules.pop("scripts.chat_web", None)
|
||||
chat_web = importlib.import_module("scripts.chat_web")
|
||||
|
||||
|
|
@ -29,6 +29,8 @@ def test_chat_web_rejects_system_role_with_consistent_error(monkeypatch):
|
|||
messages=[chat_web.ChatMessage(role="system", content="You are helpful.")]
|
||||
)
|
||||
|
||||
assert chat_web._runtime_initialized is False
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
chat_web.validate_chat_request(request)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user