diff --git a/tasks/customjson.py b/tasks/customjson.py index aeb1a3f7..9b9102e6 100644 --- a/tasks/customjson.py +++ b/tasks/customjson.py @@ -33,22 +33,15 @@ class CustomJSON(Task): else: with open(filepath, 'r', encoding='utf-8') as f: - for line in f: - line = line.strip() + for line_num, raw_line in enumerate(f, start=1): + line = raw_line.strip() if not line: # skip empty lines continue - messages = json.loads(line) - # Validate the conversation structure - assert isinstance(messages, list), f"Expected list of messages, got {type(messages)}" - assert len(messages) >= 2, f"Conversation must have at least 2 messages, got {len(messages)}" - # Validate message structure and alternating roles - for i, message in enumerate(messages): - assert "role" in message, f"Message {i} missing 'role' field" - assert "content" in message, f"Message {i} missing 'content' field" - expected_role = "user" if i % 2 == 0 else "assistant" - assert message["role"] == expected_role, f"Message {i} has role {message['role']} but should be {expected_role}" - assert isinstance(message["content"], str), f"Message {i} content must be a string" - + try: + messages = json.loads(line) + except json.JSONDecodeError as exc: + raise ValueError(f"{filepath}:{line_num}: invalid JSON ({exc.msg})") from exc + self._validate_conversation(messages, filepath, line_num) self.conversations.append(messages) self.length = len(self.conversations) @@ -56,10 +49,36 @@ class CustomJSON(Task): def num_examples(self): return self.length + @staticmethod + def _validate_conversation(messages, filepath, line_num): + if not isinstance(messages, list): + raise ValueError( + f"{filepath}:{line_num}: expected a JSON array of messages, got {type(messages).__name__}" + ) + if len(messages) < 2: + raise ValueError( + f"{filepath}:{line_num}: conversation must have at least 2 messages, got {len(messages)}" + ) + for i, message in enumerate(messages): + if not isinstance(message, dict): + raise ValueError( + f"{filepath}:{line_num}: message {i} must be an object, got {type(message).__name__}" + ) + if "role" not in message: + raise ValueError(f"{filepath}:{line_num}: message {i} missing 'role' field") + if "content" not in message: + raise ValueError(f"{filepath}:{line_num}: message {i} missing 'content' field") + expected_role = "user" if i % 2 == 0 else "assistant" + if message["role"] != expected_role: + raise ValueError( + f"{filepath}:{line_num}: message {i} has role {message['role']} but should be {expected_role}" + ) + if not isinstance(message["content"], str): + raise ValueError(f"{filepath}:{line_num}: message {i} content must be a string") + def get_example(self, index): messages = self.conversations[index] conversation = { "messages": messages, } return conversation - diff --git a/tests/test_customjson.py b/tests/test_customjson.py new file mode 100644 index 00000000..281250fd --- /dev/null +++ b/tests/test_customjson.py @@ -0,0 +1,74 @@ +import json +import tempfile +import unittest +from pathlib import Path + +from tasks.customjson import CustomJSON + + +class CustomJSONTests(unittest.TestCase): + def _write_jsonl(self, directory, name, lines): + path = Path(directory) / name + path.write_text("\n".join(lines) + "\n", encoding="utf-8") + return path + + def test_loads_valid_conversations(self): + with tempfile.TemporaryDirectory() as tmpdir: + path = self._write_jsonl( + tmpdir, + "valid.jsonl", + [ + json.dumps( + [ + {"role": "user", "content": "Hi"}, + {"role": "assistant", "content": "Hello"}, + ] + ) + ], + ) + + task = CustomJSON(str(path)) + + self.assertEqual(task.num_examples(), 1) + self.assertEqual(task.get_example(0)["messages"][1]["content"], "Hello") + + def test_invalid_json_reports_file_and_line(self): + with tempfile.TemporaryDirectory() as tmpdir: + path = self._write_jsonl( + tmpdir, + "broken.jsonl", + [ + json.dumps( + [ + {"role": "user", "content": "Hi"}, + {"role": "assistant", "content": "Hello"}, + ] + ), + '{"role": "user"', + ], + ) + + with self.assertRaisesRegex(ValueError, rf"{path}:2: invalid JSON"): + CustomJSON(str(path)) + + def test_invalid_message_structure_reports_file_and_line(self): + with tempfile.TemporaryDirectory() as tmpdir: + path = self._write_jsonl( + tmpdir, + "wrong-role.jsonl", + [ + json.dumps( + [ + {"role": "assistant", "content": "Hi"}, + {"role": "assistant", "content": "Hello"}, + ] + ) + ], + ) + + with self.assertRaisesRegex(ValueError, rf"{path}:1: message 0 has role assistant but should be user"): + CustomJSON(str(path)) + + +if __name__ == "__main__": + unittest.main()