diff --git a/test/external/external_benchmark_schedule.py b/test/external/external_benchmark_schedule.py index da4e8dd293..5b50669439 100644 --- a/test/external/external_benchmark_schedule.py +++ b/test/external/external_benchmark_schedule.py @@ -6,7 +6,7 @@ from tinygrad.ops import Ops from tinygrad.codegen.kernel import Kernel from tinygrad.codegen.lowerer import rewrite_shapetracker_with_index from tinygrad.codegen.linearize import linearize_uop -from tinygrad.codegen.rewriter import full_graph_rewrite +from tinygrad.codegen.devectorizer import full_graph_rewrite from tinygrad.engine.search import beam_search, bufs_from_lin if __name__ == "__main__": diff --git a/test/helpers.py b/test/helpers.py index c04f5d333d..e76da3d02b 100644 --- a/test/helpers.py +++ b/test/helpers.py @@ -10,7 +10,7 @@ from tinygrad.dtype import ConstType, DType from tinygrad.nn.state import get_parameters from tinygrad.helpers import T, unwrap from tinygrad.codegen.linearize import linearize_uop -from tinygrad.codegen.rewriter import full_graph_rewrite +from tinygrad.codegen.devectorizer import full_graph_rewrite from tinygrad.runtime.ops_python import PythonProgram, PythonRenderer, PythonCompiler, PythonAllocator def derandomize_model(model): diff --git a/test/test_const_folding.py b/test/test_const_folding.py index ebe4fc9226..8e4a1a6c42 100644 --- a/test/test_const_folding.py +++ b/test/test_const_folding.py @@ -4,7 +4,7 @@ from tinygrad import Tensor, Device, dtypes from tinygrad.dtype import DType from tinygrad.ops import Ops, UOp from tinygrad.helpers import CI -from tinygrad.codegen.rewriter import full_graph_rewrite +from tinygrad.codegen.devectorizer import full_graph_rewrite import numpy as np from tinygrad.device import is_dtype_supported diff --git a/test/test_pickle.py b/test/test_pickle.py index 969656e23e..3bfae36c89 100644 --- a/test/test_pickle.py +++ b/test/test_pickle.py @@ -20,7 +20,7 @@ class TestPickle(unittest.TestCase): self.assertEqual(pm2.rewrite(sink).key, tt.key) def test_pickle_main_pattern_matcher(self): - from tinygrad.codegen.rewriter import sym + from tinygrad.codegen.devectorizer import sym pickle.dumps(sym) def test_pickle_realized_tensor(self): diff --git a/test/test_renderer_failures.py b/test/test_renderer_failures.py index 645f57ff24..474d5a6007 100644 --- a/test/test_renderer_failures.py +++ b/test/test_renderer_failures.py @@ -1,7 +1,7 @@ import unittest from typing import List, cast import numpy as np -from tinygrad.codegen.rewriter import full_graph_rewrite +from tinygrad.codegen.devectorizer import full_graph_rewrite from tinygrad.codegen.linearize import linearize_uop from tinygrad.device import Buffer, Device, is_dtype_supported from tinygrad.dtype import dtypes diff --git a/test/test_tensor.py b/test/test_tensor.py index 6ab3f504f3..fb9e3b57ff 100644 --- a/test/test_tensor.py +++ b/test/test_tensor.py @@ -10,7 +10,7 @@ from tinygrad.device import is_dtype_supported from tinygrad.ops import Ops, UOp from tinygrad.runtime.support.compiler_cuda import PTX from tinygrad.codegen.linearize import linearize_uop -from tinygrad.codegen.rewriter import full_graph_rewrite +from tinygrad.codegen.devectorizer import full_graph_rewrite from tinygrad.codegen.lowerer import rewrite_shapetracker_with_index from tinygrad.dtype import DType diff --git a/test/test_uop_graph.py b/test/test_uop_graph.py index 1a97dadf59..ee06194af2 100644 --- a/test/test_uop_graph.py +++ b/test/test_uop_graph.py @@ -5,7 +5,7 @@ from tinygrad.helpers import DEBUG, AMX from tinygrad.ops import Ops, UOp, KernelInfo, UPat, PatternMatcher from tinygrad.renderer import Renderer from tinygrad.codegen.lowerer import rewrite_shapetracker_with_index -from tinygrad.codegen.rewriter import full_graph_rewrite, graph_rewrite, sym +from tinygrad.codegen.devectorizer import full_graph_rewrite, graph_rewrite, sym from tinygrad.codegen.expander import expander, expand_rewrite from tinygrad.codegen.linearize import linearize_uop from tinygrad.shape.shapetracker import ShapeTracker, View diff --git a/test/test_uops.py b/test/test_uops.py index f0de508016..58f23b29d9 100644 --- a/test/test_uops.py +++ b/test/test_uops.py @@ -13,7 +13,7 @@ from tinygrad.renderer import ProgramSpec from tinygrad.engine.schedule import fix_kernel_ops from tinygrad.engine.realize import CompiledRunner, get_kernel from tinygrad.codegen.linearize import linearize_uop -from tinygrad.codegen.rewriter import full_graph_rewrite +from tinygrad.codegen.devectorizer import full_graph_rewrite from tinygrad.codegen.symbolic import sym from tinygrad.device import is_dtype_supported from tinygrad.codegen.kernel import Kernel, Opt, OptOps diff --git a/test/unit/test_graph_rewrite.py b/test/unit/test_graph_rewrite.py index 18b8bbb1d4..40538daa7c 100644 --- a/test/unit/test_graph_rewrite.py +++ b/test/unit/test_graph_rewrite.py @@ -2,7 +2,7 @@ import unittest, math from tinygrad import dtypes from tinygrad.helpers import all_same from tinygrad.ops import GroupOp, UOp, Ops, exec_alu -from tinygrad.codegen.rewriter import full_graph_rewrite, mulacc_unrolled +from tinygrad.codegen.devectorizer import full_graph_rewrite, mulacc_unrolled # Helper function to apply the graph rewrite def apply_rewrite(expr): diff --git a/test/unit/test_shapetracker.py b/test/unit/test_shapetracker.py index 7257a6dd0a..919766681f 100644 --- a/test/unit/test_shapetracker.py +++ b/test/unit/test_shapetracker.py @@ -6,7 +6,7 @@ from tinygrad.helpers import prod from tinygrad.shape.shapetracker import ShapeTracker, View from tinygrad import Variable from tinygrad.ops import UOp, Ops, graph_rewrite -from tinygrad.codegen.rewriter import sym +from tinygrad.codegen.devectorizer import sym from itertools import product def shapetracker_getitem(st:ShapeTracker, val:int): diff --git a/test/unit/test_simplify_valid_idx.py b/test/unit/test_simplify_valid_idx.py index b46f4f8f60..ebf42c4a1e 100644 --- a/test/unit/test_simplify_valid_idx.py +++ b/test/unit/test_simplify_valid_idx.py @@ -1,6 +1,6 @@ import unittest, itertools -from tinygrad.codegen.rewriter import full_graph_rewrite +from tinygrad.codegen.devectorizer import full_graph_rewrite from tinygrad.dtype import dtypes from tinygrad.ops import UOp, Ops from tinygrad.codegen.symbolic import simplify_valid diff --git a/test/unit/test_uop_symbolic.py b/test/unit/test_uop_symbolic.py index e4552679d6..45377065ea 100644 --- a/test/unit/test_uop_symbolic.py +++ b/test/unit/test_uop_symbolic.py @@ -3,7 +3,7 @@ import unittest, pickle from tinygrad.dtype import dtypes, ConstType from tinygrad.codegen.linearize import linearize_uop -from tinygrad.codegen.rewriter import full_graph_rewrite, sym +from tinygrad.codegen.devectorizer import full_graph_rewrite, sym from tinygrad.ops import UOp, Ops, graph_rewrite, sym_infer from tinygrad import Variable import functools diff --git a/tinygrad/codegen/rewriter.py b/tinygrad/codegen/devectorizer.py similarity index 100% rename from tinygrad/codegen/rewriter.py rename to tinygrad/codegen/devectorizer.py diff --git a/tinygrad/codegen/kernel.py b/tinygrad/codegen/kernel.py index fef7f693f8..18afee1f7c 100644 --- a/tinygrad/codegen/kernel.py +++ b/tinygrad/codegen/kernel.py @@ -15,7 +15,7 @@ from tinygrad.helpers import DEBUG, TC_SELECT, TC_OPT, USE_TC, AMX, CAPTURE_PROC from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.view import strides_for_shape from tinygrad.codegen.linearize import linearize_uop -from tinygrad.codegen.rewriter import full_graph_rewrite +from tinygrad.codegen.devectorizer import full_graph_rewrite from tinygrad.codegen.lowerer import rewrite_shapetracker_with_index, get_contraction class OptOps(Enum): diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index ab18ff6218..3b03d69cd0 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -5,7 +5,7 @@ from tinygrad.ops import GroupOp, Ops, UOp, PatternMatcher, UPat from tinygrad.helpers import strip_parens, getenv, prod, dedup, AMX from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType from tinygrad.renderer import Renderer, TensorCore -from tinygrad.codegen.rewriter import no_vectorized_alu +from tinygrad.codegen.devectorizer import no_vectorized_alu base_rewrite = PatternMatcher([ (UPat(Ops.DEFINE_ACC, name="x"), lambda ctx,x: ctx[x.src[0]]),