fix load / barrier (#11386)

* fix load / barrier

* cleanups

* fix CI
This commit is contained in:
George Hotz
2025-07-26 10:27:37 -07:00
committed by GitHub
parent 65673e68ca
commit 2c70eaf18c
8 changed files with 24 additions and 16 deletions

View File

@@ -1,5 +1,5 @@
from tinygrad import Tensor, Device, Context, GlobalCounters, dtypes
from tinygrad.uop.ops import UOp, Ops, KernelInfo, graph_rewrite
from tinygrad.uop.ops import UOp, Ops, KernelInfo, graph_rewrite, AxisType
from tinygrad.engine.realize import CompiledRunner, ExecItem, get_program
from tinygrad.dtype import AddrSpace
from tinygrad.schedule.kernelize import merge_views
@@ -40,7 +40,11 @@ def hl_spec_kernel3():
A_col = A_col.reshape((1, nbIterWaveM, 1, TM, 1, 1, 1, 1, 1, 1)).expand(full_shape)
B_row = B_row.reshape((1, 1, 1, 1, 1, nbIterWaveN, 1, TN, 1, 1)).expand(full_shape)
out = (A_col.load(A_col.store(As.load(As.store(a.load())))) * B_row.load(B_row.store(Bs.load(Bs.store(b.load()))))).r(Ops.ADD, (8, 9))
sink = c.store(out).sink(arg=KernelInfo(name="tinygemm"))
axis_types = [
AxisType.GLOBAL, AxisType.UPCAST, AxisType.LOCAL, AxisType.UPCAST,
AxisType.GLOBAL, AxisType.UPCAST, AxisType.LOCAL, AxisType.UPCAST,
AxisType.REDUCE, AxisType.UNROLL]
sink = c.store(out).sink(arg=KernelInfo(name="tinygemm", axis_types=tuple(axis_types)))
sink = graph_rewrite(sink, merge_views)
return sink