mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
quantization preprocessor from DSP, should be universal (#9437)
* quantization preprocessor from DSP, should be universal * touchups * fix tests
This commit is contained in:
@@ -2,8 +2,9 @@ import numpy as np
|
||||
import unittest
|
||||
from dataclasses import replace
|
||||
from tinygrad import Tensor, Context, Device, dtypes
|
||||
from tinygrad.ops import Ops
|
||||
from tinygrad.codegen.kernel import Kernel, Opt, OptOps
|
||||
from tinygrad.engine.realize import CompiledRunner, ExecItem
|
||||
from tinygrad.engine.realize import CompiledRunner, ExecItem, lower_schedule_item
|
||||
|
||||
N = 512
|
||||
|
||||
@@ -44,24 +45,46 @@ def sexec(out:Tensor, opts:list[Opt], replace_src=None, run_count=3):
|
||||
ei = ExecItem(CompiledRunner(prg), [x.ensure_allocated() for x in si.bufs], si.metadata)
|
||||
for _ in range(run_count): ei.run(wait=True)
|
||||
|
||||
def get_quantized_model(sz):
|
||||
from onnxruntime.quantization import quantize_static, QuantFormat, QuantType, CalibrationDataReader
|
||||
class FakeDataReader(CalibrationDataReader):
|
||||
def __init__(self): self.cnt = 0
|
||||
def get_next(self) -> dict:
|
||||
self.cnt += 1
|
||||
if self.cnt == 100: return None
|
||||
return {"input": np.random.uniform(size=(sz, sz)).astype(np.float32)}
|
||||
out_file = "/tmp/test_out.onnx"
|
||||
quantize_static(create_gemm_model("/tmp/test_in.onnx", sz, sz, sz), out_file,
|
||||
FakeDataReader(), quant_format=QuantFormat.QDQ, per_channel=False, reduce_range=False,
|
||||
activation_type=QuantType.QUInt8, weight_type=QuantType.QInt8,
|
||||
extra_options={"ActivationSymmetric": False})
|
||||
return out_file
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT != "CPU", "only tests for CPU")
|
||||
class TestQuantizeOnnxCPU(unittest.TestCase):
|
||||
def test_quant_128(self, sz=128):
|
||||
try:
|
||||
import onnx
|
||||
except ImportError:
|
||||
raise unittest.SkipTest()
|
||||
from extra.onnx import OnnxRunner
|
||||
out_file = get_quantized_model(sz)
|
||||
onnx_model = onnx.load(out_file)
|
||||
run_onnx = OnnxRunner(onnx_model)
|
||||
inp = Tensor(np.random.uniform(size=(sz, sz)).astype(np.float32))
|
||||
with Context(DONT_REALIZE_EXPAND=1, QUANTIZE=1):
|
||||
sched = run_onnx({"input":inp})["output"].schedule()
|
||||
ei = lower_schedule_item(sched[-2])
|
||||
daccs = [u for u in ei.prg.p.uops if u.op is Ops.DEFINE_ACC]
|
||||
assert all(u.dtype.scalar() is dtypes.int for u in daccs)
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT != "DSP", "only tests for DSP")
|
||||
class TestQuantizeOnnx(unittest.TestCase):
|
||||
def test_quant_128(self): self.test_quant(128)
|
||||
def test_quant(self, sz=512):
|
||||
from onnxruntime.quantization import quantize_static, QuantFormat, QuantType, CalibrationDataReader
|
||||
from examples.benchmark_onnx import load_onnx_model
|
||||
class FakeDataReader(CalibrationDataReader):
|
||||
def __init__(self): self.cnt = 0
|
||||
def get_next(self) -> dict:
|
||||
self.cnt += 1
|
||||
if self.cnt == 100: return None
|
||||
return {"input": np.random.uniform(size=(sz, sz)).astype(np.float32)}
|
||||
out_file = "/tmp/test_out.onnx"
|
||||
# divide is ~1500-2000 without reduce_range, 750-900 with it
|
||||
quantize_static(create_gemm_model("/tmp/test_in.onnx", sz, sz, sz), out_file,
|
||||
FakeDataReader(), quant_format=QuantFormat.QDQ, per_channel=False, reduce_range=False,
|
||||
activation_type=QuantType.QUInt8, weight_type=QuantType.QInt8,
|
||||
extra_options={"ActivationSymmetric": False})
|
||||
out_file = get_quantized_model(sz)
|
||||
run_onnx_jit, _ = load_onnx_model(out_file)
|
||||
with Context(DONT_REALIZE_EXPAND=1):
|
||||
run_onnx_jit(input=Tensor(np.random.uniform(size=(sz, sz)).astype(np.float32)))
|
||||
|
||||
Reference in New Issue
Block a user