test case to sum twice in different order (#12253)

* test case to sum twice in different order

fixed by #12251

* try metal
This commit is contained in:
chenyu
2025-09-20 10:11:57 -04:00
committed by GitHub
parent 4756971c88
commit 393c6b236c
2 changed files with 22 additions and 1 deletions

View File

@@ -559,6 +559,21 @@ jobs:
- name: Test ONNX
run: CL=1 RANGEIFY=1 python -m pytest -n=auto test/external/external_test_onnx_backend.py --durations=20
testrangeifymacos:
name: MacOS (rangeify)
runs-on: macos-14
timeout-minutes: 15
steps:
- name: Checkout Code
uses: actions/checkout@v4
- name: Setup Environment
uses: ./.github/actions/setup-tinygrad
with:
key: metal
deps: testing
- name: Test METAL=1 RANGEIFY=1
run: METAL=1 RANGEIFY=1 python -m pytest -n=auto test/test_ops.py --durations=20
testdevectorize:
name: Linux (devectorize)
runs-on: ubuntu-24.04

View File

@@ -2,7 +2,7 @@ import time, math, unittest, functools, platform, warnings
import numpy as np
from typing import List, Callable
import torch
from tinygrad.helpers import getenv, IMAGE, DEBUG, CI, Context, TRANSCENDENTAL, CPU_LLVM, AMD_LLVM, RANGEIFY
from tinygrad.helpers import getenv, IMAGE, DEBUG, CI, Context, TRANSCENDENTAL, CPU_LLVM, AMD_LLVM, RANGEIFY, OSX
from tinygrad import Tensor, Device, dtypes
from tinygrad.tensor import _to_np_dtype
from tinygrad.device import is_dtype_supported
@@ -312,6 +312,12 @@ class TestOps(unittest.TestCase):
helper_test_op([], lambda: torch.nn.functional.pad(torch.ones(256,256), pad=(0,64,0,0)).sum(axis=1),
lambda: Tensor.ones(256,256).pad(((0,0), (0,64))).sum(axis=1), forward_only=True)
@unittest.skipUnless(OSX or Device.DEFAULT=="CPU", "TODO fail on some devices")
def test_sum_twice(self):
helper_test_op([(4, 4, 4)], lambda x: x.sum((0, 1)).sum())
helper_test_op([(4, 4, 4)], lambda x: x.sum((0, 2)).sum())
helper_test_op([(4, 4, 4)], lambda x: x.sum((1, 2)).sum())
# this is more complex and won't fold for a while
def test_sum_cat_collapse(self):
helper_test_op([], lambda: torch.cat([torch.ones(256,256), torch.zeros(256,64)], dim=1).sum(axis=1),