mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
Tensor.invalid -> Tesnor.invalids (#15849)
matches ones and zeros, and to not share name with UOp.invalid
This commit is contained in:
@@ -48,10 +48,10 @@ def _fused_quantize_bwd_w13(gradient:UOp, kernel:UOp):
|
||||
if isinstance(device, tuple):
|
||||
axis, ndev = xw13.axis, len(device)
|
||||
assert axis in (0, 1), f"unsupported sharding axis={axis}"
|
||||
grad_xw13 = Tensor(Tensor.invalid(*_shard_shape(xw13.shape, axis, ndev), dtype=dtypes.bfloat16, device=device).uop.multi(axis), device=device)
|
||||
grad_xw13 = Tensor(Tensor.invalids(*_shard_shape(xw13.shape, axis, ndev), dtype=dtypes.bfloat16, device=device).uop.multi(axis), device=device)
|
||||
dname = device[0].split(":")[0]
|
||||
else:
|
||||
grad_xw13 = Tensor.invalid(*xw13.shape, dtype=dtypes.bfloat16, device=device)
|
||||
grad_xw13 = Tensor.invalids(*xw13.shape, dtype=dtypes.bfloat16, device=device)
|
||||
dname = device.split(":")[0] if isinstance(device, str) else device
|
||||
grad_x2_t = Tensor(gradient, device=device).cast(dtypes.bfloat16)
|
||||
fxn = functools.partial(_custom_fused_bwd_w13, dname=dname)
|
||||
@@ -67,12 +67,12 @@ def fused_quantize_fp8_w13(xw13:Tensor, amax_state:Tensor, fp8_dtype) -> tuple[T
|
||||
if isinstance(xw13.device, tuple):
|
||||
axis, ndev = xw13.uop.axis, len(xw13.device)
|
||||
assert axis in (0, 1), f"unsupported sharding axis={axis}"
|
||||
fp8_out = Tensor(Tensor.invalid(*_shard_shape((MBS, SEQ, HIDDEN), axis, ndev), dtype=fp8_dtype, device=xw13.device).uop.multi(axis), device=xw13.device)
|
||||
amax_buf = Tensor(Tensor.invalid(NUM_WG, dtype=dtypes.bfloat16, device=xw13.device).uop.multi(0), device=xw13.device)
|
||||
fp8_out = Tensor(Tensor.invalids(*_shard_shape((MBS, SEQ, HIDDEN), axis, ndev), dtype=fp8_dtype, device=xw13.device).uop.multi(axis), device=xw13.device)
|
||||
amax_buf = Tensor(Tensor.invalids(NUM_WG, dtype=dtypes.bfloat16, device=xw13.device).uop.multi(0), device=xw13.device)
|
||||
dname = xw13.device[0].split(":")[0]
|
||||
else:
|
||||
fp8_out = Tensor.invalid(MBS, SEQ, HIDDEN, dtype=fp8_dtype, device=xw13.device)
|
||||
amax_buf = Tensor.invalid(NUM_WG, dtype=dtypes.bfloat16, device=xw13.device)
|
||||
fp8_out = Tensor.invalids(MBS, SEQ, HIDDEN, dtype=fp8_dtype, device=xw13.device)
|
||||
amax_buf = Tensor.invalids(NUM_WG, dtype=dtypes.bfloat16, device=xw13.device)
|
||||
dname = xw13.device.split(":")[0] if isinstance(xw13.device, str) else xw13.device
|
||||
fxn = functools.partial(_custom_fused_cast_amax_w13, dname=dname)
|
||||
fp8_out, amax_buf, *_ = Tensor.custom_kernel(fp8_out, amax_buf, xw13, amax_state, fxn=fxn, grad_fxn=_fused_quantize_bwd_w13)
|
||||
|
||||
@@ -2745,14 +2745,14 @@ def asm_gemm(a:Tensor, b:Tensor, x_scale:Tensor|None=None, w_scale:Tensor|None=N
|
||||
|
||||
if is_multi:
|
||||
if n_sharded:
|
||||
out = Tensor(Tensor.invalid(batch, M, N//len(a.device), dtype=out_dtype, device=a.device).uop.multi(2), device=a.device)
|
||||
out = Tensor(Tensor.invalids(batch, M, N//len(a.device), dtype=out_dtype, device=a.device).uop.multi(2), device=a.device)
|
||||
elif m_sharded:
|
||||
out = Tensor(Tensor.invalid(batch, M, N, dtype=out_dtype, device=a.device).uop.multi(1), device=a.device)
|
||||
out = Tensor(Tensor.invalids(batch, M, N, dtype=out_dtype, device=a.device).uop.multi(1), device=a.device)
|
||||
else:
|
||||
out = Tensor(Tensor.invalid(batch//len(a.device) if a.uop.axis==0 else batch, M, N, dtype=out_dtype, device=a.device).uop.multi(0),
|
||||
out = Tensor(Tensor.invalids(batch//len(a.device) if a.uop.axis==0 else batch, M, N, dtype=out_dtype, device=a.device).uop.multi(0),
|
||||
device=a.device)
|
||||
else:
|
||||
out = Tensor.invalid(batch, M, N, dtype=out_dtype, device=a.device)
|
||||
out = Tensor.invalids(batch, M, N, dtype=out_dtype, device=a.device)
|
||||
|
||||
renderer = Device[dname:=(a.device[0] if is_multi else a.device)].renderer
|
||||
dname, arch = dname.split(":")[0], renderer.target.arch
|
||||
|
||||
@@ -10,11 +10,11 @@ from tinygrad.uop.ops import UOp, Ops, KernelInfo
|
||||
|
||||
def _sharded_empty(shape:Tensor, ref:Tensor, axis:int|None, dtype:DTypeLike|None=None) -> Tensor:
|
||||
dtype = dtype or ref.dtype
|
||||
if not isinstance(ref.device, tuple): return Tensor.invalid(*shape, dtype=dtype, device=ref.device)
|
||||
if not isinstance(ref.device, tuple): return Tensor.invalids(*shape, dtype=dtype, device=ref.device)
|
||||
shard_axis = ref.uop.axis if axis is None else axis
|
||||
shape = tuple(s // len(ref.device) if i == shard_axis else s for i, s in enumerate(shape))
|
||||
axis = ref.uop.axis if axis is None else axis
|
||||
return Tensor(Tensor.invalid(*shape, dtype=dtype, device=ref.device).uop.multi(axis), dtype=dtype, device=ref.device)
|
||||
return Tensor(Tensor.invalids(*shape, dtype=dtype, device=ref.device).uop.multi(axis), dtype=dtype, device=ref.device)
|
||||
|
||||
def _sharded_empty_like(ref:Tensor, axis:int|None=None) -> Tensor:
|
||||
return _sharded_empty(ref.shape, ref, axis)
|
||||
|
||||
@@ -299,8 +299,8 @@ class TestCustomKernel(unittest.TestCase):
|
||||
from tinygrad import function
|
||||
@function(precompile=True)
|
||||
def run(x:Tensor, w:Tensor) -> Tensor:
|
||||
out = Tensor.invalid(*x.shape, dtype=x.dtype)
|
||||
tmp = Tensor.invalid(*x.shape, dtype=x.dtype)
|
||||
out = Tensor.invalids(*x.shape, dtype=x.dtype)
|
||||
tmp = Tensor.invalids(*x.shape, dtype=x.dtype)
|
||||
out, tmp = Tensor.custom_kernel(out, tmp, x, w, fxn=custom_add_with_tmp)[:2]
|
||||
return out+tmp
|
||||
|
||||
|
||||
@@ -293,7 +293,7 @@ class TestSetitem(unittest.TestCase):
|
||||
np.testing.assert_allclose(b.numpy(), [0, 2, 4, 6])
|
||||
|
||||
def test_setitem_multiple_disjoint_on_invalid(self):
|
||||
z = Tensor.invalid(10, dtype="int").realize()
|
||||
z = Tensor.invalids(10, dtype="int").realize()
|
||||
z[2:5] = 2
|
||||
z[6:7] = 3
|
||||
z.realize()
|
||||
|
||||
@@ -430,8 +430,8 @@ class TestFunctionTuple(unittest.TestCase):
|
||||
|
||||
@function(precompile=True, precompile_backward=True)
|
||||
def f(a:Tensor):
|
||||
c = Tensor.invalid(*a.shape, dtype=a.dtype, device=a.device)
|
||||
d = Tensor.invalid(3, dtype=a.dtype, device=a.device)
|
||||
c = Tensor.invalids(*a.shape, dtype=a.dtype, device=a.device)
|
||||
d = Tensor.invalids(3, dtype=a.dtype, device=a.device)
|
||||
c, d = Tensor.custom_kernel(c, d, a, fxn=my_kernel, grad_fxn=my_grad)[:2]
|
||||
return c, d
|
||||
|
||||
@@ -454,8 +454,8 @@ class TestFunctionTuple(unittest.TestCase):
|
||||
|
||||
@function(precompile=True, precompile_backward=True)
|
||||
def f(a:Tensor):
|
||||
c = Tensor.invalid(*a.shape, dtype=a.dtype, device=a.device)
|
||||
d = Tensor.invalid(*a.shape, dtype=a.dtype, device=a.device)
|
||||
c = Tensor.invalids(*a.shape, dtype=a.dtype, device=a.device)
|
||||
d = Tensor.invalids(*a.shape, dtype=a.dtype, device=a.device)
|
||||
c, d = Tensor.custom_kernel(c, d, a, fxn=my_kernel, grad_fxn=my_grad)[:2]
|
||||
return (c, d)
|
||||
|
||||
|
||||
@@ -645,7 +645,7 @@ class Tensor(OpMixin):
|
||||
# ***** creation helper functions *****
|
||||
|
||||
@staticmethod
|
||||
def invalid(*shape, **kwargs) -> Tensor:
|
||||
def invalids(*shape, **kwargs) -> Tensor:
|
||||
"""
|
||||
Creates a tensor with the given shape, filled with Invalid.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user