mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
lil work on llvm speed (#10157)
* lil work on llvm speed * llvm failing test * 1e-4 * simpler failing test * once is fine * gpt suggests this syntax change * bump that debug
This commit is contained in:
4
.github/workflows/test.yml
vendored
4
.github/workflows/test.yml
vendored
@@ -29,9 +29,9 @@ jobs:
|
||||
- name: External Benchmark Schedule
|
||||
run: PYTHONPATH="." python3 test/external/external_benchmark_schedule.py
|
||||
- name: Speed Test
|
||||
run: LLVM=1 LLVMOPT=1 python3 test/test_speed_v_torch.py
|
||||
run: LLVM=1 python3 test/test_speed_v_torch.py
|
||||
- name: Speed Test (BEAM=2)
|
||||
run: BEAM=2 LLVM=1 LLVMOPT=1 python3 test/test_speed_v_torch.py
|
||||
run: BEAM=2 LLVM=1 python3 test/test_speed_v_torch.py
|
||||
|
||||
docs:
|
||||
name: Docs
|
||||
|
||||
2
test/external/external_benchmark_schedule.py
vendored
2
test/external/external_benchmark_schedule.py
vendored
@@ -40,7 +40,7 @@ if __name__ == "__main__":
|
||||
else: k.apply_opts(hand_coded_optimizations(k))
|
||||
kernels.append(k)
|
||||
|
||||
with Timing("***** model prep in "):
|
||||
with Timing("***** model prep in "):
|
||||
kernels = [(k, k.get_optimized_ast(), get_rewrites_for_renderer(k.opts, linearizer=LINEARIZE)) for k in kernels]
|
||||
|
||||
with Profiling(PROFILE, fn="/tmp/rewrite.prof"):
|
||||
|
||||
45
test/test_opt_gemm.py
Normal file
45
test/test_opt_gemm.py
Normal file
@@ -0,0 +1,45 @@
|
||||
import numpy as np
|
||||
import unittest
|
||||
from tinygrad import Tensor
|
||||
from tinygrad.helpers import get_single_element
|
||||
from tinygrad.codegen.kernel import Kernel, Opt, OptOps
|
||||
from tinygrad.engine.realize import CompiledRunner, ExecItem
|
||||
|
||||
class TestOptGemm(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
N = 64
|
||||
cls.a = Tensor.randn(N, N).contiguous().realize()
|
||||
cls.b = Tensor.randn(N, N).contiguous().realize()
|
||||
cls.res = cls.a.T.numpy() @ cls.b.T.numpy()
|
||||
|
||||
def _test_gemm_unrolled_permute_l(self, opts=[]):
|
||||
t = self.a.T @ self.b.T
|
||||
# TODO: this should be a generic test helper
|
||||
si = get_single_element(t.schedule())
|
||||
k = Kernel(si.ast)
|
||||
k.apply_opts(opts)
|
||||
run = CompiledRunner(k.to_program())
|
||||
ExecItem(run, si.bufs).run()
|
||||
test = si.bufs[0].numpy().reshape(self.res.shape)
|
||||
np.testing.assert_allclose(self.res, test, atol=1e-4)
|
||||
|
||||
def test_gemm_unrolled_permute_l_44(self):
|
||||
opts = [Opt(op=OptOps.UPCAST, axis=0, arg=4), Opt(op=OptOps.UPCAST, axis=1, arg=4)]
|
||||
self._test_gemm_unrolled_permute_l(opts)
|
||||
|
||||
def test_gemm_unrolled_permute_l_424(self):
|
||||
# was failing with LLVM
|
||||
opts = [Opt(op=OptOps.UPCAST, axis=0, arg=4), Opt(op=OptOps.UPCAST, axis=1, arg=2), Opt(op=OptOps.UPCAST, axis=0, arg=4)]
|
||||
self._test_gemm_unrolled_permute_l(opts)
|
||||
|
||||
def test_gemm_unrolled_permute_l_42(self):
|
||||
opts = [Opt(op=OptOps.UPCAST, axis=0, arg=4), Opt(op=OptOps.UPCAST, axis=1, arg=2)]
|
||||
self._test_gemm_unrolled_permute_l(opts)
|
||||
|
||||
def test_gemm_unrolled_permute_l_22(self):
|
||||
opts = [Opt(op=OptOps.UPCAST, axis=0, arg=2), Opt(op=OptOps.UPCAST, axis=1, arg=2)]
|
||||
self._test_gemm_unrolled_permute_l(opts)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -3,6 +3,7 @@ os.environ["NVIDIA_TF32_OVERRIDE"] = "0"
|
||||
os.environ["MKL_NUM_THREADS"] = "1"
|
||||
os.environ["NUMEXPR_NUM_THREADS"] = "1"
|
||||
os.environ["OMP_NUM_THREADS"] = "1"
|
||||
os.environ["VECLIB_MAXIMUM_THREADS"] = "1"
|
||||
import unittest
|
||||
import torch
|
||||
torch.set_num_threads(1)
|
||||
@@ -47,14 +48,16 @@ def helper_test_speed(f1, *args):
|
||||
cache_defeat += 1
|
||||
|
||||
# manual pre sync
|
||||
if isinstance(args[0], Tensor): Device[args[0].device].synchronize()
|
||||
if isinstance(args[0], Tensor):
|
||||
local_device = Device[args[0].device]
|
||||
local_device.synchronize()
|
||||
else: sync()
|
||||
|
||||
GlobalCounters.global_ops = 0
|
||||
GlobalCounters.global_mem = 0
|
||||
st = time.perf_counter()
|
||||
ret = f1(*args)
|
||||
if isinstance(ret, Tensor): Device[ret.device].synchronize()
|
||||
if isinstance(ret, Tensor): local_device.synchronize()
|
||||
else: sync()
|
||||
et = (time.perf_counter() - st) * 1000
|
||||
if i >= 1: ets.append(et)
|
||||
|
||||
@@ -184,7 +184,7 @@ def db_connection():
|
||||
# another connection has set it already or is in the process of setting it
|
||||
# that connection will lock the database
|
||||
with contextlib.suppress(sqlite3.OperationalError): _db_connection.execute("PRAGMA journal_mode=WAL").fetchone()
|
||||
if DEBUG >= 7: _db_connection.set_trace_callback(print)
|
||||
if DEBUG >= 8: _db_connection.set_trace_callback(print)
|
||||
return _db_connection
|
||||
|
||||
def diskcache_clear():
|
||||
|
||||
@@ -98,7 +98,7 @@ base_rewrite = PatternMatcher([
|
||||
(UPat(Ops.RANGE, name="x"), lambda ctx,x:
|
||||
f" br label %loop_entry_{x.arg}\nloop_entry_{x.arg}:\n"
|
||||
f" br label %loop_body_{x.arg}\nloop_body_{x.arg}:\n"
|
||||
f" {ctx[x]} = phi {ldt(x.dtype)} [0, %loop_entry_{x.arg}], [{ctx[x]}phi, %loop_latch_{x.arg}]"),
|
||||
f" {ctx[x]} = phi {ldt(x.dtype)} [ 0, %loop_entry_{x.arg} ], [ {ctx[x]}phi, %loop_latch_{x.arg} ]"),
|
||||
(UPat(Ops.ENDRANGE, name="x"), lambda ctx,x:
|
||||
f" br label %loop_latch_{x.src[0].arg}\nloop_latch_{x.src[0].arg}:\n"
|
||||
f" {ctx[x.src[0]]}phi = add i32 {ctx[x.src[0]]}, 1\n {ctx[x]} = icmp ult i32 {ctx[x.src[0]]}phi, {ctx[x.src[0].src[0]]}\n"
|
||||
@@ -204,7 +204,7 @@ class LLVMRenderer(Renderer):
|
||||
for x in acc_to_assign:
|
||||
if u in x.src: # if this range is relevant for this acc
|
||||
vc += 1
|
||||
kernel.append(f" %acc{vc} = phi {ldt(x.dtype)}" f"[{r[x]}, %loop_entry_{u.arg}], [{r[acc_to_assign[x]]}, %loop_latch_{u.arg}]")
|
||||
kernel.append(f" %acc{vc} = phi {ldt(x.dtype)} [ {r[x]}, %loop_entry_{u.arg} ], [ {r[acc_to_assign[x]]}, %loop_latch_{u.arg} ]")
|
||||
r[x] = f"%acc{vc}"
|
||||
return tuple(local_args), self._render_fn(name, args, kernel, prefix)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user