mirror of
https://github.com/karpathy/nanochat.git
synced 2025-12-06 04:12:13 +00:00
add basic logging to chat_web, which i think might be fun
This commit is contained in:
parent
52bfeea8bd
commit
03fa673b7d
|
|
@ -35,6 +35,7 @@ import json
|
|||
import os
|
||||
import torch
|
||||
import asyncio
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
|
@ -70,6 +71,14 @@ parser.add_argument('-p', '--port', type=int, default=8000, help='Port to run th
|
|||
parser.add_argument('--host', type=str, default='0.0.0.0', help='Host to bind the server to')
|
||||
args = parser.parse_args()
|
||||
|
||||
# Configure logging for conversation traffic
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(message)s',
|
||||
datefmt='%Y-%m-%d %H:%M:%S'
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
|
||||
|
||||
@dataclass
|
||||
|
|
@ -273,6 +282,12 @@ async def chat_completions(request: ChatRequest):
|
|||
# Basic validation to prevent abuse
|
||||
validate_chat_request(request)
|
||||
|
||||
# Log incoming conversation to console
|
||||
logger.info("="*20)
|
||||
for i, message in enumerate(request.messages):
|
||||
logger.info(f"[{message.role.upper()}]: {message.content}")
|
||||
logger.info("-"*20)
|
||||
|
||||
# Acquire a worker from the pool (will wait if all are busy)
|
||||
worker_pool = app.state.worker_pool
|
||||
worker = await worker_pool.acquire_worker()
|
||||
|
|
@ -299,6 +314,7 @@ async def chat_completions(request: ChatRequest):
|
|||
conversation_tokens.append(assistant_start)
|
||||
|
||||
# Streaming response with worker release after completion
|
||||
response_tokens = []
|
||||
async def stream_and_release():
|
||||
try:
|
||||
async for chunk in generate_stream(
|
||||
|
|
@ -308,8 +324,16 @@ async def chat_completions(request: ChatRequest):
|
|||
max_new_tokens=request.max_tokens,
|
||||
top_k=request.top_k
|
||||
):
|
||||
# Accumulate response for logging
|
||||
chunk_data = json.loads(chunk.replace("data: ", "").strip())
|
||||
if "token" in chunk_data:
|
||||
response_tokens.append(chunk_data["token"])
|
||||
yield chunk
|
||||
finally:
|
||||
# Log the assistant response to console
|
||||
full_response = "".join(response_tokens)
|
||||
logger.info(f"[ASSISTANT] (GPU {worker.gpu_id}): {full_response}")
|
||||
logger.info("="*20)
|
||||
# Release worker back to pool after streaming is done
|
||||
await worker_pool.release_worker(worker)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user