From 8601115976c8d4cea493ff8fa87f9f4773d8b1c2 Mon Sep 17 00:00:00 2001 From: chenyu Date: Tue, 15 Oct 2024 17:31:25 -0400 Subject: [PATCH] _get_chain -> split_uop [pr] (#7075) --- tinygrad/codegen/uopgraph.py | 20 +++++++++----------- tinygrad/ops.py | 12 ++++++------ tinygrad/shape/shapetracker.py | 4 ++-- 3 files changed, 17 insertions(+), 19 deletions(-) diff --git a/tinygrad/codegen/uopgraph.py b/tinygrad/codegen/uopgraph.py index c33f7604dc..0c5a1ab57d 100644 --- a/tinygrad/codegen/uopgraph.py +++ b/tinygrad/codegen/uopgraph.py @@ -3,8 +3,8 @@ from typing import Optional, Tuple, Dict, List, cast, TYPE_CHECKING, Any, Defaul import functools, itertools, operator from collections import defaultdict from tinygrad.dtype import dtypes, PtrDType, ImageDType, ConstType -from tinygrad.ops import UnaryOps, BinaryOps, UOp, UOps, identity_element -from tinygrad.ops import UPat, PatternMatcher, graph_rewrite, TernaryOps, symbolic_flat, is_irreducible, _get_chain +from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, UOp, UOps, UPat, PatternMatcher +from tinygrad.ops import graph_rewrite, symbolic_flat, is_irreducible, split_uop, identity_element from tinygrad.helpers import DEBUG, getenv, flatten, dedup, TRANSCENDENTAL, AMX, prod, partition, all_same from tinygrad.codegen.transcendental import xexp2, xlog2, xsin, TRANSCENDENTAL_SUPPORTED_DTYPES @@ -105,7 +105,7 @@ def idx_given_valid(valid:UOp, idx:UOp) -> Optional[UOp]: # first, parse valid into {expr: (lower_bound, upper_bound)} bounds:DefaultDict[UOp, List[Optional[ConstType]]] = defaultdict(lambda: [None, None]) - for stmt in _get_chain(valid, BinaryOps.AND): + for stmt in split_uop(valid, BinaryOps.AND): expr, is_upper, c = parse_valid(stmt) bounds[expr][int(is_upper)] = c @@ -116,9 +116,9 @@ def idx_given_valid(valid:UOp, idx:UOp) -> Optional[UOp]: # every candidate is a set of contrained UOp based on valid, and if every item in a set simplifies the idx into a same output, we rewrite idx candidates = [] - if uop.op is UOps.ALU and uop.arg is BinaryOps.ADD and all(is_irreducible(u) and u.vmin == 0 for u in _get_chain(uop, BinaryOps.ADD)): + if uop.op is UOps.ALU and uop.arg is BinaryOps.ADD and all(is_irreducible(u) and u.vmin == 0 for u in split_uop(uop, BinaryOps.ADD)): # if the constraint is a simplex: X0 + X1 + ... > 0, we can check if all Xi > 0 simplify into the same output - candidates.append([(Xi, UOp.variable("fake", 1, Xi.vmax, Xi.dtype)) for Xi in _get_chain(uop, BinaryOps.ADD)]) + candidates.append([(Xi, UOp.variable("fake", 1, Xi.vmax, Xi.dtype)) for Xi in split_uop(uop, BinaryOps.ADD)]) # try checking the whole clause candidates.append([(uop, UOp.variable("fake", uop.vmin if v[0] is None else v[0], uop.vmax if v[1] is None else v[1], uop.dtype))]) @@ -147,14 +147,12 @@ def simplify_image_load(load:UOp) -> Optional[UOp]: # can drop valid if idx is out of bound when valid is False drop_stmt = [] - for stmt in _get_chain(valid, BinaryOps.AND): + for stmt in split_uop(valid, BinaryOps.AND): X, is_upper_bound, c = parse_valid(stmt) # for X0 + X1 + ... >= 1, check if it's out of bound when Xi = 0 for all i - # TODO: does not need to be add chain? - if not is_upper_bound and c == 1 and X.op is UOps.ALU and X.arg is BinaryOps.ADD and \ - all(is_irreducible(u) and u.vmin == 0 for u in _get_chain(X, BinaryOps.ADD)): - testidx = functools.reduce(lambda nowidx,u: replace_uop(nowidx, u, u.const_like(0)), _get_chain(X, BinaryOps.ADD), idx) + if not is_upper_bound and c == 1 and all(is_irreducible(u) and u.vmin == 0 for u in split_uop(X, BinaryOps.ADD)): + testidx = functools.reduce(lambda nowidx,u: replace_uop(nowidx, u, u.const_like(0)), split_uop(X, BinaryOps.ADD), idx) testidx = graph_rewrite(testidx, sym) if testidx.src[0].vmax < 0 or testidx.src[1].vmax < 0: drop_stmt.append(stmt) @@ -171,7 +169,7 @@ def simplify_image_load(load:UOp) -> Optional[UOp]: break if not drop_stmt and idx is start_idx: return None - new_valid = functools.reduce(operator.and_, ss) if (ss:=[s for s in _get_chain(valid, BinaryOps.AND) if s not in drop_stmt]) else None + new_valid = functools.reduce(operator.and_, ss) if (ss:=[s for s in split_uop(valid, BinaryOps.AND) if s not in drop_stmt]) else None return load.replace(src=((buf, idx, invalid_val, new_valid) if new_valid is not None else (buf, idx))) # ***** optional patterns ***** diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 059ff0b9c6..12215d3426 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -759,9 +759,9 @@ def type_verify(uops:List[UOp]): # *** most of symbolic lives here now *** -def _get_chain(x:UOp, sep:BinaryOps): +def split_uop(x:UOp, sep:BinaryOps): if x.op is UOps.ALU and x.arg is sep: - for s in x.src: yield from _get_chain(s, sep) + for s in x.src: yield from split_uop(s, sep) else: yield x def mod_folding(x:UOp, c:int) -> Optional[UOp]: @@ -771,7 +771,7 @@ def mod_folding(x:UOp, c:int) -> Optional[UOp]: if 0 < c and 0 <= x.vmin and (quotient:=x.vmin//c) == x.vmax//c: return x-quotient*c remainder, something_changed = [], False - for u in _get_chain(x, BinaryOps.ADD): + for u in split_uop(x, BinaryOps.ADD): if (factor:=u.const_factor())%c != factor: divides = u.divides(factor)*(factor%c) assert divides is not None @@ -791,7 +791,7 @@ def div_folding(x:UOp, c:int) -> Optional[UOp]: if 0 <= x.vmin and x.vmax < c: return x.const_like(0) quotient, remainder, rem_const, something_changed, gcd, divisor = [], [], 0, False, c, 1 - for u in _get_chain(x, BinaryOps.ADD): + for u in split_uop(x, BinaryOps.ADD): if u.op is UOps.CONST: # add all const together first if rem_const != 0: something_changed = True @@ -830,7 +830,7 @@ def lt_folding(x:UOp, c:int) -> Optional[UOp]: def fold_unrolled_divs(divs:UOp): # div pattern in unrolled arange # example: (x//4+(x+1)//4+(x+2)//4+(x+3)//4 -> x - add_chain, denominator, seen_const, ans = list(_get_chain(divs, BinaryOps.ADD)), None, [], None + add_chain, denominator, seen_const, ans = list(split_uop(divs, BinaryOps.ADD)), None, [], None for u in add_chain: if not (u.op is UOps.ALU and u.arg is BinaryOps.IDIV and u.src[1].op is UOps.CONST): return None if denominator is None: denominator = u.src[1].arg @@ -854,7 +854,7 @@ def canonicalize_simplex(X:UOp) -> Optional[UOp]: # (X := a0*x0 + a1*x1 + ...) > 0 is equivalent to x0 + x1 + ... > 0 if xi >= 0 and ai > 0 for ints. # returns x0 + x1 + ... in such case, or None if not changed, ret = False, [] - for u in _get_chain(X, BinaryOps.ADD): + for u in split_uop(X, BinaryOps.ADD): # assumed the const is the last src of MUL if u.op is UOps.ALU and u.arg is BinaryOps.MUL and u.src[1].op is UOps.CONST and u.src[1].arg > 0: changed = True diff --git a/tinygrad/shape/shapetracker.py b/tinygrad/shape/shapetracker.py index bea81181e6..48793c6e05 100644 --- a/tinygrad/shape/shapetracker.py +++ b/tinygrad/shape/shapetracker.py @@ -5,7 +5,7 @@ from typing import Tuple, List, Optional, Dict, Set from tinygrad.helpers import merge_dicts, getenv from tinygrad.shape.view import View, strides_for_shape from tinygrad.dtype import dtypes -from tinygrad.ops import UOp, UOps, BinaryOps, graph_rewrite, _get_chain, symbolic_flat, Variable, sint +from tinygrad.ops import UOp, UOps, BinaryOps, graph_rewrite, split_uop, symbolic_flat, Variable, sint @dataclass(frozen=True) class ShapeTracker: @@ -75,7 +75,7 @@ class ShapeTracker: ret: List[Optional[sint]] = [None] * len(self.shape) idx, valid = self.to_indexed_uops() idx = graph_rewrite(idx, symbolic_flat) - for c in _get_chain(idx, BinaryOps.ADD): + for c in split_uop(idx, BinaryOps.ADD): if c.op is UOps.RANGE: ret[c.arg] = 1 if c.op is UOps.ALU and c.arg is BinaryOps.MUL and c.src[0].op is UOps.RANGE and c.src[1].op is UOps.CONST: ret[c.src[0].arg] = c.src[1].arg if c.op is UOps.ALU and c.arg is BinaryOps.MUL and c.src[1].op is UOps.RANGE and c.src[0].op is UOps.CONST: ret[c.src[1].arg] = c.src[0].arg