mirror of
https://github.com/karpathy/nanochat.git
synced 2026-06-15 10:39:08 +00:00
Merge 6c1322e4ce into b9b6ce137b
This commit is contained in:
commit
adc7a4f632
|
|
@ -33,22 +33,15 @@ class CustomJSON(Task):
|
|||
|
||||
else:
|
||||
with open(filepath, 'r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
for line_num, raw_line in enumerate(f, start=1):
|
||||
line = raw_line.strip()
|
||||
if not line: # skip empty lines
|
||||
continue
|
||||
messages = json.loads(line)
|
||||
# Validate the conversation structure
|
||||
assert isinstance(messages, list), f"Expected list of messages, got {type(messages)}"
|
||||
assert len(messages) >= 2, f"Conversation must have at least 2 messages, got {len(messages)}"
|
||||
# Validate message structure and alternating roles
|
||||
for i, message in enumerate(messages):
|
||||
assert "role" in message, f"Message {i} missing 'role' field"
|
||||
assert "content" in message, f"Message {i} missing 'content' field"
|
||||
expected_role = "user" if i % 2 == 0 else "assistant"
|
||||
assert message["role"] == expected_role, f"Message {i} has role {message['role']} but should be {expected_role}"
|
||||
assert isinstance(message["content"], str), f"Message {i} content must be a string"
|
||||
|
||||
try:
|
||||
messages = json.loads(line)
|
||||
except json.JSONDecodeError as exc:
|
||||
raise ValueError(f"{filepath}:{line_num}: invalid JSON ({exc.msg})") from exc
|
||||
self._validate_conversation(messages, filepath, line_num)
|
||||
self.conversations.append(messages)
|
||||
|
||||
self.length = len(self.conversations)
|
||||
|
|
@ -56,10 +49,36 @@ class CustomJSON(Task):
|
|||
def num_examples(self):
|
||||
return self.length
|
||||
|
||||
@staticmethod
|
||||
def _validate_conversation(messages, filepath, line_num):
|
||||
if not isinstance(messages, list):
|
||||
raise ValueError(
|
||||
f"{filepath}:{line_num}: expected a JSON array of messages, got {type(messages).__name__}"
|
||||
)
|
||||
if len(messages) < 2:
|
||||
raise ValueError(
|
||||
f"{filepath}:{line_num}: conversation must have at least 2 messages, got {len(messages)}"
|
||||
)
|
||||
for i, message in enumerate(messages):
|
||||
if not isinstance(message, dict):
|
||||
raise ValueError(
|
||||
f"{filepath}:{line_num}: message {i} must be an object, got {type(message).__name__}"
|
||||
)
|
||||
if "role" not in message:
|
||||
raise ValueError(f"{filepath}:{line_num}: message {i} missing 'role' field")
|
||||
if "content" not in message:
|
||||
raise ValueError(f"{filepath}:{line_num}: message {i} missing 'content' field")
|
||||
expected_role = "user" if i % 2 == 0 else "assistant"
|
||||
if message["role"] != expected_role:
|
||||
raise ValueError(
|
||||
f"{filepath}:{line_num}: message {i} has role {message['role']} but should be {expected_role}"
|
||||
)
|
||||
if not isinstance(message["content"], str):
|
||||
raise ValueError(f"{filepath}:{line_num}: message {i} content must be a string")
|
||||
|
||||
def get_example(self, index):
|
||||
messages = self.conversations[index]
|
||||
conversation = {
|
||||
"messages": messages,
|
||||
}
|
||||
return conversation
|
||||
|
||||
|
|
|
|||
74
tests/test_customjson.py
Normal file
74
tests/test_customjson.py
Normal file
|
|
@ -0,0 +1,74 @@
|
|||
import json
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
from tasks.customjson import CustomJSON
|
||||
|
||||
|
||||
class CustomJSONTests(unittest.TestCase):
|
||||
def _write_jsonl(self, directory, name, lines):
|
||||
path = Path(directory) / name
|
||||
path.write_text("\n".join(lines) + "\n", encoding="utf-8")
|
||||
return path
|
||||
|
||||
def test_loads_valid_conversations(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
path = self._write_jsonl(
|
||||
tmpdir,
|
||||
"valid.jsonl",
|
||||
[
|
||||
json.dumps(
|
||||
[
|
||||
{"role": "user", "content": "Hi"},
|
||||
{"role": "assistant", "content": "Hello"},
|
||||
]
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
task = CustomJSON(str(path))
|
||||
|
||||
self.assertEqual(task.num_examples(), 1)
|
||||
self.assertEqual(task.get_example(0)["messages"][1]["content"], "Hello")
|
||||
|
||||
def test_invalid_json_reports_file_and_line(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
path = self._write_jsonl(
|
||||
tmpdir,
|
||||
"broken.jsonl",
|
||||
[
|
||||
json.dumps(
|
||||
[
|
||||
{"role": "user", "content": "Hi"},
|
||||
{"role": "assistant", "content": "Hello"},
|
||||
]
|
||||
),
|
||||
'{"role": "user"',
|
||||
],
|
||||
)
|
||||
|
||||
with self.assertRaisesRegex(ValueError, rf"{path}:2: invalid JSON"):
|
||||
CustomJSON(str(path))
|
||||
|
||||
def test_invalid_message_structure_reports_file_and_line(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
path = self._write_jsonl(
|
||||
tmpdir,
|
||||
"wrong-role.jsonl",
|
||||
[
|
||||
json.dumps(
|
||||
[
|
||||
{"role": "assistant", "content": "Hi"},
|
||||
{"role": "assistant", "content": "Hello"},
|
||||
]
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
with self.assertRaisesRegex(ValueError, rf"{path}:1: message 0 has role assistant but should be user"):
|
||||
CustomJSON(str(path))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Loading…
Reference in New Issue
Block a user