diff --git a/test/test_linearizer.py b/test/test_linearizer.py index be8b65305d..9e4fa537b3 100644 --- a/test/test_linearizer.py +++ b/test/test_linearizer.py @@ -756,5 +756,56 @@ class TestLinearizerHelper(unittest.TestCase): idxs = (uidx0 // 5, uidx0 * 5, uidx1) assert expand_idxs(idxs) == (uidx0, NumNode(0), uidx1) +class TestLinearizerUOptimize(unittest.TestCase): + @unittest.skipUnless(Device[Device.DEFAULT].compiler.linearizer_opts.supports_float4, "device doesn't support float4") + def test_grouped_store_phis(self): + x, y = Tensor.randn(64,64), Tensor.randn(64,64) + out = x.matmul(y) + + k = Linearizer(create_schedule([out.lazydata])[-1].ast) + k.hand_coded_optimizations() + k.linearize() + + # check that the float4 cast collapses + store_vals = [u.vin[-1] for u in k.uops if u.uop is UOps.STORE] + for val in store_vals: + assert val.dtype == dtypes.float.vec(4) and val.uop != UOps.CAST + + @unittest.skipUnless(Device[Device.DEFAULT].compiler.linearizer_opts.supports_float4, "device doesn't support float4") + def test_grouped_store_values(self): + x = Tensor.randn((4,3,6,6)).realize() + out = x.flip((0,1)).contiguous() + + k = Linearizer(create_schedule([out.lazydata])[-1].ast) + k.hand_coded_optimizations() + k.linearize() + + store_val = [u.vin[-1] for u in k.uops if u.uop is UOps.STORE][0] + assert store_val.dtype == dtypes.float.vec(4) and store_val.uop != UOps.CAST + + @unittest.skip("TODO: support locals replacement across the uop graph") + def test_grouped_store_locals_and_globals(self): + if not Device[Device.DEFAULT].compiler.linearizer_opts.has_local or not Device[Device.DEFAULT].compiler.linearizer_opts.has_shared: + self.skipTest("Only Compiled uses linearizer with locals and shared") + + x, y = Tensor.rand(128, 128), Tensor.rand(128, 128) + out = x@y + + opts = [Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.GROUPTOP, 0, 8), + Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 2)] # upcast accs in both reduces + k = Linearizer(create_schedule([out.lazydata])[-1].ast) + for opt in opts: k.apply_opt(opt) + k.linearize() + + local_stores = [u for u in k.uops if u.uop is UOps.STORE and u.vin[0].uop is UOps.DEFINE_LOCAL] + barrier = [u for u in k.uops if u.uop is UOps.BARRIER][0] + global_stores = [u for u in k.uops if u.uop is UOps.STORE and u.vin[0].uop is UOps.DEFINE_GLOBAL] + + # check that the float4 cast collapses for all stores + for store in local_stores+global_stores: + assert store.vin[-1].dtype == dtypes.float.vec(2) and store.vin[-1].uop != UOps.CAST + # check that the barrier uses the new stores + assert barrier.vin == tuple(local_stores) + if __name__ == '__main__': unittest.main() diff --git a/tinygrad/codegen/uops.py b/tinygrad/codegen/uops.py index ff213cfe82..be2e90645b 100644 --- a/tinygrad/codegen/uops.py +++ b/tinygrad/codegen/uops.py @@ -1,6 +1,6 @@ from __future__ import annotations import functools, math -from typing import List, Set, Optional, Tuple, Any, Dict, DefaultDict +from typing import List, Set, Optional, Tuple, Any, Dict, DefaultDict, cast from collections import defaultdict from tinygrad.helpers import DEBUG, flatten, all_same from tinygrad.dtype import dtypes, DType @@ -56,6 +56,10 @@ def uop_alu_resolve(u:UOp) -> sint: else: raise RuntimeError(f"ALU resolve fail @ {u.uop}") +def phi_resolve_acc(u:UOp) -> UOp: + if u.uop == UOps.DEFINE_ACC: return u + return phi_resolve_acc(u.vin[0]) + class UOpGraph: def __init__(self): # list of uops @@ -216,6 +220,17 @@ class UOpGraph: # (recursively) remove childless uops self.remove_childless() + # store float4 upcasts directly if possible + replaced_stores: Dict[UOp,UOp] = {} + for u in self.uops: + if u.uop is not UOps.STORE or (val:=u.vin[-1]).uop is not UOps.CAST or cast(DType,val.dtype).count == 1: continue + if u.vin[0].uop is UOps.DEFINE_LOCAL: continue # TODO add support for local store + if all(el.uop is UOps.GEP for el in val.vin): replaced_stores[u] = val.vin[0].vin[0] + elif all(el.uop is UOps.PHI for el in val.vin): replaced_stores[u] = phi_resolve_acc(val) + for prev,new in replaced_stores.items(): + self.add(UOps.STORE, prev.dtype, (prev.vin[0],prev.vin[1],new), insert_before=self.uops.index(prev)) + self.uops.remove(prev) + # add UOps.END* self.add_ends() @@ -247,4 +262,3 @@ class UOpGraph: elif u.arg == "__cuda_mma_m16n8k16_f16_f32": flops += 2*(8*16*16)//32 * mults else: raise Exception("not implemented") return flops, mem - diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index 842b1721dc..4b875cd17d 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -81,7 +81,8 @@ class CStyleLanguage(NamedTuple): if self.uses_vload and buf_dtype.scalar() == dtypes.float16 and var_dtype.scalar() != dtypes.float16: return f"vstore_half{'' if var_dtype.count == 1 else str(var_dtype.count)}({var_name}, 0, {buf_name}+{idx});" if var_dtype.count > 1: - return f"*(({self.smem_prefix if local and self.smem_prefix_for_cast else self.buffer_prefix}{buf_dtype.name}{var_dtype.count}*)({buf_name}+{idx})) = ({buf_dtype.name}{var_dtype.count}){var_name};" # noqa: E501 + prefix = self.smem_prefix if local and self.smem_prefix_for_cast else self.buffer_prefix + return f"*(({prefix}{buf_dtype.name}{var_dtype.count}*)({buf_name}+{idx})) = {var_name};" return f"*({buf_name}+{idx}) = {var_name};" if self.uses_ptr_arithmetic else f"{buf_name}[{idx}] = {var_name};" def render_local(self, name:str, dtype:DType, size:int): return self.smem_align + self.smem_prefix + f"{dtype.name} {name}[{size}];"