mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
kopt works with local+grouped reduce and tests (#1824)
This commit is contained in:
@@ -16,7 +16,7 @@ from examples.stable_diffusion import UNetModel
|
||||
|
||||
def kopt_search_hook(k, create_k, to_prg, baseline):
|
||||
import nevergrad as ng
|
||||
wanna_output = k.bufs[0].toCPU()
|
||||
wanna_output = k.bufs[0].toCPU().copy()
|
||||
def check_opt(x):
|
||||
try:
|
||||
k = create_k()
|
||||
|
||||
@@ -2,7 +2,7 @@ import numpy as np
|
||||
import unittest
|
||||
|
||||
from tinygrad.codegen.linearizer import Linearizer, UOps
|
||||
from tinygrad.ops import Compiled, Device
|
||||
from tinygrad.ops import Compiled, Device, MovementOps, LazyOp
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.jit import CacheCollector
|
||||
|
||||
@@ -85,5 +85,124 @@ class TestLinearizer(unittest.TestCase):
|
||||
num_ops = len([uop for uop in k.uops if uop.uop in [UOps.LOAD, UOps.ALU]])
|
||||
assert num_ops <= 0, "more load or alu uops than needed"
|
||||
|
||||
def helper_linearizer_opt(r:Tensor, opts=[]):
|
||||
wanna_output = None
|
||||
realized_ast = None
|
||||
|
||||
# HACK to get real ast.
|
||||
real_dev_exec_ast = Device[Device.DEFAULT].exec_ast
|
||||
def fake_exec_ast(ast, output=None, **kwargs):
|
||||
nonlocal realized_ast
|
||||
x = real_dev_exec_ast(ast, output, **kwargs)
|
||||
if not(ast.op in MovementOps and ast.src[0].__class__ is not LazyOp and ast.src[0].realized): realized_ast = ast # get last executed
|
||||
return x
|
||||
Device[Device.DEFAULT].exec_ast = fake_exec_ast
|
||||
r = r.realize() # realize an output buffer
|
||||
assert realized_ast is not None
|
||||
Device[Device.DEFAULT].exec_ast = real_dev_exec_ast
|
||||
|
||||
def check_opt(x, create_k, to_prg):
|
||||
k = create_k()
|
||||
k.process()
|
||||
k.apply_auto_opt(x)
|
||||
prg = to_prg(k)
|
||||
k.bufs[0].realized = k.bufs[0].realized.fromCPU(np.zeros(k.bufs[0].shape, dtype=k.bufs[0].dtype.np)) # Zero to check that all values are filled
|
||||
prg.exec(k.bufs, force_wait=True)
|
||||
np.testing.assert_allclose(wanna_output, k.bufs[0].toCPU(), atol=1e-4, rtol=1e-4)
|
||||
|
||||
# Get baseline, which is not optimized at all.
|
||||
k = Linearizer(realized_ast, r.lazydata, Device[Device.DEFAULT].linearizer_opts)
|
||||
k.process()
|
||||
prg = Device[Device.DEFAULT].to_program(k)
|
||||
prg.exec(k.bufs, force_wait=True)
|
||||
wanna_output = k.bufs[0].toCPU().copy()
|
||||
|
||||
# Check correctness of handcoded optimiztions.
|
||||
k = Linearizer(realized_ast, r.lazydata, Device[Device.DEFAULT].linearizer_opts)
|
||||
k.hand_coded_optimizations()
|
||||
prg = Device[Device.DEFAULT].to_program(k)
|
||||
k.bufs[0].realized = k.bufs[0].realized.fromCPU(np.zeros(k.bufs[0].shape, dtype=k.bufs[0].dtype.np)) # Zero to check that all values are filled
|
||||
prg.exec(k.bufs, force_wait=True)
|
||||
np.testing.assert_allclose(wanna_output, k.bufs[0].toCPU(), atol=1e-4, rtol=1e-4)
|
||||
for x in opts: # Check custom transformations if any.
|
||||
check_opt(x, lambda: Linearizer(realized_ast, r.lazydata, Device[Device.DEFAULT].linearizer_opts), Device[Device.DEFAULT].to_program)
|
||||
|
||||
class TestLinearizerOpts(unittest.TestCase):
|
||||
def test_local_and_grouped_reduce(self):
|
||||
if not isinstance(Device[Device.DEFAULT], Compiled) or not Device[Device.DEFAULT].linearizer_opts.has_local:
|
||||
self.skipTest("Only Compiled uses linearizer with locals")
|
||||
|
||||
N = 128
|
||||
Tensor.manual_seed(1882)
|
||||
a = Tensor.rand(4, 4, N, N)
|
||||
b = Tensor.rand(4, 4, N)
|
||||
r = (b.sqrt() + ((a+1).sum(axis=3).exp()))
|
||||
helper_linearizer_opt(r, [
|
||||
[(0, 2, 'L')], [(0, 8, 'L')], [(0, 16, 'L')], # Checking how it works with locals
|
||||
[(0, 2, 'G')], [(0, 32, 'G')], [(0, 64, 'G')], # Checking how it works with grouped reduce
|
||||
[(0, 2, 'L'), (0, 2, 'G')], [(0, 16, 'L'), (0, 16, 'G')], [(0, 32, 'L'), (0, 2, 'G')], [(0, 2, 'L'), (0, 64, 'G')], # Checking how it works with locals + grouped reduce
|
||||
[(0, 2, 'L'), (0, 2, 'G'), (0, 8, 'U'), (0, 4, 'R')], # Checking how it works with locals + grouped reduce + upcasts
|
||||
])
|
||||
|
||||
def test_upcasts(self):
|
||||
if not isinstance(Device[Device.DEFAULT], Compiled):
|
||||
self.skipTest("Only Compiled uses linearizer")
|
||||
|
||||
N = 16
|
||||
Tensor.manual_seed(1772)
|
||||
a = Tensor.rand(N, N)
|
||||
b = Tensor.rand(N, N)
|
||||
r = (a+b).sqrt() * ((a+1).exp())
|
||||
helper_linearizer_opt(r, [
|
||||
[(0, 2, 'U')], [(0, 4, 'U')], [(0, 8, 'U')], # Checking how it works with upcasts
|
||||
])
|
||||
|
||||
def test_full_upcast(self):
|
||||
if not isinstance(Device[Device.DEFAULT], Compiled):
|
||||
self.skipTest("Only Compiled uses linearizer")
|
||||
|
||||
Tensor.manual_seed(1772)
|
||||
a = Tensor.rand(4)
|
||||
b = Tensor.rand(4)
|
||||
r = (a+b).sqrt() * ((a+1).exp())
|
||||
helper_linearizer_opt(r, [
|
||||
[(0, 4, 'U')], # Checking how it works with upcasts
|
||||
])
|
||||
|
||||
def test_matmul(self):
|
||||
if not isinstance(Device[Device.DEFAULT], Compiled) or not Device[Device.DEFAULT].linearizer_opts.has_local:
|
||||
self.skipTest("Only Compiled uses linearizer with locals")
|
||||
|
||||
N = 128
|
||||
Tensor.manual_seed(1552)
|
||||
a = Tensor.rand(N, N)
|
||||
b = Tensor.rand(N, N)
|
||||
r = a@b
|
||||
helper_linearizer_opt(r, [
|
||||
[(0, 2, 'U')], [(0, 4, 'U'), (1, 4, 'U')], # Checking how it works with upcasts
|
||||
[(0, 2, 'L')], [(1, 32, 'L')], [(0, 4, 'L'), (1, 4, 'L')], [(0, 4, 'L'), (1, 32, 'L')], [(0, 16, 'L'), (1, 8, 'L')], # Checking how it works with locals
|
||||
[(0, 2, 'G')], [(0, 32, 'G')], [(0, 32, 'G'), (0, 4, 'R')], # Checking how it works with grouped_reduce
|
||||
[(0, 2, 'L'), (1, 2, 'L'), (0, 32, 'G')], [(0, 16, 'L'), (0, 32, 'G')], [(0, 16, 'L'), (0, 8, 'L'), (0, 4, 'G')], # Checking how it works with local+grouped_reduce
|
||||
[(0, 4, 'L'), (0, 4, 'L'), (0, 16, 'G'), (0, 4, 'R'), (0, 4, 'U'), (1, 2, 'U')], # Checking all together
|
||||
[(0, 4, 'L'), (0, 4, 'L'), (0, 16, 'G'), (0, 4, 'R'), (0, 8, 'U')], # Full global upcast + local
|
||||
])
|
||||
|
||||
def test_double_reduce(self):
|
||||
if not isinstance(Device[Device.DEFAULT], Compiled) or not Device[Device.DEFAULT].linearizer_opts.has_local:
|
||||
self.skipTest("Only Compiled uses linearizer with locals")
|
||||
|
||||
N = 128
|
||||
Tensor.manual_seed(1552)
|
||||
a = Tensor.rand(8, N, 8, N)
|
||||
r = a.sum(axis=(1,3))
|
||||
helper_linearizer_opt(r, [
|
||||
[(0, 2, 'G')], [(0, 32, 'G')], [(1, 2, 'G')], [(1, 32, 'G')], # Checking how it works with 1 grouped_reduce.
|
||||
[(0, 2, 'G'), (1, 2, 'G')], [(0, 16, 'G'), (1, 2, 'G')], [(0, 4, 'G'), (1, 64, 'G')], # Checking how it works with 2 grouped_reduces.
|
||||
[(0, 16, 'G'), (1, 2, 'G'), (1, 4, 'R')], [(0, 2, 'G'), (1, 32, 'G'), (1, 4, 'R')], # Checking how it works with 2 grouped_reduces + upcasts.
|
||||
[(0, 4, 'L'), (1, 4, 'L'), (0, 8, 'G'), (1, 4, 'G')], [(0, 4, 'L'), (1, 4, 'L'), (0, 2, 'G'), (1, 32, 'G'), (1, 4, 'R')], # Checking how it works with 2 grouped_reduces + upcasts + locals.
|
||||
[(0, 2, 'L'), (1, 2, 'L'), (0, 8, 'G'), (1, 4, 'G'), (0, 2, 'U')], [(0, 2, 'L'), (1, 2, 'L'), (0, 8, 'G'), (1, 4, 'G'), (0, 2, 'U'), (0, 4, 'R'), (1, 4, 'R')], # Checking how it works with 2 grouped_reduces + upcasts + locals.
|
||||
[(0, 4, 'L'), (1, 4, 'L'), (0, 8, 'G'), (1, 4, 'G'), (0, 2, 'U'), (1, 2, 'U')], # No globals
|
||||
])
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -131,17 +131,17 @@ class OptimizedKernel(Kernel):
|
||||
for axis, amt, typ in x:
|
||||
if axis is None or amt == 1: continue
|
||||
if typ == "G":
|
||||
assert self.full_shape[self.first_reduce+axis] % amt == 0, "no longer valid shift"
|
||||
self.shift_to(self.first_reduce+axis, amt, top=True, insert_before=self.first_reduce+axis+len(self.group_for_reduce))
|
||||
assert self.full_shape[self.first_reduce+axis+len(self.group_for_reduce)] % amt == 0, "no longer valid shift"
|
||||
self.shift_to(self.first_reduce+axis+len(self.group_for_reduce), amt, top=True, insert_before=self.first_reduce+len(self.group_for_reduce))
|
||||
self.group_for_reduce.append(amt)
|
||||
if typ == "R":
|
||||
typ = "U"
|
||||
axis += self.first_reduce
|
||||
if typ == "U" and (len(self.group_for_reduce) == 0 or self.first_reduce != axis):
|
||||
axis += self.first_reduce + len(self.group_for_reduce)
|
||||
if typ == "U":
|
||||
assert self.full_shape[axis] % amt == 0, "no longer valid shift"
|
||||
self.shift_to(axis, amt)
|
||||
self.upcast()
|
||||
elif typ == "L" and len(self.group_for_reduce) == 0: # TODO: Cannot mix local+group_for_reduce, codegen need to be fixed.
|
||||
elif typ == "L":
|
||||
assert self.full_shape[axis] % amt == 0, "no longer valid shift"
|
||||
self.shift_to(axis, amt, insert_before=self.first_reduce)
|
||||
self.local_dims += 1
|
||||
|
||||
@@ -11,15 +11,14 @@ def get_divisors(n, min_div = 1, max_div = 512):
|
||||
def kernel_optimize_opts(k:Linearizer):
|
||||
import nevergrad as ng
|
||||
opts = []
|
||||
if k.first_reduce < k.shape_len: # TODO: Grouped reduces do not work with other locals. More chances to mutate to 1, so locals can be used.
|
||||
opts.append(ng.p.TransitionChoice([(0,s,"G") for s in get_divisors(k.full_shape[k.first_reduce], min_div=16) if all(st.shape[k.first_reduce] % s == 0 or st.shape[k.first_reduce] == 1 for st in k.sts)], transitions=(0.8, 0.2)))
|
||||
for i in range(k.first_reduce):
|
||||
# TODO: the upcast always happen first, you might want to reverse this?
|
||||
# TODO: the order of the locals might improve things too
|
||||
opts.append(ng.p.TransitionChoice([(i,s,"U") for s in get_divisors(k.full_shape[i], max_div=32)]))
|
||||
opts.append(ng.p.TransitionChoice([(i,s,"L") for s in get_divisors(k.full_shape[i])]))
|
||||
opts.append(ng.p.TransitionChoice([(i,s,"U") for s in get_divisors(k.full_shape[i], max_div=8)]))
|
||||
opts.append(ng.p.TransitionChoice([(i,s,"L") for s in get_divisors(k.full_shape[i], min_div=4)]))
|
||||
for i in range(k.shape_len-k.first_reduce):
|
||||
opts.append(ng.p.TransitionChoice([(i,s,"R") for s in get_divisors(k.full_shape[k.first_reduce+i], max_div=32)]))
|
||||
opts.append(ng.p.TransitionChoice([(i,s,"R") for s in get_divisors(k.full_shape[k.first_reduce+i], max_div=8)]))
|
||||
opts.append(ng.p.TransitionChoice([(i,s,"G") for s in get_divisors(k.full_shape[k.first_reduce+i], min_div=4) if all(st.shape[k.first_reduce+i] % s == 0 or st.shape[k.first_reduce+i] == 1 for st in k.sts)]))
|
||||
return opts
|
||||
|
||||
def kernel_optimize_search(k:Linearizer, create_k:Callable[[], Linearizer], to_prg, baseline):
|
||||
|
||||
Reference in New Issue
Block a user