mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
split on tile_dim
This commit is contained in:
@@ -39,9 +39,15 @@ def store(gl:UOp, reg:UOp, *idxs):
|
||||
|
||||
def mma_AB(outacc:UOp, a:UOp, b:UOp, *endrngs):
|
||||
assert a.shape[1] == b.shape[0]
|
||||
rngs = [rng(s) for s in outacc.shape]
|
||||
red = rng(a.shape[1], AxisType.REDUCE)
|
||||
return outacc[*rngs].store(outacc[*rngs].load(red) + a[rngs[0],red].load() * b[red,rngs[1]].load(), *rngs, *endrngs, red, dtype=outacc.dtype).reshape(outacc.shape)
|
||||
# meta::unroll_i_j_in_range -- split on TILE_DIM
|
||||
rngs = [rng(s//TILE_DIM)*TILE_DIM for s in outacc.shape]
|
||||
red = rng(a.shape[1]//TILE_DIM, AxisType.REDUCE)*TILE_DIM
|
||||
# meta::unroll_i_in_range -- split reudce on TILE_DIM
|
||||
rngs = [x+rng(TILE_DIM) for x in rngs]
|
||||
red = red + rng(TILE_DIM, AxisType.REDUCE)
|
||||
store_rngs = [x for x in UOp.sink(*rngs, red).toposort() if x.op is Ops.RANGE]
|
||||
acc = outacc[*rngs].load(red) + a[rngs[0],red].load() * b[red,rngs[1]].load()
|
||||
return outacc[*rngs].store(acc, *store_rngs, *endrngs, dtype=outacc.dtype).reshape(outacc.shape)
|
||||
|
||||
if __name__ == "__main__":
|
||||
# TODO: support string ranges
|
||||
|
||||
Reference in New Issue
Block a user