mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
dsp cleanups [pr] (#8866)
This commit is contained in:
@@ -1,6 +1,5 @@
|
|||||||
import sys, onnx, time
|
import sys, onnx, time
|
||||||
from tinygrad import Tensor, TinyJit, Device, GlobalCounters, fetch
|
from tinygrad import Tensor, TinyJit, Device, GlobalCounters, fetch
|
||||||
from tinygrad.tensor import _from_np_dtype
|
|
||||||
from extra.onnx import OnnxRunner
|
from extra.onnx import OnnxRunner
|
||||||
|
|
||||||
def load_onnx_model(fn):
|
def load_onnx_model(fn):
|
||||||
@@ -18,19 +17,25 @@ def load_onnx_model(fn):
|
|||||||
run_onnx_jit = TinyJit(lambda **kwargs: next(iter(run_onnx({k:v.to(Device.DEFAULT) for k,v in kwargs.items()}).values())), prune=True)
|
run_onnx_jit = TinyJit(lambda **kwargs: next(iter(run_onnx({k:v.to(Device.DEFAULT) for k,v in kwargs.items()}).values())), prune=True)
|
||||||
return run_onnx_jit, input_shapes, input_types
|
return run_onnx_jit, input_shapes, input_types
|
||||||
|
|
||||||
|
def get_new_inputs(input_shapes):
|
||||||
|
#from tinygrad.tensor import _from_np_dtype
|
||||||
|
#return {k:Tensor.randn(*shp, dtype=_from_np_dtype(input_types[k])).mul(8).realize() for k,shp in sorted(input_shapes.items())}
|
||||||
|
import numpy as np
|
||||||
|
return {k:Tensor(np.random.uniform(size=shp).astype(input_types[k]) * 8).realize() for k,shp in sorted(input_shapes.items())}
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
run_onnx_jit, input_shapes, input_types = load_onnx_model(sys.argv[1])
|
run_onnx_jit, input_shapes, input_types = load_onnx_model(sys.argv[1])
|
||||||
print("loaded model")
|
print("loaded model")
|
||||||
|
|
||||||
for i in range(3):
|
for i in range(3):
|
||||||
new_inputs = {k:Tensor.randn(*shp, dtype=_from_np_dtype(input_types[k])).mul(8).realize() for k,shp in sorted(input_shapes.items())}
|
new_inputs = get_new_inputs(input_shapes)
|
||||||
GlobalCounters.reset()
|
GlobalCounters.reset()
|
||||||
print(f"run {i}")
|
print(f"run {i}")
|
||||||
run_onnx_jit(**new_inputs)
|
run_onnx_jit(**new_inputs)
|
||||||
|
|
||||||
# run 20 times
|
# run 20 times
|
||||||
for _ in range(20):
|
for _ in range(20):
|
||||||
new_inputs = {k:Tensor.randn(*shp, dtype=_from_np_dtype(input_types[k])).mul(8).realize() for k,shp in sorted(input_shapes.items())}
|
new_inputs = get_new_inputs(input_shapes)
|
||||||
GlobalCounters.reset()
|
GlobalCounters.reset()
|
||||||
st = time.perf_counter()
|
st = time.perf_counter()
|
||||||
out = run_onnx_jit(**new_inputs)
|
out = run_onnx_jit(**new_inputs)
|
||||||
|
|||||||
@@ -17,6 +17,10 @@ from tinygrad.helpers import fetch, getenv
|
|||||||
# https://huggingface.co/qualcomm/MobileNet-v2-Quantized/resolve/main/MobileNet-v2-Quantized.onnx
|
# https://huggingface.co/qualcomm/MobileNet-v2-Quantized/resolve/main/MobileNet-v2-Quantized.onnx
|
||||||
# ~35% - https://github.com/axinc-ai/onnx-quantization/raw/refs/heads/main/models/mobilenev2_quantized.onnx
|
# ~35% - https://github.com/axinc-ai/onnx-quantization/raw/refs/heads/main/models/mobilenev2_quantized.onnx
|
||||||
|
|
||||||
|
# QUANT=1 python3 examples/test_onnx_imagenet.py
|
||||||
|
# https://github.com/xamcat/mobcat-samples/raw/refs/heads/master/onnx_runtime/InferencingSample/InferencingSample/mobilenetv2-7.onnx
|
||||||
|
# VIZ=1 DONT_REALIZE_EXPAND=1 python3 examples/benchmark_onnx.py /tmp/model.quant.onnx
|
||||||
|
|
||||||
def imagenet_dataloader(cnt=0):
|
def imagenet_dataloader(cnt=0):
|
||||||
input_mean = Tensor([0.485, 0.456, 0.406]).reshape(1, -1, 1, 1)
|
input_mean = Tensor([0.485, 0.456, 0.406]).reshape(1, -1, 1, 1)
|
||||||
input_std = Tensor([0.229, 0.224, 0.225]).reshape(1, -1, 1, 1)
|
input_std = Tensor([0.229, 0.224, 0.225]).reshape(1, -1, 1, 1)
|
||||||
@@ -61,7 +65,7 @@ if __name__ == "__main__":
|
|||||||
assert shape[1:] == (3,224,224), f"shape is {shape}"
|
assert shape[1:] == (3,224,224), f"shape is {shape}"
|
||||||
|
|
||||||
hit = 0
|
hit = 0
|
||||||
for i,(img,y) in enumerate(imagenet_dataloader()):
|
for i,(img,y) in enumerate(imagenet_dataloader(cnt=100)):
|
||||||
p = run_onnx_jit(**{t_name:img})
|
p = run_onnx_jit(**{t_name:img})
|
||||||
assert p.shape == (1,1000)
|
assert p.shape == (1,1000)
|
||||||
t = p.argmax().item()
|
t = p.argmax().item()
|
||||||
|
|||||||
@@ -37,7 +37,7 @@ if __name__ == "__main__":
|
|||||||
print("mmapped", hex(res))
|
print("mmapped", hex(res))
|
||||||
to_mv(res, 0x10)[1] = 0xaa
|
to_mv(res, 0x10)[1] = 0xaa
|
||||||
|
|
||||||
from tinygrad.runtime.ops_clang import ClangCompiler
|
from tinygrad.runtime.ops_dsp import ClangCompiler
|
||||||
cc = ClangCompiler(args=["--target=hexagon", "-mcpu=hexagonv65", "-fuse-ld=lld", "-nostdlib"])
|
cc = ClangCompiler(args=["--target=hexagon", "-mcpu=hexagonv65", "-fuse-ld=lld", "-nostdlib"])
|
||||||
|
|
||||||
obj = cc.compile("""
|
obj = cc.compile("""
|
||||||
|
|||||||
27
extra/dsp/opt.py
Normal file
27
extra/dsp/opt.py
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
from tinygrad.runtime.ops_dsp import DSPCompiler
|
||||||
|
|
||||||
|
# PATH=/opt/homebrew/opt/llvm/bin:$PATH python3 extra/dsp/opt.py
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
compiler = DSPCompiler()
|
||||||
|
|
||||||
|
lib = compiler.compile("""
|
||||||
|
typedef long HVX_Vector __attribute__((__vector_size__(128))) __attribute__ ((aligned(128)));
|
||||||
|
typedef long HVX_VectorPair __attribute__((__vector_size__(256))) __attribute__ ((aligned(256)));
|
||||||
|
|
||||||
|
void test(unsigned char *c, unsigned char *a, unsigned char *b) {
|
||||||
|
HVX_Vector t0 = *(HVX_Vector*)a;
|
||||||
|
//HVX_VectorPair t1 = *((HVX_VectorPair*)b);
|
||||||
|
HVX_Vector acc = __builtin_HEXAGON_V6_vd0_128B();
|
||||||
|
for (int i = 0; i < 128; i++) {
|
||||||
|
//__builtin_HEXAGON_V6_lvsplatb_128B(t0[i])
|
||||||
|
//acc += __builtin_HEXAGON_V6_lvsplatb_128B(t0[i]) * t1;
|
||||||
|
//acc += t0[i] * t1;
|
||||||
|
unsigned int t1 = ((unsigned int *)b)[i];
|
||||||
|
//acc = __builtin_HEXAGON_V6_vrmpyub_acc_128B(acc, t0, t1);
|
||||||
|
acc = __builtin_HEXAGON_V6_vrmpybus_acc_128B(acc, t0, t1);
|
||||||
|
}
|
||||||
|
*((HVX_Vector*)c) = acc;
|
||||||
|
}""")
|
||||||
|
|
||||||
|
compiler.disassemble(lib)
|
||||||
@@ -1,25 +1,9 @@
|
|||||||
import platform, tempfile, pathlib, subprocess, sys
|
import platform, subprocess, sys
|
||||||
from tinygrad.helpers import cpu_objdump, capstone_flatdump
|
from tinygrad.helpers import capstone_flatdump
|
||||||
from tinygrad.device import Compiled, Compiler, MallocAllocator, CPUProgram
|
from tinygrad.device import Compiled, Compiler, MallocAllocator, CPUProgram
|
||||||
from tinygrad.runtime.support.elf import jit_loader
|
from tinygrad.runtime.support.elf import jit_loader
|
||||||
from tinygrad.renderer.cstyle import ClangRenderer
|
from tinygrad.renderer.cstyle import ClangRenderer
|
||||||
|
|
||||||
# Used by ops_dsp.py
|
|
||||||
class ClangCompiler(Compiler):
|
|
||||||
def __init__(self, cachekey="compile_clang", args:list[str]|None=None, objdump_tool='objdump'):
|
|
||||||
self.args = ['-march=native'] if args is None else args
|
|
||||||
self.objdump_tool = objdump_tool
|
|
||||||
super().__init__(cachekey)
|
|
||||||
|
|
||||||
def compile(self, src:str) -> bytes:
|
|
||||||
# TODO: remove file write. sadly clang doesn't like the use of /dev/stdout here
|
|
||||||
with tempfile.NamedTemporaryFile(delete=True) as output_file:
|
|
||||||
subprocess.check_output(['clang', '-shared', *self.args, '-O2', '-Wall', '-Werror', '-x', 'c', '-fPIC', '-ffreestanding', '-nostdlib',
|
|
||||||
'-', '-o', str(output_file.name)], input=src.encode('utf-8'))
|
|
||||||
return pathlib.Path(output_file.name).read_bytes()
|
|
||||||
|
|
||||||
def disassemble(self, lib:bytes): return cpu_objdump(lib, self.objdump_tool)
|
|
||||||
|
|
||||||
class ClangJITCompiler(Compiler):
|
class ClangJITCompiler(Compiler):
|
||||||
def __init__(self, cachekey="compile_clang_jit"): super().__init__(cachekey)
|
def __init__(self, cachekey="compile_clang_jit"): super().__init__(cachekey)
|
||||||
|
|
||||||
|
|||||||
@@ -1,12 +1,11 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
from typing import Tuple, Any, List
|
from typing import Tuple, Any, List
|
||||||
import ctypes, os, mmap, tempfile, pathlib, array, functools, threading, contextlib, sys
|
import ctypes, os, mmap, tempfile, pathlib, array, functools, threading, contextlib, sys, subprocess
|
||||||
assert sys.platform != 'win32'
|
assert sys.platform != 'win32'
|
||||||
from tinygrad.device import BufferSpec, Compiled, Allocator
|
from tinygrad.device import BufferSpec, Compiled, Allocator, Compiler
|
||||||
from tinygrad.dtype import dtypes, DType, PtrDType
|
from tinygrad.dtype import dtypes, DType, PtrDType
|
||||||
from tinygrad.ops import Ops, UOp
|
from tinygrad.ops import Ops, UOp
|
||||||
from tinygrad.helpers import from_mv, getenv, round_up, mv_address, to_mv
|
from tinygrad.helpers import from_mv, getenv, round_up, mv_address, to_mv, cpu_objdump
|
||||||
from tinygrad.runtime.ops_clang import ClangCompiler
|
|
||||||
from tinygrad.renderer.cstyle import ClangRenderer
|
from tinygrad.renderer.cstyle import ClangRenderer
|
||||||
from tinygrad.runtime.autogen import libc, qcom_dsp
|
from tinygrad.runtime.autogen import libc, qcom_dsp
|
||||||
if getenv("IOCTL"): import extra.dsp.run # noqa: F401 # pylint: disable=unused-import
|
if getenv("IOCTL"): import extra.dsp.run # noqa: F401 # pylint: disable=unused-import
|
||||||
@@ -91,10 +90,23 @@ class DSPAllocator(Allocator):
|
|||||||
def _copyout(self, dest:memoryview, src:DSPBuffer): ctypes.memmove(from_mv(dest), src.va_addr, dest.nbytes)
|
def _copyout(self, dest:memoryview, src:DSPBuffer): ctypes.memmove(from_mv(dest), src.va_addr, dest.nbytes)
|
||||||
def _offset(self, buf, size:int, offset:int): return DSPBuffer(buf.va_addr+offset, size, buf.share_info, buf.offset+offset)
|
def _offset(self, buf, size:int, offset:int): return DSPBuffer(buf.va_addr+offset, size, buf.share_info, buf.offset+offset)
|
||||||
|
|
||||||
class DSPDevice(Compiled):
|
class ClangCompiler(Compiler):
|
||||||
def __init__(self, device:str=""):
|
def __init__(self, cachekey="compile_clang", args:list[str]|None=None, objdump_tool='objdump'):
|
||||||
self.ion_fd = os.open('/dev/ion', os.O_RDONLY)
|
self.args = ['-march=native'] if args is None else args
|
||||||
|
self.objdump_tool = objdump_tool
|
||||||
|
super().__init__(cachekey)
|
||||||
|
|
||||||
|
def compile(self, src:str) -> bytes:
|
||||||
|
# TODO: remove file write. sadly clang doesn't like the use of /dev/stdout here
|
||||||
|
with tempfile.NamedTemporaryFile(delete=True) as output_file:
|
||||||
|
subprocess.check_output(['clang', '-shared', *self.args, '-O2', '-Wall', '-Werror', '-x', 'c', '-fPIC', '-ffreestanding', '-nostdlib',
|
||||||
|
'-', '-o', str(output_file.name)], input=src.encode('utf-8'))
|
||||||
|
return pathlib.Path(output_file.name).read_bytes()
|
||||||
|
|
||||||
|
def disassemble(self, lib:bytes): return cpu_objdump(lib, self.objdump_tool)
|
||||||
|
|
||||||
|
class DSPCompiler(ClangCompiler):
|
||||||
|
def __init__(self):
|
||||||
# Generate link script to pass into clang. Aligning all used sections to 4k fixes invoke problem.
|
# Generate link script to pass into clang. Aligning all used sections to 4k fixes invoke problem.
|
||||||
sections = ['hash', 'text', 'rela.plt', 'got', 'got.plt', 'dynamic', 'dynsym', 'dynstr', 'plt', 'data', 'bss']
|
sections = ['hash', 'text', 'rela.plt', 'got', 'got.plt', 'dynamic', 'dynsym', 'dynstr', 'plt', 'data', 'bss']
|
||||||
sections_link = '\n'.join([f'.{n} : ALIGN(4096) {{ *(.{n}) }}' for n in sections])
|
sections_link = '\n'.join([f'.{n} : ALIGN(4096) {{ *(.{n}) }}' for n in sections])
|
||||||
@@ -103,15 +115,19 @@ class DSPDevice(Compiled):
|
|||||||
self.link_ld.flush()
|
self.link_ld.flush()
|
||||||
|
|
||||||
compiler_args = ["--target=hexagon", "-mcpu=hexagonv65", "-fuse-ld=lld", "-nostdlib", "-mhvx=v65", "-mhvx-length=128b", f"-T{self.link_ld.name}"]
|
compiler_args = ["--target=hexagon", "-mcpu=hexagonv65", "-fuse-ld=lld", "-nostdlib", "-mhvx=v65", "-mhvx-length=128b", f"-T{self.link_ld.name}"]
|
||||||
super().__init__(device, DSPAllocator(self), DSPRenderer(),
|
return super().__init__("compile_dsp", args=compiler_args, objdump_tool='llvm-objdump')
|
||||||
ClangCompiler("compile_dsp", args=compiler_args, objdump_tool='llvm-objdump'), functools.partial(DSPProgram, self))
|
|
||||||
|
class DSPDevice(Compiled):
|
||||||
|
def __init__(self, device:str=""):
|
||||||
|
self.ion_fd = os.open('/dev/ion', os.O_RDONLY)
|
||||||
|
super().__init__(device, DSPAllocator(self), DSPRenderer(), DSPCompiler(), functools.partial(DSPProgram, self))
|
||||||
|
|
||||||
fastrpc_shell = memoryview(bytearray(pathlib.Path('/dsp/cdsp/fastrpc_shell_3').read_bytes()))
|
fastrpc_shell = memoryview(bytearray(pathlib.Path('/dsp/cdsp/fastrpc_shell_3').read_bytes()))
|
||||||
self.shell_buf = self.allocator.alloc(round_up(fastrpc_shell.nbytes, 0x1000), BufferSpec(nolru=True))
|
self.shell_buf = self.allocator.alloc(round_up(fastrpc_shell.nbytes, 0x1000), BufferSpec(nolru=True))
|
||||||
ctypes.memmove(self.shell_buf.va_addr, mv_address(fastrpc_shell), fastrpc_shell.nbytes)
|
ctypes.memmove(self.shell_buf.va_addr, mv_address(fastrpc_shell), fastrpc_shell.nbytes)
|
||||||
|
|
||||||
self.init_dsp()
|
self.init_dsp()
|
||||||
RPCListner(self).start()
|
RPCListener(self).start()
|
||||||
|
|
||||||
def open_lib(self, lib):
|
def open_lib(self, lib):
|
||||||
self.binded_lib, self.binded_lib_off = lib, 0
|
self.binded_lib, self.binded_lib_off = lib, 0
|
||||||
@@ -149,7 +165,7 @@ class DSPDevice(Compiled):
|
|||||||
qcom_dsp.FASTRPC_IOCTL_INIT(self.rpc_fd, flags=0x1, file=self.shell_buf.va_addr, filelen=self.shell_buf.size, filefd=self.shell_buf.share_info.fd)
|
qcom_dsp.FASTRPC_IOCTL_INIT(self.rpc_fd, flags=0x1, file=self.shell_buf.va_addr, filelen=self.shell_buf.size, filefd=self.shell_buf.share_info.fd)
|
||||||
qcom_dsp.FASTRPC_IOCTL_INVOKE(self.rpc_fd, handle=3, sc=rpc_sc(method=3, ins=0, outs=0))
|
qcom_dsp.FASTRPC_IOCTL_INVOKE(self.rpc_fd, handle=3, sc=rpc_sc(method=3, ins=0, outs=0))
|
||||||
|
|
||||||
class RPCListner(threading.Thread):
|
class RPCListener(threading.Thread):
|
||||||
def __init__(self, device:DSPDevice):
|
def __init__(self, device:DSPDevice):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.device, self.daemon = device, True
|
self.device, self.daemon = device, True
|
||||||
|
|||||||
Reference in New Issue
Block a user