mirror of
https://github.com/karpathy/nanochat.git
synced 2025-12-06 04:12:13 +00:00
also allow regenerating assistant message by clicking it, and make sure to feed good seed to generate
This commit is contained in:
parent
2846999b8f
commit
4346536ab2
173
nanochat/ui.html
173
nanochat/ui.html
|
|
@ -108,6 +108,15 @@
|
||||||
background: transparent;
|
background: transparent;
|
||||||
border: none;
|
border: none;
|
||||||
padding: 0.25rem 0;
|
padding: 0.25rem 0;
|
||||||
|
cursor: pointer;
|
||||||
|
border-radius: 0.5rem;
|
||||||
|
padding: 0.5rem;
|
||||||
|
margin-left: -0.5rem;
|
||||||
|
transition: background-color 0.2s ease;
|
||||||
|
}
|
||||||
|
|
||||||
|
.message.assistant .message-content:hover {
|
||||||
|
background-color: #f9fafb;
|
||||||
}
|
}
|
||||||
|
|
||||||
.message.user .message-content {
|
.message.user .message-content {
|
||||||
|
|
@ -325,6 +334,17 @@
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Add click handler for assistant messages to enable regeneration
|
||||||
|
if (role === 'assistant' && messageIndex !== null) {
|
||||||
|
contentDiv.setAttribute('data-message-index', messageIndex);
|
||||||
|
contentDiv.setAttribute('title', 'Click to regenerate this response');
|
||||||
|
contentDiv.addEventListener('click', function() {
|
||||||
|
if (!isGenerating) {
|
||||||
|
regenerateMessage(messageIndex);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
messageDiv.appendChild(contentDiv);
|
messageDiv.appendChild(contentDiv);
|
||||||
chatWrapper.appendChild(messageDiv);
|
chatWrapper.appendChild(messageDiv);
|
||||||
|
|
||||||
|
|
@ -358,6 +378,99 @@
|
||||||
chatInput.focus();
|
chatInput.focus();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async function generateAssistantResponse() {
|
||||||
|
isGenerating = true;
|
||||||
|
sendButton.disabled = true;
|
||||||
|
|
||||||
|
const assistantContent = addMessage('assistant', '');
|
||||||
|
assistantContent.innerHTML = '<span class="typing-indicator"></span>';
|
||||||
|
|
||||||
|
try {
|
||||||
|
const response = await fetch(`${API_URL}/chat/completions`, {
|
||||||
|
method: 'POST',
|
||||||
|
headers: {
|
||||||
|
'Content-Type': 'application/json',
|
||||||
|
},
|
||||||
|
body: JSON.stringify({
|
||||||
|
messages: messages,
|
||||||
|
temperature: currentTemperature,
|
||||||
|
top_k: currentTopK,
|
||||||
|
max_tokens: 512
|
||||||
|
}),
|
||||||
|
});
|
||||||
|
|
||||||
|
if (!response.ok) {
|
||||||
|
throw new Error(`HTTP error! status: ${response.status}`);
|
||||||
|
}
|
||||||
|
|
||||||
|
const reader = response.body.getReader();
|
||||||
|
const decoder = new TextDecoder();
|
||||||
|
let fullResponse = '';
|
||||||
|
assistantContent.textContent = '';
|
||||||
|
|
||||||
|
while (true) {
|
||||||
|
const { done, value } = await reader.read();
|
||||||
|
if (done) break;
|
||||||
|
|
||||||
|
const chunk = decoder.decode(value);
|
||||||
|
const lines = chunk.split('\n');
|
||||||
|
|
||||||
|
for (const line of lines) {
|
||||||
|
if (line.startsWith('data: ')) {
|
||||||
|
try {
|
||||||
|
const data = JSON.parse(line.slice(6));
|
||||||
|
if (data.token) {
|
||||||
|
fullResponse += data.token;
|
||||||
|
assistantContent.textContent = fullResponse;
|
||||||
|
chatContainer.scrollTop = chatContainer.scrollHeight;
|
||||||
|
}
|
||||||
|
} catch (e) {
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const assistantMessageIndex = messages.length;
|
||||||
|
messages.push({ role: 'assistant', content: fullResponse });
|
||||||
|
|
||||||
|
// Add click handler to regenerate this assistant message
|
||||||
|
assistantContent.setAttribute('data-message-index', assistantMessageIndex);
|
||||||
|
assistantContent.setAttribute('title', 'Click to regenerate this response');
|
||||||
|
assistantContent.addEventListener('click', function() {
|
||||||
|
if (!isGenerating) {
|
||||||
|
regenerateMessage(assistantMessageIndex);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
} catch (error) {
|
||||||
|
console.error('Error:', error);
|
||||||
|
assistantContent.innerHTML = `<div class="error-message">Error: ${error.message}</div>`;
|
||||||
|
} finally {
|
||||||
|
isGenerating = false;
|
||||||
|
sendButton.disabled = !chatInput.value.trim();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async function regenerateMessage(messageIndex) {
|
||||||
|
// Find the message in the messages array
|
||||||
|
if (messageIndex < 0 || messageIndex >= messages.length) return;
|
||||||
|
|
||||||
|
const messageToRegenerate = messages[messageIndex];
|
||||||
|
if (messageToRegenerate.role !== 'assistant') return;
|
||||||
|
|
||||||
|
// Remove this message and all subsequent messages from the array
|
||||||
|
messages = messages.slice(0, messageIndex);
|
||||||
|
|
||||||
|
// Remove message elements from DOM starting from messageIndex
|
||||||
|
const allMessages = chatWrapper.querySelectorAll('.message');
|
||||||
|
for (let i = messageIndex; i < allMessages.length; i++) {
|
||||||
|
allMessages[i].remove();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Regenerate the assistant response
|
||||||
|
await generateAssistantResponse();
|
||||||
|
}
|
||||||
|
|
||||||
function handleSlashCommand(command) {
|
function handleSlashCommand(command) {
|
||||||
const parts = command.trim().split(/\s+/);
|
const parts = command.trim().split(/\s+/);
|
||||||
const cmd = parts[0].toLowerCase();
|
const cmd = parts[0].toLowerCase();
|
||||||
|
|
@ -419,72 +532,14 @@
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
isGenerating = true;
|
|
||||||
chatInput.value = '';
|
chatInput.value = '';
|
||||||
chatInput.style.height = 'auto';
|
chatInput.style.height = 'auto';
|
||||||
sendButton.disabled = true;
|
|
||||||
|
|
||||||
const userMessageIndex = messages.length;
|
const userMessageIndex = messages.length;
|
||||||
messages.push({ role: 'user', content: message });
|
messages.push({ role: 'user', content: message });
|
||||||
addMessage('user', message, userMessageIndex);
|
addMessage('user', message, userMessageIndex);
|
||||||
|
|
||||||
const assistantContent = addMessage('assistant', '');
|
await generateAssistantResponse();
|
||||||
assistantContent.innerHTML = '<span class="typing-indicator"></span>';
|
|
||||||
|
|
||||||
try {
|
|
||||||
const response = await fetch(`${API_URL}/chat/completions`, {
|
|
||||||
method: 'POST',
|
|
||||||
headers: {
|
|
||||||
'Content-Type': 'application/json',
|
|
||||||
},
|
|
||||||
body: JSON.stringify({
|
|
||||||
messages: messages,
|
|
||||||
temperature: currentTemperature,
|
|
||||||
top_k: currentTopK,
|
|
||||||
max_tokens: 512
|
|
||||||
}),
|
|
||||||
});
|
|
||||||
|
|
||||||
if (!response.ok) {
|
|
||||||
throw new Error(`HTTP error! status: ${response.status}`);
|
|
||||||
}
|
|
||||||
|
|
||||||
const reader = response.body.getReader();
|
|
||||||
const decoder = new TextDecoder();
|
|
||||||
let fullResponse = '';
|
|
||||||
assistantContent.textContent = '';
|
|
||||||
|
|
||||||
while (true) {
|
|
||||||
const { done, value } = await reader.read();
|
|
||||||
if (done) break;
|
|
||||||
|
|
||||||
const chunk = decoder.decode(value);
|
|
||||||
const lines = chunk.split('\n');
|
|
||||||
|
|
||||||
for (const line of lines) {
|
|
||||||
if (line.startsWith('data: ')) {
|
|
||||||
try {
|
|
||||||
const data = JSON.parse(line.slice(6));
|
|
||||||
if (data.token) {
|
|
||||||
fullResponse += data.token;
|
|
||||||
assistantContent.textContent = fullResponse;
|
|
||||||
chatContainer.scrollTop = chatContainer.scrollHeight;
|
|
||||||
}
|
|
||||||
} catch (e) {
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
messages.push({ role: 'assistant', content: fullResponse });
|
|
||||||
|
|
||||||
} catch (error) {
|
|
||||||
console.error('Error:', error);
|
|
||||||
assistantContent.innerHTML = `<div class="error-message">Error: ${error.message}</div>`;
|
|
||||||
} finally {
|
|
||||||
isGenerating = false;
|
|
||||||
sendButton.disabled = !chatInput.value.trim();
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
sendButton.disabled = false;
|
sendButton.disabled = false;
|
||||||
|
|
|
||||||
|
|
@ -36,6 +36,7 @@ import os
|
||||||
import torch
|
import torch
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
|
import random
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from fastapi import FastAPI, HTTPException
|
from fastapi import FastAPI, HTTPException
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
|
@ -268,7 +269,8 @@ async def generate_stream(
|
||||||
num_samples=1,
|
num_samples=1,
|
||||||
max_tokens=max_new_tokens,
|
max_tokens=max_new_tokens,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
top_k=top_k
|
top_k=top_k,
|
||||||
|
seed=random.randint(0, 2**31 - 1)
|
||||||
):
|
):
|
||||||
token = token_column[0]
|
token = token_column[0]
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user