From 5c6cd884e3fd0782b6daedcaff48c5d2ae04708b Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Wed, 26 Mar 2025 21:42:52 +0800 Subject: [PATCH] multiple simplifies is faster [pr] (#9586) * multiple simplifies is faster [pr] * cleanup * cleanup --- tinygrad/codegen/devectorizer.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/tinygrad/codegen/devectorizer.py b/tinygrad/codegen/devectorizer.py index d2096bde83..c7b95fd9fd 100644 --- a/tinygrad/codegen/devectorizer.py +++ b/tinygrad/codegen/devectorizer.py @@ -4,7 +4,7 @@ from collections import defaultdict from tinygrad.dtype import dtypes, ImageDType, PtrDType from tinygrad.ops import UOp, Ops, UPat, PatternMatcher, resolve from tinygrad.ops import graph_rewrite, GroupOp -from tinygrad.codegen.symbolic import symbolic_simple, split_uop, uop_given_valid, parse_valid, simplify_valid, sym +from tinygrad.codegen.symbolic import symbolic_simple, split_uop, uop_given_valid, parse_valid, simplify_valid, sym, symbolic from tinygrad.helpers import getenv, flatten, TRANSCENDENTAL, AMX, prod, DEVECTORIZE from tinygrad.codegen.transcendental import xexp2, xlog2, xsin, xpow, TRANSCENDENTAL_SUPPORTED_DTYPES from tinygrad.renderer import Renderer @@ -15,13 +15,16 @@ def expand_index(buf:UOp, vec:UOp, mask:UOp|None=None): if getenv("UNSAFE_DISABLE_MASK", 0): mask = None # first, extract all the relevant offsets offsets_rootsrc: defaultdict[Any, dict[int, list[int]]] = defaultdict(dict) + midx, mmask = graph_rewrite(UOp.sink(UOp.sink(*[vec.gep(i) for i in range(vec.dtype.count)]), + UOp.sink(*[mask.gep(i) for i in range(vec.dtype.count)]) if mask is not None else UOp(Ops.NOOP)), + symbolic, name=f"index_buf_{buf.arg}").src for i in range(vec.dtype.count): - idx = vec.gep(i).simplify() + idx: Any = midx.src[i] if idx.op is Ops.ADD and idx.src[1].op is Ops.CONST: root_src, arg = idx.src[0], idx.src[1].arg elif idx.op is Ops.ADD and idx.src[0].op is Ops.CONST: root_src, arg = idx.src[1], idx.src[0].arg elif idx.op is Ops.CONST: root_src, arg = "CONST", idx.arg else: root_src, arg = idx, 0 - if mask is not None: root_src = (mask.gep(i).simplify(), root_src) + if mask is not None: root_src = (mmask.src[i], root_src) offsets_rootsrc[root_src].setdefault(arg, []).append(i) # the buf.dtype is always a pointer