delete revectorize (#9000)

* delete revectorize

* test vectorized LLVM/CLANG

* idk about that

* was that the segfault?
This commit is contained in:
George Hotz
2025-02-10 18:32:35 +08:00
committed by GitHub
parent fd9f9ec772
commit 0568720a68
8 changed files with 36 additions and 33 deletions

View File

@@ -403,8 +403,10 @@ jobs:
- name: Setup Environment
uses: ./.github/actions/setup-tinygrad
with:
key: dsp
key: dsp-minimal
deps: testing_minimal
pydeps: "onnx==1.16.0 onnxruntime"
llvm: "true"
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Build QEMU Docker with cache
@@ -420,6 +422,10 @@ jobs:
run: DEBUG=2 DSP=1 python test/test_tiny.py
- name: Test quantize onnx
run: PYTHONPATH="." DEBUG=2 DSP=1 python3 test/test_quantize_onnx.py
- name: Test LLVM=1 DEVECTORIZE=0
run: LLVM=1 DEVECTORIZE=0 pytest -n auto test/test_tiny.py test/test_ops.py -k "not test_avg_pool3d_failure"
#- name: Test CLANG=1 DEVECTORIZE=0
# run: CLANG=1 DEVECTORIZE=0 pytest -n auto test/test_tiny.py test/test_ops.py -k "not test_avg_pool3d_failure"
testwebgpu:
name: Linux (WebGPU)

View File

@@ -57,8 +57,8 @@ if __name__ == "__main__":
return None
return {"input": img.numpy()}
quantize_static(model_fp32, fn, ImagenetReader(), quant_format=QuantFormat.QDQ, per_channel=False,
activation_type=QuantType.QInt8, weight_type=QuantType.QInt8,
extra_options={"ActivationSymmetric": True})
activation_type=QuantType.QUInt8, weight_type=QuantType.QInt8,
extra_options={"ActivationSymmetric": False})
run_onnx_jit, input_specs = load_onnx_model(fetch(fn))
t_name, t_spec = list(input_specs.items())[0]

View File

@@ -51,8 +51,8 @@ class TestQuantizeOnnx(unittest.TestCase):
out_file = "/tmp/test_out.onnx"
quantize_static(create_gemm_model("/tmp/test_in.onnx"), out_file,
FakeDataReader(), quant_format=QuantFormat.QDQ, per_channel=False,
activation_type=QuantType.QInt8, weight_type=QuantType.QInt8,
extra_options={"ActivationSymmetric": True})
activation_type=QuantType.QUInt8, weight_type=QuantType.QInt8,
extra_options={"ActivationSymmetric": False})
run_onnx_jit, _ = load_onnx_model(out_file)
with Context(NOOPT=1):
run_onnx_jit(input=Tensor(np.random.uniform(size=(1, N)).astype(np.float32)))
@@ -73,6 +73,15 @@ class TestQuantizeOnnx(unittest.TestCase):
opts = [Opt(op=OptOps.UPCAST, axis=1, arg=128), Opt(op=OptOps.UNROLL, axis=0, arg=4)]
sexec(out, opts)
def test_prequant_gemm_intacc(self):
N = 512
# ugh, it's so broken with those casts. need DONT_REALIZE_EXPAND=1 python3 test/test_quantize_onnx.py TestQuantizeOnnx.test_prequant
X = Tensor(np.random.uniform(0, 255, size=(N,N)).astype(np.uint8))
W = Tensor(np.random.uniform(0, 255, size=(N,N)).astype(np.int8))
out = X.matmul(W)
opts = [Opt(op=OptOps.UPCAST, axis=1, arg=128), Opt(op=OptOps.UNROLL, axis=0, arg=4)]
sexec(out, opts)
def test_prequant_gemv(self):
N = 2048
# ugh, it's so broken with those casts. need DONT_REALIZE_EXPAND=1 python3 test/test_quantize_onnx.py TestQuantizeOnnx.test_prequant

View File

