move view pushing to codegen, try 2 (#11534)

* move view pushing to codegen, try 2

* fix up some linearizer tests

* fix test search

* fix test schedule

* delete that test

* fix test arange

* fix a few tests

* update tests

* push views

* ebs cleanup

* fix local/reg

* test and lint

* fix more tests

* test cleanups

* skipped that one
This commit is contained in:
George Hotz
2025-08-06 15:58:38 -07:00
committed by GitHub
parent 2d5bdc939d
commit 21570545d3
12 changed files with 38 additions and 180 deletions

View File

@@ -542,8 +542,6 @@ jobs:
run: PYTHONPATH="." GPU=1 IGNORE_BEAM_CACHE=1 python3 -m pytest extra/optimization/test_beam_search.py run: PYTHONPATH="." GPU=1 IGNORE_BEAM_CACHE=1 python3 -m pytest extra/optimization/test_beam_search.py
- name: Test MLPerf stuff - name: Test MLPerf stuff
run: GPU=1 python -m pytest -n=auto test/external/external_test_optim.py test/external/external_test_losses.py test/external/external_test_metrics.py test/external/external_test_datasets.py --durations=20 run: GPU=1 python -m pytest -n=auto test/external/external_test_optim.py test/external/external_test_losses.py test/external/external_test_metrics.py test/external/external_test_datasets.py --durations=20
- name: Run handcode_opt
run: PYTHONPATH=. MODEL=resnet GPU=1 DEBUG=1 BS=4 HALF=0 python3 examples/handcode_opt.py
- name: Run process replay tests - name: Run process replay tests
uses: ./.github/actions/process-replay uses: ./.github/actions/process-replay

View File

@@ -1,134 +0,0 @@
from extra.models.resnet import ResNet50
from extra.mcts_search import mcts_search
from examples.mlperf.helpers import get_mlperf_bert_model
from tinygrad import Tensor, Device, dtypes, nn
from tinygrad.opt.kernel import Kernel
from tinygrad.opt.heuristic import hand_coded_optimizations
from tinygrad.uop.ops import Ops, sym_infer
from tinygrad.device import Compiled
from tinygrad.opt.search import beam_search, bufs_from_lin
from tinygrad.helpers import DEBUG, ansilen, getenv, colored, TRACEMETA
from extra.optimization.helpers import time_linearizer
from tinygrad.engine.realize import get_program
def get_sched_resnet():
mdl = ResNet50()
optim = (nn.optim.LARS if getenv("LARS") else nn.optim.SGD)(nn.state.get_parameters(mdl))
BS = getenv("BS", 64)
# run model twice to get only what changes, these are the kernels of the model
for _ in range(2):
out = mdl(Tensor.empty(BS, 3, 224, 224))
targets = [out]
if getenv("BACKWARD"):
optim.zero_grad()
out.sparse_categorical_crossentropy(Tensor.empty(BS, dtype=dtypes.int)).backward()
targets += [x for x in optim.schedule_step()]
sched = Tensor.schedule(*targets)
print(f"schedule length {len(sched)}")
return sched
def get_sched_bert():
mdl = get_mlperf_bert_model()
optim = nn.optim.LAMB(nn.state.get_parameters(mdl))
# fake data
BS = getenv("BS", 9)
input_ids = Tensor.empty((BS, 512), dtype=dtypes.float32)
segment_ids = Tensor.empty((BS, 512), dtype=dtypes.float32)
attention_mask = Tensor.empty((BS, 512), dtype=dtypes.default_float)
masked_positions = Tensor.empty((BS, 76), dtype=dtypes.float32)
masked_lm_ids = Tensor.empty((BS, 76), dtype=dtypes.float32)
masked_lm_weights = Tensor.empty((BS, 76), dtype=dtypes.float32)
next_sentence_labels = Tensor.empty((BS, 1), dtype=dtypes.float32)
# run model twice to get only what changes, these are the kernels of the model
for _ in range(2):
lm_logits, seq_relationship_logits = mdl(input_ids, attention_mask, masked_positions, segment_ids)
targets = [lm_logits, seq_relationship_logits]
if getenv("BACKWARD"):
optim.zero_grad()
loss = mdl.loss(lm_logits, seq_relationship_logits, masked_lm_ids, masked_lm_weights, next_sentence_labels)
# ignore grad norm and loss scaler for now
loss.backward()
targets += [x for x in optim.schedule_step()]
sched = Tensor.schedule(*targets)
print(f"schedule length {len(sched)}")
return sched
if __name__ == "__main__":
if getenv("HALF", 1):
dtypes.default_float = dtypes.half
# the device we are optimizing for
device: Compiled = Device[Device.DEFAULT]
if getenv("BACKWARD"): Tensor.training = True
print(f"optimizing for {Device.DEFAULT}")
sched = globals()[f"get_sched_{getenv('MODEL', 'resnet')}"]()
sched = [x for x in sched if x.ast.op is Ops.SINK]
# focus on one kernel
if getenv("KERNEL", -1) >= 0: sched = sched[getenv("KERNEL", -1):getenv("KERNEL", -1)+1]
# work with the schedule
total_tm = 0
running_gflops = 0
usage = {}
for i,si in enumerate(sched):
if DEBUG >= 3: print(si.ast)
rawbufs = bufs_from_lin(Kernel(si.ast))
# "linearize" the op into uops in different ways
lins: list[tuple[Kernel, str]] = []
# always try hand coded opt
lin = Kernel(si.ast, opts=device.renderer)
lin.apply_opts(hand_coded_optimizations(lin))
lins.append((lin, "HC"))
# maybe try tensor cores
lin = Kernel(si.ast, opts=device.renderer)
if lin.apply_tensor_cores():
lins.append((lin, "TC"))
# try a beam search
if beam:=getenv("BEAM"):
lin = Kernel(si.ast, opts=device.renderer)
lin = beam_search(lin, rawbufs, beam, bool(getenv("BEAM_ESTIMATE", 1)))
lins.append((lin, "BEAM"))
# try MCTS
if mcts:=getenv("MCTS"):
lin = Kernel(si.ast, opts=device.renderer)
lin = mcts_search(lin, rawbufs, mcts)
lins.append((lin, "MCTS"))
# benchmark the programs
choices = []
for lin, nm in lins:
tm = time_linearizer(lin, rawbufs, allow_test_size=False, cnt=10, disable_cache=True)
ops = (prg:=get_program(lin.get_optimized_ast(), lin.opts)).estimates.ops
gflops = sym_infer(ops, {k:k.min for k in lin.ast.variables()})*1e-9/tm
choices.append((tm, gflops, lin, prg, nm))
sorted_choices = sorted(choices, key=lambda x: x[0])
if DEBUG >= 1: # print all kernels
for tm, gflops, lin, prg, nm in choices:
print(f" kernel {i:2d} {lin.name+' '*(37-ansilen(lin.name))} {str(prg.global_size):18s} {str(prg.local_size):12s} takes {tm*1000:7.2f} ms, {gflops:6.0f} GFLOPS -- {colored(nm, 'green') if lin is sorted_choices[0][2] else nm}")
tm, gflops, lin, prg, nm = sorted_choices[0]
if getenv("SRC"):
print(si.ast)
print(lin.applied_opts)
print(get_program(lin.get_optimized_ast(), lin.opts).src)
total_tm += tm
running_gflops += gflops * tm
if (key := str([str(m) for m in si.metadata])) not in usage: usage[key] = (0, 0)
usage[key] = (usage[key][0] + tm, usage[key][1] + 1)
print(f"*** {total_tm*1000:7.2f} ms : kernel {i:2d} {lin.name+' '*(37-ansilen(lin.name))} {str(prg.global_size):18s} {str(prg.local_size):12s} takes {tm*1000:7.2f} ms, {gflops:6.0f} GFLOPS {[repr(m) if TRACEMETA >= 2 else str(m) for m in si.metadata]}")
print(f"******* total {total_tm*1000:.2f} ms, {running_gflops/total_tm:6.0f} GFLOPS")
print("usage:")
for k in sorted(usage, key=lambda x: -usage[x][0])[:10]:
print(f"{usage[k][0]*1000:.2f} ms: {k} ({usage[k][1]} times)")

View File

@@ -7,6 +7,7 @@ from tinygrad.opt.kernel import Opt, OptOps, Kernel, KernelOptError
from tinygrad.engine.realize import CompiledRunner, ExecItem, get_program from tinygrad.engine.realize import CompiledRunner, ExecItem, get_program
from tinygrad.opt.search import get_kernel_actions from tinygrad.opt.search import get_kernel_actions
from tinygrad.uop.ops import Ops from tinygrad.uop.ops import Ops
from tinygrad.codegen import apply_rewrites, rewrites_for_views
class TestArange(unittest.TestCase): class TestArange(unittest.TestCase):
def _get_flops(self, N, opts=None): def _get_flops(self, N, opts=None):
@@ -49,11 +50,11 @@ class TestArange(unittest.TestCase):
def test_complexity_w_local_and_padto(self): return self.test_complexity([Opt(OptOps.LOCAL, 0, 16), Opt(OptOps.PADTO, axis=1, arg=32)]) def test_complexity_w_local_and_padto(self): return self.test_complexity([Opt(OptOps.LOCAL, 0, 16), Opt(OptOps.PADTO, axis=1, arg=32)])
def test_all_opts(self, opts=None, exclude=None): def test_all_opts(self, opts=None, exclude=None):
k = Kernel(Tensor.arange(256).schedule()[-1].ast) k = Kernel(apply_rewrites(Tensor.arange(256).schedule()[-1].ast, rewrites_for_views))
if opts is not None: if opts is not None:
for o in opts: k.apply_opt(o) for o in opts: k.apply_opt(o)
all_opts_256 = [kk.applied_opts for kk in get_kernel_actions(k, include_0=False).values()] all_opts_256 = [kk.applied_opts for kk in get_kernel_actions(k, include_0=False).values()]
k = Kernel(Tensor.arange(2560).schedule()[-1].ast) k = Kernel(apply_rewrites(Tensor.arange(2560).schedule()[-1].ast, rewrites_for_views))
if opts is not None: if opts is not None:
for o in opts: k.apply_opt(o) for o in opts: k.apply_opt(o)
all_opts_2560 = [kk.applied_opts for kk in get_kernel_actions(k, include_0=False).values()] all_opts_2560 = [kk.applied_opts for kk in get_kernel_actions(k, include_0=False).values()]

View File

@@ -6,7 +6,7 @@ from tinygrad.runtime.support.hcq import HCQCompiled, HCQBuffer
from tinygrad.runtime.autogen import libc from tinygrad.runtime.autogen import libc
from tinygrad.runtime.support.system import PCIIfaceBase from tinygrad.runtime.support.system import PCIIfaceBase
from tinygrad.engine.realize import get_runner, CompiledRunner, get_program from tinygrad.engine.realize import get_runner, CompiledRunner, get_program
from tinygrad.opt.kernel import Kernel, Opt, OptOps from tinygrad.opt.kernel import Opt, OptOps
from tinygrad import Variable from tinygrad import Variable
MOCKGPU = getenv("MOCKGPU") MOCKGPU = getenv("MOCKGPU")
@@ -163,10 +163,8 @@ class TestHCQ(unittest.TestCase):
a = Tensor.randint((3, 3, 3), dtype=dtypes.int, device=Device.DEFAULT).realize() a = Tensor.randint((3, 3, 3), dtype=dtypes.int, device=Device.DEFAULT).realize()
b = a + 1 b = a + 1
si = b.schedule()[-1] si = b.schedule()[-1]
k = Kernel(si.ast, opts=TestHCQ.d0.renderer)
for i in range(3): k.apply_opt(Opt(op=OptOps.LOCAL, axis=0, arg=3))
runner = CompiledRunner(get_program(k.get_optimized_ast(), k.opts)) runner = CompiledRunner(get_program(si.ast, TestHCQ.d0.renderer, opts=[Opt(op=OptOps.LOCAL, axis=0, arg=3) for _ in range(3)]))
zb = Buffer(Device.DEFAULT, 3 * 3 * 3, dtypes.int, options=BufferSpec(cpu_access=True, nolru=True)).ensure_allocated() zb = Buffer(Device.DEFAULT, 3 * 3 * 3, dtypes.int, options=BufferSpec(cpu_access=True, nolru=True)).ensure_allocated()
zt = Buffer(Device.DEFAULT, 3 * 3 * 3, dtypes.int, options=BufferSpec(cpu_access=True, nolru=True)).ensure_allocated() zt = Buffer(Device.DEFAULT, 3 * 3 * 3, dtypes.int, options=BufferSpec(cpu_access=True, nolru=True)).ensure_allocated()

View File

@@ -13,6 +13,9 @@ from tinygrad.engine.realize import run_schedule, lower_schedule, CompiledRunner
from tinygrad.opt.heuristic import hand_coded_optimizations from tinygrad.opt.heuristic import hand_coded_optimizations
from tinygrad.helpers import prod, Context, getenv, CI, flatten, dedup, AMX, AMD_LLVM from tinygrad.helpers import prod, Context, getenv, CI, flatten, dedup, AMX, AMD_LLVM
from tinygrad.dtype import DType, dtypes, AddrSpace from tinygrad.dtype import DType, dtypes, AddrSpace
from tinygrad.codegen import apply_rewrites, rewrites_for_views
def push_views(ast): return apply_rewrites(ast, rewrites_for_views)
def helper_realized_ast(r:Tensor|list[Tensor]) -> tuple[UOp, list[Buffer]]: def helper_realized_ast(r:Tensor|list[Tensor]) -> tuple[UOp, list[Buffer]]:
if isinstance(r, Tensor): r = [r] if isinstance(r, Tensor): r = [r]
@@ -22,7 +25,7 @@ def helper_realized_ast(r:Tensor|list[Tensor]) -> tuple[UOp, list[Buffer]]:
# now all input buffers in s[-1] should be realized # now all input buffers in s[-1] should be realized
# create fresh buffers for the outputs # create fresh buffers for the outputs
bufs = [Buffer((x).device, x.size, x.dtype).allocate() if i < len(s[-1].ast.src) else x for i,x in enumerate(s[-1].bufs)] bufs = [Buffer((x).device, x.size, x.dtype).allocate() if i < len(s[-1].ast.src) else x for i,x in enumerate(s[-1].bufs)]
return s[-1].ast, bufs return push_views(s[-1].ast), bufs
def helper_tc_allclose(N:int, M:int, K:int, dtype_in:DType, dtype_out:DType, axis:int=0, tc_select:int=-1, tc_opt:int=0, use_tensor_cores:int=1): def helper_tc_allclose(N:int, M:int, K:int, dtype_in:DType, dtype_out:DType, axis:int=0, tc_select:int=-1, tc_opt:int=0, use_tensor_cores:int=1):
a, b = Tensor.rand(M, K, dtype=dtype_in), Tensor.rand(K, N, dtype=dtype_in) a, b = Tensor.rand(M, K, dtype=dtype_in), Tensor.rand(K, N, dtype=dtype_in)
@@ -121,7 +124,7 @@ class TestLinearizer(unittest.TestCase):
with Context(FUSE_ARANGE=1): with Context(FUSE_ARANGE=1):
sink = dataset[idxs].contiguous().kernelize().uop.base.src[1].arg.ast sink = dataset[idxs].contiguous().kernelize().uop.base.src[1].arg.ast
real_index = dataset.numpy()[idxs.numpy()].reshape(4, 256, 1, 1) real_index = dataset.numpy()[idxs.numpy()].reshape(4, 256, 1, 1)
helper_linearizer_ast(sink, [dataset, idxs], wanna_output=[real_index]) helper_linearizer_ast(push_views(sink), [dataset, idxs], wanna_output=[real_index])
def test_two_nested_range(self): def test_two_nested_range(self):
a = Tensor.randn(2, ).realize() a = Tensor.randn(2, ).realize()
@@ -325,7 +328,7 @@ class TestLinearizer(unittest.TestCase):
a, b = Tensor.rand(m, k, dtype=tc.dtype_in), Tensor.rand(k, n, dtype=tc.dtype_in) a, b = Tensor.rand(m, k, dtype=tc.dtype_in), Tensor.rand(k, n, dtype=tc.dtype_in)
r = a.matmul(b, dtype=tc.dtype_out) r = a.matmul(b, dtype=tc.dtype_out)
sched = r.schedule() sched = r.schedule()
realized_ast = sched[-1].ast realized_ast = push_views(sched[-1].ast)
kernel = Kernel(realized_ast) kernel = Kernel(realized_ast)
kernel.apply_tensor_cores(1, axis=0, tc_select=-1, tc_opt=2) kernel.apply_tensor_cores(1, axis=0, tc_select=-1, tc_opt=2)
prg = get_program(kernel.get_optimized_ast(), kernel.opts) prg = get_program(kernel.get_optimized_ast(), kernel.opts)
@@ -799,10 +802,7 @@ class TestFloat4(unittest.TestCase):
c = a + b c = a + b
s = c.schedule()[0] s = c.schedule()[0]
k = Kernel(s.ast) return get_program(s.ast, opts=[Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.UPCAST, axis=1, arg=shift)]).uops
k.shift_to(1, 4, AxisType.UPCAST) # manual trigger float4 dim
k.shift_to(1, shift, AxisType.UPCAST, insert_at=k.shape_len-1)
return get_program(k.get_optimized_ast(), k.opts).uops
sizes = [13, 9, 17] sizes = [13, 9, 17]
shifts = [3, 2, 4] shifts = [3, 2, 4]
@@ -948,7 +948,7 @@ class TestHandCodedOpts(unittest.TestCase):
layer_2 = Tensor.cat(layer_1.unsqueeze(0), Tensor.empty(6, 20)) layer_2 = Tensor.cat(layer_1.unsqueeze(0), Tensor.empty(6, 20))
s = layer_2.schedule()[-1] s = layer_2.schedule()[-1]
k = Kernel(s.ast) k = Kernel(push_views(s.ast))
k.apply_opts(hand_coded_optimizations(k)) k.apply_opts(hand_coded_optimizations(k))
assert len(k.bufs) == 6 # make sure all ops are done in one kernel assert len(k.bufs) == 6 # make sure all ops are done in one kernel
# masked upcast should upcast masked axis of size 7 # masked upcast should upcast masked axis of size 7
@@ -961,7 +961,7 @@ class TestHandCodedOpts(unittest.TestCase):
monster = Tensor.stack(*[Tensor.stack(*[Tensor.empty(16) for _ in range(6)]) for _ in range(6)]) monster = Tensor.stack(*[Tensor.stack(*[Tensor.empty(16) for _ in range(6)]) for _ in range(6)])
s = monster.schedule()[-1] s = monster.schedule()[-1]
k = Kernel(s.ast) k = Kernel(push_views(s.ast))
k.apply_opts(hand_coded_optimizations(k)) k.apply_opts(hand_coded_optimizations(k))
assert len(k.bufs) == 37 # make sure all ops are done in one kernel assert len(k.bufs) == 37 # make sure all ops are done in one kernel
# should upcast the two Tensor.stacks # should upcast the two Tensor.stacks
@@ -977,7 +977,7 @@ class TestHandCodedOpts(unittest.TestCase):
wino_schedule = out.schedule() wino_schedule = out.schedule()
# collect upcasts of tile transform kernels # collect upcasts of tile transform kernels
for i, si in enumerate(wino_schedule): for i, si in enumerate(wino_schedule):
k = Kernel(si.ast) k = Kernel(push_views(si.ast))
k.apply_opts(hand_coded_optimizations(k)) k.apply_opts(hand_coded_optimizations(k))
if k.reduceop is not None: continue # not a tile transform kernel (there is a gemm reduce kernel) if k.reduceop is not None: continue # not a tile transform kernel (there is a gemm reduce kernel)
if len(k.bufs) < 22: continue # not a tile transform kernel (there's a permute kernel at the end) if len(k.bufs) < 22: continue # not a tile transform kernel (there's a permute kernel at the end)
@@ -989,7 +989,7 @@ class TestHandCodedOpts(unittest.TestCase):
backward_schedule = Tensor.schedule(x.grad, w.grad) backward_schedule = Tensor.schedule(x.grad, w.grad)
for si in backward_schedule: for si in backward_schedule:
k = Kernel(si.ast) k = Kernel(push_views(si.ast))
k.apply_opts(hand_coded_optimizations(k)) k.apply_opts(hand_coded_optimizations(k))
if len(k.bufs) < 20: continue # not a tile transform kernel if len(k.bufs) < 20: continue # not a tile transform kernel
# heuristic number to make sure that at least some upcasts but not too many upcasts are being done # heuristic number to make sure that at least some upcasts but not too many upcasts are being done

