can_safe_cast -> can_lossless_cast (#13789)

safe cast in numpy only means the result won't overflow, so lossless is more precise
This commit is contained in:
chenyu
2025-12-21 11:29:19 -05:00
committed by GitHub
parent ed1fd7023b
commit 29ef0809bb
2 changed files with 5 additions and 5 deletions

View File

@@ -218,9 +218,9 @@ DTYPES_DICT = {k: v for k, v in dtypes.__dict__.items() if isinstance(v, DType)
INVERSE_DTYPES_DICT = {**{v.name:k for k,v in DTYPES_DICT.items()}, "void": "void", "index":"index"}
@functools.cache
def can_safe_cast(dt0:DType, dt1:DType) -> bool:
def can_lossless_cast(dt0:DType, dt1:DType) -> bool:
# return if dt1 preserves value of dt0
# https://numpy.org/doc/stable/reference/generated/numpy.can_cast.html
# similar to https://numpy.org/doc/stable/reference/generated/numpy.can_cast.html
if dt0 == dt1 or dt0 == dtypes.bool: return True
match dt1:
case dtypes.index: return dt0 in dtypes.ints

View File

@@ -2,7 +2,7 @@
import math, operator, struct, functools
from collections import defaultdict
from tinygrad.uop.ops import Ops, PatternMatcher, UPat, UOp, GroupOp, exec_alu
from tinygrad.dtype import ConstType, dtypes, PtrDType, can_safe_cast, Invalid
from tinygrad.dtype import ConstType, dtypes, PtrDType, can_lossless_cast, Invalid
from tinygrad.helpers import partition, all_same, prod, flatten, get_single_element, unwrap, IMAGE, dedup
from tinygrad.uop.decompositions import xpow
from tinygrad.uop.divandmod import div_and_mod_symbolic
@@ -97,7 +97,7 @@ symbolic_simple = propagate_invalid + PatternMatcher([
(UPat((Ops.CAST, Ops.BITCAST), name="root"), lambda root: root.src[0] if root.dtype == root.src[0].dtype else None),
(UPat(Ops.BITCAST, name="root", src=(UPat.cvar("c"),)), fold_bitcast),
# b.cast(a).cast(b) -> b if a preserves all values in b
(UPat.var('x').cast(name="a").cast(name="b"), lambda x,a,b: x if x.dtype == b.dtype and can_safe_cast(b.dtype, a.dtype) else None),
(UPat.var('x').cast(name="a").cast(name="b"), lambda x,a,b: x if x.dtype == b.dtype and can_lossless_cast(b.dtype, a.dtype) else None),
# ** pow **
(UPat.var("x").alu(Ops.POW, UPat.cvar("c", vec=False)), simplify_pow),
# positive const ** x
@@ -242,7 +242,7 @@ symbolic = symbolic_simple+commutative+PatternMatcher([
(UPat(Ops.RANGE, src=UPat.var("end"), name="r")//UPat.var("end"), lambda r,end: r.const_like(0)),
# cast/long folding
# if the intermediate cast doesnt narrow we can do it in one cast
(UPat.var('x').cast(name="a").cast(name="b"), lambda x,a,b: x.cast(b.dtype) if can_safe_cast(x.dtype, a.dtype) else None),
(UPat.var('x').cast(name="a").cast(name="b"), lambda x,a,b: x.cast(b.dtype) if can_lossless_cast(x.dtype, a.dtype) else None),
(UPat.var('x', dtypes.ints+(dtypes.index,)).cast(dtypes.ints+(dtypes.index,), name="a").cast(name="b"),
lambda x,a,b: x.cast(b.dtype) if a.dtype.min<=x.vmin and x.vmax<=a.dtype.max else None),
# try to do math in int instead of long