@@ -3,8 +3,8 @@ from collections import defaultdict, deque
from dataclasses import dataclass, field
from tinygrad.ops import UOp, Variable, Ops, GroupOp, PatternMatcher, UPat, graph_rewrite, graph_rewrite_map, track_rewrites, buffers
from tinygrad.ops import can_pad, identity_element, resolve, symbolic_simple, view_left, merge_views
from tinygrad.helpers import Context, ContextVar, Metadata, all_int, all_same, colored, diskcache_put, prod, dedup, getenv, unwrap, flatten
from tinygrad.helpers import FUSE_CONV_BW, FUSE_ARANGE, DEBUG, CAPTURE_PROCESS_REPLAY
from tinygrad.helpers import Context, ContextVar, Metadata, all_int, all_same, colored, diskcache_put, prod, dedup, unwrap, flatten
from tinygrad.helpers import FUSE_CONV_BW, FUSE_ARANGE, DEBUG, CAPTURE_PROCESS_REPLAY, DONT_REALIZE_EXPAND
from tinygrad.dtype import ImageDType
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import View, strides_for_shape
@@ -133,7 +133,7 @@ def realize_before_view(ctx:ScheduleContext, view:UOp, src:UOp, b:UOp, **kwargs)
if len(st.views) == 1 and (m:=st.views[-1].mask) is not None and all_int(src.shape) and resolve(prod(src.shape) >= prod([y-x for x,y in m])):
return None if can_pad(src, ctx.realizes, dict()) else realize(ctx, b, src)
# early realize before expand
if resolve(prod(src.shape) < prod(st.shape)) and not getenv("DONT_REALIZE_EXPAND"): return realize(ctx, b, src)
if resolve(prod(src.shape) < prod(st.shape)) and not DONT_REALIZE_EXPAND: return realize(ctx, b, src)
# otherwise safety check pads
return None if (all(v.mask is None for v in st.views) or can_pad(src, ctx.realizes, dict())) else realize(ctx, b, src)

View File

