lowerer is kernel [run_process_replay] (#5437)

This commit is contained in:
George Hotz
2024-07-12 18:50:55 -07:00
committed by GitHub
parent b8342fb085
commit 03c2dc8bd7
33 changed files with 215 additions and 213 deletions

View File

@@ -2,7 +2,7 @@ from typing import List
from extra.models.resnet import ResNet50
from examples.mlperf.helpers import get_mlperf_bert_model
from tinygrad import Tensor, Device, dtypes, nn
from tinygrad.codegen.lowerer import Lowerer
from tinygrad.codegen.kernel import Kernel
from tinygrad.device import Compiled
from tinygrad.engine.graph import print_tree
from tinygrad.engine.schedule import create_schedule
@@ -84,24 +84,24 @@ if __name__ == "__main__":
if DEBUG >= 2:
for ast in si.ast: print_tree(ast)
rawbufs = bufs_from_lin(Lowerer(si.ast))
rawbufs = bufs_from_lin(Kernel(si.ast))
# "linearize" the op into uops in different ways
lins:List[Lowerer] = []
lins:List[Kernel] = []
# always try hand coded opt
lin = Lowerer(si.ast, opts=device.renderer)
lin = Kernel(si.ast, opts=device.renderer)
lin.hand_coded_optimizations()
lins.append(lin)
# maybe try tensor cores
lin = Lowerer(si.ast, opts=device.renderer)
lin = Kernel(si.ast, opts=device.renderer)
if lin.apply_tensor_cores():
lins.append(lin)
# try a beam search
if beam:=getenv("BEAM"):
lin = Lowerer(si.ast, opts=device.renderer)
lin = Kernel(si.ast, opts=device.renderer)
lin = beam_search(lin, rawbufs, beam, bool(getenv("BEAM_ESTIMATE", 1)))
lins.append(lin)

View File

@@ -1,5 +1,5 @@
from typing import Tuple, List, NamedTuple, Any, Dict, Optional, Union, DefaultDict, cast
from tinygrad.codegen.lowerer import UOps, MemOp, UOp
from tinygrad.codegen.kernel import UOps, MemOp, UOp
from tinygrad.ops import BinaryOps, UnaryOps
from tinygrad.dtype import DType, dtypes
from tinygrad.helpers import DEBUG

View File

@@ -3,7 +3,7 @@ from platform import system
from typing import Tuple, Dict, List, Optional
from tinygrad import dtypes
from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps
from tinygrad.codegen.lowerer import UOps, UOp
from tinygrad.codegen.kernel import UOps, UOp
from tinygrad.helpers import CI
from tinygrad.codegen.assembly import uops_to_asmstyle, AssemblyLanguage

View File

@@ -1,7 +1,7 @@
from typing import List
import struct
from tinygrad.codegen.assembly import uops_to_asmstyle, AssemblyLanguage
from tinygrad.codegen.lowerer import UOps, UOp
from tinygrad.codegen.kernel import UOps, UOp
from tinygrad import dtypes
from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps
from tinygrad.runtime.ops_cuda import arch

View File

@@ -2,7 +2,7 @@ import yaml
from typing import Tuple, Set, Dict
from tinygrad import dtypes
from tinygrad.codegen.assembly import AssemblyCodegen, Register
from tinygrad.codegen.lowerer import UOps
from tinygrad.codegen.kernel import UOps
from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps
from tinygrad.runtime.ops_gpu import ROCM_LLVM_PATH

View File

@@ -2,7 +2,7 @@ from typing import Dict, List, Final, Callable, DefaultDict
from collections import defaultdict
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, Op
from tinygrad.helpers import DType, PtrDType, dtypes, ImageDType, DEBUG, getenv
from tinygrad.codegen.lowerer import UOp, UOps
from tinygrad.codegen.kernel import UOp, UOps
from triton.compiler import compile as triton_compile
import linecache
import math

View File

@@ -38,9 +38,9 @@ B = Tensor.rand(K, N, device="clang")
C = (A.reshape(M, 1, K) * B.permute(1,0).reshape(1, N, K)).sum(axis=2)
sched = create_schedule([C.lazydata])
from tinygrad.codegen.lowerer import Lowerer
from tinygrad.codegen.kernel import Kernel
from tinygrad.device import CompilerOptions
lin = Lowerer(sched[-1].ast, CompilerOptions(has_local=False, supports_float4=False))
lin = Kernel(sched[-1].ast, CompilerOptions(has_local=False, supports_float4=False))
#lin.hand_coded_optimizations()
lin.linearize()
from tinygrad.runtime.ops_clang import renderer

View File

@@ -7,7 +7,7 @@ from tinygrad.nn.optim import Adam
from tinygrad.nn.state import get_parameters, get_state_dict, safe_save, safe_load, load_state_dict
from tinygrad.engine.search import actions
from extra.optimization.helpers import load_worlds, ast_str_to_lin, lin_to_feats, assert_same_lin
from tinygrad.codegen.lowerer import Lowerer
from tinygrad.codegen.kernel import Kernel
from tinygrad.helpers import getenv
# stuff needed to unpack a kernel
@@ -38,7 +38,7 @@ def dataset_from_cache(fn):
for f in tqdm(cur.fetchall()):
Xs,As = [], []
try:
lin = Lowerer(eval(f[0]))
lin = Kernel(eval(f[0]))
opts = pickle.loads(f[-1])
for o in opts:
Xs.append(lin_to_feats(lin, use_sts=True))

View File

@@ -13,7 +13,7 @@ inf, nan = float('inf'), float('nan')
from tinygrad.codegen.kernel import Opt, OptOps
# more stuff
from tinygrad.codegen.lowerer import Lowerer
from tinygrad.codegen.kernel import Kernel
from tinygrad.engine.search import actions
from extra.optimization.helpers import lin_to_feats
from extra.optimization.pretrain_valuenet import ValueNet
@@ -48,7 +48,7 @@ def dataset_from_cache(fn):
new_tm = min(opts_to_outcome[(ast,k)])
if math.isinf(old_tm) or math.isinf(new_tm) or old_tm < 1e-9 or new_tm < 1e-9: continue
try:
lin = Lowerer(eval(ast))
lin = Kernel(eval(ast))
except Exception:
continue
for opt in k[:-1]: lin.apply_opt(opt)

View File

@@ -1,12 +1,12 @@
import random
from extra.optimization.helpers import load_worlds, ast_str_to_lin
from tinygrad.engine.search import actions
from tinygrad.codegen.lowerer import Lowerer
from tinygrad.codegen.kernel import Kernel
from tinygrad.helpers import tqdm
tactions = set()
def test_rebuild(lin):
linr = Lowerer(lin.ast)
linr = Kernel(lin.ast)
for o in lin.applied_opts:
assert o in actions, f"{o} is not in actions"
tactions.add(o)

View File

@@ -9,12 +9,12 @@ from tinygrad.shape.symbolic import Variable, NumNode
inf, nan = float('inf'), float('nan')
# kernel unpacker
from tinygrad.codegen.lowerer import Lowerer
from tinygrad.codegen.kernel import Kernel
def ast_str_to_ast(ast_str:str) -> Tuple[LazyOp,...]: return LazyOp(MetaOps.SINK, val) if isinstance(val:=eval(ast_str), tuple) else val
def ast_str_to_lin(ast_str:str, opts=None): return Lowerer(ast_str_to_ast(ast_str), opts=opts)
def ast_str_to_lin(ast_str:str, opts=None): return Kernel(ast_str_to_ast(ast_str), opts=opts)
def kern_str_to_lin(kern_str:str, opts=None):
(ast, applied_opts,) = eval(kern_str)
k = Lowerer(ast, opts=opts)
k = Kernel(ast, opts=opts)
for opt in applied_opts:
k.apply_opt(opt)
return k
@@ -44,7 +44,7 @@ from tinygrad.shape.symbolic import Node
MAX_DIMS = 16
MAX_BUFS = 9
def lin_to_feats(lin:Lowerer, use_sts=True):
def lin_to_feats(lin:Kernel, use_sts=True):
assert lin.shape_len < MAX_DIMS, "too many dims"
all_colors = ["blue", "cyan", "white", "green", "red", "magenta", "yellow"]

View File

@@ -1,4 +1,4 @@
from tinygrad.codegen.lowerer import Lowerer
from tinygrad.codegen.kernel import Kernel
from tqdm import tqdm, trange
import math
import random
@@ -45,7 +45,7 @@ if __name__ == "__main__":
X,Y = [], []
for i,x in enumerate(tqdm(dset)):
ast, opts, tms = eval(x)
lin = Lowerer(ast)
lin = Kernel(ast)
for o in opts: lin.apply_opt(o)
if lin.shape_len >= MAX_DIMS: continue
if min(tms) == float('inf'): continue

View File

@@ -1,9 +1,9 @@
from typing import List, Tuple
from tinygrad.codegen.lowerer import Lowerer
from tinygrad.codegen.kernel import Kernel
from tinygrad.engine.search import get_linearizer_actions, actions
_net = None
def beam_q_estimate(beam:List[Tuple[Lowerer, float]]) -> List[Tuple[Lowerer, float]]:
def beam_q_estimate(beam:List[Tuple[Kernel, float]]) -> List[Tuple[Kernel, float]]:
global _net
if _net is None:
from tinygrad.nn.state import load_state_dict, safe_load

View File

@@ -4,7 +4,7 @@ from extra.optimization.helpers import ast_str_to_lin
from tinygrad import dtypes
from tinygrad.helpers import BEAM, getenv
from tinygrad.device import Device, Compiled
from tinygrad.codegen.lowerer import Lowerer
from tinygrad.codegen.kernel import Kernel
from tinygrad.engine.search import time_linearizer, beam_search, bufs_from_lin

View File

@@ -6,7 +6,7 @@ from tinygrad import dtypes, Device
from tinygrad.ops import LazyOp, BinaryOps, UnaryOps, BufferOps, MemBuffer, ConstBuffer
from tinygrad.shape.view import View
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.codegen.lowerer import Lowerer
from tinygrad.codegen.kernel import Kernel
from tinygrad.engine.search import time_linearizer, bufs_from_lin
# from resnet50, tinybox red
@@ -15,9 +15,9 @@ from tinygrad.engine.search import time_linearizer, bufs_from_lin
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=UnaryOps.CAST, src=(LazyOp(op=UnaryOps.CAST, src=(LazyOp(op=UnaryOps.CAST, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=UnaryOps.CAST, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(256, 1, 128, 56, 56, 1, 1, 1), strides=(401408, 0, 3136, 56, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)))),), arg=dtypes.float), LazyOp(op=UnaryOps.NEG, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(256, 1, 128, 56, 56, 1, 1, 1), strides=(0, 0, 1, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)))),), arg=None)), arg=None), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=3, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(256, 1, 128, 56, 56, 1, 1, 1), strides=(0, 0, 1, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None), LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=UnaryOps.CAST, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(256, 1, 128, 56, 56, 1, 1, 1), strides=(401408, 0, 3136, 56, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)))),), arg=dtypes.float), LazyOp(op=UnaryOps.NEG, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(256, 1, 128, 56, 56, 1, 1, 1), strides=(0, 0, 1, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)))),), arg=None)), arg=None), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=3, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(256, 1, 128, 56, 56, 1, 1, 1), strides=(0, 0, 1, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None)), arg=None), LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=4, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(256, 1, 128, 56, 56, 1, 1, 1), strides=(0, 0, 1, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)))), LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=5, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(256, 1, 128, 56, 56, 1, 1, 1), strides=(0, 0, 1, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)))), LazyOp(op=UnaryOps.CAST, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=UnaryOps.CAST, src=(LazyOp(op=BinaryOps.CMPLT, src=(LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.0, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(256, 1, 128, 56, 56, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)))), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=6, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(256, 1, 128, 56, 56, 1, 1, 1), strides=(401408, 0, 3136, 56, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),))))), arg=None),), arg=dtypes.half), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=7, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(256, 1, 128, 56, 56, 1, 1, 1), strides=(430592, 0, 3364, 58, 1, 0, 0, 0), offset=59, mask=None, contiguous=False),))))), arg=None),), arg=dtypes.float)), arg=None)), arg=None)), arg=None), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=8, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(256, 1, 128, 56, 56, 1, 1, 1), strides=(0, 0, 1, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=dtypes.half),), arg=dtypes.float),), arg=dtypes.half),), arg=MemBuffer(idx=0, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(256, 1, 128, 56, 56, 1, 1, 1), strides=(401408, 0, 3136, 56, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),))))
device = Device[Device.DEFAULT]
rawbufs = bufs_from_lin(Lowerer(ast))
rawbufs = bufs_from_lin(Kernel(ast))
lin = Lowerer(ast, opts=device.renderer)
lin = Kernel(ast, opts=device.renderer)
lin.hand_coded_optimizations()
tm = time_linearizer(lin, rawbufs, allow_test_size=False, cnt=10)
print(f"{tm=}")

