kopt works with local+grouped reduce and tests (#1824)

This commit is contained in:
nimlgen
2023-09-09 23:22:09 +03:00
committed by GitHub
parent 9da40c8448
commit 31fca43706
4 changed files with 130 additions and 12 deletions

View File

@@ -2,7 +2,7 @@ import numpy as np
import unittest
from tinygrad.codegen.linearizer import Linearizer, UOps
from tinygrad.ops import Compiled, Device
from tinygrad.ops import Compiled, Device, MovementOps, LazyOp
from tinygrad.tensor import Tensor
from tinygrad.jit import CacheCollector
@@ -85,5 +85,124 @@ class TestLinearizer(unittest.TestCase):
num_ops = len([uop for uop in k.uops if uop.uop in [UOps.LOAD, UOps.ALU]])
assert num_ops <= 0, "more load or alu uops than needed"
def helper_linearizer_opt(r:Tensor, opts=[]):
wanna_output = None
realized_ast = None
# HACK to get real ast.
real_dev_exec_ast = Device[Device.DEFAULT].exec_ast
def fake_exec_ast(ast, output=None, **kwargs):
nonlocal realized_ast
x = real_dev_exec_ast(ast, output, **kwargs)
if not(ast.op in MovementOps and ast.src[0].__class__ is not LazyOp and ast.src[0].realized): realized_ast = ast # get last executed
return x
Device[Device.DEFAULT].exec_ast = fake_exec_ast
r = r.realize() # realize an output buffer
assert realized_ast is not None
Device[Device.DEFAULT].exec_ast = real_dev_exec_ast
def check_opt(x, create_k, to_prg):
k = create_k()
k.process()
k.apply_auto_opt(x)
prg = to_prg(k)
k.bufs[0].realized = k.bufs[0].realized.fromCPU(np.zeros(k.bufs[0].shape, dtype=k.bufs[0].dtype.np)) # Zero to check that all values are filled
prg.exec(k.bufs, force_wait=True)
np.testing.assert_allclose(wanna_output, k.bufs[0].toCPU(), atol=1e-4, rtol=1e-4)
# Get baseline, which is not optimized at all.
k = Linearizer(realized_ast, r.lazydata, Device[Device.DEFAULT].linearizer_opts)
k.process()
prg = Device[Device.DEFAULT].to_program(k)
prg.exec(k.bufs, force_wait=True)
wanna_output = k.bufs[0].toCPU().copy()
# Check correctness of handcoded optimiztions.
k = Linearizer(realized_ast, r.lazydata, Device[Device.DEFAULT].linearizer_opts)
k.hand_coded_optimizations()
prg = Device[Device.DEFAULT].to_program(k)
k.bufs[0].realized = k.bufs[0].realized.fromCPU(np.zeros(k.bufs[0].shape, dtype=k.bufs[0].dtype.np)) # Zero to check that all values are filled
prg.exec(k.bufs, force_wait=True)
np.testing.assert_allclose(wanna_output, k.bufs[0].toCPU(), atol=1e-4, rtol=1e-4)
for x in opts: # Check custom transformations if any.
check_opt(x, lambda: Linearizer(realized_ast, r.lazydata, Device[Device.DEFAULT].linearizer_opts), Device[Device.DEFAULT].to_program)
class TestLinearizerOpts(unittest.TestCase):
def test_local_and_grouped_reduce(self):
if not isinstance(Device[Device.DEFAULT], Compiled) or not Device[Device.DEFAULT].linearizer_opts.has_local:
self.skipTest("Only Compiled uses linearizer with locals")
N = 128
Tensor.manual_seed(1882)
a = Tensor.rand(4, 4, N, N)
b = Tensor.rand(4, 4, N)
r = (b.sqrt() + ((a+1).sum(axis=3).exp()))
helper_linearizer_opt(r, [
[(0, 2, 'L')], [(0, 8, 'L')], [(0, 16, 'L')], # Checking how it works with locals
[(0, 2, 'G')], [(0, 32, 'G')], [(0, 64, 'G')], # Checking how it works with grouped reduce
[(0, 2, 'L'), (0, 2, 'G')], [(0, 16, 'L'), (0, 16, 'G')], [(0, 32, 'L'), (0, 2, 'G')], [(0, 2, 'L'), (0, 64, 'G')], # Checking how it works with locals + grouped reduce
[(0, 2, 'L'), (0, 2, 'G'), (0, 8, 'U'), (0, 4, 'R')], # Checking how it works with locals + grouped reduce + upcasts
])
def test_upcasts(self):
if not isinstance(Device[Device.DEFAULT], Compiled):
self.skipTest("Only Compiled uses linearizer")
N = 16
Tensor.manual_seed(1772)
a = Tensor.rand(N, N)
b = Tensor.rand(N, N)
r = (a+b).sqrt() * ((a+1).exp())
helper_linearizer_opt(r, [
[(0, 2, 'U')], [(0, 4, 'U')], [(0, 8, 'U')], # Checking how it works with upcasts
])
def test_full_upcast(self):
if not isinstance(Device[Device.DEFAULT], Compiled):
self.skipTest("Only Compiled uses linearizer")
Tensor.manual_seed(1772)
a = Tensor.rand(4)
b = Tensor.rand(4)
r = (a+b).sqrt() * ((a+1).exp())
helper_linearizer_opt(r, [
[(0, 4, 'U')], # Checking how it works with upcasts
])
def test_matmul(self):
if not isinstance(Device[Device.DEFAULT], Compiled) or not Device[Device.DEFAULT].linearizer_opts.has_local:
self.skipTest("Only Compiled uses linearizer with locals")
N = 128
Tensor.manual_seed(1552)
a = Tensor.rand(N, N)
b = Tensor.rand(N, N)
r = a@b
helper_linearizer_opt(r, [
[(0, 2, 'U')], [(0, 4, 'U'), (1, 4, 'U')], # Checking how it works with upcasts
[(0, 2, 'L')], [(1, 32, 'L')], [(0, 4, 'L'), (1, 4, 'L')], [(0, 4, 'L'), (1, 32, 'L')], [(0, 16, 'L'), (1, 8, 'L')], # Checking how it works with locals
[(0, 2, 'G')], [(0, 32, 'G')], [(0, 32, 'G'), (0, 4, 'R')], # Checking how it works with grouped_reduce
[(0, 2, 'L'), (1, 2, 'L'), (0, 32, 'G')], [(0, 16, 'L'), (0, 32, 'G')], [(0, 16, 'L'), (0, 8, 'L'), (0, 4, 'G')], # Checking how it works with local+grouped_reduce
[(0, 4, 'L'), (0, 4, 'L'), (0, 16, 'G'), (0, 4, 'R'), (0, 4, 'U'), (1, 2, 'U')], # Checking all together
[(0, 4, 'L'), (0, 4, 'L'), (0, 16, 'G'), (0, 4, 'R'), (0, 8, 'U')], # Full global upcast + local
])
def test_double_reduce(self):
if not isinstance(Device[Device.DEFAULT], Compiled) or not Device[Device.DEFAULT].linearizer_opts.has_local:
self.skipTest("Only Compiled uses linearizer with locals")
N = 128
Tensor.manual_seed(1552)
a = Tensor.rand(8, N, 8, N)
r = a.sum(axis=(1,3))
helper_linearizer_opt(r, [
[(0, 2, 'G')], [(0, 32, 'G')], [(1, 2, 'G')], [(1, 32, 'G')], # Checking how it works with 1 grouped_reduce.
[(0, 2, 'G'), (1, 2, 'G')], [(0, 16, 'G'), (1, 2, 'G')], [(0, 4, 'G'), (1, 64, 'G')], # Checking how it works with 2 grouped_reduces.
[(0, 16, 'G'), (1, 2, 'G'), (1, 4, 'R')], [(0, 2, 'G'), (1, 32, 'G'), (1, 4, 'R')], # Checking how it works with 2 grouped_reduces + upcasts.
[(0, 4, 'L'), (1, 4, 'L'), (0, 8, 'G'), (1, 4, 'G')], [(0, 4, 'L'), (1, 4, 'L'), (0, 2, 'G'), (1, 32, 'G'), (1, 4, 'R')], # Checking how it works with 2 grouped_reduces + upcasts + locals.
[(0, 2, 'L'), (1, 2, 'L'), (0, 8, 'G'), (1, 4, 'G'), (0, 2, 'U')], [(0, 2, 'L'), (1, 2, 'L'), (0, 8, 'G'), (1, 4, 'G'), (0, 2, 'U'), (0, 4, 'R'), (1, 4, 'R')], # Checking how it works with 2 grouped_reduces + upcasts + locals.
[(0, 4, 'L'), (1, 4, 'L'), (0, 8, 'G'), (1, 4, 'G'), (0, 2, 'U'), (1, 2, 'U')], # No globals
])
if __name__ == '__main__':
unittest.main()