From f484db0e6344d56f212e24974681f227027979e5 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Mon, 3 Feb 2025 15:18:53 +0800 Subject: [PATCH] dsp cleanups [pr] (#8866) --- examples/benchmark_onnx.py | 11 +++++++--- examples/test_onnx_imagenet.py | 6 +++++- extra/dsp/compile.py | 2 +- extra/dsp/opt.py | 27 ++++++++++++++++++++++++ tinygrad/runtime/ops_clang.py | 20 ++---------------- tinygrad/runtime/ops_dsp.py | 38 ++++++++++++++++++++++++---------- 6 files changed, 70 insertions(+), 34 deletions(-) create mode 100644 extra/dsp/opt.py diff --git a/examples/benchmark_onnx.py b/examples/benchmark_onnx.py index 498e626aa6..333ac9fbba 100644 --- a/examples/benchmark_onnx.py +++ b/examples/benchmark_onnx.py @@ -1,6 +1,5 @@ import sys, onnx, time from tinygrad import Tensor, TinyJit, Device, GlobalCounters, fetch -from tinygrad.tensor import _from_np_dtype from extra.onnx import OnnxRunner def load_onnx_model(fn): @@ -18,19 +17,25 @@ def load_onnx_model(fn): run_onnx_jit = TinyJit(lambda **kwargs: next(iter(run_onnx({k:v.to(Device.DEFAULT) for k,v in kwargs.items()}).values())), prune=True) return run_onnx_jit, input_shapes, input_types +def get_new_inputs(input_shapes): + #from tinygrad.tensor import _from_np_dtype + #return {k:Tensor.randn(*shp, dtype=_from_np_dtype(input_types[k])).mul(8).realize() for k,shp in sorted(input_shapes.items())} + import numpy as np + return {k:Tensor(np.random.uniform(size=shp).astype(input_types[k]) * 8).realize() for k,shp in sorted(input_shapes.items())} + if __name__ == "__main__": run_onnx_jit, input_shapes, input_types = load_onnx_model(sys.argv[1]) print("loaded model") for i in range(3): - new_inputs = {k:Tensor.randn(*shp, dtype=_from_np_dtype(input_types[k])).mul(8).realize() for k,shp in sorted(input_shapes.items())} + new_inputs = get_new_inputs(input_shapes) GlobalCounters.reset() print(f"run {i}") run_onnx_jit(**new_inputs) # run 20 times for _ in range(20): - new_inputs = {k:Tensor.randn(*shp, dtype=_from_np_dtype(input_types[k])).mul(8).realize() for k,shp in sorted(input_shapes.items())} + new_inputs = get_new_inputs(input_shapes) GlobalCounters.reset() st = time.perf_counter() out = run_onnx_jit(**new_inputs) diff --git a/examples/test_onnx_imagenet.py b/examples/test_onnx_imagenet.py index a8e5a8c56a..1f27e23e0b 100644 --- a/examples/test_onnx_imagenet.py +++ b/examples/test_onnx_imagenet.py @@ -17,6 +17,10 @@ from tinygrad.helpers import fetch, getenv # https://huggingface.co/qualcomm/MobileNet-v2-Quantized/resolve/main/MobileNet-v2-Quantized.onnx # ~35% - https://github.com/axinc-ai/onnx-quantization/raw/refs/heads/main/models/mobilenev2_quantized.onnx +# QUANT=1 python3 examples/test_onnx_imagenet.py +# https://github.com/xamcat/mobcat-samples/raw/refs/heads/master/onnx_runtime/InferencingSample/InferencingSample/mobilenetv2-7.onnx +# VIZ=1 DONT_REALIZE_EXPAND=1 python3 examples/benchmark_onnx.py /tmp/model.quant.onnx + def imagenet_dataloader(cnt=0): input_mean = Tensor([0.485, 0.456, 0.406]).reshape(1, -1, 1, 1) input_std = Tensor([0.229, 0.224, 0.225]).reshape(1, -1, 1, 1) @@ -61,7 +65,7 @@ if __name__ == "__main__": assert shape[1:] == (3,224,224), f"shape is {shape}" hit = 0 - for i,(img,y) in enumerate(imagenet_dataloader()): + for i,(img,y) in enumerate(imagenet_dataloader(cnt=100)): p = run_onnx_jit(**{t_name:img}) assert p.shape == (1,1000) t = p.argmax().item() diff --git a/extra/dsp/compile.py b/extra/dsp/compile.py index 93ee4cf9c0..cb3c18a880 100755 --- a/extra/dsp/compile.py +++ b/extra/dsp/compile.py @@ -37,7 +37,7 @@ if __name__ == "__main__": print("mmapped", hex(res)) to_mv(res, 0x10)[1] = 0xaa - from tinygrad.runtime.ops_clang import ClangCompiler + from tinygrad.runtime.ops_dsp import ClangCompiler cc = ClangCompiler(args=["--target=hexagon", "-mcpu=hexagonv65", "-fuse-ld=lld", "-nostdlib"]) obj = cc.compile(""" diff --git a/extra/dsp/opt.py b/extra/dsp/opt.py new file mode 100644 index 0000000000..fbe35e7ccb --- /dev/null +++ b/extra/dsp/opt.py @@ -0,0 +1,27 @@ +from tinygrad.runtime.ops_dsp import DSPCompiler + +# PATH=/opt/homebrew/opt/llvm/bin:$PATH python3 extra/dsp/opt.py + +if __name__ == "__main__": + compiler = DSPCompiler() + + lib = compiler.compile(""" +typedef long HVX_Vector __attribute__((__vector_size__(128))) __attribute__ ((aligned(128))); +typedef long HVX_VectorPair __attribute__((__vector_size__(256))) __attribute__ ((aligned(256))); + +void test(unsigned char *c, unsigned char *a, unsigned char *b) { + HVX_Vector t0 = *(HVX_Vector*)a; + //HVX_VectorPair t1 = *((HVX_VectorPair*)b); + HVX_Vector acc = __builtin_HEXAGON_V6_vd0_128B(); + for (int i = 0; i < 128; i++) { + //__builtin_HEXAGON_V6_lvsplatb_128B(t0[i]) + //acc += __builtin_HEXAGON_V6_lvsplatb_128B(t0[i]) * t1; + //acc += t0[i] * t1; + unsigned int t1 = ((unsigned int *)b)[i]; + //acc = __builtin_HEXAGON_V6_vrmpyub_acc_128B(acc, t0, t1); + acc = __builtin_HEXAGON_V6_vrmpybus_acc_128B(acc, t0, t1); + } + *((HVX_Vector*)c) = acc; +}""") + + compiler.disassemble(lib) diff --git a/tinygrad/runtime/ops_clang.py b/tinygrad/runtime/ops_clang.py index 2baf572382..463799f305 100644 --- a/tinygrad/runtime/ops_clang.py +++ b/tinygrad/runtime/ops_clang.py @@ -1,25 +1,9 @@ -import platform, tempfile, pathlib, subprocess, sys -from tinygrad.helpers import cpu_objdump, capstone_flatdump +import platform, subprocess, sys +from tinygrad.helpers import capstone_flatdump from tinygrad.device import Compiled, Compiler, MallocAllocator, CPUProgram from tinygrad.runtime.support.elf import jit_loader from tinygrad.renderer.cstyle import ClangRenderer -# Used by ops_dsp.py -class ClangCompiler(Compiler): - def __init__(self, cachekey="compile_clang", args:list[str]|None=None, objdump_tool='objdump'): - self.args = ['-march=native'] if args is None else args - self.objdump_tool = objdump_tool - super().__init__(cachekey) - - def compile(self, src:str) -> bytes: - # TODO: remove file write. sadly clang doesn't like the use of /dev/stdout here - with tempfile.NamedTemporaryFile(delete=True) as output_file: - subprocess.check_output(['clang', '-shared', *self.args, '-O2', '-Wall', '-Werror', '-x', 'c', '-fPIC', '-ffreestanding', '-nostdlib', - '-', '-o', str(output_file.name)], input=src.encode('utf-8')) - return pathlib.Path(output_file.name).read_bytes() - - def disassemble(self, lib:bytes): return cpu_objdump(lib, self.objdump_tool) - class ClangJITCompiler(Compiler): def __init__(self, cachekey="compile_clang_jit"): super().__init__(cachekey) diff --git a/tinygrad/runtime/ops_dsp.py b/tinygrad/runtime/ops_dsp.py index 7cac17c6c1..8813bafa45 100644 --- a/tinygrad/runtime/ops_dsp.py +++ b/tinygrad/runtime/ops_dsp.py @@ -1,12 +1,11 @@ from __future__ import annotations from typing import Tuple, Any, List -import ctypes, os, mmap, tempfile, pathlib, array, functools, threading, contextlib, sys +import ctypes, os, mmap, tempfile, pathlib, array, functools, threading, contextlib, sys, subprocess assert sys.platform != 'win32' -from tinygrad.device import BufferSpec, Compiled, Allocator +from tinygrad.device import BufferSpec, Compiled, Allocator, Compiler from tinygrad.dtype import dtypes, DType, PtrDType from tinygrad.ops import Ops, UOp -from tinygrad.helpers import from_mv, getenv, round_up, mv_address, to_mv -from tinygrad.runtime.ops_clang import ClangCompiler +from tinygrad.helpers import from_mv, getenv, round_up, mv_address, to_mv, cpu_objdump from tinygrad.renderer.cstyle import ClangRenderer from tinygrad.runtime.autogen import libc, qcom_dsp if getenv("IOCTL"): import extra.dsp.run # noqa: F401 # pylint: disable=unused-import @@ -91,10 +90,23 @@ class DSPAllocator(Allocator): def _copyout(self, dest:memoryview, src:DSPBuffer): ctypes.memmove(from_mv(dest), src.va_addr, dest.nbytes) def _offset(self, buf, size:int, offset:int): return DSPBuffer(buf.va_addr+offset, size, buf.share_info, buf.offset+offset) -class DSPDevice(Compiled): - def __init__(self, device:str=""): - self.ion_fd = os.open('/dev/ion', os.O_RDONLY) +class ClangCompiler(Compiler): + def __init__(self, cachekey="compile_clang", args:list[str]|None=None, objdump_tool='objdump'): + self.args = ['-march=native'] if args is None else args + self.objdump_tool = objdump_tool + super().__init__(cachekey) + def compile(self, src:str) -> bytes: + # TODO: remove file write. sadly clang doesn't like the use of /dev/stdout here + with tempfile.NamedTemporaryFile(delete=True) as output_file: + subprocess.check_output(['clang', '-shared', *self.args, '-O2', '-Wall', '-Werror', '-x', 'c', '-fPIC', '-ffreestanding', '-nostdlib', + '-', '-o', str(output_file.name)], input=src.encode('utf-8')) + return pathlib.Path(output_file.name).read_bytes() + + def disassemble(self, lib:bytes): return cpu_objdump(lib, self.objdump_tool) + +class DSPCompiler(ClangCompiler): + def __init__(self): # Generate link script to pass into clang. Aligning all used sections to 4k fixes invoke problem. sections = ['hash', 'text', 'rela.plt', 'got', 'got.plt', 'dynamic', 'dynsym', 'dynstr', 'plt', 'data', 'bss'] sections_link = '\n'.join([f'.{n} : ALIGN(4096) {{ *(.{n}) }}' for n in sections]) @@ -103,15 +115,19 @@ class DSPDevice(Compiled): self.link_ld.flush() compiler_args = ["--target=hexagon", "-mcpu=hexagonv65", "-fuse-ld=lld", "-nostdlib", "-mhvx=v65", "-mhvx-length=128b", f"-T{self.link_ld.name}"] - super().__init__(device, DSPAllocator(self), DSPRenderer(), - ClangCompiler("compile_dsp", args=compiler_args, objdump_tool='llvm-objdump'), functools.partial(DSPProgram, self)) + return super().__init__("compile_dsp", args=compiler_args, objdump_tool='llvm-objdump') + +class DSPDevice(Compiled): + def __init__(self, device:str=""): + self.ion_fd = os.open('/dev/ion', os.O_RDONLY) + super().__init__(device, DSPAllocator(self), DSPRenderer(), DSPCompiler(), functools.partial(DSPProgram, self)) fastrpc_shell = memoryview(bytearray(pathlib.Path('/dsp/cdsp/fastrpc_shell_3').read_bytes())) self.shell_buf = self.allocator.alloc(round_up(fastrpc_shell.nbytes, 0x1000), BufferSpec(nolru=True)) ctypes.memmove(self.shell_buf.va_addr, mv_address(fastrpc_shell), fastrpc_shell.nbytes) self.init_dsp() - RPCListner(self).start() + RPCListener(self).start() def open_lib(self, lib): self.binded_lib, self.binded_lib_off = lib, 0 @@ -149,7 +165,7 @@ class DSPDevice(Compiled): qcom_dsp.FASTRPC_IOCTL_INIT(self.rpc_fd, flags=0x1, file=self.shell_buf.va_addr, filelen=self.shell_buf.size, filefd=self.shell_buf.share_info.fd) qcom_dsp.FASTRPC_IOCTL_INVOKE(self.rpc_fd, handle=3, sc=rpc_sc(method=3, ins=0, outs=0)) -class RPCListner(threading.Thread): +class RPCListener(threading.Thread): def __init__(self, device:DSPDevice): super().__init__() self.device, self.daemon = device, True