View File

@@ -3,14 +3,14 @@ from tinygrad.runtime.support.hip_comgr import compile_hip
from tinygrad import Tensor
from tinygrad.device import Device
from tinygrad.engine.schedule import create_schedule
from tinygrad.codegen.lowerer import Lowerer
from tinygrad.codegen.kernel import Kernel
class TestHIPCompileSpeed(unittest.TestCase):
@unittest.skipIf(Device.DEFAULT != "HIP", "only run on HIP")
def test_hip_compile(self):
a, b = Tensor([1,2,3,4,5]), Tensor([1,2,3,4,5])
out = a + b
lin = Lowerer(create_schedule([out.lazydata])[-1].ast[0])
lin = Kernel(create_schedule([out.lazydata])[-1].ast[0])
lin.linearize()
reference = """

View File

@@ -8,7 +8,7 @@ from test.test_linearizer_failures import helper_test_lin
from tinygrad.engine.realize import get_runner, CompiledRunner
from test.external.fuzz_linearizer import get_fuzz_rawbufs
from tinygrad.codegen.lowerer import Lowerer
from tinygrad.codegen.kernel import Kernel
from tinygrad.ops import LazyOp, BinaryOps, UnaryOps, ReduceOps, BufferOps, MemBuffer
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import View
@@ -28,13 +28,13 @@ class TestNV(unittest.TestCase):
def test_oor_kernels(self):
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=UnaryOps.CAST, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=UnaryOps.CAST, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(1, 256, 1, 512, 4, 16, 4, 16), strides=(0, 100352, 0, 196, 0, 14, 0, 1), offset=-15, mask=((0, 1), (0, 256), (0, 1), (0, 512), (0, 4), (1, 15), (0, 4), (1, 15)), contiguous=False), View(shape=(256, 1, 512, 7, 7, 512, 3, 3), strides=(2097152, 0, 0, 128, 2, 4096, 1088, 17), offset=0, mask=None, contiguous=False))))), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(256, 1, 512, 7, 7, 512, 3, 3), strides=(25088, 0, 49, 7, 1, 0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=(dtypes.float, False)),), arg=((0, 3, 4), dtypes.float)),), arg=(dtypes.half, False)),), arg=MemBuffer(idx=0, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(1, 1, 512, 1, 1, 512, 3, 3), strides=(0, 0, 4608, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=True),)))) # noqa: E501
opts = [Opt(op=OptOps.TC, axis=6, amt=2), Opt(op=OptOps.UPCAST, axis=0, amt=4), Opt(op=OptOps.UPCAST, axis=3, amt=0), Opt(op=OptOps.LOCAL, axis=1, amt=4), Opt(op=OptOps.LOCAL, axis=2, amt=3), Opt(op=OptOps.UPCAST, axis=1, amt=2)] # noqa: E501
helper_test_lin(Lowerer(ast), opts=opts, failed_platforms=["NV"])
helper_test_lin(Kernel(ast), opts=opts, failed_platforms=["NV"])
def test_error_on_huge_dims(self):
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=UnaryOps.CAST, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(1, 1, 1024, 683), strides=(0, 0, 0, 1), offset=0, mask=None, contiguous=False),)))), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(1, 1, 1024, 683), strides=(0, 0, 683, 1), offset=0, mask=None, contiguous=True),))))), arg=None),), arg=dtypes.float),), arg=(3,)),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 1, 1024, 1), strides=(0, 0, 1, 0), offset=0, mask=None, contiguous=True),)))) # noqa: E501
opts = [Opt(op=OptOps.GROUP, axis=0, amt=0), Opt(op=OptOps.PADTO, axis=1, amt=32), Opt(op=OptOps.UNROLL, axis=0, amt=4), Opt(op=OptOps.LOCAL, axis=0, amt=2), Opt(op=OptOps.LOCAL, axis=0, amt=2)] # noqa: E501
with self.assertRaises(RuntimeError) as cm:
lin = Lowerer(ast)
lin = Kernel(ast)
for opt in opts: lin.apply_opt(opt)
rawbufs = get_fuzz_rawbufs(lin)
prg = CompiledRunner(lin.to_program())

View File

@@ -1,7 +1,7 @@
#!/usr/bin/env python
import unittest
from tinygrad.tensor import Tensor
from tinygrad.codegen.lowerer import Lowerer
from tinygrad.codegen.kernel import Kernel
from tinygrad.renderer.cstyle import OpenCLRenderer
from tinygrad.engine.graph import graph_uops
from tinygrad.engine.schedule import create_schedule
@@ -13,7 +13,7 @@ class TestUopsGraph(unittest.TestCase):
a = Tensor.rand(N,N)
b = Tensor.rand(N,N)
si = create_schedule([(a@b).lazydata])[-1]
lin = Lowerer(si.ast)
lin = Kernel(si.ast)
lin.hand_coded_optimizations()
print(lin.colored_shape())
uops = lin.linearize().uops
@@ -24,7 +24,7 @@ class TestUopsGraph(unittest.TestCase):
def test_reduce(self):
a = Tensor.rand(1024*1024)
si = create_schedule([a.sum().lazydata])[-1]
lin = Lowerer(si.ast)
lin = Kernel(si.ast)
lin.hand_coded_optimizations()
uops = lin.linearize().uops
graph_uops(uops)
@@ -34,7 +34,7 @@ class TestUopsGraph(unittest.TestCase):
x = Tensor.rand(1,3,16,16)
c = Conv2d(3, 16, (3,3))
si = create_schedule([c(x).elu().lazydata])[-1]
lin = Lowerer(si.ast)
lin = Kernel(si.ast)
lin.hand_coded_optimizations()
uops = lin.linearize().uops
graph_uops(uops)

View File

@@ -6,7 +6,7 @@ from extra.optimization.helpers import load_worlds, ast_str_to_lin
from tinygrad import Tensor, Device, dtypes
from tinygrad.tensor import _to_np_dtype
from tinygrad.codegen.lowerer import Lowerer
from tinygrad.codegen.kernel import Kernel
from tinygrad.codegen.uops import UOp
from tinygrad.codegen.kernel import Opt, OptOps
from tinygrad.engine.search import get_linearizer_actions, bufs_from_lin
@@ -53,7 +53,7 @@ def get_fuzz_rawbuf_like(rawbuf, zero=False, size=None):
rawbuf.copyin(mv)
return rawbuf
def run_linearizer(lin: Lowerer, rawbufs=None, var_vals=None):
def run_linearizer(lin: Kernel, rawbufs=None, var_vals=None):
if rawbufs is None: rawbufs = bufs_from_lin(lin)
if var_vals is None: var_vals = {v: v.min for v in lin.ast[0].vars()}
@@ -72,7 +72,7 @@ def run_linearizer(lin: Lowerer, rawbufs=None, var_vals=None):
return "PASS"
def compare_linearizer(lin: Lowerer, rawbufs=None, var_vals=None, ground_truth=None, rtol=1e-2, atol=1e-2):
def compare_linearizer(lin: Kernel, rawbufs=None, var_vals=None, ground_truth=None, rtol=1e-2, atol=1e-2):
# TODO: for bfloat16 it compiles linearizer, but it does not run because numpy cannot generate bf16 buffer.
has_bf16 = any(b.dtype == dtypes.bfloat16 for b in lin.membufs)
@@ -90,7 +90,7 @@ def compare_linearizer(lin: Lowerer, rawbufs=None, var_vals=None, ground_truth=N
var_vals = {v: random.randint(v.min, v.max if isinstance(v.max, int) else v.min) for v in lin.ast.vars()}
if ground_truth is None and not has_bf16:
unoptimized = Lowerer(lin.ast)
unoptimized = Kernel(lin.ast)
unoptimized.required_optimizations()
if run_linearizer(unoptimized, rawbufs, var_vals) != "PASS":
return ("BASELINE_ERROR", rawbufs, var_vals, ground_truth,)
@@ -117,7 +117,7 @@ def compare_linearizer(lin: Lowerer, rawbufs=None, var_vals=None, ground_truth=N
return ("PASS", rawbufs, var_vals, ground_truth,)
def fuzz_linearizer(lin: Lowerer, rtol=1e-2, atol=1e-2):
def fuzz_linearizer(lin: Kernel, rtol=1e-2, atol=1e-2):
SEED = getenv("SEED", 42)
random.seed(SEED)
np.random.seed(SEED)
@@ -177,7 +177,7 @@ def fuzz_linearizer(lin: Lowerer, rtol=1e-2, atol=1e-2):
if FUZZ_ALL_ACTIONS: print(f"depth={depth} total_lins={len(last_lins)} {failures=}")
return failures
def _is_simple(lin: Lowerer) -> bool:
def _is_simple(lin: Kernel) -> bool:
if len(lin.ast.src) > 1: return False
ast:LazyOp = lin.ast.src[0]
if ast.src[0] and ast.src[0].op is UnaryOps.CAST and ast.src[0].src[0] and ast.src[0].src[0].op is BufferOps.LOAD: return True

View File

@@ -1,7 +1,7 @@
#!/usr/bin/env python3
# compare kernels created by HEAD against master
import difflib, pickle
from tinygrad.codegen.lowerer import Lowerer
from tinygrad.codegen.kernel import Kernel
from tinygrad.helpers import Context, ContextVar, colored, db_connection, VERSION, getenv, tqdm
page_size = 100
@@ -17,7 +17,7 @@ for offset in tqdm(range(0, row_count, page_size)):
with Context(**{k:v for k,v in ctx.items() if k in ContextVar._cache}):
# try linearize
try:
k = Lowerer(ast, opts=opts)
k = Kernel(ast, opts=opts)
for opt in applied_opts: k.apply_opt(opt)
good_src = k.opts.render(name, k.linearize().uops)
except Exception as e:

View File

@@ -3,7 +3,7 @@ from collections import defaultdict
from extra.optimization.helpers import kern_str_to_lin
from test.external.fuzz_linearizer import compare_linearizer
from tinygrad.helpers import colored
from tinygrad.codegen.lowerer import Lowerer
from tinygrad.codegen.kernel import Kernel
from tinygrad.engine.graph import print_tree
from tinygrad.engine.search import time_linearizer
@@ -37,7 +37,7 @@ if __name__ == "__main__":
import pickle
with open(args.pkl, 'rb') as file:
(ast, applied_opts,) = pickle.load(file)
lin = Lowerer(ast)
lin = Kernel(ast)
for opt in applied_opts:
lin.apply_opt(opt)
test_lins = [lin]
@@ -55,7 +55,7 @@ if __name__ == "__main__":
print_tree(op)
print(op)
print(test_lin.applied_opts)
unoptimized_lin = Lowerer(test_lin.ast)
unoptimized_lin = Kernel(test_lin.ast)
unoptimized_lin.required_optimizations()
print(f"{unoptimized_lin.colored_shape()} -> {test_lin.colored_shape()}")
(msg,rb,vv,gt) = compare_linearizer(test_lin, None, None, None, rtol=args.rtol, atol=args.atol)

View File

@@ -4,8 +4,7 @@ import unittest
from dataclasses import replace
from test.external.fuzz_linearizer import compare_linearizer
from tinygrad.codegen.kernel import Opt, OptOps, KernelOptError
from tinygrad.codegen.lowerer import Lowerer
from tinygrad.codegen.kernel import Opt, OptOps, KernelOptError, Kernel
from tinygrad.codegen.lowerer import get_grouped_dims
from tinygrad.codegen.uops import UOp, UOps
from tinygrad.device import Device, Buffer
@@ -38,7 +37,7 @@ def helper_tc_allclose(n:int, m:int, k:int, dtype_in:DType, dtype_out:DType, axi
realized_ast = sched[-1].ast
run_schedule(sched)
out = r.numpy()
k = Lowerer(realized_ast)
k = Kernel(realized_ast)
k.apply_tensor_cores(1, axis=axis, tc_opt=tc_opt)
k.linearize()
assert len([uop for uop in k.uops if uop.op is UOps.WMMA]) > 0, "tensor core not triggered"
@@ -54,7 +53,7 @@ def helper_tc_ensure_uops_and_opts_count(n: int, m:int, k:int, dtype_in:DType, d
r = a.matmul(b, acc_dtype=dtype_out)
sched = create_schedule([r.lazydata])
realized_ast = sched[-1].ast
k = Lowerer(realized_ast)
k = Kernel(realized_ast)
k.apply_tensor_cores(1, axis=axis, tc_opt=tc_opt)
k.linearize()
wmmas = len([uop for uop in k.uops if uop.op is UOps.WMMA])
@@ -211,7 +210,7 @@ class TestLinearizer(unittest.TestCase):
@unittest.skip("AST has implicit movement ops")
def test_early_end_local(self):
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=UnaryOps.SQRT, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 32), strides=(864, 32, 1), offset=0, mask=None, contiguous=True),)))), LazyOp(op=UnaryOps.NEG, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 32), strides=(864, 32, 1), offset=0, mask=None, contiguous=True),)))),), arg=(2,)), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.03125, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 32), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=None)), arg=None), LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 32), strides=(864, 32, 1), offset=0, mask=None, contiguous=True),)))), LazyOp(op=UnaryOps.NEG, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 32), strides=(864, 32, 1), offset=0, mask=None, contiguous=True),)))),), arg=(2,)), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.03125, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 32), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=None)), arg=None)), arg=None),), arg=(2,)), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.03125, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 1), strides=(27, 1, 0), offset=0, mask=None, contiguous=True),))))), arg=None),), arg=None),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 1), strides=(27, 1, 0), offset=0, mask=None, contiguous=True),)))), # noqa: E501
k = Lowerer(ast)
k = Kernel(ast)
k.hand_coded_optimizations()
k.linearize()
self.assertEqual(len(endifs:=[x for x in k.uops if x.op is UOps.ENDIF]), len(ifs:=[x for x in k.uops if x.op is UOps.IF]))
@@ -243,7 +242,7 @@ class TestLinearizer(unittest.TestCase):
LazyOp(op=BufferOps.STORE, src=(ast2,), arg=MemBuffer(idx=order.index(2), dtype=dtypes.float, st=ShapeTracker.from_shape((1,)))),
LazyOp(op=BufferOps.STORE, src=(ast3,), arg=MemBuffer(idx=order.index(3), dtype=dtypes.float, st=ShapeTracker.from_shape((1,))))
]
k = Lowerer([asts[i] for i in order])
k = Kernel([asts[i] for i in order])
def recursive_reduceops(x: LazyOp): return [c for v in x.src for c in recursive_reduceops(v)] + [v for v in list(x.src) if v.op in ReduceOps]
for i,r in enumerate(k.reduceops): assert not any([r in recursive_reduceops(x) for x in k.reduceops[:i]]), "reduceops are out of order"
x = Tensor.randn(32).realize()
@@ -256,7 +255,7 @@ class TestLinearizer(unittest.TestCase):
def test_multireduce_store_locals(self):
# ensure the result of local reducop is stored and loaded back into every thread for future use
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=UnaryOps.SQRT, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 32), strides=(864, 32, 1), offset=0, mask=None, contiguous=True),)))), LazyOp(op=UnaryOps.NEG, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 32), strides=(864, 32, 1), offset=0, mask=None, contiguous=True),)))),), arg=(2,)), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.03125, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 32), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=None)), arg=None), LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 32), strides=(864, 32, 1), offset=0, mask=None, contiguous=True),)))), LazyOp(op=UnaryOps.NEG, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 32), strides=(864, 32, 1), offset=0, mask=None, contiguous=True),)))),), arg=(2,)), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.03125, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 32), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=None)), arg=None)), arg=None),), arg=(2,)), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.03125, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 1), strides=(27, 1, 0), offset=0, mask=None, contiguous=True),))))), arg=None),), arg=None),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 1), strides=(27, 1, 0), offset=0, mask=None, contiguous=True),)))), # noqa: E501
k = Lowerer(ast)
k = Kernel(ast)
k.hand_coded_optimizations()
k.linearize()
local_buf = [u for u in k.uops if u.op is UOps.DEFINE_LOCAL]
@@ -273,7 +272,7 @@ class TestLinearizer(unittest.TestCase):
def test_multireduce_upcasting(self):
# when upcasting multiple reductions, ensure ast_parse will create multiple uops even when using the result of past reductions
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float32, st=ShapeTracker(views=(View(shape=(8, 7), strides=(7, 1), offset=0, mask=None, contiguous=True),)))), LazyOp(op=UnaryOps.NEG, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float32, st=ShapeTracker(views=(View(shape=(8, 7), strides=(7, 1), offset=0, mask=None, contiguous=True),),))),), arg=(1,)),), arg=None),)),), arg=(1,)),), arg=MemBuffer(idx=0, dtype=dtypes.float32, st=ShapeTracker(views=(View(shape=(8, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)))), # noqa: E501
k = Lowerer(ast)
k = Kernel(ast)
k.upcast()
k.linearize()
define_globals = [u for u in k.uops if u.op is UOps.DEFINE_GLOBAL]
@@ -302,7 +301,7 @@ class TestLinearizer(unittest.TestCase):
@unittest.skip("AST has implicit movement ops")
def test_multireduce_loop_scope(self):
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 32), strides=(864, 32, 1), offset=0, mask=None, contiguous=True),)))), LazyOp(op=UnaryOps.NEG, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 32), strides=(864, 32, 1), offset=0, mask=None, contiguous=True),)))),), arg=(2,)), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.03125, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 32), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=None))), LazyOp(op=UnaryOps.RECIP, src=(LazyOp(op=UnaryOps.SQRT, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 32), strides=(864, 32, 1), offset=0, mask=None, contiguous=True),)))), LazyOp(op=UnaryOps.NEG, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 32), strides=(864, 32, 1), offset=0, mask=None, contiguous=True),)))),), arg=(2,)), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.03125, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 32), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=None)), arg=None), LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 32), strides=(864, 32, 1), offset=0, mask=None, contiguous=True),)))), LazyOp(op=UnaryOps.NEG, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 32), strides=(864, 32, 1), offset=0, mask=None, contiguous=True),)))),), arg=(2,)), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.03125, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 32), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=None)), arg=None)), arg=None),), arg=(2,)), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.03125, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 1), strides=(27, 1, 0), offset=0, mask=None, contiguous=True),))))), arg=None),), arg=None),)),),),), arg=(2,)),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 1), strides=(27, 1, 0), offset=0, mask=None, contiguous=True),),))), # noqa: E501
k = Lowerer(ast)
k = Kernel(ast)
k.hand_coded_optimizations()
k.linearize()
def get_recursive_children(x:UOp): return set.union(set(x.src), *[get_recursive_children(v) for v in x.src])
@@ -377,7 +376,7 @@ class TestLinearizer(unittest.TestCase):
# these are of size 3 to avoid float4 coalesce
r = a[:-1] + a[1:]
k = Lowerer(create_schedule([r.lazydata])[-1].ast)
k = Kernel(create_schedule([r.lazydata])[-1].ast)
k.upcast()
k.linearize()
num_loads = len([uop for uop in k.uops if uop.op is UOps.LOAD])
@@ -395,7 +394,7 @@ class TestLinearizer(unittest.TestCase):
b = LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=1, dtype=DT, st=ST)), VAL))
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=BinaryOps.ADD, src=(a,b)),), arg=MemBuffer(idx=0, dtype=DT, st=ST))
lin = Lowerer(ast)
lin = Kernel(ast)
lin.linearize()
assert len(lin.uops.uops) <= 7, "too many uops"
@@ -408,7 +407,7 @@ class TestLinearizer(unittest.TestCase):
a, b = Tensor.randn(1).realize(), Tensor.randn(1).realize()
r = a.expand([2]) + b.expand([2])
k = Lowerer(create_schedule([r.lazydata])[-1].ast)
k = Kernel(create_schedule([r.lazydata])[-1].ast)
k.upcast()
k.linearize()
num_ops = len([uop for uop in k.uops if uop.op is UOps.ALU])
@@ -419,7 +418,7 @@ class TestLinearizer(unittest.TestCase):
x, w = Tensor.randn((1,1,3)).realize(), Tensor.randn((1,1,2)).realize()
r = Tensor.conv2d(x,w,padding=1).relu()
k = Lowerer(create_schedule([r.lazydata])[-1].ast)
k = Kernel(create_schedule([r.lazydata])[-1].ast)
k.upcast()
k.upcast()
k.linearize()
@@ -435,7 +434,7 @@ class TestLinearizer(unittest.TestCase):
def test_upcast_with_locals(self):
x, y = Tensor.rand(1,128), Tensor.rand(128, 128)
r = (x@y).relu()
k = Lowerer(create_schedule([r.lazydata])[-1].ast)
k = Kernel(create_schedule([r.lazydata])[-1].ast)
k.hand_coded_optimizations()
k.linearize()
@@ -469,7 +468,7 @@ class TestLinearizer(unittest.TestCase):
a, b = Tensor.randn(1).realize(), Tensor.randn(1).realize()
r = Tensor.stack(a, b)
k = Lowerer(create_schedule([r.lazydata])[-1].ast)
k = Kernel(create_schedule([r.lazydata])[-1].ast)
k.upcast()
k.linearize()
num_ops = len([uop for uop in k.uops if uop.op is UOps.ALU])
@@ -479,14 +478,14 @@ class TestLinearizer(unittest.TestCase):
for tensor_dtype, acc_dtype in (
(dtypes.bool, dtypes.int), (dtypes.int16, dtypes.int), (dtypes.float16, dtypes.float), (dtypes.bfloat16, dtypes.float)):
a = Tensor([1, 2, 3], dtype=tensor_dtype).sum()
k = Lowerer(create_schedule([a.lazydata])[-1].ast)
k = Kernel(create_schedule([a.lazydata])[-1].ast)
k.linearize()
local = [uop for uop in k.uops if uop.op is UOps.DEFINE_ACC]
assert local[0].dtype == acc_dtype
def test_arg_acc_dtype(self):
def helper_arg_acc_dtype(c: Tensor, expected_dtype:DType):
k = Lowerer(create_schedule([c.lazydata])[-1].ast)
k = Kernel(create_schedule([c.lazydata])[-1].ast)
k.linearize()
local = [uop for uop in k.uops if uop.op is UOps.DEFINE_ACC]
assert local[0].dtype == expected_dtype
@@ -550,7 +549,7 @@ class TestLinearizer(unittest.TestCase):
c = a.conv2d(b, padding=1, acc_dtype=tc.dtype_out)
realized_ast, real_bufs = helper_realized_ast(c)
k = Lowerer(realized_ast)
k = Kernel(realized_ast)
k.apply_tensor_cores(1, axis=axis, tc_opt=2)
k.linearize()
assert len([uop for uop in k.uops if uop.op is UOps.WMMA]) > 0, "tensor core not triggered"
@@ -567,7 +566,7 @@ class TestLinearizer(unittest.TestCase):
# check that get_linearizer_actions produces all 9 options
from tinygrad.engine.search import get_linearizer_actions
tc_actions = [k for i, k in get_linearizer_actions(Lowerer(realized_ast), False).items() if k.applied_opts[0].op == OptOps.TC]
tc_actions = [k for i, k in get_linearizer_actions(Kernel(realized_ast), False).items() if k.applied_opts[0].op == OptOps.TC]
assert len(tc_actions) == 9, f"get_linearizer_actions should contain 9 possible TC actions, only got {len(tc_actions)}"
@unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores")
@@ -677,7 +676,7 @@ class TestLinearizer(unittest.TestCase):
sched = [si for si in create_schedule([t.lazydata]) if si.ast.op is MetaOps.SINK]
assert len(sched) == 1
lin = Lowerer(sched[0].ast)
lin = Kernel(sched[0].ast)
assert sum(u.arg is UnaryOps.RECIP for u in lin.linearize().uops) == max_ops, msg
a = Tensor.rand((4,4))
@@ -697,7 +696,7 @@ class TestLinearizer(unittest.TestCase):
t = Tensor([2]).reshape(1, 1).expand(256, 256).sum()
sched = [si for si in create_schedule([t.lazydata]) if si.ast.op is MetaOps.SINK]
assert len(sched) == 1
lin = Lowerer(sched[0].ast)
lin = Kernel(sched[0].ast)
assert not any(u.op is UOps.RANGE for u in lin.linearize().uops), "found loop in sum collapse"
def test_assign_fold(self):
@@ -716,7 +715,7 @@ class TestLinearizer(unittest.TestCase):
sched_copy = sched[:]
run_schedule(sched)
np.testing.assert_equal(a.flatten().numpy(), [1.,1.,1.,1.,2.,2.,2.,2.,1.,1.,1.,1.,1.,1.,1.,1.])
lin = Lowerer(sched_copy[-1].ast)
lin = Kernel(sched_copy[-1].ast)
lin.hand_coded_optimizations()
lin.linearize()
assert not any(u.arg == TernaryOps.WHERE for u in lin.uops), "found where where where should be folded"
@@ -844,7 +843,7 @@ class TestFloat4(unittest.TestCase):
c = a + b
s = create_schedule([c.lazydata])[0]
k = Lowerer(s.ast)
k = Kernel(s.ast)
k.hand_coded_optimizations()
k.linearize()
@@ -856,7 +855,7 @@ class TestFloat4(unittest.TestCase):
c = a + b
s = create_schedule([c.lazydata])[0]
k = Lowerer(s.ast)
k = Kernel(s.ast)
k.shift_to(0, 4) # float4 dimension
k.shift_to(0, 2, insert_before=k.shape_len-1)
k.upcast()
@@ -872,7 +871,7 @@ class TestFloat4(unittest.TestCase):
c = a + b
s = create_schedule([c.lazydata])[0]
k = Lowerer(s.ast)
k = Kernel(s.ast)
k.hand_coded_optimizations() # implicit trigger float4 dim
k.linearize()
@@ -884,7 +883,7 @@ class TestFloat4(unittest.TestCase):
c = a + b
s = create_schedule([c.lazydata])[0]
k = Lowerer(s.ast)
k = Kernel(s.ast)
k.shift_to(len(k.full_unupcasted_shape)-1, 4) # manual trigger float4 dim
k.upcast()
k.shift_to(len(k.full_unupcasted_shape)-1, 2, insert_before=k.shape_len-1)
@@ -902,7 +901,7 @@ class TestFloat4(unittest.TestCase):
# float4 should be emitted (the reduce axis of size 4 is the float4 axis here)
s = create_schedule([c.lazydata])[0]
k = Lowerer(s.ast)
k = Kernel(s.ast)
k.upcast()
k.linearize()
@@ -917,7 +916,7 @@ class TestFloat4(unittest.TestCase):
# don't.
s = create_schedule([c.lazydata])[0]
k = Lowerer(s.ast)
k = Kernel(s.ast)
k.upcast()
k.upcast()
k.linearize()
@@ -933,7 +932,7 @@ class TestFloat4(unittest.TestCase):
# since the top axis is not contiguous.
s = create_schedule([c.lazydata])[0]
k = Lowerer(s.ast)
k = Kernel(s.ast)
k.shift_to(0, 4, top=True) # top axes are float4 axes
k.upcast()
k.linearize()
@@ -949,7 +948,7 @@ class TestFloat4(unittest.TestCase):
# since the top axis is not contiguous.
s = create_schedule([c.lazydata])[0]
k = Lowerer(s.ast)
k = Kernel(s.ast)
k.shift_to(0, 4) # float4 axis
k.upcast()
k.linearize()
@@ -964,7 +963,7 @@ class TestFloat4(unittest.TestCase):
# should float4 b but not a
s = create_schedule([c.lazydata])[0]
k = Lowerer(s.ast)
k = Kernel(s.ast)
k.shift_to(0, 4) # float4 axis
k.upcast()
k.linearize()
@@ -977,7 +976,7 @@ class TestHandCodedOpts(unittest.TestCase):
layer_2 = Tensor.cat(layer_1.unsqueeze(0), Tensor.rand(6, 20))
s = create_schedule([layer_2.lazydata])[-1]
k = Lowerer(s.ast)
k = Kernel(s.ast)
k.hand_coded_optimizations()
assert len(k.bufs) == 6 # make sure all ops are done in one kernel
# masked upcast should upcast masked axis of size 7
@@ -989,7 +988,7 @@ class TestHandCodedOpts(unittest.TestCase):
monster = Tensor.stack(*[Tensor.stack(*[Tensor.rand(16) for _ in range(6)]) for _ in range(6)])
s = create_schedule([monster.lazydata])[-1]
k = Lowerer(s.ast)
k = Kernel(s.ast)
k.hand_coded_optimizations()
assert len(k.bufs) == 37 # make sure all ops are done in one kernel
# should upcast the two Tensor.stacks
@@ -1003,7 +1002,7 @@ class TestHandCodedOpts(unittest.TestCase):
wino_schedule = create_schedule([out.lazydata])
# collect upcasts of tile transform kernels
for i, si in enumerate(wino_schedule):
k = Lowerer(si.ast)
k = Kernel(si.ast)
k.hand_coded_optimizations()
if k.reduceop is not None: continue # not a tile transform kernel (there is a gemm reduce kernel)
if len(k.bufs) < 36: continue # not a tile transform kernel (there's a permute kernel at the end)
@@ -1016,7 +1015,7 @@ class TestHandCodedOpts(unittest.TestCase):
out.mean().backward()
backward_schedule = create_schedule([x.grad.lazydata, w.grad.lazydata])
for si in backward_schedule:
k = Lowerer(si.ast)
k = Kernel(si.ast)
k.hand_coded_optimizations()
k.linearize()
if len(k.bufs) < 20: continue # not a tile transform kernel
@@ -1058,11 +1057,11 @@ def helper_linearizer_opt(r:Union[Tensor, List[Tensor]], *args, **kwargs):
return _helper_linearizer_opt_ast(realized_ast, real_bufs, *args, **kwargs)
def _helper_linearizer_opt_ast(realized_ast:LazyOp, real_bufs:List[Buffer], opts=[],
apply_tc=False, atol=1e-4, rtol=1e-4, color_sizes=[], wanna_output=[]) -> List[Lowerer]:
lins: List[Lowerer] = []
apply_tc=False, atol=1e-4, rtol=1e-4, color_sizes=[], wanna_output=[]) -> List[Kernel]:
lins: List[Kernel] = []
outbufs = [real_bufs[i] for i in range(len(realized_ast.src))]
def get_prg(k:Lowerer): return CompiledRunner(replace(k.to_program(), dname=Device.DEFAULT))
def get_prg(k:Kernel): return CompiledRunner(replace(k.to_program(), dname=Device.DEFAULT))
def check_opt(opts, create_k, expected_color_size):
k = create_k()
@@ -1082,7 +1081,7 @@ def _helper_linearizer_opt_ast(realized_ast:LazyOp, real_bufs:List[Buffer], opts
np.testing.assert_allclose(np.frombuffer(buf.as_buffer(), _to_np_dtype(buf.dtype)), wanna_output[i], atol=atol, rtol=rtol)
# Get baseline if it is not provided, which is not optimized at all.
k = Lowerer(realized_ast)
k = Kernel(realized_ast)
lins.append(k)
prg = get_prg(k)
prg.exec(real_bufs)
@@ -1092,7 +1091,7 @@ def _helper_linearizer_opt_ast(realized_ast:LazyOp, real_bufs:List[Buffer], opts
np.testing.assert_allclose(np.frombuffer(buf.as_buffer(), _to_np_dtype(buf.dtype)), wanna_output[i], atol=atol, rtol=rtol)
# Check correctness of handcoded optimiztions.
k = Lowerer(realized_ast)
k = Kernel(realized_ast)
lins.append(k)
k.hand_coded_optimizations()
prg = get_prg(k)
@@ -1101,7 +1100,7 @@ def _helper_linearizer_opt_ast(realized_ast:LazyOp, real_bufs:List[Buffer], opts
for i, buf in enumerate(outbufs):
np.testing.assert_allclose(np.frombuffer(buf.as_buffer(), _to_np_dtype(buf.dtype)), wanna_output[i], atol=atol, rtol=rtol)
for i, x in enumerate(opts): # Check custom transformations if any.
check_opt(x, lambda: Lowerer(realized_ast), color_sizes[i] if i < len(color_sizes) else None)
check_opt(x, lambda: Kernel(realized_ast), color_sizes[i] if i < len(color_sizes) else None)
return lins
# creates a back-to-back multi reduce AST by merging r0 and r1.
@@ -1438,14 +1437,14 @@ class TestKernelOpts(unittest.TestCase):
[Opt(OptOps.LOCAL, 0, 2), Opt(OptOps.LOCAL, 2, 2)],
]
for x in invalid_opts:
k = Lowerer(realized_ast)
k = Kernel(realized_ast)
with self.assertRaises(AssertionError):
assert k.apply_tensor_cores(use_tensor_cores=1, extra_opts=x), "no valid tensor core" # for METAL in runners
@unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores")
def test_buf_index_not_found_tensor_core(self):
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=UnaryOps.CAST, src=(LazyOp(op=BinaryOps.CMPNE, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(1243, 256), strides=(0, 1), offset=0, mask=None, contiguous=False),)))), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(1243, 256), strides=(1, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=dtypes.float), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=3, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1243, 256), strides=(1, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=(0,)),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 256), strides=(0, 1), offset=0, mask=None, contiguous=True),)))) # noqa: E501
k = Lowerer(ast, opts=Device[Device.DEFAULT].renderer)
k = Kernel(ast, opts=Device[Device.DEFAULT].renderer)
with self.assertRaises(KernelOptError):
k.apply_opt(Opt(OptOps.TC, 0, 1))
@@ -1462,7 +1461,7 @@ class TestKernelOpts(unittest.TestCase):
c, d = Tensor.rand(M, K, dtype=tc.dtype_in).realize(), Tensor.rand(K, N, dtype=tc.dtype_in).realize()
r1 = c.matmul(d, acc_dtype=tc.dtype_out)
ast = _temp_create_multireduce_ast(r0, r1)
lin = Lowerer(ast)
lin = Kernel(ast)
lin.apply_opt(Opt(op=OptOps.TC, axis=0, amt=2))
lin.linearize()
result = compare_linearizer(lin)

File diff suppressed because one or more lines are too long

View File

@@ -2,7 +2,7 @@
import unittest
from tinygrad import dtypes, Device
from tinygrad.helpers import CI
from tinygrad.codegen.lowerer import Lowerer
from tinygrad.codegen.kernel import Kernel
from tinygrad.engine.search import Opt, OptOps
from tinygrad.engine.search import time_linearizer, bufs_from_lin
@@ -12,7 +12,7 @@ from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import View
def _test_overflow(ast, opts):
lin = Lowerer(ast)
lin = Kernel(ast)
for opt in opts: lin.apply_opt(opt)
lin.linearize()
bufs = bufs_from_lin(lin)

View File

@@ -10,7 +10,7 @@ from tinygrad.device import Device
from tinygrad.tensor import Tensor
from tinygrad.ops import BinaryOps, MetaOps, ReduceOps, UnaryOps
from tinygrad.helpers import DEBUG, flatten, getenv
from tinygrad.codegen.lowerer import Lowerer
from tinygrad.codegen.kernel import Kernel
from tinygrad.engine.graph import print_tree
from tinygrad.engine.schedule import create_schedule
from tinygrad.engine.realize import run_schedule
@@ -38,7 +38,7 @@ def check_schedule(t:Union[Tensor, List[Tensor]], allowed:int, to_prerealize:Opt
# test the (non loadops) ops linearize
for s in sched:
if s.ast.op is not MetaOps.SINK: continue
l = Lowerer(s.ast)
l = Kernel(s.ast)
l.hand_coded_optimizations()
l.linearize()
return sched

View File

@@ -1,7 +1,7 @@
import unittest
from tinygrad.codegen.kernel import Opt, OptOps
from tinygrad.codegen.lowerer import Lowerer
from tinygrad.codegen.kernel import Kernel
from tinygrad.engine.schedule import create_schedule
from tinygrad.engine.search import time_linearizer, bufs_from_lin, actions, beam_search
from tinygrad.device import Device, Buffer
@@ -19,12 +19,12 @@ class TestTimeLinearizer(unittest.TestCase):
out = Buffer(Device.DEFAULT, si.outputs[0].size, si.outputs[0].dtype).allocate()
memops = {x.arg.idx:x.arg.st.real_size() for x in si.ast.lazyops if x.op is BufferOps.LOAD}
rawbufs = [out] + [Buffer(Device.DEFAULT, memops[i], x.dtype).allocate() for i,x in enumerate(si.inputs, start=len(si.outputs))]
tm = time_linearizer(Lowerer(si.ast), rawbufs, allow_test_size=False, cnt=10, disable_cache=True)
tm = time_linearizer(Kernel(si.ast), rawbufs, allow_test_size=False, cnt=10, disable_cache=True)
assert tm > 0 and tm != float('inf')
def test_bufs_from_lin(self):
si = [i for i in create_schedule([Tensor([1,2,3,4]).add(1).lazydata]) if i.ast.op is MetaOps.SINK][0]
rawbufs = bufs_from_lin(lin:=Lowerer(si.ast))
rawbufs = bufs_from_lin(lin:=Kernel(si.ast))
assert len(rawbufs) == len(lin.membufs)
assert all(r is not None for r in rawbufs)
assert all(isinstance(r, Buffer) for r in rawbufs)
@@ -36,7 +36,7 @@ class TestTimeLinearizer(unittest.TestCase):
"""
# ast of Tensor.zeros(16).contiguous().realize()
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(16,), strides=(0,), offset=0, mask=None, contiguous=False),)))),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(16,), strides=(1,), offset=0, mask=None, contiguous=True),)))) # noqa: E501
lin = Lowerer(ast)
lin = Kernel(ast)
bufs = bufs_from_lin(lin)
kernel_count = GlobalCounters.kernel_count
@@ -71,7 +71,7 @@ class TestBEAM(unittest.TestCase):
b = Tensor.rand(3)
realized_ast, _ = helper_realized_ast(a @ b)
from tinygrad.engine.search import get_linearizer_actions
lins = get_linearizer_actions(Lowerer(realized_ast), False).values()
lins = get_linearizer_actions(Kernel(realized_ast), False).values()
# ensure amt=0 are not duplicated
if Opt(OptOps.UPCAST, 0, 0) in actions:
@@ -88,7 +88,7 @@ class TestBEAM(unittest.TestCase):
def test_filter_global_buffer(self):
# taken from https://github.com/tinygrad/tinygrad/issues/4612
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=ReduceOps.MAX, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=0, mask=((0, 64128),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False))))), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=-64128, mask=((64128, 128256),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False)))))), arg=None), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=3, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=-128256, mask=((128256, 192384),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False)))))), arg=None), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=4, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=-192384, mask=((192384, 256512),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False)))))), arg=None), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=5, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=-256512, mask=((256512, 320640),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False)))))), arg=None), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=6, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=-320640, mask=((320640, 384768),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False)))))), arg=None), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=1.4285714285714286, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 501, 256), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=(1,)),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 1, 256), strides=(0, 0, 1), offset=0, mask=None, contiguous=True),)))) # noqa: E501
lin = Lowerer(ast)
lin = Kernel(ast)
bufs = bufs_from_lin(lin)
best_lin = beam_search(lin, bufs, 3)

