mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 06:58:11 -05:00
more stuff from DSP (#9689)
* more good stuff from dsp branch * test pkl imagenet
This commit is contained in:
@@ -1,12 +1,12 @@
|
|||||||
import sys, onnx, time, pickle
|
import sys, onnx, time, pickle
|
||||||
from tinygrad import TinyJit, Device, GlobalCounters, fetch, getenv
|
from tinygrad import TinyJit, GlobalCounters, fetch, getenv
|
||||||
from tinygrad.frontend.onnx import OnnxRunner
|
from tinygrad.frontend.onnx import OnnxRunner
|
||||||
from extra.onnx_helpers import get_example_inputs, validate
|
from extra.onnx_helpers import get_example_inputs, validate
|
||||||
|
|
||||||
def load_onnx_model(onnx_file):
|
def load_onnx_model(onnx_file):
|
||||||
onnx_model = onnx.load(onnx_file)
|
onnx_model = onnx.load(onnx_file)
|
||||||
run_onnx = OnnxRunner(onnx_model)
|
run_onnx = OnnxRunner(onnx_model)
|
||||||
run_onnx_jit = TinyJit(lambda **kwargs: next(iter(run_onnx({k:v.to(Device.DEFAULT) for k,v in kwargs.items()}).values())), prune=True)
|
run_onnx_jit = TinyJit(lambda **kwargs: next(iter(run_onnx({k:v.to(None) for k,v in kwargs.items()}).values())), prune=True, optimize=True)
|
||||||
return run_onnx_jit, run_onnx.graph_inputs
|
return run_onnx_jit, run_onnx.graph_inputs
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
@@ -34,8 +34,3 @@ if __name__ == "__main__":
|
|||||||
if getenv("ORT"):
|
if getenv("ORT"):
|
||||||
validate(onnx_file, new_inputs, rtol=1e-3, atol=1e-3)
|
validate(onnx_file, new_inputs, rtol=1e-3, atol=1e-3)
|
||||||
print("model validated")
|
print("model validated")
|
||||||
|
|
||||||
if (fn:=getenv("SAVE_PKL", "")) != "":
|
|
||||||
with open(fn, "wb") as f:
|
|
||||||
pickle.dump(run_onnx_jit, f)
|
|
||||||
print(f"pkl saved to {fn}")
|
|
||||||
|
|||||||
19
examples/test_pkl_imagenet.py
Normal file
19
examples/test_pkl_imagenet.py
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
import sys, pickle
|
||||||
|
from tinygrad import GlobalCounters
|
||||||
|
from tinygrad.helpers import fetch, getenv
|
||||||
|
from examples.test_onnx_imagenet import imagenet_dataloader
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
with open(fetch(sys.argv[1]), "rb") as f:
|
||||||
|
run_onnx_jit = pickle.load(f)
|
||||||
|
input_name = run_onnx_jit.captured.expected_names[0]
|
||||||
|
device = run_onnx_jit.captured.expected_st_vars_dtype_device[0][-1]
|
||||||
|
print(f"input goes into {input_name=} on {device=}")
|
||||||
|
hit = 0
|
||||||
|
for i,(img,y) in enumerate(imagenet_dataloader(cnt=getenv("CNT", 100))):
|
||||||
|
GlobalCounters.reset()
|
||||||
|
p = run_onnx_jit(**{input_name:img.to(device)})
|
||||||
|
assert p.shape == (1,1000)
|
||||||
|
t = p.to('cpu').argmax().item()
|
||||||
|
hit += y==t
|
||||||
|
print(f"target: {y:3d} pred: {t:3d} acc: {hit/(i+1)*100:.2f}%")
|
||||||
@@ -175,7 +175,7 @@ def div_and_mod_folding(x: UOp, y: UOp, which: Literal[Ops.MOD, Ops.IDIV], split
|
|||||||
gep_pushing = PatternMatcher([
|
gep_pushing = PatternMatcher([
|
||||||
# GEP/VECTORIZE, GEP/GEP, GEP/CONST, GEP/VCONST
|
# GEP/VECTORIZE, GEP/GEP, GEP/CONST, GEP/VCONST
|
||||||
(UPat(Ops.GEP, src=(UPat(Ops.GEP, name='g2'),), name='g1'),
|
(UPat(Ops.GEP, src=(UPat(Ops.GEP, name='g2'),), name='g1'),
|
||||||
lambda g1, g2: g2.src[0].gep(tuple(g2.arg[g1.arg[i]] for i in range(g1.dtype.count)))),
|
lambda g1, g2: g2.src[0].gep(tuple(g2.arg[g1.arg[i]] for i in range(len(g1.arg))))),
|
||||||
(UPat(Ops.GEP, src=(UPat(Ops.VECTORIZE, name="vec"),), name="gep"),
|
(UPat(Ops.GEP, src=(UPat(Ops.VECTORIZE, name="vec"),), name="gep"),
|
||||||
lambda gep, vec: UOp(Ops.VECTORIZE, gep.dtype, tuple(vec.src[i] for i in gep.arg)) if len(gep.arg) > 1 else vec.src[gep.arg[0]]),
|
lambda gep, vec: UOp(Ops.VECTORIZE, gep.dtype, tuple(vec.src[i] for i in gep.arg)) if len(gep.arg) > 1 else vec.src[gep.arg[0]]),
|
||||||
(UPat(Ops.GEP, src=(UPat.cvar("c", vec=False),), name="gep"), lambda gep, c: gep.const_like(c.arg)),
|
(UPat(Ops.GEP, src=(UPat.cvar("c", vec=False),), name="gep"), lambda gep, c: gep.const_like(c.arg)),
|
||||||
@@ -191,9 +191,13 @@ gep_pushing = PatternMatcher([
|
|||||||
(UPat(Ops.VECTORIZE, name="v", src=UPat(Ops.GEP, src=(UPat.var("x"),))), lambda v,x: x.gep(tuple(get_single_element(i.arg) for i in v.src))),
|
(UPat(Ops.VECTORIZE, name="v", src=UPat(Ops.GEP, src=(UPat.var("x"),))), lambda v,x: x.gep(tuple(get_single_element(i.arg) for i in v.src))),
|
||||||
])
|
])
|
||||||
|
|
||||||
symbolic = symbolic_simple+PatternMatcher([
|
commutative = PatternMatcher([
|
||||||
# ** COMMUTATIVE flipping (only for ints) **
|
# ** COMMUTATIVE flipping (only for ints) **
|
||||||
(UPat(GroupOp.Commutative, dtype=dtypes.ints, name='x'), lambda x: x.replace(src=x.src[::-1]) if x.src[1].tuplize < x.src[0].tuplize else None),
|
# NOTE: this can break merging vector math by only flipping some of them
|
||||||
|
(UPat(GroupOp.Commutative, dtype=dtypes.int, name='x'), lambda x: x.replace(src=x.src[::-1]) if x.src[1].tuplize < x.src[0].tuplize else None),
|
||||||
|
])
|
||||||
|
|
||||||
|
symbolic = symbolic_simple+commutative+PatternMatcher([
|
||||||
# ** boolean algebra **
|
# ** boolean algebra **
|
||||||
(UPat.var("x") | (UPat.var("x") & UPat.var()), lambda x: x), # x|(x&y) -> x
|
(UPat.var("x") | (UPat.var("x") & UPat.var()), lambda x: x), # x|(x&y) -> x
|
||||||
# ** combine terms **
|
# ** combine terms **
|
||||||
|
|||||||
@@ -149,8 +149,7 @@ class CapturedJit(Generic[ReturnType]):
|
|||||||
expected_st_vars_dtype_device: list[tuple[ShapeTracker, tuple[Variable, ...], DType, str]]
|
expected_st_vars_dtype_device: list[tuple[ShapeTracker, tuple[Variable, ...], DType, str]]
|
||||||
|
|
||||||
def __reduce__(self):
|
def __reduce__(self):
|
||||||
# TODO: free_intermediates here?
|
# TODO: free_intermediates here? optimize_weights here?
|
||||||
self.optimize_weights()
|
|
||||||
return self.__class__, (self.ret, self.jit_cache, self.input_replace, self.extra_view_inputs,
|
return self.__class__, (self.ret, self.jit_cache, self.input_replace, self.extra_view_inputs,
|
||||||
self.expected_names, self.expected_st_vars_dtype_device)
|
self.expected_names, self.expected_st_vars_dtype_device)
|
||||||
|
|
||||||
@@ -218,12 +217,13 @@ def _prepare_jit_inputs(args, kwargs):
|
|||||||
return input_buffers, var_vals, names, st_vars_dtype_device
|
return input_buffers, var_vals, names, st_vars_dtype_device
|
||||||
|
|
||||||
class TinyJit(Generic[ReturnType]):
|
class TinyJit(Generic[ReturnType]):
|
||||||
def __init__(self, fxn:Optional[Callable[..., ReturnType]], captured:Optional[CapturedJit]=None, prune=False):
|
def __init__(self, fxn:Optional[Callable[..., ReturnType]], captured:Optional[CapturedJit]=None, prune=False, optimize=False):
|
||||||
assert fxn or captured, "need either a function or a CapturedJit"
|
assert fxn or captured, "need either a function or a CapturedJit"
|
||||||
self.fxn = fxn
|
self.fxn = fxn
|
||||||
self.captured: Optional[CapturedJit] = captured
|
self.captured: Optional[CapturedJit] = captured
|
||||||
self.cnt: int = 2 if self.fxn is None else 0
|
self.cnt: int = 2 if self.fxn is None else 0
|
||||||
self.prune = prune
|
self.prune = prune
|
||||||
|
self.optimize = optimize
|
||||||
|
|
||||||
def add_buffer(self, b:Buffer) -> Buffer:
|
def add_buffer(self, b:Buffer) -> Buffer:
|
||||||
if found:=self._buffer_replace.get(b, None): return found
|
if found:=self._buffer_replace.get(b, None): return found
|
||||||
@@ -314,6 +314,7 @@ class TinyJit(Generic[ReturnType]):
|
|||||||
|
|
||||||
# set this for next run
|
# set this for next run
|
||||||
self.captured = CapturedJit(ret, jit_cache, input_replace, extra_view_inputs, names, st_vars_dtype_device)
|
self.captured = CapturedJit(ret, jit_cache, input_replace, extra_view_inputs, names, st_vars_dtype_device)
|
||||||
|
if self.optimize: self.captured.optimize_weights()
|
||||||
elif self.cnt >= 2:
|
elif self.cnt >= 2:
|
||||||
# jit exec
|
# jit exec
|
||||||
assert self.captured is not None
|
assert self.captured is not None
|
||||||
|
|||||||
@@ -62,7 +62,7 @@ base_rewrite = PatternMatcher([
|
|||||||
extra_pm = PatternMatcher([
|
extra_pm = PatternMatcher([
|
||||||
# insert a NOOP before BITCAST to force it to be rendered. not needed on all backends?
|
# insert a NOOP before BITCAST to force it to be rendered. not needed on all backends?
|
||||||
(UPat(Ops.BITCAST, name="x"),
|
(UPat(Ops.BITCAST, name="x"),
|
||||||
lambda x: UOp(Ops.BITCAST, x.dtype, (UOp(Ops.NOOP, x.src[0].dtype, x.src),)) if x.src[0].op is not Ops.NOOP else None),
|
lambda x: UOp(Ops.BITCAST, x.dtype, (UOp(Ops.NOOP, x.src[0].dtype, x.src),)) if x.src[0].op not in {Ops.NOOP, Ops.LOAD, Ops.CUSTOM} else None),
|
||||||
# rewrite MAX to CMPLT + WHERE (max function is annoying on many cstyle backends)
|
# rewrite MAX to CMPLT + WHERE (max function is annoying on many cstyle backends)
|
||||||
(UPat(Ops.MAX, name="m"), lambda m: (m.src[0] < m.src[1]).where(m.src[1], m.src[0])),
|
(UPat(Ops.MAX, name="m"), lambda m: (m.src[0] < m.src[1]).where(m.src[1], m.src[0])),
|
||||||
# devectorize any bools
|
# devectorize any bools
|
||||||
|
|||||||
Reference in New Issue
Block a user