From 00e9ba0b82eca7850ea861b898db0db03f45edac Mon Sep 17 00:00:00 2001 From: chenyu Date: Sat, 24 Jan 2026 11:17:41 -0500 Subject: [PATCH] update type for split_uop and where_on_load [pr] (#14319) also variable names in where_on_load, before logic update --- tinygrad/uop/ops.py | 6 +++--- tinygrad/uop/symbolic.py | 18 +++++++++--------- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index 1ba458b893..60b03f26bc 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -1,5 +1,5 @@ from __future__ import annotations -from typing import Any, Callable, cast, TYPE_CHECKING, Type, Sequence, Iterable, Final +from typing import Any, Callable, cast, TYPE_CHECKING, Type, Sequence, Iterable, Final, Iterator import sys, time, functools, itertools, math, operator, hashlib, os, types, pickle, pathlib, inspect, weakref, collections from dataclasses import dataclass from enum import Enum, auto @@ -467,7 +467,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass): return UOp(Ops.ALLREDUCE, self.dtype, (self, UOp(Ops.DEVICE, arg=device) if not isinstance(device, UOp) else device), op) def overflows(self, dtype:DType) -> bool: return self.vmin < dtype.min or dtype.max < self.vmax - def split_uop(self:UOp, sep:Ops): + def split_uop(self:UOp, sep:Ops) -> Iterator[UOp]: if self.op is sep: for s in self.src: yield from s.split_uop(sep) else: yield self @@ -734,7 +734,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass): (fac, const), (div_fac, div_const) = self.pop_const(Ops.MUL), v.pop_const(Ops.MUL) new_count = collections.Counter(fac.split_uop(Ops.MUL)) new_count.subtract(div_fac.split_uop(Ops.MUL)) - if const%div_const==0 and all(v>=0 for v in new_count.values()): return math.prod([*new_count.elements(), self.const_like(const//div_const)]) + if const%div_const==0 and all(v>=0 for v in new_count.values()): return math.prod(new_count.elements(), start=self.const_like(const//div_const)) return None # generic None if we aren't sure def sum(self:UOp, *uops:UOp) -> UOp: return functools.reduce(operator.or_ if self.dtype is dtypes.bool else operator.add, uops, self) def prod(self:UOp, *uops:UOp) -> UOp: return functools.reduce(operator.and_ if self.dtype is dtypes.bool else operator.mul, uops, self) diff --git a/tinygrad/uop/symbolic.py b/tinygrad/uop/symbolic.py index 136def2d86..effb084d3a 100644 --- a/tinygrad/uop/symbolic.py +++ b/tinygrad/uop/symbolic.py @@ -345,22 +345,22 @@ def drop_and_clauses(cond:UOp, x:UOp, i:UOp) -> UOp|None: return UOp.const(dtypes.bool, True).prod(*keep).where(x, i) if drop else None pm_drop_and_clauses = PatternMatcher([(invalid_gate, drop_and_clauses)]) -def where_on_load(c1, buf, x): - c2 = x.get_valid() - duplicate_clauses = [c for c in c1.split_uop(Ops.AND) if c in c2.split_uop(Ops.AND)] +def where_on_load(cond:UOp, buf:UOp, idx:UOp) -> UOp|None: + load_valid = idx.get_valid() + duplicate_clauses = [c for c in cond.split_uop(Ops.AND) if c in load_valid.split_uop(Ops.AND)] # we move the condition from the where to the load _as long as_ the condtition doesn't have some range that would place it inside of a new range # also no data dependent loads! - moved_clauses = [c for c in c1.split_uop(Ops.AND) if c not in duplicate_clauses and all(r in x.ranges for r in c.ranges) - and all(u in x.backward_slice_with_self for u in c.backward_slice_with_self if u.op is Ops.INDEX)] + moved_clauses = [c for c in cond.split_uop(Ops.AND) if c not in duplicate_clauses and all(r in idx.ranges for r in c.ranges) + and all(u in idx.backward_slice_with_self for u in c.backward_slice_with_self if u.op is Ops.INDEX)] if not (removed:=moved_clauses+duplicate_clauses): return None # aditionally we can drop the clause on the where if it already exists in the load - remaining_clause = UOp.const(dtypes.bool, True).prod(*[c for c in c1.split_uop(Ops.AND) if c not in removed]) - return remaining_clause.where(buf.index(x.get_idx().valid(functools.reduce(operator.and_, moved_clauses, c2))), 0) + remaining_clause = UOp.const(dtypes.bool, True).prod(*[c for c in cond.split_uop(Ops.AND) if c not in removed]) + return remaining_clause.where(buf.index(idx.get_idx().valid(functools.reduce(operator.and_, moved_clauses, load_valid))), 0) # where after gated load becomes alt value, TODO: this is sort of duplicated with rules in devectorizer pm_move_where_on_load = PatternMatcher([ - (UPat.var("c1").where(UPat.var("buf").index(UPat.var("x")), 0), where_on_load), - (UPat.var("c1").where(0, UPat.var("buf").index(UPat.var("x"))), lambda c1,buf,x: where_on_load(c1.logical_not(),buf,x)), + (UPat.var("cond").where(UPat.var("buf").index(UPat.var("idx")), 0), where_on_load), + (UPat.var("cond").where(0, UPat.var("buf").index(UPat.var("idx"))), lambda cond,buf,idx: where_on_load(cond.logical_not(),buf,idx)), ]) def gated_given_valid(cond:UOp, x:UOp, i:UOp) -> UOp|None: