Fast DSP for MobileNetV2 (try 2) (#9467)

* Fast DSP for MobileNetV2 (try 2)

* enable fast path on uchar

* fix tests
This commit is contained in:
George Hotz
2025-03-17 15:10:36 +08:00
committed by GitHub
parent 15ee742afa
commit 52ae9af4dd
6 changed files with 51 additions and 16 deletions

View File

@@ -728,7 +728,11 @@ def get_onnx_ops():
def QuantizeLinear(x:Tensor, y_scale:Tensor, y_zero_point:Tensor|int=0, axis:int=1, block_size:int=0, output_dtype:int=0, saturate=1):
out_dtype = y_zero_point.dtype if isinstance(y_zero_point, Tensor) else dtype_parse(output_dtype) if output_dtype else dtypes.uint8
y_scale, y_zero_point = _prepare_quantize(x, y_scale, y_zero_point, axis, block_size)
return _clamp_cast(((x / y_scale).round() + y_zero_point), out_dtype).contiguous()
if out_dtype == dtypes.uchar:
# this appears to work in practice, at least for uchar out_dtype. it folds with the quantize stuff
return _clamp_cast((x / y_scale + 0.4999999 + y_zero_point).int(), out_dtype).contiguous()
else:
return _clamp_cast(((x / y_scale).round() + y_zero_point), out_dtype).contiguous()
def DynamicQuantizeLinear(x: Tensor):
# only support uint8

View File

@@ -1,6 +1,7 @@
import pickle, sys
from dataclasses import replace
from tinygrad import Device
from tinygrad import Device, Context
from tinygrad.device import Buffer
from tinygrad.helpers import getenv
from tinygrad.engine.jit import TinyJit
from tinygrad.engine.realize import CompiledRunner
@@ -8,10 +9,11 @@ from tinygrad.renderer import ProgramSpec
from tinygrad.codegen.kernel import Kernel, Opt, OptOps
if __name__ == "__main__":
with open(sys.argv[1], "rb") as f:
fxn: TinyJit = pickle.load(f)
print(f"{f.tell()/1e6:.2f}M loaded")
print(type(fxn))
with Context(DEBUG=0):
with open(sys.argv[1], "rb") as f:
fxn: TinyJit = pickle.load(f)
print(f"{f.tell()/1e6:.2f}M loaded")
print(type(fxn))
knum = 1
for ei in fxn.captured.jit_cache:
@@ -21,17 +23,33 @@ if __name__ == "__main__":
p: ProgramSpec = ei.prg.p
k = Kernel(p.ast, Device["DSP"].renderer)
if not getenv("NOOPT"):
if knum == 2:
if knum in [6,7,9,11]:
k.apply_opt(Opt(OptOps.PADTO, 1, 128))
k.apply_opt(Opt(OptOps.UPCAST, 1, 128))
elif knum in [5,8]:
k.apply_opt(Opt(op=OptOps.UNROLL, axis=1, arg=0))
k.apply_opt(Opt(op=OptOps.UNROLL, axis=0, arg=0))
k.apply_opt(Opt(OptOps.PADTO, 2, 128))
k.apply_opt(Opt(OptOps.UPCAST, 2, 128))
elif knum == 2:
k.apply_opt(Opt(op=OptOps.UNROLL, axis=1, arg=0))
k.apply_opt(Opt(op=OptOps.UNROLL, axis=0, arg=0))
k.apply_opt(Opt(OptOps.PADTO, 2, 128))
k.apply_opt(Opt(OptOps.UPCAST, 2, 128))
#k.apply_opt(Opt(op=OptOps.UPCAST, axis=1, arg=4))
elif knum == 1:
k.apply_opt(Opt(op=OptOps.UNROLL, axis=2, arg=0))
k.apply_opt(Opt(op=OptOps.UNROLL, axis=1, arg=0))
#k.apply_opt(Opt(op=OptOps.UNROLL, axis=0, arg=0))
k.apply_opt(Opt(OptOps.PADTO, 2, 128))
k.apply_opt(Opt(OptOps.UPCAST, 2, 128))
elif knum == 3:
k.apply_opt(Opt(op=OptOps.UPCAST, axis=1, arg=128))
k.apply_opt(Opt(op=OptOps.UNROLL, axis=0, arg=4))
k.apply_opt(Opt(OptOps.UPCAST, 1, 128))
else:
k.hand_coded_optimizations()
#if knum in [5]: k.apply_opt(Opt(OptOps.UPCAST, 1, 2))
p2 = k.to_program()
new_ei = replace(ei, prg=CompiledRunner(p2))
new_ei = replace(ei, prg=CompiledRunner(p2), bufs=[Buffer("DSP", 128+b.size*2, b.dtype).view(b.size, b.dtype, 128) for b in ei.bufs])
new_ei.run()
knum += 1

View File

@@ -45,7 +45,8 @@ def expand_index(buf:UOp, vec:UOp, mask:UOp|None=None):
global_offset += len(grp)
assert None not in idxs, f"some idxs are missing {idxs}"
# this base thing is for image, we want the CAT to be a normal pointer
return UOp(Ops.CAT, ptrdtype.base.ptr(size=ptrdtype.size, local=ptrdtype.local).vec(vec.dtype.count), tuple(ret)).gep(tuple(cast(list[int], idxs)))
post_cat = UOp(Ops.CAT, ptrdtype.base.ptr(size=ptrdtype.size, local=ptrdtype.local).vec(vec.dtype.count), tuple(ret)) if len(ret) > 1 else ret[0]
return post_cat.gep(tuple(cast(list[int], idxs)))
def cat_after_store(cat:UOp, data:UOp):
# TODO: this is written in many places
@@ -143,7 +144,11 @@ def split_load_store(ctx:Renderer|None, ls:UOp, idx:UOp):
if (sz:=ls.src[0].dtype.count) == 1: return None
lengths = []
buf = idx.src[0]
if buf.dtype.base != dtypes.float and buf.dtype.base != dtypes.half and not isinstance(buf.dtype, ImageDType):
must_divide = True
if ctx is not None and ctx.device == "DSP":
lengths = [128,64,32,16,8,4]
must_divide = False
elif buf.dtype.base != dtypes.float and buf.dtype.base != dtypes.half and not isinstance(buf.dtype, ImageDType):
pass
elif isinstance(buf.dtype, ImageDType):
lengths = [4]
@@ -158,7 +163,7 @@ def split_load_store(ctx:Renderer|None, ls:UOp, idx:UOp):
for fold_length in lengths:
if global_offset+fold_length > sz: continue
oidx = idx.src[1] + global_offset
if oidx.simplify().divides(fold_length) is None: continue
if must_divide and oidx.simplify().divides(fold_length) is None: continue
lidx = buf.index(oidx, idx.src[2] if len(idx.src) > 2 else None)
if fold_length > 1: lidx = lidx.cast(ptrdtype.base.vec(fold_length).ptr(size=ptrdtype.size, local=ptrdtype.local))
if ls.op is Ops.STORE: ret.append(ls.replace(src=(lidx,ls.src[1].gep(tuple(range(global_offset, global_offset+fold_length))))+ls.src[2:]))

View File

@@ -501,10 +501,12 @@ class Kernel:
for axis in to_upcast[::-1]: self.apply_opt(Opt(OptOps.UPCAST, axis, 0))
# potentially do more upcasts of non reduce axes based on a heuristic
is_dsp = self.opts is not None and self.opts.device == "DSP"
upcasted_axis: set[int] = set()
while resolve(prod(self.sts[0].shape[:self.first_reduce]) >= 1024):
xb_choices = []
for axis, upcast_amount in itertools.product(range(self.first_reduce), [3,4]): # consider all the non reduce axes, and a 3 or 4 reduce
# consider all the non reduce axes, and a 3 or 4 reduce. (128 on the DSP)
for axis, upcast_amount in itertools.product(range(self.first_reduce), ([128] if not len(upcasted_axis) else []) if is_dsp else [3,4]):
# if we haven't upcasted it, it's not symbolic, it mods, and buffer has stride 0 on axis while having no stride 0 in the upcasted axis already
if axis not in upcasted_axis and isinstance(self.full_shape[axis], int) and self.full_shape[axis]%upcast_amount == 0 and any(st.views[-1].strides[axis] == 0 and not any(x[1] == 0 for x in self.upcasted_axis(buf_index)) for buf_index, st in enumerate(self.sts)): # noqa: E501
xb_choices.append((sum(st.views[-1].strides[axis]>0 for st in self.sts), sum(st.views[-1].strides[axis] for st in self.sts), axis, upcast_amount)) # noqa: E501

View File

@@ -197,8 +197,8 @@ class ClangRenderer(CStyleLanguage):
if sys.platform == 'win32':
kernel_prefix = "__attribute__((ms_abi)) "
def render_vector_prefix(self, dt:DType) -> str:
# round (down) to power of two
alignment = 2**int(math.log2(dt.itemsize))
# round (down) to power of two (this is actually the default clang behavior)
alignment = 2**int(math.log2(dt.itemsize)) if getenv("ALIGNED", 1) else 1
return f"typedef {self.render_dtype(dt.scalar())} {self.render_dtype(dt)} __attribute__((aligned({alignment}),vector_size({dt.itemsize})));"
def render_kernel(self, function_name, kernel, bufs, uops, prefix=None) -> str:

View File

@@ -27,6 +27,11 @@ dsp_pm_late = PatternMatcher([
lambda d: d.replace(src=(UOp(Ops.CUSTOMI, d.dtype, arg="__builtin_HEXAGON_V6_vd0_128B()"),)+d.src[1:])),
])
# NOTE: this just increases readability of the generated code
dsp_string = PatternMatcher([
(UPat(Ops.CONST, (dtypes.int8, dtypes.uint8), name="x"), lambda ctx,x: str(x.arg)),
])
class DSPRenderer(ClangRenderer):
device = "DSP"
supports_float4 = True
@@ -34,6 +39,7 @@ class DSPRenderer(ClangRenderer):
kernel_prefix = "__attribute__((noinline)) "
pre_matcher = dsp_pm
extra_matcher = dsp_pm_late+ClangRenderer.extra_matcher
string_rewrite = dsp_string+ClangRenderer.string_rewrite
type_map = { **ClangRenderer.type_map, dtypes.uint64: "unsigned long long", dtypes.int64: "long long" }
code_for_op = {**ClangRenderer.code_for_op, Ops.SIN: lambda x,dtype: f"__builtin_sin({x})",
Ops.LOG2: lambda x,dtype: f"__builtin_log2l({x})" if dtype == dtypes.float64 else f"__builtin_log2f({x})",