also allow regenerating assistant message by clicking it, and make sure to feed good seed to generate

This commit is contained in:
Andrej Karpathy 2025-10-16 01:28:37 +00:00
parent 2846999b8f
commit 4346536ab2
2 changed files with 117 additions and 60 deletions

View File

@ -108,6 +108,15 @@
background: transparent; background: transparent;
border: none; border: none;
padding: 0.25rem 0; padding: 0.25rem 0;
cursor: pointer;
border-radius: 0.5rem;
padding: 0.5rem;
margin-left: -0.5rem;
transition: background-color 0.2s ease;
}
.message.assistant .message-content:hover {
background-color: #f9fafb;
} }
.message.user .message-content { .message.user .message-content {
@ -325,6 +334,17 @@
}); });
} }
// Add click handler for assistant messages to enable regeneration
if (role === 'assistant' && messageIndex !== null) {
contentDiv.setAttribute('data-message-index', messageIndex);
contentDiv.setAttribute('title', 'Click to regenerate this response');
contentDiv.addEventListener('click', function() {
if (!isGenerating) {
regenerateMessage(messageIndex);
}
});
}
messageDiv.appendChild(contentDiv); messageDiv.appendChild(contentDiv);
chatWrapper.appendChild(messageDiv); chatWrapper.appendChild(messageDiv);
@ -358,6 +378,99 @@
chatInput.focus(); chatInput.focus();
} }
async function generateAssistantResponse() {
isGenerating = true;
sendButton.disabled = true;
const assistantContent = addMessage('assistant', '');
assistantContent.innerHTML = '<span class="typing-indicator"></span>';
try {
const response = await fetch(`${API_URL}/chat/completions`, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify({
messages: messages,
temperature: currentTemperature,
top_k: currentTopK,
max_tokens: 512
}),
});
if (!response.ok) {
throw new Error(`HTTP error! status: ${response.status}`);
}
const reader = response.body.getReader();
const decoder = new TextDecoder();
let fullResponse = '';
assistantContent.textContent = '';
while (true) {
const { done, value } = await reader.read();
if (done) break;
const chunk = decoder.decode(value);
const lines = chunk.split('\n');
for (const line of lines) {
if (line.startsWith('data: ')) {
try {
const data = JSON.parse(line.slice(6));
if (data.token) {
fullResponse += data.token;
assistantContent.textContent = fullResponse;
chatContainer.scrollTop = chatContainer.scrollHeight;
}
} catch (e) {
}
}
}
}
const assistantMessageIndex = messages.length;
messages.push({ role: 'assistant', content: fullResponse });
// Add click handler to regenerate this assistant message
assistantContent.setAttribute('data-message-index', assistantMessageIndex);
assistantContent.setAttribute('title', 'Click to regenerate this response');
assistantContent.addEventListener('click', function() {
if (!isGenerating) {
regenerateMessage(assistantMessageIndex);
}
});
} catch (error) {
console.error('Error:', error);
assistantContent.innerHTML = `<div class="error-message">Error: ${error.message}</div>`;
} finally {
isGenerating = false;
sendButton.disabled = !chatInput.value.trim();
}
}
async function regenerateMessage(messageIndex) {
// Find the message in the messages array
if (messageIndex < 0 || messageIndex >= messages.length) return;
const messageToRegenerate = messages[messageIndex];
if (messageToRegenerate.role !== 'assistant') return;
// Remove this message and all subsequent messages from the array
messages = messages.slice(0, messageIndex);
// Remove message elements from DOM starting from messageIndex
const allMessages = chatWrapper.querySelectorAll('.message');
for (let i = messageIndex; i < allMessages.length; i++) {
allMessages[i].remove();
}
// Regenerate the assistant response
await generateAssistantResponse();
}
function handleSlashCommand(command) { function handleSlashCommand(command) {
const parts = command.trim().split(/\s+/); const parts = command.trim().split(/\s+/);
const cmd = parts[0].toLowerCase(); const cmd = parts[0].toLowerCase();
@ -419,72 +532,14 @@
return; return;
} }
isGenerating = true;
chatInput.value = ''; chatInput.value = '';
chatInput.style.height = 'auto'; chatInput.style.height = 'auto';
sendButton.disabled = true;
const userMessageIndex = messages.length; const userMessageIndex = messages.length;
messages.push({ role: 'user', content: message }); messages.push({ role: 'user', content: message });
addMessage('user', message, userMessageIndex); addMessage('user', message, userMessageIndex);
const assistantContent = addMessage('assistant', ''); await generateAssistantResponse();
assistantContent.innerHTML = '<span class="typing-indicator"></span>';
try {
const response = await fetch(`${API_URL}/chat/completions`, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify({
messages: messages,
temperature: currentTemperature,
top_k: currentTopK,
max_tokens: 512
}),
});
if (!response.ok) {
throw new Error(`HTTP error! status: ${response.status}`);
}
const reader = response.body.getReader();
const decoder = new TextDecoder();
let fullResponse = '';
assistantContent.textContent = '';
while (true) {
const { done, value } = await reader.read();
if (done) break;
const chunk = decoder.decode(value);
const lines = chunk.split('\n');
for (const line of lines) {
if (line.startsWith('data: ')) {
try {
const data = JSON.parse(line.slice(6));
if (data.token) {
fullResponse += data.token;
assistantContent.textContent = fullResponse;
chatContainer.scrollTop = chatContainer.scrollHeight;
}
} catch (e) {
}
}
}
}
messages.push({ role: 'assistant', content: fullResponse });
} catch (error) {
console.error('Error:', error);
assistantContent.innerHTML = `<div class="error-message">Error: ${error.message}</div>`;
} finally {
isGenerating = false;
sendButton.disabled = !chatInput.value.trim();
}
} }
sendButton.disabled = false; sendButton.disabled = false;

View File

@ -36,6 +36,7 @@ import os
import torch import torch
import asyncio import asyncio
import logging import logging
import random
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from fastapi import FastAPI, HTTPException from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
@ -268,7 +269,8 @@ async def generate_stream(
num_samples=1, num_samples=1,
max_tokens=max_new_tokens, max_tokens=max_new_tokens,
temperature=temperature, temperature=temperature,
top_k=top_k top_k=top_k,
seed=random.randint(0, 2**31 - 1)
): ):
token = token_column[0] token = token_column[0]