more stuff from DSP (#9689)

* more good stuff from dsp branch

* test pkl imagenet
This commit is contained in:
George Hotz
2025-04-02 15:27:48 +08:00
committed by GitHub
parent 6a5eacba8b
commit 4514fd91c1
5 changed files with 33 additions and 14 deletions

View File

@@ -1,12 +1,12 @@
import sys, onnx, time, pickle
from tinygrad import TinyJit, Device, GlobalCounters, fetch, getenv
from tinygrad import TinyJit, GlobalCounters, fetch, getenv
from tinygrad.frontend.onnx import OnnxRunner
from extra.onnx_helpers import get_example_inputs, validate
def load_onnx_model(onnx_file):
onnx_model = onnx.load(onnx_file)
run_onnx = OnnxRunner(onnx_model)
run_onnx_jit = TinyJit(lambda **kwargs: next(iter(run_onnx({k:v.to(Device.DEFAULT) for k,v in kwargs.items()}).values())), prune=True)
run_onnx_jit = TinyJit(lambda **kwargs: next(iter(run_onnx({k:v.to(None) for k,v in kwargs.items()}).values())), prune=True, optimize=True)
return run_onnx_jit, run_onnx.graph_inputs
if __name__ == "__main__":
@@ -34,8 +34,3 @@ if __name__ == "__main__":
if getenv("ORT"):
validate(onnx_file, new_inputs, rtol=1e-3, atol=1e-3)
print("model validated")
if (fn:=getenv("SAVE_PKL", "")) != "":
with open(fn, "wb") as f:
pickle.dump(run_onnx_jit, f)
print(f"pkl saved to {fn}")

View File

@@ -0,0 +1,19 @@
import sys, pickle
from tinygrad import GlobalCounters
from tinygrad.helpers import fetch, getenv
from examples.test_onnx_imagenet import imagenet_dataloader
if __name__ == "__main__":
with open(fetch(sys.argv[1]), "rb") as f:
run_onnx_jit = pickle.load(f)
input_name = run_onnx_jit.captured.expected_names[0]
device = run_onnx_jit.captured.expected_st_vars_dtype_device[0][-1]
print(f"input goes into {input_name=} on {device=}")
hit = 0
for i,(img,y) in enumerate(imagenet_dataloader(cnt=getenv("CNT", 100))):
GlobalCounters.reset()
p = run_onnx_jit(**{input_name:img.to(device)})
assert p.shape == (1,1000)
t = p.to('cpu').argmax().item()
hit += y==t
print(f"target: {y:3d} pred: {t:3d} acc: {hit/(i+1)*100:.2f}%")

View File

@@ -175,7 +175,7 @@ def div_and_mod_folding(x: UOp, y: UOp, which: Literal[Ops.MOD, Ops.IDIV], split
gep_pushing = PatternMatcher([
# GEP/VECTORIZE, GEP/GEP, GEP/CONST, GEP/VCONST
(UPat(Ops.GEP, src=(UPat(Ops.GEP, name='g2'),), name='g1'),
lambda g1, g2: g2.src[0].gep(tuple(g2.arg[g1.arg[i]] for i in range(g1.dtype.count)))),
lambda g1, g2: g2.src[0].gep(tuple(g2.arg[g1.arg[i]] for i in range(len(g1.arg))))),
(UPat(Ops.GEP, src=(UPat(Ops.VECTORIZE, name="vec"),), name="gep"),
lambda gep, vec: UOp(Ops.VECTORIZE, gep.dtype, tuple(vec.src[i] for i in gep.arg)) if len(gep.arg) > 1 else vec.src[gep.arg[0]]),
(UPat(Ops.GEP, src=(UPat.cvar("c", vec=False),), name="gep"), lambda gep, c: gep.const_like(c.arg)),
@@ -191,9 +191,13 @@ gep_pushing = PatternMatcher([
(UPat(Ops.VECTORIZE, name="v", src=UPat(Ops.GEP, src=(UPat.var("x"),))), lambda v,x: x.gep(tuple(get_single_element(i.arg) for i in v.src))),
])
symbolic = symbolic_simple+PatternMatcher([
commutative = PatternMatcher([
# ** COMMUTATIVE flipping (only for ints) **
(UPat(GroupOp.Commutative, dtype=dtypes.ints, name='x'), lambda x: x.replace(src=x.src[::-1]) if x.src[1].tuplize < x.src[0].tuplize else None),
# NOTE: this can break merging vector math by only flipping some of them
(UPat(GroupOp.Commutative, dtype=dtypes.int, name='x'), lambda x: x.replace(src=x.src[::-1]) if x.src[1].tuplize < x.src[0].tuplize else None),
])
symbolic = symbolic_simple+commutative+PatternMatcher([
# ** boolean algebra **
(UPat.var("x") | (UPat.var("x") & UPat.var()), lambda x: x), # x|(x&y) -> x
# ** combine terms **

View File

@@ -149,8 +149,7 @@ class CapturedJit(Generic[ReturnType]):
expected_st_vars_dtype_device: list[tuple[ShapeTracker, tuple[Variable, ...], DType, str]]
def __reduce__(self):
# TODO: free_intermediates here?
self.optimize_weights()
# TODO: free_intermediates here? optimize_weights here?
return self.__class__, (self.ret, self.jit_cache, self.input_replace, self.extra_view_inputs,
self.expected_names, self.expected_st_vars_dtype_device)
@@ -218,12 +217,13 @@ def _prepare_jit_inputs(args, kwargs):
return input_buffers, var_vals, names, st_vars_dtype_device
class TinyJit(Generic[ReturnType]):
def __init__(self, fxn:Optional[Callable[..., ReturnType]], captured:Optional[CapturedJit]=None, prune=False):
def __init__(self, fxn:Optional[Callable[..., ReturnType]], captured:Optional[CapturedJit]=None, prune=False, optimize=False):
assert fxn or captured, "need either a function or a CapturedJit"
self.fxn = fxn
self.captured: Optional[CapturedJit] = captured
self.cnt: int = 2 if self.fxn is None else 0
self.prune = prune
self.optimize = optimize
def add_buffer(self, b:Buffer) -> Buffer:
if found:=self._buffer_replace.get(b, None): return found
@@ -314,6 +314,7 @@ class TinyJit(Generic[ReturnType]):
# set this for next run
self.captured = CapturedJit(ret, jit_cache, input_replace, extra_view_inputs, names, st_vars_dtype_device)
if self.optimize: self.captured.optimize_weights()
elif self.cnt >= 2:
# jit exec
assert self.captured is not None

View File

@@ -62,7 +62,7 @@ base_rewrite = PatternMatcher([
extra_pm = PatternMatcher([
# insert a NOOP before BITCAST to force it to be rendered. not needed on all backends?
(UPat(Ops.BITCAST, name="x"),
lambda x: UOp(Ops.BITCAST, x.dtype, (UOp(Ops.NOOP, x.src[0].dtype, x.src),)) if x.src[0].op is not Ops.NOOP else None),
lambda x: UOp(Ops.BITCAST, x.dtype, (UOp(Ops.NOOP, x.src[0].dtype, x.src),)) if x.src[0].op not in {Ops.NOOP, Ops.LOAD, Ops.CUSTOM} else None),
# rewrite MAX to CMPLT + WHERE (max function is annoying on many cstyle backends)
(UPat(Ops.MAX, name="m"), lambda m: (m.src[0] < m.src[1]).where(m.src[1], m.src[0])),
# devectorize any bools