From f54242849d54d5a49c30567bb6483f4db7894045 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Fri, 7 Feb 2025 09:44:54 +0800 Subject: [PATCH] failing test for the devectorize [pr] (#8940) * failing test for the devectorize [pr] * add DEVECTORIZE to method_cache --- test/test_ops.py | 5 +++++ tinygrad/codegen/rewriter.py | 10 +++++----- tinygrad/engine/realize.py | 9 ++++++--- tinygrad/helpers.py | 2 +- 4 files changed, 17 insertions(+), 9 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index 4b42ffdedb..014001e654 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -2019,6 +2019,11 @@ class TestOps(unittest.TestCase): lambda x,w: torch.nn.functional.conv2d(x,w,stride=2).relu(), lambda x,w: Tensor.conv2d(x,w,stride=2).relu()) + @unittest.expectedFailure + @unittest.skipIf(Device.DEFAULT != "LLVM", "DEVECTORIZE=0 only for LLVM") + def test_strided_conv2d_simple_vec(self): + with Context(DEVECTORIZE=0): self.test_strided_conv2d_simple() + def test_strided_conv2d(self): bs = 4 cin = 3 diff --git a/tinygrad/codegen/rewriter.py b/tinygrad/codegen/rewriter.py index 3de8974bea..974581f2d9 100644 --- a/tinygrad/codegen/rewriter.py +++ b/tinygrad/codegen/rewriter.py @@ -4,7 +4,7 @@ from collections import defaultdict from tinygrad.dtype import dtypes, ImageDType, PtrDType from tinygrad.ops import UOp, Ops, UPat, PatternMatcher, symbolic_flat, symbolic_simple, resolve from tinygrad.ops import graph_rewrite, split_uop, uop_given_valid, parse_valid, is_increasing, simplify_valid, GroupOp -from tinygrad.helpers import DEBUG, getenv, flatten, dedup, TRANSCENDENTAL, AMX, prod, partition, all_same +from tinygrad.helpers import DEBUG, getenv, flatten, dedup, TRANSCENDENTAL, AMX, prod, partition, all_same, DEVECTORIZE from tinygrad.codegen.transcendental import xexp2, xlog2, xsin, xpow, TRANSCENDENTAL_SUPPORTED_DTYPES from tinygrad.renderer import Renderer @@ -514,13 +514,13 @@ def full_graph_rewrite(sink:UOp, opts:Optional[Renderer]=None) -> UOp: # expand sink = graph_rewrite(sink, sym+expander) - if getenv("NO_DEVECTORIZE"): - # new devectorize for load/store - sink = graph_rewrite(sink, sym+devectorize_load_store) - else: + if DEVECTORIZE: # devectorize + load_store_indexing + mulacc_unrolled, mulacc_unrolled must be last because it can break loop_collapse sink = graph_rewrite(sink, sym+(devectorize+float4_folding if opts is not None and opts.supports_float4 else devectorize)+load_store_indexing+ mulacc_unrolled) + else: + # new devectorize only for load/store + sink = graph_rewrite(sink, sym+devectorize_load_store) # final rules for the renderer (without sym) sink = graph_rewrite(sink, symbolic_simple+get_late_rewrite_patterns(supported_ops, TRANSCENDENTAL>=2)+pm_render+extra_matcher) diff --git a/tinygrad/engine/realize.py b/tinygrad/engine/realize.py index c8c3a06eea..b2102cc058 100644 --- a/tinygrad/engine/realize.py +++ b/tinygrad/engine/realize.py @@ -2,6 +2,7 @@ from typing import Optional, cast, Generator import time, pprint from dataclasses import dataclass, replace from tinygrad.helpers import all_same, colored, getenv, DEBUG, GlobalCounters, ansilen, BEAM, NOOPT, all_int, CAPTURING, Metadata, TRACEMETA +from tinygrad.helpers import DEVECTORIZE from tinygrad.ops import Ops, PatternMatcher, UOp, UPat, Variable, sym_infer from tinygrad.device import Device, Buffer from tinygrad.renderer import Renderer, ProgramSpec, Estimates @@ -99,11 +100,13 @@ class BufferXfer(BufferCopy): # **************** method cache **************** -method_cache: dict[tuple[str, bytes, int, int, bool], CompiledRunner] = {} +method_cache: dict[tuple[str, bytes, tuple[int, ...], bool], CompiledRunner] = {} def get_runner(device:str, ast:UOp) -> CompiledRunner: - ckey = (device, ast.key, BEAM.value, NOOPT.value, False) + # TODO: this should be all context relevant to rendering + context = (BEAM.value, NOOPT.value, DEVECTORIZE.value) + ckey = (device, ast.key, context, False) if cret:=method_cache.get(ckey): return cret - bkey = (device.split(":")[0], ast.key, BEAM.value, NOOPT.value, True) + bkey = (device.split(":")[0], ast.key, context, True) if bret:=method_cache.get(bkey): method_cache[ckey] = ret = CompiledRunner(replace(bret.p, device=device), bret.lib) else: diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index 09054b4419..717ba5c6d8 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -110,7 +110,7 @@ TRANSCENDENTAL, TC_SEARCH_OVER_SHAPE = ContextVar("TRANSCENDENTAL", 1), ContextV FUSE_ARANGE, FUSE_CONV_BW = ContextVar("FUSE_ARANGE", 0), ContextVar("FUSE_CONV_BW", 0) SPLIT_REDUCEOP, NO_MEMORY_PLANNER, RING = ContextVar("SPLIT_REDUCEOP", 1), ContextVar("NO_MEMORY_PLANNER", 0), ContextVar("RING", 1) PICKLE_BUFFERS, PROFILE, LRU = ContextVar("PICKLE_BUFFERS", 1), ContextVar("PROFILE", getenv("VIZ")), ContextVar("LRU", 1) -CACHELEVEL, IGNORE_BEAM_CACHE = ContextVar("CACHELEVEL", 2), ContextVar("IGNORE_BEAM_CACHE", 0) +CACHELEVEL, IGNORE_BEAM_CACHE, DEVECTORIZE = ContextVar("CACHELEVEL", 2), ContextVar("IGNORE_BEAM_CACHE", 0), ContextVar("DEVECTORIZE", 1) @dataclass(frozen=True) class Metadata: