diff --git a/tinygrad/codegen/__init__.py b/tinygrad/codegen/__init__.py index ea76d43e17..c0bf192ef5 100644 --- a/tinygrad/codegen/__init__.py +++ b/tinygrad/codegen/__init__.py @@ -42,7 +42,7 @@ def _get_rewrites_for_renderer(opts:Renderer, linearizer:bool, _QUANTIZE, _DEVEC # ** lowerer (rewrite_shapetracker_with_index) ** ret: list[RewriteStep] = [] if _QUANTIZE and opts.device in {"CPU", "DSP"}: ret.append(RewriteStep(pm_quant, name="quantize")) - ret.append(RewriteStep(pm_lowerer, get_index, name="lowerer")) + ret.append(RewriteStep(pm_lowerer, get_index, name="lowerer", bottom_up=True)) # ** expander (expand_rewrite) ** ret.append(RewriteStep(sym+migrate_indexing, name="initial symbolic")) diff --git a/tinygrad/codegen/lowerer.py b/tinygrad/codegen/lowerer.py index 116d18d2ef..a7b3649fd3 100644 --- a/tinygrad/codegen/lowerer.py +++ b/tinygrad/codegen/lowerer.py @@ -17,11 +17,6 @@ def get_index(ast:UOp) -> IndexContext: # NOTE: assumes the shape is full_shape = ast.full_shape first_upcasted = len(full_shape)-ki.upcasted - # if there's no reduce, this is first_upcasted. assumes reduces are at the end - first_reduce = min([first_upcasted]+flatten(x.axis_arg for x in ast.toposort() if x.op is Ops.REDUCE_AXIS)) - local_loads = [x for x in ast.toposort() if x.op is Ops.LOAD and x.src[0].base.op is Ops.DEFINE_LOCAL] - # NOTE: sum up the reduced axes looking across all local loads, yields the number of grouped reduces - group_for_reduces = sum([any(l.st_arg.shape[i]!=ast.src[0].st_arg.shape[i] for l in local_loads) for i in range(first_reduce,first_upcasted)]) # all loops are RANGES idxs = [UOp(Ops.RANGE, dtypes.int, (sint_to_uop(g),), i) for i,g in enumerate(full_shape[:first_upcasted])] @@ -32,6 +27,11 @@ def get_index(ast:UOp) -> IndexContext: idxs.append(UOp(Ops.UNROLL, dtypes.int, (UOp.const(dtypes.int.vec(g), tuple(range(g))),), ((i,g),))) # late indexes (group for reduce) + # if there's no reduce, this is first_upcasted. assumes reduces are at the end + first_reduce = min([first_upcasted]+flatten(x.axis_arg for x in ast.toposort() if x.op is Ops.REDUCE_AXIS)) + local_loads = [x for x in ast.toposort() if x.op is Ops.LOAD and x.src[0].base.op is Ops.DEFINE_LOCAL] + # NOTE: sum up the reduced axes looking across all local loads, yields the number of grouped reduces + group_for_reduces = sum([any(l.st_arg.shape[i]!=ast.src[0].st_arg.shape[i] for l in local_loads) for i in range(first_reduce,first_upcasted)]) ridxs = idxs[:] for a in range(first_reduce, first_reduce+group_for_reduces): ridxs[a] = UOp(Ops.RANGE, dtypes.int, (sint_to_uop(full_shape[a]),), 1000+a) @@ -51,11 +51,13 @@ def lower_reduce_axis(ctx: IndexContext, x: UOp): # REDUCE supports both "horizontal" reduction and range reduction. the horizontal elements are taken in the nearest group return UOp(Ops.REDUCE, x.dtype, (ret,)+tuple(reduce_range), alu_op) -def lower_load_store(ctx: IndexContext, x: UOp, buf: UOp): - idx, valid = x.st_arg.to_indexed_uops(ctx.ridxs if x.op is Ops.LOAD and buf.op is Ops.DEFINE_LOCAL else ctx.idxs) - if x.op is Ops.LOAD: - barrier = (UOp(Ops.BARRIER, dtypes.void, (x.src[1],)),) if buf.op is Ops.DEFINE_LOCAL else () - return UOp(Ops.LOAD, x.dtype, (buf.index(idx, valid),) + barrier) +def lower_load(ctx: IndexContext, x: UOp, buf: UOp): + idx, valid = x.st_arg.to_indexed_uops(ctx.ridxs if buf.op is Ops.DEFINE_LOCAL else ctx.idxs) + barrier = (UOp(Ops.BARRIER, dtypes.void, (x.src[1],)),) if buf.op is Ops.DEFINE_LOCAL else () + return UOp(Ops.LOAD, x.dtype, (buf.index(idx, valid),) + barrier) + +def lower_store(ctx: IndexContext, x: UOp, buf: UOp): + idx, valid = x.st_arg.to_indexed_uops(ctx.idxs) # NOTE: only store the local reduceop in the threads that are actually doing the reduce if cast(PtrDType, buf.dtype).local and x.src[1].op is Ops.REDUCE: reduce_input = x.src[1].src[0] @@ -82,6 +84,6 @@ pm_lowerer = PatternMatcher([ (UPat(Ops.REDUCE_AXIS, name="x"), lower_reduce_axis), (UPat(Ops.VIEW, src=(UPat((Ops.CONST, Ops.DEFINE_VAR), name="c"),), name="view"), lower_const), # rewrite LOAD/STORE VIEW to LOAD/STORE with indexed - (UPat((Ops.LOAD, Ops.STORE), src=(UPat.var("buf").view(),), allow_any_len=True, name="x"), lower_load_store), - (UPat(Ops.INDEX, src=(UPat.var("b"), UPat.var("idx"), UPat.const(dtypes.bool, True))), lambda b, idx: b.index(idx)), + (UPat(Ops.LOAD, src=(UPat.var("buf").view(),), allow_any_len=True, name="x"), lower_load), + (UPat(Ops.STORE, src=(UPat.var("buf").view(),), allow_any_len=True, name="x"), lower_store), ]) diff --git a/tinygrad/uop/symbolic.py b/tinygrad/uop/symbolic.py index 6ee9ed5813..acf8d698ba 100644 --- a/tinygrad/uop/symbolic.py +++ b/tinygrad/uop/symbolic.py @@ -443,6 +443,8 @@ sym = symbolic_flat+PatternMatcher([ (UPat.var("a").where(UPat.var("b").where(UPat.var("c"), UPat.var("d")), UPat.var("d")), lambda a,b,c,d: (a&b).where(c,d)), # ** pow ** ((UPat(Ops.POW, name="p"), lambda p: xpow(*p.src))), + # index true is index without op + (UPat(Ops.INDEX, src=(UPat.var("b"), UPat.var("idx"), UPat.const(dtypes.bool, True))), lambda b, idx: b.index(idx)), # ** load/store folding ** (UPat.store(UPat(Ops.INDEX, name="index"), UPat.load(UPat(Ops.INDEX, name="index"))), lambda index: UOp(Ops.NOOP)), (UPat.store(UPat(Ops.INDEX, name="index"), UPat.var("gate").where(UPat.var("alt"), UPat.load(UPat(Ops.INDEX, name="index")))),