fix load / barrier (#11386)

* fix load / barrier

* cleanups

* fix CI
This commit is contained in:
George Hotz
2025-07-26 10:27:37 -07:00
committed by GitHub
parent 65673e68ca
commit 2c70eaf18c
8 changed files with 24 additions and 16 deletions

View File

@@ -1,5 +1,5 @@
from tinygrad import Tensor, Device, Context, GlobalCounters, dtypes
from tinygrad.uop.ops import UOp, Ops, KernelInfo, graph_rewrite
from tinygrad.uop.ops import UOp, Ops, KernelInfo, graph_rewrite, AxisType
from tinygrad.engine.realize import CompiledRunner, ExecItem, get_program
from tinygrad.dtype import AddrSpace
from tinygrad.schedule.kernelize import merge_views
@@ -40,7 +40,11 @@ def hl_spec_kernel3():
A_col = A_col.reshape((1, nbIterWaveM, 1, TM, 1, 1, 1, 1, 1, 1)).expand(full_shape)
B_row = B_row.reshape((1, 1, 1, 1, 1, nbIterWaveN, 1, TN, 1, 1)).expand(full_shape)
out = (A_col.load(A_col.store(As.load(As.store(a.load())))) * B_row.load(B_row.store(Bs.load(Bs.store(b.load()))))).r(Ops.ADD, (8, 9))
sink = c.store(out).sink(arg=KernelInfo(name="tinygemm"))
axis_types = [
AxisType.GLOBAL, AxisType.UPCAST, AxisType.LOCAL, AxisType.UPCAST,
AxisType.GLOBAL, AxisType.UPCAST, AxisType.LOCAL, AxisType.UPCAST,
AxisType.REDUCE, AxisType.UNROLL]
sink = c.store(out).sink(arg=KernelInfo(name="tinygemm", axis_types=tuple(axis_types)))
sink = graph_rewrite(sink, merge_views)
return sink

View File

@@ -122,8 +122,8 @@ def gep_on_store(gep:UOp, st:UOp, sto:UOp):
return gep.src[0].store(st.gep(new_arg), *sto.src[2:])
load_store_folding = PatternMatcher([
(UPat(Ops.INDEX, src=(UPat(Ops.VECTORIZE, src=UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL), name="buf")), UPat.var("vec"))), expand_index),
(UPat(Ops.INDEX, src=(UPat(Ops.VECTORIZE, src=UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL), name="buf")), UPat.var("vec"),
(UPat(Ops.INDEX, src=(UPat(Ops.VECTORIZE, src=UPat(GroupOp.Defines, name="buf")), UPat.var("vec"))), expand_index),
(UPat(Ops.INDEX, src=(UPat(Ops.VECTORIZE, src=UPat(GroupOp.Defines, name="buf")), UPat.var("vec"),
UPat.var("mask"))), expand_index),
# GEP after LOAD
(UPat(Ops.LOAD, src=(UPat(Ops.GEP, name="gep"),), name="ld", allow_any_len=True),
@@ -154,6 +154,8 @@ def split_load_store(ctx:Renderer|None, ls:UOp, idx:UOp):
must_divide = False
elif buf.dtype.base != dtypes.float and buf.dtype.base != dtypes.half and not isinstance(buf.dtype, ImageDType):
pass
elif cast(PtrDType, buf.dtype).addrspace == AddrSpace.REG:
pass
elif isinstance(buf.dtype, ImageDType):
lengths = [4]
elif ctx is not None and ctx.supports_float4:

View File

@@ -86,9 +86,6 @@ expander = PatternMatcher([
(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.GEP, Ops.WMMA, Ops.LOAD, Ops.STORE, Ops.INDEX,
Ops.VECTORIZE, Ops.IF, Ops.REDUCE), name="root", custom_early_reject=set([Ops.UNROLL])), do_expand),
(UPat(Ops.CONTRACT, name="con"), do_contract),
# vectorize DEFINE_ACC
(UPat(Ops.VECTORIZE, src=UPat(Ops.DEFINE_REG, name="acc"), name="v"),
lambda acc,v: acc.replace(dtype=v.dtype, src=(acc.src[0].broadcast(v.dtype.count),)+acc.src[1:])),
# BARRIERs aren't actually expanded
(UPat(Ops.BARRIER, src=(UPat(Ops.UNROLL, name="ex"),)),
lambda ex: UOp(Ops.UNROLL, src=(UOp(Ops.BARRIER, src=ex.src),)*len(ex.src), arg=ex.arg)),

View File

@@ -48,7 +48,7 @@ def lower_reduce_axis(ctx: IndexContext, x: UOp):
def lower_load(ctx: IndexContext, x: UOp, buf: UOp):
idx, valid = x.st_arg.to_indexed_uops(ctx.ridxs if buf.op is Ops.DEFINE_LOCAL else ctx.idxs)
barrier = (UOp(Ops.BARRIER, dtypes.void, (x.src[1],)),) if buf.op is Ops.DEFINE_LOCAL else ()
barrier = tuple([y.barrier() if buf.op is Ops.DEFINE_LOCAL else y for y in x.src[1:]])
return UOp(Ops.LOAD, x.dtype, (buf.index(idx, valid),) + barrier)
def lower_store(ctx: IndexContext, x: UOp, buf: UOp):

View File

@@ -118,7 +118,9 @@ class CStyleLanguage(Renderer):
def render_dtype(self, dt:DType, mutable=True) -> str:
if isinstance(dt, ImageDType): return f"{'write_only' if mutable else 'read_only'} image2d_t"
if isinstance(dt, PtrDType):
prefix = self.smem_prefix if dt.addrspace == AddrSpace.LOCAL and self.smem_prefix_for_cast else self.buffer_prefix
prefix = ""
if dt.addrspace == AddrSpace.LOCAL and self.smem_prefix_for_cast: prefix = self.smem_prefix
if dt.addrspace == AddrSpace.GLOBAL: prefix = self.buffer_prefix
return prefix + self.render_dtype(dt.base) + "*"
if dt.count > 1: return self.type_map.get(scalar:=dt.scalar(), scalar.name).replace(" ", "_") + str(dt.count)
return self.type_map.get(scalar:=dt.scalar(), scalar.name)

View File

@@ -83,6 +83,8 @@ class GroupOp:
Ternary = {Ops.WHERE, Ops.MULACC}
ALU = set.union(Unary, Binary, Ternary)
Defines = {Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL, Ops.DEFINE_REG}
Irreducible = {Ops.CONST, Ops.DEFINE_VAR, Ops.SPECIAL, Ops.RANGE}
Movement = {Ops.RESHAPE, Ops.EXPAND, Ops.PERMUTE, Ops.PAD, Ops.SHRINK, Ops.FLIP}

View File

@@ -238,6 +238,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
def load(self, *src:UOp, **kwargs): return UOp(Ops.LOAD, dtype=kwargs.pop("dtype", self.dtype.base), src=(self,)+src, **kwargs)
def store(self, *src:UOp, **kwargs): return UOp(Ops.STORE, dtypes.void, (self,)+src, **kwargs)
def assign(self, x:UOp): return UOp(Ops.ASSIGN, self.dtype, (self, x))
def barrier(self): return UOp(Ops.BARRIER, src=(self,))
def alu(self, op, *src:UOp, **kwargs):
out_dtype = (self, *src)[-1].dtype
if op in {Ops.CMPLT, Ops.CMPNE}: out_dtype = dtypes.bool.vec(out_dtype.count) if out_dtype.count > 1 else dtypes.bool

View File

@@ -144,18 +144,18 @@ spec = PatternMatcher([
(UPat(Ops.CONST, name="x"), lambda x: type(x.arg) is type(dtypes.as_const(x.arg, x.dtype))),
# early LOAD has a <bufview, store?>
(UPat(Ops.LOAD, src=(UPat(Ops.VIEW, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)),)),)), lambda: True),
(UPat(Ops.LOAD, src=(UPat(Ops.VIEW, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)),)), UPat(Ops.STORE))), lambda: True),
(UPat(Ops.LOAD, src=(UPat(Ops.VIEW, src=(UPat(GroupOp.Defines),)),)), lambda: True),
(UPat(Ops.LOAD, src=(UPat(Ops.VIEW, src=(UPat(GroupOp.Defines),)), UPat(Ops.STORE))), lambda: True),
# early STORE has a <bufview, val>
(UPat(Ops.STORE, src=(UPat(Ops.VIEW, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)),)), UPat())), lambda: True),
(UPat(Ops.STORE, src=(UPat(Ops.VIEW, src=(UPat(GroupOp.Defines),)), UPat())), lambda: True),
# **** new style load/store ****
# INDEX is used in new style load/store
# INDEX takes a <buf, alu, gate?>
(UPat(Ops.INDEX, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL, Ops.DEFINE_REG)), UPat())), lambda: True),
(UPat(Ops.INDEX, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL, Ops.DEFINE_REG)), UPat(), UPat(dtype=dtypes.bool))), lambda: True),
(UPat(Ops.INDEX, src=(UPat(GroupOp.Defines), UPat())), lambda: True),
(UPat(Ops.INDEX, src=(UPat(GroupOp.Defines), UPat(), UPat(dtype=dtypes.bool))), lambda: True),
# LOAD on STORE
(UPat(Ops.LOAD, src=(UPat(Ops.STORE),), allow_any_len=True), lambda: True),
@@ -165,8 +165,8 @@ spec = PatternMatcher([
(UPat(Ops.LOAD, src=(index_pat,), allow_any_len=True), validate_index),
# STORE takes a <bufidx, val, gate?>
(UPat(Ops.STORE, src=(index_pat, UPat(name="val"), UPat(Ops.IF, name="gate")), allow_any_len=True), validate_store),
(UPat(Ops.STORE, src=(index_pat, UPat(name="val")), allow_any_len=True), validate_store),
(UPat(Ops.STORE, dtype=dtypes.void, src=(index_pat, UPat(name="val"), UPat(Ops.IF, name="gate")), allow_any_len=True), validate_store),
(UPat(Ops.STORE, dtype=dtypes.void, src=(index_pat, UPat(name="val")), allow_any_len=True), validate_store),
# most ALUs have all matching dtypes, except CMPLT, CMPNE, and WHERE
(UPat(Ops.WHERE, name="w", src=(UPat(dtype=dtypes.bool), UPat.var("x"), UPat.var("y"))), lambda w,x,y: w.dtype == x.dtype == y.dtype),