Directly store float4 nodes (#3564)

* float4 cast collapse

* simplify cstyle

* simplify uoptimizer

* ci

---------

Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
qazal
2024-03-03 01:58:20 +02:00
committed by GitHub
parent 770707b376
commit a89afd4ffa
3 changed files with 69 additions and 3 deletions

View File

@@ -756,5 +756,56 @@ class TestLinearizerHelper(unittest.TestCase):
idxs = (uidx0 // 5, uidx0 * 5, uidx1)
assert expand_idxs(idxs) == (uidx0, NumNode(0), uidx1)
class TestLinearizerUOptimize(unittest.TestCase):
@unittest.skipUnless(Device[Device.DEFAULT].compiler.linearizer_opts.supports_float4, "device doesn't support float4")
def test_grouped_store_phis(self):
x, y = Tensor.randn(64,64), Tensor.randn(64,64)
out = x.matmul(y)
k = Linearizer(create_schedule([out.lazydata])[-1].ast)
k.hand_coded_optimizations()
k.linearize()
# check that the float4 cast collapses
store_vals = [u.vin[-1] for u in k.uops if u.uop is UOps.STORE]
for val in store_vals:
assert val.dtype == dtypes.float.vec(4) and val.uop != UOps.CAST
@unittest.skipUnless(Device[Device.DEFAULT].compiler.linearizer_opts.supports_float4, "device doesn't support float4")
def test_grouped_store_values(self):
x = Tensor.randn((4,3,6,6)).realize()
out = x.flip((0,1)).contiguous()
k = Linearizer(create_schedule([out.lazydata])[-1].ast)
k.hand_coded_optimizations()
k.linearize()
store_val = [u.vin[-1] for u in k.uops if u.uop is UOps.STORE][0]
assert store_val.dtype == dtypes.float.vec(4) and store_val.uop != UOps.CAST
@unittest.skip("TODO: support locals replacement across the uop graph")
def test_grouped_store_locals_and_globals(self):
if not Device[Device.DEFAULT].compiler.linearizer_opts.has_local or not Device[Device.DEFAULT].compiler.linearizer_opts.has_shared:
self.skipTest("Only Compiled uses linearizer with locals and shared")
x, y = Tensor.rand(128, 128), Tensor.rand(128, 128)
out = x@y
opts = [Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.GROUPTOP, 0, 8),
Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 2)] # upcast accs in both reduces
k = Linearizer(create_schedule([out.lazydata])[-1].ast)
for opt in opts: k.apply_opt(opt)
k.linearize()
local_stores = [u for u in k.uops if u.uop is UOps.STORE and u.vin[0].uop is UOps.DEFINE_LOCAL]
barrier = [u for u in k.uops if u.uop is UOps.BARRIER][0]
global_stores = [u for u in k.uops if u.uop is UOps.STORE and u.vin[0].uop is UOps.DEFINE_GLOBAL]
# check that the float4 cast collapses for all stores
for store in local_stores+global_stores:
assert store.vin[-1].dtype == dtypes.float.vec(2) and store.vin[-1].uop != UOps.CAST
# check that the barrier uses the new stores
assert barrier.vin == tuple(local_stores)
if __name__ == '__main__':
unittest.main()

View File

@@ -1,6 +1,6 @@
from __future__ import annotations
import functools, math
from typing import List, Set, Optional, Tuple, Any, Dict, DefaultDict
from typing import List, Set, Optional, Tuple, Any, Dict, DefaultDict, cast
from collections import defaultdict
from tinygrad.helpers import DEBUG, flatten, all_same
from tinygrad.dtype import dtypes, DType
@@ -56,6 +56,10 @@ def uop_alu_resolve(u:UOp) -> sint:
else:
raise RuntimeError(f"ALU resolve fail @ {u.uop}")
def phi_resolve_acc(u:UOp) -> UOp:
if u.uop == UOps.DEFINE_ACC: return u
return phi_resolve_acc(u.vin[0])
class UOpGraph:
def __init__(self):
# list of uops
@@ -216,6 +220,17 @@ class UOpGraph:
# (recursively) remove childless uops
self.remove_childless()
# store float4 upcasts directly if possible
replaced_stores: Dict[UOp,UOp] = {}
for u in self.uops:
if u.uop is not UOps.STORE or (val:=u.vin[-1]).uop is not UOps.CAST or cast(DType,val.dtype).count == 1: continue
if u.vin[0].uop is UOps.DEFINE_LOCAL: continue # TODO add support for local store
if all(el.uop is UOps.GEP for el in val.vin): replaced_stores[u] = val.vin[0].vin[0]
elif all(el.uop is UOps.PHI for el in val.vin): replaced_stores[u] = phi_resolve_acc(val)
for prev,new in replaced_stores.items():
self.add(UOps.STORE, prev.dtype, (prev.vin[0],prev.vin[1],new), insert_before=self.uops.index(prev))
self.uops.remove(prev)
# add UOps.END*
self.add_ends()
@@ -247,4 +262,3 @@ class UOpGraph:
elif u.arg == "__cuda_mma_m16n8k16_f16_f32": flops += 2*(8*16*16)//32 * mults
else: raise Exception("not implemented")
return flops, mem

View File

@@ -81,7 +81,8 @@ class CStyleLanguage(NamedTuple):
if self.uses_vload and buf_dtype.scalar() == dtypes.float16 and var_dtype.scalar() != dtypes.float16:
return f"vstore_half{'' if var_dtype.count == 1 else str(var_dtype.count)}({var_name}, 0, {buf_name}+{idx});"
if var_dtype.count > 1:
return f"*(({self.smem_prefix if local and self.smem_prefix_for_cast else self.buffer_prefix}{buf_dtype.name}{var_dtype.count}*)({buf_name}+{idx})) = ({buf_dtype.name}{var_dtype.count}){var_name};" # noqa: E501
prefix = self.smem_prefix if local and self.smem_prefix_for_cast else self.buffer_prefix
return f"*(({prefix}{buf_dtype.name}{var_dtype.count}*)({buf_name}+{idx})) = {var_name};"
return f"*({buf_name}+{idx}) = {var_name};" if self.uses_ptr_arithmetic else f"{buf_name}[{idx}] = {var_name};"
def render_local(self, name:str, dtype:DType, size:int): return self.smem_align + self.smem_prefix + f"{dtype.name} {name}[{size}];"