mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
* little changes from dsp branch
* not that one
* need the where
* Revert "need the where"
This reverts commit 140f89c878.
64 lines
2.8 KiB
Python
64 lines
2.8 KiB
Python
import pickle, sys
|
|
from dataclasses import replace
|
|
from tinygrad import Device, Context
|
|
from tinygrad.device import Buffer
|
|
from tinygrad.helpers import getenv, BEAM
|
|
from tinygrad.engine.jit import TinyJit
|
|
from tinygrad.engine.realize import CompiledRunner
|
|
from tinygrad.renderer import ProgramSpec
|
|
from tinygrad.codegen.kernel import Kernel, Opt, OptOps
|
|
|
|
if __name__ == "__main__":
|
|
with Context(DEBUG=0):
|
|
with open(sys.argv[1], "rb") as f:
|
|
fxn: TinyJit = pickle.load(f)
|
|
print(f"{f.tell()/1e6:.2f}M loaded")
|
|
print(type(fxn))
|
|
|
|
knum = 1
|
|
for ei in fxn.captured.jit_cache:
|
|
# skip the copy and the first kernel
|
|
if isinstance(ei.prg, CompiledRunner) and all(x is not None for x in ei.bufs):
|
|
if knum == (pknum:=getenv("KNUM", 0)) or pknum == 0:
|
|
p: ProgramSpec = ei.prg.p
|
|
k = Kernel(p.ast, Device["DSP"].renderer)
|
|
dsp_bufs = [Buffer("DSP", 8192+b.size, b.dtype).view(b.size, b.dtype, 4096) for b in ei.bufs]
|
|
if BEAM:
|
|
from tinygrad.engine.search import beam_search
|
|
k = beam_search(k, dsp_bufs, BEAM.value, bool(getenv("BEAM_ESTIMATE", 1)))
|
|
elif not getenv("NOOPT"):
|
|
if knum == 1:
|
|
k.apply_opt(Opt(OptOps.UPCAST, 2, 32))
|
|
k.apply_opt(Opt(OptOps.UPCAST, 1, 4))
|
|
elif knum == 66:
|
|
k.apply_opt(Opt(OptOps.UNROLL, 0, 4))
|
|
k.apply_opt(Opt(OptOps.UPCAST, 0, 8))
|
|
elif k.full_shape[-3:] == (32,3,3):
|
|
#if k.full_shape[-4]%4 != 0: k.apply_opt(Opt(OptOps.PADTO, len(k.full_shape)-4, 4))
|
|
# 3x3 dwconv
|
|
k.apply_opt(Opt(OptOps.UNROLL, 0, 0))
|
|
k.apply_opt(Opt(OptOps.UNROLL, 0, 0))
|
|
k.apply_opt(Opt(OptOps.UPCAST, len(k.full_shape)-3, 32))
|
|
if k.full_shape[-4]%4 == 0: k.apply_opt(Opt(OptOps.UPCAST, len(k.full_shape)-4, 4))
|
|
elif len(k.full_shape) == 3 and k.full_shape[1] == 32:
|
|
#if k.full_shape[0]%4 != 0: k.apply_opt(Opt(OptOps.PADTO, 0, 4))
|
|
# weight without more
|
|
k.apply_opt(Opt(OptOps.UNROLL, 0, 8))
|
|
k.apply_opt(Opt(OptOps.UPCAST, 1, 32))
|
|
if k.full_shape[0]%4 == 0: k.apply_opt(Opt(OptOps.UPCAST, 0, 4))
|
|
elif len(k.full_shape) == 4 and k.full_shape[2] == 32:
|
|
#if k.full_shape[1]%4 != 0: k.apply_opt(Opt(OptOps.PADTO, 1, 4))
|
|
# weight with more
|
|
k.apply_opt(Opt(OptOps.UNROLL, 0, 8))
|
|
k.apply_opt(Opt(OptOps.UPCAST, 2, 32))
|
|
if k.full_shape[1]%4 == 0: k.apply_opt(Opt(OptOps.UPCAST, 1, 4))
|
|
elif len(k.full_shape) == 1:
|
|
for sz in [128,64,32]:
|
|
if k.full_shape[0]%sz == 0:
|
|
k.apply_opt(Opt(OptOps.UPCAST, 0, sz))
|
|
break
|
|
p2 = k.to_program()
|
|
new_ei = replace(ei, prg=CompiledRunner(p2), bufs=dsp_bufs)
|
|
new_ei.run()
|
|
knum += 1
|