mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-06 21:53:53 -05:00
more stuff from DSP (#9689)
* more good stuff from dsp branch * test pkl imagenet
This commit is contained in:
@@ -1,12 +1,12 @@
|
||||
import sys, onnx, time, pickle
|
||||
from tinygrad import TinyJit, Device, GlobalCounters, fetch, getenv
|
||||
from tinygrad import TinyJit, GlobalCounters, fetch, getenv
|
||||
from tinygrad.frontend.onnx import OnnxRunner
|
||||
from extra.onnx_helpers import get_example_inputs, validate
|
||||
|
||||
def load_onnx_model(onnx_file):
|
||||
onnx_model = onnx.load(onnx_file)
|
||||
run_onnx = OnnxRunner(onnx_model)
|
||||
run_onnx_jit = TinyJit(lambda **kwargs: next(iter(run_onnx({k:v.to(Device.DEFAULT) for k,v in kwargs.items()}).values())), prune=True)
|
||||
run_onnx_jit = TinyJit(lambda **kwargs: next(iter(run_onnx({k:v.to(None) for k,v in kwargs.items()}).values())), prune=True, optimize=True)
|
||||
return run_onnx_jit, run_onnx.graph_inputs
|
||||
|
||||
if __name__ == "__main__":
|
||||
@@ -34,8 +34,3 @@ if __name__ == "__main__":
|
||||
if getenv("ORT"):
|
||||
validate(onnx_file, new_inputs, rtol=1e-3, atol=1e-3)
|
||||
print("model validated")
|
||||
|
||||
if (fn:=getenv("SAVE_PKL", "")) != "":
|
||||
with open(fn, "wb") as f:
|
||||
pickle.dump(run_onnx_jit, f)
|
||||
print(f"pkl saved to {fn}")
|
||||
|
||||
19
examples/test_pkl_imagenet.py
Normal file
19
examples/test_pkl_imagenet.py
Normal file
@@ -0,0 +1,19 @@
|
||||
import sys, pickle
|
||||
from tinygrad import GlobalCounters
|
||||
from tinygrad.helpers import fetch, getenv
|
||||
from examples.test_onnx_imagenet import imagenet_dataloader
|
||||
|
||||
if __name__ == "__main__":
|
||||
with open(fetch(sys.argv[1]), "rb") as f:
|
||||
run_onnx_jit = pickle.load(f)
|
||||
input_name = run_onnx_jit.captured.expected_names[0]
|
||||
device = run_onnx_jit.captured.expected_st_vars_dtype_device[0][-1]
|
||||
print(f"input goes into {input_name=} on {device=}")
|
||||
hit = 0
|
||||
for i,(img,y) in enumerate(imagenet_dataloader(cnt=getenv("CNT", 100))):
|
||||
GlobalCounters.reset()
|
||||
p = run_onnx_jit(**{input_name:img.to(device)})
|
||||
assert p.shape == (1,1000)
|
||||
t = p.to('cpu').argmax().item()
|
||||
hit += y==t
|
||||
print(f"target: {y:3d} pred: {t:3d} acc: {hit/(i+1)*100:.2f}%")
|
||||
@@ -175,7 +175,7 @@ def div_and_mod_folding(x: UOp, y: UOp, which: Literal[Ops.MOD, Ops.IDIV], split
|
||||
gep_pushing = PatternMatcher([
|
||||
# GEP/VECTORIZE, GEP/GEP, GEP/CONST, GEP/VCONST
|
||||
(UPat(Ops.GEP, src=(UPat(Ops.GEP, name='g2'),), name='g1'),
|
||||
lambda g1, g2: g2.src[0].gep(tuple(g2.arg[g1.arg[i]] for i in range(g1.dtype.count)))),
|
||||
lambda g1, g2: g2.src[0].gep(tuple(g2.arg[g1.arg[i]] for i in range(len(g1.arg))))),
|
||||
(UPat(Ops.GEP, src=(UPat(Ops.VECTORIZE, name="vec"),), name="gep"),
|
||||
lambda gep, vec: UOp(Ops.VECTORIZE, gep.dtype, tuple(vec.src[i] for i in gep.arg)) if len(gep.arg) > 1 else vec.src[gep.arg[0]]),
|
||||
(UPat(Ops.GEP, src=(UPat.cvar("c", vec=False),), name="gep"), lambda gep, c: gep.const_like(c.arg)),
|
||||
@@ -191,9 +191,13 @@ gep_pushing = PatternMatcher([
|
||||
(UPat(Ops.VECTORIZE, name="v", src=UPat(Ops.GEP, src=(UPat.var("x"),))), lambda v,x: x.gep(tuple(get_single_element(i.arg) for i in v.src))),
|
||||
])
|
||||
|
||||
symbolic = symbolic_simple+PatternMatcher([
|
||||
commutative = PatternMatcher([
|
||||
# ** COMMUTATIVE flipping (only for ints) **
|
||||
(UPat(GroupOp.Commutative, dtype=dtypes.ints, name='x'), lambda x: x.replace(src=x.src[::-1]) if x.src[1].tuplize < x.src[0].tuplize else None),
|
||||
# NOTE: this can break merging vector math by only flipping some of them
|
||||
(UPat(GroupOp.Commutative, dtype=dtypes.int, name='x'), lambda x: x.replace(src=x.src[::-1]) if x.src[1].tuplize < x.src[0].tuplize else None),
|
||||
])
|
||||
|
||||
symbolic = symbolic_simple+commutative+PatternMatcher([
|
||||
# ** boolean algebra **
|
||||
(UPat.var("x") | (UPat.var("x") & UPat.var()), lambda x: x), # x|(x&y) -> x
|
||||
# ** combine terms **
|
||||
|
||||
@@ -149,8 +149,7 @@ class CapturedJit(Generic[ReturnType]):
|
||||
expected_st_vars_dtype_device: list[tuple[ShapeTracker, tuple[Variable, ...], DType, str]]
|
||||
|
||||
def __reduce__(self):
|
||||
# TODO: free_intermediates here?
|
||||
self.optimize_weights()
|
||||
# TODO: free_intermediates here? optimize_weights here?
|
||||
return self.__class__, (self.ret, self.jit_cache, self.input_replace, self.extra_view_inputs,
|
||||
self.expected_names, self.expected_st_vars_dtype_device)
|
||||
|
||||
@@ -218,12 +217,13 @@ def _prepare_jit_inputs(args, kwargs):
|
||||
return input_buffers, var_vals, names, st_vars_dtype_device
|
||||
|
||||
class TinyJit(Generic[ReturnType]):
|
||||
def __init__(self, fxn:Optional[Callable[..., ReturnType]], captured:Optional[CapturedJit]=None, prune=False):
|
||||
def __init__(self, fxn:Optional[Callable[..., ReturnType]], captured:Optional[CapturedJit]=None, prune=False, optimize=False):
|
||||
assert fxn or captured, "need either a function or a CapturedJit"
|
||||
self.fxn = fxn
|
||||
self.captured: Optional[CapturedJit] = captured
|
||||
self.cnt: int = 2 if self.fxn is None else 0
|
||||
self.prune = prune
|
||||
self.optimize = optimize
|
||||
|
||||
def add_buffer(self, b:Buffer) -> Buffer:
|
||||
if found:=self._buffer_replace.get(b, None): return found
|
||||
@@ -314,6 +314,7 @@ class TinyJit(Generic[ReturnType]):
|
||||
|
||||
# set this for next run
|
||||
self.captured = CapturedJit(ret, jit_cache, input_replace, extra_view_inputs, names, st_vars_dtype_device)
|
||||
if self.optimize: self.captured.optimize_weights()
|
||||
elif self.cnt >= 2:
|
||||
# jit exec
|
||||
assert self.captured is not None
|
||||
|
||||
@@ -62,7 +62,7 @@ base_rewrite = PatternMatcher([
|
||||
extra_pm = PatternMatcher([
|
||||
# insert a NOOP before BITCAST to force it to be rendered. not needed on all backends?
|
||||
(UPat(Ops.BITCAST, name="x"),
|
||||
lambda x: UOp(Ops.BITCAST, x.dtype, (UOp(Ops.NOOP, x.src[0].dtype, x.src),)) if x.src[0].op is not Ops.NOOP else None),
|
||||
lambda x: UOp(Ops.BITCAST, x.dtype, (UOp(Ops.NOOP, x.src[0].dtype, x.src),)) if x.src[0].op not in {Ops.NOOP, Ops.LOAD, Ops.CUSTOM} else None),
|
||||
# rewrite MAX to CMPLT + WHERE (max function is annoying on many cstyle backends)
|
||||
(UPat(Ops.MAX, name="m"), lambda m: (m.src[0] < m.src[1]).where(m.src[1], m.src[0])),
|
||||
# devectorize any bools
|
||||
|
||||
Reference in New Issue
Block a user