UOp.const(x.dtype, y) -> x.const(y) [run_process_replay] (#6276)

This commit is contained in:
chenyu
2024-08-24 21:39:50 -04:00
committed by GitHub
parent 00282afa41
commit b86907c6c7
2 changed files with 4 additions and 5 deletions

View File

@@ -93,16 +93,16 @@ class IndependentLowerer:
idx, valid = x.st_arg.to_indexed_uops(self.ridxs if x.op is UOps.LOAD and x.src[0].op is UOps.DEFINE_LOCAL else self.idxs)
# TODO: check has_valid in UPat, not here
has_valid = valid.op is not UOps.CONST or valid.arg is not True
if x.op is UOps.CONST: return valid.where(UOp.const(x.dtype, x.arg), UOp.const(x.dtype, 0))
if x.op is UOps.CONST: return valid.where(x.const(x.arg), x.const(0))
buf = x.src[0]
if x.op is UOps.LOAD:
barrier = (UOp(UOps.BARRIER, None, (self.to_uop(x.src[2]),)),) if x.src[0].op is UOps.DEFINE_LOCAL else ()
return UOp(UOps.LOAD, x.dtype, (buf, idx) + ((UOp.const(x.dtype, 0), valid) if has_valid else ()) + barrier)
return UOp(UOps.LOAD, x.dtype, (buf, idx) + ((x.const(0), valid) if has_valid else ()) + barrier)
# NOTE: only store the local reduceop in the threads that are actually doing the reduce
store_back = x.src[0].op is UOps.DEFINE_LOCAL and x.src[2].op is UOps.REDUCE_AXIS and \
x.src[2].src[0].op is UOps.LOAD and x.src[2].src[0].src[0].op is UOps.DEFINE_LOCAL
# NOTE: If we're storing the reduced value back into each thread, need to zero-out the reduced axes
if store_back: idx, _ = x.st_arg.to_indexed_uops([UOp.const(u.dtype, 0) if i in x.src[2].arg[1] else u for i,u in enumerate(self.idxs)])
if store_back: idx, _ = x.st_arg.to_indexed_uops([u.const(0) if i in x.src[2].arg[1] else u for i,u in enumerate(self.idxs)])
if x.src[0].op is UOps.DEFINE_GLOBAL or store_back:
for oidx, ridx in zip(self.idxs, self.ridxs):
if oidx != ridx: valid = valid * oidx.eq(0)

View File

@@ -78,8 +78,7 @@ def fix_unfoldable_image_load(load:UOp, buf:UOp):
if len(new_src) >= 4:
new_src[2] = UOp(UOps.VECTORIZE, cast(DType, new_src[2].dtype).vec(4), tuple(new_src[2] for _ in range(4)))
vec_load = UOp(UOps.LOAD, cast(DType, load.dtype).vec(4), tuple(new_src))
return functools.reduce(lambda ret, i: id4.ne(i).where(ret, UOp(UOps.GEP, load.dtype, (vec_load,), i)),
range(4), UOp.const(load.dtype, float('nan')))
return functools.reduce(lambda ret, i: id4.ne(i).where(ret, UOp(UOps.GEP, load.dtype, (vec_load,), i)), range(4), load.const(float('nan')))
float4_folding = PatternMatcher([
(UPat(UOps.EXPAND, src=UPat(UOps.LOAD, src=(UPat(name="buf"), UPat()), allow_any_len=True), name="ex"), fold_expanded),