View File

@@ -1,7 +1,7 @@
from __future__ import annotations
import unittest
from tinygrad.codegen.lowerer import Lowerer
#from tinygrad.codegen.lowerer import Lowerer
from tinygrad.codegen.kernel import Kernel
#from tinygrad.codegen.kernel import Kernel
from tinygrad.engine.graph import print_tree
from tinygrad.helpers import DEBUG
from tinygrad.ops import BufferOps, MemBuffer, LazyOp, ReduceOps, MetaOps, verify_lazyop
@@ -16,7 +16,7 @@ def lower(*ast:LazyOp):
for op in ast: print_tree(op)
try: verify_lazyop(sink_ast)
except AssertionError: raise InvalidLazyOpException()
k = Lowerer(sink_ast)
k = Kernel(sink_ast)
k.linearize()
if DEBUG >= 6: k.uops.print()
if DEBUG >= 4: print(k.to_program().src)

View File

@@ -2,7 +2,7 @@ import unittest
from tinygrad import Tensor, GlobalCounters
from tinygrad.helpers import Timing, CI, Profiling, WINO, DEBUG, getenv
from tinygrad.ops import MetaOps
from tinygrad.codegen.lowerer import Lowerer
from tinygrad.codegen.kernel import Kernel
from tinygrad.engine.schedule import create_schedule
class TestWinograd(unittest.TestCase):
@@ -26,7 +26,7 @@ class TestWinograd(unittest.TestCase):
if s.ast.op is not MetaOps.SINK: continue
ops = s.ast.lazyops
with Timing(f"linearize {i} with {len(ops):4d} ops: "):
l = Lowerer(s.ast)
l = Kernel(s.ast)
l.hand_coded_optimizations()
l.linearize()
assert len(l.sts) <= 256 # just the current value to prevent regression

