reorder cast before masking constants (#10609)

* failing test from fuzzer

* .numpy() handles bfloat16 better

* const->view->cast becomes const->cast->view

* update TestMovedConstFolding.test_cast_padded
This commit is contained in:
qazal
2025-06-03 15:44:03 +03:00
committed by GitHub
parent 910cabb081
commit ce9f12dc13
3 changed files with 11 additions and 6 deletions

View File

@@ -174,9 +174,9 @@ class TestMovedConstFolding(unittest.TestCase):
if is_dtype_supported(dtypes.uint16):
_check_ast_count(0, Tensor.full(4, fill_value=-1).pad(((1, 1),)).cast(dtypes.uint16))
np.testing.assert_equal(Tensor.full(4, fill_value=-1).pad(((1, 1),)).cast(dtypes.uint16).numpy(), [0, 65535, 65535, 65535, 65535, 0])
# not folded
# folded
if is_dtype_supported(dtypes.int64):
_check_ast_count(1, Tensor.ones(4).pad(((1, 1),)).cast(dtypes.int64))
_check_ast_count(0, Tensor.ones(4).pad(((1, 1),)).cast(dtypes.int64))
np.testing.assert_equal(Tensor.ones(4).pad(((1, 1),)).cast(dtypes.int64).numpy(), [0, 1, 1, 1, 1, 0])
class TestReduceOpsConstFolding(unittest.TestCase):

View File

@@ -6,6 +6,7 @@ import unittest
import numpy as np
import functools
from typing import List, Optional, Union, cast
from hypothesis import assume, given, strategies as strat
from tinygrad import nn, dtypes, Device, Tensor
from tinygrad.device import is_dtype_supported
@@ -1705,13 +1706,15 @@ class TestSchedule(unittest.TestCase):
run_schedule(check_schedule(realized_const_view, 1))
self.assertListEqual(realized_const_view.tolist(), [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]])
def test_cast_padded_const(self):
a = Tensor(1, dtype=dtypes.int32).reshape(1, 1).pad(((1, 1), None))
casted_view = a.cast(dtypes.float32)
@given(strat.sampled_from(dtypes.all), strat.sampled_from(dtypes.all))
def test_cast_padded_const(self, dt1, dt2):
assume(is_dtype_supported(dt1) and is_dtype_supported(dt2))
a = Tensor(1, dtype=dt1).reshape(1, 1).pad(((1, 1), None))
casted_view = a.cast(dt2)
run_schedule(check_schedule(casted_view, 0))
realized_const_view = casted_view.contiguous()
run_schedule(check_schedule(realized_const_view, 1))
self.assertListEqual(realized_const_view.tolist(), [[0], [1], [0]])
np.testing.assert_equal(realized_const_view.numpy(), [[0], [1], [0]])
class TestIndexing(unittest.TestCase):
def check_schedule(self, xt:Union[Tensor,List[Tensor]], cnt:int):

View File

@@ -84,6 +84,8 @@ sym = symbolic_simple+PatternMatcher([
# remove cast to image when it's already a contiguous image
(UPat(Ops.CAST, name="cast", src=(UPat(Ops.VIEW, name="vm", src=(UPat(Ops.CONTIGUOUS, name="base"),)),)),
lambda cast,base,vm: base.view(vm.st) if isinstance(cast.dtype, ImageDType) and isinstance(base.dtype, ImageDType) else None),
# CAST before masking constants
(UPat.cvar("x").view().cast(name="c"), lambda x,c: x.cast(c.dtype).view(c.src[0].arg)),
# make things that can't be images not images
(UPat(GroupOp.All-{Ops.BUFFER, Ops.VIEW, Ops.CONST, Ops.DEVICE}, name="u"), lambda u: u.replace(dtype=dt.base) if isinstance(dt:=u.dtype,ImageDType)
and (prod(u.shape) != prod(dt.shape) or not any(u.shape[x]%4 == 0 for x in u.st.unit_stride_axes())) else None),