mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
reshape rawbufs in test_linearizer (#5492)
* reshape rawbufs in test_linearizer * fix helper_linearizer_ast
This commit is contained in:
@@ -166,7 +166,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
a = Tensor.randn(4, 1).realize()
|
||||
b = Tensor.randn(1, 1).realize()
|
||||
out = (a + b[0]).sum() + b[0]
|
||||
lin = helper_linearizer_opt(out, wanna_output=[(a.numpy()+b.numpy()[0]).sum()+b.numpy()[0]])[0]
|
||||
lin = helper_linearizer_opt(out, wanna_output=[(a.numpy()+b.numpy()[0]).sum()+b.numpy()])[0]
|
||||
ranges = [i for i,u in enumerate(lin.uops) if u.op is UOps.RANGE]
|
||||
# LOAD -> RANGE -> LOAD -> PHI
|
||||
assert lin.uops[ranges[0]-2].op is UOps.LOAD
|
||||
@@ -176,7 +176,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
a = Tensor.randn(2, ).realize()
|
||||
b = Tensor.randn(1, 1).realize()
|
||||
out = (a.reshape(2, 1).expand(2, 3) + b[0]).sum() + b[0]
|
||||
lin = helper_linearizer_opt(out, wanna_output=[(np.broadcast_to(a.numpy().reshape(2, 1), (2, 3)) + b.numpy()[0]).sum() + b.numpy()[0]])[0]
|
||||
lin = helper_linearizer_opt(out, wanna_output=[(np.broadcast_to(a.numpy().reshape(2, 1), (2, 3)) + b.numpy()[0]).sum() + b.numpy()])[0]
|
||||
ranges = [i for i,u in enumerate(lin.uops) if u.op is UOps.RANGE]
|
||||
if getenv("PTX"):
|
||||
# LOAD -> RANGE -> CAST -> ALU -> ALU -> LOAD -> ALU -> RANGE -> ALU -> PHI
|
||||
@@ -1046,10 +1046,10 @@ class TestHandCodedOpts(unittest.TestCase):
|
||||
assert k.local_dims == 1
|
||||
assert k.upcasted == 1
|
||||
|
||||
def helper_linearizer_ast(_ast:Tuple[LazyOp, ...], inputs:List[Tensor], *args, **kwargs):
|
||||
if not isinstance(_ast, LazyOp): ast = LazyOp(MetaOps.SINK, _ast)
|
||||
def helper_linearizer_ast(ast:Union[Tuple[LazyOp, ...], LazyOp], inputs:List[Tensor], *args, **kwargs):
|
||||
if not isinstance(ast, LazyOp): ast = LazyOp(MetaOps.SINK, ast)
|
||||
inbufs = [x.lazydata.buffer for x in inputs]
|
||||
outbufs = [Buffer(inbufs[-1].device, out.arg.st.size, out.arg.dtype).allocate() for out in ast.src]
|
||||
outbufs = [Buffer(inbufs[-1].device if inbufs else Device.DEFAULT, out.arg.st.size, out.arg.dtype).allocate() for out in ast.src]
|
||||
return _helper_linearizer_opt_ast(ast, outbufs+inbufs, *args, **kwargs)
|
||||
|
||||
def helper_linearizer_opt(r:Union[Tensor, List[Tensor]], *args, **kwargs):
|
||||
@@ -1059,7 +1059,7 @@ def helper_linearizer_opt(r:Union[Tensor, List[Tensor]], *args, **kwargs):
|
||||
def _helper_linearizer_opt_ast(realized_ast:LazyOp, real_bufs:List[Buffer], opts=[],
|
||||
apply_tc=False, atol=1e-4, rtol=1e-4, color_sizes=[], wanna_output=[]) -> List[Kernel]:
|
||||
lins: List[Kernel] = []
|
||||
outbufs = [real_bufs[i] for i in range(len(realized_ast.src))]
|
||||
outbufs = [(real_bufs[i], lop.arg.st.shape) for i,lop in enumerate(realized_ast.src)]
|
||||
|
||||
def get_prg(k:Kernel): return CompiledRunner(replace(k.to_program(), dname=Device.DEFAULT))
|
||||
|
||||
@@ -1074,31 +1074,31 @@ def _helper_linearizer_opt_ast(realized_ast:LazyOp, real_bufs:List[Buffer], opts
|
||||
if expected_color_size is not None:
|
||||
assert (cs:=list(zip(k.colors(), k.full_shape))) == expected_color_size, f"expected={expected_color_size} got={cs}"
|
||||
prg = get_prg(k)
|
||||
for buf in outbufs: buf.copyin(np.zeros((buf.size, ), dtype=_to_np_dtype(buf.dtype)).data) # Zero to check that all values are filled
|
||||
for buf,_ in outbufs: buf.copyin(np.zeros((buf.size, ), dtype=_to_np_dtype(buf.dtype)).data) # Zero to check that all values are filled
|
||||
prg.exec(real_bufs)
|
||||
|
||||
for i, buf in enumerate(outbufs):
|
||||
np.testing.assert_allclose(np.frombuffer(buf.as_buffer(), _to_np_dtype(buf.dtype)), wanna_output[i], atol=atol, rtol=rtol)
|
||||
for i, (buf,shape) in enumerate(outbufs):
|
||||
np.testing.assert_allclose(np.frombuffer(buf.as_buffer(), _to_np_dtype(buf.dtype)).reshape(shape), wanna_output[i], atol=atol, rtol=rtol)
|
||||
|
||||
# Get baseline if it is not provided, which is not optimized at all.
|
||||
k = Kernel(realized_ast)
|
||||
lins.append(k)
|
||||
prg = get_prg(k)
|
||||
prg.exec(real_bufs)
|
||||
if len(wanna_output) == 0: wanna_output = [np.frombuffer(buf.as_buffer(), _to_np_dtype(buf.dtype)).copy() for buf in outbufs]
|
||||
if len(wanna_output) == 0: wanna_output = [np.frombuffer(buf.as_buffer(), _to_np_dtype(buf.dtype)).reshape(shape).copy() for buf,shape in outbufs]
|
||||
else:
|
||||
for i, buf in enumerate(outbufs):
|
||||
np.testing.assert_allclose(np.frombuffer(buf.as_buffer(), _to_np_dtype(buf.dtype)), wanna_output[i], atol=atol, rtol=rtol)
|
||||
for i, (buf,shape) in enumerate(outbufs):
|
||||
np.testing.assert_allclose(np.frombuffer(buf.as_buffer(), _to_np_dtype(buf.dtype)).reshape(shape), wanna_output[i], atol=atol, rtol=rtol)
|
||||
|
||||
# Check correctness of handcoded optimiztions.
|
||||
k = Kernel(realized_ast)
|
||||
lins.append(k)
|
||||
k.hand_coded_optimizations()
|
||||
prg = get_prg(k)
|
||||
for buf in outbufs: buf.copyin(np.zeros((buf.size, ), dtype=_to_np_dtype(buf.dtype)).data) # Zero to check that all values are filled
|
||||
for buf,_ in outbufs: buf.copyin(np.zeros((buf.size, ), dtype=_to_np_dtype(buf.dtype)).data) # Zero to check that all values are filled
|
||||
prg.exec(real_bufs)
|
||||
for i, buf in enumerate(outbufs):
|
||||
np.testing.assert_allclose(np.frombuffer(buf.as_buffer(), _to_np_dtype(buf.dtype)), wanna_output[i], atol=atol, rtol=rtol)
|
||||
for i, (buf,shape) in enumerate(outbufs):
|
||||
np.testing.assert_allclose(np.frombuffer(buf.as_buffer(), _to_np_dtype(buf.dtype)).reshape(shape), wanna_output[i], atol=atol, rtol=rtol)
|
||||
for i, x in enumerate(opts): # Check custom transformations if any.
|
||||
check_opt(x, lambda: Kernel(realized_ast), color_sizes[i] if i < len(color_sizes) else None)
|
||||
return lins
|
||||
@@ -1745,7 +1745,7 @@ class TestKernelOpts(unittest.TestCase):
|
||||
[Opt(OptOps.PADTO, 0, 32), Opt(OptOps.PADTO, 1, 32), Opt(OptOps.PADTO, 2, 32)],
|
||||
# can optimize further post PADTO
|
||||
[Opt(OptOps.PADTO, 0, 32), Opt(OptOps.PADTO, 1, 32), Opt(OptOps.UPCAST, 0, 2), Opt(OptOps.UPCAST, 1, 2),],
|
||||
], wanna_output=[(a.numpy()@b.numpy()+c.numpy()@d.numpy()).reshape(-1)])
|
||||
], wanna_output=[(a.numpy()@b.numpy()+c.numpy()@d.numpy()).reshape(N, N, 1)])
|
||||
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared")
|
||||
|
||||
Reference in New Issue
Block a user