diff --git a/extra/amax/cast_amax.py b/extra/amax/cast_amax.py index d641255010..6369b18144 100644 --- a/extra/amax/cast_amax.py +++ b/extra/amax/cast_amax.py @@ -48,10 +48,10 @@ def _fused_quantize_bwd_w13(gradient:UOp, kernel:UOp): if isinstance(device, tuple): axis, ndev = xw13.axis, len(device) assert axis in (0, 1), f"unsupported sharding axis={axis}" - grad_xw13 = Tensor(Tensor.invalid(*_shard_shape(xw13.shape, axis, ndev), dtype=dtypes.bfloat16, device=device).uop.multi(axis), device=device) + grad_xw13 = Tensor(Tensor.invalids(*_shard_shape(xw13.shape, axis, ndev), dtype=dtypes.bfloat16, device=device).uop.multi(axis), device=device) dname = device[0].split(":")[0] else: - grad_xw13 = Tensor.invalid(*xw13.shape, dtype=dtypes.bfloat16, device=device) + grad_xw13 = Tensor.invalids(*xw13.shape, dtype=dtypes.bfloat16, device=device) dname = device.split(":")[0] if isinstance(device, str) else device grad_x2_t = Tensor(gradient, device=device).cast(dtypes.bfloat16) fxn = functools.partial(_custom_fused_bwd_w13, dname=dname) @@ -67,12 +67,12 @@ def fused_quantize_fp8_w13(xw13:Tensor, amax_state:Tensor, fp8_dtype) -> tuple[T if isinstance(xw13.device, tuple): axis, ndev = xw13.uop.axis, len(xw13.device) assert axis in (0, 1), f"unsupported sharding axis={axis}" - fp8_out = Tensor(Tensor.invalid(*_shard_shape((MBS, SEQ, HIDDEN), axis, ndev), dtype=fp8_dtype, device=xw13.device).uop.multi(axis), device=xw13.device) - amax_buf = Tensor(Tensor.invalid(NUM_WG, dtype=dtypes.bfloat16, device=xw13.device).uop.multi(0), device=xw13.device) + fp8_out = Tensor(Tensor.invalids(*_shard_shape((MBS, SEQ, HIDDEN), axis, ndev), dtype=fp8_dtype, device=xw13.device).uop.multi(axis), device=xw13.device) + amax_buf = Tensor(Tensor.invalids(NUM_WG, dtype=dtypes.bfloat16, device=xw13.device).uop.multi(0), device=xw13.device) dname = xw13.device[0].split(":")[0] else: - fp8_out = Tensor.invalid(MBS, SEQ, HIDDEN, dtype=fp8_dtype, device=xw13.device) - amax_buf = Tensor.invalid(NUM_WG, dtype=dtypes.bfloat16, device=xw13.device) + fp8_out = Tensor.invalids(MBS, SEQ, HIDDEN, dtype=fp8_dtype, device=xw13.device) + amax_buf = Tensor.invalids(NUM_WG, dtype=dtypes.bfloat16, device=xw13.device) dname = xw13.device.split(":")[0] if isinstance(xw13.device, str) else xw13.device fxn = functools.partial(_custom_fused_cast_amax_w13, dname=dname) fp8_out, amax_buf, *_ = Tensor.custom_kernel(fp8_out, amax_buf, xw13, amax_state, fxn=fxn, grad_fxn=_fused_quantize_bwd_w13) diff --git a/extra/gemm/cdna_asm_gemm.py b/extra/gemm/cdna_asm_gemm.py index d3927d7c95..f24dbd2847 100644 --- a/extra/gemm/cdna_asm_gemm.py +++ b/extra/gemm/cdna_asm_gemm.py @@ -2745,14 +2745,14 @@ def asm_gemm(a:Tensor, b:Tensor, x_scale:Tensor|None=None, w_scale:Tensor|None=N if is_multi: if n_sharded: - out = Tensor(Tensor.invalid(batch, M, N//len(a.device), dtype=out_dtype, device=a.device).uop.multi(2), device=a.device) + out = Tensor(Tensor.invalids(batch, M, N//len(a.device), dtype=out_dtype, device=a.device).uop.multi(2), device=a.device) elif m_sharded: - out = Tensor(Tensor.invalid(batch, M, N, dtype=out_dtype, device=a.device).uop.multi(1), device=a.device) + out = Tensor(Tensor.invalids(batch, M, N, dtype=out_dtype, device=a.device).uop.multi(1), device=a.device) else: - out = Tensor(Tensor.invalid(batch//len(a.device) if a.uop.axis==0 else batch, M, N, dtype=out_dtype, device=a.device).uop.multi(0), + out = Tensor(Tensor.invalids(batch//len(a.device) if a.uop.axis==0 else batch, M, N, dtype=out_dtype, device=a.device).uop.multi(0), device=a.device) else: - out = Tensor.invalid(batch, M, N, dtype=out_dtype, device=a.device) + out = Tensor.invalids(batch, M, N, dtype=out_dtype, device=a.device) renderer = Device[dname:=(a.device[0] if is_multi else a.device)].renderer dname, arch = dname.split(":")[0], renderer.target.arch diff --git a/extra/thunder/amd/fa.py b/extra/thunder/amd/fa.py index fb33ae8124..e22cb5f55a 100644 --- a/extra/thunder/amd/fa.py +++ b/extra/thunder/amd/fa.py @@ -10,11 +10,11 @@ from tinygrad.uop.ops import UOp, Ops, KernelInfo def _sharded_empty(shape:Tensor, ref:Tensor, axis:int|None, dtype:DTypeLike|None=None) -> Tensor: dtype = dtype or ref.dtype - if not isinstance(ref.device, tuple): return Tensor.invalid(*shape, dtype=dtype, device=ref.device) + if not isinstance(ref.device, tuple): return Tensor.invalids(*shape, dtype=dtype, device=ref.device) shard_axis = ref.uop.axis if axis is None else axis shape = tuple(s // len(ref.device) if i == shard_axis else s for i, s in enumerate(shape)) axis = ref.uop.axis if axis is None else axis - return Tensor(Tensor.invalid(*shape, dtype=dtype, device=ref.device).uop.multi(axis), dtype=dtype, device=ref.device) + return Tensor(Tensor.invalids(*shape, dtype=dtype, device=ref.device).uop.multi(axis), dtype=dtype, device=ref.device) def _sharded_empty_like(ref:Tensor, axis:int|None=None) -> Tensor: return _sharded_empty(ref.shape, ref, axis) diff --git a/test/backend/test_custom_kernel.py b/test/backend/test_custom_kernel.py index 3c4efa9ae2..ba4d834200 100644 --- a/test/backend/test_custom_kernel.py +++ b/test/backend/test_custom_kernel.py @@ -299,8 +299,8 @@ class TestCustomKernel(unittest.TestCase): from tinygrad import function @function(precompile=True) def run(x:Tensor, w:Tensor) -> Tensor: - out = Tensor.invalid(*x.shape, dtype=x.dtype) - tmp = Tensor.invalid(*x.shape, dtype=x.dtype) + out = Tensor.invalids(*x.shape, dtype=x.dtype) + tmp = Tensor.invalids(*x.shape, dtype=x.dtype) out, tmp = Tensor.custom_kernel(out, tmp, x, w, fxn=custom_add_with_tmp)[:2] return out+tmp diff --git a/test/backend/test_setitem.py b/test/backend/test_setitem.py index 2582c2383a..0833cbadc6 100644 --- a/test/backend/test_setitem.py +++ b/test/backend/test_setitem.py @@ -293,7 +293,7 @@ class TestSetitem(unittest.TestCase): np.testing.assert_allclose(b.numpy(), [0, 2, 4, 6]) def test_setitem_multiple_disjoint_on_invalid(self): - z = Tensor.invalid(10, dtype="int").realize() + z = Tensor.invalids(10, dtype="int").realize() z[2:5] = 2 z[6:7] = 3 z.realize() diff --git a/test/unit/test_function.py b/test/unit/test_function.py index f3eb9088b2..cd498b92f7 100644 --- a/test/unit/test_function.py +++ b/test/unit/test_function.py @@ -430,8 +430,8 @@ class TestFunctionTuple(unittest.TestCase): @function(precompile=True, precompile_backward=True) def f(a:Tensor): - c = Tensor.invalid(*a.shape, dtype=a.dtype, device=a.device) - d = Tensor.invalid(3, dtype=a.dtype, device=a.device) + c = Tensor.invalids(*a.shape, dtype=a.dtype, device=a.device) + d = Tensor.invalids(3, dtype=a.dtype, device=a.device) c, d = Tensor.custom_kernel(c, d, a, fxn=my_kernel, grad_fxn=my_grad)[:2] return c, d @@ -454,8 +454,8 @@ class TestFunctionTuple(unittest.TestCase): @function(precompile=True, precompile_backward=True) def f(a:Tensor): - c = Tensor.invalid(*a.shape, dtype=a.dtype, device=a.device) - d = Tensor.invalid(*a.shape, dtype=a.dtype, device=a.device) + c = Tensor.invalids(*a.shape, dtype=a.dtype, device=a.device) + d = Tensor.invalids(*a.shape, dtype=a.dtype, device=a.device) c, d = Tensor.custom_kernel(c, d, a, fxn=my_kernel, grad_fxn=my_grad)[:2] return (c, d) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 721dd2d591..84e00af809 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -645,7 +645,7 @@ class Tensor(OpMixin): # ***** creation helper functions ***** @staticmethod - def invalid(*shape, **kwargs) -> Tensor: + def invalids(*shape, **kwargs) -> Tensor: """ Creates a tensor with the given shape, filled with Invalid.