Files
tinygrad/extra/autopad.py
chenyu 8798d120bb autopad shapetracker for BEAM (#2375)
* autopad shapetracker for BEAM

* OptOps.PADTO

* skip that test for now

* correct padding reduce axis

* just 32

* avoid more than double the FLOPs

* cleanups

* test case

* no support for triton and llvm yet

* typos

* symbolic shape would not work

* cannot PADTO with MAX kernel

* advance db version

* no breaking change - don't advance db version

* is triton just python?

* Revert "is triton just python?"

This reverts commit 17e776c25587615e33a3634c2fb0bb8591ce65d4.

* Revert "Revert "is triton just python?""

This reverts commit 6c434c01e1c4b0ea0431ec18632cd859fb3cf260.

* support llvm

* is it really passing in CI only?

* update tests

* oh triton test passed

* simpler

* revert that, with a test

* check if st are the same

* Revert "check if st are the same"

This reverts commit d2a5eac110a5da1af82a2728c883779ef69c3cad.

* update the db version

* rebase artifact
2023-11-22 21:05:25 -05:00

40 lines
1.0 KiB
Python

from tinygrad.tensor import Tensor
from tinygrad.ops import LoadOps
from tinygrad.codegen.linearizer import Linearizer
from test.external.fuzz_linearizer import run_linearizer
from tinygrad.codegen.kernel import Opt, OptOps
N = 17**3
a = Tensor.rand(N, N)
b = Tensor.rand(N, N)
c = a @ b
sched = [si for si in c.lazydata.schedule() if si.ast.op not in LoadOps]
assert len(sched) == 1
lin = Linearizer(sched[0].ast)
lin.apply_opt(Opt(op=OptOps.PADTO, axis=0, amt=32))
lin.apply_opt(Opt(op=OptOps.PADTO, axis=1, amt=32))
lin.apply_opt(Opt(op=OptOps.PADTO, axis=2, amt=32))
lin.hand_coded_optimizations()
lin.linearize()
print(f"{lin.applied_opts=}")
run_linearizer(lin)
quit()
###
a = Tensor.rand(61, 61).sum(axis=0)
sched = [si for si in a.lazydata.schedule() if si.ast.op not in LoadOps]
assert len(sched) == 1
lin = Linearizer(sched[0].ast)
# lin.apply_opt(Opt(op=OptOps.LOCAL, axis=0, amt=32))
lin.apply_opt(Opt(op=OptOps.PADTO, axis=0, amt=32))
lin.apply_opt(Opt(op=OptOps.PADTO, axis=1, amt=32))
lin.hand_coded_optimizations()
lin.linearize()
run_linearizer(lin)