mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
store/load not pass through index (#11381)
* noop * fix noop * store cat is NOOP * store dtype is void * stores aren't passed through anymore * meh, skip those for ptx * correct ptx skip * hl runs
This commit is contained in:
@@ -4,6 +4,7 @@ from tinygrad.engine.realize import CompiledRunner, ExecItem, get_program
|
||||
from tinygrad.dtype import AddrSpace
|
||||
from tinygrad.schedule.kernelize import merge_views
|
||||
from tinygrad.helpers import getenv
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
|
||||
N = 4096
|
||||
run_count = 5
|
||||
@@ -20,13 +21,14 @@ def hl_spec_kernel3():
|
||||
nbIterWaveN = 2
|
||||
|
||||
# define buffers
|
||||
a = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(N*N), arg=0)
|
||||
b = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(N*N), arg=1)
|
||||
c = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(N*N), arg=2)
|
||||
As = UOp(Ops.DEFINE_LOCAL, dtypes.float.ptr(BK*BM, AddrSpace.LOCAL), arg=0)
|
||||
Bs = UOp(Ops.DEFINE_LOCAL, dtypes.float.ptr(BK*BN, AddrSpace.LOCAL), arg=1)
|
||||
A_col = UOp(Ops.DEFINE_REG, dtypes.float.ptr(nbIterWaveM * TM, AddrSpace.REG), arg=0)
|
||||
B_row = UOp(Ops.DEFINE_REG, dtypes.float.ptr(nbIterWaveN * TN, AddrSpace.REG), arg=1)
|
||||
# TODO: remove these views once the defines have a shape
|
||||
a = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(N*N), arg=0).view(ShapeTracker.from_shape((N*N,)))
|
||||
b = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(N*N), arg=1).view(ShapeTracker.from_shape((N*N,)))
|
||||
c = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(N*N), arg=2).view(ShapeTracker.from_shape((N*N,)))
|
||||
As = UOp(Ops.DEFINE_LOCAL, dtypes.float.ptr(BK*BM, AddrSpace.LOCAL), arg=0).view(ShapeTracker.from_shape((BK*BM,)))
|
||||
Bs = UOp(Ops.DEFINE_LOCAL, dtypes.float.ptr(BK*BN, AddrSpace.LOCAL), arg=1).view(ShapeTracker.from_shape((BK*BN,)))
|
||||
A_col = UOp(Ops.DEFINE_REG, dtypes.float.ptr(nbIterWaveM * TM, AddrSpace.REG), arg=0).view(ShapeTracker.from_shape((nbIterWaveM * TM,)))
|
||||
B_row = UOp(Ops.DEFINE_REG, dtypes.float.ptr(nbIterWaveN * TN, AddrSpace.REG), arg=1).view(ShapeTracker.from_shape((nbIterWaveN * TN,)))
|
||||
|
||||
# shape buffers. TODO: permutes
|
||||
full_shape = (N//BM, nbIterWaveM, BM//(nbIterWaveM * TM), TM, N//BN, nbIterWaveN, BN//(nbIterWaveN * TN), TN, N//BK, BK)
|
||||
@@ -37,8 +39,7 @@ def hl_spec_kernel3():
|
||||
Bs = Bs.reshape((1, 1, 1, 1, 1, nbIterWaveN, BN//(nbIterWaveN * TN), TN, 1, BK)).expand(full_shape)
|
||||
A_col = A_col.reshape((1, nbIterWaveM, 1, TM, 1, 1, 1, 1, 1, 1)).expand(full_shape)
|
||||
B_row = B_row.reshape((1, 1, 1, 1, 1, nbIterWaveN, 1, TN, 1, 1)).expand(full_shape)
|
||||
|
||||
out = (A_col.store(As.store(a.load()).load()).load() * B_row.store(Bs.store(b.load()).load()).load()).r(Ops.ADD, (8, 9))
|
||||
out = (A_col.load(A_col.store(As.load(As.store(a.load())))) * B_row.load(B_row.store(Bs.load(Bs.store(b.load()))))).r(Ops.ADD, (8, 9))
|
||||
sink = c.store(out).sink(arg=KernelInfo(name="tinygemm"))
|
||||
sink = graph_rewrite(sink, merge_views)
|
||||
return sink
|
||||
|
||||
@@ -291,7 +291,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
realized_ast = realized_ast.replace(arg=KernelInfo(opts_to_apply=tuple(opts_to_apply)))
|
||||
program = get_program(realized_ast, Device[Device.DEFAULT].renderer)
|
||||
|
||||
stores = [u for u in program.uops if u.op is Ops.STORE and u.dtype.addrspace != AddrSpace.REG]
|
||||
stores = [u for u in program.uops if u.op is Ops.STORE and u.src[0].dtype.addrspace != AddrSpace.REG]
|
||||
|
||||
# the first store is to lds and can be upcasted
|
||||
assert stores[0].src[1].dtype == dtypes.float.vec(4)
|
||||
@@ -633,6 +633,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
helper(Tensor.arange(255), max_ops=2)
|
||||
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "test requires float4")
|
||||
@unittest.skipIf(getenv("PTX"), "broken on ptx for some reason")
|
||||
def test_grouped_store_phis(self):
|
||||
"""
|
||||
float4 acc0 = float4(0.0,0.0,0.0,0.0);
|
||||
@@ -648,7 +649,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
k = helper_linearizer_opt(out)[-1]
|
||||
uops = get_program(k.get_optimized_ast(), k.opts).uops
|
||||
# check that the float4 cast collapses
|
||||
store_vals = [u.src[1] for u in uops if u.op is Ops.STORE and u.dtype.addrspace != AddrSpace.REG]
|
||||
store_vals = [u.src[1] for u in uops if u.op is Ops.STORE and u.src[0].dtype.addrspace != AddrSpace.REG]
|
||||
for val in store_vals:
|
||||
assert val.dtype == dtypes.float.vec(4) # and val.op is not Ops.VECTORIZE
|
||||
|
||||
@@ -699,12 +700,13 @@ class TestLinearizer(unittest.TestCase):
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared")
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "test requires float4")
|
||||
@unittest.skipIf(getenv("PTX"), "broken on ptx for some reason")
|
||||
def test_grouped_store_local_only(self):
|
||||
x, y = Tensor.rand(1,128), Tensor.rand(128, 128)
|
||||
r = (x@y).relu()
|
||||
k = helper_linearizer_opt(r)[-1]
|
||||
uops = get_program(k.get_optimized_ast(), k.opts).uops
|
||||
stores = [u for u in uops if u.op is Ops.STORE and u.dtype.addrspace != AddrSpace.REG]
|
||||
stores = [u for u in uops if u.op is Ops.STORE and u.src[0].dtype.addrspace != AddrSpace.REG]
|
||||
|
||||
# the float4 value stores directly in lds and we skip upcast
|
||||
self.assertEqual(stores[0].src[1].dtype, dtypes.float.vec(4))
|
||||
|
||||
@@ -5,7 +5,7 @@ from dataclasses import dataclass
|
||||
from tinygrad.dtype import dtypes, ImageDType, PtrDType, DType, AddrSpace
|
||||
from tinygrad.uop.ops import UOp, Ops, UPat, PatternMatcher, graph_rewrite, GroupOp, identity_element
|
||||
from tinygrad.uop.symbolic import split_uop, uop_given_valid, parse_valid, simplify_valid, sym, symbolic_flat
|
||||
from tinygrad.helpers import getenv, flatten, AMX, prod, partition, all_same
|
||||
from tinygrad.helpers import getenv, flatten, AMX, prod, partition
|
||||
from tinygrad.renderer import Renderer
|
||||
|
||||
# ***** image load valid simplification *****
|
||||
@@ -111,11 +111,7 @@ def cat_after_store(cat:UOp, data:UOp, sto:UOp):
|
||||
for s in cat.src:
|
||||
ret.append(s.store(data.gep(tuple(range(offset, offset+s.dtype.count))), *sto.src[2:]))
|
||||
offset += s.dtype.count
|
||||
# dtype CAT
|
||||
dtypes: list[PtrDType] = [x.dtype for x in ret if isinstance(x.dtype, PtrDType)]
|
||||
assert len(dtypes) == len(ret) and all_same([(x.size, x.addrspace) for x in dtypes])
|
||||
out_dtype = dtypes[0].base.scalar().vec(sum([x.count for x in dtypes])).ptr(dtypes[0].size, dtypes[0].addrspace)
|
||||
return UOp(Ops.PTRCAT, dtype=out_dtype, src=tuple(ret))
|
||||
return UOp(Ops.NOOP, src=tuple(ret))
|
||||
|
||||
def gep_on_store(gep:UOp, st:UOp, sto:UOp):
|
||||
# NOTE: we need to invert the gep here, but it may be an expanding gep
|
||||
@@ -184,7 +180,8 @@ def split_load_store(ctx:Renderer|None, ls:UOp, idx:UOp):
|
||||
break
|
||||
|
||||
# if it wasn't split, we return None. otherwise we CAT them
|
||||
return UOp(Ops.CAT, ls.dtype, tuple(ret)) if len(ret) > 1 else None
|
||||
if len(ret) <= 1: return None
|
||||
return UOp(Ops.CAT, ls.dtype, tuple(ret)) if ls.op is Ops.LOAD else UOp(Ops.NOOP, src=tuple(ret))
|
||||
|
||||
def image_fixup(ls:UOp):
|
||||
# normal image load or store, with the CAST from expand_index
|
||||
@@ -287,10 +284,11 @@ def reduce_to_acc(ctx:ReduceContext, red:UOp):
|
||||
input_ranges = tuple([x for x in inp.toposort(gate=lambda x: x.op is not Ops.STORE) if x.op is Ops.RANGE and x not in reduce_range])
|
||||
identity = red.const_like(identity_element(red.arg, red.dtype.scalar()))
|
||||
acc = UOp(Ops.DEFINE_REG, red.dtype.ptr(size=1, addrspace=AddrSpace.REG), arg=(ctx.acc_num,)).index(UOp.const(dtypes.int, 0))
|
||||
lst = [acc.store(identity, UOp(Ops.NOOP, src=input_ranges)).load(*reduce_range)] + lst # put acc as the first element
|
||||
do_store = acc.store(identity, UOp(Ops.NOOP, src=input_ranges)) if len(input_ranges) else acc.store(identity)
|
||||
lst = [acc.load(do_store, *reduce_range)] + lst # put acc as the first element
|
||||
ctx.acc_num += 1
|
||||
ret = functools.reduce(lambda x,y: x.alu(red.arg, y), lst)
|
||||
return acc.store(ret, *reduce_range).load() if len(reduce_range) != 0 else ret
|
||||
return acc.load(acc.store(ret, *reduce_range)) if len(reduce_range) != 0 else ret
|
||||
|
||||
def no_vectorized_reduce(inp:UOp, red:UOp):
|
||||
if inp.dtype != red.dtype:
|
||||
|
||||
@@ -167,10 +167,8 @@ class CStyleLanguage(Renderer):
|
||||
(u.op in {Ops.VECTORIZE, *(GroupOp.ALU-{Ops.WHERE}), Ops.CAST, Ops.BITCAST} and child_count[u] == 1 and not getenv("EXPAND_SSA"))):
|
||||
r[u] = l
|
||||
else:
|
||||
if u.op in {Ops.RANGE, Ops.DEFINE_LOCAL, Ops.STORE, Ops.DEFINE_REG} or u.dtype == dtypes.void:
|
||||
if u.op is Ops.STORE: r[u] = r[u.src[0]]
|
||||
else:
|
||||
l = f"{self.render_dtype(u.dtype)} {r[u]} = {l}" + (";" if u.op is not Ops.SPECIAL else "")
|
||||
if u.op in {Ops.RANGE, Ops.DEFINE_LOCAL, Ops.STORE, Ops.DEFINE_REG} or u.dtype == dtypes.void: pass
|
||||
else: l = f"{self.render_dtype(u.dtype)} {r[u]} = {l}" + (";" if u.op is not Ops.SPECIAL else "")
|
||||
kernel.append(" "*depth + l)
|
||||
if prefix: c[prefix] += 1 # if it was used, increment
|
||||
if u.op in {Ops.IF, Ops.RANGE}: depth += 1
|
||||
|
||||
@@ -188,9 +188,6 @@ class LLVMRenderer(Renderer):
|
||||
if (l:=self.string_rewrite.rewrite(u, ctx=r)) is None:
|
||||
raise RuntimeError(f"failed to render {u.op} with {u.dtype} srcs {[x.dtype for x in u.src]}")
|
||||
kernel.append(cast(str, l))
|
||||
|
||||
# stores pass the first arg through
|
||||
if u.op is Ops.STORE: r[u] = r[u.src[0]]
|
||||
return tuple(local_args), self._render_fn(name, args, kernel, prefix)
|
||||
|
||||
barrier = 'fence syncscope("workgroup") release\ntail call void @llvm.amdgcn.s.barrier()\nfence syncscope("workgroup") acquire\n'
|
||||
|
||||
@@ -40,7 +40,7 @@ class PythonProgram:
|
||||
loop_ends: dict[int, int] = {}
|
||||
while i < len(self.uops):
|
||||
uop, dtype, idp, arg = self.uops[i]
|
||||
void_ops = {Ops.ENDRANGE, Ops.BARRIER, Ops.IF, Ops.ENDIF, Ops.SINK, Ops.NOOP}
|
||||
void_ops = {Ops.ENDRANGE, Ops.BARRIER, Ops.IF, Ops.ENDIF, Ops.SINK, Ops.NOOP, Ops.STORE}
|
||||
inp = [ul[v] for v in idp if self.uops[v][0] not in void_ops]
|
||||
dtp = [dl[v] for v in idp if self.uops[v][0] not in void_ops]
|
||||
if getenv("TRACE"): print(i, uop, dtype, arg, inp, dtp)
|
||||
@@ -58,7 +58,6 @@ class PythonProgram:
|
||||
for j,val in enumerate(inp[1] if dtp[1].count > 1 else [inp[1]]):
|
||||
for (m,o,g),v in zip(inp[0], val):
|
||||
if g: _store(m, o+j, v)
|
||||
ul[i] = inp[0]
|
||||
i += 1
|
||||
continue
|
||||
if uop in {Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL, Ops.DEFINE_REG}:
|
||||
|
||||
@@ -236,7 +236,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
i = (i,)
|
||||
return UOp(Ops.GEP, self.dtype.scalar().vec(len(i)) if len(i) > 1 else self.dtype.scalar(), (self,), i)
|
||||
def load(self, *src:UOp, **kwargs): return UOp(Ops.LOAD, dtype=kwargs.pop("dtype", self.dtype.base), src=(self,)+src, **kwargs)
|
||||
def store(self, *src:UOp, **kwargs): return UOp(Ops.STORE, self.dtype, (self,)+src, **kwargs)
|
||||
def store(self, *src:UOp, **kwargs): return UOp(Ops.STORE, dtypes.void, (self,)+src, **kwargs)
|
||||
def assign(self, x:UOp): return UOp(Ops.ASSIGN, self.dtype, (self, x))
|
||||
def alu(self, op, *src:UOp, **kwargs):
|
||||
out_dtype = (self, *src)[-1].dtype
|
||||
|
||||
@@ -405,8 +405,8 @@ def reduce_mul_chain(r:UOp):
|
||||
return r.replace(src=(prod(inside) if len(inside) else r.src[0].const_like(1),)+r.src[1:])*prod(outside)
|
||||
|
||||
# this is symbolic 2.0
|
||||
REMOVE_FROM_SINK = {Ops.SINK, Ops.UNROLL, Ops.PTRCAT, Ops.CAT}
|
||||
REMOVE_FROM_BARRIER = {Ops.VECTORIZE, Ops.SINK, Ops.CAT, Ops.PTRCAT}
|
||||
REMOVE_FROM_SINK = {Ops.SINK, Ops.UNROLL, Ops.PTRCAT, Ops.CAT, Ops.NOOP}
|
||||
REMOVE_FROM_BARRIER = {Ops.VECTORIZE, Ops.SINK, Ops.CAT, Ops.PTRCAT, Ops.NOOP}
|
||||
sym = symbolic_flat+PatternMatcher([
|
||||
# LOAD/STORE -> NOOP
|
||||
(UPat.var('x').store(UPat.var('x').load(), allow_any_len=True), lambda x: None if x.dtype.addrspace != AddrSpace.REG else x.src[0].src[0]),
|
||||
@@ -457,9 +457,6 @@ sym = symbolic_flat+PatternMatcher([
|
||||
(UPat().index(UPat(), UPat.const(dtypes.bool, True)).named("idx"), lambda idx: idx.replace(src=idx.src[0:2])), # remove True
|
||||
(UPat((Ops.LOAD, Ops.STORE), src=(UPat().index(UPat(), UPat.const(dtypes.bool, False)).or_casted(),), allow_any_len=True, name="x"),
|
||||
lambda x: UOp(Ops.NOOP) if x.op is Ops.STORE else x.const_like(0)), # NULL pointer store does nothing. NULL pointer load produces 0
|
||||
# remove NOOPs from SINK
|
||||
(UPat(Ops.SINK, name="root"),
|
||||
lambda root: UOp(Ops.SINK, root.dtype, a, root.arg) if len(a:=tuple(x for x in root.src if x.op is not Ops.NOOP)) != len(root.src) else None),
|
||||
# remove VECTORIZE from SINK/BARRIER. TODO: SINK/BARRIER are really the same thing at GLOBAL/LOCAL levels
|
||||
(UPat(Ops.BARRIER, name="root"),
|
||||
lambda root: UOp(Ops.BARRIER, root.dtype, tuple(flatten(x.src if x.op in REMOVE_FROM_BARRIER else (x,) for x in root.src)), root.arg)
|
||||
|
||||
Reference in New Issue
Block a user