mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
fix load_worlds filter_novariable (#7564)
filter based on "DEFINE_VAR" instead of "Variable". also added a unit test to make sure dataset includes image and variable kernels
This commit is contained in:
@@ -32,8 +32,8 @@ def load_worlds(filter_reduce=True, filter_noimage=True, filter_novariable=True)
|
||||
if DEBUG >= 1: print(f"loaded {len(ast_strs)=} before filters")
|
||||
if filter_reduce: ast_strs = [x for x in ast_strs if "REDUCE_AXIS" in x]
|
||||
if filter_noimage: ast_strs = [x for x in ast_strs if "dtypes.image" not in x]
|
||||
if filter_novariable: ast_strs = [x for x in ast_strs if "Variable" not in x]
|
||||
if DEBUG >= 1: print(f"loaded {len(ast_strs)=} after filters")
|
||||
if filter_novariable: ast_strs = [x for x in ast_strs if "DEFINE_VAR" not in x]
|
||||
if DEBUG >= 1: print(f"loaded {len(ast_strs)=} after filters {filter_reduce=}, {filter_noimage=}, {filter_novariable=}")
|
||||
random.seed(1337)
|
||||
random.shuffle(ast_strs)
|
||||
return ast_strs
|
||||
|
||||
19
extra/optimization/test_helpers.py
Normal file
19
extra/optimization/test_helpers.py
Normal file
@@ -0,0 +1,19 @@
|
||||
import unittest
|
||||
|
||||
from extra.optimization.helpers import load_worlds
|
||||
|
||||
class TestKernelDataset(unittest.TestCase):
|
||||
def test_load_worlds_filters(self):
|
||||
all_kernels = load_worlds(filter_reduce=False, filter_noimage=False, filter_novariable=False)
|
||||
|
||||
reduce_kernels = load_worlds(filter_reduce=True, filter_noimage=False, filter_novariable=False)
|
||||
self.assertGreater(len(all_kernels), len(reduce_kernels))
|
||||
|
||||
image_kernels = load_worlds(filter_reduce=False, filter_noimage=True, filter_novariable=False)
|
||||
self.assertGreater(len(all_kernels), len(image_kernels))
|
||||
|
||||
variable_kernels = load_worlds(filter_reduce=False, filter_noimage=False, filter_novariable=True)
|
||||
self.assertGreater(len(all_kernels), len(variable_kernels))
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user