diff --git a/tinygrad/engine/grouper.py b/tinygrad/engine/grouper.py index b33bb8d89a..42bb29df93 100644 --- a/tinygrad/engine/grouper.py +++ b/tinygrad/engine/grouper.py @@ -408,16 +408,17 @@ def get_becomes_map(big_sink:UOp) -> tuple[dict[UOp, UOp], dict[Variable, int]]: # map tensors to buffer/const, optionally apply a VIEW on top becomes_map: dict[UOp, UOp] = {} for k,v in tensor_map.items(): - # ASSIGN always becomes the target buffer - if v.op is Ops.ASSIGN: becomes_map[k] = v.src[0] - # if we created a new buffer for this tensor, map it to the assigned buffer - elif (a:=kernel_map.get(v.base)) is not None and (a:=a.base).op is Ops.ASSIGN: - becomes_map[k] = a.src[0] if a.src[0].st == v.st else a.src[0].view(unwrap(v.st)) - # tensors can also simplify to an existing buffer/const - else: - if k is v: continue - if v.base.op is Ops.BUFFER: becomes_map[k] = v - if v.base.op is Ops.CONST and all_int(v.shape): becomes_map[k] = v + if (kernel:=kernel_map.get(v.base)) is not None and kernel.base.op is Ops.ASSIGN: v = kernel.view(unwrap(v.st)) + if k is v: continue + if k.op is Ops.ASSIGN: + becomes_map[k] = k.src[0] + continue + op = v.base.op + if op is Ops.BUFFER: becomes_map[k] = v + if op is Ops.CONST and all_int(v.shape): becomes_map[k] = v + if op is Ops.ASSIGN: + new_buf = v.base.src[0] + becomes_map[k] = new_buf if new_buf.st == v.st else new_buf.view(unwrap(v.st)) # if a kernel depends on a buffer, and that buffer is later assigned to, make the assign depend on the kernel's assign kernel_assign: dict[UOp, UOp] = {}