mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
lowerer is kernel [run_process_replay] (#5437)
This commit is contained in:
@@ -2,7 +2,7 @@ from typing import List
|
||||
from extra.models.resnet import ResNet50
|
||||
from examples.mlperf.helpers import get_mlperf_bert_model
|
||||
from tinygrad import Tensor, Device, dtypes, nn
|
||||
from tinygrad.codegen.lowerer import Lowerer
|
||||
from tinygrad.codegen.kernel import Kernel
|
||||
from tinygrad.device import Compiled
|
||||
from tinygrad.engine.graph import print_tree
|
||||
from tinygrad.engine.schedule import create_schedule
|
||||
@@ -84,24 +84,24 @@ if __name__ == "__main__":
|
||||
if DEBUG >= 2:
|
||||
for ast in si.ast: print_tree(ast)
|
||||
|
||||
rawbufs = bufs_from_lin(Lowerer(si.ast))
|
||||
rawbufs = bufs_from_lin(Kernel(si.ast))
|
||||
|
||||
# "linearize" the op into uops in different ways
|
||||
lins:List[Lowerer] = []
|
||||
lins:List[Kernel] = []
|
||||
|
||||
# always try hand coded opt
|
||||
lin = Lowerer(si.ast, opts=device.renderer)
|
||||
lin = Kernel(si.ast, opts=device.renderer)
|
||||
lin.hand_coded_optimizations()
|
||||
lins.append(lin)
|
||||
|
||||
# maybe try tensor cores
|
||||
lin = Lowerer(si.ast, opts=device.renderer)
|
||||
lin = Kernel(si.ast, opts=device.renderer)
|
||||
if lin.apply_tensor_cores():
|
||||
lins.append(lin)
|
||||
|
||||
# try a beam search
|
||||
if beam:=getenv("BEAM"):
|
||||
lin = Lowerer(si.ast, opts=device.renderer)
|
||||
lin = Kernel(si.ast, opts=device.renderer)
|
||||
lin = beam_search(lin, rawbufs, beam, bool(getenv("BEAM_ESTIMATE", 1)))
|
||||
lins.append(lin)
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from typing import Tuple, List, NamedTuple, Any, Dict, Optional, Union, DefaultDict, cast
|
||||
from tinygrad.codegen.lowerer import UOps, MemOp, UOp
|
||||
from tinygrad.codegen.kernel import UOps, MemOp, UOp
|
||||
from tinygrad.ops import BinaryOps, UnaryOps
|
||||
from tinygrad.dtype import DType, dtypes
|
||||
from tinygrad.helpers import DEBUG
|
||||
|
||||
@@ -3,7 +3,7 @@ from platform import system
|
||||
from typing import Tuple, Dict, List, Optional
|
||||
from tinygrad import dtypes
|
||||
from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps
|
||||
from tinygrad.codegen.lowerer import UOps, UOp
|
||||
from tinygrad.codegen.kernel import UOps, UOp
|
||||
from tinygrad.helpers import CI
|
||||
from tinygrad.codegen.assembly import uops_to_asmstyle, AssemblyLanguage
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import List
|
||||
import struct
|
||||
from tinygrad.codegen.assembly import uops_to_asmstyle, AssemblyLanguage
|
||||
from tinygrad.codegen.lowerer import UOps, UOp
|
||||
from tinygrad.codegen.kernel import UOps, UOp
|
||||
from tinygrad import dtypes
|
||||
from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps
|
||||
from tinygrad.runtime.ops_cuda import arch
|
||||
|
||||
@@ -2,7 +2,7 @@ import yaml
|
||||
from typing import Tuple, Set, Dict
|
||||
from tinygrad import dtypes
|
||||
from tinygrad.codegen.assembly import AssemblyCodegen, Register
|
||||
from tinygrad.codegen.lowerer import UOps
|
||||
from tinygrad.codegen.kernel import UOps
|
||||
from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps
|
||||
from tinygrad.runtime.ops_gpu import ROCM_LLVM_PATH
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ from typing import Dict, List, Final, Callable, DefaultDict
|
||||
from collections import defaultdict
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, Op
|
||||
from tinygrad.helpers import DType, PtrDType, dtypes, ImageDType, DEBUG, getenv
|
||||
from tinygrad.codegen.lowerer import UOp, UOps
|
||||
from tinygrad.codegen.kernel import UOp, UOps
|
||||
from triton.compiler import compile as triton_compile
|
||||
import linecache
|
||||
import math
|
||||
|
||||
@@ -38,9 +38,9 @@ B = Tensor.rand(K, N, device="clang")
|
||||
C = (A.reshape(M, 1, K) * B.permute(1,0).reshape(1, N, K)).sum(axis=2)
|
||||
|
||||
sched = create_schedule([C.lazydata])
|
||||
from tinygrad.codegen.lowerer import Lowerer
|
||||
from tinygrad.codegen.kernel import Kernel
|
||||
from tinygrad.device import CompilerOptions
|
||||
lin = Lowerer(sched[-1].ast, CompilerOptions(has_local=False, supports_float4=False))
|
||||
lin = Kernel(sched[-1].ast, CompilerOptions(has_local=False, supports_float4=False))
|
||||
#lin.hand_coded_optimizations()
|
||||
lin.linearize()
|
||||
from tinygrad.runtime.ops_clang import renderer
|
||||
|
||||
@@ -7,7 +7,7 @@ from tinygrad.nn.optim import Adam
|
||||
from tinygrad.nn.state import get_parameters, get_state_dict, safe_save, safe_load, load_state_dict
|
||||
from tinygrad.engine.search import actions
|
||||
from extra.optimization.helpers import load_worlds, ast_str_to_lin, lin_to_feats, assert_same_lin
|
||||
from tinygrad.codegen.lowerer import Lowerer
|
||||
from tinygrad.codegen.kernel import Kernel
|
||||
from tinygrad.helpers import getenv
|
||||
|
||||
# stuff needed to unpack a kernel
|
||||
@@ -38,7 +38,7 @@ def dataset_from_cache(fn):
|
||||
for f in tqdm(cur.fetchall()):
|
||||
Xs,As = [], []
|
||||
try:
|
||||
lin = Lowerer(eval(f[0]))
|
||||
lin = Kernel(eval(f[0]))
|
||||
opts = pickle.loads(f[-1])
|
||||
for o in opts:
|
||||
Xs.append(lin_to_feats(lin, use_sts=True))
|
||||
|
||||
@@ -13,7 +13,7 @@ inf, nan = float('inf'), float('nan')
|
||||
from tinygrad.codegen.kernel import Opt, OptOps
|
||||
|
||||
# more stuff
|
||||
from tinygrad.codegen.lowerer import Lowerer
|
||||
from tinygrad.codegen.kernel import Kernel
|
||||
from tinygrad.engine.search import actions
|
||||
from extra.optimization.helpers import lin_to_feats
|
||||
from extra.optimization.pretrain_valuenet import ValueNet
|
||||
@@ -48,7 +48,7 @@ def dataset_from_cache(fn):
|
||||
new_tm = min(opts_to_outcome[(ast,k)])
|
||||
if math.isinf(old_tm) or math.isinf(new_tm) or old_tm < 1e-9 or new_tm < 1e-9: continue
|
||||
try:
|
||||
lin = Lowerer(eval(ast))
|
||||
lin = Kernel(eval(ast))
|
||||
except Exception:
|
||||
continue
|
||||
for opt in k[:-1]: lin.apply_opt(opt)
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
import random
|
||||
from extra.optimization.helpers import load_worlds, ast_str_to_lin
|
||||
from tinygrad.engine.search import actions
|
||||
from tinygrad.codegen.lowerer import Lowerer
|
||||
from tinygrad.codegen.kernel import Kernel
|
||||
from tinygrad.helpers import tqdm
|
||||
|
||||
tactions = set()
|
||||
def test_rebuild(lin):
|
||||
linr = Lowerer(lin.ast)
|
||||
linr = Kernel(lin.ast)
|
||||
for o in lin.applied_opts:
|
||||
assert o in actions, f"{o} is not in actions"
|
||||
tactions.add(o)
|
||||
|
||||
@@ -9,12 +9,12 @@ from tinygrad.shape.symbolic import Variable, NumNode
|
||||
inf, nan = float('inf'), float('nan')
|
||||
|
||||
# kernel unpacker
|
||||
from tinygrad.codegen.lowerer import Lowerer
|
||||
from tinygrad.codegen.kernel import Kernel
|
||||
def ast_str_to_ast(ast_str:str) -> Tuple[LazyOp,...]: return LazyOp(MetaOps.SINK, val) if isinstance(val:=eval(ast_str), tuple) else val
|
||||
def ast_str_to_lin(ast_str:str, opts=None): return Lowerer(ast_str_to_ast(ast_str), opts=opts)
|
||||
def ast_str_to_lin(ast_str:str, opts=None): return Kernel(ast_str_to_ast(ast_str), opts=opts)
|
||||
def kern_str_to_lin(kern_str:str, opts=None):
|
||||
(ast, applied_opts,) = eval(kern_str)
|
||||
k = Lowerer(ast, opts=opts)
|
||||
k = Kernel(ast, opts=opts)
|
||||
for opt in applied_opts:
|
||||
k.apply_opt(opt)
|
||||
return k
|
||||
@@ -44,7 +44,7 @@ from tinygrad.shape.symbolic import Node
|
||||
|
||||
MAX_DIMS = 16
|
||||
MAX_BUFS = 9
|
||||
def lin_to_feats(lin:Lowerer, use_sts=True):
|
||||
def lin_to_feats(lin:Kernel, use_sts=True):
|
||||
assert lin.shape_len < MAX_DIMS, "too many dims"
|
||||
|
||||
all_colors = ["blue", "cyan", "white", "green", "red", "magenta", "yellow"]
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from tinygrad.codegen.lowerer import Lowerer
|
||||
from tinygrad.codegen.kernel import Kernel
|
||||
from tqdm import tqdm, trange
|
||||
import math
|
||||
import random
|
||||
@@ -45,7 +45,7 @@ if __name__ == "__main__":
|
||||
X,Y = [], []
|
||||
for i,x in enumerate(tqdm(dset)):
|
||||
ast, opts, tms = eval(x)
|
||||
lin = Lowerer(ast)
|
||||
lin = Kernel(ast)
|
||||
for o in opts: lin.apply_opt(o)
|
||||
if lin.shape_len >= MAX_DIMS: continue
|
||||
if min(tms) == float('inf'): continue
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
from typing import List, Tuple
|
||||
from tinygrad.codegen.lowerer import Lowerer
|
||||
from tinygrad.codegen.kernel import Kernel
|
||||
from tinygrad.engine.search import get_linearizer_actions, actions
|
||||
|
||||
_net = None
|
||||
def beam_q_estimate(beam:List[Tuple[Lowerer, float]]) -> List[Tuple[Lowerer, float]]:
|
||||
def beam_q_estimate(beam:List[Tuple[Kernel, float]]) -> List[Tuple[Kernel, float]]:
|
||||
global _net
|
||||
if _net is None:
|
||||
from tinygrad.nn.state import load_state_dict, safe_load
|
||||
|
||||
@@ -4,7 +4,7 @@ from extra.optimization.helpers import ast_str_to_lin
|
||||
from tinygrad import dtypes
|
||||
from tinygrad.helpers import BEAM, getenv
|
||||
from tinygrad.device import Device, Compiled
|
||||
from tinygrad.codegen.lowerer import Lowerer
|
||||
from tinygrad.codegen.kernel import Kernel
|
||||
from tinygrad.engine.search import time_linearizer, beam_search, bufs_from_lin
|
||||
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ from tinygrad import dtypes, Device
|
||||
from tinygrad.ops import LazyOp, BinaryOps, UnaryOps, BufferOps, MemBuffer, ConstBuffer
|
||||
from tinygrad.shape.view import View
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.codegen.lowerer import Lowerer
|
||||
from tinygrad.codegen.kernel import Kernel
|
||||
from tinygrad.engine.search import time_linearizer, bufs_from_lin
|
||||
|
||||
# from resnet50, tinybox red
|
||||
@@ -15,9 +15,9 @@ from tinygrad.engine.search import time_linearizer, bufs_from_lin
|
||||
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=UnaryOps.CAST, src=(LazyOp(op=UnaryOps.CAST, src=(LazyOp(op=UnaryOps.CAST, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=UnaryOps.CAST, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(256, 1, 128, 56, 56, 1, 1, 1), strides=(401408, 0, 3136, 56, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)))),), arg=dtypes.float), LazyOp(op=UnaryOps.NEG, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(256, 1, 128, 56, 56, 1, 1, 1), strides=(0, 0, 1, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)))),), arg=None)), arg=None), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=3, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(256, 1, 128, 56, 56, 1, 1, 1), strides=(0, 0, 1, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None), LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=UnaryOps.CAST, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(256, 1, 128, 56, 56, 1, 1, 1), strides=(401408, 0, 3136, 56, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)))),), arg=dtypes.float), LazyOp(op=UnaryOps.NEG, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(256, 1, 128, 56, 56, 1, 1, 1), strides=(0, 0, 1, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)))),), arg=None)), arg=None), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=3, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(256, 1, 128, 56, 56, 1, 1, 1), strides=(0, 0, 1, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None)), arg=None), LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=4, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(256, 1, 128, 56, 56, 1, 1, 1), strides=(0, 0, 1, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)))), LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=5, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(256, 1, 128, 56, 56, 1, 1, 1), strides=(0, 0, 1, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)))), LazyOp(op=UnaryOps.CAST, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=UnaryOps.CAST, src=(LazyOp(op=BinaryOps.CMPLT, src=(LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.0, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(256, 1, 128, 56, 56, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)))), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=6, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(256, 1, 128, 56, 56, 1, 1, 1), strides=(401408, 0, 3136, 56, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),))))), arg=None),), arg=dtypes.half), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=7, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(256, 1, 128, 56, 56, 1, 1, 1), strides=(430592, 0, 3364, 58, 1, 0, 0, 0), offset=59, mask=None, contiguous=False),))))), arg=None),), arg=dtypes.float)), arg=None)), arg=None)), arg=None), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=8, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(256, 1, 128, 56, 56, 1, 1, 1), strides=(0, 0, 1, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=dtypes.half),), arg=dtypes.float),), arg=dtypes.half),), arg=MemBuffer(idx=0, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(256, 1, 128, 56, 56, 1, 1, 1), strides=(401408, 0, 3136, 56, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),))))
|
||||
|
||||
device = Device[Device.DEFAULT]
|
||||
rawbufs = bufs_from_lin(Lowerer(ast))
|
||||
rawbufs = bufs_from_lin(Kernel(ast))
|
||||
|
||||
lin = Lowerer(ast, opts=device.renderer)
|
||||
lin = Kernel(ast, opts=device.renderer)
|
||||
lin.hand_coded_optimizations()
|
||||
tm = time_linearizer(lin, rawbufs, allow_test_size=False, cnt=10)
|
||||
print(f"{tm=}")
|
||||
4
test/external/external_test_hip_compile.py
vendored
4
test/external/external_test_hip_compile.py
vendored
@@ -3,14 +3,14 @@ from tinygrad.runtime.support.hip_comgr import compile_hip
|
||||
from tinygrad import Tensor
|
||||
from tinygrad.device import Device
|
||||
from tinygrad.engine.schedule import create_schedule
|
||||
from tinygrad.codegen.lowerer import Lowerer
|
||||
from tinygrad.codegen.kernel import Kernel
|
||||
|
||||
class TestHIPCompileSpeed(unittest.TestCase):
|
||||
@unittest.skipIf(Device.DEFAULT != "HIP", "only run on HIP")
|
||||
def test_hip_compile(self):
|
||||
a, b = Tensor([1,2,3,4,5]), Tensor([1,2,3,4,5])
|
||||
out = a + b
|
||||
lin = Lowerer(create_schedule([out.lazydata])[-1].ast[0])
|
||||
lin = Kernel(create_schedule([out.lazydata])[-1].ast[0])
|
||||
lin.linearize()
|
||||
|
||||
reference = """
|
||||
|
||||
6
test/external/external_test_nv.py
vendored
6
test/external/external_test_nv.py
vendored
@@ -8,7 +8,7 @@ from test.test_linearizer_failures import helper_test_lin
|
||||
from tinygrad.engine.realize import get_runner, CompiledRunner
|
||||
from test.external.fuzz_linearizer import get_fuzz_rawbufs
|
||||
|
||||
from tinygrad.codegen.lowerer import Lowerer
|
||||
from tinygrad.codegen.kernel import Kernel
|
||||
from tinygrad.ops import LazyOp, BinaryOps, UnaryOps, ReduceOps, BufferOps, MemBuffer
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.view import View
|
||||
@@ -28,13 +28,13 @@ class TestNV(unittest.TestCase):
|
||||
def test_oor_kernels(self):
|
||||
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=UnaryOps.CAST, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=UnaryOps.CAST, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(1, 256, 1, 512, 4, 16, 4, 16), strides=(0, 100352, 0, 196, 0, 14, 0, 1), offset=-15, mask=((0, 1), (0, 256), (0, 1), (0, 512), (0, 4), (1, 15), (0, 4), (1, 15)), contiguous=False), View(shape=(256, 1, 512, 7, 7, 512, 3, 3), strides=(2097152, 0, 0, 128, 2, 4096, 1088, 17), offset=0, mask=None, contiguous=False))))), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(256, 1, 512, 7, 7, 512, 3, 3), strides=(25088, 0, 49, 7, 1, 0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=(dtypes.float, False)),), arg=((0, 3, 4), dtypes.float)),), arg=(dtypes.half, False)),), arg=MemBuffer(idx=0, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(1, 1, 512, 1, 1, 512, 3, 3), strides=(0, 0, 4608, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=True),)))) # noqa: E501
|
||||
opts = [Opt(op=OptOps.TC, axis=6, amt=2), Opt(op=OptOps.UPCAST, axis=0, amt=4), Opt(op=OptOps.UPCAST, axis=3, amt=0), Opt(op=OptOps.LOCAL, axis=1, amt=4), Opt(op=OptOps.LOCAL, axis=2, amt=3), Opt(op=OptOps.UPCAST, axis=1, amt=2)] # noqa: E501
|
||||
helper_test_lin(Lowerer(ast), opts=opts, failed_platforms=["NV"])
|
||||
helper_test_lin(Kernel(ast), opts=opts, failed_platforms=["NV"])
|
||||
|
||||
def test_error_on_huge_dims(self):
|
||||
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=UnaryOps.CAST, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(1, 1, 1024, 683), strides=(0, 0, 0, 1), offset=0, mask=None, contiguous=False),)))), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(1, 1, 1024, 683), strides=(0, 0, 683, 1), offset=0, mask=None, contiguous=True),))))), arg=None),), arg=dtypes.float),), arg=(3,)),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 1, 1024, 1), strides=(0, 0, 1, 0), offset=0, mask=None, contiguous=True),)))) # noqa: E501
|
||||
opts = [Opt(op=OptOps.GROUP, axis=0, amt=0), Opt(op=OptOps.PADTO, axis=1, amt=32), Opt(op=OptOps.UNROLL, axis=0, amt=4), Opt(op=OptOps.LOCAL, axis=0, amt=2), Opt(op=OptOps.LOCAL, axis=0, amt=2)] # noqa: E501
|
||||
with self.assertRaises(RuntimeError) as cm:
|
||||
lin = Lowerer(ast)
|
||||
lin = Kernel(ast)
|
||||
for opt in opts: lin.apply_opt(opt)
|
||||
rawbufs = get_fuzz_rawbufs(lin)
|
||||
prg = CompiledRunner(lin.to_program())
|
||||
|
||||
8
test/external/external_test_uops_graphing.py
vendored
8
test/external/external_test_uops_graphing.py
vendored
@@ -1,7 +1,7 @@
|
||||
#!/usr/bin/env python
|
||||
import unittest
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.codegen.lowerer import Lowerer
|
||||
from tinygrad.codegen.kernel import Kernel
|
||||
from tinygrad.renderer.cstyle import OpenCLRenderer
|
||||
from tinygrad.engine.graph import graph_uops
|
||||
from tinygrad.engine.schedule import create_schedule
|
||||
@@ -13,7 +13,7 @@ class TestUopsGraph(unittest.TestCase):
|
||||
a = Tensor.rand(N,N)
|
||||
b = Tensor.rand(N,N)
|
||||
si = create_schedule([(a@b).lazydata])[-1]
|
||||
lin = Lowerer(si.ast)
|
||||
lin = Kernel(si.ast)
|
||||
lin.hand_coded_optimizations()
|
||||
print(lin.colored_shape())
|
||||
uops = lin.linearize().uops
|
||||
@@ -24,7 +24,7 @@ class TestUopsGraph(unittest.TestCase):
|
||||
def test_reduce(self):
|
||||
a = Tensor.rand(1024*1024)
|
||||
si = create_schedule([a.sum().lazydata])[-1]
|
||||
lin = Lowerer(si.ast)
|
||||
lin = Kernel(si.ast)
|
||||
lin.hand_coded_optimizations()
|
||||
uops = lin.linearize().uops
|
||||
graph_uops(uops)
|
||||
@@ -34,7 +34,7 @@ class TestUopsGraph(unittest.TestCase):
|
||||
x = Tensor.rand(1,3,16,16)
|
||||
c = Conv2d(3, 16, (3,3))
|
||||
si = create_schedule([c(x).elu().lazydata])[-1]
|
||||
lin = Lowerer(si.ast)
|
||||
lin = Kernel(si.ast)
|
||||
lin.hand_coded_optimizations()
|
||||
uops = lin.linearize().uops
|
||||
graph_uops(uops)
|
||||
|
||||
12
test/external/fuzz_linearizer.py
vendored
12
test/external/fuzz_linearizer.py
vendored
@@ -6,7 +6,7 @@ from extra.optimization.helpers import load_worlds, ast_str_to_lin
|
||||
|
||||
from tinygrad import Tensor, Device, dtypes
|
||||
from tinygrad.tensor import _to_np_dtype
|
||||
from tinygrad.codegen.lowerer import Lowerer
|
||||
from tinygrad.codegen.kernel import Kernel
|
||||
from tinygrad.codegen.uops import UOp
|
||||
from tinygrad.codegen.kernel import Opt, OptOps
|
||||
from tinygrad.engine.search import get_linearizer_actions, bufs_from_lin
|
||||
@@ -53,7 +53,7 @@ def get_fuzz_rawbuf_like(rawbuf, zero=False, size=None):
|
||||
rawbuf.copyin(mv)
|
||||
return rawbuf
|
||||
|
||||
def run_linearizer(lin: Lowerer, rawbufs=None, var_vals=None):
|
||||
def run_linearizer(lin: Kernel, rawbufs=None, var_vals=None):
|
||||
if rawbufs is None: rawbufs = bufs_from_lin(lin)
|
||||
if var_vals is None: var_vals = {v: v.min for v in lin.ast[0].vars()}
|
||||
|
||||
@@ -72,7 +72,7 @@ def run_linearizer(lin: Lowerer, rawbufs=None, var_vals=None):
|
||||
|
||||
return "PASS"
|
||||
|
||||
def compare_linearizer(lin: Lowerer, rawbufs=None, var_vals=None, ground_truth=None, rtol=1e-2, atol=1e-2):
|
||||
def compare_linearizer(lin: Kernel, rawbufs=None, var_vals=None, ground_truth=None, rtol=1e-2, atol=1e-2):
|
||||
# TODO: for bfloat16 it compiles linearizer, but it does not run because numpy cannot generate bf16 buffer.
|
||||
has_bf16 = any(b.dtype == dtypes.bfloat16 for b in lin.membufs)
|
||||
|
||||
@@ -90,7 +90,7 @@ def compare_linearizer(lin: Lowerer, rawbufs=None, var_vals=None, ground_truth=N
|
||||
var_vals = {v: random.randint(v.min, v.max if isinstance(v.max, int) else v.min) for v in lin.ast.vars()}
|
||||
|
||||
if ground_truth is None and not has_bf16:
|
||||
unoptimized = Lowerer(lin.ast)
|
||||
unoptimized = Kernel(lin.ast)
|
||||
unoptimized.required_optimizations()
|
||||
if run_linearizer(unoptimized, rawbufs, var_vals) != "PASS":
|
||||
return ("BASELINE_ERROR", rawbufs, var_vals, ground_truth,)
|
||||
@@ -117,7 +117,7 @@ def compare_linearizer(lin: Lowerer, rawbufs=None, var_vals=None, ground_truth=N
|
||||
|
||||
return ("PASS", rawbufs, var_vals, ground_truth,)
|
||||
|
||||
def fuzz_linearizer(lin: Lowerer, rtol=1e-2, atol=1e-2):
|
||||
def fuzz_linearizer(lin: Kernel, rtol=1e-2, atol=1e-2):
|
||||
SEED = getenv("SEED", 42)
|
||||
random.seed(SEED)
|
||||
np.random.seed(SEED)
|
||||
@@ -177,7 +177,7 @@ def fuzz_linearizer(lin: Lowerer, rtol=1e-2, atol=1e-2):
|
||||
if FUZZ_ALL_ACTIONS: print(f"depth={depth} total_lins={len(last_lins)} {failures=}")
|
||||
return failures
|
||||
|
||||
def _is_simple(lin: Lowerer) -> bool:
|
||||
def _is_simple(lin: Kernel) -> bool:
|
||||
if len(lin.ast.src) > 1: return False
|
||||
ast:LazyOp = lin.ast.src[0]
|
||||
if ast.src[0] and ast.src[0].op is UnaryOps.CAST and ast.src[0].src[0] and ast.src[0].src[0].op is BufferOps.LOAD: return True
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
#!/usr/bin/env python3
|
||||
# compare kernels created by HEAD against master
|
||||
import difflib, pickle
|
||||
from tinygrad.codegen.lowerer import Lowerer
|
||||
from tinygrad.codegen.kernel import Kernel
|
||||
from tinygrad.helpers import Context, ContextVar, colored, db_connection, VERSION, getenv, tqdm
|
||||
|
||||
page_size = 100
|
||||
@@ -17,7 +17,7 @@ for offset in tqdm(range(0, row_count, page_size)):
|
||||
with Context(**{k:v for k,v in ctx.items() if k in ContextVar._cache}):
|
||||
# try linearize
|
||||
try:
|
||||
k = Lowerer(ast, opts=opts)
|
||||
k = Kernel(ast, opts=opts)
|
||||
for opt in applied_opts: k.apply_opt(opt)
|
||||
good_src = k.opts.render(name, k.linearize().uops)
|
||||
except Exception as e:
|
||||
|
||||
6
test/external/verify_kernel.py
vendored
6
test/external/verify_kernel.py
vendored
@@ -3,7 +3,7 @@ from collections import defaultdict
|
||||
from extra.optimization.helpers import kern_str_to_lin
|
||||
from test.external.fuzz_linearizer import compare_linearizer
|
||||
from tinygrad.helpers import colored
|
||||
from tinygrad.codegen.lowerer import Lowerer
|
||||
from tinygrad.codegen.kernel import Kernel
|
||||
from tinygrad.engine.graph import print_tree
|
||||
from tinygrad.engine.search import time_linearizer
|
||||
|
||||
@@ -37,7 +37,7 @@ if __name__ == "__main__":
|
||||
import pickle
|
||||
with open(args.pkl, 'rb') as file:
|
||||
(ast, applied_opts,) = pickle.load(file)
|
||||
lin = Lowerer(ast)
|
||||
lin = Kernel(ast)
|
||||
for opt in applied_opts:
|
||||
lin.apply_opt(opt)
|
||||
test_lins = [lin]
|
||||
@@ -55,7 +55,7 @@ if __name__ == "__main__":
|
||||
print_tree(op)
|
||||
print(op)
|
||||
print(test_lin.applied_opts)
|
||||
unoptimized_lin = Lowerer(test_lin.ast)
|
||||
unoptimized_lin = Kernel(test_lin.ast)
|
||||
unoptimized_lin.required_optimizations()
|
||||
print(f"{unoptimized_lin.colored_shape()} -> {test_lin.colored_shape()}")
|
||||
(msg,rb,vv,gt) = compare_linearizer(test_lin, None, None, None, rtol=args.rtol, atol=args.atol)
|
||||
|
||||
@@ -4,8 +4,7 @@ import unittest
|
||||
from dataclasses import replace
|
||||
from test.external.fuzz_linearizer import compare_linearizer
|
||||
|
||||
from tinygrad.codegen.kernel import Opt, OptOps, KernelOptError
|
||||
from tinygrad.codegen.lowerer import Lowerer
|
||||
from tinygrad.codegen.kernel import Opt, OptOps, KernelOptError, Kernel
|
||||
from tinygrad.codegen.lowerer import get_grouped_dims
|
||||
from tinygrad.codegen.uops import UOp, UOps
|
||||
from tinygrad.device import Device, Buffer
|
||||
@@ -38,7 +37,7 @@ def helper_tc_allclose(n:int, m:int, k:int, dtype_in:DType, dtype_out:DType, axi
|
||||
realized_ast = sched[-1].ast
|
||||
run_schedule(sched)
|
||||
out = r.numpy()
|
||||
k = Lowerer(realized_ast)
|
||||
k = Kernel(realized_ast)
|
||||
k.apply_tensor_cores(1, axis=axis, tc_opt=tc_opt)
|
||||
k.linearize()
|
||||
assert len([uop for uop in k.uops if uop.op is UOps.WMMA]) > 0, "tensor core not triggered"
|
||||
@@ -54,7 +53,7 @@ def helper_tc_ensure_uops_and_opts_count(n: int, m:int, k:int, dtype_in:DType, d
|
||||
r = a.matmul(b, acc_dtype=dtype_out)
|
||||
sched = create_schedule([r.lazydata])
|
||||
realized_ast = sched[-1].ast
|
||||
k = Lowerer(realized_ast)
|
||||
k = Kernel(realized_ast)
|
||||
k.apply_tensor_cores(1, axis=axis, tc_opt=tc_opt)
|
||||
k.linearize()
|
||||
wmmas = len([uop for uop in k.uops if uop.op is UOps.WMMA])
|
||||
@@ -211,7 +210,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
@unittest.skip("AST has implicit movement ops")
|
||||
def test_early_end_local(self):
|
||||
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=UnaryOps.SQRT, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 32), strides=(864, 32, 1), offset=0, mask=None, contiguous=True),)))), LazyOp(op=UnaryOps.NEG, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 32), strides=(864, 32, 1), offset=0, mask=None, contiguous=True),)))),), arg=(2,)), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.03125, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 32), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=None)), arg=None), LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 32), strides=(864, 32, 1), offset=0, mask=None, contiguous=True),)))), LazyOp(op=UnaryOps.NEG, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 32), strides=(864, 32, 1), offset=0, mask=None, contiguous=True),)))),), arg=(2,)), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.03125, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 32), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=None)), arg=None)), arg=None),), arg=(2,)), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.03125, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 1), strides=(27, 1, 0), offset=0, mask=None, contiguous=True),))))), arg=None),), arg=None),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 1), strides=(27, 1, 0), offset=0, mask=None, contiguous=True),)))), # noqa: E501
|
||||
k = Lowerer(ast)
|
||||
k = Kernel(ast)
|
||||
k.hand_coded_optimizations()
|
||||
k.linearize()
|
||||
self.assertEqual(len(endifs:=[x for x in k.uops if x.op is UOps.ENDIF]), len(ifs:=[x for x in k.uops if x.op is UOps.IF]))
|
||||
@@ -243,7 +242,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
LazyOp(op=BufferOps.STORE, src=(ast2,), arg=MemBuffer(idx=order.index(2), dtype=dtypes.float, st=ShapeTracker.from_shape((1,)))),
|
||||
LazyOp(op=BufferOps.STORE, src=(ast3,), arg=MemBuffer(idx=order.index(3), dtype=dtypes.float, st=ShapeTracker.from_shape((1,))))
|
||||
]
|
||||
k = Lowerer([asts[i] for i in order])
|
||||
k = Kernel([asts[i] for i in order])
|
||||
def recursive_reduceops(x: LazyOp): return [c for v in x.src for c in recursive_reduceops(v)] + [v for v in list(x.src) if v.op in ReduceOps]
|
||||
for i,r in enumerate(k.reduceops): assert not any([r in recursive_reduceops(x) for x in k.reduceops[:i]]), "reduceops are out of order"
|
||||
x = Tensor.randn(32).realize()
|
||||
@@ -256,7 +255,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
def test_multireduce_store_locals(self):
|
||||
# ensure the result of local reducop is stored and loaded back into every thread for future use
|
||||
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=UnaryOps.SQRT, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 32), strides=(864, 32, 1), offset=0, mask=None, contiguous=True),)))), LazyOp(op=UnaryOps.NEG, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 32), strides=(864, 32, 1), offset=0, mask=None, contiguous=True),)))),), arg=(2,)), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.03125, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 32), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=None)), arg=None), LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 32), strides=(864, 32, 1), offset=0, mask=None, contiguous=True),)))), LazyOp(op=UnaryOps.NEG, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 32), strides=(864, 32, 1), offset=0, mask=None, contiguous=True),)))),), arg=(2,)), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.03125, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 32), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=None)), arg=None)), arg=None),), arg=(2,)), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.03125, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 1), strides=(27, 1, 0), offset=0, mask=None, contiguous=True),))))), arg=None),), arg=None),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 1), strides=(27, 1, 0), offset=0, mask=None, contiguous=True),)))), # noqa: E501
|
||||
k = Lowerer(ast)
|
||||
k = Kernel(ast)
|
||||
k.hand_coded_optimizations()
|
||||
k.linearize()
|
||||
local_buf = [u for u in k.uops if u.op is UOps.DEFINE_LOCAL]
|
||||
@@ -273,7 +272,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
def test_multireduce_upcasting(self):
|
||||
# when upcasting multiple reductions, ensure ast_parse will create multiple uops even when using the result of past reductions
|
||||
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float32, st=ShapeTracker(views=(View(shape=(8, 7), strides=(7, 1), offset=0, mask=None, contiguous=True),)))), LazyOp(op=UnaryOps.NEG, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float32, st=ShapeTracker(views=(View(shape=(8, 7), strides=(7, 1), offset=0, mask=None, contiguous=True),),))),), arg=(1,)),), arg=None),)),), arg=(1,)),), arg=MemBuffer(idx=0, dtype=dtypes.float32, st=ShapeTracker(views=(View(shape=(8, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)))), # noqa: E501
|
||||
k = Lowerer(ast)
|
||||
k = Kernel(ast)
|
||||
k.upcast()
|
||||
k.linearize()
|
||||
define_globals = [u for u in k.uops if u.op is UOps.DEFINE_GLOBAL]
|
||||
@@ -302,7 +301,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
@unittest.skip("AST has implicit movement ops")
|
||||
def test_multireduce_loop_scope(self):
|
||||
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 32), strides=(864, 32, 1), offset=0, mask=None, contiguous=True),)))), LazyOp(op=UnaryOps.NEG, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 32), strides=(864, 32, 1), offset=0, mask=None, contiguous=True),)))),), arg=(2,)), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.03125, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 32), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=None))), LazyOp(op=UnaryOps.RECIP, src=(LazyOp(op=UnaryOps.SQRT, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 32), strides=(864, 32, 1), offset=0, mask=None, contiguous=True),)))), LazyOp(op=UnaryOps.NEG, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 32), strides=(864, 32, 1), offset=0, mask=None, contiguous=True),)))),), arg=(2,)), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.03125, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 32), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=None)), arg=None), LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 32), strides=(864, 32, 1), offset=0, mask=None, contiguous=True),)))), LazyOp(op=UnaryOps.NEG, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 32), strides=(864, 32, 1), offset=0, mask=None, contiguous=True),)))),), arg=(2,)), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.03125, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 32), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=None)), arg=None)), arg=None),), arg=(2,)), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.03125, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 1), strides=(27, 1, 0), offset=0, mask=None, contiguous=True),))))), arg=None),), arg=None),)),),),), arg=(2,)),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 1), strides=(27, 1, 0), offset=0, mask=None, contiguous=True),),))), # noqa: E501
|
||||
k = Lowerer(ast)
|
||||
k = Kernel(ast)
|
||||
k.hand_coded_optimizations()
|
||||
k.linearize()
|
||||
def get_recursive_children(x:UOp): return set.union(set(x.src), *[get_recursive_children(v) for v in x.src])
|
||||
@@ -377,7 +376,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
# these are of size 3 to avoid float4 coalesce
|
||||
r = a[:-1] + a[1:]
|
||||
|
||||
k = Lowerer(create_schedule([r.lazydata])[-1].ast)
|
||||
k = Kernel(create_schedule([r.lazydata])[-1].ast)
|
||||
k.upcast()
|
||||
k.linearize()
|
||||
num_loads = len([uop for uop in k.uops if uop.op is UOps.LOAD])
|
||||
@@ -395,7 +394,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
b = LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=1, dtype=DT, st=ST)), VAL))
|
||||
|
||||
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=BinaryOps.ADD, src=(a,b)),), arg=MemBuffer(idx=0, dtype=DT, st=ST))
|
||||
lin = Lowerer(ast)
|
||||
lin = Kernel(ast)
|
||||
lin.linearize()
|
||||
|
||||
assert len(lin.uops.uops) <= 7, "too many uops"
|
||||
@@ -408,7 +407,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
a, b = Tensor.randn(1).realize(), Tensor.randn(1).realize()
|
||||
r = a.expand([2]) + b.expand([2])
|
||||
|
||||
k = Lowerer(create_schedule([r.lazydata])[-1].ast)
|
||||
k = Kernel(create_schedule([r.lazydata])[-1].ast)
|
||||
k.upcast()
|
||||
k.linearize()
|
||||
num_ops = len([uop for uop in k.uops if uop.op is UOps.ALU])
|
||||
@@ -419,7 +418,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
x, w = Tensor.randn((1,1,3)).realize(), Tensor.randn((1,1,2)).realize()
|
||||
r = Tensor.conv2d(x,w,padding=1).relu()
|
||||
|
||||
k = Lowerer(create_schedule([r.lazydata])[-1].ast)
|
||||
k = Kernel(create_schedule([r.lazydata])[-1].ast)
|
||||
k.upcast()
|
||||
k.upcast()
|
||||
k.linearize()
|
||||
@@ -435,7 +434,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
def test_upcast_with_locals(self):
|
||||
x, y = Tensor.rand(1,128), Tensor.rand(128, 128)
|
||||
r = (x@y).relu()
|
||||
k = Lowerer(create_schedule([r.lazydata])[-1].ast)
|
||||
k = Kernel(create_schedule([r.lazydata])[-1].ast)
|
||||
k.hand_coded_optimizations()
|
||||
k.linearize()
|
||||
|
||||
@@ -469,7 +468,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
a, b = Tensor.randn(1).realize(), Tensor.randn(1).realize()
|
||||
r = Tensor.stack(a, b)
|
||||
|
||||
k = Lowerer(create_schedule([r.lazydata])[-1].ast)
|
||||
k = Kernel(create_schedule([r.lazydata])[-1].ast)
|
||||
k.upcast()
|
||||
k.linearize()
|
||||
num_ops = len([uop for uop in k.uops if uop.op is UOps.ALU])
|
||||
@@ -479,14 +478,14 @@ class TestLinearizer(unittest.TestCase):
|
||||
for tensor_dtype, acc_dtype in (
|
||||
(dtypes.bool, dtypes.int), (dtypes.int16, dtypes.int), (dtypes.float16, dtypes.float), (dtypes.bfloat16, dtypes.float)):
|
||||
a = Tensor([1, 2, 3], dtype=tensor_dtype).sum()
|
||||
k = Lowerer(create_schedule([a.lazydata])[-1].ast)
|
||||
k = Kernel(create_schedule([a.lazydata])[-1].ast)
|
||||
k.linearize()
|
||||
local = [uop for uop in k.uops if uop.op is UOps.DEFINE_ACC]
|
||||
assert local[0].dtype == acc_dtype
|
||||
|
||||
def test_arg_acc_dtype(self):
|
||||
def helper_arg_acc_dtype(c: Tensor, expected_dtype:DType):
|
||||
k = Lowerer(create_schedule([c.lazydata])[-1].ast)
|
||||
k = Kernel(create_schedule([c.lazydata])[-1].ast)
|
||||
k.linearize()
|
||||
local = [uop for uop in k.uops if uop.op is UOps.DEFINE_ACC]
|
||||
assert local[0].dtype == expected_dtype
|
||||
@@ -550,7 +549,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
c = a.conv2d(b, padding=1, acc_dtype=tc.dtype_out)
|
||||
realized_ast, real_bufs = helper_realized_ast(c)
|
||||
|
||||
k = Lowerer(realized_ast)
|
||||
k = Kernel(realized_ast)
|
||||
k.apply_tensor_cores(1, axis=axis, tc_opt=2)
|
||||
k.linearize()
|
||||
assert len([uop for uop in k.uops if uop.op is UOps.WMMA]) > 0, "tensor core not triggered"
|
||||
@@ -567,7 +566,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
|
||||
# check that get_linearizer_actions produces all 9 options
|
||||
from tinygrad.engine.search import get_linearizer_actions
|
||||
tc_actions = [k for i, k in get_linearizer_actions(Lowerer(realized_ast), False).items() if k.applied_opts[0].op == OptOps.TC]
|
||||
tc_actions = [k for i, k in get_linearizer_actions(Kernel(realized_ast), False).items() if k.applied_opts[0].op == OptOps.TC]
|
||||
assert len(tc_actions) == 9, f"get_linearizer_actions should contain 9 possible TC actions, only got {len(tc_actions)}"
|
||||
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores")
|
||||
@@ -677,7 +676,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
sched = [si for si in create_schedule([t.lazydata]) if si.ast.op is MetaOps.SINK]
|
||||
assert len(sched) == 1
|
||||
|
||||
lin = Lowerer(sched[0].ast)
|
||||
lin = Kernel(sched[0].ast)
|
||||
assert sum(u.arg is UnaryOps.RECIP for u in lin.linearize().uops) == max_ops, msg
|
||||
|
||||
a = Tensor.rand((4,4))
|
||||
@@ -697,7 +696,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
t = Tensor([2]).reshape(1, 1).expand(256, 256).sum()
|
||||
sched = [si for si in create_schedule([t.lazydata]) if si.ast.op is MetaOps.SINK]
|
||||
assert len(sched) == 1
|
||||
lin = Lowerer(sched[0].ast)
|
||||
lin = Kernel(sched[0].ast)
|
||||
assert not any(u.op is UOps.RANGE for u in lin.linearize().uops), "found loop in sum collapse"
|
||||
|
||||
def test_assign_fold(self):
|
||||
@@ -716,7 +715,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
sched_copy = sched[:]
|
||||
run_schedule(sched)
|
||||
np.testing.assert_equal(a.flatten().numpy(), [1.,1.,1.,1.,2.,2.,2.,2.,1.,1.,1.,1.,1.,1.,1.,1.])
|
||||
lin = Lowerer(sched_copy[-1].ast)
|
||||
lin = Kernel(sched_copy[-1].ast)
|
||||
lin.hand_coded_optimizations()
|
||||
lin.linearize()
|
||||
assert not any(u.arg == TernaryOps.WHERE for u in lin.uops), "found where where where should be folded"
|
||||
@@ -844,7 +843,7 @@ class TestFloat4(unittest.TestCase):
|
||||
c = a + b
|
||||
|
||||
s = create_schedule([c.lazydata])[0]
|
||||
k = Lowerer(s.ast)
|
||||
k = Kernel(s.ast)
|
||||
k.hand_coded_optimizations()
|
||||
k.linearize()
|
||||
|
||||
@@ -856,7 +855,7 @@ class TestFloat4(unittest.TestCase):
|
||||
c = a + b
|
||||
|
||||
s = create_schedule([c.lazydata])[0]
|
||||
k = Lowerer(s.ast)
|
||||
k = Kernel(s.ast)
|
||||
k.shift_to(0, 4) # float4 dimension
|
||||
k.shift_to(0, 2, insert_before=k.shape_len-1)
|
||||
k.upcast()
|
||||
@@ -872,7 +871,7 @@ class TestFloat4(unittest.TestCase):
|
||||
c = a + b
|
||||
|
||||
s = create_schedule([c.lazydata])[0]
|
||||
k = Lowerer(s.ast)
|
||||
k = Kernel(s.ast)
|
||||
k.hand_coded_optimizations() # implicit trigger float4 dim
|
||||
k.linearize()
|
||||
|
||||
@@ -884,7 +883,7 @@ class TestFloat4(unittest.TestCase):
|
||||
c = a + b
|
||||
|
||||
s = create_schedule([c.lazydata])[0]
|
||||
k = Lowerer(s.ast)
|
||||
k = Kernel(s.ast)
|
||||
k.shift_to(len(k.full_unupcasted_shape)-1, 4) # manual trigger float4 dim
|
||||
k.upcast()
|
||||
k.shift_to(len(k.full_unupcasted_shape)-1, 2, insert_before=k.shape_len-1)
|
||||
@@ -902,7 +901,7 @@ class TestFloat4(unittest.TestCase):
|
||||
# float4 should be emitted (the reduce axis of size 4 is the float4 axis here)
|
||||
|
||||
s = create_schedule([c.lazydata])[0]
|
||||
k = Lowerer(s.ast)
|
||||
k = Kernel(s.ast)
|
||||
k.upcast()
|
||||
k.linearize()
|
||||
|
||||
@@ -917,7 +916,7 @@ class TestFloat4(unittest.TestCase):
|
||||
# don't.
|
||||
|
||||
s = create_schedule([c.lazydata])[0]
|
||||
k = Lowerer(s.ast)
|
||||
k = Kernel(s.ast)
|
||||
k.upcast()
|
||||
k.upcast()
|
||||
k.linearize()
|
||||
@@ -933,7 +932,7 @@ class TestFloat4(unittest.TestCase):
|
||||
# since the top axis is not contiguous.
|
||||
|
||||
s = create_schedule([c.lazydata])[0]
|
||||
k = Lowerer(s.ast)
|
||||
k = Kernel(s.ast)
|
||||
k.shift_to(0, 4, top=True) # top axes are float4 axes
|
||||
k.upcast()
|
||||
k.linearize()
|
||||
@@ -949,7 +948,7 @@ class TestFloat4(unittest.TestCase):
|
||||
# since the top axis is not contiguous.
|
||||
|
||||
s = create_schedule([c.lazydata])[0]
|
||||
k = Lowerer(s.ast)
|
||||
k = Kernel(s.ast)
|
||||
k.shift_to(0, 4) # float4 axis
|
||||
k.upcast()
|
||||
k.linearize()
|
||||
@@ -964,7 +963,7 @@ class TestFloat4(unittest.TestCase):
|
||||
# should float4 b but not a
|
||||
|
||||
s = create_schedule([c.lazydata])[0]
|
||||
k = Lowerer(s.ast)
|
||||
k = Kernel(s.ast)
|
||||
k.shift_to(0, 4) # float4 axis
|
||||
k.upcast()
|
||||
k.linearize()
|
||||
@@ -977,7 +976,7 @@ class TestHandCodedOpts(unittest.TestCase):
|
||||
layer_2 = Tensor.cat(layer_1.unsqueeze(0), Tensor.rand(6, 20))
|
||||
|
||||
s = create_schedule([layer_2.lazydata])[-1]
|
||||
k = Lowerer(s.ast)
|
||||
k = Kernel(s.ast)
|
||||
k.hand_coded_optimizations()
|
||||
assert len(k.bufs) == 6 # make sure all ops are done in one kernel
|
||||
# masked upcast should upcast masked axis of size 7
|
||||
@@ -989,7 +988,7 @@ class TestHandCodedOpts(unittest.TestCase):
|
||||
monster = Tensor.stack(*[Tensor.stack(*[Tensor.rand(16) for _ in range(6)]) for _ in range(6)])
|
||||
|
||||
s = create_schedule([monster.lazydata])[-1]
|
||||
k = Lowerer(s.ast)
|
||||
k = Kernel(s.ast)
|
||||
k.hand_coded_optimizations()
|
||||
assert len(k.bufs) == 37 # make sure all ops are done in one kernel
|
||||
# should upcast the two Tensor.stacks
|
||||
@@ -1003,7 +1002,7 @@ class TestHandCodedOpts(unittest.TestCase):
|
||||
wino_schedule = create_schedule([out.lazydata])
|
||||
# collect upcasts of tile transform kernels
|
||||
for i, si in enumerate(wino_schedule):
|
||||
k = Lowerer(si.ast)
|
||||
k = Kernel(si.ast)
|
||||
k.hand_coded_optimizations()
|
||||
if k.reduceop is not None: continue # not a tile transform kernel (there is a gemm reduce kernel)
|
||||
if len(k.bufs) < 36: continue # not a tile transform kernel (there's a permute kernel at the end)
|
||||
@@ -1016,7 +1015,7 @@ class TestHandCodedOpts(unittest.TestCase):
|
||||
out.mean().backward()
|
||||
backward_schedule = create_schedule([x.grad.lazydata, w.grad.lazydata])
|
||||
for si in backward_schedule:
|
||||
k = Lowerer(si.ast)
|
||||
k = Kernel(si.ast)
|
||||
k.hand_coded_optimizations()
|
||||
k.linearize()
|
||||
if len(k.bufs) < 20: continue # not a tile transform kernel
|
||||
@@ -1058,11 +1057,11 @@ def helper_linearizer_opt(r:Union[Tensor, List[Tensor]], *args, **kwargs):
|
||||
return _helper_linearizer_opt_ast(realized_ast, real_bufs, *args, **kwargs)
|
||||
|
||||
def _helper_linearizer_opt_ast(realized_ast:LazyOp, real_bufs:List[Buffer], opts=[],
|
||||
apply_tc=False, atol=1e-4, rtol=1e-4, color_sizes=[], wanna_output=[]) -> List[Lowerer]:
|
||||
lins: List[Lowerer] = []
|
||||
apply_tc=False, atol=1e-4, rtol=1e-4, color_sizes=[], wanna_output=[]) -> List[Kernel]:
|
||||
lins: List[Kernel] = []
|
||||
outbufs = [real_bufs[i] for i in range(len(realized_ast.src))]
|
||||
|
||||
def get_prg(k:Lowerer): return CompiledRunner(replace(k.to_program(), dname=Device.DEFAULT))
|
||||
def get_prg(k:Kernel): return CompiledRunner(replace(k.to_program(), dname=Device.DEFAULT))
|
||||
|
||||
def check_opt(opts, create_k, expected_color_size):
|
||||
k = create_k()
|
||||
@@ -1082,7 +1081,7 @@ def _helper_linearizer_opt_ast(realized_ast:LazyOp, real_bufs:List[Buffer], opts
|
||||
np.testing.assert_allclose(np.frombuffer(buf.as_buffer(), _to_np_dtype(buf.dtype)), wanna_output[i], atol=atol, rtol=rtol)
|
||||
|
||||
# Get baseline if it is not provided, which is not optimized at all.
|
||||
k = Lowerer(realized_ast)
|
||||
k = Kernel(realized_ast)
|
||||
lins.append(k)
|
||||
prg = get_prg(k)
|
||||
prg.exec(real_bufs)
|
||||
@@ -1092,7 +1091,7 @@ def _helper_linearizer_opt_ast(realized_ast:LazyOp, real_bufs:List[Buffer], opts
|
||||
np.testing.assert_allclose(np.frombuffer(buf.as_buffer(), _to_np_dtype(buf.dtype)), wanna_output[i], atol=atol, rtol=rtol)
|
||||
|
||||
# Check correctness of handcoded optimiztions.
|
||||
k = Lowerer(realized_ast)
|
||||
k = Kernel(realized_ast)
|
||||
lins.append(k)
|
||||
k.hand_coded_optimizations()
|
||||
prg = get_prg(k)
|
||||
@@ -1101,7 +1100,7 @@ def _helper_linearizer_opt_ast(realized_ast:LazyOp, real_bufs:List[Buffer], opts
|
||||
for i, buf in enumerate(outbufs):
|
||||
np.testing.assert_allclose(np.frombuffer(buf.as_buffer(), _to_np_dtype(buf.dtype)), wanna_output[i], atol=atol, rtol=rtol)
|
||||
for i, x in enumerate(opts): # Check custom transformations if any.
|
||||
check_opt(x, lambda: Lowerer(realized_ast), color_sizes[i] if i < len(color_sizes) else None)
|
||||
check_opt(x, lambda: Kernel(realized_ast), color_sizes[i] if i < len(color_sizes) else None)
|
||||
return lins
|
||||
|
||||
# creates a back-to-back multi reduce AST by merging r0 and r1.
|
||||
@@ -1438,14 +1437,14 @@ class TestKernelOpts(unittest.TestCase):
|
||||
[Opt(OptOps.LOCAL, 0, 2), Opt(OptOps.LOCAL, 2, 2)],
|
||||
]
|
||||
for x in invalid_opts:
|
||||
k = Lowerer(realized_ast)
|
||||
k = Kernel(realized_ast)
|
||||
with self.assertRaises(AssertionError):
|
||||
assert k.apply_tensor_cores(use_tensor_cores=1, extra_opts=x), "no valid tensor core" # for METAL in runners
|
||||
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores")
|
||||
def test_buf_index_not_found_tensor_core(self):
|
||||
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=UnaryOps.CAST, src=(LazyOp(op=BinaryOps.CMPNE, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(1243, 256), strides=(0, 1), offset=0, mask=None, contiguous=False),)))), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(1243, 256), strides=(1, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=dtypes.float), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=3, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1243, 256), strides=(1, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=(0,)),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 256), strides=(0, 1), offset=0, mask=None, contiguous=True),)))) # noqa: E501
|
||||
k = Lowerer(ast, opts=Device[Device.DEFAULT].renderer)
|
||||
k = Kernel(ast, opts=Device[Device.DEFAULT].renderer)
|
||||
with self.assertRaises(KernelOptError):
|
||||
k.apply_opt(Opt(OptOps.TC, 0, 1))
|
||||
|
||||
@@ -1462,7 +1461,7 @@ class TestKernelOpts(unittest.TestCase):
|
||||
c, d = Tensor.rand(M, K, dtype=tc.dtype_in).realize(), Tensor.rand(K, N, dtype=tc.dtype_in).realize()
|
||||
r1 = c.matmul(d, acc_dtype=tc.dtype_out)
|
||||
ast = _temp_create_multireduce_ast(r0, r1)
|
||||
lin = Lowerer(ast)
|
||||
lin = Kernel(ast)
|
||||
lin.apply_opt(Opt(op=OptOps.TC, axis=0, amt=2))
|
||||
lin.linearize()
|
||||
result = compare_linearizer(lin)
|
||||
|
||||
File diff suppressed because one or more lines are too long
@@ -2,7 +2,7 @@
|
||||
import unittest
|
||||
from tinygrad import dtypes, Device
|
||||
from tinygrad.helpers import CI
|
||||
from tinygrad.codegen.lowerer import Lowerer
|
||||
from tinygrad.codegen.kernel import Kernel
|
||||
from tinygrad.engine.search import Opt, OptOps
|
||||
from tinygrad.engine.search import time_linearizer, bufs_from_lin
|
||||
|
||||
@@ -12,7 +12,7 @@ from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.view import View
|
||||
|
||||
def _test_overflow(ast, opts):
|
||||
lin = Lowerer(ast)
|
||||
lin = Kernel(ast)
|
||||
for opt in opts: lin.apply_opt(opt)
|
||||
lin.linearize()
|
||||
bufs = bufs_from_lin(lin)
|
||||
|
||||
@@ -10,7 +10,7 @@ from tinygrad.device import Device
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.ops import BinaryOps, MetaOps, ReduceOps, UnaryOps
|
||||
from tinygrad.helpers import DEBUG, flatten, getenv
|
||||
from tinygrad.codegen.lowerer import Lowerer
|
||||
from tinygrad.codegen.kernel import Kernel
|
||||
from tinygrad.engine.graph import print_tree
|
||||
from tinygrad.engine.schedule import create_schedule
|
||||
from tinygrad.engine.realize import run_schedule
|
||||
@@ -38,7 +38,7 @@ def check_schedule(t:Union[Tensor, List[Tensor]], allowed:int, to_prerealize:Opt
|
||||
# test the (non loadops) ops linearize
|
||||
for s in sched:
|
||||
if s.ast.op is not MetaOps.SINK: continue
|
||||
l = Lowerer(s.ast)
|
||||
l = Kernel(s.ast)
|
||||
l.hand_coded_optimizations()
|
||||
l.linearize()
|
||||
return sched
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import unittest
|
||||
|
||||
from tinygrad.codegen.kernel import Opt, OptOps
|
||||
from tinygrad.codegen.lowerer import Lowerer
|
||||
from tinygrad.codegen.kernel import Kernel
|
||||
from tinygrad.engine.schedule import create_schedule
|
||||
from tinygrad.engine.search import time_linearizer, bufs_from_lin, actions, beam_search
|
||||
from tinygrad.device import Device, Buffer
|
||||
@@ -19,12 +19,12 @@ class TestTimeLinearizer(unittest.TestCase):
|
||||
out = Buffer(Device.DEFAULT, si.outputs[0].size, si.outputs[0].dtype).allocate()
|
||||
memops = {x.arg.idx:x.arg.st.real_size() for x in si.ast.lazyops if x.op is BufferOps.LOAD}
|
||||
rawbufs = [out] + [Buffer(Device.DEFAULT, memops[i], x.dtype).allocate() for i,x in enumerate(si.inputs, start=len(si.outputs))]
|
||||
tm = time_linearizer(Lowerer(si.ast), rawbufs, allow_test_size=False, cnt=10, disable_cache=True)
|
||||
tm = time_linearizer(Kernel(si.ast), rawbufs, allow_test_size=False, cnt=10, disable_cache=True)
|
||||
assert tm > 0 and tm != float('inf')
|
||||
|
||||
def test_bufs_from_lin(self):
|
||||
si = [i for i in create_schedule([Tensor([1,2,3,4]).add(1).lazydata]) if i.ast.op is MetaOps.SINK][0]
|
||||
rawbufs = bufs_from_lin(lin:=Lowerer(si.ast))
|
||||
rawbufs = bufs_from_lin(lin:=Kernel(si.ast))
|
||||
assert len(rawbufs) == len(lin.membufs)
|
||||
assert all(r is not None for r in rawbufs)
|
||||
assert all(isinstance(r, Buffer) for r in rawbufs)
|
||||
@@ -36,7 +36,7 @@ class TestTimeLinearizer(unittest.TestCase):
|
||||
"""
|
||||
# ast of Tensor.zeros(16).contiguous().realize()
|
||||
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(16,), strides=(0,), offset=0, mask=None, contiguous=False),)))),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(16,), strides=(1,), offset=0, mask=None, contiguous=True),)))) # noqa: E501
|
||||
lin = Lowerer(ast)
|
||||
lin = Kernel(ast)
|
||||
bufs = bufs_from_lin(lin)
|
||||
|
||||
kernel_count = GlobalCounters.kernel_count
|
||||
@@ -71,7 +71,7 @@ class TestBEAM(unittest.TestCase):
|
||||
b = Tensor.rand(3)
|
||||
realized_ast, _ = helper_realized_ast(a @ b)
|
||||
from tinygrad.engine.search import get_linearizer_actions
|
||||
lins = get_linearizer_actions(Lowerer(realized_ast), False).values()
|
||||
lins = get_linearizer_actions(Kernel(realized_ast), False).values()
|
||||
|
||||
# ensure amt=0 are not duplicated
|
||||
if Opt(OptOps.UPCAST, 0, 0) in actions:
|
||||
@@ -88,7 +88,7 @@ class TestBEAM(unittest.TestCase):
|
||||
def test_filter_global_buffer(self):
|
||||
# taken from https://github.com/tinygrad/tinygrad/issues/4612
|
||||
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=ReduceOps.MAX, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=0, mask=((0, 64128),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False))))), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=-64128, mask=((64128, 128256),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False)))))), arg=None), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=3, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=-128256, mask=((128256, 192384),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False)))))), arg=None), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=4, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=-192384, mask=((192384, 256512),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False)))))), arg=None), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=5, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=-256512, mask=((256512, 320640),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False)))))), arg=None), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=6, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=-320640, mask=((320640, 384768),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False)))))), arg=None), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=1.4285714285714286, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 501, 256), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=(1,)),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 1, 256), strides=(0, 0, 1), offset=0, mask=None, contiguous=True),)))) # noqa: E501
|
||||
lin = Lowerer(ast)
|
||||
lin = Kernel(ast)
|
||||
|
||||
bufs = bufs_from_lin(lin)
|
||||
best_lin = beam_search(lin, bufs, 3)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
import unittest
|
||||
from tinygrad.codegen.lowerer import Lowerer
|
||||
#from tinygrad.codegen.lowerer import Lowerer
|
||||
from tinygrad.codegen.kernel import Kernel
|
||||
#from tinygrad.codegen.kernel import Kernel
|
||||
from tinygrad.engine.graph import print_tree
|
||||
from tinygrad.helpers import DEBUG
|
||||
from tinygrad.ops import BufferOps, MemBuffer, LazyOp, ReduceOps, MetaOps, verify_lazyop
|
||||
@@ -16,7 +16,7 @@ def lower(*ast:LazyOp):
|
||||
for op in ast: print_tree(op)
|
||||
try: verify_lazyop(sink_ast)
|
||||
except AssertionError: raise InvalidLazyOpException()
|
||||
k = Lowerer(sink_ast)
|
||||
k = Kernel(sink_ast)
|
||||
k.linearize()
|
||||
if DEBUG >= 6: k.uops.print()
|
||||
if DEBUG >= 4: print(k.to_program().src)
|
||||
|
||||
@@ -2,7 +2,7 @@ import unittest
|
||||
from tinygrad import Tensor, GlobalCounters
|
||||
from tinygrad.helpers import Timing, CI, Profiling, WINO, DEBUG, getenv
|
||||
from tinygrad.ops import MetaOps
|
||||
from tinygrad.codegen.lowerer import Lowerer
|
||||
from tinygrad.codegen.kernel import Kernel
|
||||
from tinygrad.engine.schedule import create_schedule
|
||||
|
||||
class TestWinograd(unittest.TestCase):
|
||||
@@ -26,7 +26,7 @@ class TestWinograd(unittest.TestCase):
|
||||
if s.ast.op is not MetaOps.SINK: continue
|
||||
ops = s.ast.lazyops
|
||||
with Timing(f"linearize {i} with {len(ops):4d} ops: "):
|
||||
l = Lowerer(s.ast)
|
||||
l = Kernel(s.ast)
|
||||
l.hand_coded_optimizations()
|
||||
l.linearize()
|
||||
assert len(l.sts) <= 256 # just the current value to prevent regression
|
||||
|
||||
@@ -4,7 +4,7 @@ from tinygrad import dtypes, Tensor
|
||||
from tinygrad.helpers import prod
|
||||
from tinygrad.ops import LazyOp, UnaryOps, BinaryOps, ReduceOps, get_lazyop_info, BufferOps, MemBuffer
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.codegen.lowerer import Lowerer
|
||||
from tinygrad.codegen.kernel import Kernel
|
||||
from tinygrad.codegen.uops import flops_mem
|
||||
|
||||
class TestFlopCounter(unittest.TestCase):
|
||||
@@ -15,7 +15,7 @@ class TestFlopCounter(unittest.TestCase):
|
||||
|
||||
def compare_flop_counters(self, ast):
|
||||
info = get_lazyop_info(ast.src[0])
|
||||
lin = Lowerer(ast)
|
||||
lin = Kernel(ast)
|
||||
# NOTE: why does hand coded optimizations change flops for the GEMM?
|
||||
#lin.hand_coded_optimizations()
|
||||
lin.linearize()
|
||||
|
||||
@@ -4,14 +4,19 @@ from dataclasses import replace
|
||||
from collections import defaultdict
|
||||
from typing import Optional, List, Tuple, cast, Dict, Union, Final, DefaultDict
|
||||
from tinygrad.engine.graph import print_tree
|
||||
from tinygrad.ops import LazyOp, UnaryOps, BinaryOps, ReduceOps, MemBuffer, ConstBuffer, BufferOps, MetaOps, UNSAFE_PAD_OPS, verify_lazyop, KernelInfo
|
||||
from tinygrad.ops import LazyOp, UnaryOps, BinaryOps, ReduceOps, MemBuffer, ConstBuffer, BufferOps, MetaOps, UNSAFE_PAD_OPS, \
|
||||
verify_lazyop, KernelInfo, get_lazyop_info
|
||||
from tinygrad.device import Device
|
||||
from tinygrad.renderer import Renderer, TensorCore
|
||||
from tinygrad.renderer import Renderer, TensorCore, Program
|
||||
from tinygrad.dtype import dtypes, ImageDType
|
||||
from tinygrad.helpers import all_same, colored, ansilen, dedup, getenv, prod, DEBUG, TC_OPT, USE_TC, round_up, all_int, get_contraction, to_function_name # noqa: E501
|
||||
from tinygrad.helpers import all_same, colored, ansilen, dedup, getenv, prod, DEBUG, TC_OPT, USE_TC, round_up, all_int, \
|
||||
get_contraction, to_function_name, diskcache_put, ContextVar
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.symbolic import sint
|
||||
from tinygrad.shape.view import strides_for_shape
|
||||
from tinygrad.codegen.uops import UOps, flops_mem
|
||||
from tinygrad.codegen.uopgraph import UOpGraph
|
||||
from tinygrad.codegen.lowerer import lazyop_to_uop
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum, auto
|
||||
|
||||
@@ -719,3 +724,45 @@ class Kernel:
|
||||
arg = op.arg
|
||||
return LazyOp(op.op, tuple(fixup_ast(x) for x in op.src), arg)
|
||||
return fixup_ast(self.ast)
|
||||
|
||||
# **** this is the lowerer ****
|
||||
|
||||
def linearize(self) -> Kernel:
|
||||
modified_ast = self.get_optimized_ast()
|
||||
|
||||
if DEBUG >= 3:
|
||||
print(self.name)
|
||||
print_tree(modified_ast)
|
||||
|
||||
uop_sink = lazyop_to_uop(modified_ast, self.opts)
|
||||
|
||||
# extract global/local sizes
|
||||
if self.opts.has_local:
|
||||
self.global_size: Optional[List[int]] = [1,1,1]
|
||||
self.local_size: Optional[List[int]] = [1,1,1]
|
||||
for u in uop_sink.parents:
|
||||
if u.op is UOps.SPECIAL:
|
||||
if u.arg[1][0] == 'l': self.local_size[u.arg[0]] = u.arg[2]
|
||||
else: self.global_size[u.arg[0]] = u.arg[2]
|
||||
else:
|
||||
self.global_size, self.local_size = None, None
|
||||
|
||||
# generate the UOpGraph
|
||||
self.uops:UOpGraph = UOpGraph(uop_sink, self.opts)
|
||||
if DEBUG >= 5: self.uops.print()
|
||||
if getenv("GRAPHUOPS"):
|
||||
self.uops.graph()
|
||||
if getenv("GRAPHUOPS") == 2: exit(0)
|
||||
return self
|
||||
|
||||
def to_program(self) -> Program:
|
||||
self.linearize()
|
||||
src = self.opts.render(name:=to_function_name(self.name), self.uops)
|
||||
if getenv("RUN_PROCESS_REPLAY"):
|
||||
table_name = f"process_replay_{getenv('GITHUB_SHA', 'HEAD')}"
|
||||
diskcache_put(table_name, id(self), (self.ast, self.opts, self.applied_opts, name, src, {k:v.value for k,v in ContextVar._cache.items()}))
|
||||
info = get_lazyop_info(self.ast.src[0]) # TODO: this should be removed
|
||||
ops, mem = flops_mem(self.uops.uops)
|
||||
run_count = prod((self.global_size or []) + (self.local_size or []))
|
||||
return Program(self.name, src, self.opts.device, self.global_size, self.local_size,
|
||||
self.uops, min(info.flops, ops * run_count), min(info.mem_estimate, mem * run_count))
|
||||
|
||||
@@ -1,14 +1,12 @@
|
||||
from __future__ import annotations
|
||||
from typing import List, Tuple, cast, Optional, Any, Dict
|
||||
import functools
|
||||
from tinygrad.codegen.kernel import Kernel
|
||||
from tinygrad.shape.shapetracker import ShapeTracker, View
|
||||
from tinygrad.dtype import dtypes, PtrDType, ImageDType, DType
|
||||
from tinygrad.ops import BufferOps, LazyOp, TernaryOps, ReduceOps, UnaryOps, MetaOps, get_lazyop_info, KernelInfo
|
||||
from tinygrad.codegen.uops import UOp, flops_mem, UOps
|
||||
from tinygrad.codegen.uopgraph import UOpGraph
|
||||
from tinygrad.renderer import Program, Renderer
|
||||
from tinygrad.helpers import to_function_name, DEBUG, getenv, prod, diskcache_put, ContextVar
|
||||
from tinygrad.ops import BufferOps, LazyOp, TernaryOps, ReduceOps, UnaryOps, MetaOps, KernelInfo
|
||||
from tinygrad.codegen.uops import UOp, UOps
|
||||
from tinygrad.renderer import Renderer
|
||||
from tinygrad.helpers import getenv, prod
|
||||
|
||||
# TODO: this needs to be replaced, there shouldn't be variables in the shapetracker, only ints and UOps
|
||||
from tinygrad.shape.symbolic import Variable, NumNode, SumNode, MulNode, DivNode, ModNode, LtNode, AndNode
|
||||
@@ -156,47 +154,5 @@ class IndependentLowerer:
|
||||
# NOTE: always using ridxs is fine here
|
||||
return UOp(UOps.REDUCE, dtype, (in_uops[0],) + tuple(self.ridxs[i] for i in x.arg), x.op)
|
||||
return UOp.alu(x.op, *in_uops)
|
||||
|
||||
def lazyop_to_uop(ast:LazyOp, opts:Renderer) -> UOp: return IndependentLowerer().lower(ast, opts)
|
||||
|
||||
# TODO: move this to Kernel
|
||||
class Lowerer(Kernel):
|
||||
def linearize(self) -> Lowerer:
|
||||
modified_ast = self.get_optimized_ast()
|
||||
|
||||
if DEBUG >= 3:
|
||||
print(self.name)
|
||||
from tinygrad.engine.graph import print_tree
|
||||
print_tree(modified_ast)
|
||||
|
||||
uop_sink = lazyop_to_uop(modified_ast, self.opts)
|
||||
|
||||
# extract global/local sizes
|
||||
if self.opts.has_local:
|
||||
self.global_size: Optional[List[int]] = [1,1,1]
|
||||
self.local_size: Optional[List[int]] = [1,1,1]
|
||||
for u in uop_sink.parents:
|
||||
if u.op is UOps.SPECIAL:
|
||||
if u.arg[1][0] == 'l': self.local_size[u.arg[0]] = u.arg[2]
|
||||
else: self.global_size[u.arg[0]] = u.arg[2]
|
||||
else:
|
||||
self.global_size, self.local_size = None, None
|
||||
|
||||
# generate the UOpGraph
|
||||
self.uops:UOpGraph = UOpGraph(uop_sink, self.opts)
|
||||
if DEBUG >= 5: self.uops.print()
|
||||
if getenv("GRAPHUOPS"):
|
||||
self.uops.graph()
|
||||
if getenv("GRAPHUOPS") == 2: exit(0)
|
||||
return self
|
||||
|
||||
def to_program(self) -> Program:
|
||||
self.linearize()
|
||||
src = self.opts.render(name:=to_function_name(self.name), self.uops)
|
||||
if getenv("RUN_PROCESS_REPLAY"):
|
||||
table_name = f"process_replay_{getenv('GITHUB_SHA', 'HEAD')}"
|
||||
diskcache_put(table_name, id(self), (self.ast, self.opts, self.applied_opts, name, src, {k:v.value for k,v in ContextVar._cache.items()}))
|
||||
info = get_lazyop_info(self.ast.src[0]) # TODO: this should be removed
|
||||
ops, mem = flops_mem(self.uops.uops)
|
||||
run_count = prod((self.global_size or []) + (self.local_size or []))
|
||||
return Program(self.name, src, self.opts.device, self.global_size, self.local_size,
|
||||
self.uops, min(info.flops, ops * run_count), min(info.mem_estimate, mem * run_count))
|
||||
|
||||
@@ -6,31 +6,31 @@ from tinygrad.ops import MetaOps, LazyOp
|
||||
from tinygrad.device import Device, Buffer
|
||||
from tinygrad.shape.symbolic import Variable, sym_infer, sint
|
||||
from tinygrad.renderer import Renderer, Program
|
||||
from tinygrad.codegen.lowerer import Lowerer
|
||||
from tinygrad.codegen.kernel import Kernel
|
||||
from tinygrad.engine.schedule import ScheduleItem
|
||||
|
||||
# **************** Program Creation ****************
|
||||
|
||||
logkerns, logkerns_level = open(getenv("LOGKERNS", ""), "a") if getenv("LOGKERNS", "") else None, getenv("LOGKERNS_LEVEL", 1)
|
||||
def get_linearizer(renderer:Renderer, ast:LazyOp) -> Lowerer:
|
||||
def get_linearizer(renderer:Renderer, ast:LazyOp) -> Kernel:
|
||||
if DEBUG >= 5:
|
||||
from tinygrad.engine.graph import print_tree
|
||||
print_tree(ast)
|
||||
k = Lowerer(ast, opts=renderer)
|
||||
k = Kernel(ast, opts=renderer)
|
||||
k.required_optimizations()
|
||||
if not NOOPT:
|
||||
if not (used_tensor_cores:=k.apply_tensor_cores(getenv("TC", 1))): k.hand_coded_optimizations()
|
||||
if BEAM >= 1:
|
||||
from tinygrad.engine.search import beam_search, time_linearizer, bufs_from_lin
|
||||
kb, k_opt = Lowerer(ast, opts=renderer), k
|
||||
kb, k_opt = Kernel(ast, opts=renderer), k
|
||||
kb.required_optimizations()
|
||||
rawbufs = bufs_from_lin(kb, allocate=False)
|
||||
k = beam_search(kb, rawbufs, BEAM.value, bool(getenv("BEAM_ESTIMATE", 1)))
|
||||
if getenv("BEAM_COMPARE", 1):
|
||||
# TODO: move the HC/TC/BEAM compare to beam_search so it can be optionally cached which choice is better
|
||||
lins: List[Tuple[str, Lowerer]] = [(f"beam{BEAM.value}", k), (("tc" if used_tensor_cores else "hc"), k_opt)]
|
||||
lins: List[Tuple[str, Kernel]] = [(f"beam{BEAM.value}", k), (("tc" if used_tensor_cores else "hc"), k_opt)]
|
||||
if used_tensor_cores:
|
||||
lins.append(("hc", Lowerer(ast, opts=renderer)))
|
||||
lins.append(("hc", Kernel(ast, opts=renderer)))
|
||||
lins[-1][1].hand_coded_optimizations()
|
||||
timed = sorted([(nm, tk, time_linearizer(tk, rawbufs, allow_test_size=False, clear_l2=True)) for nm, tk in lins], key=lambda x: x[2])
|
||||
if DEBUG >= 1: print(" < ".join(f"{nm:6s} : {lin.colored_shape(30, dense=True)} : {tm*1e6:8.2f} us" for nm, lin, tm in timed))
|
||||
|
||||
@@ -6,7 +6,7 @@ from tinygrad.device import Device, Buffer, Compiler
|
||||
from tinygrad.ops import MemBuffer
|
||||
from tinygrad.helpers import prod, flatten, DEBUG, CACHELEVEL, diskcache_get, diskcache_put, getenv, Context, colored, to_function_name
|
||||
from tinygrad.dtype import ImageDType
|
||||
from tinygrad.codegen.lowerer import Lowerer
|
||||
from tinygrad.codegen.kernel import Kernel
|
||||
from tinygrad.codegen.kernel import Opt, OptOps, KernelOptError
|
||||
from tinygrad.codegen.uopgraph import UOpGraph
|
||||
from tinygrad.tensor import Tensor
|
||||
@@ -53,7 +53,7 @@ def _time_program(p:Program, lib:bytes, var_vals, rawbufs, early_stop=None, max_
|
||||
class TimeoutException(Exception): pass
|
||||
def timeout_handler(signum, frame): raise TimeoutException()
|
||||
|
||||
def _try_compile_linearized_w_idx(x:Tuple[int,Lowerer], compiler:Compiler) -> Tuple[int, Optional[Tuple[Program, bytes, float]]]:
|
||||
def _try_compile_linearized_w_idx(x:Tuple[int,Kernel], compiler:Compiler) -> Tuple[int, Optional[Tuple[Program, bytes, float]]]:
|
||||
signal.signal(signal.SIGALRM, timeout_handler)
|
||||
# set timeout
|
||||
signal.alarm(getenv("BEAM_TIMEOUT_SEC", 10))
|
||||
@@ -85,7 +85,7 @@ def _ensure_buffer_alloc(bufs:List[Buffer]) -> List[Buffer]: return [buf.ensure_
|
||||
# *** external API ***
|
||||
|
||||
# get (scrap) buffers for timing the linearizer
|
||||
def bufs_from_lin(lin:Lowerer, allocate:bool=True) -> List[Buffer]:
|
||||
def bufs_from_lin(lin:Kernel, allocate:bool=True) -> List[Buffer]:
|
||||
bufsts:DefaultDict[int, List[MemBuffer]] = defaultdict(list)
|
||||
for x in lin.membufs: bufsts[x.idx].append(x)
|
||||
rawbufs:List[Optional[Buffer]] = [None]*len(bufsts)
|
||||
@@ -97,7 +97,7 @@ def bufs_from_lin(lin:Lowerer, allocate:bool=True) -> List[Buffer]:
|
||||
return cast(List[Buffer], rawbufs)
|
||||
|
||||
# get dictionary of all possible actions
|
||||
def get_linearizer_actions(lin:Lowerer, include_0=True) -> Dict[int, Lowerer]:
|
||||
def get_linearizer_actions(lin:Kernel, include_0=True) -> Dict[int, Kernel]:
|
||||
acted_lins, max_up, max_lcl = {0:lin} if include_0 else {}, getenv("BEAM_UPCAST_MAX", 256), getenv("BEAM_LOCAL_MAX", 1024)
|
||||
for i,a in enumerate(actions):
|
||||
if a.axis is not None and a.op is not OptOps.TC:
|
||||
@@ -115,7 +115,7 @@ def get_linearizer_actions(lin:Lowerer, include_0=True) -> Dict[int, Lowerer]:
|
||||
return acted_lins
|
||||
|
||||
beam_pool, BEAM_DEBUG = None, getenv("BEAM_DEBUG")
|
||||
def beam_search(lin:Lowerer, rawbufs:List[Buffer], amt:int, allow_test_size=True) -> Lowerer:
|
||||
def beam_search(lin:Kernel, rawbufs:List[Buffer], amt:int, allow_test_size=True) -> Kernel:
|
||||
global beam_pool
|
||||
key = {"ast": lin.ast.key, "amt": amt, "allow_test_size": allow_test_size, "device": lin.opts.device, "suffix": lin.opts.suffix}
|
||||
if not getenv("IGNORE_BEAM_CACHE") and CACHELEVEL >= 1 and (val:=diskcache_get("beam_search", key)) is not None:
|
||||
@@ -123,7 +123,7 @@ def beam_search(lin:Lowerer, rawbufs:List[Buffer], amt:int, allow_test_size=True
|
||||
for o in val[len(lin.applied_opts):]: ret.apply_opt(o)
|
||||
return ret
|
||||
|
||||
beam: List[Tuple[Lowerer, float]] = [(lin, float("inf"))]
|
||||
beam: List[Tuple[Kernel, float]] = [(lin, float("inf"))]
|
||||
seen_libs = set()
|
||||
|
||||
default_parallel = multiprocessing.cpu_count() if lin.opts.device in {"CUDA", "AMD", "NV"} else 0
|
||||
@@ -140,8 +140,8 @@ def beam_search(lin:Lowerer, rawbufs:List[Buffer], amt:int, allow_test_size=True
|
||||
exiting, st = False, time.perf_counter()
|
||||
dev = Device[lin.opts.device]
|
||||
while not exiting:
|
||||
acted_lins: List[Lowerer] = flatten([get_linearizer_actions(lin, include_0=False).values() for lin,_ in beam])
|
||||
timed_lins: List[Tuple[Lowerer, float]] = []
|
||||
acted_lins: List[Kernel] = flatten([get_linearizer_actions(lin, include_0=False).values() for lin,_ in beam])
|
||||
timed_lins: List[Tuple[Kernel, float]] = []
|
||||
_compile_fn = functools.partial(_try_compile_linearized_w_idx, compiler=dev.compiler)
|
||||
for i,proc in (map(_compile_fn, enumerate(acted_lins)) if beam_pool is None else beam_pool.imap_unordered(_compile_fn, enumerate(acted_lins))):
|
||||
if proc is None: continue
|
||||
@@ -181,7 +181,7 @@ def optimize_local_size(clprg:Callable, global_size:List[int], rawbufs:List[Buff
|
||||
assert not math.isinf(ret[0]), "all optimize_local_size exec failed"
|
||||
return ret[1]
|
||||
|
||||
def time_linearizer(lin:Lowerer, rawbufs:List[Buffer], allow_test_size=True, max_global_size=65536, cnt=3, disable_cache=False, clear_l2=False) -> float: # noqa: E501
|
||||
def time_linearizer(lin:Kernel, rawbufs:List[Buffer], allow_test_size=True, max_global_size=65536, cnt=3, disable_cache=False, clear_l2=False) -> float: # noqa: E501
|
||||
key = {"ast": lin.ast.key, "opts": str(lin.applied_opts), "allow_test_size": allow_test_size,
|
||||
"max_global_size": max_global_size, "clear_l2": clear_l2, "device": lin.opts.device, "suffix": lin.opts.suffix}
|
||||
if not disable_cache and CACHELEVEL >= 2 and (val:=diskcache_get("time_linearizer", key)) is not None: return min(val)
|
||||
|
||||
Reference in New Issue
Block a user