From 0ce0f5101028e8b5ca70eb8cf8241b7baf945e45 Mon Sep 17 00:00:00 2001 From: chenyu Date: Sat, 2 Aug 2025 16:26:37 -0700 Subject: [PATCH] generic double cast folding (#11481) b.cast(a).cast(b) -> b if a preserves all values in b --- tinygrad/dtype.py | 15 +++++++++++++++ tinygrad/uop/symbolic.py | 5 +++-- 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/tinygrad/dtype.py b/tinygrad/dtype.py index cc971df3a4..2fb651763f 100644 --- a/tinygrad/dtype.py +++ b/tinygrad/dtype.py @@ -193,6 +193,21 @@ def least_upper_float(dt:DType) -> DType: return dt if dtypes.is_float(dt) else DTYPES_DICT = {k: v for k, v in dtypes.__dict__.items() if isinstance(v, DType) and not k.startswith(("default", "void"))} INVERSE_DTYPES_DICT = {**{v.name:k for k,v in DTYPES_DICT.items()}, "void": "void"} +@functools.cache +def can_safe_cast(dt0:DType, dt1:DType) -> bool: + # return if dt1 preserves value of dt0 + # https://numpy.org/doc/stable/reference/generated/numpy.can_cast.html + if dt0 == dt1 or dt0 == dtypes.bool: return True + match dt1: + case dtypes.double: return dt0 in (dtypes.float, dtypes.half, dtypes.bfloat16) + case dtypes.float: return dt0 in (dtypes.half, dtypes.bfloat16) + case dtypes.uint64: return dt0 in (dtypes.uint32, dtypes.uint16, dtypes.uint8) + case dtypes.uint32: return dt0 in (dtypes.uint16, dtypes.uint8) + case dtypes.int64: return dt0 in (dtypes.uint32, dtypes.uint16, dtypes.uint8, dtypes.int32, dtypes.int16, dtypes.int8) + case dtypes.int32: return dt0 in (dtypes.uint16, dtypes.uint8, dtypes.int16, dtypes.int8) + case dtypes.int16: return dt0 in (dtypes.uint8, dtypes.int8) + case _: return False + def sum_acc_dtype(dt:DType): # default acc dtype for sum if dtypes.is_unsigned(dt): return least_upper_dtype(dt, dtypes.uint) diff --git a/tinygrad/uop/symbolic.py b/tinygrad/uop/symbolic.py index 867b1b61b9..2c295d0af9 100644 --- a/tinygrad/uop/symbolic.py +++ b/tinygrad/uop/symbolic.py @@ -3,7 +3,7 @@ from typing import Any, Literal, cast import math, operator, struct, functools from collections import defaultdict from tinygrad.uop.ops import Ops, PatternMatcher, UPat, UOp, GroupOp, exec_alu -from tinygrad.dtype import ConstType, dtypes, PtrDType, AddrSpace +from tinygrad.dtype import ConstType, dtypes, PtrDType, AddrSpace, can_safe_cast from tinygrad.helpers import partition, all_same, prod, flatten, get_single_element, cdiv, cmod, CORRECT_DIVMOD_FOLDING from tinygrad.uop.transcendental import xpow @@ -65,6 +65,8 @@ symbolic_simple = PatternMatcher([ (UPat(Ops.CAST, name="root", src=(UPat.cvar("c"),)), lambda root, c: root.const_like(c.arg)), (UPat((Ops.CAST, Ops.BITCAST), name="root"), lambda root: root.src[0] if root.dtype == root.src[0].dtype else None), (UPat(Ops.BITCAST, name="root", src=(UPat.cvar("c"),)), fold_bitcast), + # b.cast(a).cast(b) -> b if a preserves all values in b + (UPat.var('x').cast().named('a').cast().named('b'), lambda x,a,b: x if x.dtype == b.dtype and can_safe_cast(b.dtype, a.dtype) else None), # ** pow ** (UPat.var("x").alu(Ops.POW, UPat.cvar("c", vec=False)), simplify_pow), # positive const ** x @@ -427,7 +429,6 @@ sym = symbolic_flat+PatternMatcher([ (UPat(Ops.WMMA, src=(UPat.var(), UPat.const(None, 0.0), UPat.var("acc"))), lambda acc: acc), # threefry + remove longs (UPat(Ops.THREEFRY, dtype=dtypes.uint64, src=(UPat.var("x"), UPat.var("key"))), threefry2x32), - (UPat.var('x', dtypes.uint32).cast(dtypes.uint64).cast(dtypes.uint32), lambda x: x), # cast there and back is noop (TODO: genericize) ((UPat.var('x', dtypes.uint64)&0xFFFFFFFF).cast(dtypes.uint32), lambda x: x.cast(dtypes.uint32)), # cast does truncation (((UPat.var(None, dtypes.uint64)*(1<<32)) | UPat.var('y', dtypes.uint32).cast(dtypes.uint64)).cast(dtypes.uint32), lambda y: y), (((UPat.var('x', dtypes.uint64)*(1<<32)) | UPat.var(None, dtypes.uint32).cast(dtypes.uint64))//(1<<32), lambda x: x),