From 4514fd91c10d9cc7b827b1c45d7eeef8626e1660 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Wed, 2 Apr 2025 15:27:48 +0800 Subject: [PATCH] more stuff from DSP (#9689) * more good stuff from dsp branch * test pkl imagenet --- examples/benchmark_onnx.py | 9 ++------- examples/test_pkl_imagenet.py | 19 +++++++++++++++++++ tinygrad/codegen/symbolic.py | 10 +++++++--- tinygrad/engine/jit.py | 7 ++++--- tinygrad/renderer/cstyle.py | 2 +- 5 files changed, 33 insertions(+), 14 deletions(-) create mode 100644 examples/test_pkl_imagenet.py diff --git a/examples/benchmark_onnx.py b/examples/benchmark_onnx.py index de7ac9122b..e88033bd0a 100644 --- a/examples/benchmark_onnx.py +++ b/examples/benchmark_onnx.py @@ -1,12 +1,12 @@ import sys, onnx, time, pickle -from tinygrad import TinyJit, Device, GlobalCounters, fetch, getenv +from tinygrad import TinyJit, GlobalCounters, fetch, getenv from tinygrad.frontend.onnx import OnnxRunner from extra.onnx_helpers import get_example_inputs, validate def load_onnx_model(onnx_file): onnx_model = onnx.load(onnx_file) run_onnx = OnnxRunner(onnx_model) - run_onnx_jit = TinyJit(lambda **kwargs: next(iter(run_onnx({k:v.to(Device.DEFAULT) for k,v in kwargs.items()}).values())), prune=True) + run_onnx_jit = TinyJit(lambda **kwargs: next(iter(run_onnx({k:v.to(None) for k,v in kwargs.items()}).values())), prune=True, optimize=True) return run_onnx_jit, run_onnx.graph_inputs if __name__ == "__main__": @@ -34,8 +34,3 @@ if __name__ == "__main__": if getenv("ORT"): validate(onnx_file, new_inputs, rtol=1e-3, atol=1e-3) print("model validated") - - if (fn:=getenv("SAVE_PKL", "")) != "": - with open(fn, "wb") as f: - pickle.dump(run_onnx_jit, f) - print(f"pkl saved to {fn}") diff --git a/examples/test_pkl_imagenet.py b/examples/test_pkl_imagenet.py new file mode 100644 index 0000000000..8110abf309 --- /dev/null +++ b/examples/test_pkl_imagenet.py @@ -0,0 +1,19 @@ +import sys, pickle +from tinygrad import GlobalCounters +from tinygrad.helpers import fetch, getenv +from examples.test_onnx_imagenet import imagenet_dataloader + +if __name__ == "__main__": + with open(fetch(sys.argv[1]), "rb") as f: + run_onnx_jit = pickle.load(f) + input_name = run_onnx_jit.captured.expected_names[0] + device = run_onnx_jit.captured.expected_st_vars_dtype_device[0][-1] + print(f"input goes into {input_name=} on {device=}") + hit = 0 + for i,(img,y) in enumerate(imagenet_dataloader(cnt=getenv("CNT", 100))): + GlobalCounters.reset() + p = run_onnx_jit(**{input_name:img.to(device)}) + assert p.shape == (1,1000) + t = p.to('cpu').argmax().item() + hit += y==t + print(f"target: {y:3d} pred: {t:3d} acc: {hit/(i+1)*100:.2f}%") diff --git a/tinygrad/codegen/symbolic.py b/tinygrad/codegen/symbolic.py index d589f35a61..df95f8d8ec 100644 --- a/tinygrad/codegen/symbolic.py +++ b/tinygrad/codegen/symbolic.py @@ -175,7 +175,7 @@ def div_and_mod_folding(x: UOp, y: UOp, which: Literal[Ops.MOD, Ops.IDIV], split gep_pushing = PatternMatcher([ # GEP/VECTORIZE, GEP/GEP, GEP/CONST, GEP/VCONST (UPat(Ops.GEP, src=(UPat(Ops.GEP, name='g2'),), name='g1'), - lambda g1, g2: g2.src[0].gep(tuple(g2.arg[g1.arg[i]] for i in range(g1.dtype.count)))), + lambda g1, g2: g2.src[0].gep(tuple(g2.arg[g1.arg[i]] for i in range(len(g1.arg))))), (UPat(Ops.GEP, src=(UPat(Ops.VECTORIZE, name="vec"),), name="gep"), lambda gep, vec: UOp(Ops.VECTORIZE, gep.dtype, tuple(vec.src[i] for i in gep.arg)) if len(gep.arg) > 1 else vec.src[gep.arg[0]]), (UPat(Ops.GEP, src=(UPat.cvar("c", vec=False),), name="gep"), lambda gep, c: gep.const_like(c.arg)), @@ -191,9 +191,13 @@ gep_pushing = PatternMatcher([ (UPat(Ops.VECTORIZE, name="v", src=UPat(Ops.GEP, src=(UPat.var("x"),))), lambda v,x: x.gep(tuple(get_single_element(i.arg) for i in v.src))), ]) -symbolic = symbolic_simple+PatternMatcher([ +commutative = PatternMatcher([ # ** COMMUTATIVE flipping (only for ints) ** - (UPat(GroupOp.Commutative, dtype=dtypes.ints, name='x'), lambda x: x.replace(src=x.src[::-1]) if x.src[1].tuplize < x.src[0].tuplize else None), + # NOTE: this can break merging vector math by only flipping some of them + (UPat(GroupOp.Commutative, dtype=dtypes.int, name='x'), lambda x: x.replace(src=x.src[::-1]) if x.src[1].tuplize < x.src[0].tuplize else None), +]) + +symbolic = symbolic_simple+commutative+PatternMatcher([ # ** boolean algebra ** (UPat.var("x") | (UPat.var("x") & UPat.var()), lambda x: x), # x|(x&y) -> x # ** combine terms ** diff --git a/tinygrad/engine/jit.py b/tinygrad/engine/jit.py index ac65ce5c7a..04fbfb1cb9 100644 --- a/tinygrad/engine/jit.py +++ b/tinygrad/engine/jit.py @@ -149,8 +149,7 @@ class CapturedJit(Generic[ReturnType]): expected_st_vars_dtype_device: list[tuple[ShapeTracker, tuple[Variable, ...], DType, str]] def __reduce__(self): - # TODO: free_intermediates here? - self.optimize_weights() + # TODO: free_intermediates here? optimize_weights here? return self.__class__, (self.ret, self.jit_cache, self.input_replace, self.extra_view_inputs, self.expected_names, self.expected_st_vars_dtype_device) @@ -218,12 +217,13 @@ def _prepare_jit_inputs(args, kwargs): return input_buffers, var_vals, names, st_vars_dtype_device class TinyJit(Generic[ReturnType]): - def __init__(self, fxn:Optional[Callable[..., ReturnType]], captured:Optional[CapturedJit]=None, prune=False): + def __init__(self, fxn:Optional[Callable[..., ReturnType]], captured:Optional[CapturedJit]=None, prune=False, optimize=False): assert fxn or captured, "need either a function or a CapturedJit" self.fxn = fxn self.captured: Optional[CapturedJit] = captured self.cnt: int = 2 if self.fxn is None else 0 self.prune = prune + self.optimize = optimize def add_buffer(self, b:Buffer) -> Buffer: if found:=self._buffer_replace.get(b, None): return found @@ -314,6 +314,7 @@ class TinyJit(Generic[ReturnType]): # set this for next run self.captured = CapturedJit(ret, jit_cache, input_replace, extra_view_inputs, names, st_vars_dtype_device) + if self.optimize: self.captured.optimize_weights() elif self.cnt >= 2: # jit exec assert self.captured is not None diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index e900fc875b..857683118a 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -62,7 +62,7 @@ base_rewrite = PatternMatcher([ extra_pm = PatternMatcher([ # insert a NOOP before BITCAST to force it to be rendered. not needed on all backends? (UPat(Ops.BITCAST, name="x"), - lambda x: UOp(Ops.BITCAST, x.dtype, (UOp(Ops.NOOP, x.src[0].dtype, x.src),)) if x.src[0].op is not Ops.NOOP else None), + lambda x: UOp(Ops.BITCAST, x.dtype, (UOp(Ops.NOOP, x.src[0].dtype, x.src),)) if x.src[0].op not in {Ops.NOOP, Ops.LOAD, Ops.CUSTOM} else None), # rewrite MAX to CMPLT + WHERE (max function is annoying on many cstyle backends) (UPat(Ops.MAX, name="m"), lambda m: (m.src[0] < m.src[1]).where(m.src[1], m.src[0])), # devectorize any bools