From c07f13c438f52692fe2fb0b5d1d6acbd59222c42 Mon Sep 17 00:00:00 2001 From: Eitan Turok <150733043+eitanturok@users.noreply.github.com> Date: Thu, 29 May 2025 13:49:02 +0300 Subject: [PATCH] Docs for masked_fill (#10558) * add docs * fix doc examples * add to docs * fix typo --- docs/tensor/ops.md | 1 + tinygrad/tensor.py | 19 +++++++++++++++++-- 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/docs/tensor/ops.md b/docs/tensor/ops.md index f772b974ad..465add4a2f 100644 --- a/docs/tensor/ops.md +++ b/docs/tensor/ops.md @@ -37,6 +37,7 @@ ::: tinygrad.Tensor.scatter ::: tinygrad.Tensor.scatter_reduce ::: tinygrad.Tensor.masked_select +::: tinygrad.Tensor.masked_fill ::: tinygrad.Tensor.sort ::: tinygrad.Tensor.topk diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 7ff31b9219..edc14ce229 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -1608,6 +1608,23 @@ class Tensor(MathTrait): idxs = counts.scatter(0, mask_cumsum, 1, reduce='add').cumsum() return x[idxs] + def masked_fill(self:Tensor, mask:Tensor, value:Tensor|ConstType) -> Tensor: + """ + Replace `self` with `value` wherever the elements of `mask` are True. + + ```python exec="true" source="above" session="tensor" result="python" + t = Tensor([1, 2, 3, 4, 5]) + mask = Tensor([True, False, True, False, False]) + print(t.masked_fill(mask, -12).numpy()) + ``` + ```python exec="true" source="above" session="tensor" result="python" + t = Tensor([1, 2, 3, 4, 5]) + mask = Tensor([True, False, True, False, False]) + value = Tensor([-1, -2, -3, -4, -5]) + print(t.masked_fill(mask, value).numpy()) + """ + return mask.where(value, self) + # ***** reduce ops ***** def _reduce(self, op:Ops, axis:int|Sequence[int]|None=None, keepdim=False) -> Tensor: @@ -3679,8 +3696,6 @@ class Tensor(MathTrait): cond, y = cond._broadcasted(y, match_dtype=False) return cond.cast(dtypes.bool)._apply_uop(UOp.where, *x._broadcasted(y)) - def masked_fill(self:Tensor, mask:Tensor, value:Tensor|ConstType) -> Tensor: return mask.where(value, self) - def copysign(self, other) -> Tensor: """ Return a tensor of with the magnitude of `self` and the sign of `other`, elementwise.