all realize 2 (#4527)

* all realize 2

* tests fixup

* fix more tests

* fix openpilot

* fix tests

* unneeded
This commit is contained in:
George Hotz
2024-05-10 22:43:09 -07:00
committed by GitHub
parent d2c347fc74
commit 2f970a4fc2
21 changed files with 142 additions and 139 deletions

View File

@@ -4,7 +4,8 @@ from examples.llama import Transformer, MODEL_PARAMS
from tinygrad.tensor import Tensor
from tinygrad import Device
from tinygrad.nn.state import get_state_dict
from tinygrad.device import Allocator, method_cache
from tinygrad.device import Allocator
from tinygrad.engine.realize import method_cache
from tinygrad.helpers import Profiling
class FakeProgram:

View File

@@ -9,6 +9,7 @@ from tinygrad.codegen.linearizer import Linearizer, UOp
from tinygrad.codegen.kernel import Opt, OptOps
from tinygrad.features.search import get_linearizer_actions, bufs_from_lin
from tinygrad.features.graph import print_tree
from tinygrad.engine.realize import CompiledRunner
from tinygrad.helpers import getenv, from_mv, prod, colored, Context, DEBUG
from tinygrad.ops import LazyOp, UnaryOps, BufferOps
@@ -55,7 +56,7 @@ def run_linearizer(lin: Linearizer, rawbufs=None, var_vals=None):
# TODO: images needs required_optimization
try:
prg = device.to_runner(lin)
prg = CompiledRunner(lin.to_program())
except Exception:
traceback.print_exc()
return "COMPILE_ERROR"

View File

@@ -1,7 +1,7 @@
import sys
import numpy as np
from tinygrad import Tensor, Device, dtypes
from tinygrad.device import Runner
from tinygrad.engine.realize import Runner
from tinygrad.dtype import DType
from tinygrad.nn.state import get_parameters
from tinygrad.helpers import Context, CI, OSX, getenv

View File

@@ -10,8 +10,10 @@ from tinygrad.dtype import dtypes
# *** first, we implement the atan2 op at the lowest level ***
# `atan2_gpu` for GPUBuffers and `atan2_cpu` for CPUBuffers
from tinygrad.lazy import Buffer, create_lazybuffer
from tinygrad.device import CompiledRunner, Device, Program
from tinygrad.device import Device
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.engine.realize import CompiledRunner
from tinygrad.renderer import Program
# we don't always have GPU support, so the type signature is the abstract CompiledBuffer instead of GPUBuffer
def atan2_gpu(ret:Buffer, a:Buffer, b:Buffer):

View File

@@ -1,5 +1,6 @@
import numpy as np
import unittest
from dataclasses import replace
from tinygrad.codegen.kernel import Opt, OptOps, KernelOptError, tensor_cores
from tinygrad.codegen.linearizer import Linearizer, UOp, UOps, expand_node, expand_idxs
@@ -10,7 +11,7 @@ from tinygrad.shape.view import View
from tinygrad.shape.symbolic import MulNode, Variable, NumNode, Node
from tinygrad.tensor import Tensor
from tinygrad.engine.schedule import create_schedule
from tinygrad.engine.realize import run_schedule, lower_schedule
from tinygrad.engine.realize import run_schedule, lower_schedule, CompiledRunner
from tinygrad.helpers import prod, Context, getenv, CI
from tinygrad.dtype import DType, dtypes
from tinygrad.codegen.uops import UOpGraph
@@ -269,7 +270,7 @@ class TestLinearizer(unittest.TestCase):
assert len([uop for uop in k.uops if uop.uop is UOps.WMMA]) > 0, "tensor core not triggered"
assert len([x for x in k.applied_opts if x.op is OptOps.TC]) == 1, "tensor core opt not included"
prg = Device[Device.DEFAULT].to_runner(k)
prg = CompiledRunner(k.to_program())
real_bufs[0].copyin(np.zeros((real_bufs[0].size, ), dtype=real_bufs[0].dtype.np).data) # Zero to check that all values are filled
prg.exec(real_bufs)
result = np.frombuffer(real_bufs[0].as_buffer(), real_bufs[0].dtype.np)
@@ -586,7 +587,9 @@ def helper_linearizer_opt(r:Tensor, opts=[], apply_tc=False, atol=1e-4, rtol=1e-
wanna_output = None
realized_ast, real_bufs = helper_realized_ast(r)
def check_opt(opts, create_k, to_prg, expected_color_size):
def get_prg(k:Linearizer): return CompiledRunner(replace(k.to_program(), dname=Device.DEFAULT))
def check_opt(opts, create_k, expected_color_size):
k = create_k()
if apply_tc:
assert k.apply_tensor_cores(1, extra_opts=opts), "no tensor core triggered"
@@ -595,26 +598,26 @@ def helper_linearizer_opt(r:Tensor, opts=[], apply_tc=False, atol=1e-4, rtol=1e-
k.apply_opt(opt)
if expected_color_size is not None:
assert (cs:=[(x,y) for x,y in zip(k.colors(), k.full_shape)]) == expected_color_size, f"expected={expected_color_size} got={cs}"
prg = to_prg(k)
prg = get_prg(k)
real_bufs[0].copyin(np.zeros((real_bufs[0].size, ), dtype=real_bufs[0].dtype.np).data) # Zero to check that all values are filled
prg.exec(real_bufs)
np.testing.assert_allclose(np.frombuffer(real_bufs[0].as_buffer(), real_bufs[0].dtype.np), wanna_output, atol=atol, rtol=rtol)
# Get baseline, which is not optimized at all.
k = Linearizer(realized_ast)
prg = Device[Device.DEFAULT].to_runner(k)
prg = get_prg(k)
prg.exec(real_bufs)
wanna_output = np.frombuffer(real_bufs[0].as_buffer(), real_bufs[0].dtype.np).copy()
# Check correctness of handcoded optimiztions.
k = Linearizer(realized_ast)
k.hand_coded_optimizations()
prg = Device[Device.DEFAULT].to_runner(k)
prg = get_prg(k)
real_bufs[0].copyin(np.zeros((real_bufs[0].size, ), dtype=real_bufs[0].dtype.np).data) # Zero to check that all values are filled
prg.exec(real_bufs)
np.testing.assert_allclose(wanna_output, np.frombuffer(real_bufs[0].as_buffer(), real_bufs[0].dtype.np), atol=atol, rtol=rtol)
for i, x in enumerate(opts): # Check custom transformations if any.
check_opt(x, lambda: Linearizer(realized_ast), Device[Device.DEFAULT].to_runner, color_sizes[i] if i < len(color_sizes) else None)
check_opt(x, lambda: Linearizer(realized_ast), color_sizes[i] if i < len(color_sizes) else None)
class TestKernelOpts(unittest.TestCase):
def test_local_and_grouped_reduce(self):

View File

@@ -1,12 +1,11 @@
import unittest, functools, random
from typing import List
from tinygrad import Tensor, Device, nn, GlobalCounters, TinyJit, dtypes
from tinygrad.device import CompiledRunner
from tinygrad.ops import LoadOps, ReduceOps
from tinygrad.helpers import CI, prod, Context
from tinygrad.nn.state import get_parameters, get_state_dict
from tinygrad.engine.schedule import create_schedule
from tinygrad.engine.realize import lower_schedule, BufferCopy
from tinygrad.engine.realize import lower_schedule, BufferCopy, CompiledRunner
from tinygrad.features.multi import all_reduce, MultiLazyBuffer
from random import randint
import numpy as np

View File

@@ -4,9 +4,11 @@ import numpy as np
from tinygrad.tensor import Tensor
from tinygrad.helpers import getenv
from tinygrad.dtype import dtypes, DType, PtrDType
from tinygrad.device import Buffer, Device, CompiledRunner, Program
from tinygrad.device import Buffer, Device
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps
from tinygrad.renderer import Program
from tinygrad.engine.schedule import create_schedule
from tinygrad.engine.realize import CompiledRunner, lower_schedule_item
from tinygrad.codegen.linearizer import UOps, UOp
from tinygrad.codegen.uops import exec_alu, UOpGraph
from test.helpers import is_dtype_supported
@@ -210,9 +212,8 @@ class TestConstantFolding(unittest.TestCase):
t = Tensor(1, dtype=dtypes.float).bitcast(dtypes.int)
si = create_schedule([t.lazydata])
assert len(si) == 1
si = si[0]
lin = Device[Device.DEFAULT].get_linearizer(si.ast[0]).linearize()
assert any(uop.uop is UOps.BITCAST for uop in lin.uops.uops), f"{[uop.uop for uop in lin.uops.uops]} does not contain bitcast"
ji = lower_schedule_item(si[-1])
assert any(uop.uop is UOps.BITCAST for uop in ji.prg.p.uops), f"{[uop.uop for uop in ji.prg.p.uops]} does not contain bitcast"
class TestLocalAccess(unittest.TestCase):
@unittest.skipIf(Device.DEFAULT in {"LLVM"}, "device doesn't support local memory")

View File

@@ -14,7 +14,7 @@ from tinygrad.engine.realize import lower_schedule_item
def get_stats(x:Tensor):
si = create_schedule([x.lazydata])[-1]
ei = lower_schedule_item(si)
return ei.prg.p.op_estimate, ei.prg.p.mem_estimate
return ei.prg.op_estimate, ei.prg.mem_estimate
class TestUOpsStats(unittest.TestCase):
def test_simple_add(self):