mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
bring cast before view back (#9230)
* bring cast before view back * tune it to only trigger on expands --------- Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
2
.github/workflows/test.yml
vendored
2
.github/workflows/test.yml
vendored
@@ -355,7 +355,7 @@ jobs:
|
||||
opencl: 'true'
|
||||
- name: Test openpilot model kernel count and gate usage
|
||||
run: |
|
||||
PYTHONPATH="." ALLOWED_KERNEL_COUNT=209 ALLOWED_READ_IMAGE=2105 ALLOWED_GATED_READ_IMAGE=29 FLOAT16=0 GPU=1 IMAGE=2 python examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/v0.9.4/selfdrive/modeld/models/supercombo.onnx
|
||||
PYTHONPATH="." ALLOWED_KERNEL_COUNT=209 ALLOWED_READ_IMAGE=2138 ALLOWED_GATED_READ_IMAGE=29 FLOAT16=0 GPU=1 IMAGE=2 python examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/v0.9.4/selfdrive/modeld/models/supercombo.onnx
|
||||
- name: Test openpilot alt model correctness (float32)
|
||||
run: PYTHONPATH="." FLOAT16=0 DEBUGCL=1 GPU=1 IMAGE=2 python examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/3799fe46b3a629e491d4b8498b8ae83e4c88c304/selfdrive/modeld/models/supercombo.onnx
|
||||
- name: Test openpilot fastvits model correctness (float32)
|
||||
|
||||
@@ -166,12 +166,11 @@ class TestMovedConstFolding(unittest.TestCase):
|
||||
|
||||
def test_cast_padded(self):
|
||||
# NOTE: this is folded due to CAST_BEFORE_VIEW
|
||||
# update: CAST_BEFORE_VIEW=1 is no longer supported
|
||||
if is_dtype_supported(dtypes.int16):
|
||||
_check_ast_count(1, Tensor.ones(4).pad(((1, 1),)).cast(dtypes.int16))
|
||||
_check_ast_count(0, Tensor.ones(4).pad(((1, 1),)).cast(dtypes.int16))
|
||||
np.testing.assert_equal(Tensor.ones(4).pad(((1, 1),)).cast(dtypes.int16).numpy(), [0, 1, 1, 1, 1, 0])
|
||||
if is_dtype_supported(dtypes.uint16):
|
||||
_check_ast_count(1, Tensor.full(4, fill_value=-1).pad(((1, 1),)).cast(dtypes.uint16))
|
||||
_check_ast_count(0, Tensor.full(4, fill_value=-1).pad(((1, 1),)).cast(dtypes.uint16))
|
||||
np.testing.assert_equal(Tensor.full(4, fill_value=-1).pad(((1, 1),)).cast(dtypes.uint16).numpy(), [0, 65535, 65535, 65535, 65535, 0])
|
||||
# not folded
|
||||
if is_dtype_supported(dtypes.int64):
|
||||
|
||||
@@ -794,8 +794,7 @@ class TestAutoCastType(unittest.TestCase):
|
||||
if DEBUG >= 2:
|
||||
print(f"testing {default_dtype=}, {dtype=}")
|
||||
a = Tensor([1, 2, 3], dtype=dtype, requires_grad=True)
|
||||
# NOTE: this is broken without default_dtype because of CAST_BEFORE_VIEW
|
||||
b = (a * 5).sum(acc_dtype=default_dtype)
|
||||
b = (a * 5).sum()
|
||||
b.backward() # if there is dtype mismatch, lazy should assert
|
||||
assert a.grad.dtype == a.dtype
|
||||
np.testing.assert_allclose(a.grad.numpy(), [5, 5, 5])
|
||||
|
||||
@@ -1540,7 +1540,6 @@ class TestSchedule(unittest.TestCase):
|
||||
def test_late_fusion_post_expand(self):
|
||||
self._test_fusion([(32, 32)], lambda a:a-a.sum(1), 2)
|
||||
|
||||
@unittest.skip("CAST_BEFORE_VIEW=1 is not supported")
|
||||
def test_cast_padded_view(self):
|
||||
a = Tensor.arange(4).reshape(1, 4)
|
||||
casted_view = a.pad(((0, 1), (0, 0))).cast(dtypes.float)
|
||||
@@ -1550,18 +1549,16 @@ class TestSchedule(unittest.TestCase):
|
||||
self.assertEqual(realized_view.lazydata.base.realized.size, 8)
|
||||
self.assertListEqual(realized_view.tolist(), [[0.0, 1.0, 2.0, 3.0], [0.0, 0.0, 0.0, 0.0]])
|
||||
|
||||
# NOTE: we might want to reconsider pushing this cast before the shrink
|
||||
@unittest.skip("CAST_BEFORE_VIEW=1 is not supported")
|
||||
# NOTE: we only reorder CAST if it's an EXPAND
|
||||
def test_cast_after_shrink(self):
|
||||
a = Tensor.arange(4).reshape(1, 4)
|
||||
casted_view = a.shrink(((0, 1), (0, 2))).cast(dtypes.float)
|
||||
casted_view.realize()
|
||||
self.assertEqual(casted_view.lazydata.base.realized.size, 4)
|
||||
self.assertEqual(casted_view.lazydata.base.realized.size, 2)
|
||||
realized_view = casted_view.contiguous().realize()
|
||||
self.assertEqual(realized_view.lazydata.base.realized.size, 2)
|
||||
self.assertListEqual(realized_view.tolist(), [[0, 1]])
|
||||
|
||||
@unittest.skip("CAST_BEFORE_VIEW=1 is not supported")
|
||||
def test_cast_const_view(self):
|
||||
a = Tensor.ones((4, 4), dtype=dtypes.float32)
|
||||
casted_view = a.cast(dtypes.int32)
|
||||
@@ -1571,7 +1568,6 @@ class TestSchedule(unittest.TestCase):
|
||||
run_schedule(check_schedule(realized_const_view, 1))
|
||||
self.assertListEqual(realized_const_view.tolist(), [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]])
|
||||
|
||||
@unittest.skip("CAST_BEFORE_VIEW=1 is not supported")
|
||||
def test_cast_padded_const(self):
|
||||
a = Tensor(1, dtype=dtypes.int32).reshape(1, 1).pad(((1, 1), None))
|
||||
casted_view = a.cast(dtypes.float32)
|
||||
|
||||
@@ -78,6 +78,9 @@ sym = symbolic_simple+PatternMatcher([
|
||||
# remove cast to image when it's already a contiguous image
|
||||
(UPat(Ops.CAST, name="cast", src=(UPat(Ops.VIEW, name="vm", src=(UPat(Ops.CONTIGUOUS, name="base"))),)),
|
||||
lambda cast,base,vm: base.view(vm.st) if isinstance(cast.dtype, ImageDType) and isinstance(base.dtype, ImageDType) else None),
|
||||
# put CAST to smaller dtype before EXPAND
|
||||
(UPat(Ops.CAST, name="cast", src=(UPat(Ops.VIEW, name="vm"),)), lambda cast,vm: vm.base.cast(cast.dtype).view(vm.st) \
|
||||
if cast.dtype.itemsize <= vm.dtype.itemsize and resolve(prod(vm.shape) > vm.st.real_size()) else None),
|
||||
# make things that can't be images not images
|
||||
(UPat(GroupOp.All-{Ops.BUFFER, Ops.VIEW, Ops.CONST, Ops.DEVICE}, name="u"), lambda u: u.replace(dtype=dt.base) if isinstance(dt:=u.dtype,ImageDType)
|
||||
and (prod(u.shape) != prod(dt.shape) or not any(u.shape[x]%4 == 0 for x in u.st.unit_stride_axes())) else None),
|
||||
|
||||
Reference in New Issue
Block a user