View File

@@ -4,7 +4,7 @@ from tinygrad import dtypes, Tensor
from tinygrad.helpers import prod
from tinygrad.ops import LazyOp, UnaryOps, BinaryOps, ReduceOps, get_lazyop_info, BufferOps, MemBuffer
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.codegen.lowerer import Lowerer
from tinygrad.codegen.kernel import Kernel
from tinygrad.codegen.uops import flops_mem
class TestFlopCounter(unittest.TestCase):
@@ -15,7 +15,7 @@ class TestFlopCounter(unittest.TestCase):
def compare_flop_counters(self, ast):
info = get_lazyop_info(ast.src[0])
lin = Lowerer(ast)
lin = Kernel(ast)
# NOTE: why does hand coded optimizations change flops for the GEMM?
#lin.hand_coded_optimizations()
lin.linearize()

View File

@@ -4,14 +4,19 @@ from dataclasses import replace
from collections import defaultdict
from typing import Optional, List, Tuple, cast, Dict, Union, Final, DefaultDict
from tinygrad.engine.graph import print_tree
from tinygrad.ops import LazyOp, UnaryOps, BinaryOps, ReduceOps, MemBuffer, ConstBuffer, BufferOps, MetaOps, UNSAFE_PAD_OPS, verify_lazyop, KernelInfo
from tinygrad.ops import LazyOp, UnaryOps, BinaryOps, ReduceOps, MemBuffer, ConstBuffer, BufferOps, MetaOps, UNSAFE_PAD_OPS, \
verify_lazyop, KernelInfo, get_lazyop_info
from tinygrad.device import Device
from tinygrad.renderer import Renderer, TensorCore
from tinygrad.renderer import Renderer, TensorCore, Program
from tinygrad.dtype import dtypes, ImageDType
from tinygrad.helpers import all_same, colored, ansilen, dedup, getenv, prod, DEBUG, TC_OPT, USE_TC, round_up, all_int, get_contraction, to_function_name # noqa: E501
from tinygrad.helpers import all_same, colored, ansilen, dedup, getenv, prod, DEBUG, TC_OPT, USE_TC, round_up, all_int, \
get_contraction, to_function_name, diskcache_put, ContextVar
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.symbolic import sint
from tinygrad.shape.view import strides_for_shape
from tinygrad.codegen.uops import UOps, flops_mem
from tinygrad.codegen.uopgraph import UOpGraph
from tinygrad.codegen.lowerer import lazyop_to_uop
from dataclasses import dataclass
from enum import Enum, auto
@@ -719,3 +724,45 @@ class Kernel:
arg = op.arg
return LazyOp(op.op, tuple(fixup_ast(x) for x in op.src), arg)
return fixup_ast(self.ast)
# **** this is the lowerer ****
def linearize(self) -> Kernel:
modified_ast = self.get_optimized_ast()
if DEBUG >= 3:
print(self.name)
print_tree(modified_ast)
uop_sink = lazyop_to_uop(modified_ast, self.opts)
# extract global/local sizes
if self.opts.has_local:
self.global_size: Optional[List[int]] = [1,1,1]
self.local_size: Optional[List[int]] = [1,1,1]
for u in uop_sink.parents:
if u.op is UOps.SPECIAL:
if u.arg[1][0] == 'l': self.local_size[u.arg[0]] = u.arg[2]
else: self.global_size[u.arg[0]] = u.arg[2]
else:
self.global_size, self.local_size = None, None
# generate the UOpGraph
self.uops:UOpGraph = UOpGraph(uop_sink, self.opts)
if DEBUG >= 5: self.uops.print()
if getenv("GRAPHUOPS"):
self.uops.graph()
if getenv("GRAPHUOPS") == 2: exit(0)
return self
def to_program(self) -> Program:
self.linearize()
src = self.opts.render(name:=to_function_name(self.name), self.uops)
if getenv("RUN_PROCESS_REPLAY"):
table_name = f"process_replay_{getenv('GITHUB_SHA', 'HEAD')}"
diskcache_put(table_name, id(self), (self.ast, self.opts, self.applied_opts, name, src, {k:v.value for k,v in ContextVar._cache.items()}))
info = get_lazyop_info(self.ast.src[0]) # TODO: this should be removed
ops, mem = flops_mem(self.uops.uops)
run_count = prod((self.global_size or []) + (self.local_size or []))
return Program(self.name, src, self.opts.device, self.global_size, self.local_size,
self.uops, min(info.flops, ops * run_count), min(info.mem_estimate, mem * run_count))

