mirror of
https://github.com/karpathy/nanochat.git
synced 2025-12-06 04:12:13 +00:00
add basic logging to chat_web, which i think might be fun
This commit is contained in:
parent
52bfeea8bd
commit
03fa673b7d
|
|
@ -35,6 +35,7 @@ import json
|
||||||
import os
|
import os
|
||||||
import torch
|
import torch
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import logging
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from fastapi import FastAPI, HTTPException
|
from fastapi import FastAPI, HTTPException
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
|
@ -70,6 +71,14 @@ parser.add_argument('-p', '--port', type=int, default=8000, help='Port to run th
|
||||||
parser.add_argument('--host', type=str, default='0.0.0.0', help='Host to bind the server to')
|
parser.add_argument('--host', type=str, default='0.0.0.0', help='Host to bind the server to')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Configure logging for conversation traffic
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format='%(asctime)s - %(message)s',
|
||||||
|
datefmt='%Y-%m-%d %H:%M:%S'
|
||||||
|
)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
|
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
@ -273,6 +282,12 @@ async def chat_completions(request: ChatRequest):
|
||||||
# Basic validation to prevent abuse
|
# Basic validation to prevent abuse
|
||||||
validate_chat_request(request)
|
validate_chat_request(request)
|
||||||
|
|
||||||
|
# Log incoming conversation to console
|
||||||
|
logger.info("="*20)
|
||||||
|
for i, message in enumerate(request.messages):
|
||||||
|
logger.info(f"[{message.role.upper()}]: {message.content}")
|
||||||
|
logger.info("-"*20)
|
||||||
|
|
||||||
# Acquire a worker from the pool (will wait if all are busy)
|
# Acquire a worker from the pool (will wait if all are busy)
|
||||||
worker_pool = app.state.worker_pool
|
worker_pool = app.state.worker_pool
|
||||||
worker = await worker_pool.acquire_worker()
|
worker = await worker_pool.acquire_worker()
|
||||||
|
|
@ -299,6 +314,7 @@ async def chat_completions(request: ChatRequest):
|
||||||
conversation_tokens.append(assistant_start)
|
conversation_tokens.append(assistant_start)
|
||||||
|
|
||||||
# Streaming response with worker release after completion
|
# Streaming response with worker release after completion
|
||||||
|
response_tokens = []
|
||||||
async def stream_and_release():
|
async def stream_and_release():
|
||||||
try:
|
try:
|
||||||
async for chunk in generate_stream(
|
async for chunk in generate_stream(
|
||||||
|
|
@ -308,8 +324,16 @@ async def chat_completions(request: ChatRequest):
|
||||||
max_new_tokens=request.max_tokens,
|
max_new_tokens=request.max_tokens,
|
||||||
top_k=request.top_k
|
top_k=request.top_k
|
||||||
):
|
):
|
||||||
|
# Accumulate response for logging
|
||||||
|
chunk_data = json.loads(chunk.replace("data: ", "").strip())
|
||||||
|
if "token" in chunk_data:
|
||||||
|
response_tokens.append(chunk_data["token"])
|
||||||
yield chunk
|
yield chunk
|
||||||
finally:
|
finally:
|
||||||
|
# Log the assistant response to console
|
||||||
|
full_response = "".join(response_tokens)
|
||||||
|
logger.info(f"[ASSISTANT] (GPU {worker.gpu_id}): {full_response}")
|
||||||
|
logger.info("="*20)
|
||||||
# Release worker back to pool after streaming is done
|
# Release worker back to pool after streaming is done
|
||||||
await worker_pool.release_worker(worker)
|
await worker_pool.release_worker(worker)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user