mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
failing test for the devectorize [pr] (#8940)
* failing test for the devectorize [pr] * add DEVECTORIZE to method_cache
This commit is contained in:
@@ -2019,6 +2019,11 @@ class TestOps(unittest.TestCase):
|
||||
lambda x,w: torch.nn.functional.conv2d(x,w,stride=2).relu(),
|
||||
lambda x,w: Tensor.conv2d(x,w,stride=2).relu())
|
||||
|
||||
@unittest.expectedFailure
|
||||
@unittest.skipIf(Device.DEFAULT != "LLVM", "DEVECTORIZE=0 only for LLVM")
|
||||
def test_strided_conv2d_simple_vec(self):
|
||||
with Context(DEVECTORIZE=0): self.test_strided_conv2d_simple()
|
||||
|
||||
def test_strided_conv2d(self):
|
||||
bs = 4
|
||||
cin = 3
|
||||
|
||||
@@ -4,7 +4,7 @@ from collections import defaultdict
|
||||
from tinygrad.dtype import dtypes, ImageDType, PtrDType
|
||||
from tinygrad.ops import UOp, Ops, UPat, PatternMatcher, symbolic_flat, symbolic_simple, resolve
|
||||
from tinygrad.ops import graph_rewrite, split_uop, uop_given_valid, parse_valid, is_increasing, simplify_valid, GroupOp
|
||||
from tinygrad.helpers import DEBUG, getenv, flatten, dedup, TRANSCENDENTAL, AMX, prod, partition, all_same
|
||||
from tinygrad.helpers import DEBUG, getenv, flatten, dedup, TRANSCENDENTAL, AMX, prod, partition, all_same, DEVECTORIZE
|
||||
from tinygrad.codegen.transcendental import xexp2, xlog2, xsin, xpow, TRANSCENDENTAL_SUPPORTED_DTYPES
|
||||
from tinygrad.renderer import Renderer
|
||||
|
||||
@@ -514,13 +514,13 @@ def full_graph_rewrite(sink:UOp, opts:Optional[Renderer]=None) -> UOp:
|
||||
# expand
|
||||
sink = graph_rewrite(sink, sym+expander)
|
||||
|
||||
if getenv("NO_DEVECTORIZE"):
|
||||
# new devectorize for load/store
|
||||
sink = graph_rewrite(sink, sym+devectorize_load_store)
|
||||
else:
|
||||
if DEVECTORIZE:
|
||||
# devectorize + load_store_indexing + mulacc_unrolled, mulacc_unrolled must be last because it can break loop_collapse
|
||||
sink = graph_rewrite(sink, sym+(devectorize+float4_folding if opts is not None and opts.supports_float4 else devectorize)+load_store_indexing+
|
||||
mulacc_unrolled)
|
||||
else:
|
||||
# new devectorize only for load/store
|
||||
sink = graph_rewrite(sink, sym+devectorize_load_store)
|
||||
|
||||
# final rules for the renderer (without sym)
|
||||
sink = graph_rewrite(sink, symbolic_simple+get_late_rewrite_patterns(supported_ops, TRANSCENDENTAL>=2)+pm_render+extra_matcher)
|
||||
|
||||
@@ -2,6 +2,7 @@ from typing import Optional, cast, Generator
|
||||
import time, pprint
|
||||
from dataclasses import dataclass, replace
|
||||
from tinygrad.helpers import all_same, colored, getenv, DEBUG, GlobalCounters, ansilen, BEAM, NOOPT, all_int, CAPTURING, Metadata, TRACEMETA
|
||||
from tinygrad.helpers import DEVECTORIZE
|
||||
from tinygrad.ops import Ops, PatternMatcher, UOp, UPat, Variable, sym_infer
|
||||
from tinygrad.device import Device, Buffer
|
||||
from tinygrad.renderer import Renderer, ProgramSpec, Estimates
|
||||
@@ -99,11 +100,13 @@ class BufferXfer(BufferCopy):
|
||||
|
||||
# **************** method cache ****************
|
||||
|
||||
method_cache: dict[tuple[str, bytes, int, int, bool], CompiledRunner] = {}
|
||||
method_cache: dict[tuple[str, bytes, tuple[int, ...], bool], CompiledRunner] = {}
|
||||
def get_runner(device:str, ast:UOp) -> CompiledRunner:
|
||||
ckey = (device, ast.key, BEAM.value, NOOPT.value, False)
|
||||
# TODO: this should be all context relevant to rendering
|
||||
context = (BEAM.value, NOOPT.value, DEVECTORIZE.value)
|
||||
ckey = (device, ast.key, context, False)
|
||||
if cret:=method_cache.get(ckey): return cret
|
||||
bkey = (device.split(":")[0], ast.key, BEAM.value, NOOPT.value, True)
|
||||
bkey = (device.split(":")[0], ast.key, context, True)
|
||||
if bret:=method_cache.get(bkey):
|
||||
method_cache[ckey] = ret = CompiledRunner(replace(bret.p, device=device), bret.lib)
|
||||
else:
|
||||
|
||||
@@ -110,7 +110,7 @@ TRANSCENDENTAL, TC_SEARCH_OVER_SHAPE = ContextVar("TRANSCENDENTAL", 1), ContextV
|
||||
FUSE_ARANGE, FUSE_CONV_BW = ContextVar("FUSE_ARANGE", 0), ContextVar("FUSE_CONV_BW", 0)
|
||||
SPLIT_REDUCEOP, NO_MEMORY_PLANNER, RING = ContextVar("SPLIT_REDUCEOP", 1), ContextVar("NO_MEMORY_PLANNER", 0), ContextVar("RING", 1)
|
||||
PICKLE_BUFFERS, PROFILE, LRU = ContextVar("PICKLE_BUFFERS", 1), ContextVar("PROFILE", getenv("VIZ")), ContextVar("LRU", 1)
|
||||
CACHELEVEL, IGNORE_BEAM_CACHE = ContextVar("CACHELEVEL", 2), ContextVar("IGNORE_BEAM_CACHE", 0)
|
||||
CACHELEVEL, IGNORE_BEAM_CACHE, DEVECTORIZE = ContextVar("CACHELEVEL", 2), ContextVar("IGNORE_BEAM_CACHE", 0), ContextVar("DEVECTORIZE", 1)
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Metadata:
|
||||
|
||||
Reference in New Issue
Block a user