multiple simplifies is faster [pr] (#9586)

* multiple simplifies is faster [pr]

* cleanup

* cleanup
This commit is contained in:
George Hotz
2025-03-26 21:42:52 +08:00
committed by GitHub
parent 1e6e75e39a
commit 5c6cd884e3

View File

@@ -4,7 +4,7 @@ from collections import defaultdict
from tinygrad.dtype import dtypes, ImageDType, PtrDType
from tinygrad.ops import UOp, Ops, UPat, PatternMatcher, resolve
from tinygrad.ops import graph_rewrite, GroupOp
from tinygrad.codegen.symbolic import symbolic_simple, split_uop, uop_given_valid, parse_valid, simplify_valid, sym
from tinygrad.codegen.symbolic import symbolic_simple, split_uop, uop_given_valid, parse_valid, simplify_valid, sym, symbolic
from tinygrad.helpers import getenv, flatten, TRANSCENDENTAL, AMX, prod, DEVECTORIZE
from tinygrad.codegen.transcendental import xexp2, xlog2, xsin, xpow, TRANSCENDENTAL_SUPPORTED_DTYPES
from tinygrad.renderer import Renderer
@@ -15,13 +15,16 @@ def expand_index(buf:UOp, vec:UOp, mask:UOp|None=None):
if getenv("UNSAFE_DISABLE_MASK", 0): mask = None
# first, extract all the relevant offsets
offsets_rootsrc: defaultdict[Any, dict[int, list[int]]] = defaultdict(dict)
midx, mmask = graph_rewrite(UOp.sink(UOp.sink(*[vec.gep(i) for i in range(vec.dtype.count)]),
UOp.sink(*[mask.gep(i) for i in range(vec.dtype.count)]) if mask is not None else UOp(Ops.NOOP)),
symbolic, name=f"index_buf_{buf.arg}").src
for i in range(vec.dtype.count):
idx = vec.gep(i).simplify()
idx: Any = midx.src[i]
if idx.op is Ops.ADD and idx.src[1].op is Ops.CONST: root_src, arg = idx.src[0], idx.src[1].arg
elif idx.op is Ops.ADD and idx.src[0].op is Ops.CONST: root_src, arg = idx.src[1], idx.src[0].arg
elif idx.op is Ops.CONST: root_src, arg = "CONST", idx.arg
else: root_src, arg = idx, 0
if mask is not None: root_src = (mask.gep(i).simplify(), root_src)
if mask is not None: root_src = (mmask.src[i], root_src)
offsets_rootsrc[root_src].setdefault(arg, []).append(i)
# the buf.dtype is always a pointer