mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
dsp opt test has different axis number on rangeify (#12309)
This commit is contained in:
@@ -3,6 +3,7 @@ import numpy as np
|
||||
import unittest
|
||||
from dataclasses import replace
|
||||
from tinygrad import Tensor, Context, Device, dtypes
|
||||
from tinygrad.helpers import RANGEIFY
|
||||
from tinygrad.uop.ops import Ops
|
||||
from tinygrad.codegen.opt import Opt, OptOps
|
||||
from tinygrad.engine.realize import CompiledRunner, ExecItem, lower_schedule_item, get_program
|
||||
@@ -93,7 +94,8 @@ class TestQuantizeOnnx(unittest.TestCase):
|
||||
X = Tensor(np.random.uniform(0, 255, size=(1, 32, 128, 128)).astype(np.uint8))
|
||||
W = Tensor(np.random.uniform(0, 255, size=(64, 32, 1, 1)).astype(np.uint8))
|
||||
out = X.conv2d(W, dtype=X.dtype)
|
||||
opts = [Opt(op=OptOps.UPCAST, axis=1, arg=128), Opt(op=OptOps.UNROLL, axis=0, arg=4)]
|
||||
# rangeify merges axis in a different order
|
||||
opts = [Opt(op=OptOps.UPCAST, axis=0 if RANGEIFY else 1, arg=128), Opt(op=OptOps.UNROLL, axis=0, arg=4)]
|
||||
sexec(out, opts)
|
||||
|
||||
def test_prequant_gemm(self):
|
||||
|
||||
Reference in New Issue
Block a user