more stuff from DSP (#9689)

* more good stuff from dsp branch

* test pkl imagenet
This commit is contained in:
George Hotz
2025-04-02 15:27:48 +08:00
committed by GitHub
parent 6a5eacba8b
commit 4514fd91c1
5 changed files with 33 additions and 14 deletions

View File

@@ -1,12 +1,12 @@
import sys, onnx, time, pickle import sys, onnx, time, pickle
from tinygrad import TinyJit, Device, GlobalCounters, fetch, getenv from tinygrad import TinyJit, GlobalCounters, fetch, getenv
from tinygrad.frontend.onnx import OnnxRunner from tinygrad.frontend.onnx import OnnxRunner
from extra.onnx_helpers import get_example_inputs, validate from extra.onnx_helpers import get_example_inputs, validate
def load_onnx_model(onnx_file): def load_onnx_model(onnx_file):
onnx_model = onnx.load(onnx_file) onnx_model = onnx.load(onnx_file)
run_onnx = OnnxRunner(onnx_model) run_onnx = OnnxRunner(onnx_model)
run_onnx_jit = TinyJit(lambda **kwargs: next(iter(run_onnx({k:v.to(Device.DEFAULT) for k,v in kwargs.items()}).values())), prune=True) run_onnx_jit = TinyJit(lambda **kwargs: next(iter(run_onnx({k:v.to(None) for k,v in kwargs.items()}).values())), prune=True, optimize=True)
return run_onnx_jit, run_onnx.graph_inputs return run_onnx_jit, run_onnx.graph_inputs
if __name__ == "__main__": if __name__ == "__main__":
@@ -34,8 +34,3 @@ if __name__ == "__main__":
if getenv("ORT"): if getenv("ORT"):
validate(onnx_file, new_inputs, rtol=1e-3, atol=1e-3) validate(onnx_file, new_inputs, rtol=1e-3, atol=1e-3)
print("model validated") print("model validated")
if (fn:=getenv("SAVE_PKL", "")) != "":
with open(fn, "wb") as f:
pickle.dump(run_onnx_jit, f)
print(f"pkl saved to {fn}")

View File

@@ -0,0 +1,19 @@
import sys, pickle
from tinygrad import GlobalCounters
from tinygrad.helpers import fetch, getenv
from examples.test_onnx_imagenet import imagenet_dataloader
if __name__ == "__main__":
with open(fetch(sys.argv[1]), "rb") as f:
run_onnx_jit = pickle.load(f)
input_name = run_onnx_jit.captured.expected_names[0]
device = run_onnx_jit.captured.expected_st_vars_dtype_device[0][-1]
print(f"input goes into {input_name=} on {device=}")
hit = 0
for i,(img,y) in enumerate(imagenet_dataloader(cnt=getenv("CNT", 100))):
GlobalCounters.reset()
p = run_onnx_jit(**{input_name:img.to(device)})
assert p.shape == (1,1000)
t = p.to('cpu').argmax().item()
hit += y==t
print(f"target: {y:3d} pred: {t:3d} acc: {hit/(i+1)*100:.2f}%")

View File

@@ -175,7 +175,7 @@ def div_and_mod_folding(x: UOp, y: UOp, which: Literal[Ops.MOD, Ops.IDIV], split
gep_pushing = PatternMatcher([ gep_pushing = PatternMatcher([
# GEP/VECTORIZE, GEP/GEP, GEP/CONST, GEP/VCONST # GEP/VECTORIZE, GEP/GEP, GEP/CONST, GEP/VCONST
(UPat(Ops.GEP, src=(UPat(Ops.GEP, name='g2'),), name='g1'), (UPat(Ops.GEP, src=(UPat(Ops.GEP, name='g2'),), name='g1'),
lambda g1, g2: g2.src[0].gep(tuple(g2.arg[g1.arg[i]] for i in range(g1.dtype.count)))), lambda g1, g2: g2.src[0].gep(tuple(g2.arg[g1.arg[i]] for i in range(len(g1.arg))))),
(UPat(Ops.GEP, src=(UPat(Ops.VECTORIZE, name="vec"),), name="gep"), (UPat(Ops.GEP, src=(UPat(Ops.VECTORIZE, name="vec"),), name="gep"),
lambda gep, vec: UOp(Ops.VECTORIZE, gep.dtype, tuple(vec.src[i] for i in gep.arg)) if len(gep.arg) > 1 else vec.src[gep.arg[0]]), lambda gep, vec: UOp(Ops.VECTORIZE, gep.dtype, tuple(vec.src[i] for i in gep.arg)) if len(gep.arg) > 1 else vec.src[gep.arg[0]]),
(UPat(Ops.GEP, src=(UPat.cvar("c", vec=False),), name="gep"), lambda gep, c: gep.const_like(c.arg)), (UPat(Ops.GEP, src=(UPat.cvar("c", vec=False),), name="gep"), lambda gep, c: gep.const_like(c.arg)),
@@ -191,9 +191,13 @@ gep_pushing = PatternMatcher([
(UPat(Ops.VECTORIZE, name="v", src=UPat(Ops.GEP, src=(UPat.var("x"),))), lambda v,x: x.gep(tuple(get_single_element(i.arg) for i in v.src))), (UPat(Ops.VECTORIZE, name="v", src=UPat(Ops.GEP, src=(UPat.var("x"),))), lambda v,x: x.gep(tuple(get_single_element(i.arg) for i in v.src))),
]) ])
symbolic = symbolic_simple+PatternMatcher([ commutative = PatternMatcher([
# ** COMMUTATIVE flipping (only for ints) ** # ** COMMUTATIVE flipping (only for ints) **
(UPat(GroupOp.Commutative, dtype=dtypes.ints, name='x'), lambda x: x.replace(src=x.src[::-1]) if x.src[1].tuplize < x.src[0].tuplize else None), # NOTE: this can break merging vector math by only flipping some of them
(UPat(GroupOp.Commutative, dtype=dtypes.int, name='x'), lambda x: x.replace(src=x.src[::-1]) if x.src[1].tuplize < x.src[0].tuplize else None),
])
symbolic = symbolic_simple+commutative+PatternMatcher([
# ** boolean algebra ** # ** boolean algebra **
(UPat.var("x") | (UPat.var("x") & UPat.var()), lambda x: x), # x|(x&y) -> x (UPat.var("x") | (UPat.var("x") & UPat.var()), lambda x: x), # x|(x&y) -> x
# ** combine terms ** # ** combine terms **

View File

@@ -149,8 +149,7 @@ class CapturedJit(Generic[ReturnType]):
expected_st_vars_dtype_device: list[tuple[ShapeTracker, tuple[Variable, ...], DType, str]] expected_st_vars_dtype_device: list[tuple[ShapeTracker, tuple[Variable, ...], DType, str]]
def __reduce__(self): def __reduce__(self):
# TODO: free_intermediates here? # TODO: free_intermediates here? optimize_weights here?
self.optimize_weights()
return self.__class__, (self.ret, self.jit_cache, self.input_replace, self.extra_view_inputs, return self.__class__, (self.ret, self.jit_cache, self.input_replace, self.extra_view_inputs,
self.expected_names, self.expected_st_vars_dtype_device) self.expected_names, self.expected_st_vars_dtype_device)
@@ -218,12 +217,13 @@ def _prepare_jit_inputs(args, kwargs):
return input_buffers, var_vals, names, st_vars_dtype_device return input_buffers, var_vals, names, st_vars_dtype_device
class TinyJit(Generic[ReturnType]): class TinyJit(Generic[ReturnType]):
def __init__(self, fxn:Optional[Callable[..., ReturnType]], captured:Optional[CapturedJit]=None, prune=False): def __init__(self, fxn:Optional[Callable[..., ReturnType]], captured:Optional[CapturedJit]=None, prune=False, optimize=False):
assert fxn or captured, "need either a function or a CapturedJit" assert fxn or captured, "need either a function or a CapturedJit"
self.fxn = fxn self.fxn = fxn
self.captured: Optional[CapturedJit] = captured self.captured: Optional[CapturedJit] = captured
self.cnt: int = 2 if self.fxn is None else 0 self.cnt: int = 2 if self.fxn is None else 0
self.prune = prune self.prune = prune
self.optimize = optimize
def add_buffer(self, b:Buffer) -> Buffer: def add_buffer(self, b:Buffer) -> Buffer:
if found:=self._buffer_replace.get(b, None): return found if found:=self._buffer_replace.get(b, None): return found
@@ -314,6 +314,7 @@ class TinyJit(Generic[ReturnType]):
# set this for next run # set this for next run
self.captured = CapturedJit(ret, jit_cache, input_replace, extra_view_inputs, names, st_vars_dtype_device) self.captured = CapturedJit(ret, jit_cache, input_replace, extra_view_inputs, names, st_vars_dtype_device)
if self.optimize: self.captured.optimize_weights()
elif self.cnt >= 2: elif self.cnt >= 2:
# jit exec # jit exec
assert self.captured is not None assert self.captured is not None

View File

@@ -62,7 +62,7 @@ base_rewrite = PatternMatcher([
extra_pm = PatternMatcher([ extra_pm = PatternMatcher([
# insert a NOOP before BITCAST to force it to be rendered. not needed on all backends? # insert a NOOP before BITCAST to force it to be rendered. not needed on all backends?
(UPat(Ops.BITCAST, name="x"), (UPat(Ops.BITCAST, name="x"),
lambda x: UOp(Ops.BITCAST, x.dtype, (UOp(Ops.NOOP, x.src[0].dtype, x.src),)) if x.src[0].op is not Ops.NOOP else None), lambda x: UOp(Ops.BITCAST, x.dtype, (UOp(Ops.NOOP, x.src[0].dtype, x.src),)) if x.src[0].op not in {Ops.NOOP, Ops.LOAD, Ops.CUSTOM} else None),
# rewrite MAX to CMPLT + WHERE (max function is annoying on many cstyle backends) # rewrite MAX to CMPLT + WHERE (max function is annoying on many cstyle backends)
(UPat(Ops.MAX, name="m"), lambda m: (m.src[0] < m.src[1]).where(m.src[1], m.src[0])), (UPat(Ops.MAX, name="m"), lambda m: (m.src[0] < m.src[1]).where(m.src[1], m.src[0])),
# devectorize any bools # devectorize any bools