mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
reorder cast before masking constants (#10609)
* failing test from fuzzer * .numpy() handles bfloat16 better * const->view->cast becomes const->cast->view * update TestMovedConstFolding.test_cast_padded
This commit is contained in:
@@ -174,9 +174,9 @@ class TestMovedConstFolding(unittest.TestCase):
|
||||
if is_dtype_supported(dtypes.uint16):
|
||||
_check_ast_count(0, Tensor.full(4, fill_value=-1).pad(((1, 1),)).cast(dtypes.uint16))
|
||||
np.testing.assert_equal(Tensor.full(4, fill_value=-1).pad(((1, 1),)).cast(dtypes.uint16).numpy(), [0, 65535, 65535, 65535, 65535, 0])
|
||||
# not folded
|
||||
# folded
|
||||
if is_dtype_supported(dtypes.int64):
|
||||
_check_ast_count(1, Tensor.ones(4).pad(((1, 1),)).cast(dtypes.int64))
|
||||
_check_ast_count(0, Tensor.ones(4).pad(((1, 1),)).cast(dtypes.int64))
|
||||
np.testing.assert_equal(Tensor.ones(4).pad(((1, 1),)).cast(dtypes.int64).numpy(), [0, 1, 1, 1, 1, 0])
|
||||
|
||||
class TestReduceOpsConstFolding(unittest.TestCase):
|
||||
|
||||
@@ -6,6 +6,7 @@ import unittest
|
||||
import numpy as np
|
||||
import functools
|
||||
from typing import List, Optional, Union, cast
|
||||
from hypothesis import assume, given, strategies as strat
|
||||
|
||||
from tinygrad import nn, dtypes, Device, Tensor
|
||||
from tinygrad.device import is_dtype_supported
|
||||
@@ -1705,13 +1706,15 @@ class TestSchedule(unittest.TestCase):
|
||||
run_schedule(check_schedule(realized_const_view, 1))
|
||||
self.assertListEqual(realized_const_view.tolist(), [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]])
|
||||
|
||||
def test_cast_padded_const(self):
|
||||
a = Tensor(1, dtype=dtypes.int32).reshape(1, 1).pad(((1, 1), None))
|
||||
casted_view = a.cast(dtypes.float32)
|
||||
@given(strat.sampled_from(dtypes.all), strat.sampled_from(dtypes.all))
|
||||
def test_cast_padded_const(self, dt1, dt2):
|
||||
assume(is_dtype_supported(dt1) and is_dtype_supported(dt2))
|
||||
a = Tensor(1, dtype=dt1).reshape(1, 1).pad(((1, 1), None))
|
||||
casted_view = a.cast(dt2)
|
||||
run_schedule(check_schedule(casted_view, 0))
|
||||
realized_const_view = casted_view.contiguous()
|
||||
run_schedule(check_schedule(realized_const_view, 1))
|
||||
self.assertListEqual(realized_const_view.tolist(), [[0], [1], [0]])
|
||||
np.testing.assert_equal(realized_const_view.numpy(), [[0], [1], [0]])
|
||||
|
||||
class TestIndexing(unittest.TestCase):
|
||||
def check_schedule(self, xt:Union[Tensor,List[Tensor]], cnt:int):
|
||||
|
||||
@@ -84,6 +84,8 @@ sym = symbolic_simple+PatternMatcher([
|
||||
# remove cast to image when it's already a contiguous image
|
||||
(UPat(Ops.CAST, name="cast", src=(UPat(Ops.VIEW, name="vm", src=(UPat(Ops.CONTIGUOUS, name="base"),)),)),
|
||||
lambda cast,base,vm: base.view(vm.st) if isinstance(cast.dtype, ImageDType) and isinstance(base.dtype, ImageDType) else None),
|
||||
# CAST before masking constants
|
||||
(UPat.cvar("x").view().cast(name="c"), lambda x,c: x.cast(c.dtype).view(c.src[0].arg)),
|
||||
# make things that can't be images not images
|
||||
(UPat(GroupOp.All-{Ops.BUFFER, Ops.VIEW, Ops.CONST, Ops.DEVICE}, name="u"), lambda u: u.replace(dtype=dt.base) if isinstance(dt:=u.dtype,ImageDType)
|
||||
and (prod(u.shape) != prod(dt.shape) or not any(u.shape[x]%4 == 0 for x in u.st.unit_stride_axes())) else None),
|
||||
|
||||
Reference in New Issue
Block a user