From cb7a7f69c759206d8e4bf682f9619558d4e09306 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Sat, 15 Mar 2025 07:49:37 +0800 Subject: [PATCH] quantization preprocessor from DSP, should be universal (#9437) * quantization preprocessor from DSP, should be universal * touchups * fix tests --- .github/workflows/test.yml | 4 ++- extra/replay_pkl.py | 2 ++ test/test_quantize_onnx.py | 49 +++++++++++++++++++++-------- tinygrad/codegen/lowerer.py | 63 +++++++++++++++++++++++++++++++++++-- tinygrad/helpers.py | 1 + tinygrad/ops.py | 1 + tinygrad/runtime/ops_dsp.py | 6 ++-- 7 files changed, 106 insertions(+), 20 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index a0b4038507..a6cb1ebb34 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -423,6 +423,8 @@ jobs: run: LLVM=1 python -m pytest -n=auto test/external/external_test_onnx_backend.py --durations=20 - name: Test Additional ONNX Ops (CPU) run: CPU=1 PYTHONPATH=. python3 test/external/external_test_onnx_ops.py + - name: Test Quantize ONNX + run: CPU=1 PYTHONPATH=. python3 test/test_quantize_onnx.py - name: Run CLOUD=1 Test run: | CLOUDDEV=CPU CLOUD=1 python3 test/test_tiny.py @@ -467,7 +469,7 @@ jobs: testdsp: name: Linux (DSP) runs-on: ubuntu-24.04 - timeout-minutes: 10 + timeout-minutes: 15 steps: - name: Checkout Code uses: actions/checkout@v4 diff --git a/extra/replay_pkl.py b/extra/replay_pkl.py index 43272adc8c..fc40280d49 100644 --- a/extra/replay_pkl.py +++ b/extra/replay_pkl.py @@ -26,6 +26,8 @@ if __name__ == "__main__": k.apply_opt(Opt(op=OptOps.UNROLL, axis=0, arg=0)) k.apply_opt(Opt(OptOps.PADTO, 2, 128)) k.apply_opt(Opt(OptOps.UPCAST, 2, 128)) + elif knum == 3: + k.apply_opt(Opt(op=OptOps.UPCAST, axis=1, arg=128)) else: k.hand_coded_optimizations() p2 = k.to_program() diff --git a/test/test_quantize_onnx.py b/test/test_quantize_onnx.py index 1f6681b0cc..e1bce65950 100644 --- a/test/test_quantize_onnx.py +++ b/test/test_quantize_onnx.py @@ -2,8 +2,9 @@ import numpy as np import unittest from dataclasses import replace from tinygrad import Tensor, Context, Device, dtypes +from tinygrad.ops import Ops from tinygrad.codegen.kernel import Kernel, Opt, OptOps -from tinygrad.engine.realize import CompiledRunner, ExecItem +from tinygrad.engine.realize import CompiledRunner, ExecItem, lower_schedule_item N = 512 @@ -44,24 +45,46 @@ def sexec(out:Tensor, opts:list[Opt], replace_src=None, run_count=3): ei = ExecItem(CompiledRunner(prg), [x.ensure_allocated() for x in si.bufs], si.metadata) for _ in range(run_count): ei.run(wait=True) +def get_quantized_model(sz): + from onnxruntime.quantization import quantize_static, QuantFormat, QuantType, CalibrationDataReader + class FakeDataReader(CalibrationDataReader): + def __init__(self): self.cnt = 0 + def get_next(self) -> dict: + self.cnt += 1 + if self.cnt == 100: return None + return {"input": np.random.uniform(size=(sz, sz)).astype(np.float32)} + out_file = "/tmp/test_out.onnx" + quantize_static(create_gemm_model("/tmp/test_in.onnx", sz, sz, sz), out_file, + FakeDataReader(), quant_format=QuantFormat.QDQ, per_channel=False, reduce_range=False, + activation_type=QuantType.QUInt8, weight_type=QuantType.QInt8, + extra_options={"ActivationSymmetric": False}) + return out_file + +@unittest.skipIf(Device.DEFAULT != "CPU", "only tests for CPU") +class TestQuantizeOnnxCPU(unittest.TestCase): + def test_quant_128(self, sz=128): + try: + import onnx + except ImportError: + raise unittest.SkipTest() + from extra.onnx import OnnxRunner + out_file = get_quantized_model(sz) + onnx_model = onnx.load(out_file) + run_onnx = OnnxRunner(onnx_model) + inp = Tensor(np.random.uniform(size=(sz, sz)).astype(np.float32)) + with Context(DONT_REALIZE_EXPAND=1, QUANTIZE=1): + sched = run_onnx({"input":inp})["output"].schedule() + ei = lower_schedule_item(sched[-2]) + daccs = [u for u in ei.prg.p.uops if u.op is Ops.DEFINE_ACC] + assert all(u.dtype.scalar() is dtypes.int for u in daccs) + @unittest.skipIf(Device.DEFAULT != "DSP", "only tests for DSP") class TestQuantizeOnnx(unittest.TestCase): def test_quant_128(self): self.test_quant(128) def test_quant(self, sz=512): - from onnxruntime.quantization import quantize_static, QuantFormat, QuantType, CalibrationDataReader from examples.benchmark_onnx import load_onnx_model - class FakeDataReader(CalibrationDataReader): - def __init__(self): self.cnt = 0 - def get_next(self) -> dict: - self.cnt += 1 - if self.cnt == 100: return None - return {"input": np.random.uniform(size=(sz, sz)).astype(np.float32)} - out_file = "/tmp/test_out.onnx" # divide is ~1500-2000 without reduce_range, 750-900 with it - quantize_static(create_gemm_model("/tmp/test_in.onnx", sz, sz, sz), out_file, - FakeDataReader(), quant_format=QuantFormat.QDQ, per_channel=False, reduce_range=False, - activation_type=QuantType.QUInt8, weight_type=QuantType.QInt8, - extra_options={"ActivationSymmetric": False}) + out_file = get_quantized_model(sz) run_onnx_jit, _ = load_onnx_model(out_file) with Context(DONT_REALIZE_EXPAND=1): run_onnx_jit(input=Tensor(np.random.uniform(size=(sz, sz)).astype(np.float32))) diff --git a/tinygrad/codegen/lowerer.py b/tinygrad/codegen/lowerer.py index 90861e3a9b..40b0bc96f0 100644 --- a/tinygrad/codegen/lowerer.py +++ b/tinygrad/codegen/lowerer.py @@ -2,11 +2,12 @@ import functools, itertools, operator, math from dataclasses import dataclass from typing import cast -from tinygrad.dtype import dtypes, PtrDType -from tinygrad.ops import KernelInfo, UOp, Ops, graph_rewrite, PatternMatcher, UPat, sint, identity_element, sint_to_uop +from tinygrad.dtype import dtypes, PtrDType, least_upper_dtype +from tinygrad.ops import KernelInfo, UOp, Ops, graph_rewrite, PatternMatcher, UPat, sint, identity_element, sint_to_uop, GroupOp from tinygrad.renderer import Renderer -from tinygrad.helpers import all_int, prod, partition, flatten, unwrap +from tinygrad.helpers import all_int, prod, partition, flatten, unwrap, QUANTIZE from tinygrad.codegen.expander import expand_rewrite +from tinygrad.codegen.symbolic import symbolic # returns the axes to create new_shape if new_shape can be created by combining axis from old_shape def get_contraction(old_shape:tuple[sint, ...], new_shape:tuple[sint, ...]) -> list[list[int]]|None: @@ -156,9 +157,65 @@ pm_lowerer = PatternMatcher([ # rewrite LOAD/STORE VIEW to LOAD/STORE with indexed (UPat((Ops.LOAD, Ops.STORE), src=(UPat(), UPat(Ops.VIEW)), allow_any_len=True, name="x"), lower_load_store), (UPat(Ops.INDEX, src=(UPat.var("b"), UPat.var("idx"), UPat.const(dtypes.bool, True))), lambda b, idx: b.index(idx)), + (UPat(Ops.IGNORE, name="x"), lambda x: x.src[0]), +]) + +# **** this is the "quantization preprocessor", it makes ONNX quantized models, and probably also others, actually use ints **** + +def view_to_mask(x:UOp): + from tinygrad.shape.shapetracker import ShapeTracker, View + st = cast(ShapeTracker, x.st) + if len(st.views) > 1: return None + if st.views[-1].mask is None: return None + return ShapeTracker((View(st.shape, (0,)*len(st.shape), 0, st.views[-1].mask, False),)) + +FP = (1 << 16) +pm_quant = symbolic+PatternMatcher([ + # cast after add/mul + (UPat.var("x").cast(dtypes.float32) + UPat.var("y").cast(dtypes.float32), + lambda x,y: (x.cast(least_upper_dtype(x.dtype, y.dtype))+y.cast(least_upper_dtype(x.dtype, y.dtype))).cast(dtypes.float32)), + (UPat.var("x").cast(dtypes.float32) * UPat.var("y").cast(dtypes.float32), + lambda x,y: (x.cast(least_upper_dtype(x.dtype, y.dtype))*y.cast(least_upper_dtype(x.dtype, y.dtype))).cast(dtypes.float32)), + # MUL after reduce + (UPat(Ops.REDUCE_AXIS, src=(UPat.var("x") * UPat.cvar("c"),), name="r"), lambda x,c,r: r.replace(src=(x,))*c), + # CAST after reduce (doesn't work if it's a size change) + (UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.CAST, src=(UPat.var("x"),)),), name="r"), + lambda x,r: r.replace(dtype=x.dtype, src=(x,)).cast(r.dtype) if dtypes.is_float(r.dtype) else None), + # x*c1 + y*c2 -> (x+y)*c1 (if c1 and c2 are close floats) + (UPat.var("x")*UPat.cvar("c1", dtype=dtypes.floats) + UPat.var("y")*UPat.cvar("c2", dtype=dtypes.floats), + lambda x,y,c1,c2: (x+y)*c1 if abs(c1.arg-c2.arg) < 1e-9 else None), + # mul 0 * c1 is 0 + (UPat(Ops.VALID, src=(UPat(Ops.VIEW, name="v"),)).where(UPat.cvar("c1"), UPat(Ops.CONST, arg=0)) * + UPat(Ops.LOAD, src=(UPat(), UPat(Ops.VIEW, name="v"))).cast(dtypes.int).cast(dtypes.float).named("ld"), lambda ld,v,c1: ld*c1), + # mul (with plus) 0 * c1 is 0 + (UPat(Ops.VALID, src=(UPat(Ops.VIEW, name="v"),)).where(UPat.cvar("c1"), UPat(Ops.CONST, arg=0)) * + (UPat(Ops.LOAD, src=(UPat(), UPat(Ops.VIEW, name="v"))).cast(dtypes.int) + \ + UPat(Ops.VALID, src=(UPat(Ops.VIEW, name="v"),)).where(UPat.cvar(), UPat(Ops.CONST, arg=0))).cast(dtypes.float).named("ld"), + lambda ld,v,c1: ld*c1), + # fixed point mult, replace (x.float()*c1+c2).int() with an int expression + ((UPat.var("x").cast(dtypes.float)*UPat.cvar("c1")+UPat.cvar("c2")).cast(dtypes.int), + lambda x,c1,c2: (x * (c1 * FP).cast(dtypes.int) + (c2 * FP).cast(dtypes.int)) // FP), + # where move + (UPat.var("valid").where(UPat.var("yes"), UPat(Ops.CONST, arg=0))*UPat.var("mul"), lambda valid, yes, mul: + (yes*mul*valid.where(UOp.const(mul.dtype, 1), UOp.const(mul.dtype, 0))) if yes.op is not Ops.CONST or yes.arg != 1 else None), + ((UPat.var("x")*UPat.cvar("c"))*(UPat.var().where(UPat(Ops.CONST, arg=1), UPat(Ops.CONST, arg=0)).named("v")), lambda x,c,v: (x*v)*c), + (UPat.var("x").cast().named('c') * UPat.var('valid').where(UPat(Ops.CONST, arg=1), UPat(Ops.CONST, arg=0)), lambda x,c,valid: + (x*valid.where(UOp.const(x.dtype, 1), UOp.const(x.dtype, 0))).cast(c.dtype)), + ((UPat.var('x') * UPat.var('v1').where(UPat(Ops.CONST, arg=1), UPat(Ops.CONST, arg=0)) * + UPat.var('v2').where(UPat(Ops.CONST, arg=1), UPat(Ops.CONST, arg=0))).named("mul"), lambda x, mul, v1, v2: + x * (v1&v2).where(UOp.const(mul.dtype, 1), UOp.const(mul.dtype, 0))), + # don't care + (UPat(Ops.STORE, name="x"), lambda x: + x.replace(src=(x.src[0], UOp(Ops.IGNORE, src=(x.src[1],), arg=mm), UOp(Ops.IGNORE, x.src[2].dtype, src=(x.src[2],), arg=mm),)) \ + if x.src[1].op is not Ops.IGNORE and (mm:=view_to_mask(x.src[1])) is not None else None), + (UPat(Ops.IGNORE, src=(UPat((*GroupOp.ALU, Ops.CAST), name="alu"),), name="ig"), + lambda ig,alu: alu.replace(src=tuple(UOp(Ops.IGNORE, x.dtype, (x,), ig.arg) for x in alu.src))), + (UPat(Ops.IGNORE, src=(UPat.cvar("c"),), name="ig"), lambda ig, c: c), + (UPat(Ops.IGNORE, src=(UPat(Ops.VALID, name="v"),), name="ig"), lambda ig, v: UOp.const(dtypes.bool, True) if v.src[0].arg == ig.arg else None), ]) def rewrite_shapetracker_with_index(ast:UOp, opts:Renderer) -> UOp: + if QUANTIZE and opts.device in {"CPU", "DSP"}: ast = graph_rewrite(ast, pm_quant, name="quantize") sink = graph_rewrite(ast, pm_lowerer, ctx=get_index(ast, opts)) # expand_rewrite turns this into a vectorized program return expand_rewrite(sink) diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index 2b51316083..3c87eba7d6 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -113,6 +113,7 @@ SPLIT_REDUCEOP, NO_MEMORY_PLANNER, RING = ContextVar("SPLIT_REDUCEOP", 1), Conte PICKLE_BUFFERS, PROFILE, LRU = ContextVar("PICKLE_BUFFERS", 1), ContextVar("PROFILE", getenv("VIZ")), ContextVar("LRU", 1) CACHELEVEL, IGNORE_BEAM_CACHE, DEVECTORIZE = ContextVar("CACHELEVEL", 2), ContextVar("IGNORE_BEAM_CACHE", 0), ContextVar("DEVECTORIZE", 1) DONT_REALIZE_EXPAND, DONT_GROUP_REDUCES = ContextVar("DONT_REALIZE_EXPAND", 0), ContextVar("DONT_GROUP_REDUCES", 0) +QUANTIZE = ContextVar("QUANTIZE", 0) @dataclass(frozen=True) class Metadata: diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 57b5959578..42e4ec929d 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -154,6 +154,7 @@ class Ops(FastEnum): # CUSTOMI is inline CUSTOM = auto(); CUSTOMI = auto() # noqa: E702 + IGNORE = auto() class GroupOp: Unary = {Ops.EXP2, Ops.LOG2, Ops.SIN, Ops.SQRT, Ops.RECIP, Ops.NEG} diff --git a/tinygrad/runtime/ops_dsp.py b/tinygrad/runtime/ops_dsp.py index 2d23df950d..dbf604bbec 100644 --- a/tinygrad/runtime/ops_dsp.py +++ b/tinygrad/runtime/ops_dsp.py @@ -20,9 +20,9 @@ dsp_pm = PatternMatcher([ ]) dsp_pm_late = PatternMatcher([ - (UPat.var("x")+UPat(Ops.VECTORIZE, src=UPat.var("y")), lambda x,y: x+UOp(Ops.CUSTOMI, x.dtype, (y,), arg="{0}")), - (UPat.var("x")*UPat(Ops.VECTORIZE, src=UPat.var("y")), lambda x,y: x*UOp(Ops.CUSTOMI, x.dtype, (y,), arg="{0}")), - (UPat.var("x")//UPat(Ops.VECTORIZE, src=UPat.var("y")), lambda x,y: x//UOp(Ops.CUSTOMI, x.dtype, (y,), arg="{0}")), + (UPat.var("x")+UPat(Ops.VECTORIZE,src=UPat.var("y")), lambda x,y: x+UOp(Ops.CUSTOMI,x.dtype,(y,),arg="{0}") if x.op is not Ops.CUSTOMI else None), + (UPat.var("x")*UPat(Ops.VECTORIZE,src=UPat.var("y")), lambda x,y: x*UOp(Ops.CUSTOMI,x.dtype,(y,),arg="{0}") if x.op is not Ops.CUSTOMI else None), + (UPat.var("x")//UPat(Ops.VECTORIZE,src=UPat.var("y")), lambda x,y: x//UOp(Ops.CUSTOMI,x.dtype,(y,),arg="{0}") if x.op is not Ops.CUSTOMI else None), (UPat(Ops.DEFINE_ACC, src=(UPat(Ops.VECTORIZE, src=UPat(Ops.CONST, arg=0)),), dtype=dtypes.uchar.vec(128), name="d", allow_any_len=True), lambda d: d.replace(src=(UOp(Ops.CUSTOMI, d.dtype, arg="__builtin_HEXAGON_V6_vd0_128B()"),)+d.src[1:])), ])