mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-12 15:45:27 -05:00
all realize 2 (#4527)
* all realize 2 * tests fixup * fix more tests * fix openpilot * fix tests * unneeded
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
import numpy as np
|
||||
import unittest
|
||||
from dataclasses import replace
|
||||
|
||||
from tinygrad.codegen.kernel import Opt, OptOps, KernelOptError, tensor_cores
|
||||
from tinygrad.codegen.linearizer import Linearizer, UOp, UOps, expand_node, expand_idxs
|
||||
@@ -10,7 +11,7 @@ from tinygrad.shape.view import View
|
||||
from tinygrad.shape.symbolic import MulNode, Variable, NumNode, Node
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.engine.schedule import create_schedule
|
||||
from tinygrad.engine.realize import run_schedule, lower_schedule
|
||||
from tinygrad.engine.realize import run_schedule, lower_schedule, CompiledRunner
|
||||
from tinygrad.helpers import prod, Context, getenv, CI
|
||||
from tinygrad.dtype import DType, dtypes
|
||||
from tinygrad.codegen.uops import UOpGraph
|
||||
@@ -269,7 +270,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
assert len([uop for uop in k.uops if uop.uop is UOps.WMMA]) > 0, "tensor core not triggered"
|
||||
assert len([x for x in k.applied_opts if x.op is OptOps.TC]) == 1, "tensor core opt not included"
|
||||
|
||||
prg = Device[Device.DEFAULT].to_runner(k)
|
||||
prg = CompiledRunner(k.to_program())
|
||||
real_bufs[0].copyin(np.zeros((real_bufs[0].size, ), dtype=real_bufs[0].dtype.np).data) # Zero to check that all values are filled
|
||||
prg.exec(real_bufs)
|
||||
result = np.frombuffer(real_bufs[0].as_buffer(), real_bufs[0].dtype.np)
|
||||
@@ -586,7 +587,9 @@ def helper_linearizer_opt(r:Tensor, opts=[], apply_tc=False, atol=1e-4, rtol=1e-
|
||||
wanna_output = None
|
||||
realized_ast, real_bufs = helper_realized_ast(r)
|
||||
|
||||
def check_opt(opts, create_k, to_prg, expected_color_size):
|
||||
def get_prg(k:Linearizer): return CompiledRunner(replace(k.to_program(), dname=Device.DEFAULT))
|
||||
|
||||
def check_opt(opts, create_k, expected_color_size):
|
||||
k = create_k()
|
||||
if apply_tc:
|
||||
assert k.apply_tensor_cores(1, extra_opts=opts), "no tensor core triggered"
|
||||
@@ -595,26 +598,26 @@ def helper_linearizer_opt(r:Tensor, opts=[], apply_tc=False, atol=1e-4, rtol=1e-
|
||||
k.apply_opt(opt)
|
||||
if expected_color_size is not None:
|
||||
assert (cs:=[(x,y) for x,y in zip(k.colors(), k.full_shape)]) == expected_color_size, f"expected={expected_color_size} got={cs}"
|
||||
prg = to_prg(k)
|
||||
prg = get_prg(k)
|
||||
real_bufs[0].copyin(np.zeros((real_bufs[0].size, ), dtype=real_bufs[0].dtype.np).data) # Zero to check that all values are filled
|
||||
prg.exec(real_bufs)
|
||||
np.testing.assert_allclose(np.frombuffer(real_bufs[0].as_buffer(), real_bufs[0].dtype.np), wanna_output, atol=atol, rtol=rtol)
|
||||
|
||||
# Get baseline, which is not optimized at all.
|
||||
k = Linearizer(realized_ast)
|
||||
prg = Device[Device.DEFAULT].to_runner(k)
|
||||
prg = get_prg(k)
|
||||
prg.exec(real_bufs)
|
||||
wanna_output = np.frombuffer(real_bufs[0].as_buffer(), real_bufs[0].dtype.np).copy()
|
||||
|
||||
# Check correctness of handcoded optimiztions.
|
||||
k = Linearizer(realized_ast)
|
||||
k.hand_coded_optimizations()
|
||||
prg = Device[Device.DEFAULT].to_runner(k)
|
||||
prg = get_prg(k)
|
||||
real_bufs[0].copyin(np.zeros((real_bufs[0].size, ), dtype=real_bufs[0].dtype.np).data) # Zero to check that all values are filled
|
||||
prg.exec(real_bufs)
|
||||
np.testing.assert_allclose(wanna_output, np.frombuffer(real_bufs[0].as_buffer(), real_bufs[0].dtype.np), atol=atol, rtol=rtol)
|
||||
for i, x in enumerate(opts): # Check custom transformations if any.
|
||||
check_opt(x, lambda: Linearizer(realized_ast), Device[Device.DEFAULT].to_runner, color_sizes[i] if i < len(color_sizes) else None)
|
||||
check_opt(x, lambda: Linearizer(realized_ast), color_sizes[i] if i < len(color_sizes) else None)
|
||||
|
||||
class TestKernelOpts(unittest.TestCase):
|
||||
def test_local_and_grouped_reduce(self):
|
||||
|
||||
Reference in New Issue
Block a user