mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
move valid related functions to ops.py [pr] (#7229)
This commit is contained in:
@@ -2,9 +2,9 @@ from __future__ import annotations
|
||||
from typing import Optional, Tuple, Dict, List, cast, TYPE_CHECKING, Any, DefaultDict, Callable
|
||||
import functools, itertools, operator
|
||||
from collections import defaultdict
|
||||
from tinygrad.dtype import dtypes, PtrDType, ImageDType, ConstType
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, UOp, UOps, UPat, PatternMatcher
|
||||
from tinygrad.ops import graph_rewrite, symbolic_flat, is_irreducible, split_uop, identity_element
|
||||
from tinygrad.dtype import dtypes, PtrDType, ImageDType
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, UOp, UOps, UPat, PatternMatcher, symbolic_flat
|
||||
from tinygrad.ops import graph_rewrite, is_irreducible, split_uop, identity_element, uop_given_valid, parse_valid, is_increasing, simplify_valid
|
||||
from tinygrad.helpers import DEBUG, getenv, flatten, dedup, TRANSCENDENTAL, AMX, prod, partition, all_same
|
||||
from tinygrad.codegen.transcendental import xexp2, xlog2, xsin, TRANSCENDENTAL_SUPPORTED_DTYPES
|
||||
|
||||
@@ -78,65 +78,6 @@ float4_folding = PatternMatcher([
|
||||
|
||||
# ***** image load valid simplification *****
|
||||
|
||||
def is_increasing(f:UOp) -> bool:
|
||||
# is f a monotonically increasing function regards its input
|
||||
if is_irreducible(f): return True
|
||||
if f.op is UOps.ALU and f.arg is BinaryOps.ADD: return is_increasing(f.src[0]) and is_increasing(f.src[1])
|
||||
if f.op is UOps.ALU and f.arg in (BinaryOps.MUL, BinaryOps.IDIV) and f.src[1].op is UOps.CONST and f.src[1].arg >= 0: return is_increasing(f.src[0])
|
||||
return False # False if not sure
|
||||
|
||||
def parse_valid(valid:UOp) -> Tuple[UOp, bool, int]:
|
||||
# if it's X <= c, returns X, True, c
|
||||
# if it's X >= c, returns X, False, c
|
||||
|
||||
# (X < c).ne(True) -> X >= c
|
||||
if valid.op is UOps.ALU and valid.arg is BinaryOps.CMPNE and valid.src[1].op is UOps.CONST and valid.src[1].arg == 1 and \
|
||||
(s0:=valid.src[0]).op is UOps.ALU and s0.arg is BinaryOps.CMPLT and s0.src[1].op is UOps.CONST: return s0.src[0], False, s0.src[1].arg
|
||||
# X < c -> X <= c-1
|
||||
if valid.op is UOps.ALU and valid.arg is BinaryOps.CMPLT and valid.src[1].op is UOps.CONST: return valid.src[0], True, valid.src[1].arg-1
|
||||
raise ValueError(f"not able to parse {valid=}")
|
||||
|
||||
def uop_given_valid(valid:UOp, uop:UOp) -> Optional[UOp]:
|
||||
# return None if valid is always False, otherwise the simplified uop (might be the same as input)
|
||||
|
||||
# first, parse valid into {expr: (lower_bound, upper_bound)}
|
||||
bounds:DefaultDict[UOp, List[Optional[ConstType]]] = defaultdict(lambda: [None, None])
|
||||
for stmt in split_uop(valid, BinaryOps.AND):
|
||||
try: expr, is_upper, c = parse_valid(stmt)
|
||||
except ValueError: return uop # give up if we cannot parse the valid
|
||||
bounds[expr][int(is_upper)] = c
|
||||
|
||||
# simplify uop given that valid is True
|
||||
for expr,v in bounds.items():
|
||||
# some expr has lower bound > upper bound -> valid is an empty set and we return None
|
||||
if v[0] is not None and v[1] is not None and v[0] > v[1]: return None
|
||||
|
||||
# every candidate is a set of contrained UOp based on valid, and if every item in a set simplifies the uop into a same output, we rewrite uop
|
||||
candidates = []
|
||||
if expr.op is UOps.ALU and expr.arg is BinaryOps.ADD and all(is_irreducible(u) and u.vmin == 0 for u in split_uop(expr, BinaryOps.ADD)):
|
||||
# if the constraint is a simplex: X0 + X1 + ... > 0, we can check if all Xi > 0 simplify into the same output
|
||||
candidates.append([(Xi, UOp.variable("fake", 1, Xi.vmax, Xi.dtype)) for Xi in split_uop(expr, BinaryOps.ADD)])
|
||||
# try checking the whole clause
|
||||
candidates.append([(expr, UOp.variable("fake", expr.vmin if v[0] is None else v[0], expr.vmax if v[1] is None else v[1], expr.dtype))])
|
||||
|
||||
for candidate in candidates:
|
||||
# if every branch in candidate gives the same simplified uop, we can rewrite the uop
|
||||
newuops = [graph_rewrite(uop.substitute({X:newX}), sym).substitute({newX:X}) for X,newX in candidate]
|
||||
if uop.op is UOps.VECTORIZE and len(uop.src) == 2:
|
||||
if all_same([uops.src[0] for uops in newuops]): uop = uop.replace(src=(newuops[0].src[0], uop.src[1]))
|
||||
if all_same([uops.src[1] for uops in newuops]): uop = uop.replace(src=(uop.src[0], newuops[0].src[1]))
|
||||
elif all_same(newuops): uop = newuops[0]
|
||||
|
||||
return uop
|
||||
|
||||
def simplify_valid(valid:UOp) -> Optional[UOp]:
|
||||
ret:List[UOp] = []
|
||||
something_changed = False
|
||||
for stmt in split_uop(valid, BinaryOps.AND):
|
||||
ret.append(newstmt if ret and (newstmt:=uop_given_valid(functools.reduce(operator.and_, ret), stmt)) is not None else stmt)
|
||||
if ret[-1] is not stmt: something_changed = True
|
||||
return functools.reduce(operator.and_, ret) if something_changed else None
|
||||
|
||||
def simplify_buffer_load(load:UOp) -> Optional[UOp]:
|
||||
if not isinstance(load.src[0].dtype, PtrDType) or len(load.src) != 4: return None
|
||||
buf, start_idx, invalid_val, valid = load.src
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
from __future__ import annotations
|
||||
from typing import Any, List, Optional, Set, Union, Tuple, Dict, Callable, cast, TYPE_CHECKING, TypeVar
|
||||
from typing import Any, List, Optional, Set, Union, Tuple, Dict, Callable, cast, TYPE_CHECKING, TypeVar, DefaultDict
|
||||
import sys, time, functools, itertools, math, operator, hashlib, os, types, pickle
|
||||
from enum import auto, IntEnum, Enum
|
||||
from dataclasses import dataclass, field
|
||||
from collections import defaultdict
|
||||
from weakref import WeakValueDictionary
|
||||
from tinygrad.dtype import ConstType, ImageDType, PtrDType, dtypes, DType, truncate
|
||||
from tinygrad.helpers import ContextVar, prod, getenv, all_same, Context, partition
|
||||
@@ -891,6 +892,65 @@ def canonicalize_simplex(X:UOp) -> Optional[UOp]:
|
||||
ret.append(u)
|
||||
return functools.reduce(operator.add, ret) if changed else None
|
||||
|
||||
def is_increasing(f:UOp) -> bool:
|
||||
# is f a monotonically increasing function regards its input
|
||||
if is_irreducible(f): return True
|
||||
if f.op is UOps.ALU and f.arg is BinaryOps.ADD: return is_increasing(f.src[0]) and is_increasing(f.src[1])
|
||||
if f.op is UOps.ALU and f.arg in (BinaryOps.MUL, BinaryOps.IDIV) and f.src[1].op is UOps.CONST and f.src[1].arg >= 0: return is_increasing(f.src[0])
|
||||
return False # False if not sure
|
||||
|
||||
def parse_valid(valid:UOp) -> Tuple[UOp, bool, int]:
|
||||
# if it's X <= c, returns X, True, c
|
||||
# if it's X >= c, returns X, False, c
|
||||
|
||||
# (X < c).ne(True) -> X >= c
|
||||
if valid.op is UOps.ALU and valid.arg is BinaryOps.CMPNE and valid.src[1].op is UOps.CONST and valid.src[1].arg == 1 and \
|
||||
(s0:=valid.src[0]).op is UOps.ALU and s0.arg is BinaryOps.CMPLT and s0.src[1].op is UOps.CONST: return s0.src[0], False, s0.src[1].arg
|
||||
# X < c -> X <= c-1
|
||||
if valid.op is UOps.ALU and valid.arg is BinaryOps.CMPLT and valid.src[1].op is UOps.CONST: return valid.src[0], True, valid.src[1].arg-1
|
||||
raise ValueError(f"not able to parse {valid=}")
|
||||
|
||||
def uop_given_valid(valid:UOp, uop:UOp) -> Optional[UOp]:
|
||||
# return None if valid is always False, otherwise the simplified uop (might be the same as input)
|
||||
|
||||
# first, parse valid into {expr: (lower_bound, upper_bound)}
|
||||
bounds:DefaultDict[UOp, List[Optional[ConstType]]] = defaultdict(lambda: [None, None])
|
||||
for stmt in split_uop(valid, BinaryOps.AND):
|
||||
try: expr, is_upper, c = parse_valid(stmt)
|
||||
except ValueError: return uop # give up if we cannot parse the valid
|
||||
bounds[expr][int(is_upper)] = c
|
||||
|
||||
# simplify uop given that valid is True
|
||||
for expr,v in bounds.items():
|
||||
# some expr has lower bound > upper bound -> valid is an empty set and we return None
|
||||
if v[0] is not None and v[1] is not None and v[0] > v[1]: return None
|
||||
|
||||
# every candidate is a set of contrained UOp based on valid, and if every item in a set simplifies the uop into a same output, we rewrite uop
|
||||
candidates = []
|
||||
if expr.op is UOps.ALU and expr.arg is BinaryOps.ADD and all(is_irreducible(u) and u.vmin == 0 for u in split_uop(expr, BinaryOps.ADD)):
|
||||
# if the constraint is a simplex: X0 + X1 + ... > 0, we can check if all Xi > 0 simplify into the same output
|
||||
candidates.append([(Xi, UOp.variable("fake", 1, Xi.vmax, Xi.dtype)) for Xi in split_uop(expr, BinaryOps.ADD)])
|
||||
# try checking the whole clause
|
||||
candidates.append([(expr, UOp.variable("fake", expr.vmin if v[0] is None else v[0], expr.vmax if v[1] is None else v[1], expr.dtype))])
|
||||
|
||||
for candidate in candidates:
|
||||
# if every branch in candidate gives the same simplified uop, we can rewrite the uop
|
||||
newuops = [graph_rewrite(uop.substitute({X:newX}), symbolic_flat).substitute({newX:X}) for X,newX in candidate]
|
||||
if uop.op is UOps.VECTORIZE and len(uop.src) == 2:
|
||||
if all_same([uops.src[0] for uops in newuops]): uop = uop.replace(src=(newuops[0].src[0], uop.src[1]))
|
||||
if all_same([uops.src[1] for uops in newuops]): uop = uop.replace(src=(uop.src[0], newuops[0].src[1]))
|
||||
elif all_same(newuops): uop = newuops[0]
|
||||
|
||||
return uop
|
||||
|
||||
def simplify_valid(valid:UOp) -> Optional[UOp]:
|
||||
ret:List[UOp] = []
|
||||
something_changed = False
|
||||
for stmt in split_uop(valid, BinaryOps.AND):
|
||||
ret.append(newstmt if ret and (newstmt:=uop_given_valid(functools.reduce(operator.and_, ret), stmt)) is not None else stmt)
|
||||
if ret[-1] is not stmt: something_changed = True
|
||||
return functools.reduce(operator.and_, ret) if something_changed else None
|
||||
|
||||
symbolic = PatternMatcher([
|
||||
# bool MUL is AND, ADD/MAX is OR. prevents other rules to rewrite bool ADD/MUL incorrectly
|
||||
(UPat.var('x', dtype=dtypes.bool) * UPat.var('y'), lambda x,y: x&y),
|
||||
|
||||
Reference in New Issue
Block a user