mirror of
https://github.com/karpathy/nanochat.git
synced 2025-12-06 04:12:13 +00:00
also allow regenerating assistant message by clicking it, and make sure to feed good seed to generate
This commit is contained in:
parent
2846999b8f
commit
4346536ab2
173
nanochat/ui.html
173
nanochat/ui.html
|
|
@ -108,6 +108,15 @@
|
|||
background: transparent;
|
||||
border: none;
|
||||
padding: 0.25rem 0;
|
||||
cursor: pointer;
|
||||
border-radius: 0.5rem;
|
||||
padding: 0.5rem;
|
||||
margin-left: -0.5rem;
|
||||
transition: background-color 0.2s ease;
|
||||
}
|
||||
|
||||
.message.assistant .message-content:hover {
|
||||
background-color: #f9fafb;
|
||||
}
|
||||
|
||||
.message.user .message-content {
|
||||
|
|
@ -325,6 +334,17 @@
|
|||
});
|
||||
}
|
||||
|
||||
// Add click handler for assistant messages to enable regeneration
|
||||
if (role === 'assistant' && messageIndex !== null) {
|
||||
contentDiv.setAttribute('data-message-index', messageIndex);
|
||||
contentDiv.setAttribute('title', 'Click to regenerate this response');
|
||||
contentDiv.addEventListener('click', function() {
|
||||
if (!isGenerating) {
|
||||
regenerateMessage(messageIndex);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
messageDiv.appendChild(contentDiv);
|
||||
chatWrapper.appendChild(messageDiv);
|
||||
|
||||
|
|
@ -358,6 +378,99 @@
|
|||
chatInput.focus();
|
||||
}
|
||||
|
||||
async function generateAssistantResponse() {
|
||||
isGenerating = true;
|
||||
sendButton.disabled = true;
|
||||
|
||||
const assistantContent = addMessage('assistant', '');
|
||||
assistantContent.innerHTML = '<span class="typing-indicator"></span>';
|
||||
|
||||
try {
|
||||
const response = await fetch(`${API_URL}/chat/completions`, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
body: JSON.stringify({
|
||||
messages: messages,
|
||||
temperature: currentTemperature,
|
||||
top_k: currentTopK,
|
||||
max_tokens: 512
|
||||
}),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error(`HTTP error! status: ${response.status}`);
|
||||
}
|
||||
|
||||
const reader = response.body.getReader();
|
||||
const decoder = new TextDecoder();
|
||||
let fullResponse = '';
|
||||
assistantContent.textContent = '';
|
||||
|
||||
while (true) {
|
||||
const { done, value } = await reader.read();
|
||||
if (done) break;
|
||||
|
||||
const chunk = decoder.decode(value);
|
||||
const lines = chunk.split('\n');
|
||||
|
||||
for (const line of lines) {
|
||||
if (line.startsWith('data: ')) {
|
||||
try {
|
||||
const data = JSON.parse(line.slice(6));
|
||||
if (data.token) {
|
||||
fullResponse += data.token;
|
||||
assistantContent.textContent = fullResponse;
|
||||
chatContainer.scrollTop = chatContainer.scrollHeight;
|
||||
}
|
||||
} catch (e) {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const assistantMessageIndex = messages.length;
|
||||
messages.push({ role: 'assistant', content: fullResponse });
|
||||
|
||||
// Add click handler to regenerate this assistant message
|
||||
assistantContent.setAttribute('data-message-index', assistantMessageIndex);
|
||||
assistantContent.setAttribute('title', 'Click to regenerate this response');
|
||||
assistantContent.addEventListener('click', function() {
|
||||
if (!isGenerating) {
|
||||
regenerateMessage(assistantMessageIndex);
|
||||
}
|
||||
});
|
||||
|
||||
} catch (error) {
|
||||
console.error('Error:', error);
|
||||
assistantContent.innerHTML = `<div class="error-message">Error: ${error.message}</div>`;
|
||||
} finally {
|
||||
isGenerating = false;
|
||||
sendButton.disabled = !chatInput.value.trim();
|
||||
}
|
||||
}
|
||||
|
||||
async function regenerateMessage(messageIndex) {
|
||||
// Find the message in the messages array
|
||||
if (messageIndex < 0 || messageIndex >= messages.length) return;
|
||||
|
||||
const messageToRegenerate = messages[messageIndex];
|
||||
if (messageToRegenerate.role !== 'assistant') return;
|
||||
|
||||
// Remove this message and all subsequent messages from the array
|
||||
messages = messages.slice(0, messageIndex);
|
||||
|
||||
// Remove message elements from DOM starting from messageIndex
|
||||
const allMessages = chatWrapper.querySelectorAll('.message');
|
||||
for (let i = messageIndex; i < allMessages.length; i++) {
|
||||
allMessages[i].remove();
|
||||
}
|
||||
|
||||
// Regenerate the assistant response
|
||||
await generateAssistantResponse();
|
||||
}
|
||||
|
||||
function handleSlashCommand(command) {
|
||||
const parts = command.trim().split(/\s+/);
|
||||
const cmd = parts[0].toLowerCase();
|
||||
|
|
@ -419,72 +532,14 @@
|
|||
return;
|
||||
}
|
||||
|
||||
isGenerating = true;
|
||||
chatInput.value = '';
|
||||
chatInput.style.height = 'auto';
|
||||
sendButton.disabled = true;
|
||||
|
||||
const userMessageIndex = messages.length;
|
||||
messages.push({ role: 'user', content: message });
|
||||
addMessage('user', message, userMessageIndex);
|
||||
|
||||
const assistantContent = addMessage('assistant', '');
|
||||
assistantContent.innerHTML = '<span class="typing-indicator"></span>';
|
||||
|
||||
try {
|
||||
const response = await fetch(`${API_URL}/chat/completions`, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
body: JSON.stringify({
|
||||
messages: messages,
|
||||
temperature: currentTemperature,
|
||||
top_k: currentTopK,
|
||||
max_tokens: 512
|
||||
}),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error(`HTTP error! status: ${response.status}`);
|
||||
}
|
||||
|
||||
const reader = response.body.getReader();
|
||||
const decoder = new TextDecoder();
|
||||
let fullResponse = '';
|
||||
assistantContent.textContent = '';
|
||||
|
||||
while (true) {
|
||||
const { done, value } = await reader.read();
|
||||
if (done) break;
|
||||
|
||||
const chunk = decoder.decode(value);
|
||||
const lines = chunk.split('\n');
|
||||
|
||||
for (const line of lines) {
|
||||
if (line.startsWith('data: ')) {
|
||||
try {
|
||||
const data = JSON.parse(line.slice(6));
|
||||
if (data.token) {
|
||||
fullResponse += data.token;
|
||||
assistantContent.textContent = fullResponse;
|
||||
chatContainer.scrollTop = chatContainer.scrollHeight;
|
||||
}
|
||||
} catch (e) {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
messages.push({ role: 'assistant', content: fullResponse });
|
||||
|
||||
} catch (error) {
|
||||
console.error('Error:', error);
|
||||
assistantContent.innerHTML = `<div class="error-message">Error: ${error.message}</div>`;
|
||||
} finally {
|
||||
isGenerating = false;
|
||||
sendButton.disabled = !chatInput.value.trim();
|
||||
}
|
||||
await generateAssistantResponse();
|
||||
}
|
||||
|
||||
sendButton.disabled = false;
|
||||
|
|
|
|||
|
|
@ -36,6 +36,7 @@ import os
|
|||
import torch
|
||||
import asyncio
|
||||
import logging
|
||||
import random
|
||||
from contextlib import asynccontextmanager
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
|
@ -268,7 +269,8 @@ async def generate_stream(
|
|||
num_samples=1,
|
||||
max_tokens=max_new_tokens,
|
||||
temperature=temperature,
|
||||
top_k=top_k
|
||||
top_k=top_k,
|
||||
seed=random.randint(0, 2**31 - 1)
|
||||
):
|
||||
token = token_column[0]
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user