failing test for the devectorize [pr] (#8940)

* failing test for the devectorize [pr]

* add DEVECTORIZE to method_cache
This commit is contained in:
George Hotz
2025-02-07 09:44:54 +08:00
committed by GitHub
parent ee1a0fb8ec
commit f54242849d
4 changed files with 17 additions and 9 deletions

View File

@@ -2019,6 +2019,11 @@ class TestOps(unittest.TestCase):
lambda x,w: torch.nn.functional.conv2d(x,w,stride=2).relu(),
lambda x,w: Tensor.conv2d(x,w,stride=2).relu())
@unittest.expectedFailure
@unittest.skipIf(Device.DEFAULT != "LLVM", "DEVECTORIZE=0 only for LLVM")
def test_strided_conv2d_simple_vec(self):
with Context(DEVECTORIZE=0): self.test_strided_conv2d_simple()
def test_strided_conv2d(self):
bs = 4
cin = 3

View File

@@ -4,7 +4,7 @@ from collections import defaultdict
from tinygrad.dtype import dtypes, ImageDType, PtrDType
from tinygrad.ops import UOp, Ops, UPat, PatternMatcher, symbolic_flat, symbolic_simple, resolve
from tinygrad.ops import graph_rewrite, split_uop, uop_given_valid, parse_valid, is_increasing, simplify_valid, GroupOp
from tinygrad.helpers import DEBUG, getenv, flatten, dedup, TRANSCENDENTAL, AMX, prod, partition, all_same
from tinygrad.helpers import DEBUG, getenv, flatten, dedup, TRANSCENDENTAL, AMX, prod, partition, all_same, DEVECTORIZE
from tinygrad.codegen.transcendental import xexp2, xlog2, xsin, xpow, TRANSCENDENTAL_SUPPORTED_DTYPES
from tinygrad.renderer import Renderer
@@ -514,13 +514,13 @@ def full_graph_rewrite(sink:UOp, opts:Optional[Renderer]=None) -> UOp:
# expand
sink = graph_rewrite(sink, sym+expander)
if getenv("NO_DEVECTORIZE"):
# new devectorize for load/store
sink = graph_rewrite(sink, sym+devectorize_load_store)
else:
if DEVECTORIZE:
# devectorize + load_store_indexing + mulacc_unrolled, mulacc_unrolled must be last because it can break loop_collapse
sink = graph_rewrite(sink, sym+(devectorize+float4_folding if opts is not None and opts.supports_float4 else devectorize)+load_store_indexing+
mulacc_unrolled)
else:
# new devectorize only for load/store
sink = graph_rewrite(sink, sym+devectorize_load_store)
# final rules for the renderer (without sym)
sink = graph_rewrite(sink, symbolic_simple+get_late_rewrite_patterns(supported_ops, TRANSCENDENTAL>=2)+pm_render+extra_matcher)

View File

@@ -2,6 +2,7 @@ from typing import Optional, cast, Generator
import time, pprint
from dataclasses import dataclass, replace
from tinygrad.helpers import all_same, colored, getenv, DEBUG, GlobalCounters, ansilen, BEAM, NOOPT, all_int, CAPTURING, Metadata, TRACEMETA
from tinygrad.helpers import DEVECTORIZE
from tinygrad.ops import Ops, PatternMatcher, UOp, UPat, Variable, sym_infer
from tinygrad.device import Device, Buffer
from tinygrad.renderer import Renderer, ProgramSpec, Estimates
@@ -99,11 +100,13 @@ class BufferXfer(BufferCopy):
# **************** method cache ****************
method_cache: dict[tuple[str, bytes, int, int, bool], CompiledRunner] = {}
method_cache: dict[tuple[str, bytes, tuple[int, ...], bool], CompiledRunner] = {}
def get_runner(device:str, ast:UOp) -> CompiledRunner:
ckey = (device, ast.key, BEAM.value, NOOPT.value, False)
# TODO: this should be all context relevant to rendering
context = (BEAM.value, NOOPT.value, DEVECTORIZE.value)
ckey = (device, ast.key, context, False)
if cret:=method_cache.get(ckey): return cret
bkey = (device.split(":")[0], ast.key, BEAM.value, NOOPT.value, True)
bkey = (device.split(":")[0], ast.key, context, True)
if bret:=method_cache.get(bkey):
method_cache[ckey] = ret = CompiledRunner(replace(bret.p, device=device), bret.lib)
else:

View File

@@ -110,7 +110,7 @@ TRANSCENDENTAL, TC_SEARCH_OVER_SHAPE = ContextVar("TRANSCENDENTAL", 1), ContextV
FUSE_ARANGE, FUSE_CONV_BW = ContextVar("FUSE_ARANGE", 0), ContextVar("FUSE_CONV_BW", 0)
SPLIT_REDUCEOP, NO_MEMORY_PLANNER, RING = ContextVar("SPLIT_REDUCEOP", 1), ContextVar("NO_MEMORY_PLANNER", 0), ContextVar("RING", 1)
PICKLE_BUFFERS, PROFILE, LRU = ContextVar("PICKLE_BUFFERS", 1), ContextVar("PROFILE", getenv("VIZ")), ContextVar("LRU", 1)
CACHELEVEL, IGNORE_BEAM_CACHE = ContextVar("CACHELEVEL", 2), ContextVar("IGNORE_BEAM_CACHE", 0)
CACHELEVEL, IGNORE_BEAM_CACHE, DEVECTORIZE = ContextVar("CACHELEVEL", 2), ContextVar("IGNORE_BEAM_CACHE", 0), ContextVar("DEVECTORIZE", 1)
@dataclass(frozen=True)
class Metadata: