From c5ae66215aad41d94df59cd527cc5144e4f20a12 Mon Sep 17 00:00:00 2001 From: leopf <43857362+leopf@users.noreply.github.com> Date: Wed, 18 Dec 2024 18:05:18 +0100 Subject: [PATCH] [pr] move has_valid into pm_lowerer (#8308) * [pr] move has_valid into pm_lowerer * simpler --- tinygrad/codegen/lowerer.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/tinygrad/codegen/lowerer.py b/tinygrad/codegen/lowerer.py index 1fae75b42f..3ffbd5e9b1 100644 --- a/tinygrad/codegen/lowerer.py +++ b/tinygrad/codegen/lowerer.py @@ -106,12 +106,10 @@ def lower_reduce_axis(ctx: IndexContext, x: UOp): def lower_load_store(ctx: IndexContext, x: UOp): idx, valid = x.st_arg.to_indexed_uops(ctx.ridxs if x.op is Ops.LOAD and x.src[0].op is Ops.DEFINE_LOCAL else ctx.idxs) - # TODO: check has_valid in UPat, not here - has_valid = valid.op is not Ops.CONST or valid.arg is not True buf = x.src[0] if x.op is Ops.LOAD: barrier = (UOp(Ops.BARRIER, dtypes.void, (x.src[2],)),) if x.src[0].op is Ops.DEFINE_LOCAL else () - return UOp(Ops.LOAD, x.dtype, (buf.index(idx, valid if has_valid else None),) + barrier) + return UOp(Ops.LOAD, x.dtype, (buf.index(idx, valid),) + barrier) # NOTE: only store the local reduceop in the threads that are actually doing the reduce if cast(PtrDType, x.src[0].dtype).local and x.src[2].op is Ops.ASSIGN: reduce_input = x.src[2].src[1].src[1] if x.src[2].src[1].src[1] is not x.src[2].src[0] else x.src[2].src[1].src[0] @@ -122,14 +120,14 @@ def lower_load_store(ctx: IndexContext, x: UOp): if (not cast(PtrDType, x.src[0].dtype).local) or store_back: for oidx, ridx in zip(ctx.idxs, ctx.ridxs): if oidx is not ridx: valid = valid * oidx.eq(0) - has_valid = valid.op is not Ops.CONST or valid.arg is not True - return UOp(Ops.STORE, dtypes.void, (buf.index(idx, valid if has_valid else None), x.src[2])) + return UOp(Ops.STORE, dtypes.void, (buf.index(idx, valid), x.src[2])) pm_lowerer = PatternMatcher([ (UPat(Ops.REDUCE_AXIS, name="x"), lower_reduce_axis), (UPat(Ops.VALID, src=(UPat(Ops.VIEW),), name="x"), lambda ctx,x: x.st_arg.to_indexed_uops(ctx.idxs)[1]), # rewrite LOAD/STORE VIEW to LOAD/STORE with indexed (UPat((Ops.LOAD, Ops.STORE), src=(UPat(), UPat(Ops.VIEW)), allow_any_len=True, name="x"), lower_load_store), + (UPat(Ops.INDEX, src=(UPat.var("b"), UPat.var("idx"), UPat.const(dtypes.bool, True))), lambda b, idx: b.index(idx)), ]) def rewrite_shapetracker_with_index(ast:UOp, opts:Renderer) -> UOp: return graph_rewrite(ast, pm_lowerer, ctx=get_index(ast, opts))