mirror of
https://github.com/karpathy/nanochat.git
synced 2026-01-20 10:23:42 +00:00
Merge branch 'master' into fix/fa3-fallback-mps
This commit is contained in:
commit
38e4e0dd7b
|
|
@ -82,10 +82,10 @@ That said, to give a sense, the example changes needed for the [speedrun.sh](spe
|
|||
python -m nanochat.dataset -n 450 &
|
||||
...
|
||||
# use --depth to increase model size. to not oom, halve device batch size 32 -> 16:
|
||||
torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- --depth=26 --device_batch_size=16
|
||||
torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- --depth=26 --device-batch-size=16
|
||||
...
|
||||
# make sure to use the same later during midtraining:
|
||||
torchrun --standalone --nproc_per_node=8 -m scripts.mid_train -- --device_batch_size=16
|
||||
torchrun --standalone --nproc_per_node=8 -m scripts.mid_train -- --device-batch-size=16
|
||||
```
|
||||
|
||||
That's it! The biggest thing to pay attention to is making sure you have enough data shards to train on (the code will loop and do more epochs over the same training set otherwise, decreasing learning speed a bit), and managing your memory/VRAM, primarily by decreasing the `device_batch_size` until things fit (the scripts automatically compensate by increasing the number of gradient accumulation loops, simply turning parallel compute to sequential compute).
|
||||
|
|
|
|||
26
dev/LOG.md
26
dev/LOG.md
|
|
@ -4,6 +4,32 @@ A running summary documenting some experiments and findings. Started ~Jan 7 2026
|
|||
|
||||
---
|
||||
|
||||
## 2026-01-16: Modded-nanogpt Ideas Sweep (Mostly Negative)
|
||||
|
||||
Tested several architectural ideas from modded-nanogpt to see if they transfer to nanochat. All of these did not help:
|
||||
|
||||
| Idea | Result | Notes |
|
||||
|------|--------|-------|
|
||||
| Half-truncated RoPE | No improvement | Only first half of head dims get RoPE (base 1024, linspace). Second half "stationary". |
|
||||
| Asymmetric softcap | Slightly worse | `23 * sigmoid((x+5)/7.5)` vs our symmetric `15 * tanh(x/15)`. May only help with FP8. |
|
||||
| Smear gate | Negligible | Blend each token with predecessor via learned gate. Tiny improvement not worth n_embd² params. |
|
||||
| Backout | No improvement | Save activations at ~60% through network, subtract scaled version at end. |
|
||||
| Skip connection | Slightly worse | Save at layer ~25%, add at layer ~50%. Also +2GB memory from storing activations. |
|
||||
|
||||
Value Embeddings do show promise. I need a more elaborate exploration of a few related ideas, which I leave for tomorrow.
|
||||
|
||||
---
|
||||
|
||||
## 2026-01-15: Olmo pretraining mix (Negative result)
|
||||
|
||||
I attempted to train on the Olmo 3 pretraining dataset [allenai/dolma3_mix-6T](https://huggingface.co/datasets/allenai/dolma3_mix-6T) instead of FineWeb-edu. I ran into a number of [errors and issues](https://huggingface.co/datasets/allenai/dolma3_mix-6T/discussions/2) trying to both download and process the dataset and then noticed some quality issues (e.g. some documents seem to be extremely short, like "5".). I managed to work around these with some sensible hacks (e.g. reject documents less than 100 characters in length) and tried to process the dataset exactly as FineWeb, re-trained the tokenizer and trained a d16 model. The CORE score decreased from 15.5 to 13.8, i.e. the result is quite a bit worse.
|
||||
|
||||
I am still looking to try the [DCLM dataset](https://arxiv.org/abs/2406.11794), which according to the paper should be better that FineWeb-edu. I do have some concerns that the same group both prepared the DCLM dataset *and* introduced the CORE score so I'm a bit hesitant in case there was some overfitting to CORE score adjacent data distribution.
|
||||
|
||||
Classifying as negative result and reverting back to FineWeb-edu for now.
|
||||
|
||||
---
|
||||
|
||||
## 2026-01-13: Varlen Attention (Negative Result)
|
||||
|
||||
Attempted to prevent attention from "leaking" across document boundaries using Flash Attention's `flash_attn_varlen_func`, similar to modded-nanogpt's approach.
|
||||
|
|
|
|||
|
|
@ -1,11 +1,42 @@
|
|||
"""
|
||||
Borrowed from modded-nanogpt. By Keller, @vagrawal, et al.
|
||||
Not a general optimizer! But works for our specific use.
|
||||
Distributed AdamW optimizer with a fused step function.
|
||||
A bunch of ideas (e.g. dist comms in slices) are borrowed from modded-nanogpt.
|
||||
"""
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch import Tensor
|
||||
|
||||
@torch.compile(dynamic=False, fullgraph=True)
|
||||
def adamw_step_fused(
|
||||
p: Tensor,
|
||||
grad: Tensor,
|
||||
exp_avg: Tensor,
|
||||
exp_avg_sq: Tensor,
|
||||
step_t: Tensor,
|
||||
lr_t: Tensor,
|
||||
beta1_t: Tensor,
|
||||
beta2_t: Tensor,
|
||||
eps_t: Tensor,
|
||||
wd_t: Tensor,
|
||||
) -> None:
|
||||
"""
|
||||
Fused AdamW step: weight_decay -> momentum_update -> bias_correction -> param_update
|
||||
All in one compiled graph to eliminate Python overhead between ops.
|
||||
The 0-D CPU tensors avoid recompilation when hyperparameter values change.
|
||||
"""
|
||||
# Weight decay (decoupled, applied before the update)
|
||||
p.mul_(1 - lr_t * wd_t)
|
||||
# Update running averages (lerp_ is cleaner and fuses well)
|
||||
exp_avg.lerp_(grad, 1 - beta1_t)
|
||||
exp_avg_sq.lerp_(grad.square(), 1 - beta2_t)
|
||||
# Bias corrections
|
||||
bias1 = 1 - beta1_t ** step_t
|
||||
bias2 = 1 - beta2_t ** step_t
|
||||
# Compute update and apply
|
||||
denom = (exp_avg_sq / bias2).sqrt() + eps_t
|
||||
step_size = lr_t / bias1
|
||||
p.add_(exp_avg / denom, alpha=-step_size)
|
||||
|
||||
|
||||
class DistAdamW(torch.optim.Optimizer):
|
||||
"""
|
||||
|
|
@ -14,7 +45,26 @@ class DistAdamW(torch.optim.Optimizer):
|
|||
"""
|
||||
def __init__(self, param_groups, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01):
|
||||
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
|
||||
rank = dist.get_rank()
|
||||
world_size = dist.get_world_size()
|
||||
# Validate
|
||||
if rank == 0:
|
||||
for group in param_groups:
|
||||
assert isinstance(group, dict), "expecting param_groups to be a list of dicts"
|
||||
assert isinstance(group['params'], list), "expecting group['params'] to be a list of tensors"
|
||||
for p in group['params']:
|
||||
sliced = p.numel() >= 1024
|
||||
print(f"AdamW: 1 param of shape {p.shape}, sliced={sliced}")
|
||||
if sliced: # large parameter tensors will be operated on in slices
|
||||
assert p.shape[0] % world_size == 0, f"First dim of parameter shape {p.shape} must be divisible by world size {world_size}"
|
||||
super().__init__(param_groups, defaults)
|
||||
# 0-D CPU tensors to avoid torch.compile recompilation when values change
|
||||
self._step_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._beta1_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._eps_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self):
|
||||
|
|
@ -36,8 +86,7 @@ class DistAdamW(torch.optim.Optimizer):
|
|||
grad_slices.append(grad)
|
||||
else:
|
||||
is_small.append(False)
|
||||
assert p.shape[0] % world_size == 0, f"First dim of parameter shape {p.shape} must be divisible by world size {world_size}"
|
||||
rank_size = grad.shape[0] // world_size
|
||||
rank_size = grad.shape[0] // world_size # p.shape[0] % world_size == 0 is checked in __init__
|
||||
grad_slice = torch.empty_like(grad[:rank_size])
|
||||
reduce_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future())
|
||||
grad_slices.append(grad_slice)
|
||||
|
|
@ -63,28 +112,27 @@ class DistAdamW(torch.optim.Optimizer):
|
|||
|
||||
# State init
|
||||
if not state:
|
||||
state['step'] = torch.tensor(0, dtype=torch.int64, device=p.device)
|
||||
state['step'] = 0
|
||||
state['exp_avg'] = torch.zeros_like(p_slice)
|
||||
state['exp_avg_sq'] = torch.zeros_like(p_slice)
|
||||
exp_avg = state['exp_avg']
|
||||
exp_avg_sq = state['exp_avg_sq']
|
||||
state['step'] += 1
|
||||
t = state['step']
|
||||
# weight decay
|
||||
if wd != 0:
|
||||
eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0)
|
||||
p_slice.mul_(1 - eff_weight_decay)
|
||||
# update running averages
|
||||
exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1)
|
||||
exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2)
|
||||
# bias corrections
|
||||
bias1 = 1 - beta1 ** t
|
||||
bias2 = 1 - beta2 ** t
|
||||
# compute step
|
||||
denom = (exp_avg_sq / bias2).sqrt().add_(eps)
|
||||
step_size = lr / bias1
|
||||
update = exp_avg.div(denom).mul_(step_size)
|
||||
p_slice.add_(other=update, alpha=-1.0)
|
||||
|
||||
# Fill 0-D tensors with current values
|
||||
eff_wd = wd * getattr(p, "wd_mul", 1.0)
|
||||
self._step_t.fill_(state['step'])
|
||||
self._lr_t.fill_(lr)
|
||||
self._beta1_t.fill_(beta1)
|
||||
self._beta2_t.fill_(beta2)
|
||||
self._eps_t.fill_(eps)
|
||||
self._wd_t.fill_(eff_wd)
|
||||
|
||||
# Fused update: weight_decay -> momentum -> bias_correction -> param_update
|
||||
adamw_step_fused(
|
||||
p_slice, g_slice, exp_avg, exp_avg_sq,
|
||||
self._step_t, self._lr_t, self._beta1_t, self._beta2_t, self._eps_t, self._wd_t,
|
||||
)
|
||||
|
||||
# Only large params need all_gather
|
||||
if not is_small[idx]:
|
||||
|
|
|
|||
|
|
@ -111,7 +111,7 @@ def build_model(checkpoint_dir, step, device, phase):
|
|||
# Load the Tokenizer
|
||||
tokenizer = get_tokenizer()
|
||||
# Sanity check: compatibility between model and tokenizer
|
||||
assert tokenizer.get_vocab_size() == model_config_kwargs["vocab_size"]
|
||||
assert tokenizer.get_vocab_size() == model_config_kwargs["vocab_size"], f"Tokenizer vocab size {tokenizer.get_vocab_size()} does not match model config vocab size {model_config_kwargs['vocab_size']}"
|
||||
return model, tokenizer, meta_data
|
||||
|
||||
|
||||
|
|
|
|||
455
nanochat/muon.py
455
nanochat/muon.py
|
|
@ -1,7 +1,27 @@
|
|||
"""
|
||||
Muon optimizer adapted (simplified) from modded-nanogpt.
|
||||
Muon optimizer adapted and simplified from modded-nanogpt.
|
||||
https://github.com/KellerJordan/modded-nanogpt
|
||||
|
||||
Background:
|
||||
Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
|
||||
quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
|
||||
of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
|
||||
zero even beyond the point where the iteration no longer converges all the way to one everywhere
|
||||
on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
|
||||
where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
|
||||
performance at all relative to UV^T, where USV^T = G is the SVD.
|
||||
|
||||
Here, an alternative to Newton-Schulz iteration with potentially better convergence properties:
|
||||
Polar Express Sign Method for orthogonalization.
|
||||
https://arxiv.org/pdf/2505.16932
|
||||
by Noah Amsel, David Persson, Christopher Musco, Robert M. Gower.
|
||||
|
||||
Some of the changes in nanochat implementation:
|
||||
- Uses a simpler, more general approach to parameter grouping and stacking
|
||||
- Uses a single fused kernel for the momentum -> polar_express -> variance_reduction -> update step
|
||||
- Makes no assumptions about model architecture (e.g. that attention weights are fused into QKVO format)
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
import torch.distributed as dist
|
||||
|
|
@ -16,97 +36,61 @@ polar_express_coeffs = [
|
|||
(2.3465413258596377, -1.7097828382687081, 0.42323551169305323),
|
||||
]
|
||||
|
||||
|
||||
@torch.compile
|
||||
def zeropower_via_polar_express(G: Tensor, steps: int = 5) -> Tensor:
|
||||
@torch.compile(dynamic=False, fullgraph=True)
|
||||
def muon_step_fused(
|
||||
stacked_grads: Tensor,
|
||||
stacked_params: Tensor,
|
||||
momentum_buffer: Tensor,
|
||||
second_momentum_buffer: Tensor,
|
||||
momentum_t: Tensor,
|
||||
lr_t: Tensor,
|
||||
wd_t: Tensor,
|
||||
beta2_t: Tensor,
|
||||
ns_steps: int,
|
||||
red_dim: int,
|
||||
) -> None:
|
||||
"""
|
||||
Polar Express Sign Method for orthogonalization.
|
||||
https://arxiv.org/pdf/2505.16932
|
||||
by Noah Amsel, David Persson, Christopher Musco, Robert M. Gower.
|
||||
|
||||
Alternative to Newton-Schulz iteration with potentially better convergence properties.
|
||||
Fused Muon step: momentum -> polar_express -> variance_reduction -> cautious_update
|
||||
All in one compiled graph to eliminate Python overhead between ops.
|
||||
Some of the constants are 0-D CPU tensors to avoid recompilation when values change.
|
||||
"""
|
||||
assert G.ndim >= 2
|
||||
X = G.bfloat16()
|
||||
if G.size(-2) > G.size(-1):
|
||||
|
||||
# Nesterov momentum
|
||||
momentum = momentum_t.to(stacked_grads.dtype)
|
||||
momentum_buffer.lerp_(stacked_grads, 1 - momentum)
|
||||
g = stacked_grads.lerp_(momentum_buffer, momentum)
|
||||
|
||||
# Polar express
|
||||
X = g.bfloat16()
|
||||
if g.size(-2) > g.size(-1):
|
||||
X = X.mT
|
||||
|
||||
# Ensure spectral norm is at most 1 (with 2% safety factor)
|
||||
X = X / (X.norm(dim=(-2, -1), keepdim=True) * 1.02 + 1e-6)
|
||||
|
||||
# Perform the iterations (cap at available coefficients)
|
||||
for a, b, c in polar_express_coeffs[:min(steps, len(polar_express_coeffs))]:
|
||||
for a, b, c in polar_express_coeffs[:ns_steps]:
|
||||
A = X @ X.mT
|
||||
B = b * A + c * (A @ A)
|
||||
X = a * X + B @ X
|
||||
|
||||
if G.size(-2) > G.size(-1):
|
||||
if g.size(-2) > g.size(-1):
|
||||
X = X.mT
|
||||
return X
|
||||
g = X
|
||||
|
||||
|
||||
@torch.compile
|
||||
def zeropower_via_newtonschulz5(G: Tensor, steps: int) -> Tensor:
|
||||
"""
|
||||
Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
|
||||
quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
|
||||
of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
|
||||
zero even beyond the point where the iteration no longer converges all the way to one everywhere
|
||||
on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
|
||||
where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
|
||||
performance at all relative to UV^T, where USV^T = G is the SVD.
|
||||
"""
|
||||
assert G.ndim >= 2 # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng
|
||||
a, b, c = (3.4445, -4.7750, 2.0315)
|
||||
X = G.bfloat16()
|
||||
if G.size(-2) > G.size(-1):
|
||||
X = X.mT
|
||||
|
||||
# Ensure spectral norm is at most 1
|
||||
X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)
|
||||
# Perform the NS iterations
|
||||
for _ in range(steps):
|
||||
A = X @ X.mT
|
||||
B = b * A + c * A @ A # quintic computation strategy adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng
|
||||
X = a * X + B @ X
|
||||
|
||||
if G.size(-2) > G.size(-1):
|
||||
X = X.mT
|
||||
return X
|
||||
|
||||
|
||||
@torch.compile
|
||||
def apply_variance_reduction(v: Tensor, second_momentum_buffer: Tensor, beta2: float) -> Tensor:
|
||||
"""
|
||||
NorMuon-style variance reduction, similar to Adafactor's low-rank variance estimator.
|
||||
https://arxiv.org/pdf/2510.05491
|
||||
|
||||
Normalizes updates based on a running estimate of per-row (or per-column) variance.
|
||||
The reduction dimension is determined by the shape of second_momentum_buffer.
|
||||
"""
|
||||
# Determine reduction dimension from buffer shape
|
||||
red_dim = -1 if second_momentum_buffer.size(-1) == 1 else -2
|
||||
|
||||
# Compute per-row/col mean of squared values
|
||||
v_mean = v.float().square().mean(dim=red_dim, keepdim=True)
|
||||
red_dim_size = v.size(red_dim)
|
||||
|
||||
# Compute current norm
|
||||
# Variance reduction
|
||||
beta2 = beta2_t.to(g.dtype)
|
||||
v_mean = g.float().square().mean(dim=red_dim, keepdim=True)
|
||||
red_dim_size = g.size(red_dim)
|
||||
v_norm_sq = v_mean.sum(dim=(-2, -1), keepdim=True) * red_dim_size
|
||||
v_norm = v_norm_sq.sqrt()
|
||||
|
||||
# Update second momentum buffer (EMA of variance)
|
||||
second_momentum_buffer.lerp_(v_mean.to(dtype=second_momentum_buffer.dtype), 1 - beta2)
|
||||
|
||||
# Compute scaling factor from second momentum
|
||||
step_size = second_momentum_buffer.clamp_min(1e-10).rsqrt()
|
||||
scaled_sq_sum = (v_mean * red_dim_size) * step_size.float().square()
|
||||
v_norm_new = scaled_sq_sum.sum(dim=(-2, -1), keepdim=True).sqrt()
|
||||
|
||||
# Final scale preserves overall norm while adjusting per-row/col
|
||||
final_scale = step_size * (v_norm / v_norm_new.clamp_min(1e-10))
|
||||
return v.mul(final_scale.to(v.dtype))
|
||||
g = g * final_scale.to(g.dtype)
|
||||
|
||||
# Cautious weight decay + parameter update
|
||||
lr = lr_t.to(g.dtype)
|
||||
wd = wd_t.to(g.dtype)
|
||||
mask = (g * stacked_params) >= 0
|
||||
stacked_params.sub_(lr * g + lr * wd * stacked_params * mask)
|
||||
|
||||
class Muon(torch.optim.Optimizer):
|
||||
"""
|
||||
|
|
@ -127,94 +111,112 @@ class Muon(torch.optim.Optimizer):
|
|||
Arguments:
|
||||
lr: The learning rate used by the internal SGD.
|
||||
momentum: The momentum used by the internal SGD.
|
||||
nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended)
|
||||
ns_steps: The number of Newton-Schulz iteration steps to use.
|
||||
beta2: The decay rate for the second moment (variance) estimate. Set to None to disable.
|
||||
weight_decay: Cautious weight decay coefficient. Only decays where update and weight agree.
|
||||
"""
|
||||
def __init__(self, params, lr=0.02, momentum=0.95, nesterov=True, ns_steps=5, beta2=0.95, weight_decay=0.0):
|
||||
defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps, beta2=beta2, weight_decay=weight_decay)
|
||||
params: list[Tensor] = [*params]
|
||||
def __init__(self, params, lr=0.02, momentum=0.95, ns_steps=5, beta2=0.95, weight_decay=0.0):
|
||||
defaults = dict(lr=lr, momentum=momentum, ns_steps=ns_steps, beta2=beta2, weight_decay=weight_decay)
|
||||
assert all(p.ndim == 2 for p in params), "Muon expects 2D parameters only"
|
||||
params = list(params) # ensure we have a list, not an e.g. (exhaustible) iterator
|
||||
# Group by shape so we can stack tensors
|
||||
shapes = sorted({p.shape for p in params})
|
||||
param_groups = []
|
||||
for size in {p.numel() for p in params}:
|
||||
group = dict(params=[p for p in params if p.numel() == size])
|
||||
param_groups.append(group)
|
||||
for shape in shapes:
|
||||
group_params = [p for p in params if p.shape == shape]
|
||||
param_groups.append(dict(params=group_params))
|
||||
super().__init__(param_groups, defaults)
|
||||
# 0-D CPU tensors to avoid torch.compile recompilation when values change
|
||||
self._momentum_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self):
|
||||
for group in self.param_groups:
|
||||
params: list[Tensor] = group["params"]
|
||||
for p in params:
|
||||
g = p.grad
|
||||
assert g is not None
|
||||
state = self.state[p]
|
||||
if "momentum_buffer" not in state:
|
||||
state["momentum_buffer"] = torch.zeros_like(g)
|
||||
buf: Tensor = state["momentum_buffer"]
|
||||
buf.lerp_(g, 1 - group["momentum"])
|
||||
g = g.lerp_(buf, group["momentum"]) if group["nesterov"] else buf
|
||||
g = zeropower_via_polar_express(g, steps=group["ns_steps"])
|
||||
# Variance reduction (NorMuon-style)
|
||||
if group["beta2"] is not None:
|
||||
if "second_momentum_buffer" not in state:
|
||||
# Buffer shape determines reduction dim: reduce along larger dimension
|
||||
if p.size(-2) >= p.size(-1):
|
||||
state["second_momentum_buffer"] = torch.zeros_like(g[..., :1])
|
||||
else:
|
||||
state["second_momentum_buffer"] = torch.zeros_like(g[..., :1, :])
|
||||
g = apply_variance_reduction(g, state["second_momentum_buffer"], group["beta2"])
|
||||
# Parameter update with cautious weight decay
|
||||
effective_lr = group["lr"] * max(1, p.size(-2) / p.size(-1))**0.5
|
||||
wd = group["weight_decay"]
|
||||
if wd != 0:
|
||||
mask = (g * p) >= 0
|
||||
p.sub_(effective_lr * g + effective_lr * wd * p * mask)
|
||||
if not params:
|
||||
continue
|
||||
|
||||
# Get or create group-level buffers (stored in first param's state for convenience)
|
||||
state = self.state[params[0]]
|
||||
num_params = len(params) # e.g.: 12 (for a d12 model)
|
||||
# e.g.: shape = (768, 3072), device = cuda:0, dtype = torch.float32, for one of the MLP projections
|
||||
shape, device, dtype = params[0].shape, params[0].device, params[0].dtype
|
||||
|
||||
# Momentum for every individual parameter
|
||||
if "momentum_buffer" not in state:
|
||||
state["momentum_buffer"] = torch.zeros(num_params, *shape, dtype=dtype, device=device)
|
||||
momentum_buffer = state["momentum_buffer"] # e.g.: (12, 768, 3072)
|
||||
|
||||
# Second momentum buffer is factored, either per-row or per-column
|
||||
if "second_momentum_buffer" not in state:
|
||||
if shape[-2] >= shape[-1]:
|
||||
state["second_momentum_buffer"] = torch.zeros(num_params, shape[-2], 1, dtype=dtype, device=device)
|
||||
else:
|
||||
p.sub_(effective_lr * g)
|
||||
state["second_momentum_buffer"] = torch.zeros(num_params, 1, shape[-1], dtype=dtype, device=device)
|
||||
second_momentum_buffer = state["second_momentum_buffer"] # (12, 1, 3072)
|
||||
red_dim = -1 if shape[-2] >= shape[-1] else -2 # e.g.: -2
|
||||
|
||||
# Stack grads and params
|
||||
stacked_grads = torch.stack([p.grad for p in params]) # (12, 768, 3072)
|
||||
stacked_params = torch.stack(params) # (12, 768, 3072)
|
||||
|
||||
# Fill all the 0-D tensors with current values
|
||||
self._momentum_t.fill_(group["momentum"])
|
||||
self._beta2_t.fill_(group["beta2"] if group["beta2"] is not None else 0.0)
|
||||
self._lr_t.fill_(group["lr"] * max(1.0, shape[-2] / shape[-1])**0.5)
|
||||
self._wd_t.fill_(group["weight_decay"])
|
||||
|
||||
# Single fused kernel: momentum -> polar_express -> variance_reduction -> update
|
||||
muon_step_fused(
|
||||
stacked_grads,
|
||||
stacked_params,
|
||||
momentum_buffer,
|
||||
second_momentum_buffer,
|
||||
self._momentum_t,
|
||||
self._lr_t,
|
||||
self._wd_t,
|
||||
self._beta2_t,
|
||||
group["ns_steps"],
|
||||
red_dim,
|
||||
)
|
||||
|
||||
# Copy back to original params: [(768, 3072), (768, 3072), ...] <- (12, 768, 3072)
|
||||
torch._foreach_copy_(params, list(stacked_params.unbind(0)))
|
||||
|
||||
|
||||
class DistMuon(torch.optim.Optimizer):
|
||||
"""
|
||||
Muon: SGD-momentum + (optional) Nesterov, then orthogonalize the 2D update via Polar Express,
|
||||
finally apply aspect-ratio scaled step. Performs its own distributed synchronization:
|
||||
- reduce_scatter(AVG) for gradient averaging
|
||||
- all_gather to replicate updated weights
|
||||
|
||||
Notes:
|
||||
* Designed for 2D parameters (e.g., linear/conv kernels reshaped to 2D). Do not use for 0D/1D
|
||||
params like embeddings or scalars.
|
||||
* Momentum buffers are maintained only on the 'owner' rank for each parameter (rank chosen
|
||||
by block-cyclic assignment below). If you checkpoint optimizer state on a single rank,
|
||||
consolidate states beforehand.
|
||||
|
||||
Args:
|
||||
params: iterable of Tensors
|
||||
lr: learning rate
|
||||
momentum: momentum coefficient in [0,1)
|
||||
nesterov: if True, Nesterov-style update (g <- lerp(g, buf, momentum)); else use buf
|
||||
ns_steps: number of Newton-Schulz iterations for the orthogonalization
|
||||
beta2: decay rate for second moment (variance) estimate. Set to None to disable.
|
||||
weight_decay: Cautious weight decay coefficient. Only decays where update and weight agree.
|
||||
Distributed version of the Muon optimizer.
|
||||
"""
|
||||
def __init__(self, params, lr: float = 0.02, momentum: float = 0.95,
|
||||
nesterov: bool = True, ns_steps: int = 5, beta2: float = 0.95, weight_decay: float = 0.0):
|
||||
defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps, beta2=beta2, weight_decay=weight_decay)
|
||||
params = list(params)
|
||||
ns_steps: int = 5, beta2: float = 0.95, weight_decay: float = 0.0):
|
||||
defaults = dict(lr=lr, momentum=momentum, ns_steps=ns_steps, beta2=beta2, weight_decay=weight_decay)
|
||||
assert all(p.ndim == 2 for p in params), "Muon expects 2D parameters only"
|
||||
params = list(params)
|
||||
world_size = dist.get_world_size()
|
||||
rank = dist.get_rank()
|
||||
# Group all parameters by their shape
|
||||
shapes = sorted({p.shape for p in params}) # sort to ensure consistent / deterministic ordering
|
||||
shapes = sorted({p.shape for p in params}) # sort for deterministic ordering across ranks
|
||||
param_groups = []
|
||||
for shape in shapes:
|
||||
group_params = [p for p in params if p.shape == shape]
|
||||
device, dtype = group_params[0].device, group_params[0].dtype
|
||||
assert all(p.device == device for p in group_params)
|
||||
assert all(p.dtype == dtype for p in group_params)
|
||||
# Compute chunk size for this group (how many params each rank owns)
|
||||
chunk_size = (len(group_params) + world_size - 1) // world_size
|
||||
if rank == 0:
|
||||
print(f"Muon: Grouping {len(group_params)} params of shape {shape}, device {device}, dtype {dtype}")
|
||||
param_groups.append(dict(params=group_params, zero_buffer=torch.zeros_like(group_params[0])))
|
||||
print(f"Muon: {len(group_params)} params of shape {shape}, chunk_size={chunk_size}")
|
||||
param_groups.append(dict(params=group_params, chunk_size=chunk_size))
|
||||
super().__init__(param_groups, defaults)
|
||||
# 0-D CPU tensors to avoid torch.compile recompilation when values change
|
||||
self._momentum_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self):
|
||||
|
|
@ -224,72 +226,127 @@ class DistMuon(torch.optim.Optimizer):
|
|||
# Ensure all grads exist
|
||||
assert all(p.grad is not None for group in self.param_groups for p in group["params"]), "All params must have grads"
|
||||
|
||||
# Kick off all the reduce scatter operations to average up the gradients across all ranks
|
||||
all_reduce_futures = []
|
||||
# First pass: stack grads and kick off reduce_scatter for each group
|
||||
group_infos = []
|
||||
for group in self.param_groups:
|
||||
params = group["params"]
|
||||
zero_buffer = group["zero_buffer"]
|
||||
# Go through params in groups of world_size.
|
||||
for base_i in range(0, len(params), world_size):
|
||||
# The compute owner of each param is rank i % world_size
|
||||
owner_idx = base_i + rank
|
||||
# each rank stacks up its chunk of world_size params into a list
|
||||
rs_input = [p.grad for p in params[base_i:base_i + world_size]]
|
||||
# pad rs_input with the zero buffer to complete the group
|
||||
rs_input.extend([zero_buffer] * (world_size - len(rs_input)))
|
||||
# the output buffer gets strided across the group based on the rank
|
||||
rs_output = params[owner_idx].grad if owner_idx < len(params) else torch.empty_like(zero_buffer)
|
||||
# reduce scatter the gradients within this group of world_size params
|
||||
work = dist.reduce_scatter(rs_output, rs_input, op=dist.ReduceOp.AVG, async_op=True).get_future()
|
||||
all_reduce_futures.append(work)
|
||||
params: list[Tensor] = group["params"]
|
||||
chunk_size = group["chunk_size"]
|
||||
padded_num_params = chunk_size * world_size
|
||||
shape = params[0].shape
|
||||
device, dtype = params[0].device, params[0].dtype
|
||||
|
||||
# Now each rank computes the update and gathers
|
||||
future_idx = 0
|
||||
# Stack all gradients into a single tensor (single kernel via torch.stack)
|
||||
grad_stack = torch.stack([p.grad for p in params])
|
||||
stacked_grads = torch.empty(padded_num_params, *shape, dtype=dtype, device=device)
|
||||
stacked_grads[:len(params)].copy_(grad_stack)
|
||||
# Zero-pad if we have fewer params than padded size
|
||||
if len(params) < padded_num_params:
|
||||
stacked_grads[len(params):].zero_()
|
||||
|
||||
# Output buffer for this rank's chunk
|
||||
grad_chunk = torch.empty(chunk_size, *shape, dtype=dtype, device=device)
|
||||
|
||||
# Async reduce_scatter on the stacked tensor
|
||||
reduce_future = dist.reduce_scatter_tensor(
|
||||
grad_chunk, stacked_grads, op=dist.ReduceOp.AVG, async_op=True
|
||||
).get_future()
|
||||
|
||||
group_infos.append(dict(
|
||||
grad_chunk=grad_chunk,
|
||||
reduce_future=reduce_future,
|
||||
stacked_grads=stacked_grads, # reuse for all_gather output
|
||||
))
|
||||
|
||||
# Second pass: wait for reduce, compute batched updates, kick off all_gather
|
||||
all_gather_futures = []
|
||||
for group in self.param_groups:
|
||||
params = group["params"]
|
||||
zero_buffer = group["zero_buffer"]
|
||||
# Go through params in groups of world_size.
|
||||
for base_i in range(0, len(params), world_size):
|
||||
# The compute owner of each param is rank i % world_size
|
||||
owner_idx = base_i + rank # calculate the index of the param that this rank owns
|
||||
# Wait for the reduce scatter to complete
|
||||
all_reduce_futures[future_idx].wait() # possibly later we could use wait_any polling instead
|
||||
future_idx += 1
|
||||
# Owner computes the Muon update, result is in its param
|
||||
if owner_idx < len(params):
|
||||
p = params[owner_idx]
|
||||
g = p.grad # now averaged across ranks
|
||||
state = self.state[p]
|
||||
if "momentum_buffer" not in state:
|
||||
state["momentum_buffer"] = torch.zeros_like(g)
|
||||
buf: Tensor = state["momentum_buffer"]
|
||||
buf.lerp_(g, 1.0 - group["momentum"])
|
||||
g = g.lerp_(buf, group["momentum"]) if group["nesterov"] else buf
|
||||
g = zeropower_via_polar_express(g, steps=group["ns_steps"])
|
||||
# Variance reduction (NorMuon-style)
|
||||
if group["beta2"] is not None:
|
||||
if "second_momentum_buffer" not in state:
|
||||
# Buffer shape determines reduction dim: reduce along larger dimension
|
||||
if p.size(-2) >= p.size(-1):
|
||||
state["second_momentum_buffer"] = torch.zeros_like(g[..., :1])
|
||||
else:
|
||||
state["second_momentum_buffer"] = torch.zeros_like(g[..., :1, :])
|
||||
g = apply_variance_reduction(g, state["second_momentum_buffer"], group["beta2"])
|
||||
# Parameter update with cautious weight decay
|
||||
effective_lr = group["lr"] * (max(1.0, p.size(-2) / p.size(-1)) ** 0.5)
|
||||
wd = group["weight_decay"]
|
||||
if wd != 0:
|
||||
mask = (g * p) >= 0
|
||||
p.sub_(effective_lr * g + effective_lr * wd * p * mask)
|
||||
else:
|
||||
p.sub_(effective_lr * g)
|
||||
# Replicate updated parameters to all ranks
|
||||
ag_input = params[owner_idx] if owner_idx < len(params) else zero_buffer
|
||||
ag_output = params[base_i:base_i + world_size]
|
||||
ag_output.extend([torch.empty_like(zero_buffer) for _ in range(world_size - len(ag_output))]) # pad
|
||||
work = dist.all_gather(ag_output, ag_input, async_op=True).get_future()
|
||||
all_gather_futures.append(work)
|
||||
for group, info in zip(self.param_groups, group_infos):
|
||||
info["reduce_future"].wait()
|
||||
|
||||
# Wait for all work to finish
|
||||
torch.futures.collect_all(all_gather_futures).wait()
|
||||
params = group["params"]
|
||||
chunk_size = group["chunk_size"]
|
||||
shape = params[0].shape
|
||||
device, dtype = params[0].device, params[0].dtype
|
||||
grad_chunk = info["grad_chunk"]
|
||||
|
||||
# How many params does this rank actually own?
|
||||
start_idx = rank * chunk_size
|
||||
num_owned = min(chunk_size, max(0, len(params) - start_idx))
|
||||
|
||||
# Get or create group-level state (stored keyed by first param)
|
||||
state = self.state[params[0]]
|
||||
|
||||
# Momentum buffer
|
||||
if "momentum_buffer" not in state:
|
||||
state["momentum_buffer"] = torch.zeros(chunk_size, *shape, dtype=dtype, device=device)
|
||||
momentum_buffer = state["momentum_buffer"]
|
||||
|
||||
# Second momentum buffer is factored, either per-row or per-column
|
||||
if "second_momentum_buffer" not in state:
|
||||
if shape[-2] >= shape[-1]:
|
||||
state["second_momentum_buffer"] = torch.zeros(chunk_size, shape[-2], 1, dtype=dtype, device=device)
|
||||
else:
|
||||
state["second_momentum_buffer"] = torch.zeros(chunk_size, 1, shape[-1], dtype=dtype, device=device)
|
||||
second_momentum_buffer = state["second_momentum_buffer"]
|
||||
red_dim = -1 if shape[-2] >= shape[-1] else -2
|
||||
|
||||
# Build updated_params tensor for all_gather
|
||||
updated_params = torch.empty(chunk_size, *shape, dtype=dtype, device=device)
|
||||
|
||||
if num_owned > 0:
|
||||
# Stack owned params (single kernel via torch.stack)
|
||||
owned_params = [params[start_idx + i] for i in range(num_owned)]
|
||||
stacked_owned_params = torch.stack(owned_params)
|
||||
|
||||
# Get owned slices of buffers and grads
|
||||
owned_grads = grad_chunk[:num_owned]
|
||||
owned_momentum = momentum_buffer[:num_owned]
|
||||
owned_second_momentum = second_momentum_buffer[:num_owned]
|
||||
|
||||
# Fill 0-D tensors with current values
|
||||
self._momentum_t.fill_(group["momentum"])
|
||||
self._beta2_t.fill_(group["beta2"] if group["beta2"] is not None else 0.0)
|
||||
self._lr_t.fill_(group["lr"] * max(1.0, shape[-2] / shape[-1])**0.5)
|
||||
self._wd_t.fill_(group["weight_decay"])
|
||||
|
||||
# Single fused kernel: momentum -> polar_express -> variance_reduction -> update
|
||||
muon_step_fused(
|
||||
owned_grads,
|
||||
stacked_owned_params,
|
||||
owned_momentum,
|
||||
owned_second_momentum,
|
||||
self._momentum_t,
|
||||
self._lr_t,
|
||||
self._wd_t,
|
||||
self._beta2_t,
|
||||
group["ns_steps"],
|
||||
red_dim,
|
||||
)
|
||||
|
||||
# Copy updated params to output buffer
|
||||
updated_params[:num_owned].copy_(stacked_owned_params)
|
||||
|
||||
# Zero-pad the rest (for ranks that own fewer params)
|
||||
if num_owned < chunk_size:
|
||||
updated_params[num_owned:].zero_()
|
||||
|
||||
# Reuse stacked_grads buffer for all_gather output
|
||||
stacked_params = info["stacked_grads"]
|
||||
|
||||
# Async all_gather to replicate updated params to all ranks
|
||||
gather_future = dist.all_gather_into_tensor(
|
||||
stacked_params, updated_params, async_op=True
|
||||
).get_future()
|
||||
|
||||
all_gather_futures.append(dict(
|
||||
gather_future=gather_future,
|
||||
stacked_params=stacked_params,
|
||||
params=params,
|
||||
))
|
||||
|
||||
# Final pass: wait for all_gather and copy back to params
|
||||
for info in all_gather_futures:
|
||||
info["gather_future"].wait()
|
||||
stacked_params = info["stacked_params"]
|
||||
params = info["params"]
|
||||
# Batched copy back (single kernel instead of N individual copies)
|
||||
torch._foreach_copy_(params, list(stacked_params[:len(params)].unbind(0)))
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ dependencies = [
|
|||
"transformers>=4.57.3",
|
||||
"uvicorn>=0.36.0",
|
||||
"wandb>=0.21.3",
|
||||
"zstandard>=0.25.0",
|
||||
]
|
||||
|
||||
[dependency-groups]
|
||||
|
|
|
|||
|
|
@ -208,7 +208,6 @@ if resuming:
|
|||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Initialize the DataLoaders for train/val
|
||||
tokens_dir = os.path.join(base_dir, "tokenized_data")
|
||||
dataloader_resume_state_dict = None if not resuming else meta_data["dataloader_state_dict"]
|
||||
train_loader = tokenizing_distributed_data_loader_with_state_bos_bestfit(tokenizer, args.device_batch_size, args.max_seq_len, split="train", device=device, resume_state_dict=dataloader_resume_state_dict)
|
||||
build_val_loader = lambda: tokenizing_distributed_data_loader_bos_bestfit(tokenizer, args.device_batch_size, args.max_seq_len, split="val", device=device)
|
||||
|
|
@ -370,14 +369,15 @@ while True:
|
|||
for opt in optimizers:
|
||||
opt.step()
|
||||
model.zero_grad(set_to_none=True)
|
||||
train_loss_f = train_loss.item() # .item() is a CPU-GPU sync point
|
||||
synchronize()
|
||||
t1 = time.time()
|
||||
dt = t1 - t0
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
# logging
|
||||
# logging (CPU action only)
|
||||
ema_beta = 0.9 # EMA decay factor for some smoothing just for nicer logging
|
||||
smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss.item() # EMA the training loss
|
||||
smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss_f # EMA the training loss
|
||||
debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**(step + 1)) # debias the EMA
|
||||
pct_done = 100 * step / num_iterations
|
||||
tok_per_sec = int(args.total_batch_size / dt)
|
||||
|
|
|
|||
|
|
@ -4,8 +4,8 @@ All the generic code lives here, and all the evaluation-specific
|
|||
code lives in nanochat directory and is imported from here.
|
||||
|
||||
Example runs:
|
||||
python -m scripts.chat_eval -a ARC-Easy
|
||||
torchrun --nproc_per_node=8 -m scripts.chat_eval -- -a ARC-Easy
|
||||
python -m scripts.chat_eval -i mid -a ARC-Easy
|
||||
torchrun --nproc_per_node=8 -m scripts.chat_eval -- -i mid -a ARC-Easy
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ simpler and more similar to just REINFORCE:
|
|||
|
||||
1) Delete trust region, so there is no KL regularization to a reference model
|
||||
2) We are on policy, so there's no need for PPO ratio+clip.
|
||||
3) We use GAPO style normalization that is token-level, not sequence-level.
|
||||
3) We use DAPO style normalization that is token-level, not sequence-level.
|
||||
4) Instead of z-score normalization (r - mu)/sigma, only use (r - mu) as the advantage.
|
||||
|
||||
1 GPU:
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ class CustomJSON(Task):
|
|||
print("-" * 80)
|
||||
print(f"Warning: File {filepath} does not exist")
|
||||
print("HINT (Oct 21 2025)")
|
||||
print("If you recently did a git pull and suddely see this, it might be due to the new addition of identity conversations")
|
||||
print("If you recently did a git pull and suddenly see this, it might be due to the new addition of identity conversations")
|
||||
print("See this discussion for more details: https://github.com/karpathy/nanochat/discussions/139")
|
||||
print("Quick fix: simply run the following command to download the file and you're done:")
|
||||
print(f"curl -L -o {filepath} https://karpathy-public.s3.us-west-2.amazonaws.com/identity_conversations.jsonl")
|
||||
|
|
|
|||
92
uv.lock
92
uv.lock
|
|
@ -1513,6 +1513,7 @@ dependencies = [
|
|||
{ name = "transformers" },
|
||||
{ name = "uvicorn" },
|
||||
{ name = "wandb" },
|
||||
{ name = "zstandard" },
|
||||
]
|
||||
|
||||
[package.optional-dependencies]
|
||||
|
|
@ -1551,6 +1552,7 @@ requires-dist = [
|
|||
{ name = "transformers", specifier = ">=4.57.3" },
|
||||
{ name = "uvicorn", specifier = ">=0.36.0" },
|
||||
{ name = "wandb", specifier = ">=0.21.3" },
|
||||
{ name = "zstandard", specifier = ">=0.25.0" },
|
||||
]
|
||||
provides-extras = ["cpu", "gpu"]
|
||||
|
||||
|
|
@ -3619,3 +3621,93 @@ wheels = [
|
|||
{ url = "https://files.pythonhosted.org/packages/94/c3/b2e9f38bc3e11191981d57ea08cab2166e74ea770024a646617c9cddd9f6/yarl-1.20.1-cp313-cp313t-win_amd64.whl", hash = "sha256:541d050a355bbbc27e55d906bc91cb6fe42f96c01413dd0f4ed5a5240513874f", size = 93003, upload-time = "2025-06-10T00:45:27.752Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/b4/2d/2345fce04cfd4bee161bf1e7d9cdc702e3e16109021035dbb24db654a622/yarl-1.20.1-py3-none-any.whl", hash = "sha256:83b8eb083fe4683c6115795d9fc1cfaf2cbbefb19b3a1cb68f6527460f483a77", size = 46542, upload-time = "2025-06-10T00:46:07.521Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "zstandard"
|
||||
version = "0.25.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/fd/aa/3e0508d5a5dd96529cdc5a97011299056e14c6505b678fd58938792794b1/zstandard-0.25.0.tar.gz", hash = "sha256:7713e1179d162cf5c7906da876ec2ccb9c3a9dcbdffef0cc7f70c3667a205f0b", size = 711513, upload-time = "2025-09-14T22:15:54.002Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/56/7a/28efd1d371f1acd037ac64ed1c5e2b41514a6cc937dd6ab6a13ab9f0702f/zstandard-0.25.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e59fdc271772f6686e01e1b3b74537259800f57e24280be3f29c8a0deb1904dd", size = 795256, upload-time = "2025-09-14T22:15:56.415Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/96/34/ef34ef77f1ee38fc8e4f9775217a613b452916e633c4f1d98f31db52c4a5/zstandard-0.25.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:4d441506e9b372386a5271c64125f72d5df6d2a8e8a2a45a0ae09b03cb781ef7", size = 640565, upload-time = "2025-09-14T22:15:58.177Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/9d/1b/4fdb2c12eb58f31f28c4d28e8dc36611dd7205df8452e63f52fb6261d13e/zstandard-0.25.0-cp310-cp310-manylinux2010_i686.manylinux2014_i686.manylinux_2_12_i686.manylinux_2_17_i686.whl", hash = "sha256:ab85470ab54c2cb96e176f40342d9ed41e58ca5733be6a893b730e7af9c40550", size = 5345306, upload-time = "2025-09-14T22:16:00.165Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/73/28/a44bdece01bca027b079f0e00be3b6bd89a4df180071da59a3dd7381665b/zstandard-0.25.0-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:e05ab82ea7753354bb054b92e2f288afb750e6b439ff6ca78af52939ebbc476d", size = 5055561, upload-time = "2025-09-14T22:16:02.22Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/e9/74/68341185a4f32b274e0fc3410d5ad0750497e1acc20bd0f5b5f64ce17785/zstandard-0.25.0-cp310-cp310-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:78228d8a6a1c177a96b94f7e2e8d012c55f9c760761980da16ae7546a15a8e9b", size = 5402214, upload-time = "2025-09-14T22:16:04.109Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/8b/67/f92e64e748fd6aaffe01e2b75a083c0c4fd27abe1c8747fee4555fcee7dd/zstandard-0.25.0-cp310-cp310-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:2b6bd67528ee8b5c5f10255735abc21aa106931f0dbaf297c7be0c886353c3d0", size = 5449703, upload-time = "2025-09-14T22:16:06.312Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/fd/e5/6d36f92a197c3c17729a2125e29c169f460538a7d939a27eaaa6dcfcba8e/zstandard-0.25.0-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:4b6d83057e713ff235a12e73916b6d356e3084fd3d14ced499d84240f3eecee0", size = 5556583, upload-time = "2025-09-14T22:16:08.457Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/d7/83/41939e60d8d7ebfe2b747be022d0806953799140a702b90ffe214d557638/zstandard-0.25.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:9174f4ed06f790a6869b41cba05b43eeb9a35f8993c4422ab853b705e8112bbd", size = 5045332, upload-time = "2025-09-14T22:16:10.444Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/b3/87/d3ee185e3d1aa0133399893697ae91f221fda79deb61adbe998a7235c43f/zstandard-0.25.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:25f8f3cd45087d089aef5ba3848cd9efe3ad41163d3400862fb42f81a3a46701", size = 5572283, upload-time = "2025-09-14T22:16:12.128Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/0a/1d/58635ae6104df96671076ac7d4ae7816838ce7debd94aecf83e30b7121b0/zstandard-0.25.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:3756b3e9da9b83da1796f8809dd57cb024f838b9eeafde28f3cb472012797ac1", size = 4959754, upload-time = "2025-09-14T22:16:14.225Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/75/d6/57e9cb0a9983e9a229dd8fd2e6e96593ef2aa82a3907188436f22b111ccd/zstandard-0.25.0-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:81dad8d145d8fd981b2962b686b2241d3a1ea07733e76a2f15435dfb7fb60150", size = 5266477, upload-time = "2025-09-14T22:16:16.343Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/d1/a9/ee891e5edf33a6ebce0a028726f0bbd8567effe20fe3d5808c42323e8542/zstandard-0.25.0-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:a5a419712cf88862a45a23def0ae063686db3d324cec7edbe40509d1a79a0aab", size = 5440914, upload-time = "2025-09-14T22:16:18.453Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/58/08/a8522c28c08031a9521f27abc6f78dbdee7312a7463dd2cfc658b813323b/zstandard-0.25.0-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:e7360eae90809efd19b886e59a09dad07da4ca9ba096752e61a2e03c8aca188e", size = 5819847, upload-time = "2025-09-14T22:16:20.559Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/6f/11/4c91411805c3f7b6f31c60e78ce347ca48f6f16d552fc659af6ec3b73202/zstandard-0.25.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:75ffc32a569fb049499e63ce68c743155477610532da1eb38e7f24bf7cd29e74", size = 5363131, upload-time = "2025-09-14T22:16:22.206Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ef/d6/8c4bd38a3b24c4c7676a7a3d8de85d6ee7a983602a734b9f9cdefb04a5d6/zstandard-0.25.0-cp310-cp310-win32.whl", hash = "sha256:106281ae350e494f4ac8a80470e66d1fe27e497052c8d9c3b95dc4cf1ade81aa", size = 436469, upload-time = "2025-09-14T22:16:25.002Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/93/90/96d50ad417a8ace5f841b3228e93d1bb13e6ad356737f42e2dde30d8bd68/zstandard-0.25.0-cp310-cp310-win_amd64.whl", hash = "sha256:ea9d54cc3d8064260114a0bbf3479fc4a98b21dffc89b3459edd506b69262f6e", size = 506100, upload-time = "2025-09-14T22:16:23.569Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/2a/83/c3ca27c363d104980f1c9cee1101cc8ba724ac8c28a033ede6aab89585b1/zstandard-0.25.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:933b65d7680ea337180733cf9e87293cc5500cc0eb3fc8769f4d3c88d724ec5c", size = 795254, upload-time = "2025-09-14T22:16:26.137Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ac/4d/e66465c5411a7cf4866aeadc7d108081d8ceba9bc7abe6b14aa21c671ec3/zstandard-0.25.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:a3f79487c687b1fc69f19e487cd949bf3aae653d181dfb5fde3bf6d18894706f", size = 640559, upload-time = "2025-09-14T22:16:27.973Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/12/56/354fe655905f290d3b147b33fe946b0f27e791e4b50a5f004c802cb3eb7b/zstandard-0.25.0-cp311-cp311-manylinux2010_i686.manylinux2014_i686.manylinux_2_12_i686.manylinux_2_17_i686.whl", hash = "sha256:0bbc9a0c65ce0eea3c34a691e3c4b6889f5f3909ba4822ab385fab9057099431", size = 5348020, upload-time = "2025-09-14T22:16:29.523Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/3b/13/2b7ed68bd85e69a2069bcc72141d378f22cae5a0f3b353a2c8f50ef30c1b/zstandard-0.25.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:01582723b3ccd6939ab7b3a78622c573799d5d8737b534b86d0e06ac18dbde4a", size = 5058126, upload-time = "2025-09-14T22:16:31.811Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/c9/dd/fdaf0674f4b10d92cb120ccff58bbb6626bf8368f00ebfd2a41ba4a0dc99/zstandard-0.25.0-cp311-cp311-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:5f1ad7bf88535edcf30038f6919abe087f606f62c00a87d7e33e7fc57cb69fcc", size = 5405390, upload-time = "2025-09-14T22:16:33.486Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/0f/67/354d1555575bc2490435f90d67ca4dd65238ff2f119f30f72d5cde09c2ad/zstandard-0.25.0-cp311-cp311-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:06acb75eebeedb77b69048031282737717a63e71e4ae3f77cc0c3b9508320df6", size = 5452914, upload-time = "2025-09-14T22:16:35.277Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/bb/1f/e9cfd801a3f9190bf3e759c422bbfd2247db9d7f3d54a56ecde70137791a/zstandard-0.25.0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:9300d02ea7c6506f00e627e287e0492a5eb0371ec1670ae852fefffa6164b072", size = 5559635, upload-time = "2025-09-14T22:16:37.141Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/21/88/5ba550f797ca953a52d708c8e4f380959e7e3280af029e38fbf47b55916e/zstandard-0.25.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:bfd06b1c5584b657a2892a6014c2f4c20e0db0208c159148fa78c65f7e0b0277", size = 5048277, upload-time = "2025-09-14T22:16:38.807Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/46/c0/ca3e533b4fa03112facbe7fbe7779cb1ebec215688e5df576fe5429172e0/zstandard-0.25.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:f373da2c1757bb7f1acaf09369cdc1d51d84131e50d5fa9863982fd626466313", size = 5574377, upload-time = "2025-09-14T22:16:40.523Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/12/9b/3fb626390113f272abd0799fd677ea33d5fc3ec185e62e6be534493c4b60/zstandard-0.25.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:6c0e5a65158a7946e7a7affa6418878ef97ab66636f13353b8502d7ea03c8097", size = 4961493, upload-time = "2025-09-14T22:16:43.3Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/cb/d3/23094a6b6a4b1343b27ae68249daa17ae0651fcfec9ed4de09d14b940285/zstandard-0.25.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:c8e167d5adf59476fa3e37bee730890e389410c354771a62e3c076c86f9f7778", size = 5269018, upload-time = "2025-09-14T22:16:45.292Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/8c/a7/bb5a0c1c0f3f4b5e9d5b55198e39de91e04ba7c205cc46fcb0f95f0383c1/zstandard-0.25.0-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:98750a309eb2f020da61e727de7d7ba3c57c97cf6213f6f6277bb7fb42a8e065", size = 5443672, upload-time = "2025-09-14T22:16:47.076Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/27/22/503347aa08d073993f25109c36c8d9f029c7d5949198050962cb568dfa5e/zstandard-0.25.0-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:22a086cff1b6ceca18a8dd6096ec631e430e93a8e70a9ca5efa7561a00f826fa", size = 5822753, upload-time = "2025-09-14T22:16:49.316Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/e2/be/94267dc6ee64f0f8ba2b2ae7c7a2df934a816baaa7291db9e1aa77394c3c/zstandard-0.25.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:72d35d7aa0bba323965da807a462b0966c91608ef3a48ba761678cb20ce5d8b7", size = 5366047, upload-time = "2025-09-14T22:16:51.328Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/7b/a3/732893eab0a3a7aecff8b99052fecf9f605cf0fb5fb6d0290e36beee47a4/zstandard-0.25.0-cp311-cp311-win32.whl", hash = "sha256:f5aeea11ded7320a84dcdd62a3d95b5186834224a9e55b92ccae35d21a8b63d4", size = 436484, upload-time = "2025-09-14T22:16:55.005Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/43/a3/c6155f5c1cce691cb80dfd38627046e50af3ee9ddc5d0b45b9b063bfb8c9/zstandard-0.25.0-cp311-cp311-win_amd64.whl", hash = "sha256:daab68faadb847063d0c56f361a289c4f268706b598afbf9ad113cbe5c38b6b2", size = 506183, upload-time = "2025-09-14T22:16:52.753Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/8c/3e/8945ab86a0820cc0e0cdbf38086a92868a9172020fdab8a03ac19662b0e5/zstandard-0.25.0-cp311-cp311-win_arm64.whl", hash = "sha256:22a06c5df3751bb7dc67406f5374734ccee8ed37fc5981bf1ad7041831fa1137", size = 462533, upload-time = "2025-09-14T22:16:53.878Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/82/fc/f26eb6ef91ae723a03e16eddb198abcfce2bc5a42e224d44cc8b6765e57e/zstandard-0.25.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:7b3c3a3ab9daa3eed242d6ecceead93aebbb8f5f84318d82cee643e019c4b73b", size = 795738, upload-time = "2025-09-14T22:16:56.237Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/aa/1c/d920d64b22f8dd028a8b90e2d756e431a5d86194caa78e3819c7bf53b4b3/zstandard-0.25.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:913cbd31a400febff93b564a23e17c3ed2d56c064006f54efec210d586171c00", size = 640436, upload-time = "2025-09-14T22:16:57.774Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/53/6c/288c3f0bd9fcfe9ca41e2c2fbfd17b2097f6af57b62a81161941f09afa76/zstandard-0.25.0-cp312-cp312-manylinux2010_i686.manylinux2014_i686.manylinux_2_12_i686.manylinux_2_17_i686.whl", hash = "sha256:011d388c76b11a0c165374ce660ce2c8efa8e5d87f34996aa80f9c0816698b64", size = 5343019, upload-time = "2025-09-14T22:16:59.302Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/1e/15/efef5a2f204a64bdb5571e6161d49f7ef0fffdbca953a615efbec045f60f/zstandard-0.25.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:6dffecc361d079bb48d7caef5d673c88c8988d3d33fb74ab95b7ee6da42652ea", size = 5063012, upload-time = "2025-09-14T22:17:01.156Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/b7/37/a6ce629ffdb43959e92e87ebdaeebb5ac81c944b6a75c9c47e300f85abdf/zstandard-0.25.0-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:7149623bba7fdf7e7f24312953bcf73cae103db8cae49f8154dd1eadc8a29ecb", size = 5394148, upload-time = "2025-09-14T22:17:03.091Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/e3/79/2bf870b3abeb5c070fe2d670a5a8d1057a8270f125ef7676d29ea900f496/zstandard-0.25.0-cp312-cp312-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:6a573a35693e03cf1d67799fd01b50ff578515a8aeadd4595d2a7fa9f3ec002a", size = 5451652, upload-time = "2025-09-14T22:17:04.979Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/53/60/7be26e610767316c028a2cbedb9a3beabdbe33e2182c373f71a1c0b88f36/zstandard-0.25.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:5a56ba0db2d244117ed744dfa8f6f5b366e14148e00de44723413b2f3938a902", size = 5546993, upload-time = "2025-09-14T22:17:06.781Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/85/c7/3483ad9ff0662623f3648479b0380d2de5510abf00990468c286c6b04017/zstandard-0.25.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:10ef2a79ab8e2974e2075fb984e5b9806c64134810fac21576f0668e7ea19f8f", size = 5046806, upload-time = "2025-09-14T22:17:08.415Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/08/b3/206883dd25b8d1591a1caa44b54c2aad84badccf2f1de9e2d60a446f9a25/zstandard-0.25.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:aaf21ba8fb76d102b696781bddaa0954b782536446083ae3fdaa6f16b25a1c4b", size = 5576659, upload-time = "2025-09-14T22:17:10.164Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/9d/31/76c0779101453e6c117b0ff22565865c54f48f8bd807df2b00c2c404b8e0/zstandard-0.25.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:1869da9571d5e94a85a5e8d57e4e8807b175c9e4a6294e3b66fa4efb074d90f6", size = 4953933, upload-time = "2025-09-14T22:17:11.857Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/18/e1/97680c664a1bf9a247a280a053d98e251424af51f1b196c6d52f117c9720/zstandard-0.25.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:809c5bcb2c67cd0ed81e9229d227d4ca28f82d0f778fc5fea624a9def3963f91", size = 5268008, upload-time = "2025-09-14T22:17:13.627Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/1e/73/316e4010de585ac798e154e88fd81bb16afc5c5cb1a72eeb16dd37e8024a/zstandard-0.25.0-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:f27662e4f7dbf9f9c12391cb37b4c4c3cb90ffbd3b1fb9284dadbbb8935fa708", size = 5433517, upload-time = "2025-09-14T22:17:16.103Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/5b/60/dd0f8cfa8129c5a0ce3ea6b7f70be5b33d2618013a161e1ff26c2b39787c/zstandard-0.25.0-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:99c0c846e6e61718715a3c9437ccc625de26593fea60189567f0118dc9db7512", size = 5814292, upload-time = "2025-09-14T22:17:17.827Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/fc/5f/75aafd4b9d11b5407b641b8e41a57864097663699f23e9ad4dbb91dc6bfe/zstandard-0.25.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:474d2596a2dbc241a556e965fb76002c1ce655445e4e3bf38e5477d413165ffa", size = 5360237, upload-time = "2025-09-14T22:17:19.954Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ff/8d/0309daffea4fcac7981021dbf21cdb2e3427a9e76bafbcdbdf5392ff99a4/zstandard-0.25.0-cp312-cp312-win32.whl", hash = "sha256:23ebc8f17a03133b4426bcc04aabd68f8236eb78c3760f12783385171b0fd8bd", size = 436922, upload-time = "2025-09-14T22:17:24.398Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/79/3b/fa54d9015f945330510cb5d0b0501e8253c127cca7ebe8ba46a965df18c5/zstandard-0.25.0-cp312-cp312-win_amd64.whl", hash = "sha256:ffef5a74088f1e09947aecf91011136665152e0b4b359c42be3373897fb39b01", size = 506276, upload-time = "2025-09-14T22:17:21.429Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ea/6b/8b51697e5319b1f9ac71087b0af9a40d8a6288ff8025c36486e0c12abcc4/zstandard-0.25.0-cp312-cp312-win_arm64.whl", hash = "sha256:181eb40e0b6a29b3cd2849f825e0fa34397f649170673d385f3598ae17cca2e9", size = 462679, upload-time = "2025-09-14T22:17:23.147Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/35/0b/8df9c4ad06af91d39e94fa96cc010a24ac4ef1378d3efab9223cc8593d40/zstandard-0.25.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:ec996f12524f88e151c339688c3897194821d7f03081ab35d31d1e12ec975e94", size = 795735, upload-time = "2025-09-14T22:17:26.042Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/3f/06/9ae96a3e5dcfd119377ba33d4c42a7d89da1efabd5cb3e366b156c45ff4d/zstandard-0.25.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:a1a4ae2dec3993a32247995bdfe367fc3266da832d82f8438c8570f989753de1", size = 640440, upload-time = "2025-09-14T22:17:27.366Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/d9/14/933d27204c2bd404229c69f445862454dcc101cd69ef8c6068f15aaec12c/zstandard-0.25.0-cp313-cp313-manylinux2010_i686.manylinux2014_i686.manylinux_2_12_i686.manylinux_2_17_i686.whl", hash = "sha256:e96594a5537722fdfb79951672a2a63aec5ebfb823e7560586f7484819f2a08f", size = 5343070, upload-time = "2025-09-14T22:17:28.896Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/6d/db/ddb11011826ed7db9d0e485d13df79b58586bfdec56e5c84a928a9a78c1c/zstandard-0.25.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:bfc4e20784722098822e3eee42b8e576b379ed72cca4a7cb856ae733e62192ea", size = 5063001, upload-time = "2025-09-14T22:17:31.044Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/db/00/87466ea3f99599d02a5238498b87bf84a6348290c19571051839ca943777/zstandard-0.25.0-cp313-cp313-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:457ed498fc58cdc12fc48f7950e02740d4f7ae9493dd4ab2168a47c93c31298e", size = 5394120, upload-time = "2025-09-14T22:17:32.711Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/2b/95/fc5531d9c618a679a20ff6c29e2b3ef1d1f4ad66c5e161ae6ff847d102a9/zstandard-0.25.0-cp313-cp313-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:fd7a5004eb1980d3cefe26b2685bcb0b17989901a70a1040d1ac86f1d898c551", size = 5451230, upload-time = "2025-09-14T22:17:34.41Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/63/4b/e3678b4e776db00f9f7b2fe58e547e8928ef32727d7a1ff01dea010f3f13/zstandard-0.25.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:8e735494da3db08694d26480f1493ad2cf86e99bdd53e8e9771b2752a5c0246a", size = 5547173, upload-time = "2025-09-14T22:17:36.084Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/4e/d5/ba05ed95c6b8ec30bd468dfeab20589f2cf709b5c940483e31d991f2ca58/zstandard-0.25.0-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:3a39c94ad7866160a4a46d772e43311a743c316942037671beb264e395bdd611", size = 5046736, upload-time = "2025-09-14T22:17:37.891Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/50/d5/870aa06b3a76c73eced65c044b92286a3c4e00554005ff51962deef28e28/zstandard-0.25.0-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:172de1f06947577d3a3005416977cce6168f2261284c02080e7ad0185faeced3", size = 5576368, upload-time = "2025-09-14T22:17:40.206Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/5d/35/398dc2ffc89d304d59bc12f0fdd931b4ce455bddf7038a0a67733a25f550/zstandard-0.25.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:3c83b0188c852a47cd13ef3bf9209fb0a77fa5374958b8c53aaa699398c6bd7b", size = 4954022, upload-time = "2025-09-14T22:17:41.879Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/9a/5c/36ba1e5507d56d2213202ec2b05e8541734af5f2ce378c5d1ceaf4d88dc4/zstandard-0.25.0-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:1673b7199bbe763365b81a4f3252b8e80f44c9e323fc42940dc8843bfeaf9851", size = 5267889, upload-time = "2025-09-14T22:17:43.577Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/70/e8/2ec6b6fb7358b2ec0113ae202647ca7c0e9d15b61c005ae5225ad0995df5/zstandard-0.25.0-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:0be7622c37c183406f3dbf0cba104118eb16a4ea7359eeb5752f0794882fc250", size = 5433952, upload-time = "2025-09-14T22:17:45.271Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/7b/01/b5f4d4dbc59ef193e870495c6f1275f5b2928e01ff5a81fecb22a06e22fb/zstandard-0.25.0-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:5f5e4c2a23ca271c218ac025bd7d635597048b366d6f31f420aaeb715239fc98", size = 5814054, upload-time = "2025-09-14T22:17:47.08Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/b2/e5/fbd822d5c6f427cf158316d012c5a12f233473c2f9c5fe5ab1ae5d21f3d8/zstandard-0.25.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:4f187a0bb61b35119d1926aee039524d1f93aaf38a9916b8c4b78ac8514a0aaf", size = 5360113, upload-time = "2025-09-14T22:17:48.893Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/8e/e0/69a553d2047f9a2c7347caa225bb3a63b6d7704ad74610cb7823baa08ed7/zstandard-0.25.0-cp313-cp313-win32.whl", hash = "sha256:7030defa83eef3e51ff26f0b7bfb229f0204b66fe18e04359ce3474ac33cbc09", size = 436936, upload-time = "2025-09-14T22:17:52.658Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/d9/82/b9c06c870f3bd8767c201f1edbdf9e8dc34be5b0fbc5682c4f80fe948475/zstandard-0.25.0-cp313-cp313-win_amd64.whl", hash = "sha256:1f830a0dac88719af0ae43b8b2d6aef487d437036468ef3c2ea59c51f9d55fd5", size = 506232, upload-time = "2025-09-14T22:17:50.402Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/d4/57/60c3c01243bb81d381c9916e2a6d9e149ab8627c0c7d7abb2d73384b3c0c/zstandard-0.25.0-cp313-cp313-win_arm64.whl", hash = "sha256:85304a43f4d513f5464ceb938aa02c1e78c2943b29f44a750b48b25ac999a049", size = 462671, upload-time = "2025-09-14T22:17:51.533Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/3d/5c/f8923b595b55fe49e30612987ad8bf053aef555c14f05bb659dd5dbe3e8a/zstandard-0.25.0-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:e29f0cf06974c899b2c188ef7f783607dbef36da4c242eb6c82dcd8b512855e3", size = 795887, upload-time = "2025-09-14T22:17:54.198Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/8d/09/d0a2a14fc3439c5f874042dca72a79c70a532090b7ba0003be73fee37ae2/zstandard-0.25.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:05df5136bc5a011f33cd25bc9f506e7426c0c9b3f9954f056831ce68f3b6689f", size = 640658, upload-time = "2025-09-14T22:17:55.423Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/5d/7c/8b6b71b1ddd517f68ffb55e10834388d4f793c49c6b83effaaa05785b0b4/zstandard-0.25.0-cp314-cp314-manylinux2010_i686.manylinux_2_12_i686.manylinux_2_28_i686.whl", hash = "sha256:f604efd28f239cc21b3adb53eb061e2a205dc164be408e553b41ba2ffe0ca15c", size = 5379849, upload-time = "2025-09-14T22:17:57.372Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/a4/86/a48e56320d0a17189ab7a42645387334fba2200e904ee47fc5a26c1fd8ca/zstandard-0.25.0-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:223415140608d0f0da010499eaa8ccdb9af210a543fac54bce15babbcfc78439", size = 5058095, upload-time = "2025-09-14T22:17:59.498Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/f8/ad/eb659984ee2c0a779f9d06dbfe45e2dc39d99ff40a319895df2d3d9a48e5/zstandard-0.25.0-cp314-cp314-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:2e54296a283f3ab5a26fc9b8b5d4978ea0532f37b231644f367aa588930aa043", size = 5551751, upload-time = "2025-09-14T22:18:01.618Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/61/b3/b637faea43677eb7bd42ab204dfb7053bd5c4582bfe6b1baefa80ac0c47b/zstandard-0.25.0-cp314-cp314-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:ca54090275939dc8ec5dea2d2afb400e0f83444b2fc24e07df7fdef677110859", size = 6364818, upload-time = "2025-09-14T22:18:03.769Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/31/dc/cc50210e11e465c975462439a492516a73300ab8caa8f5e0902544fd748b/zstandard-0.25.0-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:e09bb6252b6476d8d56100e8147b803befa9a12cea144bbe629dd508800d1ad0", size = 5560402, upload-time = "2025-09-14T22:18:05.954Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/c9/ae/56523ae9c142f0c08efd5e868a6da613ae76614eca1305259c3bf6a0ed43/zstandard-0.25.0-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:a9ec8c642d1ec73287ae3e726792dd86c96f5681eb8df274a757bf62b750eae7", size = 4955108, upload-time = "2025-09-14T22:18:07.68Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/98/cf/c899f2d6df0840d5e384cf4c4121458c72802e8bda19691f3b16619f51e9/zstandard-0.25.0-cp314-cp314-musllinux_1_2_i686.whl", hash = "sha256:a4089a10e598eae6393756b036e0f419e8c1d60f44a831520f9af41c14216cf2", size = 5269248, upload-time = "2025-09-14T22:18:09.753Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/1b/c0/59e912a531d91e1c192d3085fc0f6fb2852753c301a812d856d857ea03c6/zstandard-0.25.0-cp314-cp314-musllinux_1_2_ppc64le.whl", hash = "sha256:f67e8f1a324a900e75b5e28ffb152bcac9fbed1cc7b43f99cd90f395c4375344", size = 5430330, upload-time = "2025-09-14T22:18:11.966Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/a0/1d/7e31db1240de2df22a58e2ea9a93fc6e38cc29353e660c0272b6735d6669/zstandard-0.25.0-cp314-cp314-musllinux_1_2_s390x.whl", hash = "sha256:9654dbc012d8b06fc3d19cc825af3f7bf8ae242226df5f83936cb39f5fdc846c", size = 5811123, upload-time = "2025-09-14T22:18:13.907Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/f6/49/fac46df5ad353d50535e118d6983069df68ca5908d4d65b8c466150a4ff1/zstandard-0.25.0-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:4203ce3b31aec23012d3a4cf4a2ed64d12fea5269c49aed5e4c3611b938e4088", size = 5359591, upload-time = "2025-09-14T22:18:16.465Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/c2/38/f249a2050ad1eea0bb364046153942e34abba95dd5520af199aed86fbb49/zstandard-0.25.0-cp314-cp314-win32.whl", hash = "sha256:da469dc041701583e34de852d8634703550348d5822e66a0c827d39b05365b12", size = 444513, upload-time = "2025-09-14T22:18:20.61Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/3a/43/241f9615bcf8ba8903b3f0432da069e857fc4fd1783bd26183db53c4804b/zstandard-0.25.0-cp314-cp314-win_amd64.whl", hash = "sha256:c19bcdd826e95671065f8692b5a4aa95c52dc7a02a4c5a0cac46deb879a017a2", size = 516118, upload-time = "2025-09-14T22:18:17.849Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/f0/ef/da163ce2450ed4febf6467d77ccb4cd52c4c30ab45624bad26ca0a27260c/zstandard-0.25.0-cp314-cp314-win_arm64.whl", hash = "sha256:d7541afd73985c630bafcd6338d2518ae96060075f9463d7dc14cfb33514383d", size = 476940, upload-time = "2025-09-14T22:18:19.088Z" },
|
||||
]
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user