@@ -1,5 +1,5 @@
from typing import cast, Optional, Callable
import itertools, functools, random, math, time, multiprocessing, traceback, signal
import itertools, functools, random, math, time, multiprocessing, traceback, signal, atexit
from collections import defaultdict
from dataclasses import replace
from tinygrad.ops import UOp, Ops, Variable, sym_infer
@@ -141,6 +141,8 @@ def beam_search(lin:Kernel, rawbufs:list[Buffer], amt:int, allow_test_size=True,
default_parallel = multiprocessing.cpu_count() if lin.opts.device in {"CUDA", "AMD", "NV", "METAL"} else 0
if beam_pool is None and (workers := getenv("PARALLEL", default_parallel)):
beam_pool = multiprocessing.get_context("spawn").Pool(workers, _init_worker, (), getenv("BEAM_MAX_TASKS_PER_CHILD", 16))
@atexit.register
def close_pool(): beam_pool.close()
min_progress = getenv("BEAM_MIN_PROGRESS", 0.01)/1e6
if BEAM_DEBUG: print(f"BEAM_SEARCH:\n{lin.ast}")

View File

@@ -111,6 +111,7 @@ FUSE_ARANGE, FUSE_CONV_BW = ContextVar("FUSE_ARANGE", 0), ContextVar("FUSE_CONV_
SPLIT_REDUCEOP, NO_MEMORY_PLANNER, RING = ContextVar("SPLIT_REDUCEOP", 1), ContextVar("NO_MEMORY_PLANNER", 0), ContextVar("RING", 1)
PICKLE_BUFFERS, PROFILE, LRU = ContextVar("PICKLE_BUFFERS", 1), ContextVar("PROFILE", getenv("VIZ")), ContextVar("LRU", 1)
CACHELEVEL, IGNORE_BEAM_CACHE, DEVECTORIZE = ContextVar("CACHELEVEL", 2), ContextVar("IGNORE_BEAM_CACHE", 0), ContextVar("DEVECTORIZE", 1)
DONT_REALIZE_EXPAND = ContextVar("DONT_REALIZE_EXPAND", 0)
@dataclass(frozen=True)
class Metadata:

View File

@@ -5,6 +5,7 @@ from tinygrad.ops import GroupOp, Ops, UOp, PatternMatcher, UPat
from tinygrad.helpers import strip_parens, getenv, prod, dedup, AMX
from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType
from tinygrad.renderer import Renderer, TensorCore
from tinygrad.codegen.rewriter import no_vectorized_alu
base_rewrite = PatternMatcher([
(UPat(Ops.DEFINE_ACC, name="x"), lambda ctx,x: ctx[x.src[0]]),
@@ -59,6 +60,10 @@ extra_pm = PatternMatcher([
lambda x: UOp(Ops.BITCAST, x.dtype, (UOp(Ops.NOOP, x.src[0].dtype, x.src),)) if x.src[0].op is not Ops.NOOP else None),
# rewrite MAX to CMPLT + WHERE (max function is annoying on many cstyle backends)
(UPat(Ops.MAX, name="m"), lambda m: (m.src[0] < m.src[1]).where(m.src[1], m.src[0])),
# devectorize any bools
(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN, Ops.INDEX), dtype=dtypes.bool, name="alu"), no_vectorized_alu),
# CAST/WHERE can't be vectorized
(UPat((Ops.CAST, Ops.WHERE), name="alu"), no_vectorized_alu),
])
def uops_to_dtypes(uops:list[UOp]) -> list[DType]: return dedup(u.dtype for u in uops if not isinstance(u.dtype, (ImageDType, PtrDType)))

View File

@@ -9,30 +9,9 @@ from tinygrad.renderer.cstyle import ClangRenderer
from tinygrad.runtime.autogen import libc, qcom_dsp
if getenv("IOCTL"): import extra.dsp.run # noqa: F401 # pylint: disable=unused-import
from tinygrad.helpers import all_same
from tinygrad.ops import PatternMatcher, UPat, GroupOp
def revectorize(v:UOp):
if not all_same([x.op for x in v.src]) or any(dtypes.is_bool(x.dtype) for x in v.src[0].src): return None
new_srcs = [UOp(Ops.VECTORIZE, v.src[0].src[i].dtype.vec(v.dtype.count), tuple(x.src[i] for x in v.src)) for i in range(len(v.src[0].src))]
return UOp(v.src[0].op, v.dtype, tuple(new_srcs), v.src[0].arg)
revectorize_pm = PatternMatcher([
(UPat(Ops.VECTORIZE, src=UPat((*GroupOp.ALU, Ops.ASSIGN, Ops.CAST)), name="v"), revectorize),
# vectorize DEFINE_ACC (similar to expander)
(UPat(Ops.VECTORIZE, src=UPat(Ops.DEFINE_ACC), name="v"),
lambda v: UOp(Ops.DEFINE_ACC, v.dtype,
(UOp.broadcast(UOp.const(v.dtype.scalar(), v.src[0].src[0].arg), v.dtype.count),)+v.src[0].src[1:], v.src[0].arg)),
# vectorize increasing GEPs = nothing (wrong if dtypes don't match!)
(UPat(Ops.VECTORIZE, src=UPat(Ops.GEP), name="v"),
lambda v: v.src[0].src[0] if all_same([x.src for x in v.src]) and \
[x.arg[0] if len(x.arg) == 1 else None for x in v.src] == list(range(v.dtype.count)) else None),
])
class DSPRenderer(ClangRenderer):
device = "DSP"
supports_float4 = True
extra_matcher = revectorize_pm+ClangRenderer.extra_matcher
buffer_suffix = " restrict __attribute__((align_value(128)))"
kernel_prefix = "__attribute__((noinline)) "
type_map = { **ClangRenderer.type_map, dtypes.uint64: "unsigned long long", dtypes.int64: "long long" }
@@ -101,9 +80,10 @@ class DSPAllocator(Allocator):
return DSPBuffer(va_addr, size, share_info, offset=0)
def _free(self, opaque:DSPBuffer, options:BufferSpec):
libc.munmap(opaque.va_addr, opaque.size)
os.close(opaque.share_info.fd)
qcom_dsp.ION_IOC_FREE(self.dev.ion_fd, handle=opaque.share_info.handle)
if libc is not None and qcom_dsp is not None:
libc.munmap(opaque.va_addr, opaque.size)
os.close(opaque.share_info.fd)
qcom_dsp.ION_IOC_FREE(self.dev.ion_fd, handle=opaque.share_info.handle)
def _as_buffer(self, src:DSPBuffer) -> memoryview: return to_mv(src.va_addr, src.size)
def _copyin(self, dest:DSPBuffer, src:memoryview): ctypes.memmove(dest.va_addr, from_mv(src), src.nbytes)