update outbufs selection in test_linearizer [pr] (#12166)

This commit is contained in:
qazal
2025-09-14 13:46:49 +03:00
committed by GitHub
parent d1ae30f7ef
commit 1591e4f66b

View File

@@ -482,7 +482,7 @@ def helper_realized_ast(r:Tensor|list[Tensor]) -> tuple[UOp, list[Buffer]]:
assert s[-1].ast.op is Ops.SINK, f"helper_realized_ast expects a SINK {s[-1]}"
# now all input buffers in s[-1] should be realized
# create fresh buffers for the outputs
bufs = [Buffer((x).device, x.size, x.dtype).allocate() if i < len(s[-1].ast.src) else x for i,x in enumerate(s[-1].bufs)]
bufs = [Buffer(x.device, x.size, x.dtype).allocate() if i < len(s[-1].ast.src) else x for i,x in enumerate(s[-1].bufs)]
return push_views(s[-1].ast), bufs
def helper_linearizer_ast(ast:UOp, inputs:list[Tensor], *args, **kwargs):
@@ -504,7 +504,7 @@ def reset_bufs(bufs:list[Buffer]):
def _helper_linearizer_opt_ast(realized_ast:UOp, real_bufs:list[Buffer], opts=[],
apply_tc=False, atol=1e-4, rtol=1e-4, color_sizes=[], wanna_output=[]):
outbufs = [real_bufs[x.src[0].base.arg] for x in realized_ast.src]
outbufs = real_bufs[:len(realized_ast.src)]
device = real_bufs[0].device
wanna_output = [np.array(x).flatten() for x in wanna_output]