diff --git a/test/test_linearizer.py b/test/test_linearizer.py index 4641bc582c..1953fdbb48 100644 --- a/test/test_linearizer.py +++ b/test/test_linearizer.py @@ -733,7 +733,7 @@ class TestLinearizer(unittest.TestCase): # check that the float4 cast collapses store_vals = [u.src[-1] for u in k.uops if u.op is UOps.STORE] for val in store_vals: - assert val.dtype == dtypes.float.vec(4) and val.op is not UOps.CAST + assert val.dtype == dtypes.float.vec(4) and val.op not in {UOps.VECTORIZE, UOps.CAST} @unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "test requires float4") def test_grouped_store_values(self): @@ -741,7 +741,7 @@ class TestLinearizer(unittest.TestCase): out = x.flip((0,1)).contiguous() k = helper_linearizer_opt(out)[-1] store_val = [u.src[-1] for u in k.uops if u.op is UOps.STORE][0] - assert store_val.dtype == dtypes.float.vec(4) and store_val.op is not UOps.CAST + assert store_val.dtype == dtypes.float.vec(4) and store_val.op not in {UOps.VECTORIZE, UOps.CAST} @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals") @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared") @@ -759,7 +759,7 @@ class TestLinearizer(unittest.TestCase): barrier = [u for u in k.uops if u.op is UOps.BARRIER][0] # check that the float4 cast collapses for all stores for store in local_stores+global_stores: - assert store.src[2].dtype == dtypes.float.vec(2) and store.src[2].op is not UOps.CAST + assert store.src[2].dtype == dtypes.float.vec(2) and store.src[2].op not in {UOps.VECTORIZE, UOps.CAST} # # check the children's vins # TODO: src ALU are not the same, should it? # assert barrier.src == tuple(local_stores) @@ -776,7 +776,7 @@ class TestLinearizer(unittest.TestCase): # the float4 value stores directly in lds and we skip upcast assert stores[0].src[-1].dtype == dtypes.float.vec(4) - assert stores[0].src[-1].op is not UOps.CAST + assert stores[0].src[-1].op not in {UOps.VECTORIZE, UOps.CAST} # the global store doesn't change assert stores[1].src[2].dtype == dtypes.float @@ -791,7 +791,7 @@ class TestLinearizer(unittest.TestCase): ] k = helper_linearizer_ast(ast, [Tensor.randn(240*40).realize()], opts=[opt])[-1] out = [u for u in k.uops if u.op is UOps.STORE][0] - assert out.src[-1].op is UOps.CAST and out.src[-1].dtype == dtypes.float.vec(4) + assert out.src[-1].op is UOps.VECTORIZE and out.src[-1].dtype == dtypes.float.vec(4) @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals") @unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "test requires float4") @@ -803,7 +803,7 @@ class TestLinearizer(unittest.TestCase): Opt(op=OptOps.UPCAST, axis=1, amt=0), Opt(op=OptOps.UPCAST, axis=0, amt=2)] k = helper_linearizer_ast(ast, [Tensor.randn(8*32).realize()], opts=[opt])[-1] out = [u for u in k.uops if u.op is UOps.STORE][0] - assert out.src[-1].op is UOps.CAST and out.src[-1].dtype == dtypes.float.vec(2) + assert out.src[-1].op is UOps.VECTORIZE and out.src[-1].dtype == dtypes.float.vec(2) @unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "need backends that support float4") class TestFloat4(unittest.TestCase): diff --git a/test/test_uop_graph.py b/test/test_uop_graph.py index 41c04a1915..17cf46a5f0 100644 --- a/test/test_uop_graph.py +++ b/test/test_uop_graph.py @@ -119,16 +119,25 @@ class TestUOpGraph(TestUOps): self.assertEqual(out.op, UOps.CONST) self.assertEqual(out.arg, 0) - def test_cast_vectorized_fold(self): + def test_const_vectorize_fold(self): + c0 = UOp(UOps.CONST, dtypes.half, arg=0.0) + out = UOp(UOps.VECTORIZE, dtypes.half.vec(2), (c0, c0)) + g = UOpGraph([out]) + self.assertEqual(len(g.uops), 1) + out = g.uops[-1] + self.assertEqual(out.op, UOps.CONST) + self.assertEqual(out.arg, 0.0) + + def test_noop_vectorize_fold(self): d0 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=(0, True)) idx = UOp.const(dtypes.int, 0) ld = UOp(UOps.LOAD, dtypes.float.vec(2), (d0, idx)) - cast = UOp(UOps.CAST, dtypes.float.vec(2), (ld,)) - x = UOp(UOps.GEP, dtypes.float, (cast, ), arg=0) + vec = UOp(UOps.VECTORIZE, dtypes.float.vec(2), (ld,)) + x = UOp(UOps.GEP, dtypes.float, (vec, ), arg=0) alu = UOp(UOps.ALU, dtypes.float, (x, ), UnaryOps.SQRT) out = UOp(UOps.STORE, None, (d0, idx, alu)) g = UOpGraph([out]) - self.assertEqual(len([x for x in g.uops if x.op is UOps.CAST]), 0) + self.assertEqual(len([x for x in g.uops if x.op is UOps.VECTORIZE]), 0) def test_cast_alu_fold(self): d0 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.bool), arg=(0, True)) diff --git a/tinygrad/codegen/linearizer.py b/tinygrad/codegen/linearizer.py index cd1cc2255b..8954f7f85d 100644 --- a/tinygrad/codegen/linearizer.py +++ b/tinygrad/codegen/linearizer.py @@ -156,7 +156,7 @@ class Linearizer(Kernel): buf_uop = self.buf_uops[i] assert buf_uop is not None, f"buffer {i} wasn't UOped" image_idx, valid = to_image_idx(buf.dtype.shape, idx, valid) - rendered_idx = UOp(UOps.CAST, dtypes.int.vec(2), tuple(x.render(render_ops, self.loop_uops) for x in image_idx)) + rendered_idx = UOp(UOps.VECTORIZE, dtypes.int.vec(2), tuple(x.render(render_ops, self.loop_uops) for x in image_idx)) valid_tuple = (valid_uop, UOp.const(buf.dtype.base.vec(4), invalid_value)) if valid.min == 0 else tuple() self.load_cache[key] = UOp(UOps.LOAD, buf.dtype.base.vec(4), (buf_uop, rendered_idx) + valid_tuple + barrier) if localtype == localtype.scalar(): @@ -197,7 +197,7 @@ class Linearizer(Kernel): amt = len(grouped) idx, valid = self.sts[i].expr_idxs(k) assert idx == ((idx//amt)*amt), "float4 stores are always aligned" - store_offset_new[k] = UOp(UOps.CAST, buf.dtype.vec(amt), tuple(grouped)) + store_offset_new[k] = UOp(UOps.VECTORIZE, buf.dtype.vec(amt), tuple(grouped)) store_offset = store_offset_new stores = [] @@ -205,7 +205,7 @@ class Linearizer(Kernel): idx, valid = self.sts[i].expr_idxs(_idx) if isinstance(buf.dtype, ImageDType): image_idx, valid = to_image_idx(buf.dtype.shape, idx, valid) - rendered_idx = UOp(UOps.CAST, dtypes.int.vec(2), \ + rendered_idx = UOp(UOps.VECTORIZE, dtypes.int.vec(2), \ tuple(x.render(render_ops, self.loop_uops) for x in image_idx)) else: rendered_idx = idx.render(render_ops, self.loop_uops) @@ -287,13 +287,13 @@ class Linearizer(Kernel): next_ *= 1 if stride == 0 else sz return strides upcasts, dev = [upcast_strides(x) for x in [locals_to_store[0][0], locals_to_store[1][0], 0]], self.opts.device - # cast initial accs - wmmas = [UOp(UOps.CAST, (dt3:=tc.dtype_out.vec(wmma_sz[2])), tuple(accs[reduceop][x:x+wmma_sz[2]])) + # vectorize initial accs + wmmas = [UOp(UOps.VECTORIZE, (dt3:=tc.dtype_out.vec(wmma_sz[2])), tuple(accs[reduceop][x:x+wmma_sz[2]])) for x in range(0, len(accs[reduceop]), wmma_sz[2])] for it in [x[::-1] for x in itertools.product(*list([range(sz) for _,sz in upcasts[0]][::-1]))]: offs = [x*y for (x,y) in zip([sum([prod(x) for x in zip(it, [stride for stride,_ in y])]) for y in upcasts], wmma_sz)] - ops = (UOp(UOps.CAST, tc.dtype_in.vec(wmma_sz[0]), tuple(locals_to_store[0][2][offs[0]:offs[0]+wmma_sz[0]])), - UOp(UOps.CAST, tc.dtype_in.vec(wmma_sz[1]), tuple(locals_to_store[1][2][offs[1]:offs[1]+wmma_sz[1]])), + ops = (UOp(UOps.VECTORIZE, tc.dtype_in.vec(wmma_sz[0]), tuple(locals_to_store[0][2][offs[0]:offs[0]+wmma_sz[0]])), + UOp(UOps.VECTORIZE, tc.dtype_in.vec(wmma_sz[1]), tuple(locals_to_store[1][2][offs[1]:offs[1]+wmma_sz[1]])), wmmas[(wmma_idx:=offs[2]//wmma_sz[2])]) # TODO: don't need to DEFINE_ACC, pass to WMMA in op3, or PHI accs that are not valid wmmas[wmma_idx] = UOp(UOps.WMMA, dt3, ops, (str(tc), tc.dims, tc.dtype_in, tc.dtype_out, tuple(wmma_sz), dev)) diff --git a/tinygrad/codegen/uops.py b/tinygrad/codegen/uops.py index ebe3fc454e..696249cccf 100644 --- a/tinygrad/codegen/uops.py +++ b/tinygrad/codegen/uops.py @@ -17,7 +17,7 @@ class UOps(Enum): CONST = auto(); SPECIAL = auto() # noqa: E702 NOOP = auto(); UNMUL = auto(); GEP = auto() # noqa: E702 # math ops - CAST = auto(); BITCAST = auto() # noqa: E702 + CAST = auto(); BITCAST = auto(); VECTORIZE = auto() # noqa: E702 ALU = auto(); WMMA = auto() # noqa: E702 # memory/assignment ops LOAD = auto(); STORE = auto(); PHI = auto() # noqa: E702 @@ -96,8 +96,10 @@ def type_verify(uops): assert dtype is not None and src[0].dtype == dtype.scalar(), f"type of {src[0].dtype=} must be a scalar {dtype.scalar()}" arg = src[0].arg assert dtype is not None and type(arg) is type(dtypes.as_const(arg, dtype)), f"type of {arg=} does not match {dtype}" - if uop in {UOps.CAST, UOps.BITCAST}: assert arg is None # type is the output type, not an arg - if uop is UOps.CAST and dtype is not None and dtype.count > 1: assert len(src) == dtype.count + if uop in {UOps.CAST, UOps.BITCAST, UOps.VECTORIZE}: assert arg is None and dtype is not None # type is the output type, not an arg + if uop is UOps.CAST: assert dtype.count == 1 and len(src) == dtype.count + if uop is UOps.VECTORIZE: assert dtype.count > 1 and len(src) == dtype.count + if uop is UOps.VECTORIZE: assert dtype == src[0].dtype.vec(len(src)), f"{dtype=} must be {src[0].dtype.vec(len(src))}" if uop is UOps.LOAD and len(src) > 3 and src[2].op is UOps.ALU: assert src[2].dtype == dtypes.bool and src[3].dtype == dtype if uop is UOps.STORE: assert dtype is None, f"{uop} dtype must be None, got {dtype}" @@ -231,6 +233,7 @@ constant_folder = PatternMatcher([ # const rules (UPat(UOps.GEP, name="root", src=(UPat(UOps.CONST, name="c"),)), lambda root, c: UOp.const(root.dtype, c.arg)), (UPat(UOps.CAST, name="root", src=UPat(UOps.CONST, name="c")), lambda root, c: UOp.const(root.dtype, c.arg)), + (UPat(UOps.VECTORIZE, name="root", src=UPat(UOps.CONST, name="c")), lambda root, c: UOp.const(root.dtype, c.arg)), # a phi on a DEFINE_ACC without loops or a CONST is a noop. this is for correctness, not just speed (UPat(UOps.PHI, src=(UPat(UOps.DEFINE_ACC, name="acc"), UPat(name="acc"))), lambda acc: UOp.cast(acc.src[0], acc.dtype)), (UPat(UOps.PHI, src=(UPat(UOps.DEFINE_ACC, src=(UPat(UOps.CONST),)), UPat(name="x"))), lambda x: x), @@ -300,18 +303,23 @@ constant_folder = PatternMatcher([ # TODO: can do the invert of this (flip alt/load) when we fix double ops (UOp.store(UOp.var("buf"), UOp.var("idx"), UOp.alu(TernaryOps.WHERE, UOp.var("gate"), UOp.var("alt"), UOp.load(UOp.var("buf"), UOp.var("idx")))), lambda buf, idx, gate, alt: UOp.store(buf, idx, alt, gate)), - # store float4/float2 directly (remove CAST/GEP) - (UOp.store(UOp.var("buf"), UOp.var("idx"), UOp(UOps.CAST, src=tuple(UOp(UOps.GEP, arg=i, src=(UOp.var("val"),)) for i in range(4)))), UOp.store), - (UOp.store(UOp.var("buf"), UOp.var("idx"), UOp(UOps.CAST, src=tuple(UOp(UOps.GEP, arg=i, src=(UOp.var("val"),)) for i in range(2)))), UOp.store), - # CAST-PHI-GEP -> PHI-CAST - (UPat(UOps.CAST, name="root", src=tuple(UPat(UOps.PHI, src=(UPat(UOps.GEP, i, src=(UPat(name="val"),)), UPat(name=f"v{i}"))) for i in range(4))), - lambda root, val, v0, v1, v2, v3: UOp(UOps.PHI, root.dtype, (val, UOp(UOps.CAST, val.dtype, (v0, v1, v2, v3))))), - (UPat(UOps.CAST, name="root", src=tuple(UPat(UOps.PHI, src=(UPat(UOps.GEP, i, src=(UPat(name="val"),)), UPat(name=f"v{i}"))) for i in range(2))), - lambda root, val, v0, v1: UOp(UOps.PHI, root.dtype, (val, UOp(UOps.CAST, val.dtype, (v0, v1))))), + # store float4/float2 directly (remove VECTORIZE/GEP) + (UOp.store(UOp.var("buf"), UOp.var("idx"), UOp(UOps.VECTORIZE, src=tuple( + UOp(UOps.GEP, arg=i, src=(UOp.var("val"),)) for i in range(4)))), UOp.store), + (UOp.store(UOp.var("buf"), UOp.var("idx"), UOp(UOps.VECTORIZE, src=tuple( + UOp(UOps.GEP, arg=i, src=(UOp.var("val"),)) for i in range(2)))), UOp.store), + # VECTORIZE-PHI-GEP -> PHI-VECTORIZE + (UPat(UOps.VECTORIZE, name="root", src=tuple( + UPat(UOps.PHI, src=(UPat(UOps.GEP, i, src=(UPat(name="val"),)), UPat(name=f"v{i}"))) for i in range(4))), + lambda root, val, v0, v1, v2, v3: UOp(UOps.PHI, root.dtype, (val, UOp(UOps.VECTORIZE, val.dtype, (v0, v1, v2, v3))))), + (UPat(UOps.VECTORIZE, name="root", src=tuple( + UPat(UOps.PHI, src=(UPat(UOps.GEP, i, src=(UPat(name="val"),)), UPat(name=f"v{i}"))) for i in range(2))), + lambda root, val, v0, v1: UOp(UOps.PHI, root.dtype, (val, UOp(UOps.VECTORIZE, val.dtype, (v0, v1))))), # NEG/CMPLT -> CMPLT (UOp.lt(-UOp.var('x'), UOp.cvar('c', dtypes.int)), lambda c,x: UOp.lt(UOp.const(c.dtype, -c.arg), x)), # cast NOOP (NOTE: it's str to deal with PtrDType) (UPat(UOps.CAST, name="root"), lambda root: root.src[0] if str(root.dtype) == str(root.src[0].dtype) else None), + (UPat(UOps.VECTORIZE, name="root"), lambda root: root.src[0] if str(root.dtype) == str(root.src[0].dtype) else None), # fold gated LOAD/STORE (UOp.load(UOp.var("buf"), UOp.var("idx"), UOp.const(dtypes.bool, True), UOp.cvar("var")), lambda buf,idx,var: UOp.load(buf, idx, dtype=var.dtype)), (UOp.load(UOp.var("buf"), UOp.var("idx"), UOp.const(dtypes.bool, True), UOp.cvar("var"), UOp.var("barrier")), diff --git a/tinygrad/renderer/assembly.py b/tinygrad/renderer/assembly.py index ab7aadbb99..0e4e2d1eb7 100644 --- a/tinygrad/renderer/assembly.py +++ b/tinygrad/renderer/assembly.py @@ -197,10 +197,12 @@ class PTXRenderer(Renderer): else: kk(f"mov.b{self.types[dtype][1:]} {r[src[0]]}, {r[src[1]]};") r[u] = r[src[0]] + elif uop in {UOps.VECTORIZE}: + assert src[0].dtype is not None and dtype.count > 1 + r[u] = [r[x] for x in src] # type: ignore elif uop in {UOps.CAST, UOps.BITCAST}: - assert src[0].dtype is not None - if dtype.count>1: r[u] = [r[x] for x in src] # type: ignore - else: _cast(r[src[0]], dtype, src[0].dtype, bitcast=uop is UOps.BITCAST, u=u) + assert src[0].dtype is not None and dtype.count == 1 + _cast(r[src[0]], dtype, src[0].dtype, bitcast=uop is UOps.BITCAST, u=u) elif uop is UOps.DEFINE_LOCAL: # TODO: we should sum these, and fetch 0xC000 from somewhere assert args[1]*dtype.itemsize <= 0xC000, "too large local" diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index 3b34e6deed..e251e54641 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -33,9 +33,12 @@ class CStyleLanguage(Renderer): TernaryOps.WHERE: lambda a,b,c,dtype: f"({a}?{b}:{c})"} # returns a str expression of the casted xs with the given type - def render_cast(self, x:List[str], var_dtype:DType, bitcast=False) -> str: - if bitcast: return f"(*(({self.buffer_prefix}{self.render_dtype(var_dtype)}*)&{x[0]}))" - if len(x) == 1: return f"({self.render_dtype(var_dtype)})({x[0]})" + def render_cast(self, x:str, var_dtype:DType, bitcast=False) -> str: + if bitcast: return f"(*(({self.buffer_prefix}{self.render_dtype(var_dtype)}*)&{x}))" + return f"({self.render_dtype(var_dtype)})({x})" + + # returns a str expression of the vectorized xs with the given type + def render_vectorize(self, x:List[str], var_dtype:DType) -> str: assert len(x) == var_dtype.count, f"cast is wrong size {len(x)} != {var_dtype.count}" assert self.float4 is not None, "vectorized cast is not supported on this platform" return f"{self.float4.replace('float4', self.render_dtype(var_dtype))}({','.join(x)})" @@ -47,7 +50,8 @@ class CStyleLanguage(Renderer): elif dtype.scalar() == dtypes.bool: val = "1" if x else "0" elif dtype.scalar() == dtypes.float: val = f"{x}f" else: val = str(x) - return (self.render_cast([val] * dtype.count, dtype) if dtype.count > 1 or dtype not in [dtypes.float, dtypes.int, dtypes.bool] else val) + if dtype.count > 1: return self.render_vectorize([val] * dtype.count, dtype) + return (self.render_cast(val, dtype) if dtype not in [dtypes.float, dtypes.int, dtypes.bool] else val) # returns a str expression of the loaded value with the output type def render_load(self, output_dtype, buf_name, buf_dtype, idx, local=False) -> str: @@ -144,14 +148,14 @@ class CStyleLanguage(Renderer): elif uop is UOps.PHI: kk(f"{r[src[0]]} = {r[src[1]]};") r[u] = r[src[0]] - elif uop in {UOps.CAST, UOps.BITCAST}: + elif uop in {UOps.CAST, UOps.BITCAST, UOps.VECTORIZE}: + assert len(src) == 1 or (uop is UOps.VECTORIZE and len(src) > 1), "Invalid source length for operation" if uop is UOps.BITCAST: - assert len(src) == 1 precast = ssa('precast') kk(f"{self.render_dtype(cast(DType, src[0].dtype))} {precast} = {r[src[0]]};") - val = self.render_cast([precast], dtype, bitcast=True) - else: - val = self.render_cast([r[x] for x in src], dtype, bitcast=False) + val = self.render_cast(precast, dtype, bitcast=True) + elif uop is UOps.CAST: val = self.render_cast(r[src[0]], dtype, bitcast=False) + else: val = self.render_vectorize([r[x] for x in src], dtype) if child_count[u] <= 1: r[u] = val else: kk(f"{self.render_dtype(dtype)} {ssa('cast',u)} = {val};") elif uop is UOps.DEFINE_LOCAL: @@ -201,7 +205,7 @@ class OpenCLRenderer(CStyleLanguage): uses_vload = True type_map = { dtypes.uint8: "uchar", dtypes.uint32: "uint", dtypes.uint16: "ushort", dtypes.uint64: "ulong" } def render_cast(self, x, var_dtype, bitcast=False) -> str: - return f"as_{self.render_dtype(var_dtype)}({x[0]})" if bitcast else super().render_cast(x, var_dtype) + return f"as_{self.render_dtype(var_dtype)}({x})" if bitcast else super().render_cast(x, var_dtype) def render_kernel(self, function_name, kernel, bufs, uops, prefix=None) -> str: if any(uop.dtype == dtypes.half for uop in uops): prefix = ["#pragma OPENCL EXTENSION cl_khr_fp16 : enable"] @@ -232,8 +236,8 @@ class MetalRenderer(CStyleLanguage): UnaryOps.LOG2: lambda x,dtype: f"(bfloat)log2({x})" if dtype == dtypes.bfloat16 else f"log2({x})", UnaryOps.SIN: lambda x,dtype: f"(bfloat)precise::sin({x})" if dtype == dtypes.bfloat16 else f"precise::sin({x})",} - def render_cast(self, x: List[str], var_dtype: DType, bitcast=False) -> str: - return f"as_type<{self.render_dtype(var_dtype)}>({x[0]})" if bitcast else super().render_cast(x, var_dtype) + def render_cast(self, x: str, var_dtype: DType, bitcast=False) -> str: + return f"as_type<{self.render_dtype(var_dtype)}>({x})" if bitcast else super().render_cast(x, var_dtype) def render_kernel(self, function_name, kernel, bufs, uops, prefix=None): prefix, wmma_args = ["#include ","using namespace metal;"], set([uop.arg for uop in uops if uop.op is UOps.WMMA]) diff --git a/tinygrad/runtime/ops_python.py b/tinygrad/runtime/ops_python.py index 7f9fbc1d0d..7bc7329327 100644 --- a/tinygrad/runtime/ops_python.py +++ b/tinygrad/runtime/ops_python.py @@ -100,7 +100,7 @@ class PythonProgram: del ul[i] i = loop_ends[i] + 1 continue - elif uop in (UOps.CAST, UOps.BITCAST): + elif uop in (UOps.CAST, UOps.BITCAST, UOps.VECTORIZE): if dtype.count > 1: ul[i] = inp else: assert dtp[0].fmt and dtype.fmt