truncate unsigned const in cstyle (#11318)

it can be a warning or a hard error in clang

PTX and PYTHON also need fix, skipping for now
This commit is contained in:
chenyu
2025-07-22 08:02:12 -04:00
committed by GitHub
parent 6668d6d241
commit a41140241b
2 changed files with 11 additions and 3 deletions

View File

@@ -96,6 +96,14 @@ class TestDType(unittest.TestCase):
get_available_cast_dtypes(self.DTYPE)
))
@unittest.skipIf(Device.DEFAULT == "PYTHON", "skip for now")
@unittest.skipIf(getenv("PTX"), "skip for now")
def test_uint_overflow(self):
if not dtypes.is_unsigned(self.DTYPE): raise unittest.SkipTest("only for unsigned")
v = dtypes.max(self.DTYPE)
_test_to_np(Tensor(v, dtype=self.DTYPE)+2, _to_np_dtype(self.DTYPE), np.array(v, dtype=_to_np_dtype(self.DTYPE))+2)
_test_to_np(Tensor(v, dtype=self.DTYPE)*2, _to_np_dtype(self.DTYPE), np.array(v, dtype=_to_np_dtype(self.DTYPE))*2)
def test_dtypes_fields(self):
fields = dtypes.fields()
self.assertIn("float", fields)

View File

@@ -4,7 +4,7 @@ from collections import defaultdict, Counter
from tinygrad.opt import tc
from tinygrad.uop.ops import GroupOp, Ops, UOp, PatternMatcher, UPat
from tinygrad.helpers import strip_parens, getenv, prod, dedup, AMX
from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType, AddrSpace
from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType, AddrSpace, truncate
from tinygrad.renderer import Renderer
from tinygrad.codegen.devectorizer import no_vectorized_alu
@@ -34,8 +34,8 @@ base_rewrite = PatternMatcher([
(UPat(Ops.CONST, dtype=dtypes.floats, name="x"), lambda ctx,x: f"({ctx.render_cast(x.dtype, ctx.nan)})" if math.isnan(x.arg) else None),
(UPat(Ops.CONST, dtype=dtypes.float, name="x"), lambda ctx,x: f"{x.arg}f"),
(UPat(Ops.CONST, dtype=dtypes.int64, name="x"), lambda ctx,x: f"{x.arg}ll"),
(UPat(Ops.CONST, dtype=dtypes.uint64, name="x"), lambda ctx,x: f"{x.arg}ull"),
(UPat(Ops.CONST, dtype=dtypes.uint32, name="x"), lambda ctx,x: f"{x.arg}u"),
(UPat(Ops.CONST, dtype=dtypes.uint64, name="x"), lambda ctx,x: f"{truncate[x.dtype](x.arg)}ull"),
(UPat(Ops.CONST, dtype=dtypes.uint32, name="x"), lambda ctx,x: f"{truncate[x.dtype](x.arg)}u"),
(UPat(Ops.CONST, dtype=dtypes.bool, name="x"), lambda ctx,x: "1" if x.arg else "0"),
# consts are rendered to larger type and casted
(UPat(Ops.CONST, (dtypes.bfloat16, dtypes.half), name="x"), lambda ctx,x: f"({ctx.render_cast(x.dtype, f'{x.arg}f')})"),