mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
update type for split_uop and where_on_load [pr] (#14319)
also variable names in where_on_load, before logic update
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
from __future__ import annotations
|
||||
from typing import Any, Callable, cast, TYPE_CHECKING, Type, Sequence, Iterable, Final
|
||||
from typing import Any, Callable, cast, TYPE_CHECKING, Type, Sequence, Iterable, Final, Iterator
|
||||
import sys, time, functools, itertools, math, operator, hashlib, os, types, pickle, pathlib, inspect, weakref, collections
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum, auto
|
||||
@@ -467,7 +467,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||
return UOp(Ops.ALLREDUCE, self.dtype, (self, UOp(Ops.DEVICE, arg=device) if not isinstance(device, UOp) else device), op)
|
||||
def overflows(self, dtype:DType) -> bool: return self.vmin < dtype.min or dtype.max < self.vmax
|
||||
|
||||
def split_uop(self:UOp, sep:Ops):
|
||||
def split_uop(self:UOp, sep:Ops) -> Iterator[UOp]:
|
||||
if self.op is sep:
|
||||
for s in self.src: yield from s.split_uop(sep)
|
||||
else: yield self
|
||||
@@ -734,7 +734,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||
(fac, const), (div_fac, div_const) = self.pop_const(Ops.MUL), v.pop_const(Ops.MUL)
|
||||
new_count = collections.Counter(fac.split_uop(Ops.MUL))
|
||||
new_count.subtract(div_fac.split_uop(Ops.MUL))
|
||||
if const%div_const==0 and all(v>=0 for v in new_count.values()): return math.prod([*new_count.elements(), self.const_like(const//div_const)])
|
||||
if const%div_const==0 and all(v>=0 for v in new_count.values()): return math.prod(new_count.elements(), start=self.const_like(const//div_const))
|
||||
return None # generic None if we aren't sure
|
||||
def sum(self:UOp, *uops:UOp) -> UOp: return functools.reduce(operator.or_ if self.dtype is dtypes.bool else operator.add, uops, self)
|
||||
def prod(self:UOp, *uops:UOp) -> UOp: return functools.reduce(operator.and_ if self.dtype is dtypes.bool else operator.mul, uops, self)
|
||||
|
||||
@@ -345,22 +345,22 @@ def drop_and_clauses(cond:UOp, x:UOp, i:UOp) -> UOp|None:
|
||||
return UOp.const(dtypes.bool, True).prod(*keep).where(x, i) if drop else None
|
||||
pm_drop_and_clauses = PatternMatcher([(invalid_gate, drop_and_clauses)])
|
||||
|
||||
def where_on_load(c1, buf, x):
|
||||
c2 = x.get_valid()
|
||||
duplicate_clauses = [c for c in c1.split_uop(Ops.AND) if c in c2.split_uop(Ops.AND)]
|
||||
def where_on_load(cond:UOp, buf:UOp, idx:UOp) -> UOp|None:
|
||||
load_valid = idx.get_valid()
|
||||
duplicate_clauses = [c for c in cond.split_uop(Ops.AND) if c in load_valid.split_uop(Ops.AND)]
|
||||
# we move the condition from the where to the load _as long as_ the condtition doesn't have some range that would place it inside of a new range
|
||||
# also no data dependent loads!
|
||||
moved_clauses = [c for c in c1.split_uop(Ops.AND) if c not in duplicate_clauses and all(r in x.ranges for r in c.ranges)
|
||||
and all(u in x.backward_slice_with_self for u in c.backward_slice_with_self if u.op is Ops.INDEX)]
|
||||
moved_clauses = [c for c in cond.split_uop(Ops.AND) if c not in duplicate_clauses and all(r in idx.ranges for r in c.ranges)
|
||||
and all(u in idx.backward_slice_with_self for u in c.backward_slice_with_self if u.op is Ops.INDEX)]
|
||||
if not (removed:=moved_clauses+duplicate_clauses): return None
|
||||
# aditionally we can drop the clause on the where if it already exists in the load
|
||||
remaining_clause = UOp.const(dtypes.bool, True).prod(*[c for c in c1.split_uop(Ops.AND) if c not in removed])
|
||||
return remaining_clause.where(buf.index(x.get_idx().valid(functools.reduce(operator.and_, moved_clauses, c2))), 0)
|
||||
remaining_clause = UOp.const(dtypes.bool, True).prod(*[c for c in cond.split_uop(Ops.AND) if c not in removed])
|
||||
return remaining_clause.where(buf.index(idx.get_idx().valid(functools.reduce(operator.and_, moved_clauses, load_valid))), 0)
|
||||
|
||||
# where after gated load becomes alt value, TODO: this is sort of duplicated with rules in devectorizer
|
||||
pm_move_where_on_load = PatternMatcher([
|
||||
(UPat.var("c1").where(UPat.var("buf").index(UPat.var("x")), 0), where_on_load),
|
||||
(UPat.var("c1").where(0, UPat.var("buf").index(UPat.var("x"))), lambda c1,buf,x: where_on_load(c1.logical_not(),buf,x)),
|
||||
(UPat.var("cond").where(UPat.var("buf").index(UPat.var("idx")), 0), where_on_load),
|
||||
(UPat.var("cond").where(0, UPat.var("buf").index(UPat.var("idx"))), lambda cond,buf,idx: where_on_load(cond.logical_not(),buf,idx)),
|
||||
])
|
||||
|
||||
def gated_given_valid(cond:UOp, x:UOp, i:UOp) -> UOp|None:
|
||||
|
||||
Reference in New Issue
Block a user