View File

@@ -1,14 +1,12 @@
from __future__ import annotations
from typing import List, Tuple, cast, Optional, Any, Dict
import functools
from tinygrad.codegen.kernel import Kernel
from tinygrad.shape.shapetracker import ShapeTracker, View
from tinygrad.dtype import dtypes, PtrDType, ImageDType, DType
from tinygrad.ops import BufferOps, LazyOp, TernaryOps, ReduceOps, UnaryOps, MetaOps, get_lazyop_info, KernelInfo
from tinygrad.codegen.uops import UOp, flops_mem, UOps
from tinygrad.codegen.uopgraph import UOpGraph
from tinygrad.renderer import Program, Renderer
from tinygrad.helpers import to_function_name, DEBUG, getenv, prod, diskcache_put, ContextVar
from tinygrad.ops import BufferOps, LazyOp, TernaryOps, ReduceOps, UnaryOps, MetaOps, KernelInfo
from tinygrad.codegen.uops import UOp, UOps
from tinygrad.renderer import Renderer
from tinygrad.helpers import getenv, prod
# TODO: this needs to be replaced, there shouldn't be variables in the shapetracker, only ints and UOps
from tinygrad.shape.symbolic import Variable, NumNode, SumNode, MulNode, DivNode, ModNode, LtNode, AndNode
@@ -156,47 +154,5 @@ class IndependentLowerer:
# NOTE: always using ridxs is fine here
return UOp(UOps.REDUCE, dtype, (in_uops[0],) + tuple(self.ridxs[i] for i in x.arg), x.op)
return UOp.alu(x.op, *in_uops)
def lazyop_to_uop(ast:LazyOp, opts:Renderer) -> UOp: return IndependentLowerer().lower(ast, opts)
# TODO: move this to Kernel
class Lowerer(Kernel):
def linearize(self) -> Lowerer:
modified_ast = self.get_optimized_ast()
if DEBUG >= 3:
print(self.name)
from tinygrad.engine.graph import print_tree
print_tree(modified_ast)
uop_sink = lazyop_to_uop(modified_ast, self.opts)
# extract global/local sizes
if self.opts.has_local:
self.global_size: Optional[List[int]] = [1,1,1]
self.local_size: Optional[List[int]] = [1,1,1]
for u in uop_sink.parents:
if u.op is UOps.SPECIAL:
if u.arg[1][0] == 'l': self.local_size[u.arg[0]] = u.arg[2]
else: self.global_size[u.arg[0]] = u.arg[2]
else:
self.global_size, self.local_size = None, None
# generate the UOpGraph
self.uops:UOpGraph = UOpGraph(uop_sink, self.opts)
if DEBUG >= 5: self.uops.print()
if getenv("GRAPHUOPS"):
self.uops.graph()
if getenv("GRAPHUOPS") == 2: exit(0)
return self
def to_program(self) -> Program:
self.linearize()
src = self.opts.render(name:=to_function_name(self.name), self.uops)
if getenv("RUN_PROCESS_REPLAY"):
table_name = f"process_replay_{getenv('GITHUB_SHA', 'HEAD')}"
diskcache_put(table_name, id(self), (self.ast, self.opts, self.applied_opts, name, src, {k:v.value for k,v in ContextVar._cache.items()}))
info = get_lazyop_info(self.ast.src[0]) # TODO: this should be removed
ops, mem = flops_mem(self.uops.uops)
run_count = prod((self.global_size or []) + (self.local_size or []))
return Program(self.name, src, self.opts.device, self.global_size, self.local_size,
self.uops, min(info.flops, ops * run_count), min(info.mem_estimate, mem * run_count))

View File

@@ -6,31 +6,31 @@ from tinygrad.ops import MetaOps, LazyOp
from tinygrad.device import Device, Buffer
from tinygrad.shape.symbolic import Variable, sym_infer, sint
from tinygrad.renderer import Renderer, Program
from tinygrad.codegen.lowerer import Lowerer
from tinygrad.codegen.kernel import Kernel
from tinygrad.engine.schedule import ScheduleItem
# **************** Program Creation ****************
logkerns, logkerns_level = open(getenv("LOGKERNS", ""), "a") if getenv("LOGKERNS", "") else None, getenv("LOGKERNS_LEVEL", 1)
def get_linearizer(renderer:Renderer, ast:LazyOp) -> Lowerer:
def get_linearizer(renderer:Renderer, ast:LazyOp) -> Kernel:
if DEBUG >= 5:
from tinygrad.engine.graph import print_tree
print_tree(ast)
k = Lowerer(ast, opts=renderer)
k = Kernel(ast, opts=renderer)
k.required_optimizations()
if not NOOPT:
if not (used_tensor_cores:=k.apply_tensor_cores(getenv("TC", 1))): k.hand_coded_optimizations()
if BEAM >= 1:
from tinygrad.engine.search import beam_search, time_linearizer, bufs_from_lin
kb, k_opt = Lowerer(ast, opts=renderer), k
kb, k_opt = Kernel(ast, opts=renderer), k
kb.required_optimizations()
rawbufs = bufs_from_lin(kb, allocate=False)
k = beam_search(kb, rawbufs, BEAM.value, bool(getenv("BEAM_ESTIMATE", 1)))
if getenv("BEAM_COMPARE", 1):
# TODO: move the HC/TC/BEAM compare to beam_search so it can be optionally cached which choice is better
lins: List[Tuple[str, Lowerer]] = [(f"beam{BEAM.value}", k), (("tc" if used_tensor_cores else "hc"), k_opt)]
lins: List[Tuple[str, Kernel]] = [(f"beam{BEAM.value}", k), (("tc" if used_tensor_cores else "hc"), k_opt)]
if used_tensor_cores:
lins.append(("hc", Lowerer(ast, opts=renderer)))
lins.append(("hc", Kernel(ast, opts=renderer)))
lins[-1][1].hand_coded_optimizations()
timed = sorted([(nm, tk, time_linearizer(tk, rawbufs, allow_test_size=False, clear_l2=True)) for nm, tk in lins], key=lambda x: x[2])
if DEBUG >= 1: print(" < ".join(f"{nm:6s} : {lin.colored_shape(30, dense=True)} : {tm*1e6:8.2f} us" for nm, lin, tm in timed))

