mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-07 22:23:55 -05:00
can_safe_cast -> can_lossless_cast (#13789)
safe cast in numpy only means the result won't overflow, so lossless is more precise
This commit is contained in:
@@ -218,9 +218,9 @@ DTYPES_DICT = {k: v for k, v in dtypes.__dict__.items() if isinstance(v, DType)
|
||||
INVERSE_DTYPES_DICT = {**{v.name:k for k,v in DTYPES_DICT.items()}, "void": "void", "index":"index"}
|
||||
|
||||
@functools.cache
|
||||
def can_safe_cast(dt0:DType, dt1:DType) -> bool:
|
||||
def can_lossless_cast(dt0:DType, dt1:DType) -> bool:
|
||||
# return if dt1 preserves value of dt0
|
||||
# https://numpy.org/doc/stable/reference/generated/numpy.can_cast.html
|
||||
# similar to https://numpy.org/doc/stable/reference/generated/numpy.can_cast.html
|
||||
if dt0 == dt1 or dt0 == dtypes.bool: return True
|
||||
match dt1:
|
||||
case dtypes.index: return dt0 in dtypes.ints
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
import math, operator, struct, functools
|
||||
from collections import defaultdict
|
||||
from tinygrad.uop.ops import Ops, PatternMatcher, UPat, UOp, GroupOp, exec_alu
|
||||
from tinygrad.dtype import ConstType, dtypes, PtrDType, can_safe_cast, Invalid
|
||||
from tinygrad.dtype import ConstType, dtypes, PtrDType, can_lossless_cast, Invalid
|
||||
from tinygrad.helpers import partition, all_same, prod, flatten, get_single_element, unwrap, IMAGE, dedup
|
||||
from tinygrad.uop.decompositions import xpow
|
||||
from tinygrad.uop.divandmod import div_and_mod_symbolic
|
||||
@@ -97,7 +97,7 @@ symbolic_simple = propagate_invalid + PatternMatcher([
|
||||
(UPat((Ops.CAST, Ops.BITCAST), name="root"), lambda root: root.src[0] if root.dtype == root.src[0].dtype else None),
|
||||
(UPat(Ops.BITCAST, name="root", src=(UPat.cvar("c"),)), fold_bitcast),
|
||||
# b.cast(a).cast(b) -> b if a preserves all values in b
|
||||
(UPat.var('x').cast(name="a").cast(name="b"), lambda x,a,b: x if x.dtype == b.dtype and can_safe_cast(b.dtype, a.dtype) else None),
|
||||
(UPat.var('x').cast(name="a").cast(name="b"), lambda x,a,b: x if x.dtype == b.dtype and can_lossless_cast(b.dtype, a.dtype) else None),
|
||||
# ** pow **
|
||||
(UPat.var("x").alu(Ops.POW, UPat.cvar("c", vec=False)), simplify_pow),
|
||||
# positive const ** x
|
||||
@@ -242,7 +242,7 @@ symbolic = symbolic_simple+commutative+PatternMatcher([
|
||||
(UPat(Ops.RANGE, src=UPat.var("end"), name="r")//UPat.var("end"), lambda r,end: r.const_like(0)),
|
||||
# cast/long folding
|
||||
# if the intermediate cast doesnt narrow we can do it in one cast
|
||||
(UPat.var('x').cast(name="a").cast(name="b"), lambda x,a,b: x.cast(b.dtype) if can_safe_cast(x.dtype, a.dtype) else None),
|
||||
(UPat.var('x').cast(name="a").cast(name="b"), lambda x,a,b: x.cast(b.dtype) if can_lossless_cast(x.dtype, a.dtype) else None),
|
||||
(UPat.var('x', dtypes.ints+(dtypes.index,)).cast(dtypes.ints+(dtypes.index,), name="a").cast(name="b"),
|
||||
lambda x,a,b: x.cast(b.dtype) if a.dtype.min<=x.vmin and x.vmax<=a.dtype.max else None),
|
||||
# try to do math in int instead of long
|
||||
|
||||
Reference in New Issue
Block a user