diff --git a/tinygrad/codegen/devectorizer.py b/tinygrad/codegen/devectorizer.py index e2b718288e..3206a50c9a 100644 --- a/tinygrad/codegen/devectorizer.py +++ b/tinygrad/codegen/devectorizer.py @@ -4,7 +4,7 @@ from collections import defaultdict from tinygrad.dtype import dtypes, ImageDType, PtrDType from tinygrad.ops import UOp, Ops, UPat, PatternMatcher, resolve from tinygrad.ops import graph_rewrite, GroupOp -from tinygrad.codegen.symbolic import symbolic_simple, split_uop, uop_given_valid, parse_valid, simplify_valid, sym, symbolic +from tinygrad.codegen.symbolic import symbolic_simple, split_uop, uop_given_valid, parse_valid, simplify_valid, sym, symbolic_flat from tinygrad.helpers import getenv, flatten, TRANSCENDENTAL, AMX, prod, DEVECTORIZE from tinygrad.codegen.transcendental import xexp2, xlog2, xsin, xpow, TRANSCENDENTAL_SUPPORTED_DTYPES from tinygrad.renderer import Renderer @@ -15,7 +15,7 @@ def expand_index(buf:UOp, vec:UOp, mask:UOp|None=None): if getenv("UNSAFE_DISABLE_MASK", 0): mask = None # generate the individual indexes midx = graph_rewrite(UOp.sink(*[buf.index(vec.gep(i), mask.gep(i) if mask is not None else None) for i in range(vec.dtype.count)]), - symbolic+load_store_indexing, name=f"index_buf_{buf.arg}") + symbolic_flat+load_store_indexing, name=f"index_buf_{buf.arg}") # extract all the relevant offsets offsets_rootsrc: defaultdict[Any, dict[int, list[int]]] = defaultdict(dict) for i in range(vec.dtype.count):