mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
good RANGEIFY kernel counts in external_test_opt (#12242)
no push permute stuff. the model ones are less clear if it's good, some got slower
This commit is contained in:
12
test/external/external_test_opt.py
vendored
12
test/external/external_test_opt.py
vendored
@@ -4,7 +4,7 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
from tinygrad import GlobalCounters, Tensor, Device
|
||||
from tinygrad.helpers import getenv, Context
|
||||
from tinygrad.helpers import getenv, Context, RANGEIFY
|
||||
from tinygrad.nn.state import get_parameters
|
||||
from tinygrad.engine.realize import capturing
|
||||
from tinygrad.tensor import _to_np_dtype
|
||||
@@ -106,7 +106,7 @@ class TestOptBinOp(unittest.TestCase):
|
||||
def test_no_binop_rerun(self): return self._test_no_binop_rerun(lambda a,b: a*b, lambda a,b: (a*b).reshape(16, 16, 1))
|
||||
def test_no_binop_rerun_alt(self): return self._test_no_binop_rerun(lambda a,b: (a*b).reshape(16, 16, 1), lambda a,b: a*b)
|
||||
def test_no_binop_rerun_reduce_broadcast(self):
|
||||
return self._test_no_binop_rerun(lambda a,b: a.sum()+b, lambda a,b: a.sum().reshape(1,1)+b, allowed=2)
|
||||
return self._test_no_binop_rerun(lambda a,b: a.sum()+b, lambda a,b: a.sum().reshape(1,1)+b, allowed=1 if RANGEIFY else 2)
|
||||
|
||||
@unittest.skip("this test started failing with the new change, based movementop issue")
|
||||
def test_no_binop_rerun_transposed(self): return self._test_no_binop_rerun(lambda a,b: (a.T*b.T).T, lambda a,b: a*b)
|
||||
@@ -164,7 +164,7 @@ class TestOpt(unittest.TestCase):
|
||||
|
||||
def test_permute_was_pushed(self):
|
||||
a = Tensor.randn(16, 16, 16)
|
||||
with CLCache(2):
|
||||
with CLCache(1 if RANGEIFY else 2):
|
||||
c = a.sum(2)
|
||||
d = c.permute(1,0).contiguous()
|
||||
d.realize()
|
||||
@@ -172,7 +172,7 @@ class TestOpt(unittest.TestCase):
|
||||
|
||||
def test_permute_was_pushed_through_contract_reshape(self):
|
||||
a = Tensor.randn(4, 4, 4, 4, 4)
|
||||
with CLCache(2):
|
||||
with CLCache(1 if RANGEIFY else 2):
|
||||
c = a.sum(-1)
|
||||
d = c.reshape(16,16).permute(1,0).contiguous()
|
||||
d.realize()
|
||||
@@ -180,7 +180,7 @@ class TestOpt(unittest.TestCase):
|
||||
|
||||
def test_permute_was_pushed_through_contractw1s_reshape(self):
|
||||
a = Tensor.randn(4, 4, 4, 4, 4)
|
||||
with CLCache(2):
|
||||
with CLCache(1 if RANGEIFY else 2):
|
||||
c = a.sum(-1)
|
||||
d = c.reshape(16,1,16).permute(2,1,0).contiguous()
|
||||
d.realize()
|
||||
@@ -188,7 +188,7 @@ class TestOpt(unittest.TestCase):
|
||||
|
||||
def test_permute_was_pushed_through_expand_reshape(self):
|
||||
a = Tensor.randn(16, 16, 16)
|
||||
with CLCache(2):
|
||||
with CLCache(1 if RANGEIFY else 2):
|
||||
c = a.sum(2)
|
||||
d = c.reshape(4,4,4,4).permute(2,3,0,1).contiguous()
|
||||
d.realize()
|
||||
|
||||
Reference in New Issue
Block a user