also allow regenerating assistant message by clicking it, and make sure to feed good seed to generate

This commit is contained in:
Andrej Karpathy 2025-10-16 01:28:37 +00:00
parent 2846999b8f
commit 4346536ab2
2 changed files with 117 additions and 60 deletions

View File

@ -108,6 +108,15 @@
background: transparent;
border: none;
padding: 0.25rem 0;
cursor: pointer;
border-radius: 0.5rem;
padding: 0.5rem;
margin-left: -0.5rem;
transition: background-color 0.2s ease;
}
.message.assistant .message-content:hover {
background-color: #f9fafb;
}
.message.user .message-content {
@ -325,6 +334,17 @@
});
}
// Add click handler for assistant messages to enable regeneration
if (role === 'assistant' && messageIndex !== null) {
contentDiv.setAttribute('data-message-index', messageIndex);
contentDiv.setAttribute('title', 'Click to regenerate this response');
contentDiv.addEventListener('click', function() {
if (!isGenerating) {
regenerateMessage(messageIndex);
}
});
}
messageDiv.appendChild(contentDiv);
chatWrapper.appendChild(messageDiv);
@ -358,6 +378,99 @@
chatInput.focus();
}
async function generateAssistantResponse() {
isGenerating = true;
sendButton.disabled = true;
const assistantContent = addMessage('assistant', '');
assistantContent.innerHTML = '<span class="typing-indicator"></span>';
try {
const response = await fetch(`${API_URL}/chat/completions`, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify({
messages: messages,
temperature: currentTemperature,
top_k: currentTopK,
max_tokens: 512
}),
});
if (!response.ok) {
throw new Error(`HTTP error! status: ${response.status}`);
}
const reader = response.body.getReader();
const decoder = new TextDecoder();
let fullResponse = '';
assistantContent.textContent = '';
while (true) {
const { done, value } = await reader.read();
if (done) break;
const chunk = decoder.decode(value);
const lines = chunk.split('\n');
for (const line of lines) {
if (line.startsWith('data: ')) {
try {
const data = JSON.parse(line.slice(6));
if (data.token) {
fullResponse += data.token;
assistantContent.textContent = fullResponse;
chatContainer.scrollTop = chatContainer.scrollHeight;
}
} catch (e) {
}
}
}
}
const assistantMessageIndex = messages.length;
messages.push({ role: 'assistant', content: fullResponse });
// Add click handler to regenerate this assistant message
assistantContent.setAttribute('data-message-index', assistantMessageIndex);
assistantContent.setAttribute('title', 'Click to regenerate this response');
assistantContent.addEventListener('click', function() {
if (!isGenerating) {
regenerateMessage(assistantMessageIndex);
}
});
} catch (error) {
console.error('Error:', error);
assistantContent.innerHTML = `<div class="error-message">Error: ${error.message}</div>`;
} finally {
isGenerating = false;
sendButton.disabled = !chatInput.value.trim();
}
}
async function regenerateMessage(messageIndex) {
// Find the message in the messages array
if (messageIndex < 0 || messageIndex >= messages.length) return;
const messageToRegenerate = messages[messageIndex];
if (messageToRegenerate.role !== 'assistant') return;
// Remove this message and all subsequent messages from the array
messages = messages.slice(0, messageIndex);
// Remove message elements from DOM starting from messageIndex
const allMessages = chatWrapper.querySelectorAll('.message');
for (let i = messageIndex; i < allMessages.length; i++) {
allMessages[i].remove();
}
// Regenerate the assistant response
await generateAssistantResponse();
}
function handleSlashCommand(command) {
const parts = command.trim().split(/\s+/);
const cmd = parts[0].toLowerCase();
@ -419,72 +532,14 @@
return;
}
isGenerating = true;
chatInput.value = '';
chatInput.style.height = 'auto';
sendButton.disabled = true;
const userMessageIndex = messages.length;
messages.push({ role: 'user', content: message });
addMessage('user', message, userMessageIndex);
const assistantContent = addMessage('assistant', '');
assistantContent.innerHTML = '<span class="typing-indicator"></span>';
try {
const response = await fetch(`${API_URL}/chat/completions`, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify({
messages: messages,
temperature: currentTemperature,
top_k: currentTopK,
max_tokens: 512
}),
});
if (!response.ok) {
throw new Error(`HTTP error! status: ${response.status}`);
}
const reader = response.body.getReader();
const decoder = new TextDecoder();
let fullResponse = '';
assistantContent.textContent = '';
while (true) {
const { done, value } = await reader.read();
if (done) break;
const chunk = decoder.decode(value);
const lines = chunk.split('\n');
for (const line of lines) {
if (line.startsWith('data: ')) {
try {
const data = JSON.parse(line.slice(6));
if (data.token) {
fullResponse += data.token;
assistantContent.textContent = fullResponse;
chatContainer.scrollTop = chatContainer.scrollHeight;
}
} catch (e) {
}
}
}
}
messages.push({ role: 'assistant', content: fullResponse });
} catch (error) {
console.error('Error:', error);
assistantContent.innerHTML = `<div class="error-message">Error: ${error.message}</div>`;
} finally {
isGenerating = false;
sendButton.disabled = !chatInput.value.trim();
}
await generateAssistantResponse();
}
sendButton.disabled = false;

View File

@ -36,6 +36,7 @@ import os
import torch
import asyncio
import logging
import random
from contextlib import asynccontextmanager
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
@ -268,7 +269,8 @@ async def generate_stream(
num_samples=1,
max_tokens=max_new_tokens,
temperature=temperature,
top_k=top_k
top_k=top_k,
seed=random.randint(0, 2**31 - 1)
):
token = token_column[0]