diff --git a/test/test_const_folding.py b/test/test_const_folding.py index a117baadd9..86d33abd0a 100644 --- a/test/test_const_folding.py +++ b/test/test_const_folding.py @@ -174,9 +174,9 @@ class TestMovedConstFolding(unittest.TestCase): if is_dtype_supported(dtypes.uint16): _check_ast_count(0, Tensor.full(4, fill_value=-1).pad(((1, 1),)).cast(dtypes.uint16)) np.testing.assert_equal(Tensor.full(4, fill_value=-1).pad(((1, 1),)).cast(dtypes.uint16).numpy(), [0, 65535, 65535, 65535, 65535, 0]) - # not folded + # folded if is_dtype_supported(dtypes.int64): - _check_ast_count(1, Tensor.ones(4).pad(((1, 1),)).cast(dtypes.int64)) + _check_ast_count(0, Tensor.ones(4).pad(((1, 1),)).cast(dtypes.int64)) np.testing.assert_equal(Tensor.ones(4).pad(((1, 1),)).cast(dtypes.int64).numpy(), [0, 1, 1, 1, 1, 0]) class TestReduceOpsConstFolding(unittest.TestCase): diff --git a/test/test_schedule.py b/test/test_schedule.py index c6312a8fc4..68833b7c4d 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -6,6 +6,7 @@ import unittest import numpy as np import functools from typing import List, Optional, Union, cast +from hypothesis import assume, given, strategies as strat from tinygrad import nn, dtypes, Device, Tensor from tinygrad.device import is_dtype_supported @@ -1705,13 +1706,15 @@ class TestSchedule(unittest.TestCase): run_schedule(check_schedule(realized_const_view, 1)) self.assertListEqual(realized_const_view.tolist(), [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]]) - def test_cast_padded_const(self): - a = Tensor(1, dtype=dtypes.int32).reshape(1, 1).pad(((1, 1), None)) - casted_view = a.cast(dtypes.float32) + @given(strat.sampled_from(dtypes.all), strat.sampled_from(dtypes.all)) + def test_cast_padded_const(self, dt1, dt2): + assume(is_dtype_supported(dt1) and is_dtype_supported(dt2)) + a = Tensor(1, dtype=dt1).reshape(1, 1).pad(((1, 1), None)) + casted_view = a.cast(dt2) run_schedule(check_schedule(casted_view, 0)) realized_const_view = casted_view.contiguous() run_schedule(check_schedule(realized_const_view, 1)) - self.assertListEqual(realized_const_view.tolist(), [[0], [1], [0]]) + np.testing.assert_equal(realized_const_view.numpy(), [[0], [1], [0]]) class TestIndexing(unittest.TestCase): def check_schedule(self, xt:Union[Tensor,List[Tensor]], cnt:int): diff --git a/tinygrad/engine/grouper.py b/tinygrad/engine/grouper.py index cf55891473..e4e45d0675 100644 --- a/tinygrad/engine/grouper.py +++ b/tinygrad/engine/grouper.py @@ -84,6 +84,8 @@ sym = symbolic_simple+PatternMatcher([ # remove cast to image when it's already a contiguous image (UPat(Ops.CAST, name="cast", src=(UPat(Ops.VIEW, name="vm", src=(UPat(Ops.CONTIGUOUS, name="base"),)),)), lambda cast,base,vm: base.view(vm.st) if isinstance(cast.dtype, ImageDType) and isinstance(base.dtype, ImageDType) else None), + # CAST before masking constants + (UPat.cvar("x").view().cast(name="c"), lambda x,c: x.cast(c.dtype).view(c.src[0].arg)), # make things that can't be images not images (UPat(GroupOp.All-{Ops.BUFFER, Ops.VIEW, Ops.CONST, Ops.DEVICE}, name="u"), lambda u: u.replace(dtype=dt.base) if isinstance(dt:=u.dtype,ImageDType) and (prod(u.shape) != prod(dt.shape) or not any(u.shape[x]%4 == 0 for x in u.st.unit_stride_axes())) else None),