diff --git a/tinygrad/callify.py b/tinygrad/callify.py index 6b34ae4674..494297ec4a 100644 --- a/tinygrad/callify.py +++ b/tinygrad/callify.py @@ -144,12 +144,11 @@ pm_early_transform_tensor_graph = PatternMatcher([ def untag_and_append(ctx:AllocCtx, x:UOp): if x.tag is None: return None ret = x.replace(tag=None) + replace_uop = ret + while replace_uop.op is Ops.AFTER: replace_uop = replace_uop.src[0] for t in x.tag: original_uop: UOp = ctx.uop_list[t] - replace_uop = ret - while replace_uop.op is Ops.AFTER: replace_uop = replace_uop.src[0] ctx.buffer_map[original_uop] = replace_uop.shrink_to(original_uop.shape) - if ret.op is not Ops.AFTER: ctx.assigns.append(ret) # AFTER gets appended by append_after return ret def append_after(ctx:AllocCtx, x:UOp):