diff --git a/tinygrad/dtype.py b/tinygrad/dtype.py index 7adf8a0cf0..5ae5f05a3f 100644 --- a/tinygrad/dtype.py +++ b/tinygrad/dtype.py @@ -218,9 +218,9 @@ DTYPES_DICT = {k: v for k, v in dtypes.__dict__.items() if isinstance(v, DType) INVERSE_DTYPES_DICT = {**{v.name:k for k,v in DTYPES_DICT.items()}, "void": "void", "index":"index"} @functools.cache -def can_safe_cast(dt0:DType, dt1:DType) -> bool: +def can_lossless_cast(dt0:DType, dt1:DType) -> bool: # return if dt1 preserves value of dt0 - # https://numpy.org/doc/stable/reference/generated/numpy.can_cast.html + # similar to https://numpy.org/doc/stable/reference/generated/numpy.can_cast.html if dt0 == dt1 or dt0 == dtypes.bool: return True match dt1: case dtypes.index: return dt0 in dtypes.ints diff --git a/tinygrad/uop/symbolic.py b/tinygrad/uop/symbolic.py index bdefa8e069..b18da2ca89 100644 --- a/tinygrad/uop/symbolic.py +++ b/tinygrad/uop/symbolic.py @@ -2,7 +2,7 @@ import math, operator, struct, functools from collections import defaultdict from tinygrad.uop.ops import Ops, PatternMatcher, UPat, UOp, GroupOp, exec_alu -from tinygrad.dtype import ConstType, dtypes, PtrDType, can_safe_cast, Invalid +from tinygrad.dtype import ConstType, dtypes, PtrDType, can_lossless_cast, Invalid from tinygrad.helpers import partition, all_same, prod, flatten, get_single_element, unwrap, IMAGE, dedup from tinygrad.uop.decompositions import xpow from tinygrad.uop.divandmod import div_and_mod_symbolic @@ -97,7 +97,7 @@ symbolic_simple = propagate_invalid + PatternMatcher([ (UPat((Ops.CAST, Ops.BITCAST), name="root"), lambda root: root.src[0] if root.dtype == root.src[0].dtype else None), (UPat(Ops.BITCAST, name="root", src=(UPat.cvar("c"),)), fold_bitcast), # b.cast(a).cast(b) -> b if a preserves all values in b - (UPat.var('x').cast(name="a").cast(name="b"), lambda x,a,b: x if x.dtype == b.dtype and can_safe_cast(b.dtype, a.dtype) else None), + (UPat.var('x').cast(name="a").cast(name="b"), lambda x,a,b: x if x.dtype == b.dtype and can_lossless_cast(b.dtype, a.dtype) else None), # ** pow ** (UPat.var("x").alu(Ops.POW, UPat.cvar("c", vec=False)), simplify_pow), # positive const ** x @@ -242,7 +242,7 @@ symbolic = symbolic_simple+commutative+PatternMatcher([ (UPat(Ops.RANGE, src=UPat.var("end"), name="r")//UPat.var("end"), lambda r,end: r.const_like(0)), # cast/long folding # if the intermediate cast doesnt narrow we can do it in one cast - (UPat.var('x').cast(name="a").cast(name="b"), lambda x,a,b: x.cast(b.dtype) if can_safe_cast(x.dtype, a.dtype) else None), + (UPat.var('x').cast(name="a").cast(name="b"), lambda x,a,b: x.cast(b.dtype) if can_lossless_cast(x.dtype, a.dtype) else None), (UPat.var('x', dtypes.ints+(dtypes.index,)).cast(dtypes.ints+(dtypes.index,), name="a").cast(name="b"), lambda x,a,b: x.cast(b.dtype) if a.dtype.min<=x.vmin and x.vmax<=a.dtype.max else None), # try to do math in int instead of long