move valid related functions to ops.py [pr] (#7229)

This commit is contained in:
chenyu
2024-10-22 21:10:12 -04:00
committed by GitHub
parent e90bbe6bbc
commit 169cc348fe
2 changed files with 64 additions and 63 deletions

View File

@@ -2,9 +2,9 @@ from __future__ import annotations
from typing import Optional, Tuple, Dict, List, cast, TYPE_CHECKING, Any, DefaultDict, Callable
import functools, itertools, operator
from collections import defaultdict
from tinygrad.dtype import dtypes, PtrDType, ImageDType, ConstType
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, UOp, UOps, UPat, PatternMatcher
from tinygrad.ops import graph_rewrite, symbolic_flat, is_irreducible, split_uop, identity_element
from tinygrad.dtype import dtypes, PtrDType, ImageDType
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, UOp, UOps, UPat, PatternMatcher, symbolic_flat
from tinygrad.ops import graph_rewrite, is_irreducible, split_uop, identity_element, uop_given_valid, parse_valid, is_increasing, simplify_valid
from tinygrad.helpers import DEBUG, getenv, flatten, dedup, TRANSCENDENTAL, AMX, prod, partition, all_same
from tinygrad.codegen.transcendental import xexp2, xlog2, xsin, TRANSCENDENTAL_SUPPORTED_DTYPES
@@ -78,65 +78,6 @@ float4_folding = PatternMatcher([
# ***** image load valid simplification *****
def is_increasing(f:UOp) -> bool:
# is f a monotonically increasing function regards its input
if is_irreducible(f): return True
if f.op is UOps.ALU and f.arg is BinaryOps.ADD: return is_increasing(f.src[0]) and is_increasing(f.src[1])
if f.op is UOps.ALU and f.arg in (BinaryOps.MUL, BinaryOps.IDIV) and f.src[1].op is UOps.CONST and f.src[1].arg >= 0: return is_increasing(f.src[0])
return False # False if not sure
def parse_valid(valid:UOp) -> Tuple[UOp, bool, int]:
# if it's X <= c, returns X, True, c
# if it's X >= c, returns X, False, c
# (X < c).ne(True) -> X >= c
if valid.op is UOps.ALU and valid.arg is BinaryOps.CMPNE and valid.src[1].op is UOps.CONST and valid.src[1].arg == 1 and \
(s0:=valid.src[0]).op is UOps.ALU and s0.arg is BinaryOps.CMPLT and s0.src[1].op is UOps.CONST: return s0.src[0], False, s0.src[1].arg
# X < c -> X <= c-1
if valid.op is UOps.ALU and valid.arg is BinaryOps.CMPLT and valid.src[1].op is UOps.CONST: return valid.src[0], True, valid.src[1].arg-1
raise ValueError(f"not able to parse {valid=}")
def uop_given_valid(valid:UOp, uop:UOp) -> Optional[UOp]:
# return None if valid is always False, otherwise the simplified uop (might be the same as input)
# first, parse valid into {expr: (lower_bound, upper_bound)}
bounds:DefaultDict[UOp, List[Optional[ConstType]]] = defaultdict(lambda: [None, None])
for stmt in split_uop(valid, BinaryOps.AND):
try: expr, is_upper, c = parse_valid(stmt)
except ValueError: return uop # give up if we cannot parse the valid
bounds[expr][int(is_upper)] = c
# simplify uop given that valid is True
for expr,v in bounds.items():
# some expr has lower bound > upper bound -> valid is an empty set and we return None
if v[0] is not None and v[1] is not None and v[0] > v[1]: return None
# every candidate is a set of contrained UOp based on valid, and if every item in a set simplifies the uop into a same output, we rewrite uop
candidates = []
if expr.op is UOps.ALU and expr.arg is BinaryOps.ADD and all(is_irreducible(u) and u.vmin == 0 for u in split_uop(expr, BinaryOps.ADD)):
# if the constraint is a simplex: X0 + X1 + ... > 0, we can check if all Xi > 0 simplify into the same output
candidates.append([(Xi, UOp.variable("fake", 1, Xi.vmax, Xi.dtype)) for Xi in split_uop(expr, BinaryOps.ADD)])
# try checking the whole clause
candidates.append([(expr, UOp.variable("fake", expr.vmin if v[0] is None else v[0], expr.vmax if v[1] is None else v[1], expr.dtype))])
for candidate in candidates:
# if every branch in candidate gives the same simplified uop, we can rewrite the uop
newuops = [graph_rewrite(uop.substitute({X:newX}), sym).substitute({newX:X}) for X,newX in candidate]
if uop.op is UOps.VECTORIZE and len(uop.src) == 2:
if all_same([uops.src[0] for uops in newuops]): uop = uop.replace(src=(newuops[0].src[0], uop.src[1]))
if all_same([uops.src[1] for uops in newuops]): uop = uop.replace(src=(uop.src[0], newuops[0].src[1]))
elif all_same(newuops): uop = newuops[0]
return uop
def simplify_valid(valid:UOp) -> Optional[UOp]:
ret:List[UOp] = []
something_changed = False
for stmt in split_uop(valid, BinaryOps.AND):
ret.append(newstmt if ret and (newstmt:=uop_given_valid(functools.reduce(operator.and_, ret), stmt)) is not None else stmt)
if ret[-1] is not stmt: something_changed = True
return functools.reduce(operator.and_, ret) if something_changed else None
def simplify_buffer_load(load:UOp) -> Optional[UOp]:
if not isinstance(load.src[0].dtype, PtrDType) or len(load.src) != 4: return None
buf, start_idx, invalid_val, valid = load.src

View File

@@ -1,8 +1,9 @@
from __future__ import annotations
from typing import Any, List, Optional, Set, Union, Tuple, Dict, Callable, cast, TYPE_CHECKING, TypeVar
from typing import Any, List, Optional, Set, Union, Tuple, Dict, Callable, cast, TYPE_CHECKING, TypeVar, DefaultDict
import sys, time, functools, itertools, math, operator, hashlib, os, types, pickle
from enum import auto, IntEnum, Enum
from dataclasses import dataclass, field
from collections import defaultdict
from weakref import WeakValueDictionary
from tinygrad.dtype import ConstType, ImageDType, PtrDType, dtypes, DType, truncate
from tinygrad.helpers import ContextVar, prod, getenv, all_same, Context, partition
@@ -891,6 +892,65 @@ def canonicalize_simplex(X:UOp) -> Optional[UOp]:
ret.append(u)
return functools.reduce(operator.add, ret) if changed else None
def is_increasing(f:UOp) -> bool:
# is f a monotonically increasing function regards its input
if is_irreducible(f): return True
if f.op is UOps.ALU and f.arg is BinaryOps.ADD: return is_increasing(f.src[0]) and is_increasing(f.src[1])
if f.op is UOps.ALU and f.arg in (BinaryOps.MUL, BinaryOps.IDIV) and f.src[1].op is UOps.CONST and f.src[1].arg >= 0: return is_increasing(f.src[0])
return False # False if not sure
def parse_valid(valid:UOp) -> Tuple[UOp, bool, int]:
# if it's X <= c, returns X, True, c
# if it's X >= c, returns X, False, c
# (X < c).ne(True) -> X >= c
if valid.op is UOps.ALU and valid.arg is BinaryOps.CMPNE and valid.src[1].op is UOps.CONST and valid.src[1].arg == 1 and \
(s0:=valid.src[0]).op is UOps.ALU and s0.arg is BinaryOps.CMPLT and s0.src[1].op is UOps.CONST: return s0.src[0], False, s0.src[1].arg
# X < c -> X <= c-1
if valid.op is UOps.ALU and valid.arg is BinaryOps.CMPLT and valid.src[1].op is UOps.CONST: return valid.src[0], True, valid.src[1].arg-1
raise ValueError(f"not able to parse {valid=}")
def uop_given_valid(valid:UOp, uop:UOp) -> Optional[UOp]:
# return None if valid is always False, otherwise the simplified uop (might be the same as input)
# first, parse valid into {expr: (lower_bound, upper_bound)}
bounds:DefaultDict[UOp, List[Optional[ConstType]]] = defaultdict(lambda: [None, None])
for stmt in split_uop(valid, BinaryOps.AND):
try: expr, is_upper, c = parse_valid(stmt)
except ValueError: return uop # give up if we cannot parse the valid
bounds[expr][int(is_upper)] = c
# simplify uop given that valid is True
for expr,v in bounds.items():
# some expr has lower bound > upper bound -> valid is an empty set and we return None
if v[0] is not None and v[1] is not None and v[0] > v[1]: return None
# every candidate is a set of contrained UOp based on valid, and if every item in a set simplifies the uop into a same output, we rewrite uop
candidates = []
if expr.op is UOps.ALU and expr.arg is BinaryOps.ADD and all(is_irreducible(u) and u.vmin == 0 for u in split_uop(expr, BinaryOps.ADD)):
# if the constraint is a simplex: X0 + X1 + ... > 0, we can check if all Xi > 0 simplify into the same output
candidates.append([(Xi, UOp.variable("fake", 1, Xi.vmax, Xi.dtype)) for Xi in split_uop(expr, BinaryOps.ADD)])
# try checking the whole clause
candidates.append([(expr, UOp.variable("fake", expr.vmin if v[0] is None else v[0], expr.vmax if v[1] is None else v[1], expr.dtype))])
for candidate in candidates:
# if every branch in candidate gives the same simplified uop, we can rewrite the uop
newuops = [graph_rewrite(uop.substitute({X:newX}), symbolic_flat).substitute({newX:X}) for X,newX in candidate]
if uop.op is UOps.VECTORIZE and len(uop.src) == 2:
if all_same([uops.src[0] for uops in newuops]): uop = uop.replace(src=(newuops[0].src[0], uop.src[1]))
if all_same([uops.src[1] for uops in newuops]): uop = uop.replace(src=(uop.src[0], newuops[0].src[1]))
elif all_same(newuops): uop = newuops[0]
return uop
def simplify_valid(valid:UOp) -> Optional[UOp]:
ret:List[UOp] = []
something_changed = False
for stmt in split_uop(valid, BinaryOps.AND):
ret.append(newstmt if ret and (newstmt:=uop_given_valid(functools.reduce(operator.and_, ret), stmt)) is not None else stmt)
if ret[-1] is not stmt: something_changed = True
return functools.reduce(operator.and_, ret) if something_changed else None
symbolic = PatternMatcher([
# bool MUL is AND, ADD/MAX is OR. prevents other rules to rewrite bool ADD/MUL incorrectly
(UPat.var('x', dtype=dtypes.bool) * UPat.var('y'), lambda x,y: x&y),