mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
delete revectorize (#9000)
* delete revectorize * test vectorized LLVM/CLANG * idk about that * was that the segfault?
This commit is contained in:
8
.github/workflows/test.yml
vendored
8
.github/workflows/test.yml
vendored
@@ -403,8 +403,10 @@ jobs:
|
||||
- name: Setup Environment
|
||||
uses: ./.github/actions/setup-tinygrad
|
||||
with:
|
||||
key: dsp
|
||||
key: dsp-minimal
|
||||
deps: testing_minimal
|
||||
pydeps: "onnx==1.16.0 onnxruntime"
|
||||
llvm: "true"
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
- name: Build QEMU Docker with cache
|
||||
@@ -420,6 +422,10 @@ jobs:
|
||||
run: DEBUG=2 DSP=1 python test/test_tiny.py
|
||||
- name: Test quantize onnx
|
||||
run: PYTHONPATH="." DEBUG=2 DSP=1 python3 test/test_quantize_onnx.py
|
||||
- name: Test LLVM=1 DEVECTORIZE=0
|
||||
run: LLVM=1 DEVECTORIZE=0 pytest -n auto test/test_tiny.py test/test_ops.py -k "not test_avg_pool3d_failure"
|
||||
#- name: Test CLANG=1 DEVECTORIZE=0
|
||||
# run: CLANG=1 DEVECTORIZE=0 pytest -n auto test/test_tiny.py test/test_ops.py -k "not test_avg_pool3d_failure"
|
||||
|
||||
testwebgpu:
|
||||
name: Linux (WebGPU)
|
||||
|
||||
@@ -57,8 +57,8 @@ if __name__ == "__main__":
|
||||
return None
|
||||
return {"input": img.numpy()}
|
||||
quantize_static(model_fp32, fn, ImagenetReader(), quant_format=QuantFormat.QDQ, per_channel=False,
|
||||
activation_type=QuantType.QInt8, weight_type=QuantType.QInt8,
|
||||
extra_options={"ActivationSymmetric": True})
|
||||
activation_type=QuantType.QUInt8, weight_type=QuantType.QInt8,
|
||||
extra_options={"ActivationSymmetric": False})
|
||||
|
||||
run_onnx_jit, input_specs = load_onnx_model(fetch(fn))
|
||||
t_name, t_spec = list(input_specs.items())[0]
|
||||
|
||||
@@ -51,8 +51,8 @@ class TestQuantizeOnnx(unittest.TestCase):
|
||||
out_file = "/tmp/test_out.onnx"
|
||||
quantize_static(create_gemm_model("/tmp/test_in.onnx"), out_file,
|
||||
FakeDataReader(), quant_format=QuantFormat.QDQ, per_channel=False,
|
||||
activation_type=QuantType.QInt8, weight_type=QuantType.QInt8,
|
||||
extra_options={"ActivationSymmetric": True})
|
||||
activation_type=QuantType.QUInt8, weight_type=QuantType.QInt8,
|
||||
extra_options={"ActivationSymmetric": False})
|
||||
run_onnx_jit, _ = load_onnx_model(out_file)
|
||||
with Context(NOOPT=1):
|
||||
run_onnx_jit(input=Tensor(np.random.uniform(size=(1, N)).astype(np.float32)))
|
||||
@@ -73,6 +73,15 @@ class TestQuantizeOnnx(unittest.TestCase):
|
||||
opts = [Opt(op=OptOps.UPCAST, axis=1, arg=128), Opt(op=OptOps.UNROLL, axis=0, arg=4)]
|
||||
sexec(out, opts)
|
||||
|
||||
def test_prequant_gemm_intacc(self):
|
||||
N = 512
|
||||
# ugh, it's so broken with those casts. need DONT_REALIZE_EXPAND=1 python3 test/test_quantize_onnx.py TestQuantizeOnnx.test_prequant
|
||||
X = Tensor(np.random.uniform(0, 255, size=(N,N)).astype(np.uint8))
|
||||
W = Tensor(np.random.uniform(0, 255, size=(N,N)).astype(np.int8))
|
||||
out = X.matmul(W)
|
||||
opts = [Opt(op=OptOps.UPCAST, axis=1, arg=128), Opt(op=OptOps.UNROLL, axis=0, arg=4)]
|
||||
sexec(out, opts)
|
||||
|
||||
def test_prequant_gemv(self):
|
||||
N = 2048
|
||||
# ugh, it's so broken with those casts. need DONT_REALIZE_EXPAND=1 python3 test/test_quantize_onnx.py TestQuantizeOnnx.test_prequant
|
||||
|
||||
@@ -3,8 +3,8 @@ from collections import defaultdict, deque
|
||||
from dataclasses import dataclass, field
|
||||
from tinygrad.ops import UOp, Variable, Ops, GroupOp, PatternMatcher, UPat, graph_rewrite, graph_rewrite_map, track_rewrites, buffers
|
||||
from tinygrad.ops import can_pad, identity_element, resolve, symbolic_simple, view_left, merge_views
|
||||
from tinygrad.helpers import Context, ContextVar, Metadata, all_int, all_same, colored, diskcache_put, prod, dedup, getenv, unwrap, flatten
|
||||
from tinygrad.helpers import FUSE_CONV_BW, FUSE_ARANGE, DEBUG, CAPTURE_PROCESS_REPLAY
|
||||
from tinygrad.helpers import Context, ContextVar, Metadata, all_int, all_same, colored, diskcache_put, prod, dedup, unwrap, flatten
|
||||
from tinygrad.helpers import FUSE_CONV_BW, FUSE_ARANGE, DEBUG, CAPTURE_PROCESS_REPLAY, DONT_REALIZE_EXPAND
|
||||
from tinygrad.dtype import ImageDType
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.view import View, strides_for_shape
|
||||
@@ -133,7 +133,7 @@ def realize_before_view(ctx:ScheduleContext, view:UOp, src:UOp, b:UOp, **kwargs)
|
||||
if len(st.views) == 1 and (m:=st.views[-1].mask) is not None and all_int(src.shape) and resolve(prod(src.shape) >= prod([y-x for x,y in m])):
|
||||
return None if can_pad(src, ctx.realizes, dict()) else realize(ctx, b, src)
|
||||
# early realize before expand
|
||||
if resolve(prod(src.shape) < prod(st.shape)) and not getenv("DONT_REALIZE_EXPAND"): return realize(ctx, b, src)
|
||||
if resolve(prod(src.shape) < prod(st.shape)) and not DONT_REALIZE_EXPAND: return realize(ctx, b, src)
|
||||
# otherwise safety check pads
|
||||
return None if (all(v.mask is None for v in st.views) or can_pad(src, ctx.realizes, dict())) else realize(ctx, b, src)
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from typing import cast, Optional, Callable
|
||||
import itertools, functools, random, math, time, multiprocessing, traceback, signal
|
||||
import itertools, functools, random, math, time, multiprocessing, traceback, signal, atexit
|
||||
from collections import defaultdict
|
||||
from dataclasses import replace
|
||||
from tinygrad.ops import UOp, Ops, Variable, sym_infer
|
||||
@@ -141,6 +141,8 @@ def beam_search(lin:Kernel, rawbufs:list[Buffer], amt:int, allow_test_size=True,
|
||||
default_parallel = multiprocessing.cpu_count() if lin.opts.device in {"CUDA", "AMD", "NV", "METAL"} else 0
|
||||
if beam_pool is None and (workers := getenv("PARALLEL", default_parallel)):
|
||||
beam_pool = multiprocessing.get_context("spawn").Pool(workers, _init_worker, (), getenv("BEAM_MAX_TASKS_PER_CHILD", 16))
|
||||
@atexit.register
|
||||
def close_pool(): beam_pool.close()
|
||||
|
||||
min_progress = getenv("BEAM_MIN_PROGRESS", 0.01)/1e6
|
||||
if BEAM_DEBUG: print(f"BEAM_SEARCH:\n{lin.ast}")
|
||||
|
||||
@@ -111,6 +111,7 @@ FUSE_ARANGE, FUSE_CONV_BW = ContextVar("FUSE_ARANGE", 0), ContextVar("FUSE_CONV_
|
||||
SPLIT_REDUCEOP, NO_MEMORY_PLANNER, RING = ContextVar("SPLIT_REDUCEOP", 1), ContextVar("NO_MEMORY_PLANNER", 0), ContextVar("RING", 1)
|
||||
PICKLE_BUFFERS, PROFILE, LRU = ContextVar("PICKLE_BUFFERS", 1), ContextVar("PROFILE", getenv("VIZ")), ContextVar("LRU", 1)
|
||||
CACHELEVEL, IGNORE_BEAM_CACHE, DEVECTORIZE = ContextVar("CACHELEVEL", 2), ContextVar("IGNORE_BEAM_CACHE", 0), ContextVar("DEVECTORIZE", 1)
|
||||
DONT_REALIZE_EXPAND = ContextVar("DONT_REALIZE_EXPAND", 0)
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Metadata:
|
||||
|
||||
@@ -5,6 +5,7 @@ from tinygrad.ops import GroupOp, Ops, UOp, PatternMatcher, UPat
|
||||
from tinygrad.helpers import strip_parens, getenv, prod, dedup, AMX
|
||||
from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType
|
||||
from tinygrad.renderer import Renderer, TensorCore
|
||||
from tinygrad.codegen.rewriter import no_vectorized_alu
|
||||
|
||||
base_rewrite = PatternMatcher([
|
||||
(UPat(Ops.DEFINE_ACC, name="x"), lambda ctx,x: ctx[x.src[0]]),
|
||||
@@ -59,6 +60,10 @@ extra_pm = PatternMatcher([
|
||||
lambda x: UOp(Ops.BITCAST, x.dtype, (UOp(Ops.NOOP, x.src[0].dtype, x.src),)) if x.src[0].op is not Ops.NOOP else None),
|
||||
# rewrite MAX to CMPLT + WHERE (max function is annoying on many cstyle backends)
|
||||
(UPat(Ops.MAX, name="m"), lambda m: (m.src[0] < m.src[1]).where(m.src[1], m.src[0])),
|
||||
# devectorize any bools
|
||||
(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN, Ops.INDEX), dtype=dtypes.bool, name="alu"), no_vectorized_alu),
|
||||
# CAST/WHERE can't be vectorized
|
||||
(UPat((Ops.CAST, Ops.WHERE), name="alu"), no_vectorized_alu),
|
||||
])
|
||||
|
||||
def uops_to_dtypes(uops:list[UOp]) -> list[DType]: return dedup(u.dtype for u in uops if not isinstance(u.dtype, (ImageDType, PtrDType)))
|
||||
|
||||
@@ -9,30 +9,9 @@ from tinygrad.renderer.cstyle import ClangRenderer
|
||||
from tinygrad.runtime.autogen import libc, qcom_dsp
|
||||
if getenv("IOCTL"): import extra.dsp.run # noqa: F401 # pylint: disable=unused-import
|
||||
|
||||
from tinygrad.helpers import all_same
|
||||
from tinygrad.ops import PatternMatcher, UPat, GroupOp
|
||||
|
||||
def revectorize(v:UOp):
|
||||
if not all_same([x.op for x in v.src]) or any(dtypes.is_bool(x.dtype) for x in v.src[0].src): return None
|
||||
new_srcs = [UOp(Ops.VECTORIZE, v.src[0].src[i].dtype.vec(v.dtype.count), tuple(x.src[i] for x in v.src)) for i in range(len(v.src[0].src))]
|
||||
return UOp(v.src[0].op, v.dtype, tuple(new_srcs), v.src[0].arg)
|
||||
|
||||
revectorize_pm = PatternMatcher([
|
||||
(UPat(Ops.VECTORIZE, src=UPat((*GroupOp.ALU, Ops.ASSIGN, Ops.CAST)), name="v"), revectorize),
|
||||
# vectorize DEFINE_ACC (similar to expander)
|
||||
(UPat(Ops.VECTORIZE, src=UPat(Ops.DEFINE_ACC), name="v"),
|
||||
lambda v: UOp(Ops.DEFINE_ACC, v.dtype,
|
||||
(UOp.broadcast(UOp.const(v.dtype.scalar(), v.src[0].src[0].arg), v.dtype.count),)+v.src[0].src[1:], v.src[0].arg)),
|
||||
# vectorize increasing GEPs = nothing (wrong if dtypes don't match!)
|
||||
(UPat(Ops.VECTORIZE, src=UPat(Ops.GEP), name="v"),
|
||||
lambda v: v.src[0].src[0] if all_same([x.src for x in v.src]) and \
|
||||
[x.arg[0] if len(x.arg) == 1 else None for x in v.src] == list(range(v.dtype.count)) else None),
|
||||
])
|
||||
|
||||
class DSPRenderer(ClangRenderer):
|
||||
device = "DSP"
|
||||
supports_float4 = True
|
||||
extra_matcher = revectorize_pm+ClangRenderer.extra_matcher
|
||||
buffer_suffix = " restrict __attribute__((align_value(128)))"
|
||||
kernel_prefix = "__attribute__((noinline)) "
|
||||
type_map = { **ClangRenderer.type_map, dtypes.uint64: "unsigned long long", dtypes.int64: "long long" }
|
||||
@@ -101,9 +80,10 @@ class DSPAllocator(Allocator):
|
||||
return DSPBuffer(va_addr, size, share_info, offset=0)
|
||||
|
||||
def _free(self, opaque:DSPBuffer, options:BufferSpec):
|
||||
libc.munmap(opaque.va_addr, opaque.size)
|
||||
os.close(opaque.share_info.fd)
|
||||
qcom_dsp.ION_IOC_FREE(self.dev.ion_fd, handle=opaque.share_info.handle)
|
||||
if libc is not None and qcom_dsp is not None:
|
||||
libc.munmap(opaque.va_addr, opaque.size)
|
||||
os.close(opaque.share_info.fd)
|
||||
qcom_dsp.ION_IOC_FREE(self.dev.ion_fd, handle=opaque.share_info.handle)
|
||||
|
||||
def _as_buffer(self, src:DSPBuffer) -> memoryview: return to_mv(src.va_addr, src.size)
|
||||
def _copyin(self, dest:DSPBuffer, src:memoryview): ctypes.memmove(dest.va_addr, from_mv(src), src.nbytes)
|
||||
|
||||
Reference in New Issue
Block a user