update type for split_uop and where_on_load [pr] (#14319)

also variable names in where_on_load, before logic update
This commit is contained in:
chenyu
2026-01-24 11:17:41 -05:00
committed by GitHub
parent cb69b7b2b2
commit 00e9ba0b82
2 changed files with 12 additions and 12 deletions

View File

@@ -1,5 +1,5 @@
from __future__ import annotations
from typing import Any, Callable, cast, TYPE_CHECKING, Type, Sequence, Iterable, Final
from typing import Any, Callable, cast, TYPE_CHECKING, Type, Sequence, Iterable, Final, Iterator
import sys, time, functools, itertools, math, operator, hashlib, os, types, pickle, pathlib, inspect, weakref, collections
from dataclasses import dataclass
from enum import Enum, auto
@@ -467,7 +467,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
return UOp(Ops.ALLREDUCE, self.dtype, (self, UOp(Ops.DEVICE, arg=device) if not isinstance(device, UOp) else device), op)
def overflows(self, dtype:DType) -> bool: return self.vmin < dtype.min or dtype.max < self.vmax
def split_uop(self:UOp, sep:Ops):
def split_uop(self:UOp, sep:Ops) -> Iterator[UOp]:
if self.op is sep:
for s in self.src: yield from s.split_uop(sep)
else: yield self
@@ -734,7 +734,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
(fac, const), (div_fac, div_const) = self.pop_const(Ops.MUL), v.pop_const(Ops.MUL)
new_count = collections.Counter(fac.split_uop(Ops.MUL))
new_count.subtract(div_fac.split_uop(Ops.MUL))
if const%div_const==0 and all(v>=0 for v in new_count.values()): return math.prod([*new_count.elements(), self.const_like(const//div_const)])
if const%div_const==0 and all(v>=0 for v in new_count.values()): return math.prod(new_count.elements(), start=self.const_like(const//div_const))
return None # generic None if we aren't sure
def sum(self:UOp, *uops:UOp) -> UOp: return functools.reduce(operator.or_ if self.dtype is dtypes.bool else operator.add, uops, self)
def prod(self:UOp, *uops:UOp) -> UOp: return functools.reduce(operator.and_ if self.dtype is dtypes.bool else operator.mul, uops, self)

View File

@@ -345,22 +345,22 @@ def drop_and_clauses(cond:UOp, x:UOp, i:UOp) -> UOp|None:
return UOp.const(dtypes.bool, True).prod(*keep).where(x, i) if drop else None
pm_drop_and_clauses = PatternMatcher([(invalid_gate, drop_and_clauses)])
def where_on_load(c1, buf, x):
c2 = x.get_valid()
duplicate_clauses = [c for c in c1.split_uop(Ops.AND) if c in c2.split_uop(Ops.AND)]
def where_on_load(cond:UOp, buf:UOp, idx:UOp) -> UOp|None:
load_valid = idx.get_valid()
duplicate_clauses = [c for c in cond.split_uop(Ops.AND) if c in load_valid.split_uop(Ops.AND)]
# we move the condition from the where to the load _as long as_ the condtition doesn't have some range that would place it inside of a new range
# also no data dependent loads!
moved_clauses = [c for c in c1.split_uop(Ops.AND) if c not in duplicate_clauses and all(r in x.ranges for r in c.ranges)
and all(u in x.backward_slice_with_self for u in c.backward_slice_with_self if u.op is Ops.INDEX)]
moved_clauses = [c for c in cond.split_uop(Ops.AND) if c not in duplicate_clauses and all(r in idx.ranges for r in c.ranges)
and all(u in idx.backward_slice_with_self for u in c.backward_slice_with_self if u.op is Ops.INDEX)]
if not (removed:=moved_clauses+duplicate_clauses): return None
# aditionally we can drop the clause on the where if it already exists in the load
remaining_clause = UOp.const(dtypes.bool, True).prod(*[c for c in c1.split_uop(Ops.AND) if c not in removed])
return remaining_clause.where(buf.index(x.get_idx().valid(functools.reduce(operator.and_, moved_clauses, c2))), 0)
remaining_clause = UOp.const(dtypes.bool, True).prod(*[c for c in cond.split_uop(Ops.AND) if c not in removed])
return remaining_clause.where(buf.index(idx.get_idx().valid(functools.reduce(operator.and_, moved_clauses, load_valid))), 0)
# where after gated load becomes alt value, TODO: this is sort of duplicated with rules in devectorizer
pm_move_where_on_load = PatternMatcher([
(UPat.var("c1").where(UPat.var("buf").index(UPat.var("x")), 0), where_on_load),
(UPat.var("c1").where(0, UPat.var("buf").index(UPat.var("x"))), lambda c1,buf,x: where_on_load(c1.logical_not(),buf,x)),
(UPat.var("cond").where(UPat.var("buf").index(UPat.var("idx")), 0), where_on_load),
(UPat.var("cond").where(0, UPat.var("buf").index(UPat.var("idx"))), lambda cond,buf,idx: where_on_load(cond.logical_not(),buf,idx)),
])
def gated_given_valid(cond:UOp, x:UOp, i:UOp) -> UOp|None: