split on tile_dim

This commit is contained in:
George Hotz
2025-10-07 17:40:00 +08:00
parent f129d75ee5
commit 5ad62f130d

View File

@@ -39,9 +39,15 @@ def store(gl:UOp, reg:UOp, *idxs):
def mma_AB(outacc:UOp, a:UOp, b:UOp, *endrngs):
assert a.shape[1] == b.shape[0]
rngs = [rng(s) for s in outacc.shape]
red = rng(a.shape[1], AxisType.REDUCE)
return outacc[*rngs].store(outacc[*rngs].load(red) + a[rngs[0],red].load() * b[red,rngs[1]].load(), *rngs, *endrngs, red, dtype=outacc.dtype).reshape(outacc.shape)
# meta::unroll_i_j_in_range -- split on TILE_DIM
rngs = [rng(s//TILE_DIM)*TILE_DIM for s in outacc.shape]
red = rng(a.shape[1]//TILE_DIM, AxisType.REDUCE)*TILE_DIM
# meta::unroll_i_in_range -- split reudce on TILE_DIM
rngs = [x+rng(TILE_DIM) for x in rngs]
red = red + rng(TILE_DIM, AxisType.REDUCE)
store_rngs = [x for x in UOp.sink(*rngs, red).toposort() if x.op is Ops.RANGE]
acc = outacc[*rngs].load(red) + a[rngs[0],red].load() * b[red,rngs[1]].load()
return outacc[*rngs].store(acc, *store_rngs, *endrngs, dtype=outacc.dtype).reshape(outacc.shape)
if __name__ == "__main__":
# TODO: support string ranges