lil work on llvm speed (#10157)

* lil work on llvm speed

* llvm failing test

* 1e-4

* simpler failing test

* once is fine

* gpt suggests this syntax change

* bump that debug
This commit is contained in:
George Hotz
2025-05-04 19:37:26 -04:00
committed by GitHub
parent 36ccaa88a6
commit a0240d8c2b
6 changed files with 56 additions and 8 deletions

View File

@@ -29,9 +29,9 @@ jobs:
- name: External Benchmark Schedule
run: PYTHONPATH="." python3 test/external/external_benchmark_schedule.py
- name: Speed Test
run: LLVM=1 LLVMOPT=1 python3 test/test_speed_v_torch.py
run: LLVM=1 python3 test/test_speed_v_torch.py
- name: Speed Test (BEAM=2)
run: BEAM=2 LLVM=1 LLVMOPT=1 python3 test/test_speed_v_torch.py
run: BEAM=2 LLVM=1 python3 test/test_speed_v_torch.py
docs:
name: Docs

View File

@@ -40,7 +40,7 @@ if __name__ == "__main__":
else: k.apply_opts(hand_coded_optimizations(k))
kernels.append(k)
with Timing("***** model prep in "):
with Timing("***** model prep in "):
kernels = [(k, k.get_optimized_ast(), get_rewrites_for_renderer(k.opts, linearizer=LINEARIZE)) for k in kernels]
with Profiling(PROFILE, fn="/tmp/rewrite.prof"):

45
test/test_opt_gemm.py Normal file
View File

@@ -0,0 +1,45 @@
import numpy as np
import unittest
from tinygrad import Tensor
from tinygrad.helpers import get_single_element
from tinygrad.codegen.kernel import Kernel, Opt, OptOps
from tinygrad.engine.realize import CompiledRunner, ExecItem
class TestOptGemm(unittest.TestCase):
@classmethod
def setUpClass(cls):
N = 64
cls.a = Tensor.randn(N, N).contiguous().realize()
cls.b = Tensor.randn(N, N).contiguous().realize()
cls.res = cls.a.T.numpy() @ cls.b.T.numpy()
def _test_gemm_unrolled_permute_l(self, opts=[]):
t = self.a.T @ self.b.T
# TODO: this should be a generic test helper
si = get_single_element(t.schedule())
k = Kernel(si.ast)
k.apply_opts(opts)
run = CompiledRunner(k.to_program())
ExecItem(run, si.bufs).run()
test = si.bufs[0].numpy().reshape(self.res.shape)
np.testing.assert_allclose(self.res, test, atol=1e-4)
def test_gemm_unrolled_permute_l_44(self):
opts = [Opt(op=OptOps.UPCAST, axis=0, arg=4), Opt(op=OptOps.UPCAST, axis=1, arg=4)]
self._test_gemm_unrolled_permute_l(opts)
def test_gemm_unrolled_permute_l_424(self):
# was failing with LLVM
opts = [Opt(op=OptOps.UPCAST, axis=0, arg=4), Opt(op=OptOps.UPCAST, axis=1, arg=2), Opt(op=OptOps.UPCAST, axis=0, arg=4)]
self._test_gemm_unrolled_permute_l(opts)
def test_gemm_unrolled_permute_l_42(self):
opts = [Opt(op=OptOps.UPCAST, axis=0, arg=4), Opt(op=OptOps.UPCAST, axis=1, arg=2)]
self._test_gemm_unrolled_permute_l(opts)
def test_gemm_unrolled_permute_l_22(self):
opts = [Opt(op=OptOps.UPCAST, axis=0, arg=2), Opt(op=OptOps.UPCAST, axis=1, arg=2)]
self._test_gemm_unrolled_permute_l(opts)
if __name__ == '__main__':
unittest.main()

View File

@@ -3,6 +3,7 @@ os.environ["NVIDIA_TF32_OVERRIDE"] = "0"
os.environ["MKL_NUM_THREADS"] = "1"
os.environ["NUMEXPR_NUM_THREADS"] = "1"
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["VECLIB_MAXIMUM_THREADS"] = "1"
import unittest
import torch
torch.set_num_threads(1)
@@ -47,14 +48,16 @@ def helper_test_speed(f1, *args):
cache_defeat += 1
# manual pre sync
if isinstance(args[0], Tensor): Device[args[0].device].synchronize()
if isinstance(args[0], Tensor):
local_device = Device[args[0].device]
local_device.synchronize()
else: sync()
GlobalCounters.global_ops = 0
GlobalCounters.global_mem = 0
st = time.perf_counter()
ret = f1(*args)
if isinstance(ret, Tensor): Device[ret.device].synchronize()
if isinstance(ret, Tensor): local_device.synchronize()
else: sync()
et = (time.perf_counter() - st) * 1000
if i >= 1: ets.append(et)

View File

@@ -184,7 +184,7 @@ def db_connection():
# another connection has set it already or is in the process of setting it
# that connection will lock the database
with contextlib.suppress(sqlite3.OperationalError): _db_connection.execute("PRAGMA journal_mode=WAL").fetchone()
if DEBUG >= 7: _db_connection.set_trace_callback(print)
if DEBUG >= 8: _db_connection.set_trace_callback(print)
return _db_connection
def diskcache_clear():

View File

@@ -98,7 +98,7 @@ base_rewrite = PatternMatcher([
(UPat(Ops.RANGE, name="x"), lambda ctx,x:
f" br label %loop_entry_{x.arg}\nloop_entry_{x.arg}:\n"
f" br label %loop_body_{x.arg}\nloop_body_{x.arg}:\n"
f" {ctx[x]} = phi {ldt(x.dtype)} [0, %loop_entry_{x.arg}], [{ctx[x]}phi, %loop_latch_{x.arg}]"),
f" {ctx[x]} = phi {ldt(x.dtype)} [ 0, %loop_entry_{x.arg} ], [ {ctx[x]}phi, %loop_latch_{x.arg} ]"),
(UPat(Ops.ENDRANGE, name="x"), lambda ctx,x:
f" br label %loop_latch_{x.src[0].arg}\nloop_latch_{x.src[0].arg}:\n"
f" {ctx[x.src[0]]}phi = add i32 {ctx[x.src[0]]}, 1\n {ctx[x]} = icmp ult i32 {ctx[x.src[0]]}phi, {ctx[x.src[0].src[0]]}\n"
@@ -204,7 +204,7 @@ class LLVMRenderer(Renderer):
for x in acc_to_assign:
if u in x.src: # if this range is relevant for this acc
vc += 1
kernel.append(f" %acc{vc} = phi {ldt(x.dtype)}" f"[{r[x]}, %loop_entry_{u.arg}], [{r[acc_to_assign[x]]}, %loop_latch_{u.arg}]")
kernel.append(f" %acc{vc} = phi {ldt(x.dtype)} [ {r[x]}, %loop_entry_{u.arg} ], [ {r[acc_to_assign[x]]}, %loop_latch_{u.arg} ]")
r[x] = f"%acc{vc}"
return tuple(local_args), self._render_fn(name, args, kernel, prefix)