dsp cleanups [pr] (#8866)

This commit is contained in:
George Hotz
2025-02-03 15:18:53 +08:00
committed by GitHub
parent af2c2837f6
commit f484db0e63
6 changed files with 70 additions and 34 deletions

View File

@@ -1,6 +1,5 @@
import sys, onnx, time
from tinygrad import Tensor, TinyJit, Device, GlobalCounters, fetch
from tinygrad.tensor import _from_np_dtype
from extra.onnx import OnnxRunner
def load_onnx_model(fn):
@@ -18,19 +17,25 @@ def load_onnx_model(fn):
run_onnx_jit = TinyJit(lambda **kwargs: next(iter(run_onnx({k:v.to(Device.DEFAULT) for k,v in kwargs.items()}).values())), prune=True)
return run_onnx_jit, input_shapes, input_types
def get_new_inputs(input_shapes):
#from tinygrad.tensor import _from_np_dtype
#return {k:Tensor.randn(*shp, dtype=_from_np_dtype(input_types[k])).mul(8).realize() for k,shp in sorted(input_shapes.items())}
import numpy as np
return {k:Tensor(np.random.uniform(size=shp).astype(input_types[k]) * 8).realize() for k,shp in sorted(input_shapes.items())}
if __name__ == "__main__":
run_onnx_jit, input_shapes, input_types = load_onnx_model(sys.argv[1])
print("loaded model")
for i in range(3):
new_inputs = {k:Tensor.randn(*shp, dtype=_from_np_dtype(input_types[k])).mul(8).realize() for k,shp in sorted(input_shapes.items())}
new_inputs = get_new_inputs(input_shapes)
GlobalCounters.reset()
print(f"run {i}")
run_onnx_jit(**new_inputs)
# run 20 times
for _ in range(20):
new_inputs = {k:Tensor.randn(*shp, dtype=_from_np_dtype(input_types[k])).mul(8).realize() for k,shp in sorted(input_shapes.items())}
new_inputs = get_new_inputs(input_shapes)
GlobalCounters.reset()
st = time.perf_counter()
out = run_onnx_jit(**new_inputs)

View File

@@ -17,6 +17,10 @@ from tinygrad.helpers import fetch, getenv
# https://huggingface.co/qualcomm/MobileNet-v2-Quantized/resolve/main/MobileNet-v2-Quantized.onnx
# ~35% - https://github.com/axinc-ai/onnx-quantization/raw/refs/heads/main/models/mobilenev2_quantized.onnx
# QUANT=1 python3 examples/test_onnx_imagenet.py
# https://github.com/xamcat/mobcat-samples/raw/refs/heads/master/onnx_runtime/InferencingSample/InferencingSample/mobilenetv2-7.onnx
# VIZ=1 DONT_REALIZE_EXPAND=1 python3 examples/benchmark_onnx.py /tmp/model.quant.onnx
def imagenet_dataloader(cnt=0):
input_mean = Tensor([0.485, 0.456, 0.406]).reshape(1, -1, 1, 1)
input_std = Tensor([0.229, 0.224, 0.225]).reshape(1, -1, 1, 1)
@@ -61,7 +65,7 @@ if __name__ == "__main__":
assert shape[1:] == (3,224,224), f"shape is {shape}"
hit = 0
for i,(img,y) in enumerate(imagenet_dataloader()):
for i,(img,y) in enumerate(imagenet_dataloader(cnt=100)):
p = run_onnx_jit(**{t_name:img})
assert p.shape == (1,1000)
t = p.argmax().item()

View File

@@ -37,7 +37,7 @@ if __name__ == "__main__":
print("mmapped", hex(res))
to_mv(res, 0x10)[1] = 0xaa
from tinygrad.runtime.ops_clang import ClangCompiler
from tinygrad.runtime.ops_dsp import ClangCompiler
cc = ClangCompiler(args=["--target=hexagon", "-mcpu=hexagonv65", "-fuse-ld=lld", "-nostdlib"])
obj = cc.compile("""

27
extra/dsp/opt.py Normal file
View File

@@ -0,0 +1,27 @@
from tinygrad.runtime.ops_dsp import DSPCompiler
# PATH=/opt/homebrew/opt/llvm/bin:$PATH python3 extra/dsp/opt.py
if __name__ == "__main__":
compiler = DSPCompiler()
lib = compiler.compile("""
typedef long HVX_Vector __attribute__((__vector_size__(128))) __attribute__ ((aligned(128)));
typedef long HVX_VectorPair __attribute__((__vector_size__(256))) __attribute__ ((aligned(256)));
void test(unsigned char *c, unsigned char *a, unsigned char *b) {
HVX_Vector t0 = *(HVX_Vector*)a;
//HVX_VectorPair t1 = *((HVX_VectorPair*)b);
HVX_Vector acc = __builtin_HEXAGON_V6_vd0_128B();
for (int i = 0; i < 128; i++) {
//__builtin_HEXAGON_V6_lvsplatb_128B(t0[i])
//acc += __builtin_HEXAGON_V6_lvsplatb_128B(t0[i]) * t1;
//acc += t0[i] * t1;
unsigned int t1 = ((unsigned int *)b)[i];
//acc = __builtin_HEXAGON_V6_vrmpyub_acc_128B(acc, t0, t1);
acc = __builtin_HEXAGON_V6_vrmpybus_acc_128B(acc, t0, t1);
}
*((HVX_Vector*)c) = acc;
}""")
compiler.disassemble(lib)

View File

@@ -1,25 +1,9 @@
import platform, tempfile, pathlib, subprocess, sys
from tinygrad.helpers import cpu_objdump, capstone_flatdump
import platform, subprocess, sys
from tinygrad.helpers import capstone_flatdump
from tinygrad.device import Compiled, Compiler, MallocAllocator, CPUProgram
from tinygrad.runtime.support.elf import jit_loader
from tinygrad.renderer.cstyle import ClangRenderer
# Used by ops_dsp.py
class ClangCompiler(Compiler):
def __init__(self, cachekey="compile_clang", args:list[str]|None=None, objdump_tool='objdump'):
self.args = ['-march=native'] if args is None else args
self.objdump_tool = objdump_tool
super().__init__(cachekey)
def compile(self, src:str) -> bytes:
# TODO: remove file write. sadly clang doesn't like the use of /dev/stdout here
with tempfile.NamedTemporaryFile(delete=True) as output_file:
subprocess.check_output(['clang', '-shared', *self.args, '-O2', '-Wall', '-Werror', '-x', 'c', '-fPIC', '-ffreestanding', '-nostdlib',
'-', '-o', str(output_file.name)], input=src.encode('utf-8'))
return pathlib.Path(output_file.name).read_bytes()
def disassemble(self, lib:bytes): return cpu_objdump(lib, self.objdump_tool)
class ClangJITCompiler(Compiler):
def __init__(self, cachekey="compile_clang_jit"): super().__init__(cachekey)

View File

@@ -1,12 +1,11 @@
from __future__ import annotations
from typing import Tuple, Any, List
import ctypes, os, mmap, tempfile, pathlib, array, functools, threading, contextlib, sys
import ctypes, os, mmap, tempfile, pathlib, array, functools, threading, contextlib, sys, subprocess
assert sys.platform != 'win32'
from tinygrad.device import BufferSpec, Compiled, Allocator
from tinygrad.device import BufferSpec, Compiled, Allocator, Compiler
from tinygrad.dtype import dtypes, DType, PtrDType
from tinygrad.ops import Ops, UOp
from tinygrad.helpers import from_mv, getenv, round_up, mv_address, to_mv
from tinygrad.runtime.ops_clang import ClangCompiler
from tinygrad.helpers import from_mv, getenv, round_up, mv_address, to_mv, cpu_objdump
from tinygrad.renderer.cstyle import ClangRenderer
from tinygrad.runtime.autogen import libc, qcom_dsp
if getenv("IOCTL"): import extra.dsp.run # noqa: F401 # pylint: disable=unused-import
@@ -91,10 +90,23 @@ class DSPAllocator(Allocator):
def _copyout(self, dest:memoryview, src:DSPBuffer): ctypes.memmove(from_mv(dest), src.va_addr, dest.nbytes)
def _offset(self, buf, size:int, offset:int): return DSPBuffer(buf.va_addr+offset, size, buf.share_info, buf.offset+offset)
class DSPDevice(Compiled):
def __init__(self, device:str=""):
self.ion_fd = os.open('/dev/ion', os.O_RDONLY)
class ClangCompiler(Compiler):
def __init__(self, cachekey="compile_clang", args:list[str]|None=None, objdump_tool='objdump'):
self.args = ['-march=native'] if args is None else args
self.objdump_tool = objdump_tool
super().__init__(cachekey)
def compile(self, src:str) -> bytes:
# TODO: remove file write. sadly clang doesn't like the use of /dev/stdout here
with tempfile.NamedTemporaryFile(delete=True) as output_file:
subprocess.check_output(['clang', '-shared', *self.args, '-O2', '-Wall', '-Werror', '-x', 'c', '-fPIC', '-ffreestanding', '-nostdlib',
'-', '-o', str(output_file.name)], input=src.encode('utf-8'))
return pathlib.Path(output_file.name).read_bytes()
def disassemble(self, lib:bytes): return cpu_objdump(lib, self.objdump_tool)
class DSPCompiler(ClangCompiler):
def __init__(self):
# Generate link script to pass into clang. Aligning all used sections to 4k fixes invoke problem.
sections = ['hash', 'text', 'rela.plt', 'got', 'got.plt', 'dynamic', 'dynsym', 'dynstr', 'plt', 'data', 'bss']
sections_link = '\n'.join([f'.{n} : ALIGN(4096) {{ *(.{n}) }}' for n in sections])
@@ -103,15 +115,19 @@ class DSPDevice(Compiled):
self.link_ld.flush()
compiler_args = ["--target=hexagon", "-mcpu=hexagonv65", "-fuse-ld=lld", "-nostdlib", "-mhvx=v65", "-mhvx-length=128b", f"-T{self.link_ld.name}"]
super().__init__(device, DSPAllocator(self), DSPRenderer(),
ClangCompiler("compile_dsp", args=compiler_args, objdump_tool='llvm-objdump'), functools.partial(DSPProgram, self))
return super().__init__("compile_dsp", args=compiler_args, objdump_tool='llvm-objdump')
class DSPDevice(Compiled):
def __init__(self, device:str=""):
self.ion_fd = os.open('/dev/ion', os.O_RDONLY)
super().__init__(device, DSPAllocator(self), DSPRenderer(), DSPCompiler(), functools.partial(DSPProgram, self))
fastrpc_shell = memoryview(bytearray(pathlib.Path('/dsp/cdsp/fastrpc_shell_3').read_bytes()))
self.shell_buf = self.allocator.alloc(round_up(fastrpc_shell.nbytes, 0x1000), BufferSpec(nolru=True))
ctypes.memmove(self.shell_buf.va_addr, mv_address(fastrpc_shell), fastrpc_shell.nbytes)
self.init_dsp()
RPCListner(self).start()
RPCListener(self).start()
def open_lib(self, lib):
self.binded_lib, self.binded_lib_off = lib, 0
@@ -149,7 +165,7 @@ class DSPDevice(Compiled):
qcom_dsp.FASTRPC_IOCTL_INIT(self.rpc_fd, flags=0x1, file=self.shell_buf.va_addr, filelen=self.shell_buf.size, filefd=self.shell_buf.share_info.fd)
qcom_dsp.FASTRPC_IOCTL_INVOKE(self.rpc_fd, handle=3, sc=rpc_sc(method=3, ins=0, outs=0))
class RPCListner(threading.Thread):
class RPCListener(threading.Thread):
def __init__(self, device:DSPDevice):
super().__init__()
self.device, self.daemon = device, True