diff --git a/dev/gen_synthetic_data.py b/dev/gen_synthetic_data.py index 73f4ac9..068824f 100644 --- a/dev/gen_synthetic_data.py +++ b/dev/gen_synthetic_data.py @@ -37,7 +37,7 @@ from concurrent.futures import ThreadPoolExecutor, as_completed from nanochat.common import get_base_dir -api_key = open("openroutertoken.txt", 'r', encoding='utf-8').read().strip() +api_key = open("openroutertoken.txt", "r", encoding="utf-8").read().strip() url = "https://openrouter.ai/api/v1/chat/completions" headers = { @@ -45,7 +45,7 @@ headers = { "Content-Type": "application/json" } -readme = open("README.md", 'r', encoding='utf-8').read().strip() +readme = open("README.md", "r", encoding="utf-8").read().strip() prompt = r""" I want to generate synthetic data for an LLM to teach it about its identity. Here is the identity I want: diff --git a/nanochat/checkpoint_manager.py b/nanochat/checkpoint_manager.py index e1a7d91..378b0ed 100644 --- a/nanochat/checkpoint_manager.py +++ b/nanochat/checkpoint_manager.py @@ -34,7 +34,7 @@ def save_checkpoint(checkpoint_dir, step, model_data, optimizer_data, meta_data) log0(f"Saved optimizer file to: {optimizer_path}") # Save the metadata dict as json meta_path = os.path.join(checkpoint_dir, f"meta_{step:06d}.json") - with open(meta_path, "w", encoding='utf-8') as f: + with open(meta_path, "w", encoding="utf-8") as f: json.dump(meta_data, f, indent=2) log0(f"Saved metadata file to: {meta_path}") @@ -50,7 +50,7 @@ def load_checkpoint(checkpoint_dir, step, device, load_optimizer=False): optimizer_data = torch.load(optimizer_path, map_location=device) # Load the metadata meta_path = os.path.join(checkpoint_dir, f"meta_{step:06d}.json") - with open(meta_path, "r", encoding='utf-8') as f: + with open(meta_path, "r", encoding="utf-8") as f: meta_data = json.load(f) return model_data, optimizer_data, meta_data diff --git a/nanochat/report.py b/nanochat/report.py index 2f65e9d..0b0ebd7 100644 --- a/nanochat/report.py +++ b/nanochat/report.py @@ -241,7 +241,7 @@ class Report: slug = slugify(section) file_name = f"{slug}.md" file_path = os.path.join(self.report_dir, file_name) - with open(file_path, "w", encoding='utf-8') as f: + with open(file_path, "w", encoding="utf-8") as f: f.write(f"## {section}\n") f.write(f"timestamp: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n") for item in data: @@ -272,11 +272,11 @@ class Report: final_metrics = {} # the most important final metrics we'll add as table at the end start_time = None end_time = None - with open(report_file, "w", encoding='utf-8') as out_file: + with open(report_file, "w", encoding="utf-8") as out_file: # write the header first header_file = os.path.join(report_dir, "header.md") if os.path.exists(header_file): - with open(header_file, "r", encoding='utf-8') as f: + with open(header_file, "r", encoding="utf-8") as f: header_content = f.read() out_file.write(header_content) start_time = extract_timestamp(header_content, "Run started:") @@ -293,7 +293,7 @@ class Report: if not os.path.exists(section_file): print(f"Warning: {section_file} does not exist, skipping") continue - with open(section_file, "r", encoding='utf-8') as in_file: + with open(section_file, "r", encoding="utf-8") as in_file: section = in_file.read() # Extract timestamp from this section (the last section's timestamp will "stick" as end_time) if "rl" not in file_name: @@ -373,7 +373,7 @@ class Report: header_file = os.path.join(self.report_dir, "header.md") header = generate_header() start_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") - with open(header_file, "w", encoding='utf-8') as f: + with open(header_file, "w", encoding="utf-8") as f: f.write(header) f.write(f"Run started: {start_time}\n\n---\n\n") print(f"Reset report and wrote header to {header_file}") diff --git a/scripts/chat_web.py b/scripts/chat_web.py index 5d0b44a..4b67b62 100644 --- a/scripts/chat_web.py +++ b/scripts/chat_web.py @@ -243,7 +243,7 @@ app.add_middleware( async def root(): """Serve the chat UI.""" ui_html_path = os.path.join("nanochat", "ui.html") - with open(ui_html_path, "r", encoding='utf-8') as f: + with open(ui_html_path, "r", encoding="utf-8") as f: html_content = f.read() # Replace the API_URL to use the same origin html_content = html_content.replace( diff --git a/tests/test_rustbpe.py b/tests/test_rustbpe.py index bad3c92..aca67fc 100644 --- a/tests/test_rustbpe.py +++ b/tests/test_rustbpe.py @@ -455,13 +455,13 @@ def enwik8_path(): @pytest.fixture(scope="module") def enwik8_small(enwik8_path): """Fixture providing 100KB of enwik8 for quick tests.""" - with open(enwik8_path, "r", encoding='utf-8') as f: + with open(enwik8_path, "r", encoding="utf-8") as f: return f.read(100_000) @pytest.fixture(scope="module") def enwik8_large(enwik8_path): """Fixture providing 10MB of enwik8 for performance tests.""" - with open(enwik8_path, "r", encoding='utf-8') as f: + with open(enwik8_path, "r", encoding="utf-8") as f: return f.read(10**7) def time_function(func, *args, **kwargs):