mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
update outbufs selection in test_linearizer [pr] (#12166)
This commit is contained in:
@@ -482,7 +482,7 @@ def helper_realized_ast(r:Tensor|list[Tensor]) -> tuple[UOp, list[Buffer]]:
|
||||
assert s[-1].ast.op is Ops.SINK, f"helper_realized_ast expects a SINK {s[-1]}"
|
||||
# now all input buffers in s[-1] should be realized
|
||||
# create fresh buffers for the outputs
|
||||
bufs = [Buffer((x).device, x.size, x.dtype).allocate() if i < len(s[-1].ast.src) else x for i,x in enumerate(s[-1].bufs)]
|
||||
bufs = [Buffer(x.device, x.size, x.dtype).allocate() if i < len(s[-1].ast.src) else x for i,x in enumerate(s[-1].bufs)]
|
||||
return push_views(s[-1].ast), bufs
|
||||
|
||||
def helper_linearizer_ast(ast:UOp, inputs:list[Tensor], *args, **kwargs):
|
||||
@@ -504,7 +504,7 @@ def reset_bufs(bufs:list[Buffer]):
|
||||
|
||||
def _helper_linearizer_opt_ast(realized_ast:UOp, real_bufs:list[Buffer], opts=[],
|
||||
apply_tc=False, atol=1e-4, rtol=1e-4, color_sizes=[], wanna_output=[]):
|
||||
outbufs = [real_bufs[x.src[0].base.arg] for x in realized_ast.src]
|
||||
outbufs = real_bufs[:len(realized_ast.src)]
|
||||
device = real_bufs[0].device
|
||||
wanna_output = [np.array(x).flatten() for x in wanna_output]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user