mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-25 06:48:22 -05:00
[pr] move has_valid into pm_lowerer (#8308)
* [pr] move has_valid into pm_lowerer * simpler
This commit is contained in:
@@ -106,12 +106,10 @@ def lower_reduce_axis(ctx: IndexContext, x: UOp):
|
||||
|
||||
def lower_load_store(ctx: IndexContext, x: UOp):
|
||||
idx, valid = x.st_arg.to_indexed_uops(ctx.ridxs if x.op is Ops.LOAD and x.src[0].op is Ops.DEFINE_LOCAL else ctx.idxs)
|
||||
# TODO: check has_valid in UPat, not here
|
||||
has_valid = valid.op is not Ops.CONST or valid.arg is not True
|
||||
buf = x.src[0]
|
||||
if x.op is Ops.LOAD:
|
||||
barrier = (UOp(Ops.BARRIER, dtypes.void, (x.src[2],)),) if x.src[0].op is Ops.DEFINE_LOCAL else ()
|
||||
return UOp(Ops.LOAD, x.dtype, (buf.index(idx, valid if has_valid else None),) + barrier)
|
||||
return UOp(Ops.LOAD, x.dtype, (buf.index(idx, valid),) + barrier)
|
||||
# NOTE: only store the local reduceop in the threads that are actually doing the reduce
|
||||
if cast(PtrDType, x.src[0].dtype).local and x.src[2].op is Ops.ASSIGN:
|
||||
reduce_input = x.src[2].src[1].src[1] if x.src[2].src[1].src[1] is not x.src[2].src[0] else x.src[2].src[1].src[0]
|
||||
@@ -122,14 +120,14 @@ def lower_load_store(ctx: IndexContext, x: UOp):
|
||||
if (not cast(PtrDType, x.src[0].dtype).local) or store_back:
|
||||
for oidx, ridx in zip(ctx.idxs, ctx.ridxs):
|
||||
if oidx is not ridx: valid = valid * oidx.eq(0)
|
||||
has_valid = valid.op is not Ops.CONST or valid.arg is not True
|
||||
return UOp(Ops.STORE, dtypes.void, (buf.index(idx, valid if has_valid else None), x.src[2]))
|
||||
return UOp(Ops.STORE, dtypes.void, (buf.index(idx, valid), x.src[2]))
|
||||
|
||||
pm_lowerer = PatternMatcher([
|
||||
(UPat(Ops.REDUCE_AXIS, name="x"), lower_reduce_axis),
|
||||
(UPat(Ops.VALID, src=(UPat(Ops.VIEW),), name="x"), lambda ctx,x: x.st_arg.to_indexed_uops(ctx.idxs)[1]),
|
||||
# rewrite LOAD/STORE VIEW to LOAD/STORE with indexed
|
||||
(UPat((Ops.LOAD, Ops.STORE), src=(UPat(), UPat(Ops.VIEW)), allow_any_len=True, name="x"), lower_load_store),
|
||||
(UPat(Ops.INDEX, src=(UPat.var("b"), UPat.var("idx"), UPat.const(dtypes.bool, True))), lambda b, idx: b.index(idx)),
|
||||
])
|
||||
|
||||
def rewrite_shapetracker_with_index(ast:UOp, opts:Renderer) -> UOp: return graph_rewrite(ast, pm_lowerer, ctx=get_index(ast, opts))
|
||||
|
||||
Reference in New Issue
Block a user