diff --git a/nanochat/common.py b/nanochat/common.py index ee02a6e..d4a9828 100644 --- a/nanochat/common.py +++ b/nanochat/common.py @@ -5,10 +5,10 @@ Common utilities for nanochat. import os import re import logging -import fcntl import urllib.request import torch import torch.distributed as dist +from filelock import FileLock class ColoredFormatter(logging.Formatter): """Custom formatter that adds colors to log messages.""" @@ -70,13 +70,11 @@ def download_file_with_lock(url, filename, postprocess_fn=None): if os.path.exists(file_path): return file_path - with open(lock_path, 'w', encoding='utf-8') as lock_file: - + with FileLock(lock_path): # Only a single rank can acquire this lock # All other ranks block until it is released - fcntl.flock(lock_file.fileno(), fcntl.LOCK_EX) - # Recheck after acquiring lock (another process may have downloaded it) + # Recheck after acquiring lock if os.path.exists(file_path): return file_path @@ -94,12 +92,6 @@ def download_file_with_lock(url, filename, postprocess_fn=None): if postprocess_fn is not None: postprocess_fn(file_path) - # Clean up the lock file after the lock is released - try: - os.remove(lock_path) - except OSError: - pass # Ignore if already removed by another process - return file_path def print0(s="",**kwargs):