diff --git a/examples/llama3.py b/examples/llama3.py index 3cc8c7666c..11d26a6827 100644 --- a/examples/llama3.py +++ b/examples/llama3.py @@ -103,7 +103,7 @@ class Int8Embedding: if not hasattr(self, 'arange'): self.arange = Tensor.arange(self.vocab_sz, requires_grad=False, device=self.weight.device).unsqueeze(-1) big_shp = idx.shape+(self.vocab_sz, self.embed_sz) arange, idx, vals = self.arange.expand(big_shp), idx.reshape(idx.shape+(1, 1)).expand(big_shp), (self.weight.cast(self.scale.dtype).T*self.scale).T - return (arange == idx).mul(vals).sum(-2, acc_dtype=vals.dtype) + return (arange == idx).mul(vals).sum(-2, dtype=vals.dtype) def NF4Linear(block_size): _CODE = [ diff --git a/examples/mlperf/initializers.py b/examples/mlperf/initializers.py index ba47ce37f7..c92ebf529f 100644 --- a/examples/mlperf/initializers.py +++ b/examples/mlperf/initializers.py @@ -53,7 +53,7 @@ class EmbeddingBert(nn.Embedding): arange_shp, weight_shp, big_shp = (1, 1, self.vocab_sz, 1), (1, 1, self.vocab_sz, self.embed_sz), idx.shape+(self.vocab_sz, self.embed_sz,) if not hasattr(self, 'arange'): self.arange = Tensor.arange(self.vocab_sz, requires_grad=False, device=self.weight.device).reshape(arange_shp) arange, idx, vals = self.arange.expand(big_shp), idx.reshape(idx.shape+(1, 1,)).expand(big_shp), self.weight.cast(dtypes.default_float).reshape(weight_shp).expand(big_shp) - return (arange == idx).mul(vals).sum(2, acc_dtype=vals.dtype) + return (arange == idx).mul(vals).sum(2, dtype=vals.dtype) class LayerNormBert: def __init__(self, normalized_shape:Union[int, tuple[int, ...]], eps:float=1e-12, elementwise_affine:bool=True): diff --git a/extra/gemm/fuzz_matmul.py b/extra/gemm/fuzz_matmul.py index 9d22d55349..b024a29aef 100644 --- a/extra/gemm/fuzz_matmul.py +++ b/extra/gemm/fuzz_matmul.py @@ -22,7 +22,7 @@ if __name__ == "__main__": for K in range(K_START, K_STOP+1, K_STEP): print(f"testing {M=} {N=} {K=}") a, b = Tensor.rand(M, K, dtype=dtype_in).realize(), Tensor.rand(K, N, dtype=dtype_in).realize() - c = a.matmul(b, acc_dtype=acc_dtype).realize() + c = a.matmul(b, dtype=acc_dtype).realize() comp = a.numpy().astype(np.float32) @ b.numpy().astype(np.float32) nc = c.numpy() try: diff --git a/extra/gemm/simple_conv.py b/extra/gemm/simple_conv.py index d7e08ef8dd..ea682a0b30 100644 --- a/extra/gemm/simple_conv.py +++ b/extra/gemm/simple_conv.py @@ -23,7 +23,7 @@ if __name__ == "__main__": for i in range(CNT): if i > 0 and getenv("RAND", 0) != 0: a, b = rand_input() - c = a.conv2d(b, padding=PADDING, acc_dtype=acc_dtype).realize() + c = a.conv2d(b, padding=PADDING, dtype=acc_dtype).realize() if COMP: import numpy as np, time, torch diff --git a/extra/gemm/simple_matmul.py b/extra/gemm/simple_matmul.py index e420018b80..fc06dad1f2 100644 --- a/extra/gemm/simple_matmul.py +++ b/extra/gemm/simple_matmul.py @@ -24,7 +24,7 @@ if __name__ == "__main__": for i in range(CNT): if i > 0 and getenv("RAND", 0) != 0: a, b = init_matrix(M, K), init_matrix(K, N) - c = a.matmul(b, acc_dtype=acc_dtype).realize() + c = a.matmul(b, dtype=acc_dtype).realize() ref = a.numpy().astype(np.float32) @ b.numpy().astype(np.float32) res = c.numpy() diff --git a/extra/gemm/simple_matvec.py b/extra/gemm/simple_matvec.py index 685f926fbf..00966ba1f8 100644 --- a/extra/gemm/simple_matvec.py +++ b/extra/gemm/simple_matvec.py @@ -24,7 +24,7 @@ if __name__ == "__main__": for i in range(CNT): if i > 0 and getenv("RAND", 0) != 0: a, b = _rand(device) - c = a.matmul(b, acc_dtype=acc_dtype).realize() + c = a.matmul(b, dtype=acc_dtype).realize() nc = c.numpy() comp = a.numpy().astype(np.float32) @ b.numpy().astype(np.float32) np.testing.assert_allclose(nc, comp, atol=ATOL, rtol=RTOL) diff --git a/extra/torch_backend/backend.py b/extra/torch_backend/backend.py index 1afb3a10f9..e43dc63163 100644 --- a/extra/torch_backend/backend.py +++ b/extra/torch_backend/backend.py @@ -331,7 +331,7 @@ tiny_backend_out = {**{f"aten.{x}.out":getattr(Tensor,x) for x in simple_tensor_ # NOTE: axis=[] in torch means all, change tinygrad? "aten.sum.IntList_out": lambda self,axis,keepdim=False,dtype=None: self.sum(axis if axis is None or len(axis) else None, keepdim, - acc_dtype = _from_torch_dtype(dtype) if dtype is not None else None), + dtype = _from_torch_dtype(dtype) if dtype is not None else None), }} # we add the "out" here diff --git a/test/test_dtype.py b/test/test_dtype.py index 1c179b03ba..c8d4b00230 100644 --- a/test/test_dtype.py +++ b/test/test_dtype.py @@ -494,7 +494,7 @@ class TestTypeSpec(unittest.TestCase): with self.assertRaises(AttributeError): Tensor([1, 2, 3], dtype="nonexistdtype") with self.assertRaises(AttributeError): Tensor([1, 2, 3], dtype="") - np.testing.assert_equal(Tensor(n).sum(acc_dtype="int16").numpy(), Tensor(n).sum(acc_dtype=dtypes.int16).numpy()) + np.testing.assert_equal(Tensor(n).sum(dtype="int16").numpy(), Tensor(n).sum(dtype=dtypes.int16).numpy()) @given(strat.sampled_from(dtype_ints), strat.sampled_from(dtype_floats)) def test_creation(self, default_int, default_float): @@ -694,21 +694,21 @@ class TestAutoCastType(unittest.TestCase): assert (Tensor([0, 1], dtype=dtypes.float64)).sum().dtype == dtypes.float64 @unittest.skipUnless(is_dtype_supported(dtypes.float16), "need float16") - def test_sum_acc_dtype(self): + def test_sum_dtype_arg(self): t = Tensor([40000, 40000], dtype=dtypes.float16) # default float16 sum returns in float16, overflowed in this case assert t.sum().dtype == dtypes.float16 assert math.isinf(t.sum().numpy().item()) - # specifiying acc_dtype and it's not downcasted - assert t.sum(acc_dtype=dtypes.float32).dtype == dtypes.float32 - np.testing.assert_allclose(t.sum(acc_dtype=dtypes.float32).numpy(), 80000) + # specifiying dtype and it's not downcasted + assert t.sum(dtype=dtypes.float32).dtype == dtypes.float32 + np.testing.assert_allclose(t.sum(dtype=dtypes.float32).numpy(), 80000) - def test_prod_acc_dtype(self): + def test_prod_dtype_arg(self): t = Tensor([100, 200], dtype=dtypes.int32) assert t.prod().dtype == dtypes.int32 np.testing.assert_allclose(t.prod().numpy(), 20000) - assert t.prod(acc_dtype=dtypes.float32).dtype == dtypes.float32 - np.testing.assert_allclose(t.prod(acc_dtype=dtypes.float32).numpy(), 20000) + assert t.prod(dtype=dtypes.float32).dtype == dtypes.float32 + np.testing.assert_allclose(t.prod(dtype=dtypes.float32).numpy(), 20000) def test_mean(self): assert (Tensor([0, 1], dtype=dtypes.bool)).mean().dtype == dtypes.float32 @@ -745,8 +745,8 @@ class TestAutoCastType(unittest.TestCase): t1 = Tensor([0, 1], dtype=dt1) t2 = Tensor([0, 1], dtype=dt2) assert (t1 @ t2).dtype == least_upper_dtype(dt1, dt2) - # if acc_dtype is specified, return in acc_dtype - assert (t1.matmul(t2, acc_dtype=acc_dt).dtype == acc_dt) + # if dtype is specified, return in dtype + assert (t1.matmul(t2, dtype=acc_dt).dtype == acc_dt) @staticmethod def check_where_alternate_input_other(input_, other, data_type): diff --git a/test/test_linearizer.py b/test/test_linearizer.py index 8ed22f98e0..f9601eb131 100644 --- a/test/test_linearizer.py +++ b/test/test_linearizer.py @@ -28,7 +28,7 @@ def helper_realized_ast(r:Union[Tensor, list[Tensor]]) -> tuple[UOp, list[Buffer def helper_tc_allclose(n:int, m:int, k:int, dtype_in:DType, dtype_out:DType, axis:int=0, tc_select:int=-1, tc_opt:int=0): a, b = Tensor.rand(m, k, dtype=dtype_in), Tensor.rand(k, n, dtype=dtype_in) np_a, np_b = a.numpy(), b.numpy() - r = a.matmul(b, acc_dtype=dtype_out) + r = a.matmul(b, dtype=dtype_out) sched = r.schedule() realized_ast = sched[-1].ast run_schedule(sched) @@ -47,7 +47,7 @@ def helper_tc_allclose(n:int, m:int, k:int, dtype_in:DType, dtype_out:DType, axi def helper_tc_ensure_uops_and_opts_count(n: int, m:int, k:int, dtype_in:DType, dtype_out:DType, axis:int=0, tc_select:int=-1, tc_opt:int=0, ensure_triggered:bool=True): a, b = Tensor.rand(m, k, dtype=dtype_in), Tensor.rand(k, n, dtype=dtype_in) - r = a.matmul(b, acc_dtype=dtype_out) + r = a.matmul(b, dtype=dtype_out) sched = r.schedule() realized_ast = sched[-1].ast k = Kernel(realized_ast) @@ -1050,11 +1050,11 @@ class TestLinearizer(unittest.TestCase): for tensor_dtype, acc_dtype, expected_dtype in tests: if is_dtype_supported(tensor_dtype) and is_dtype_supported(acc_dtype) and is_dtype_supported(expected_dtype): a, b = Tensor.rand(8, 8, dtype=tensor_dtype), Tensor.rand(8, 8, dtype=tensor_dtype) - helper_arg_acc_dtype(a.sum(acc_dtype=acc_dtype), expected_dtype) - helper_arg_acc_dtype(a.matmul(b, acc_dtype=acc_dtype), expected_dtype) - helper_arg_acc_dtype(Tensor.einsum("ki,ij->kj", a, b, acc_dtype=acc_dtype), expected_dtype) + helper_arg_acc_dtype(a.sum(dtype=acc_dtype), expected_dtype) + helper_arg_acc_dtype(a.matmul(b, dtype=acc_dtype), expected_dtype) + helper_arg_acc_dtype(Tensor.einsum("ki,ij->kj", a, b, dtype=acc_dtype), expected_dtype) d, w = Tensor.rand(4, 8, 8, 8, dtype=tensor_dtype), Tensor.rand(8, 8, 2, 2, dtype=tensor_dtype) - helper_arg_acc_dtype(d.conv2d(w, acc_dtype=acc_dtype), expected_dtype) + helper_arg_acc_dtype(d.conv2d(w, dtype=acc_dtype), expected_dtype) @unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores") def test_tensor_cores(self): @@ -1101,7 +1101,7 @@ class TestLinearizer(unittest.TestCase): for axis in range(9): a = Tensor.rand(16, 16, 29, 29, dtype=tc.dtype_in).realize() b = Tensor.rand(32, 16, 16, 16, dtype=tc.dtype_in).realize() - c = a.conv2d(b, padding=1, acc_dtype=tc.dtype_out) + c = a.conv2d(b, padding=1, dtype=tc.dtype_out) realized_ast, real_bufs = helper_realized_ast(c) k = Kernel(realized_ast) @@ -1130,7 +1130,7 @@ class TestLinearizer(unittest.TestCase): def test_tensor_cores_unroll_phi(self): tc = Device[Device.DEFAULT].renderer.tensor_cores[0] x, y = Tensor.rand(128, 128, dtype=tc.dtype_in), Tensor.rand(128, 128, dtype=tc.dtype_in) - r = x.matmul(y, acc_dtype=tc.dtype_out) + r = x.matmul(y, dtype=tc.dtype_out) k = helper_linearizer_opt(r, [[Opt(OptOps.UNROLL, 0, 4)]], apply_tc=True, atol=3e-2, rtol=1e-3)[-1] for u in k.uops: if u.op is Ops.WMMA: @@ -1141,7 +1141,7 @@ class TestLinearizer(unittest.TestCase): def test_tensor_cores_unroll_casted_phi(self): tc = [tc for tc in Device[Device.DEFAULT].renderer.tensor_cores if tc.dtype_in != tc.dtype_out][0] x, y = Tensor.rand(128, 128, dtype=tc.dtype_in), Tensor.rand(128, 128, dtype=tc.dtype_in) - r = x.matmul(y, acc_dtype=tc.dtype_out) + r = x.matmul(y, dtype=tc.dtype_out) k = helper_linearizer_opt(r, [[Opt(OptOps.UNROLL, 0, 4)]], apply_tc=True, atol=3e-2, rtol=1e-3)[-1] for u in k.uops: if u.op is Ops.WMMA: @@ -1154,7 +1154,7 @@ class TestLinearizer(unittest.TestCase): # all ASSIGN children are outside the loop tc = [tc for tc in Device[Device.DEFAULT].renderer.tensor_cores if tc.dtype_in != tc.dtype_out][0] x, y = Tensor.rand(128, 128, dtype=tc.dtype_in), Tensor.rand(128, 128, dtype=tc.dtype_in) - r = x.matmul(y, acc_dtype=tc.dtype_out).relu() + r = x.matmul(y, dtype=tc.dtype_out).relu() k = helper_linearizer_opt(r, [[Opt(OptOps.UNROLL, 0, 4)]], apply_tc=True, atol=3e-2, rtol=1e-3)[-1] for u in k.uops: if u.op is Ops.WMMA: @@ -2000,7 +2000,7 @@ class TestKernelOpts(unittest.TestCase): # bf16 buffer returns float32 numpy outputs so test would fail. testing opt with half suffices. if tc.dtype_in != dtypes.half and tc.dtype_out != dtypes.half: continue a, b = Tensor.rand(N, N, dtype=tc.dtype_in), Tensor.rand(N, N, dtype=tc.dtype_in) - r = a.matmul(b, acc_dtype=tc.dtype_out) + r = a.matmul(b, dtype=tc.dtype_out) (atol, rtol) = ((0.25, 0.01) if tc.dtype_out == dtypes.half else (3e-2, 1e-3)) if tc.dtype_in == dtypes.half else (1e-4, 1e-4) helper_linearizer_opt(r, [ [], @@ -2027,7 +2027,7 @@ class TestKernelOpts(unittest.TestCase): # bf16 buffer returns float32 numpy outputs so test would fail. testing opt with half suffices. if tc.dtype_in != dtypes.half and tc.dtype_out != dtypes.half: continue a, b = Tensor.rand(N, N, dtype=tc.dtype_in), Tensor.rand(N, N, dtype=tc.dtype_in) - r = a.matmul(b, acc_dtype=tc.dtype_out) + r = a.matmul(b, dtype=tc.dtype_out) (atol, rtol) = ((0.25, 0.01) if tc.dtype_out == dtypes.half else (3e-2, 1e-3)) if tc.dtype_in == dtypes.half else (1e-4, 1e-4) helper_linearizer_opt(r, [ [Opt(OptOps.UNROLL, 0, 0)], # check full unroll of reduce with locals @@ -2093,10 +2093,10 @@ class TestKernelOpts(unittest.TestCase): helper_linearizer_opt(a.sum(0), [[Opt(OptOps.PADTO, axis, 32)],]) helper_linearizer_opt(b.sum(), [[Opt(OptOps.PADTO, axis, 32)],]) helper_linearizer_opt(b.sum(0), [[Opt(OptOps.PADTO, axis, 32)],]) - helper_linearizer_opt(b.sum(acc_dtype=dtypes.bool), [[Opt(OptOps.PADTO, axis, 32)],]) + helper_linearizer_opt(b.sum(dtype=dtypes.bool), [[Opt(OptOps.PADTO, axis, 32)],]) if Device.DEFAULT != "WEBGPU": - helper_linearizer_opt(b.sum(0, acc_dtype=dtypes.bool), [[Opt(OptOps.PADTO, axis, 32)],]) - helper_linearizer_opt(b.sum(1, acc_dtype=dtypes.bool), [[Opt(OptOps.PADTO, axis, 32)],]) + helper_linearizer_opt(b.sum(0, dtype=dtypes.bool), [[Opt(OptOps.PADTO, axis, 32)],]) + helper_linearizer_opt(b.sum(1, dtype=dtypes.bool), [[Opt(OptOps.PADTO, axis, 32)],]) # having unsafe ops after sum is fine helper_linearizer_opt(a.sum().exp(), [[Opt(OptOps.PADTO, 0, 32)],]) diff --git a/test/test_ops.py b/test/test_ops.py index 497a6e9eee..61a86873a8 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1273,11 +1273,11 @@ class TestOps(unittest.TestCase): self.helper_test_exception([()], lambda x: x.sum(1), lambda x: x.sum(1), expected=IndexError) self.helper_test_exception([()], lambda x: x.sum((1,)), lambda x: x.sum((1,)), expected=IndexError) - def test_sum_acc_dtype(self): - helper_test_op([(45,3)], lambda x: x.sum(), lambda x: x.sum(acc_dtype=dtypes.float32)) - if is_dtype_supported(dtypes.float64): helper_test_op([(45,3)], lambda x: x.sum(dtype=torch.float64), lambda x: x.sum(acc_dtype=dtypes.float64)) + def test_sum_dtype_arg(self): + helper_test_op([(45,3)], lambda x: x.sum(), lambda x: x.sum(dtype=dtypes.float32)) + if is_dtype_supported(dtypes.float64): helper_test_op([(45,3)], lambda x: x.sum(dtype=torch.float64), lambda x: x.sum(dtype=dtypes.float64)) - with self.assertRaises(AttributeError): Tensor([1.0, 2.0]).sum(acc_dtype="") + with self.assertRaises(AttributeError): Tensor([1.0, 2.0]).sum(dtype="") def test_sum_with_zeros_shape(self): helper_test_op([(4, 0)], lambda x: x.sum(axis=(0,))) @@ -1294,8 +1294,8 @@ class TestOps(unittest.TestCase): helper_test_op([()], lambda x: x.prod(0)) helper_test_op([()], lambda x: x.prod(-1)) - def test_prod_acc_dtype(self): - with self.assertRaises(AttributeError): Tensor([1.0, 2.0]).prod(acc_dtype="") + def test_prod_dtype_arg(self): + with self.assertRaises(AttributeError): Tensor([1.0, 2.0]).prod(dtype="") def test_min(self): helper_test_op([(3,3)], lambda x: x.min()) diff --git a/test/test_quantize_onnx.py b/test/test_quantize_onnx.py index 2fd914f159..1f6681b0cc 100644 --- a/test/test_quantize_onnx.py +++ b/test/test_quantize_onnx.py @@ -69,7 +69,7 @@ class TestQuantizeOnnx(unittest.TestCase): def test_prequant_conv2d_1x1(self): X = Tensor(np.random.uniform(0, 255, size=(1, 32, 128, 128)).astype(np.uint8)) W = Tensor(np.random.uniform(0, 255, size=(64, 32, 1, 1)).astype(np.uint8)) - out = X.conv2d(W, acc_dtype=X.dtype) + out = X.conv2d(W, dtype=X.dtype) opts = [Opt(op=OptOps.UPCAST, axis=1, arg=128), Opt(op=OptOps.UNROLL, axis=0, arg=4)] sexec(out, opts) @@ -77,7 +77,7 @@ class TestQuantizeOnnx(unittest.TestCase): N = 512 X = Tensor(np.random.uniform(0, 255, size=(N,N)).astype(np.uint8)) W = Tensor(np.random.uniform(0, 255, size=(N,N)).astype(np.uint8)) - out = X.matmul(W, acc_dtype=X.dtype) + out = X.matmul(W, dtype=X.dtype) opts = [Opt(op=OptOps.UPCAST, axis=1, arg=128), Opt(op=OptOps.UNROLL, axis=0, arg=4)] sexec(out, opts) @@ -204,7 +204,7 @@ class TestQuantizeOnnx(unittest.TestCase): W = Tensor(np.random.uniform(0, 255, size=(N,N)).astype(np.uint8)).realize() #out = X.cast(dtypes.int) @ W.cast(dtypes.int) #out = X @ W - out = X.matmul(W, acc_dtype=X.dtype) + out = X.matmul(W, dtype=X.dtype) opts = [Opt(op=OptOps.UPCAST, axis=0, arg=128), Opt(op=OptOps.UNROLL, axis=0, arg=4)] sexec(out, opts) diff --git a/test/test_renderer_failures.py b/test/test_renderer_failures.py index ebe8a19329..83106b4a24 100644 --- a/test/test_renderer_failures.py +++ b/test/test_renderer_failures.py @@ -68,8 +68,8 @@ class TestPTXFailures(unittest.TestCase): def test_gated_define_acc_with_half_dtype(self): a = Tensor.randn(32, 32, dtype=dtypes.half).realize() b = Tensor.randn(34, 32, dtype=dtypes.half).realize() - result = a.pad((1,1)).matmul(b, acc_dtype=dtypes.half).numpy() - reference = a.pad((1,1)).matmul(b, acc_dtype=dtypes.float).numpy() + result = a.pad((1,1)).matmul(b, dtype=dtypes.half).numpy() + reference = a.pad((1,1)).matmul(b, dtype=dtypes.float).numpy() np.testing.assert_allclose(result, reference) if __name__ == '__main__': diff --git a/test/test_search.py b/test/test_search.py index ea5bc92d13..aa54cbe3cf 100644 --- a/test/test_search.py +++ b/test/test_search.py @@ -116,7 +116,7 @@ class TestBEAM(unittest.TestCase): for (dtype_in, dtype_out) in multi_shape_dtype_pairs: a = Tensor.rand(16, 16, dtype=dtype_in) b = Tensor.rand(16, 16, dtype=dtype_in) - realized_ast, _ = helper_realized_ast(a.matmul(b, acc_dtype=dtype_out)) + realized_ast, _ = helper_realized_ast(a.matmul(b, dtype=dtype_out)) lins = get_kernel_actions(Kernel(realized_ast)).values() assert len(set(lin.tensor_core.dims for lin in lins if lin.tensor_core is not None)) > 1 diff --git a/tinygrad/nn/__init__.py b/tinygrad/nn/__init__.py index fd9af9b2d9..3818634fb9 100644 --- a/tinygrad/nn/__init__.py +++ b/tinygrad/nn/__init__.py @@ -323,7 +323,7 @@ class Embedding: if not hasattr(self, 'arange'): self.arange = Tensor.arange(self.vocab_sz, requires_grad=False, device=self.weight.device).unsqueeze(-1) big_shp = idx.shape+(self.vocab_sz, self.embed_sz) arange, idx, vals = self.arange.expand(big_shp), idx.reshape(idx.shape+(1, 1)).expand(big_shp), self.weight.expand(big_shp) - return (arange == idx).mul(vals).sum(-2, acc_dtype=vals.dtype) + return (arange == idx).mul(vals).sum(-2, dtype=vals.dtype) class LSTMCell: """ diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 8bb77a5d83..2f1d5633b5 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -1164,7 +1164,7 @@ class Tensor(SimpleMathTrait): # inject 1's for the extra dims added in create masks reshape_arg = x.shape[:dims[0]] + (1,) * len(big_shape) + x.shape[dims[0]:] # sum reduce the extra dims introduced in create masks - x = (x.reshape(reshape_arg) * mask).sum(sum_axis:=tuple(d + len(big_shape) for d in dims), acc_dtype=x.dtype) + x = (x.reshape(reshape_arg) * mask).sum(sum_axis:=tuple(d + len(big_shape) for d in dims), dtype=x.dtype) # special permute case if dims[0] != 0 and len(dims) != 1 and tuple(dims) != tuple(range(dims[0], dims[-1]+1)): @@ -1255,7 +1255,7 @@ class Tensor(SimpleMathTrait): assert all(s >= i for d,(s,i) in enumerate(zip(self.shape, index.shape)) if d != dim), "requires self.shape[d] >= index.shape[d] for all d != dim" index = index.to(self.device) x = self.shrink(tuple((0, i) if d != dim else None for d,i in enumerate(index.shape))).unsqueeze(-1).transpose(-1, dim) - return (x * index.unsqueeze(-1)._one_hot_along_dim(self.shape[dim])).sum(-1, acc_dtype=self.dtype) + return (x * index.unsqueeze(-1)._one_hot_along_dim(self.shape[dim])).sum(-1, dtype=self.dtype) def cat(self:Tensor, *args:Tensor, dim:int=0) -> Tensor: """ @@ -1564,14 +1564,14 @@ class Tensor(SimpleMathTrait): ret = self._apply_uop(UOp.r, op=op, axis=axis) return ret if keepdim else ret.reshape(tuple(s for i,s in enumerate(self.shape) if i not in axis)) - def sum(self, axis:int|Sequence[int]|None=None, keepdim=False, acc_dtype:DTypeLike|None=None) -> Tensor: + def sum(self, axis:int|Sequence[int]|None=None, keepdim=False, dtype:DTypeLike|None=None) -> Tensor: """ Returns the sum of the elements of the tensor along the specified axis or axes. You can pass in `axis` and `keepdim` keyword arguments to control the axis along which the maximum is computed and whether the reduced dimensions are retained. - You can pass in `acc_dtype` keyword argument to control the data type of the accumulation. + You can pass in `dtype` keyword argument to control the data type of the accumulation. If not specified, the accumulation data type is chosen based on the input tensor's data type. ```python exec="true" source="above" session="tensor" result="python" @@ -1588,17 +1588,17 @@ class Tensor(SimpleMathTrait): print(t.sum(axis=1).numpy()) ``` """ - ret = self.cast(sum_acc_dtype(self.dtype) if acc_dtype is None else acc_dtype)._reduce(Ops.ADD, axis, keepdim) - return ret.cast(self.dtype) if acc_dtype is None and self.dtype in (dtypes.float16, dtypes.bfloat16) else ret + ret = self.cast(sum_acc_dtype(self.dtype) if dtype is None else dtype)._reduce(Ops.ADD, axis, keepdim) + return ret.cast(self.dtype) if dtype is None and self.dtype in (dtypes.float16, dtypes.bfloat16) else ret - def prod(self, axis:int|Sequence[int]|None=None, keepdim=False, acc_dtype:DTypeLike|None=None) -> Tensor: + def prod(self, axis:int|Sequence[int]|None=None, keepdim=False, dtype:DTypeLike|None=None) -> Tensor: """ Returns the product of the elements of the tensor along the specified axis or axes. You can pass in `axis` and `keepdim` keyword arguments to control the axis along which the maximum is computed and whether the reduced dimensions are retained. - You can pass in `acc_dtype` keyword argument to control the data type of the accumulation. + You can pass in `dtype` keyword argument to control the data type of the accumulation. If not specified, the accumulation data type is chosen based on the input tensor's data type. ```python exec="true" source="above" session="tensor" result="python" @@ -1615,7 +1615,7 @@ class Tensor(SimpleMathTrait): print(t.prod(axis=1).numpy()) ``` """ - return self.cast(acc_dtype if acc_dtype is not None else self.dtype)._reduce(Ops.MUL, axis, keepdim) + return self.cast(dtype if dtype is not None else self.dtype)._reduce(Ops.MUL, axis, keepdim) def max(self, axis:int|Sequence[int]|None=None, keepdim=False) -> Tensor: """ @@ -2005,7 +2005,7 @@ class Tensor(SimpleMathTrait): return self._inverse().argmax(axis=axis, keepdim=keepdim) @staticmethod - def einsum(formula:str, *operands:Tensor|Sequence[Tensor], acc_dtype:DTypeLike|None=None) -> Tensor: + def einsum(formula:str, *operands:Tensor|Sequence[Tensor], dtype:DTypeLike|None=None) -> Tensor: """ Sums the product of the elements of the input tensors according to a formula based on the Einstein summation convention. @@ -2047,7 +2047,7 @@ class Tensor(SimpleMathTrait): # sum over all axes that's not in the output, then permute to the output order return functools.reduce(lambda a,b:a*b, xs_) \ - .sum(axis=[axis for axis,(letter,_) in enumerate(letter_val) if letter not in output], acc_dtype=acc_dtype).permute(rhs_order) + .sum(axis=[axis for axis,(letter,_) in enumerate(letter_val) if letter not in output], dtype=dtype).permute(rhs_order) # ***** processing ops ***** @@ -2182,7 +2182,7 @@ class Tensor(SimpleMathTrait): return self.pad(pads, value=dtypes.min(self.dtype))._pool(k_, stride if stride is not None else k_, dilation).max(tuple(range(-len(k_), 0))) def conv2d(self, weight:Tensor, bias:Tensor|None=None, groups=1, stride=1, dilation=1, padding:int|tuple[int, ...]=0, - acc_dtype:DTypeLike|None=None) -> Tensor: + dtype:DTypeLike|None=None) -> Tensor: """ Applies a convolution over a tensor with a given `weight` and optional `bias`. @@ -2208,7 +2208,7 @@ class Tensor(SimpleMathTrait): print(t.conv2d(w).numpy()) ``` """ - if IMAGE: return self.image_conv2d(weight, bias, groups, stride, dilation, padding, acc_dtype) + if IMAGE: return self.image_conv2d(weight, bias, groups, stride, dilation, padding, dtype) (bs,cin_), (cout,cin), HW = self.shape[:2], weight.shape[:2], weight.shape[2:] padding_ = self._resolve_pool_pads(padding, len(HW)) assert groups*cin == cin_ and len(self.shape) == len(weight.shape), f"Input Tensor shape {self.shape} does not match the shape of the weights {weight.shape}. ({groups*cin} vs. {cin_})" # noqa: E501 @@ -2221,7 +2221,7 @@ class Tensor(SimpleMathTrait): x = x.reshape(bs, groups, cin, 1, *oyx, *HW).expand(bs, groups, cin, rcout, *oyx, *HW).permute(0,1,3,*[4+i for i in range(len(oyx))],2,*[4+len(oyx)+i for i in range(len(HW))]) # noqa: E501 # conv! broadcasted to (bs, groups, rcout, *oyx, cin, *HW) - ret = (x * weight.reshape(1, groups, rcout, *[1] * len(oyx), cin, *HW)).sum([-1-i for i in range(1+len(oyx))], keepdim=True, acc_dtype=acc_dtype).reshape(bs, cout, *oyx) # noqa: E501 + ret = (x * weight.reshape(1, groups, rcout, *[1] * len(oyx), cin, *HW)).sum([-1-i for i in range(1+len(oyx))], keepdim=True, dtype=dtype).reshape(bs, cout, *oyx) # noqa: E501 return ret if bias is None else ret.add(bias.reshape(1, -1, *[1] * len(HW))) HWI, HWO = (6,) * len(HW), (4,) * len(HW) # F(4x4,3x3) winograd tiles @@ -2246,7 +2246,7 @@ class Tensor(SimpleMathTrait): dfactors = _apply_winograd_matrix(winograd_Bt, d, len(HW)).reshape(*HWI, bs, groups, 1, cin, *tyx) # matmul; sum across cin: (HWI, bs, groups, rcout, *tyx); then HWI -> HWO: (HWO, bs, groups, rcout, *tyx) - ret = _apply_winograd_matrix(winograd_At, (gfactors * dfactors).sum(axis=-1-len(HW), acc_dtype=acc_dtype), len(HW)) + ret = _apply_winograd_matrix(winograd_At, (gfactors * dfactors).sum(axis=-1-len(HW), dtype=dtype), len(HW)) # interleave tyx and HWO: (bs, groups, rcout, oy, HO, ox, WO) ret = ret.permute([*range(len(HW), len(ret.shape)-len(HW)), *[i+o for i in range(len(HW)) for o in [len(ret.shape)-len(HW),0]]]) @@ -2294,14 +2294,14 @@ class Tensor(SimpleMathTrait): padding = flatten((((k-1)*d-pB,(k-1)*d-pA+op) for k,d,(pB,pA),op in reversed(list(zip(HW, dilation, padding, output_padding))))) return x.conv2d(w.flatten(end_dim=1), groups=groups, bias=bias, dilation=dilation, padding=padding) - def dot(self, w:Tensor, acc_dtype:DTypeLike|None=None) -> Tensor: + def dot(self, w:Tensor, dtype:DTypeLike|None=None) -> Tensor: """ Performs dot product between two tensors. If `w` is 1-D, it's a sum product over the last axis of `self` and `w`. If `w` is N-D with N>=2, it's a sum product over the last axis of `self` and the second-to-last axis of `w`. - You can pass in the optional `acc_dtype` keyword argument to control the data type of the accumulation. + You can pass in the optional `dtype` keyword argument to control the data type of the accumulation. ```python exec="true" source="above" session="tensor" result="python" a = Tensor([1, 2, 3]) @@ -2314,20 +2314,20 @@ class Tensor(SimpleMathTrait): print(a.dot(b).numpy()) ``` """ - if IMAGE: return self.image_dot(w, acc_dtype) + if IMAGE: return self.image_dot(w, dtype) x, dx, dw = self, self.ndim, w.ndim if not (dx > 0 and dw > 0): raise RuntimeError(f"both tensors need to be at least 1D, got {dx}D and {dw}D") if x.shape[-1] != w.shape[axis_w:=-min(w.ndim,2)]: raise RuntimeError(f"cannot dot {x.shape} and {w.shape}") x = x.reshape(*x.shape[0:-1], *[1]*min(dx-1, dw-1, 1), x.shape[-1]) w = w.reshape(*w.shape[0:-2], *[1]*min(dx-1, dw-1, 1), *w.shape[axis_w:]).transpose(-1, axis_w) - return (x*w).sum(-1, acc_dtype=acc_dtype).cast(least_upper_dtype(x.dtype, w.dtype) if acc_dtype is None else acc_dtype) + return (x*w).sum(-1, dtype=dtype).cast(least_upper_dtype(x.dtype, w.dtype) if dtype is None else dtype) - def matmul(self, x:Tensor, reverse=False, acc_dtype:DTypeLike|None=None) -> Tensor: + def matmul(self, x:Tensor, reverse=False, dtype:DTypeLike|None=None) -> Tensor: """ Performs matrix multiplication between two tensors. You can pass in the `reverse` keyword argument to control the order of the matrix multiplication. - You can pass in the optional `acc_dtype` keyword argument to control the data type of the accumulation. + You can pass in the optional `dtype` keyword argument to control the data type of the accumulation. ```python exec="true" source="above" session="tensor" result="python" a = Tensor([[1, 2], [3, 4]]) @@ -2335,7 +2335,7 @@ class Tensor(SimpleMathTrait): print(a.matmul(b).numpy()) ``` """ - return x.dot(self, acc_dtype=acc_dtype) if reverse else self.dot(x, acc_dtype=acc_dtype) + return x.dot(self, dtype=dtype) if reverse else self.dot(x, dtype=dtype) def _cumalu(self, axis:int, op:Ops, _include_initial=False) -> Tensor: assert self.shape[axis] != 0 and op in (Ops.ADD, Ops.MAX) @@ -2553,14 +2553,14 @@ class Tensor(SimpleMathTrait): """ src, mask = self._pre_scatter(dim, index, src) def _inv_mask(a:Tensor|ConstType, b:Tensor|ConstType) -> Tensor: return mask.any(-1).logical_not().where(a, b) - # TODO: should not overwrite acc_dtype here? - if reduce == "sum": return mask.where(src, 0).sum(-1, acc_dtype=self.dtype).add(self if include_self else _inv_mask(self, 0)) - if reduce == "prod": return mask.where(src, 1).prod(-1, acc_dtype=self.dtype).mul(self if include_self else _inv_mask(self, 1)) + # TODO: should not overwrite dtype here? + if reduce == "sum": return mask.where(src, 0).sum(-1, dtype=self.dtype).add(self if include_self else _inv_mask(self, 0)) + if reduce == "prod": return mask.where(src, 1).prod(-1, dtype=self.dtype).mul(self if include_self else _inv_mask(self, 1)) if reduce == "amax": return mask.where(src, m := dtypes.min(src.dtype)).max(-1).maximum(self if include_self else _inv_mask(self, m)) if reduce == "amin": return mask.where(src, m := dtypes.max(src.dtype)).min(-1).minimum(self if include_self else _inv_mask(self, m)) if reduce == "mean": - count = mask.where(1, 0).sum(-1, acc_dtype=self.dtype).add(1 if include_self else _inv_mask(1, 0)) - return mask.where(src, 0).sum(-1, acc_dtype=self.dtype).add(self if include_self else _inv_mask(self, 0)).div(count) + count = mask.where(1, 0).sum(-1, dtype=self.dtype).add(1 if include_self else _inv_mask(1, 0)) + return mask.where(src, 0).sum(-1, dtype=self.dtype).add(self if include_self else _inv_mask(self, 0)).div(count) raise RuntimeError(f"{reduce=} must be one of 'sum', 'prod', 'mean', 'amax', 'amin'") def topk(self, k, dim=-1, largest=True, sorted_=True): @@ -3677,7 +3677,7 @@ class Tensor(SimpleMathTrait): """ # NOTE: it also works when `key` and `value` have symbolic shape. assert all_int(self.shape), f"does not support symbolic shape {self.shape}" - qk = self.matmul(key.transpose(-2,-1), acc_dtype=least_upper_dtype(self.dtype, key.dtype, dtypes.float32)) / math.sqrt(self.shape[-1]) + qk = self.matmul(key.transpose(-2,-1), dtype=least_upper_dtype(self.dtype, key.dtype, dtypes.float32)) / math.sqrt(self.shape[-1]) # handle attention mask if is_causal: if attn_mask is not None: raise RuntimeError("cannot set attn_mask when is_causal=True") @@ -3983,7 +3983,7 @@ class Tensor(SimpleMathTrait): # *** image Tensor function replacements *** - def image_dot(self, w:Tensor, acc_dtype:DTypeLike|None=None) -> Tensor: + def image_dot(self, w:Tensor, dtype:DTypeLike|None=None) -> Tensor: # NOTE: we use a 1x1 conv2d to do the matmul. mxk @ kxn = (1,k,m,1).conv2d(n,k,1,1) x, dx, dw = self, self.ndim, w.ndim if not (dx > 0 and dw > 0): raise RuntimeError(f"both tensors need to be at least 1D, got {dx}D and {dw}D") @@ -3997,9 +3997,9 @@ class Tensor(SimpleMathTrait): cx = self.transpose(self.ndim-1, self.ndim-2).reshape((bs//groups, groups*cin, -1, 1)) # groups*cout x cin x H, W cw = w.transpose(w.ndim-1, w.ndim-2).reshape((groups*cout, cin, 1, 1)) - return cx.image_conv2d(cw, groups=groups, acc_dtype=acc_dtype).reshape(out_shape_t).transpose(self.ndim-1, self.ndim-2) + return cx.image_conv2d(cw, groups=groups, dtype=dtype).reshape(out_shape_t).transpose(self.ndim-1, self.ndim-2) - def image_conv2d(self, weight:Tensor, bias:Tensor|None=None, groups=1, stride=1, dilation=1, padding=0, acc_dtype=None) -> Tensor: + def image_conv2d(self, weight:Tensor, bias:Tensor|None=None, groups=1, stride=1, dilation=1, padding=0, dtype=None) -> Tensor: base_image_type = dtypes.imageh if getenv("FLOAT16", 0) else dtypes.imagef (bs,_,iy,ix), (cout,cin,H,W) = self.shape, weight.shape @@ -4048,7 +4048,7 @@ class Tensor(SimpleMathTrait): w = w.permute(0,4,2,5,1,3).reshape((1, 1, 1, *cout_expand, rcin_hi, rcin_lo, H, W)) # the conv! - ret = (x*w).cast(base_image_type((bs*oy, ox*cout//4, 4)) if IMAGE >= 2 else dtypes.float32).sum((-4, -3, -2, -1), acc_dtype=acc_dtype) + ret = (x*w).cast(base_image_type((bs*oy, ox*cout//4, 4)) if IMAGE >= 2 else dtypes.float32).sum((-4, -3, -2, -1), dtype=dtype) # undo hack for non multiples of 4 on C.rcout if added_output_channels != 0: