mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
Fast DSP for MobileNetV2 (try 2) (#9467)
* Fast DSP for MobileNetV2 (try 2) * enable fast path on uchar * fix tests
This commit is contained in:
@@ -728,7 +728,11 @@ def get_onnx_ops():
|
||||
def QuantizeLinear(x:Tensor, y_scale:Tensor, y_zero_point:Tensor|int=0, axis:int=1, block_size:int=0, output_dtype:int=0, saturate=1):
|
||||
out_dtype = y_zero_point.dtype if isinstance(y_zero_point, Tensor) else dtype_parse(output_dtype) if output_dtype else dtypes.uint8
|
||||
y_scale, y_zero_point = _prepare_quantize(x, y_scale, y_zero_point, axis, block_size)
|
||||
return _clamp_cast(((x / y_scale).round() + y_zero_point), out_dtype).contiguous()
|
||||
if out_dtype == dtypes.uchar:
|
||||
# this appears to work in practice, at least for uchar out_dtype. it folds with the quantize stuff
|
||||
return _clamp_cast((x / y_scale + 0.4999999 + y_zero_point).int(), out_dtype).contiguous()
|
||||
else:
|
||||
return _clamp_cast(((x / y_scale).round() + y_zero_point), out_dtype).contiguous()
|
||||
|
||||
def DynamicQuantizeLinear(x: Tensor):
|
||||
# only support uint8
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import pickle, sys
|
||||
from dataclasses import replace
|
||||
from tinygrad import Device
|
||||
from tinygrad import Device, Context
|
||||
from tinygrad.device import Buffer
|
||||
from tinygrad.helpers import getenv
|
||||
from tinygrad.engine.jit import TinyJit
|
||||
from tinygrad.engine.realize import CompiledRunner
|
||||
@@ -8,10 +9,11 @@ from tinygrad.renderer import ProgramSpec
|
||||
from tinygrad.codegen.kernel import Kernel, Opt, OptOps
|
||||
|
||||
if __name__ == "__main__":
|
||||
with open(sys.argv[1], "rb") as f:
|
||||
fxn: TinyJit = pickle.load(f)
|
||||
print(f"{f.tell()/1e6:.2f}M loaded")
|
||||
print(type(fxn))
|
||||
with Context(DEBUG=0):
|
||||
with open(sys.argv[1], "rb") as f:
|
||||
fxn: TinyJit = pickle.load(f)
|
||||
print(f"{f.tell()/1e6:.2f}M loaded")
|
||||
print(type(fxn))
|
||||
|
||||
knum = 1
|
||||
for ei in fxn.captured.jit_cache:
|
||||
@@ -21,17 +23,33 @@ if __name__ == "__main__":
|
||||
p: ProgramSpec = ei.prg.p
|
||||
k = Kernel(p.ast, Device["DSP"].renderer)
|
||||
if not getenv("NOOPT"):
|
||||
if knum == 2:
|
||||
if knum in [6,7,9,11]:
|
||||
k.apply_opt(Opt(OptOps.PADTO, 1, 128))
|
||||
k.apply_opt(Opt(OptOps.UPCAST, 1, 128))
|
||||
elif knum in [5,8]:
|
||||
k.apply_opt(Opt(op=OptOps.UNROLL, axis=1, arg=0))
|
||||
k.apply_opt(Opt(op=OptOps.UNROLL, axis=0, arg=0))
|
||||
k.apply_opt(Opt(OptOps.PADTO, 2, 128))
|
||||
k.apply_opt(Opt(OptOps.UPCAST, 2, 128))
|
||||
elif knum == 2:
|
||||
k.apply_opt(Opt(op=OptOps.UNROLL, axis=1, arg=0))
|
||||
k.apply_opt(Opt(op=OptOps.UNROLL, axis=0, arg=0))
|
||||
k.apply_opt(Opt(OptOps.PADTO, 2, 128))
|
||||
k.apply_opt(Opt(OptOps.UPCAST, 2, 128))
|
||||
#k.apply_opt(Opt(op=OptOps.UPCAST, axis=1, arg=4))
|
||||
elif knum == 1:
|
||||
k.apply_opt(Opt(op=OptOps.UNROLL, axis=2, arg=0))
|
||||
k.apply_opt(Opt(op=OptOps.UNROLL, axis=1, arg=0))
|
||||
#k.apply_opt(Opt(op=OptOps.UNROLL, axis=0, arg=0))
|
||||
k.apply_opt(Opt(OptOps.PADTO, 2, 128))
|
||||
k.apply_opt(Opt(OptOps.UPCAST, 2, 128))
|
||||
elif knum == 3:
|
||||
k.apply_opt(Opt(op=OptOps.UPCAST, axis=1, arg=128))
|
||||
k.apply_opt(Opt(op=OptOps.UNROLL, axis=0, arg=4))
|
||||
k.apply_opt(Opt(OptOps.UPCAST, 1, 128))
|
||||
else:
|
||||
k.hand_coded_optimizations()
|
||||
#if knum in [5]: k.apply_opt(Opt(OptOps.UPCAST, 1, 2))
|
||||
p2 = k.to_program()
|
||||
new_ei = replace(ei, prg=CompiledRunner(p2))
|
||||
new_ei = replace(ei, prg=CompiledRunner(p2), bufs=[Buffer("DSP", 128+b.size*2, b.dtype).view(b.size, b.dtype, 128) for b in ei.bufs])
|
||||
new_ei.run()
|
||||
knum += 1
|
||||
|
||||
|
||||
@@ -45,7 +45,8 @@ def expand_index(buf:UOp, vec:UOp, mask:UOp|None=None):
|
||||
global_offset += len(grp)
|
||||
assert None not in idxs, f"some idxs are missing {idxs}"
|
||||
# this base thing is for image, we want the CAT to be a normal pointer
|
||||
return UOp(Ops.CAT, ptrdtype.base.ptr(size=ptrdtype.size, local=ptrdtype.local).vec(vec.dtype.count), tuple(ret)).gep(tuple(cast(list[int], idxs)))
|
||||
post_cat = UOp(Ops.CAT, ptrdtype.base.ptr(size=ptrdtype.size, local=ptrdtype.local).vec(vec.dtype.count), tuple(ret)) if len(ret) > 1 else ret[0]
|
||||
return post_cat.gep(tuple(cast(list[int], idxs)))
|
||||
|
||||
def cat_after_store(cat:UOp, data:UOp):
|
||||
# TODO: this is written in many places
|
||||
@@ -143,7 +144,11 @@ def split_load_store(ctx:Renderer|None, ls:UOp, idx:UOp):
|
||||
if (sz:=ls.src[0].dtype.count) == 1: return None
|
||||
lengths = []
|
||||
buf = idx.src[0]
|
||||
if buf.dtype.base != dtypes.float and buf.dtype.base != dtypes.half and not isinstance(buf.dtype, ImageDType):
|
||||
must_divide = True
|
||||
if ctx is not None and ctx.device == "DSP":
|
||||
lengths = [128,64,32,16,8,4]
|
||||
must_divide = False
|
||||
elif buf.dtype.base != dtypes.float and buf.dtype.base != dtypes.half and not isinstance(buf.dtype, ImageDType):
|
||||
pass
|
||||
elif isinstance(buf.dtype, ImageDType):
|
||||
lengths = [4]
|
||||
@@ -158,7 +163,7 @@ def split_load_store(ctx:Renderer|None, ls:UOp, idx:UOp):
|
||||
for fold_length in lengths:
|
||||
if global_offset+fold_length > sz: continue
|
||||
oidx = idx.src[1] + global_offset
|
||||
if oidx.simplify().divides(fold_length) is None: continue
|
||||
if must_divide and oidx.simplify().divides(fold_length) is None: continue
|
||||
lidx = buf.index(oidx, idx.src[2] if len(idx.src) > 2 else None)
|
||||
if fold_length > 1: lidx = lidx.cast(ptrdtype.base.vec(fold_length).ptr(size=ptrdtype.size, local=ptrdtype.local))
|
||||
if ls.op is Ops.STORE: ret.append(ls.replace(src=(lidx,ls.src[1].gep(tuple(range(global_offset, global_offset+fold_length))))+ls.src[2:]))
|
||||
|
||||
@@ -501,10 +501,12 @@ class Kernel:
|
||||
for axis in to_upcast[::-1]: self.apply_opt(Opt(OptOps.UPCAST, axis, 0))
|
||||
|
||||
# potentially do more upcasts of non reduce axes based on a heuristic
|
||||
is_dsp = self.opts is not None and self.opts.device == "DSP"
|
||||
upcasted_axis: set[int] = set()
|
||||
while resolve(prod(self.sts[0].shape[:self.first_reduce]) >= 1024):
|
||||
xb_choices = []
|
||||
for axis, upcast_amount in itertools.product(range(self.first_reduce), [3,4]): # consider all the non reduce axes, and a 3 or 4 reduce
|
||||
# consider all the non reduce axes, and a 3 or 4 reduce. (128 on the DSP)
|
||||
for axis, upcast_amount in itertools.product(range(self.first_reduce), ([128] if not len(upcasted_axis) else []) if is_dsp else [3,4]):
|
||||
# if we haven't upcasted it, it's not symbolic, it mods, and buffer has stride 0 on axis while having no stride 0 in the upcasted axis already
|
||||
if axis not in upcasted_axis and isinstance(self.full_shape[axis], int) and self.full_shape[axis]%upcast_amount == 0 and any(st.views[-1].strides[axis] == 0 and not any(x[1] == 0 for x in self.upcasted_axis(buf_index)) for buf_index, st in enumerate(self.sts)): # noqa: E501
|
||||
xb_choices.append((sum(st.views[-1].strides[axis]>0 for st in self.sts), sum(st.views[-1].strides[axis] for st in self.sts), axis, upcast_amount)) # noqa: E501
|
||||
|
||||
@@ -197,8 +197,8 @@ class ClangRenderer(CStyleLanguage):
|
||||
if sys.platform == 'win32':
|
||||
kernel_prefix = "__attribute__((ms_abi)) "
|
||||
def render_vector_prefix(self, dt:DType) -> str:
|
||||
# round (down) to power of two
|
||||
alignment = 2**int(math.log2(dt.itemsize))
|
||||
# round (down) to power of two (this is actually the default clang behavior)
|
||||
alignment = 2**int(math.log2(dt.itemsize)) if getenv("ALIGNED", 1) else 1
|
||||
return f"typedef {self.render_dtype(dt.scalar())} {self.render_dtype(dt)} __attribute__((aligned({alignment}),vector_size({dt.itemsize})));"
|
||||
|
||||
def render_kernel(self, function_name, kernel, bufs, uops, prefix=None) -> str:
|
||||
|
||||
@@ -27,6 +27,11 @@ dsp_pm_late = PatternMatcher([
|
||||
lambda d: d.replace(src=(UOp(Ops.CUSTOMI, d.dtype, arg="__builtin_HEXAGON_V6_vd0_128B()"),)+d.src[1:])),
|
||||
])
|
||||
|
||||
# NOTE: this just increases readability of the generated code
|
||||
dsp_string = PatternMatcher([
|
||||
(UPat(Ops.CONST, (dtypes.int8, dtypes.uint8), name="x"), lambda ctx,x: str(x.arg)),
|
||||
])
|
||||
|
||||
class DSPRenderer(ClangRenderer):
|
||||
device = "DSP"
|
||||
supports_float4 = True
|
||||
@@ -34,6 +39,7 @@ class DSPRenderer(ClangRenderer):
|
||||
kernel_prefix = "__attribute__((noinline)) "
|
||||
pre_matcher = dsp_pm
|
||||
extra_matcher = dsp_pm_late+ClangRenderer.extra_matcher
|
||||
string_rewrite = dsp_string+ClangRenderer.string_rewrite
|
||||
type_map = { **ClangRenderer.type_map, dtypes.uint64: "unsigned long long", dtypes.int64: "long long" }
|
||||
code_for_op = {**ClangRenderer.code_for_op, Ops.SIN: lambda x,dtype: f"__builtin_sin({x})",
|
||||
Ops.LOG2: lambda x,dtype: f"__builtin_log2l({x})" if dtype == dtypes.float64 else f"__builtin_log2f({x})",
|
||||
|
||||
Reference in New Issue
Block a user