mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
white space
This commit is contained in:
@@ -104,7 +104,6 @@ def lower_reduce_axis(ctx: IndexContext, x: UOp):
|
||||
ctx.acc_num += 1
|
||||
return acc.assign(acc.alu(alu_op, ret))
|
||||
|
||||
|
||||
def lower_load_store(ctx: IndexContext, x: UOp):
|
||||
idx, valid = x.st_arg.to_indexed_uops(ctx.ridxs if x.op is Ops.LOAD and x.src[0].op is Ops.DEFINE_LOCAL else ctx.idxs)
|
||||
buf = x.src[0]
|
||||
@@ -117,8 +116,7 @@ def lower_load_store(ctx: IndexContext, x: UOp):
|
||||
store_back = reduce_input.op is Ops.LOAD and cast(PtrDType, reduce_input.src[0].dtype).local
|
||||
else: store_back = False
|
||||
# NOTE: If we're storing the reduced value back into each thread, need to zero-out the reduced axes
|
||||
if store_back:
|
||||
idx, _ = x.st_arg.to_indexed_uops([u.const_like(0) if u in x.src[2].src else u for u in ctx.idxs])
|
||||
if store_back: idx, _ = x.st_arg.to_indexed_uops([u.const_like(0) if u in x.src[2].src else u for u in ctx.idxs])
|
||||
if (not cast(PtrDType, x.src[0].dtype).local) or store_back:
|
||||
for oidx, ridx in zip(ctx.idxs, ctx.ridxs):
|
||||
if oidx is not ridx: valid = valid * oidx.eq(0)
|
||||
@@ -126,7 +124,7 @@ def lower_load_store(ctx: IndexContext, x: UOp):
|
||||
|
||||
pm_lowerer = PatternMatcher([
|
||||
(UPat(Ops.REDUCE_AXIS, name="x"), lower_reduce_axis),
|
||||
(UPat(Ops.VALID, src=(UPat(Ops.VIEW),), name="x"), lambda ctx, x: x.st_arg.to_indexed_uops(ctx.idxs)[1]),
|
||||
(UPat(Ops.VALID, src=(UPat(Ops.VIEW),), name="x"), lambda ctx,x: x.st_arg.to_indexed_uops(ctx.idxs)[1]),
|
||||
# rewrite LOAD/STORE VIEW to LOAD/STORE with indexed
|
||||
(UPat((Ops.LOAD, Ops.STORE), src=(UPat(), UPat(Ops.VIEW)), allow_any_len=True, name="x"), lower_load_store),
|
||||
(UPat(Ops.INDEX, src=(UPat.var("b"), UPat.var("idx"), UPat.const(dtypes.bool, True))), lambda b, idx: b.index(idx)),
|
||||
|
||||
Reference in New Issue
Block a user