mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-24 22:38:16 -05:00
Directly store float4 nodes (#3564)
* float4 cast collapse * simplify cstyle * simplify uoptimizer * ci --------- Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
@@ -756,5 +756,56 @@ class TestLinearizerHelper(unittest.TestCase):
|
||||
idxs = (uidx0 // 5, uidx0 * 5, uidx1)
|
||||
assert expand_idxs(idxs) == (uidx0, NumNode(0), uidx1)
|
||||
|
||||
class TestLinearizerUOptimize(unittest.TestCase):
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].compiler.linearizer_opts.supports_float4, "device doesn't support float4")
|
||||
def test_grouped_store_phis(self):
|
||||
x, y = Tensor.randn(64,64), Tensor.randn(64,64)
|
||||
out = x.matmul(y)
|
||||
|
||||
k = Linearizer(create_schedule([out.lazydata])[-1].ast)
|
||||
k.hand_coded_optimizations()
|
||||
k.linearize()
|
||||
|
||||
# check that the float4 cast collapses
|
||||
store_vals = [u.vin[-1] for u in k.uops if u.uop is UOps.STORE]
|
||||
for val in store_vals:
|
||||
assert val.dtype == dtypes.float.vec(4) and val.uop != UOps.CAST
|
||||
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].compiler.linearizer_opts.supports_float4, "device doesn't support float4")
|
||||
def test_grouped_store_values(self):
|
||||
x = Tensor.randn((4,3,6,6)).realize()
|
||||
out = x.flip((0,1)).contiguous()
|
||||
|
||||
k = Linearizer(create_schedule([out.lazydata])[-1].ast)
|
||||
k.hand_coded_optimizations()
|
||||
k.linearize()
|
||||
|
||||
store_val = [u.vin[-1] for u in k.uops if u.uop is UOps.STORE][0]
|
||||
assert store_val.dtype == dtypes.float.vec(4) and store_val.uop != UOps.CAST
|
||||
|
||||
@unittest.skip("TODO: support locals replacement across the uop graph")
|
||||
def test_grouped_store_locals_and_globals(self):
|
||||
if not Device[Device.DEFAULT].compiler.linearizer_opts.has_local or not Device[Device.DEFAULT].compiler.linearizer_opts.has_shared:
|
||||
self.skipTest("Only Compiled uses linearizer with locals and shared")
|
||||
|
||||
x, y = Tensor.rand(128, 128), Tensor.rand(128, 128)
|
||||
out = x@y
|
||||
|
||||
opts = [Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.GROUPTOP, 0, 8),
|
||||
Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 2)] # upcast accs in both reduces
|
||||
k = Linearizer(create_schedule([out.lazydata])[-1].ast)
|
||||
for opt in opts: k.apply_opt(opt)
|
||||
k.linearize()
|
||||
|
||||
local_stores = [u for u in k.uops if u.uop is UOps.STORE and u.vin[0].uop is UOps.DEFINE_LOCAL]
|
||||
barrier = [u for u in k.uops if u.uop is UOps.BARRIER][0]
|
||||
global_stores = [u for u in k.uops if u.uop is UOps.STORE and u.vin[0].uop is UOps.DEFINE_GLOBAL]
|
||||
|
||||
# check that the float4 cast collapses for all stores
|
||||
for store in local_stores+global_stores:
|
||||
assert store.vin[-1].dtype == dtypes.float.vec(2) and store.vin[-1].uop != UOps.CAST
|
||||
# check that the barrier uses the new stores
|
||||
assert barrier.vin == tuple(local_stores)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
import functools, math
|
||||
from typing import List, Set, Optional, Tuple, Any, Dict, DefaultDict
|
||||
from typing import List, Set, Optional, Tuple, Any, Dict, DefaultDict, cast
|
||||
from collections import defaultdict
|
||||
from tinygrad.helpers import DEBUG, flatten, all_same
|
||||
from tinygrad.dtype import dtypes, DType
|
||||
@@ -56,6 +56,10 @@ def uop_alu_resolve(u:UOp) -> sint:
|
||||
else:
|
||||
raise RuntimeError(f"ALU resolve fail @ {u.uop}")
|
||||
|
||||
def phi_resolve_acc(u:UOp) -> UOp:
|
||||
if u.uop == UOps.DEFINE_ACC: return u
|
||||
return phi_resolve_acc(u.vin[0])
|
||||
|
||||
class UOpGraph:
|
||||
def __init__(self):
|
||||
# list of uops
|
||||
@@ -216,6 +220,17 @@ class UOpGraph:
|
||||
# (recursively) remove childless uops
|
||||
self.remove_childless()
|
||||
|
||||
# store float4 upcasts directly if possible
|
||||
replaced_stores: Dict[UOp,UOp] = {}
|
||||
for u in self.uops:
|
||||
if u.uop is not UOps.STORE or (val:=u.vin[-1]).uop is not UOps.CAST or cast(DType,val.dtype).count == 1: continue
|
||||
if u.vin[0].uop is UOps.DEFINE_LOCAL: continue # TODO add support for local store
|
||||
if all(el.uop is UOps.GEP for el in val.vin): replaced_stores[u] = val.vin[0].vin[0]
|
||||
elif all(el.uop is UOps.PHI for el in val.vin): replaced_stores[u] = phi_resolve_acc(val)
|
||||
for prev,new in replaced_stores.items():
|
||||
self.add(UOps.STORE, prev.dtype, (prev.vin[0],prev.vin[1],new), insert_before=self.uops.index(prev))
|
||||
self.uops.remove(prev)
|
||||
|
||||
# add UOps.END*
|
||||
self.add_ends()
|
||||
|
||||
@@ -247,4 +262,3 @@ class UOpGraph:
|
||||
elif u.arg == "__cuda_mma_m16n8k16_f16_f32": flops += 2*(8*16*16)//32 * mults
|
||||
else: raise Exception("not implemented")
|
||||
return flops, mem
|
||||
|
||||
|
||||
@@ -81,7 +81,8 @@ class CStyleLanguage(NamedTuple):
|
||||
if self.uses_vload and buf_dtype.scalar() == dtypes.float16 and var_dtype.scalar() != dtypes.float16:
|
||||
return f"vstore_half{'' if var_dtype.count == 1 else str(var_dtype.count)}({var_name}, 0, {buf_name}+{idx});"
|
||||
if var_dtype.count > 1:
|
||||
return f"*(({self.smem_prefix if local and self.smem_prefix_for_cast else self.buffer_prefix}{buf_dtype.name}{var_dtype.count}*)({buf_name}+{idx})) = ({buf_dtype.name}{var_dtype.count}){var_name};" # noqa: E501
|
||||
prefix = self.smem_prefix if local and self.smem_prefix_for_cast else self.buffer_prefix
|
||||
return f"*(({prefix}{buf_dtype.name}{var_dtype.count}*)({buf_name}+{idx})) = {var_name};"
|
||||
return f"*({buf_name}+{idx}) = {var_name};" if self.uses_ptr_arithmetic else f"{buf_name}[{idx}] = {var_name};"
|
||||
|
||||
def render_local(self, name:str, dtype:DType, size:int): return self.smem_align + self.smem_prefix + f"{dtype.name} {name}[{size}];"
|
||||
|
||||
Reference in New Issue
Block a user