use argfix in smax/smin and remove if [pr] (#8045)

This commit is contained in:
chenyu
2024-12-04 17:06:13 -05:00
committed by GitHub
parent 4e518334b8
commit 5933ec8dc3

View File

@@ -5,7 +5,7 @@ from enum import auto, IntEnum, Enum
from dataclasses import dataclass, field
from collections import defaultdict
from tinygrad.dtype import ConstType, ImageDType, PtrDType, dtypes, DType, truncate
from tinygrad.helpers import ContextVar, prod, getenv, all_same, Context, partition, temp, unwrap, T
from tinygrad.helpers import ContextVar, prod, getenv, all_same, Context, partition, temp, unwrap, T, argfix
if TYPE_CHECKING:
from tinygrad.shape.shapetracker import ShapeTracker
@@ -180,11 +180,10 @@ def resolve(x, default:bool=True):
# smax/smin are replacements for max/min that preserve symbolic
def _suop(lst, uop_fxn, python_fxn):
max_uop, max_num = partition(lst, lambda x: isinstance(x, UOp))
if len(max_uop): return functools.reduce(uop_fxn, (max_uop + [python_fxn(max_num)]) if len(max_num) else max_uop).ssimplify()
return python_fxn(max_num)
def smax(*lst): return _suop(lst[0] if isinstance(lst[0], (tuple, list)) else lst, UOp.maximum, max)
def smin(*lst): return _suop(lst[0] if isinstance(lst[0], (tuple, list)) else lst, UOp.minimum, min)
uops, nums = partition(lst, lambda x: isinstance(x, UOp))
return ssimplify(functools.reduce(uop_fxn, uops + ([python_fxn(nums)] if nums else [])))
def smax(*lst): return _suop(argfix(*lst), UOp.maximum, max)
def smin(*lst): return _suop(argfix(*lst), UOp.minimum, min)
def ssimplify(uop): return uop.ssimplify() if isinstance(uop, UOp) else uop
def sym_infer(uop: Union[UOp, int], var_vals: Dict[UOp, int]) -> int: return uop.sym_infer(var_vals) if isinstance(uop, UOp) else uop