mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
switch quantization to unsigned/unsigned + add Ops.REDUCE (#9527)
* switch quantization to unsigned/unsigned + add Ops.REDUCE * tests * nhwc + replay pkl
This commit is contained in:
@@ -58,7 +58,7 @@ if __name__ == "__main__":
|
||||
return None
|
||||
return {"input": img.numpy()}
|
||||
quantize_static(model_fp32, fn, ImagenetReader(), quant_format=QuantFormat.QDQ, per_channel=False,
|
||||
activation_type=QuantType.QUInt8, weight_type=QuantType.QInt8,
|
||||
activation_type=QuantType.QUInt8, weight_type=QuantType.QUInt8,
|
||||
extra_options={"ActivationSymmetric": False})
|
||||
|
||||
run_onnx_jit, input_specs = load_onnx_model(fetch(fn))
|
||||
@@ -74,5 +74,8 @@ if __name__ == "__main__":
|
||||
hit += y==t
|
||||
print(f"target: {y:3d} pred: {t:3d} acc: {hit/(i+1)*100:.2f}%")
|
||||
|
||||
MS_TARGET = 13.4
|
||||
print(f"need {GlobalCounters.global_ops/1e9*(1000/MS_TARGET):.2f} GFLOPS for {MS_TARGET:.2f} ms")
|
||||
|
||||
import pickle
|
||||
with open("/tmp/im.pkl", "wb") as f: pickle.dump(run_onnx_jit, f)
|
||||
|
||||
@@ -744,6 +744,10 @@ def get_onnx_ops():
|
||||
return y, scale, zero_point
|
||||
|
||||
def DequantizeLinear(x:Tensor, x_scale:Tensor, x_zero_point:Tensor|int=0, axis:int=1, block_size:int=0):
|
||||
WEIGHT_SHIFT = 4
|
||||
if getenv("NHWC") and len(x.shape) == 4 and x.shape[2:] == (1,1) and x.shape[1]%WEIGHT_SHIFT == 0:
|
||||
# DSP swizzle memory
|
||||
x = x.reshape(x.shape[0], x.shape[1]//WEIGHT_SHIFT, WEIGHT_SHIFT).permute(1,0,2).contiguous().permute(1,0,2).reshape(x.shape)
|
||||
x_scale, x_zero_point = _prepare_quantize(x, x_scale, x_zero_point, axis, block_size)
|
||||
return ((x.int() - x_zero_point) * x_scale).cast(x_scale.dtype)
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ import pickle, sys
|
||||
from dataclasses import replace
|
||||
from tinygrad import Device, Context
|
||||
from tinygrad.device import Buffer
|
||||
from tinygrad.helpers import getenv
|
||||
from tinygrad.helpers import getenv, BEAM
|
||||
from tinygrad.engine.jit import TinyJit
|
||||
from tinygrad.engine.realize import CompiledRunner
|
||||
from tinygrad.renderer import ProgramSpec
|
||||
@@ -22,7 +22,11 @@ if __name__ == "__main__":
|
||||
if knum == (pknum:=getenv("KNUM", 0)) or pknum == 0:
|
||||
p: ProgramSpec = ei.prg.p
|
||||
k = Kernel(p.ast, Device["DSP"].renderer)
|
||||
if not getenv("NOOPT"):
|
||||
dsp_bufs = [Buffer("DSP", 1024+b.size*2, b.dtype).view(b.size, b.dtype, 512) for b in ei.bufs]
|
||||
if BEAM:
|
||||
from tinygrad.engine.search import beam_search
|
||||
k = beam_search(k, dsp_bufs, BEAM.value, bool(getenv("BEAM_ESTIMATE", 1)))
|
||||
elif not getenv("NOOPT"):
|
||||
# only NCHW
|
||||
"""
|
||||
if knum in [6,7,9,11]:
|
||||
@@ -48,19 +52,102 @@ if __name__ == "__main__":
|
||||
elif knum == 3:
|
||||
k.apply_opt(Opt(op=OptOps.UNROLL, axis=0, arg=4))
|
||||
k.apply_opt(Opt(OptOps.UPCAST, 1, 128))
|
||||
elif knum == 29:
|
||||
#k.apply_opt(Opt(OptOps.UPCAST, 1, 2))
|
||||
k.apply_opt(Opt(OptOps.PADTO, 1, 128))
|
||||
k.apply_opt(Opt(OptOps.UPCAST, 1, 256))
|
||||
#k.apply_opt(Opt(OptOps.UNROLL, 0, 4))
|
||||
else:
|
||||
k.hand_coded_optimizations()
|
||||
"""
|
||||
"""
|
||||
if knum == 3:
|
||||
# 12544x32 * 32x16 -> 12544x16
|
||||
|
||||
k.apply_opt(Opt(OptOps.UNROLL, 0, 0))
|
||||
k.apply_opt(Opt(OptOps.UPCAST, 1, 16))
|
||||
k.apply_opt(Opt(OptOps.UPCAST, 0, 128//16))
|
||||
#k.apply_opt(Opt(OptOps.UPCAST, 0, 256//16))
|
||||
#k.apply_opt(Opt(OptOps.UPCAST, 0, 8))
|
||||
pass
|
||||
elif knum == 6:
|
||||
k.apply_opt(Opt(OptOps.UNROLL, 0, 8))
|
||||
k.apply_opt(Opt(OptOps.UPCAST, 1, 0))
|
||||
elif knum == 4:
|
||||
# 12544x16 * 16x96 -> 12544x96
|
||||
# (with the biased add)
|
||||
#k.apply_opt(Opt(OptOps.UPCAST, 1, 96))
|
||||
#k.apply_opt(Opt(OptOps.UPCAST, 0, 4))
|
||||
#k.apply_opt(Opt(OptOps.UNROLL, 0, 0))
|
||||
#k.apply_opt(Opt(OptOps.PADTO, 0, 3))
|
||||
pass
|
||||
elif knum == 13:
|
||||
# 784x144 * 144x32 -> 784x32
|
||||
#k.apply_opt(Opt(OptOps.UNROLL, 0, 4))
|
||||
#k.apply_opt(Opt(OptOps.UNROLL, 0, 2))
|
||||
#k.apply_opt(Opt(OptOps.UPCAST, 0, 4))
|
||||
#k.apply_opt(Opt(OptOps.UPCAST, 0, 2))
|
||||
#k.apply_opt(Opt(OptOps.UPCAST, 1, 32))
|
||||
pass
|
||||
elif knum == 20:
|
||||
# 784x192 * 192x32 -> 784x32
|
||||
k.apply_opt(Opt(OptOps.UNROLL, 0, 8))
|
||||
k.apply_opt(Opt(OptOps.UPCAST, 1, 32))
|
||||
k.apply_opt(Opt(OptOps.UPCAST, 0, 4))
|
||||
elif knum == 35:
|
||||
k.apply_opt(Opt(OptOps.UNROLL, 0, 128))
|
||||
k.apply_opt(Opt(OptOps.UPCAST, 0, 2))
|
||||
k.apply_opt(Opt(OptOps.UPCAST, 1, 64))
|
||||
elif knum == 37:
|
||||
pass
|
||||
elif knum == 24:
|
||||
#k.apply_opt(Opt(OptOps.UNROLL, 0, 0))
|
||||
k.apply_opt(Opt(OptOps.UPCAST, 1, 64))
|
||||
#k.apply_opt(Opt(OptOps.UPCAST, 0, 2))
|
||||
"""
|
||||
#if knum in [7, 11, 14, 18]:
|
||||
# alignment issue?
|
||||
#pass
|
||||
if knum == 4:
|
||||
k.apply_opt(Opt(OptOps.UNROLL, 0, 4))
|
||||
k.apply_opt(Opt(OptOps.UPCAST, 1, 96))
|
||||
k.apply_opt(Opt(OptOps.UPCAST, 0, 4))
|
||||
elif knum == 6:
|
||||
k.apply_opt(Opt(OptOps.UNROLL, 0, 4))
|
||||
k.apply_opt(Opt(OptOps.UPCAST, 1, 24))
|
||||
k.apply_opt(Opt(OptOps.UPCAST, 0, 16))
|
||||
elif knum == 11:
|
||||
k.apply_opt(Opt(OptOps.UNROLL, 0, 4))
|
||||
k.apply_opt(Opt(OptOps.UPCAST, 1, 144))
|
||||
#k.apply_opt(Opt(OptOps.UPCAST, 0, 8))
|
||||
elif knum == 14:
|
||||
k.apply_opt(Opt(OptOps.UNROLL, 0, 4))
|
||||
k.apply_opt(Opt(OptOps.UPCAST, 1, 192))
|
||||
k.apply_opt(Opt(OptOps.UPCAST, 0, 2))
|
||||
elif knum == 37:
|
||||
k.apply_opt(Opt(OptOps.UNROLL, 0, 4))
|
||||
k.apply_opt(Opt(OptOps.UPCAST, 1, 384))
|
||||
else:
|
||||
k.hand_coded_optimizations()
|
||||
full_shape = k.full_shape
|
||||
out_shape = k.sts[0].shape
|
||||
out_strides = k.sts[0].real_strides()
|
||||
if len(out_strides) == 3:
|
||||
if full_shape[1] < 128:
|
||||
if full_shape[2] <= 32: k.apply_opt(Opt(OptOps.UNROLL, 0, 0))
|
||||
else: k.apply_opt(Opt(OptOps.UNROLL, 0, 8))
|
||||
k.apply_opt(Opt(OptOps.UPCAST, 1, full_shape[1]))
|
||||
if out_strides[0] < 128:
|
||||
upcast_0 = 128//out_strides[0]
|
||||
if out_shape[0]%upcast_0 == 0 and upcast_0 != 1: k.apply_opt(Opt(OptOps.UPCAST, 0, upcast_0))
|
||||
elif full_shape[1] % 128 == 0:
|
||||
k.apply_opt(Opt(OptOps.UPCAST, 1, 128))
|
||||
elif len(out_strides) == 1:
|
||||
#if full_shape[0]%128 == 0: k.apply_opt(Opt(OptOps.UPCAST, 0, 128))
|
||||
pass
|
||||
#print("here", out_shape, out_strides, k.name)
|
||||
#k.hand_coded_optimizations()
|
||||
#if knum in [5]: k.apply_opt(Opt(OptOps.UPCAST, 1, 2))
|
||||
p2 = k.to_program()
|
||||
new_ei = replace(ei, prg=CompiledRunner(p2), bufs=[Buffer("DSP", 1024+b.size*2, b.dtype).view(b.size, b.dtype, 512) for b in ei.bufs])
|
||||
new_ei = replace(ei, prg=CompiledRunner(p2), bufs=dsp_bufs)
|
||||
new_ei.run()
|
||||
knum += 1
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
# ruff: noqa: E501
|
||||
import numpy as np
|
||||
import unittest
|
||||
from dataclasses import replace
|
||||
from tinygrad import Tensor, Context, Device, dtypes
|
||||
from tinygrad.ops import Ops
|
||||
from tinygrad.ops import Ops, UOp # noqa: F401 # pylint: disable=unused-import
|
||||
from tinygrad.codegen.kernel import Kernel, Opt, OptOps
|
||||
from tinygrad.engine.realize import CompiledRunner, ExecItem, lower_schedule_item
|
||||
from tinygrad.engine.search import bufs_from_lin
|
||||
from tinygrad.shape.shapetracker import ShapeTracker, View # noqa: F401 # pylint: disable=unused-import
|
||||
|
||||
N = 512
|
||||
|
||||
@@ -40,7 +43,7 @@ def sexec(out:Tensor, opts:list[Opt], replace_src=None, run_count=3):
|
||||
for opt in opts: k.apply_opt(opt)
|
||||
prg = k.to_program()
|
||||
if replace_src is not None:
|
||||
old_name = prg.src.split("inscount();\n")[1].split("(")[0]
|
||||
old_name = prg.src.split("__attribute__((noinline)) void ")[1].split("(")[0]
|
||||
prg = replace(prg, src=replace_src + "/* DSP boilerplate */" + prg.src.split("/* DSP boilerplate */")[1].replace(old_name, "fxn"))
|
||||
ei = ExecItem(CompiledRunner(prg), [x.ensure_allocated() for x in si.bufs], si.metadata)
|
||||
for _ in range(run_count): ei.run(wait=True)
|
||||
@@ -195,9 +198,12 @@ class TestQuantizeOnnx(unittest.TestCase):
|
||||
}"""
|
||||
self.test_prequant_gemm_intacc(np.uint8, np.int8, src)
|
||||
|
||||
def test_prequant_gemm_intacc_32(self):
|
||||
opts = [Opt(op=OptOps.UPCAST, axis=1, arg=0), Opt(op=OptOps.UPCAST, axis=0, arg=4), Opt(op=OptOps.UNROLL, axis=0, arg=0)]
|
||||
self.test_prequant_gemm_intacc(np.uint8, np.int8, N=32, opts=opts)
|
||||
def test_prequant_gemm_intacc_128(self): self.test_prequant_gemm_intacc(np.uint8, np.int8, N=128)
|
||||
def test_prequant_gemm_intacc_256(self): self.test_prequant_gemm_intacc(np.uint8, np.int8, N=256)
|
||||
def test_prequant_gemm_intacc(self, xi=np.uint8, wi=np.uint8, replace_src=None, N=512, clip=True):
|
||||
def test_prequant_gemm_intacc(self, xi=np.uint8, wi=np.uint8, replace_src=None, N=512, clip=True, opts=None):
|
||||
X = Tensor(m1:=(np.random.uniform(0, 255, size=(N,N)).astype(xi))).realize()
|
||||
W = Tensor(m2:=(np.random.uniform(0, 255, size=(N,N)).astype(wi))).realize()
|
||||
# ugh, it's so broken with those casts. need DONT_REALIZE_EXPAND=1 python3 test/test_quantize_onnx.py TestQuantizeOnnx.test_prequant
|
||||
@@ -206,7 +212,7 @@ class TestQuantizeOnnx(unittest.TestCase):
|
||||
out = (X.int().matmul(W.int())//1000)
|
||||
if clip: out = out.clip(dtypes.min(tg_dtype),dtypes.max(tg_dtype))
|
||||
out = out.cast(tg_dtype)
|
||||
opts = [Opt(op=OptOps.UPCAST, axis=1, arg=128), Opt(op=OptOps.UNROLL, axis=0, arg=4)]
|
||||
opts = [Opt(op=OptOps.UPCAST, axis=1, arg=128), Opt(op=OptOps.UNROLL, axis=0, arg=4)] if opts is None else opts
|
||||
sexec(out, opts, replace_src, run_count=1)
|
||||
tout = out.numpy()
|
||||
mout = ((m1.astype(np.int32) @ m2.astype(np.int32)) / 1000)
|
||||
@@ -231,5 +237,132 @@ class TestQuantizeOnnx(unittest.TestCase):
|
||||
opts = [Opt(op=OptOps.UPCAST, axis=0, arg=128), Opt(op=OptOps.UNROLL, axis=0, arg=4)]
|
||||
sexec(out, opts)
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT != "DSP", "only tests for DSP")
|
||||
class TestDSPCache(unittest.TestCase):
|
||||
def test_cache_speed(self):
|
||||
# string becuase this breaks Python language server for syntax highlight for some reason
|
||||
ast = eval("""UOp(Ops.SINK, dtypes.void, arg=None, src=(
|
||||
UOp(Ops.STORE, dtypes.void, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.uchar.ptr(25088), arg=0, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 28, 28, 32, 1), strides=(0, 896, 32, 1, 0), offset=0, mask=None, contiguous=True),)), src=()),
|
||||
UOp(Ops.CAST, dtypes.uchar, arg=None, src=(
|
||||
UOp(Ops.XOR, dtypes.int, arg=None, src=(
|
||||
UOp(Ops.MAX, dtypes.int, arg=None, src=(
|
||||
UOp(Ops.XOR, dtypes.int, arg=None, src=(
|
||||
UOp(Ops.MAX, dtypes.int, arg=None, src=(
|
||||
UOp(Ops.CAST, dtypes.int, arg=None, src=(
|
||||
UOp(Ops.ADD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.ADD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.MUL, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.ADD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (4,)), src=(
|
||||
UOp(Ops.MUL, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.MUL, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.CAST, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.CAST, dtypes.int, arg=None, src=(
|
||||
UOp(Ops.LOAD, dtypes.uchar, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.uchar.ptr(150528), arg=1, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 28, 28, 32, 192), strides=(0, 5376, 192, 0, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),
|
||||
UOp(Ops.CONST, dtypes.float, arg=0.012368360534310341, src=(
|
||||
x22:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 28, 28, 32, 192), strides=(0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),
|
||||
UOp(Ops.MUL, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.CAST, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.CAST, dtypes.int, arg=None, src=(
|
||||
UOp(Ops.LOAD, dtypes.char, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.char.ptr(6144), arg=2, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 48, 4), strides=(4, 128, 1), offset=0, mask=None, contiguous=False), View(shape=(1, 28, 28, 32, 192), strides=(0, 0, 0, 192, 1), offset=0, mask=None, contiguous=False))), src=()),)),)),)),
|
||||
UOp(Ops.CONST, dtypes.float, arg=0.007441135589033365, src=(
|
||||
x22,)),)),)),)),
|
||||
UOp(Ops.MUL, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.CAST, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.LOAD, dtypes.int, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(32), arg=3, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 28, 28, 32, 1), strides=(0, 0, 0, 1, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),
|
||||
UOp(Ops.CONST, dtypes.float, arg=9.203465015161783e-05, src=(
|
||||
x36:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 28, 28, 32, 1), strides=(0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),
|
||||
UOp(Ops.CONST, dtypes.float, arg=33.812857328652136, src=(
|
||||
x36,)),)),
|
||||
UOp(Ops.CONST, dtypes.float, arg=0.4999999, src=(
|
||||
x36,)),)),
|
||||
UOp(Ops.CONST, dtypes.float, arg=136.0, src=(
|
||||
x36,)),)),)),
|
||||
UOp(Ops.CONST, dtypes.int, arg=0, src=(
|
||||
x36,)),)),
|
||||
x41:=UOp(Ops.CONST, dtypes.int, arg=-1, src=(
|
||||
x36,)),)),
|
||||
UOp(Ops.CONST, dtypes.int, arg=-256, src=(
|
||||
x36,)),)),
|
||||
x41,)),)),)),))""")
|
||||
opts = [Opt(op=OptOps.UNROLL, axis=0, arg=8), Opt(op=OptOps.UPCAST, axis=1, arg=32), Opt(op=OptOps.UPCAST, axis=0, arg=4)]
|
||||
with Context(DEVECTORIZE=0, QUANTIZE=1):
|
||||
k = Kernel(ast, opts=Device[Device.DEFAULT].renderer)
|
||||
for opt in opts: k.apply_opt(opt)
|
||||
prg = k.to_program()
|
||||
#print(prg.src)
|
||||
|
||||
new_src = """
|
||||
typedef int int32 __attribute__((aligned(128),vector_size(128)));
|
||||
typedef signed char signed_char128 __attribute__((aligned(128),vector_size(128)));
|
||||
typedef unsigned char unsigned_char8 __attribute__((aligned(8),vector_size(8)));
|
||||
typedef unsigned char unsigned_char4 __attribute__((aligned(4),vector_size(4)));
|
||||
typedef unsigned char unsigned_char128 __attribute__((aligned(128),vector_size(128)));
|
||||
__attribute__((noinline)) void r_196_24_8_32_4(unsigned char* restrict __attribute__((align_value(128))) data0, unsigned char* restrict __attribute__((align_value(128))) data1, signed char* restrict __attribute__((align_value(
|
||||
128))) data2, int* restrict __attribute__((align_value(128))) data3) {
|
||||
int32 cast0 = (int32){0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0};
|
||||
int32 val0 = *((int32*)((data3+0)));
|
||||
for (int ridx0 = 0; ridx0 < 196; ridx0++) {
|
||||
int32 acc0 = cast0;
|
||||
int32 acc1 = cast0;
|
||||
int32 acc2 = cast0;
|
||||
int32 acc3 = cast0;
|
||||
__builtin_HEXAGON_Y2_dcfetch(data1+ridx0*768);
|
||||
__builtin_HEXAGON_Y2_dcfetch(data1+ridx0*768+192);
|
||||
__builtin_HEXAGON_Y2_dcfetch(data1+ridx0*768+384);
|
||||
__builtin_HEXAGON_Y2_dcfetch(data1+ridx0*768+576);
|
||||
for (int ridx1 = 0; ridx1 < 24; ridx1++) {
|
||||
signed_char128 val1 = *((signed_char128*)((data2+(ridx1<<8))));
|
||||
signed_char128 val2 = *((signed_char128*)((data2+((1+(ridx1<<1))<<7))));
|
||||
|
||||
int alu0 = ((ridx0*768)+(ridx1<<3));
|
||||
|
||||
unsigned_char8 val3 = *((unsigned_char8*)((data1+alu0)));
|
||||
__builtin_HEXAGON_Y2_dcfetch(((data1+alu0)+16));
|
||||
unsigned_char8 val4 = *((unsigned_char8*)((data1+(alu0+192))));
|
||||
__builtin_HEXAGON_Y2_dcfetch(((data1+(alu0+192))+16));
|
||||
unsigned_char8 val5 = *((unsigned_char8*)((data1+(alu0+384))));
|
||||
__builtin_HEXAGON_Y2_dcfetch(((data1+(alu0+384))+16));
|
||||
unsigned_char8 val6 = *((unsigned_char8*)((data1+(alu0+576))));
|
||||
__builtin_HEXAGON_Y2_dcfetch(((data1+(alu0+576))+16));
|
||||
|
||||
unsigned_char4 alu5 = __builtin_shufflevector(val3, val3, 0, 1, 2, 3);
|
||||
unsigned_char4 alu6 = __builtin_shufflevector(val4, val4, 0, 1, 2, 3);
|
||||
unsigned_char4 alu7 = __builtin_shufflevector(val5, val5, 0, 1, 2, 3);
|
||||
unsigned_char4 alu8 = __builtin_shufflevector(val6, val6, 0, 1, 2, 3);
|
||||
acc0 = __builtin_HEXAGON_V6_vrmpybus_acc_128B(acc0, val1, (*((unsigned int*)&alu5)));
|
||||
acc1 = __builtin_HEXAGON_V6_vrmpybus_acc_128B(acc1, val1, (*((unsigned int*)&alu6)));
|
||||
acc2 = __builtin_HEXAGON_V6_vrmpybus_acc_128B(acc2, val1, (*((unsigned int*)&alu7)));
|
||||
acc3 = __builtin_HEXAGON_V6_vrmpybus_acc_128B(acc3, val1, (*((unsigned int*)&alu8)));
|
||||
|
||||
unsigned_char4 alu9 = __builtin_shufflevector(val3, val3, 4, 5, 6, 7);
|
||||
unsigned_char4 alu10 = __builtin_shufflevector(val4, val4, 4, 5, 6, 7);
|
||||
unsigned_char4 alu11 = __builtin_shufflevector(val5, val5, 4, 5, 6, 7);
|
||||
unsigned_char4 alu12 = __builtin_shufflevector(val6, val6, 4, 5, 6, 7);
|
||||
acc0 = __builtin_HEXAGON_V6_vrmpybus_acc_128B(acc0, val2, (*((unsigned int*)&alu9)));
|
||||
acc1 = __builtin_HEXAGON_V6_vrmpybus_acc_128B(acc1, val2, (*((unsigned int*)&alu10)));
|
||||
acc2 = __builtin_HEXAGON_V6_vrmpybus_acc_128B(acc2, val2, (*((unsigned int*)&alu11)));
|
||||
acc3 = __builtin_HEXAGON_V6_vrmpybus_acc_128B(acc3, val2, (*((unsigned int*)&alu12)));
|
||||
}
|
||||
unsigned_char128 alu18 = __builtin_HEXAGON_V6_vpackhub_sat_128B(__builtin_HEXAGON_V6_vpackwh_sat_128B((((((acc3+val0)*203)+32767)/65536)+136), (((((acc2+val0)*203)+32767)/65536)+136)), __builtin_HEXAGON_V6_vpackwh_sat_128B((((((acc1+val0)*203)+32767)/65536)+136), (((((acc0+val0)*203)+32767)/65536)+136)));
|
||||
*((unsigned_char128*)((data0+(ridx0<<7)))) = alu18;
|
||||
}
|
||||
}
|
||||
"""
|
||||
prg = replace(prg, src=new_src+prg.src.split("/* DSP boilerplate */ ")[1])
|
||||
rt = CompiledRunner(prg)
|
||||
#Device.default.compiler.disassemble(rt.lib)
|
||||
ei = ExecItem(rt, bufs_from_lin(k))
|
||||
tm = ei.run(wait=True)
|
||||
print(f"final time {tm*1e6:.2f} us")
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -114,7 +114,7 @@ class Ops(FastEnum):
|
||||
VALID = auto(); SPECIAL = auto(); NOOP = auto() # noqa: E702
|
||||
|
||||
# reduce
|
||||
REDUCE_AXIS = auto()
|
||||
REDUCE_AXIS = auto(); REDUCE = auto() # noqa: E702
|
||||
|
||||
# helper ops
|
||||
GEP = auto(); VECTORIZE = auto(); CAT = auto(); PTRCAT = auto() # noqa: E702
|
||||
|
||||
@@ -9,7 +9,7 @@ from tinygrad.codegen.kernel import Kernel
|
||||
from tinygrad.device import ProfileEvent, ProfileDeviceEvent, ProfileRangeEvent, ProfileGraphEvent
|
||||
from tinygrad.dtype import dtypes
|
||||
|
||||
uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0", Ops.VCONST: "#e0e0e0",
|
||||
uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0", Ops.VCONST: "#e0e0e0", Ops.REDUCE: "#FF5B5B",
|
||||
Ops.DEFINE_GLOBAL: "#ffe0b0", Ops.DEFINE_LOCAL: "#ffe0d0", Ops.DEFINE_ACC: "#f0ffe0", Ops.REDUCE_AXIS: "#FF6B6B",
|
||||
Ops.RANGE: "#c8a0e0", Ops.ASSIGN: "#e0ffc0", Ops.BARRIER: "#ff8080", Ops.IF: "#c8b0c0", Ops.SPECIAL: "#c0c0ff",
|
||||
Ops.INDEX: "#e8ffa0", Ops.WMMA: "#efefc0", Ops.VIEW: "#C8F9D4", Ops.MULTI: "#f6ccff", Ops.KERNEL: "#3e7f55",
|
||||
|
||||
Reference in New Issue
Block a user