quantization preprocessor from DSP, should be universal (#9437)

* quantization preprocessor from DSP, should be universal

* touchups

* fix tests
This commit is contained in:
George Hotz
2025-03-15 07:49:37 +08:00
committed by GitHub
parent ca5064a5b6
commit cb7a7f69c7
7 changed files with 106 additions and 20 deletions

View File

@@ -2,8 +2,9 @@ import numpy as np
import unittest
from dataclasses import replace
from tinygrad import Tensor, Context, Device, dtypes
from tinygrad.ops import Ops
from tinygrad.codegen.kernel import Kernel, Opt, OptOps
from tinygrad.engine.realize import CompiledRunner, ExecItem
from tinygrad.engine.realize import CompiledRunner, ExecItem, lower_schedule_item
N = 512
@@ -44,24 +45,46 @@ def sexec(out:Tensor, opts:list[Opt], replace_src=None, run_count=3):
ei = ExecItem(CompiledRunner(prg), [x.ensure_allocated() for x in si.bufs], si.metadata)
for _ in range(run_count): ei.run(wait=True)
def get_quantized_model(sz):
from onnxruntime.quantization import quantize_static, QuantFormat, QuantType, CalibrationDataReader
class FakeDataReader(CalibrationDataReader):
def __init__(self): self.cnt = 0
def get_next(self) -> dict:
self.cnt += 1
if self.cnt == 100: return None
return {"input": np.random.uniform(size=(sz, sz)).astype(np.float32)}
out_file = "/tmp/test_out.onnx"
quantize_static(create_gemm_model("/tmp/test_in.onnx", sz, sz, sz), out_file,
FakeDataReader(), quant_format=QuantFormat.QDQ, per_channel=False, reduce_range=False,
activation_type=QuantType.QUInt8, weight_type=QuantType.QInt8,
extra_options={"ActivationSymmetric": False})
return out_file
@unittest.skipIf(Device.DEFAULT != "CPU", "only tests for CPU")
class TestQuantizeOnnxCPU(unittest.TestCase):
def test_quant_128(self, sz=128):
try:
import onnx
except ImportError:
raise unittest.SkipTest()
from extra.onnx import OnnxRunner
out_file = get_quantized_model(sz)
onnx_model = onnx.load(out_file)
run_onnx = OnnxRunner(onnx_model)
inp = Tensor(np.random.uniform(size=(sz, sz)).astype(np.float32))
with Context(DONT_REALIZE_EXPAND=1, QUANTIZE=1):
sched = run_onnx({"input":inp})["output"].schedule()
ei = lower_schedule_item(sched[-2])
daccs = [u for u in ei.prg.p.uops if u.op is Ops.DEFINE_ACC]
assert all(u.dtype.scalar() is dtypes.int for u in daccs)
@unittest.skipIf(Device.DEFAULT != "DSP", "only tests for DSP")
class TestQuantizeOnnx(unittest.TestCase):
def test_quant_128(self): self.test_quant(128)
def test_quant(self, sz=512):
from onnxruntime.quantization import quantize_static, QuantFormat, QuantType, CalibrationDataReader
from examples.benchmark_onnx import load_onnx_model
class FakeDataReader(CalibrationDataReader):
def __init__(self): self.cnt = 0
def get_next(self) -> dict:
self.cnt += 1
if self.cnt == 100: return None
return {"input": np.random.uniform(size=(sz, sz)).astype(np.float32)}
out_file = "/tmp/test_out.onnx"
# divide is ~1500-2000 without reduce_range, 750-900 with it
quantize_static(create_gemm_model("/tmp/test_in.onnx", sz, sz, sz), out_file,
FakeDataReader(), quant_format=QuantFormat.QDQ, per_channel=False, reduce_range=False,
activation_type=QuantType.QUInt8, weight_type=QuantType.QInt8,
extra_options={"ActivationSymmetric": False})
out_file = get_quantized_model(sz)
run_onnx_jit, _ = load_onnx_model(out_file)
with Context(DONT_REALIZE_EXPAND=1):
run_onnx_jit(input=Tensor(np.random.uniform(size=(sz, sz)).astype(np.float32)))