View File

@@ -40,10 +40,7 @@ def create_gemm_model(model_path:str, batch_size=N, in_size=N, out_size=N, bias=
def sexec(out:Tensor, opts:list[Opt], replace_src=None, run_count=3): def sexec(out:Tensor, opts:list[Opt], replace_src=None, run_count=3):
si = out.schedule()[-1] si = out.schedule()[-1]
k = Kernel(si.ast, opts=Device[Device.DEFAULT].renderer) prg = get_program(si.ast, opts=opts)
#opts = [Opt(op=OptOps.UPCAST, axis=0, arg=128)] #, Opt(op=OptOps.UNROLL, axis=0, arg=4)]
k.apply_opts(opts)
prg = get_program(k.get_optimized_ast(), k.opts)
if replace_src is not None: if replace_src is not None:
old_name = prg.src.split("__attribute__((noinline)) void ")[1].split("(")[0] old_name = prg.src.split("__attribute__((noinline)) void ")[1].split("(")[0]
prg = replace(prg, src=replace_src + "/* DSP boilerplate */" + prg.src.split("/* DSP boilerplate */")[1].replace(old_name, "fxn")) prg = replace(prg, src=replace_src + "/* DSP boilerplate */" + prg.src.split("/* DSP boilerplate */")[1].replace(old_name, "fxn"))
@@ -297,10 +294,7 @@ class TestDSPCache(unittest.TestCase):
x41,)),)),)),))""") x41,)),)),)),))""")
opts = [Opt(op=OptOps.UNROLL, axis=0, arg=8), Opt(op=OptOps.UPCAST, axis=1, arg=32), Opt(op=OptOps.UPCAST, axis=0, arg=4)] opts = [Opt(op=OptOps.UNROLL, axis=0, arg=8), Opt(op=OptOps.UPCAST, axis=1, arg=32), Opt(op=OptOps.UPCAST, axis=0, arg=4)]
with Context(DEVECTORIZE=0, QUANTIZE=1): with Context(DEVECTORIZE=0, QUANTIZE=1):
k = Kernel(ast, opts=Device[Device.DEFAULT].renderer) prg = get_program(ast, opts=opts)
k.apply_opts(opts)
prg = get_program(k.get_optimized_ast(), k.opts)
#print(prg.src)
new_src = """ new_src = """
typedef int int32 __attribute__((aligned(128),vector_size(128))); typedef int int32 __attribute__((aligned(128),vector_size(128)));
@@ -362,7 +356,7 @@ __attribute__((noinline)) void r_196_32_4_24_8(unsigned char* restrict __attribu
prg = replace(prg, src=new_src+prg.src.split("/* DSP boilerplate */ ")[1]) prg = replace(prg, src=new_src+prg.src.split("/* DSP boilerplate */ ")[1])
rt = CompiledRunner(prg) rt = CompiledRunner(prg)
#Device.default.compiler.disassemble(rt.lib) #Device.default.compiler.disassemble(rt.lib)
ei = ExecItem(rt, bufs_from_lin(k)) ei = ExecItem(rt, bufs_from_lin(Kernel(ast)))
tm = ei.run(wait=True) tm = ei.run(wait=True)
print(f"final time {tm*1e6:.2f} us") print(f"final time {tm*1e6:.2f} us")

View File

@@ -131,10 +131,11 @@ class TestBEAM(unittest.TestCase):
assert tm assert tm
def test_beam_unnamed_kernels(self): def test_beam_unnamed_kernels(self):
from test.test_linearizer import push_views
a = Tensor.rand(100) a = Tensor.rand(100)
b = Tensor.rand(100) b = Tensor.rand(100)
si = (a+b).schedule()[-1] si = (a+b).schedule()[-1]
lin = Kernel(si.ast) lin = Kernel(push_views(si.ast))
bufs = bufs_from_lin(lin) bufs = bufs_from_lin(lin)
# TODO: beam should have better instrumentation so we don't have to check this indirect thing # TODO: beam should have better instrumentation so we don't have to check this indirect thing
kcount = len(Kernel.kernel_cnt) kcount = len(Kernel.kernel_cnt)

View File

@@ -5,6 +5,7 @@ from tinygrad.device import Buffer
from tinygrad.opt.search import get_test_global_size, bufs_from_lin from tinygrad.opt.search import get_test_global_size, bufs_from_lin
from tinygrad.helpers import GlobalCounters from tinygrad.helpers import GlobalCounters
from extra.optimization.helpers import time_linearizer from extra.optimization.helpers import time_linearizer
from test.test_linearizer import push_views
class TestSearchUtil(unittest.TestCase): class TestSearchUtil(unittest.TestCase):
def test_get_test_global_size(self): def test_get_test_global_size(self):
@@ -25,7 +26,7 @@ class TestSearchUtil(unittest.TestCase):
a = Tensor.randn(4, 4).realize() a = Tensor.randn(4, 4).realize()
b = a+a[0] b = a+a[0]
si = b.schedule()[0] si = b.schedule()[0]
rawbufs = bufs_from_lin(Kernel(si.ast)) rawbufs = bufs_from_lin(Kernel(push_views(si.ast)))
assert len(rawbufs) == 2 assert len(rawbufs) == 2
assert all(r is not None for r in rawbufs) assert all(r is not None for r in rawbufs)
assert all(isinstance(r, Buffer) for r in rawbufs) assert all(isinstance(r, Buffer) for r in rawbufs)
@@ -38,13 +39,13 @@ class TestTimeLinearizer(unittest.TestCase):
si = (a+1).schedule()[0] si = (a+1).schedule()[0]
# create fresh empty buffers # create fresh empty buffers
rawbufs = [Buffer(b.device, b.size, b.dtype).allocate() for b in si.bufs] rawbufs = [Buffer(b.device, b.size, b.dtype).allocate() for b in si.bufs]
tm = time_linearizer(Kernel(si.ast), rawbufs, allow_test_size=False, cnt=10, disable_cache=True) tm = time_linearizer(Kernel(push_views(si.ast)), rawbufs, allow_test_size=False, cnt=10, disable_cache=True)
assert tm > 0 and tm != float('inf') assert tm > 0 and tm != float('inf')
# Ensure that the kernel count is not incremented by time_linearizer when clearing l2 # Ensure that the kernel count is not incremented by time_linearizer when clearing l2
def test_kernel_count(self): def test_kernel_count(self):
ast = Tensor.zeros(16).contiguous().kernelize().uop.src[1].arg.ast ast = Tensor.zeros(16).contiguous().kernelize().uop.src[1].arg.ast
lin = Kernel(ast) lin = Kernel(push_views(ast))
bufs = bufs_from_lin(lin) bufs = bufs_from_lin(lin)
kernel_count = GlobalCounters.kernel_count kernel_count = GlobalCounters.kernel_count

View File

@@ -73,14 +73,6 @@ class TestUOpSpec(unittest.TestCase):
st = UOp.store(bufs[0].view(ShapeTracker.from_shape((32, 1))), r+a) st = UOp.store(bufs[0].view(ShapeTracker.from_shape((32, 1))), r+a)
with self.assertRaises(InvalidASTException): helper_test_verify_ast(st) with self.assertRaises(InvalidASTException): helper_test_verify_ast(st)
def test_buffer_uops_st(self):
a = Tensor.randn(4, 4)+2
helper_test_verify_ast(ast:=a.schedule()[-1].ast)
store_st = [u.st for u in ast.toposort() if u.op is Ops.STORE][0]
self.assertEqual(store_st, ShapeTracker.from_shape((4, 4)))
const_st = [u.st for u in ast.toposort() if u.op is Ops.CONST][0]
self.assertEqual(const_st, ShapeTracker.from_shape((1, 1)).expand((4, 4)))
def test_assert_swizzle(self): def test_assert_swizzle(self):
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0) buf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0)
a = UOp(Ops.LOAD, dtypes.float, (buf.view(ShapeTracker.from_shape((32, 1))),)) a = UOp(Ops.LOAD, dtypes.float, (buf.view(ShapeTracker.from_shape((32, 1))),))

