This commit is contained in:
陈家名 2026-04-14 13:27:12 -07:00 committed by GitHub
commit da451695ff
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 108 additions and 15 deletions

View File

@ -33,22 +33,15 @@ class CustomJSON(Task):
else:
with open(filepath, 'r', encoding='utf-8') as f:
for line in f:
line = line.strip()
for line_num, raw_line in enumerate(f, start=1):
line = raw_line.strip()
if not line: # skip empty lines
continue
messages = json.loads(line)
# Validate the conversation structure
assert isinstance(messages, list), f"Expected list of messages, got {type(messages)}"
assert len(messages) >= 2, f"Conversation must have at least 2 messages, got {len(messages)}"
# Validate message structure and alternating roles
for i, message in enumerate(messages):
assert "role" in message, f"Message {i} missing 'role' field"
assert "content" in message, f"Message {i} missing 'content' field"
expected_role = "user" if i % 2 == 0 else "assistant"
assert message["role"] == expected_role, f"Message {i} has role {message['role']} but should be {expected_role}"
assert isinstance(message["content"], str), f"Message {i} content must be a string"
try:
messages = json.loads(line)
except json.JSONDecodeError as exc:
raise ValueError(f"{filepath}:{line_num}: invalid JSON ({exc.msg})") from exc
self._validate_conversation(messages, filepath, line_num)
self.conversations.append(messages)
self.length = len(self.conversations)
@ -56,10 +49,36 @@ class CustomJSON(Task):
def num_examples(self):
return self.length
@staticmethod
def _validate_conversation(messages, filepath, line_num):
if not isinstance(messages, list):
raise ValueError(
f"{filepath}:{line_num}: expected a JSON array of messages, got {type(messages).__name__}"
)
if len(messages) < 2:
raise ValueError(
f"{filepath}:{line_num}: conversation must have at least 2 messages, got {len(messages)}"
)
for i, message in enumerate(messages):
if not isinstance(message, dict):
raise ValueError(
f"{filepath}:{line_num}: message {i} must be an object, got {type(message).__name__}"
)
if "role" not in message:
raise ValueError(f"{filepath}:{line_num}: message {i} missing 'role' field")
if "content" not in message:
raise ValueError(f"{filepath}:{line_num}: message {i} missing 'content' field")
expected_role = "user" if i % 2 == 0 else "assistant"
if message["role"] != expected_role:
raise ValueError(
f"{filepath}:{line_num}: message {i} has role {message['role']} but should be {expected_role}"
)
if not isinstance(message["content"], str):
raise ValueError(f"{filepath}:{line_num}: message {i} content must be a string")
def get_example(self, index):
messages = self.conversations[index]
conversation = {
"messages": messages,
}
return conversation

74
tests/test_customjson.py Normal file
View File

@ -0,0 +1,74 @@
import json
import tempfile
import unittest
from pathlib import Path
from tasks.customjson import CustomJSON
class CustomJSONTests(unittest.TestCase):
def _write_jsonl(self, directory, name, lines):
path = Path(directory) / name
path.write_text("\n".join(lines) + "\n", encoding="utf-8")
return path
def test_loads_valid_conversations(self):
with tempfile.TemporaryDirectory() as tmpdir:
path = self._write_jsonl(
tmpdir,
"valid.jsonl",
[
json.dumps(
[
{"role": "user", "content": "Hi"},
{"role": "assistant", "content": "Hello"},
]
)
],
)
task = CustomJSON(str(path))
self.assertEqual(task.num_examples(), 1)
self.assertEqual(task.get_example(0)["messages"][1]["content"], "Hello")
def test_invalid_json_reports_file_and_line(self):
with tempfile.TemporaryDirectory() as tmpdir:
path = self._write_jsonl(
tmpdir,
"broken.jsonl",
[
json.dumps(
[
{"role": "user", "content": "Hi"},
{"role": "assistant", "content": "Hello"},
]
),
'{"role": "user"',
],
)
with self.assertRaisesRegex(ValueError, rf"{path}:2: invalid JSON"):
CustomJSON(str(path))
def test_invalid_message_structure_reports_file_and_line(self):
with tempfile.TemporaryDirectory() as tmpdir:
path = self._write_jsonl(
tmpdir,
"wrong-role.jsonl",
[
json.dumps(
[
{"role": "assistant", "content": "Hi"},
{"role": "assistant", "content": "Hello"},
]
)
],
)
with self.assertRaisesRegex(ValueError, rf"{path}:1: message 0 has role assistant but should be user"):
CustomJSON(str(path))
if __name__ == "__main__":
unittest.main()