mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
UOp.const(x.dtype, y) -> x.const(y) [run_process_replay] (#6276)
This commit is contained in:
@@ -93,16 +93,16 @@ class IndependentLowerer:
|
||||
idx, valid = x.st_arg.to_indexed_uops(self.ridxs if x.op is UOps.LOAD and x.src[0].op is UOps.DEFINE_LOCAL else self.idxs)
|
||||
# TODO: check has_valid in UPat, not here
|
||||
has_valid = valid.op is not UOps.CONST or valid.arg is not True
|
||||
if x.op is UOps.CONST: return valid.where(UOp.const(x.dtype, x.arg), UOp.const(x.dtype, 0))
|
||||
if x.op is UOps.CONST: return valid.where(x.const(x.arg), x.const(0))
|
||||
buf = x.src[0]
|
||||
if x.op is UOps.LOAD:
|
||||
barrier = (UOp(UOps.BARRIER, None, (self.to_uop(x.src[2]),)),) if x.src[0].op is UOps.DEFINE_LOCAL else ()
|
||||
return UOp(UOps.LOAD, x.dtype, (buf, idx) + ((UOp.const(x.dtype, 0), valid) if has_valid else ()) + barrier)
|
||||
return UOp(UOps.LOAD, x.dtype, (buf, idx) + ((x.const(0), valid) if has_valid else ()) + barrier)
|
||||
# NOTE: only store the local reduceop in the threads that are actually doing the reduce
|
||||
store_back = x.src[0].op is UOps.DEFINE_LOCAL and x.src[2].op is UOps.REDUCE_AXIS and \
|
||||
x.src[2].src[0].op is UOps.LOAD and x.src[2].src[0].src[0].op is UOps.DEFINE_LOCAL
|
||||
# NOTE: If we're storing the reduced value back into each thread, need to zero-out the reduced axes
|
||||
if store_back: idx, _ = x.st_arg.to_indexed_uops([UOp.const(u.dtype, 0) if i in x.src[2].arg[1] else u for i,u in enumerate(self.idxs)])
|
||||
if store_back: idx, _ = x.st_arg.to_indexed_uops([u.const(0) if i in x.src[2].arg[1] else u for i,u in enumerate(self.idxs)])
|
||||
if x.src[0].op is UOps.DEFINE_GLOBAL or store_back:
|
||||
for oidx, ridx in zip(self.idxs, self.ridxs):
|
||||
if oidx != ridx: valid = valid * oidx.eq(0)
|
||||
|
||||
@@ -78,8 +78,7 @@ def fix_unfoldable_image_load(load:UOp, buf:UOp):
|
||||
if len(new_src) >= 4:
|
||||
new_src[2] = UOp(UOps.VECTORIZE, cast(DType, new_src[2].dtype).vec(4), tuple(new_src[2] for _ in range(4)))
|
||||
vec_load = UOp(UOps.LOAD, cast(DType, load.dtype).vec(4), tuple(new_src))
|
||||
return functools.reduce(lambda ret, i: id4.ne(i).where(ret, UOp(UOps.GEP, load.dtype, (vec_load,), i)),
|
||||
range(4), UOp.const(load.dtype, float('nan')))
|
||||
return functools.reduce(lambda ret, i: id4.ne(i).where(ret, UOp(UOps.GEP, load.dtype, (vec_load,), i)), range(4), load.const(float('nan')))
|
||||
|
||||
float4_folding = PatternMatcher([
|
||||
(UPat(UOps.EXPAND, src=UPat(UOps.LOAD, src=(UPat(name="buf"), UPat()), allow_any_len=True), name="ex"), fold_expanded),
|
||||
|
||||
Reference in New Issue
Block a user