View File

@@ -17,6 +17,7 @@ from tinygrad.codegen.devectorizer import load_store_folding, load_store_indexin
from tinygrad.codegen.optional import get_late_rewrite_patterns from tinygrad.codegen.optional import get_late_rewrite_patterns
from tinygrad.codegen.linearize import block_create, pm_blockend_merge, block_merge, pm_finalize, BlockContext from tinygrad.codegen.linearize import block_create, pm_blockend_merge, block_merge, pm_finalize, BlockContext
from tinygrad.opt import pm_optimize from tinygrad.opt import pm_optimize
from tinygrad.opt.swizzler import view_left, view_right, fix_kernel_ops
@dataclass @dataclass
class RewriteStep: class RewriteStep:
@@ -29,6 +30,12 @@ class RewriteStep:
def apply_rewrites(sink:UOp, rewrites:list[RewriteStep]): return functools.reduce(lambda x,f: f(x), rewrites, sink) def apply_rewrites(sink:UOp, rewrites:list[RewriteStep]): return functools.reduce(lambda x,f: f(x), rewrites, sink)
rewrites_for_views = [
RewriteStep(view_left, name="Main View Left"),
RewriteStep(view_right, name="Main View Right"),
RewriteStep(view_left+fix_kernel_ops, bottom_up=True, name="Finalize Kernel"),
]
rewrites_for_linearizer = [ rewrites_for_linearizer = [
RewriteStep(block_create, ctx=BlockContext.from_sink, name="Linearizer: Create Blocks", bottom_up=True), RewriteStep(block_create, ctx=BlockContext.from_sink, name="Linearizer: Create Blocks", bottom_up=True),
RewriteStep(pm_blockend_merge, name="Linearizer: Merge Blockends"), RewriteStep(pm_blockend_merge, name="Linearizer: Merge Blockends"),
@@ -44,6 +51,9 @@ def _get_rewrites_for_renderer(opts:Renderer, linearizer:bool, _QUANTIZE, _DEVEC
# ** lowerer (rewrite_shapetracker_with_index) ** # ** lowerer (rewrite_shapetracker_with_index) **
ret: list[RewriteStep] = [] ret: list[RewriteStep] = []
# view pushing
ret.extend(rewrites_for_views)
# this is kernel.py # this is kernel.py
ret.append(RewriteStep(pm_optimize, ctx=lambda _: opts, name="optimize ast")) ret.append(RewriteStep(pm_optimize, ctx=lambda _: opts, name="optimize ast"))

View File

@@ -3,7 +3,8 @@ from tinygrad.helpers import all_int, prod, unwrap, dedup, DONT_REALIZE_EXPAND,
from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.shapetracker import ShapeTracker
ALWAYS_CONTIGUOUS = {Ops.CONTIGUOUS, Ops.ASSIGN, Ops.COPY, Ops.BUFFER, Ops.BUFFER_VIEW, ALWAYS_CONTIGUOUS = {Ops.CONTIGUOUS, Ops.ASSIGN, Ops.COPY, Ops.BUFFER, Ops.BUFFER_VIEW,
Ops.CONST, Ops.BIND, Ops.DEVICE, Ops.MSELECT, Ops.MSTACK, Ops.DEFINE_GLOBAL, Ops.LOAD} Ops.CONST, Ops.BIND, Ops.DEVICE, Ops.MSELECT, Ops.MSTACK, Ops.DEFINE_GLOBAL,
Ops.DEFINE_LOCAL, Ops.DEFINE_REG, Ops.LOAD}
# **** Grouper decides which of the UOps realize # **** Grouper decides which of the UOps realize

View File

@@ -7,7 +7,7 @@ from tinygrad.helpers import Metadata, all_int, all_same, prod, dedup, unwrap, g
from tinygrad.dtype import ImageDType from tinygrad.dtype import ImageDType
from tinygrad.schedule.multi import multi_pm from tinygrad.schedule.multi import multi_pm
from tinygrad.schedule.grouper import group_realizes, ALWAYS_CONTIGUOUS from tinygrad.schedule.grouper import group_realizes, ALWAYS_CONTIGUOUS
from tinygrad.opt.swizzler import merge_views, view_left, view_right, fix_kernel_ops, apply_swizzle, swizzle_reduceop from tinygrad.opt.swizzler import merge_views, apply_swizzle, swizzle_reduceop
# creation can recurse a lot # creation can recurse a lot
import sys import sys
@@ -194,10 +194,6 @@ def fix_kernel_ast(k:UOp) -> UOp|None:
ast = graph_rewrite(k.arg.ast, merge_views+replace_buffers, bufs, bottom_up=True, name="replace buffers") ast = graph_rewrite(k.arg.ast, merge_views+replace_buffers, bufs, bottom_up=True, name="replace buffers")
if ast.op is Ops.SINK and not all_same([x.device for x in k.src if x.op is not Ops.BIND]): if ast.op is Ops.SINK and not all_same([x.device for x in k.src if x.op is not Ops.BIND]):
raise RuntimeError(f"all buffers must be on the same device: {tuple(b.buf_uop.buffer for b in k.src)}") raise RuntimeError(f"all buffers must be on the same device: {tuple(b.buf_uop.buffer for b in k.src)}")
# TODO: move these to codegen
ast = graph_rewrite(ast, view_left, name="Main View Left")
ast = graph_rewrite(ast, view_right, name="Main View Right")
ast = graph_rewrite(ast, view_left+fix_kernel_ops, bottom_up=True, name="Finalize Kernel")
return k.replace(arg=Kernel(ast, k.arg.metadata)) return k.replace(arg=Kernel(ast, k.arg.metadata))
create_ast = PatternMatcher([ create_ast = PatternMatcher([