mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
use argfix in smax/smin and remove if [pr] (#8045)
This commit is contained in:
@@ -5,7 +5,7 @@ from enum import auto, IntEnum, Enum
|
||||
from dataclasses import dataclass, field
|
||||
from collections import defaultdict
|
||||
from tinygrad.dtype import ConstType, ImageDType, PtrDType, dtypes, DType, truncate
|
||||
from tinygrad.helpers import ContextVar, prod, getenv, all_same, Context, partition, temp, unwrap, T
|
||||
from tinygrad.helpers import ContextVar, prod, getenv, all_same, Context, partition, temp, unwrap, T, argfix
|
||||
if TYPE_CHECKING:
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
|
||||
@@ -180,11 +180,10 @@ def resolve(x, default:bool=True):
|
||||
|
||||
# smax/smin are replacements for max/min that preserve symbolic
|
||||
def _suop(lst, uop_fxn, python_fxn):
|
||||
max_uop, max_num = partition(lst, lambda x: isinstance(x, UOp))
|
||||
if len(max_uop): return functools.reduce(uop_fxn, (max_uop + [python_fxn(max_num)]) if len(max_num) else max_uop).ssimplify()
|
||||
return python_fxn(max_num)
|
||||
def smax(*lst): return _suop(lst[0] if isinstance(lst[0], (tuple, list)) else lst, UOp.maximum, max)
|
||||
def smin(*lst): return _suop(lst[0] if isinstance(lst[0], (tuple, list)) else lst, UOp.minimum, min)
|
||||
uops, nums = partition(lst, lambda x: isinstance(x, UOp))
|
||||
return ssimplify(functools.reduce(uop_fxn, uops + ([python_fxn(nums)] if nums else [])))
|
||||
def smax(*lst): return _suop(argfix(*lst), UOp.maximum, max)
|
||||
def smin(*lst): return _suop(argfix(*lst), UOp.minimum, min)
|
||||
|
||||
def ssimplify(uop): return uop.ssimplify() if isinstance(uop, UOp) else uop
|
||||
def sym_infer(uop: Union[UOp, int], var_vals: Dict[UOp, int]) -> int: return uop.sym_infer(var_vals) if isinstance(uop, UOp) else uop
|
||||
|
||||
Reference in New Issue
Block a user