View File

@@ -6,7 +6,7 @@ from tinygrad.device import Device, Buffer, Compiler
from tinygrad.ops import MemBuffer
from tinygrad.helpers import prod, flatten, DEBUG, CACHELEVEL, diskcache_get, diskcache_put, getenv, Context, colored, to_function_name
from tinygrad.dtype import ImageDType
from tinygrad.codegen.lowerer import Lowerer
from tinygrad.codegen.kernel import Kernel
from tinygrad.codegen.kernel import Opt, OptOps, KernelOptError
from tinygrad.codegen.uopgraph import UOpGraph
from tinygrad.tensor import Tensor
@@ -53,7 +53,7 @@ def _time_program(p:Program, lib:bytes, var_vals, rawbufs, early_stop=None, max_
class TimeoutException(Exception): pass
def timeout_handler(signum, frame): raise TimeoutException()
def _try_compile_linearized_w_idx(x:Tuple[int,Lowerer], compiler:Compiler) -> Tuple[int, Optional[Tuple[Program, bytes, float]]]:
def _try_compile_linearized_w_idx(x:Tuple[int,Kernel], compiler:Compiler) -> Tuple[int, Optional[Tuple[Program, bytes, float]]]:
signal.signal(signal.SIGALRM, timeout_handler)
# set timeout
signal.alarm(getenv("BEAM_TIMEOUT_SEC", 10))
@@ -85,7 +85,7 @@ def _ensure_buffer_alloc(bufs:List[Buffer]) -> List[Buffer]: return [buf.ensure_
# *** external API ***
# get (scrap) buffers for timing the linearizer
def bufs_from_lin(lin:Lowerer, allocate:bool=True) -> List[Buffer]:
def bufs_from_lin(lin:Kernel, allocate:bool=True) -> List[Buffer]:
bufsts:DefaultDict[int, List[MemBuffer]] = defaultdict(list)
for x in lin.membufs: bufsts[x.idx].append(x)
rawbufs:List[Optional[Buffer]] = [None]*len(bufsts)
@@ -97,7 +97,7 @@ def bufs_from_lin(lin:Lowerer, allocate:bool=True) -> List[Buffer]:
return cast(List[Buffer], rawbufs)
# get dictionary of all possible actions
def get_linearizer_actions(lin:Lowerer, include_0=True) -> Dict[int, Lowerer]:
def get_linearizer_actions(lin:Kernel, include_0=True) -> Dict[int, Kernel]:
acted_lins, max_up, max_lcl = {0:lin} if include_0 else {}, getenv("BEAM_UPCAST_MAX", 256), getenv("BEAM_LOCAL_MAX", 1024)
for i,a in enumerate(actions):
if a.axis is not None and a.op is not OptOps.TC:
@@ -115,7 +115,7 @@ def get_linearizer_actions(lin:Lowerer, include_0=True) -> Dict[int, Lowerer]:
return acted_lins
beam_pool, BEAM_DEBUG = None, getenv("BEAM_DEBUG")
def beam_search(lin:Lowerer, rawbufs:List[Buffer], amt:int, allow_test_size=True) -> Lowerer:
def beam_search(lin:Kernel, rawbufs:List[Buffer], amt:int, allow_test_size=True) -> Kernel:
global beam_pool
key = {"ast": lin.ast.key, "amt": amt, "allow_test_size": allow_test_size, "device": lin.opts.device, "suffix": lin.opts.suffix}
if not getenv("IGNORE_BEAM_CACHE") and CACHELEVEL >= 1 and (val:=diskcache_get("beam_search", key)) is not None:
@@ -123,7 +123,7 @@ def beam_search(lin:Lowerer, rawbufs:List[Buffer], amt:int, allow_test_size=True
for o in val[len(lin.applied_opts):]: ret.apply_opt(o)
return ret
beam: List[Tuple[Lowerer, float]] = [(lin, float("inf"))]
beam: List[Tuple[Kernel, float]] = [(lin, float("inf"))]
seen_libs = set()
default_parallel = multiprocessing.cpu_count() if lin.opts.device in {"CUDA", "AMD", "NV"} else 0
@@ -140,8 +140,8 @@ def beam_search(lin:Lowerer, rawbufs:List[Buffer], amt:int, allow_test_size=True
exiting, st = False, time.perf_counter()
dev = Device[lin.opts.device]
while not exiting:
acted_lins: List[Lowerer] = flatten([get_linearizer_actions(lin, include_0=False).values() for lin,_ in beam])
timed_lins: List[Tuple[Lowerer, float]] = []
acted_lins: List[Kernel] = flatten([get_linearizer_actions(lin, include_0=False).values() for lin,_ in beam])
timed_lins: List[Tuple[Kernel, float]] = []
_compile_fn = functools.partial(_try_compile_linearized_w_idx, compiler=dev.compiler)
for i,proc in (map(_compile_fn, enumerate(acted_lins)) if beam_pool is None else beam_pool.imap_unordered(_compile_fn, enumerate(acted_lins))):
if proc is None: continue
@@ -181,7 +181,7 @@ def optimize_local_size(clprg:Callable, global_size:List[int], rawbufs:List[Buff
assert not math.isinf(ret[0]), "all optimize_local_size exec failed"
return ret[1]
def time_linearizer(lin:Lowerer, rawbufs:List[Buffer], allow_test_size=True, max_global_size=65536, cnt=3, disable_cache=False, clear_l2=False) -> float: # noqa: E501
def time_linearizer(lin:Kernel, rawbufs:List[Buffer], allow_test_size=True, max_global_size=65536, cnt=3, disable_cache=False, clear_l2=False) -> float: # noqa: E501
key = {"ast": lin.ast.key, "opts": str(lin.applied_opts), "allow_test_size": allow_test_size,
"max_global_size": max_global_size, "clear_l2": clear_l2, "device": lin.opts.device, "suffix": lin.opts.suffix}
if not disable_cache and CACHELEVEL >= 2 and (val:=diskcache_get("time_linearizer", key)) is not None: return min(val)