mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
truncate unsigned const in cstyle (#11318)
it can be a warning or a hard error in clang PTX and PYTHON also need fix, skipping for now
This commit is contained in:
@@ -96,6 +96,14 @@ class TestDType(unittest.TestCase):
|
||||
get_available_cast_dtypes(self.DTYPE)
|
||||
))
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT == "PYTHON", "skip for now")
|
||||
@unittest.skipIf(getenv("PTX"), "skip for now")
|
||||
def test_uint_overflow(self):
|
||||
if not dtypes.is_unsigned(self.DTYPE): raise unittest.SkipTest("only for unsigned")
|
||||
v = dtypes.max(self.DTYPE)
|
||||
_test_to_np(Tensor(v, dtype=self.DTYPE)+2, _to_np_dtype(self.DTYPE), np.array(v, dtype=_to_np_dtype(self.DTYPE))+2)
|
||||
_test_to_np(Tensor(v, dtype=self.DTYPE)*2, _to_np_dtype(self.DTYPE), np.array(v, dtype=_to_np_dtype(self.DTYPE))*2)
|
||||
|
||||
def test_dtypes_fields(self):
|
||||
fields = dtypes.fields()
|
||||
self.assertIn("float", fields)
|
||||
|
||||
@@ -4,7 +4,7 @@ from collections import defaultdict, Counter
|
||||
from tinygrad.opt import tc
|
||||
from tinygrad.uop.ops import GroupOp, Ops, UOp, PatternMatcher, UPat
|
||||
from tinygrad.helpers import strip_parens, getenv, prod, dedup, AMX
|
||||
from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType, AddrSpace
|
||||
from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType, AddrSpace, truncate
|
||||
from tinygrad.renderer import Renderer
|
||||
from tinygrad.codegen.devectorizer import no_vectorized_alu
|
||||
|
||||
@@ -34,8 +34,8 @@ base_rewrite = PatternMatcher([
|
||||
(UPat(Ops.CONST, dtype=dtypes.floats, name="x"), lambda ctx,x: f"({ctx.render_cast(x.dtype, ctx.nan)})" if math.isnan(x.arg) else None),
|
||||
(UPat(Ops.CONST, dtype=dtypes.float, name="x"), lambda ctx,x: f"{x.arg}f"),
|
||||
(UPat(Ops.CONST, dtype=dtypes.int64, name="x"), lambda ctx,x: f"{x.arg}ll"),
|
||||
(UPat(Ops.CONST, dtype=dtypes.uint64, name="x"), lambda ctx,x: f"{x.arg}ull"),
|
||||
(UPat(Ops.CONST, dtype=dtypes.uint32, name="x"), lambda ctx,x: f"{x.arg}u"),
|
||||
(UPat(Ops.CONST, dtype=dtypes.uint64, name="x"), lambda ctx,x: f"{truncate[x.dtype](x.arg)}ull"),
|
||||
(UPat(Ops.CONST, dtype=dtypes.uint32, name="x"), lambda ctx,x: f"{truncate[x.dtype](x.arg)}u"),
|
||||
(UPat(Ops.CONST, dtype=dtypes.bool, name="x"), lambda ctx,x: "1" if x.arg else "0"),
|
||||
# consts are rendered to larger type and casted
|
||||
(UPat(Ops.CONST, (dtypes.bfloat16, dtypes.half), name="x"), lambda ctx,x: f"({ctx.render_cast(x.dtype, f'{x.arg}f')})"),
|
||||
|
||||
Reference in New Issue
Block a user