fix load_worlds filter_novariable (#7564)

filter based on "DEFINE_VAR" instead of "Variable". also added a unit test to make sure dataset includes image and variable kernels
This commit is contained in:
chenyu
2024-11-05 16:06:39 -05:00
committed by GitHub
parent c805e3fff5
commit e7b18cf5c0
3 changed files with 24 additions and 2 deletions

View File

@@ -32,8 +32,8 @@ def load_worlds(filter_reduce=True, filter_noimage=True, filter_novariable=True)
if DEBUG >= 1: print(f"loaded {len(ast_strs)=} before filters")
if filter_reduce: ast_strs = [x for x in ast_strs if "REDUCE_AXIS" in x]
if filter_noimage: ast_strs = [x for x in ast_strs if "dtypes.image" not in x]
if filter_novariable: ast_strs = [x for x in ast_strs if "Variable" not in x]
if DEBUG >= 1: print(f"loaded {len(ast_strs)=} after filters")
if filter_novariable: ast_strs = [x for x in ast_strs if "DEFINE_VAR" not in x]
if DEBUG >= 1: print(f"loaded {len(ast_strs)=} after filters {filter_reduce=}, {filter_noimage=}, {filter_novariable=}")
random.seed(1337)
random.shuffle(ast_strs)
return ast_strs

View File

@@ -0,0 +1,19 @@
import unittest
from extra.optimization.helpers import load_worlds
class TestKernelDataset(unittest.TestCase):
def test_load_worlds_filters(self):
all_kernels = load_worlds(filter_reduce=False, filter_noimage=False, filter_novariable=False)
reduce_kernels = load_worlds(filter_reduce=True, filter_noimage=False, filter_novariable=False)
self.assertGreater(len(all_kernels), len(reduce_kernels))
image_kernels = load_worlds(filter_reduce=False, filter_noimage=True, filter_novariable=False)
self.assertGreater(len(all_kernels), len(image_kernels))
variable_kernels = load_worlds(filter_reduce=False, filter_noimage=False, filter_novariable=True)
self.assertGreater(len(all_kernels), len(variable_kernels))
if __name__ == '__main__':
unittest.main()