diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index ee9fef15de..e081e0eb90 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -335,8 +335,8 @@ jobs: run: awk '/```python/{flag=1;next}/```/{flag=0}flag' README.md > README.py && PYTHONPATH=. python README.py - name: Run unit tests run: PYTHONPATH="." python -m pytest -n=auto test/unit/ - - name: Repo line count < 12000 lines - run: MAX_LINE_COUNT=12000 python sz.py + - name: Repo line count < 12500 lines + run: MAX_LINE_COUNT=12500 python sz.py fuzzing: name: Fuzzing diff --git a/examples/benchmark_onnx.py b/examples/benchmark_onnx.py index 2a1425ba6e..de7ac9122b 100644 --- a/examples/benchmark_onnx.py +++ b/examples/benchmark_onnx.py @@ -1,4 +1,4 @@ -import sys, onnx, time +import sys, onnx, time, pickle from tinygrad import TinyJit, Device, GlobalCounters, fetch, getenv from tinygrad.frontend.onnx import OnnxRunner from extra.onnx_helpers import get_example_inputs, validate @@ -33,4 +33,9 @@ if __name__ == "__main__": if getenv("ORT"): validate(onnx_file, new_inputs, rtol=1e-3, atol=1e-3) - print("model validated") \ No newline at end of file + print("model validated") + + if (fn:=getenv("SAVE_PKL", "")) != "": + with open(fn, "wb") as f: + pickle.dump(run_onnx_jit, f) + print(f"pkl saved to {fn}") diff --git a/examples/test_onnx_imagenet.py b/examples/test_onnx_imagenet.py index af125f84d2..1c35e29f6a 100644 --- a/examples/test_onnx_imagenet.py +++ b/examples/test_onnx_imagenet.py @@ -70,7 +70,7 @@ if __name__ == "__main__": GlobalCounters.reset() p = run_onnx_jit(**{t_name:img}) assert p.shape == (1,1000) - t = p.argmax().item() + t = p.to('cpu').argmax().item() hit += y==t print(f"target: {y:3d} pred: {t:3d} acc: {hit/(i+1)*100:.2f}%") diff --git a/extra/replay_pkl.py b/extra/replay_pkl.py index 9fbb21c108..1965c5bd73 100644 --- a/extra/replay_pkl.py +++ b/extra/replay_pkl.py @@ -1,12 +1,34 @@ import pickle, sys from dataclasses import replace -from tinygrad import Device, Context +from tinygrad import Device, Context, Tensor, GlobalCounters from tinygrad.device import Buffer from tinygrad.helpers import getenv, BEAM from tinygrad.engine.jit import TinyJit -from tinygrad.engine.realize import CompiledRunner +from tinygrad.engine.realize import CompiledRunner, ExecItem, ScheduleItem, lower_schedule_item from tinygrad.renderer import ProgramSpec from tinygrad.codegen.kernel import Kernel, Opt, OptOps +import numpy as np + +def move_jit_captured_to_dev(captured, device="DSP"): + captured.expected_st_vars_dtype_device = [x[:3] + (device,) for x in captured.expected_st_vars_dtype_device] + + assign = {} + def move_buffer(b): + if b in assign: return assign[b] + + if b._base is not None: + newbuf = Buffer(device, b.size, b.dtype, base=move_buffer(b._base), offset=b.offset) + else: + newbuf = Buffer(device, b.size, b.dtype) + if b.is_allocated(): newbuf.ensure_allocated().copyin(b.as_buffer()) + assign[b] = newbuf + return assign[b] + + for item in captured.jit_cache: + for b in item.bufs: + if b is not None: move_buffer(b) + captured.jit_cache = [ExecItem(item.prg, [assign.get(b,b) for b in item.bufs]) for item in captured.jit_cache] + return captured if __name__ == "__main__": with Context(DEBUG=0): @@ -15,6 +37,10 @@ if __name__ == "__main__": print(f"{f.tell()/1e6:.2f}M loaded") print(type(fxn)) + # Move all buffers to DSP device. + fxn.captured = move_jit_captured_to_dev(fxn.captured, "DSP") + new_jit = [] + knum = 1 for ei in fxn.captured.jit_cache: # skip the copy and the first kernel @@ -22,9 +48,27 @@ if __name__ == "__main__": if knum == (pknum:=getenv("KNUM", 0)) or pknum == 0: p: ProgramSpec = ei.prg.p k = Kernel(p.ast, Device["DSP"].renderer) - dsp_bufs = [Buffer("DSP", 8192+b.size, b.dtype).view(b.size, b.dtype, 4096) for b in ei.bufs] - k.hand_coded_optimizations() + + if getenv("VALIDATE"): + with Context(NOOPT=1): + lower_schedule_item(ScheduleItem(p.ast, ei.bufs)).run() + correct = ei.bufs[0].numpy() + ei.bufs[0].copyin(memoryview(bytearray(b'\x00'*ei.bufs[0].nbytes))) + GlobalCounters.kernel_count -= 1 + + if not getenv("NOOPT"): k.hand_coded_optimizations() p2 = k.to_program() - new_ei = replace(ei, prg=CompiledRunner(p2), bufs=dsp_bufs) + new_ei = replace(ei, prg=CompiledRunner(p2)) new_ei.run() + new_jit.append(new_ei) + test = ei.bufs[0].numpy() + + if getenv("VALIDATE"): + import numpy as np + np.testing.assert_allclose(correct, test, rtol=1e-3, atol=1e-3) knum += 1 + + if getenv("RUN_JIT", 0): + fxn.captured.free_intermediates() + fxn.captured.jit_cache = new_jit + fxn(input=Tensor(np.zeros((1, 3, 224, 224), dtype=np.float32), device="DSP")) diff --git a/tinygrad/codegen/devectorizer.py b/tinygrad/codegen/devectorizer.py index f3f74d0ba2..6690d2ac7d 100644 --- a/tinygrad/codegen/devectorizer.py +++ b/tinygrad/codegen/devectorizer.py @@ -150,7 +150,6 @@ def split_load_store(ctx:Renderer|None, ls:UOp, idx:UOp): must_divide = True if ctx is not None and ctx.device == "DSP": lengths = [128,64,32,16,8,4] - if ls.dtype.count < 128: return None # leave these as loads (probably means something is broken) must_divide = False elif buf.dtype.base != dtypes.float and buf.dtype.base != dtypes.half and not isinstance(buf.dtype, ImageDType): pass diff --git a/tinygrad/codegen/lowerer.py b/tinygrad/codegen/lowerer.py index f2347a896e..e1e93e70f2 100644 --- a/tinygrad/codegen/lowerer.py +++ b/tinygrad/codegen/lowerer.py @@ -162,7 +162,7 @@ pm_lowerer = PatternMatcher([ # **** this is the "quantization preprocessor", it makes ONNX quantized models, and probably also others, actually use ints **** -FP = (1 << 16) +FP = (1 << 15) pm_quant = symbolic+PatternMatcher([ # cast after add/mul (UPat.var("x").cast(dtypes.float32) + UPat.var("y").cast(dtypes.float32), @@ -192,12 +192,15 @@ pm_quant = symbolic+PatternMatcher([ UPat(Ops.VALID, src=(UPat(Ops.VIEW, name="v"),)).where(UPat.cvar(), UPat(Ops.CONST, arg=0))).cast(dtypes.float).named("ld"), lambda ld,v,c1: ld*c1), + # const push through add + ((UPat.var("x")*UPat.cvar("c1") + UPat.var("y")*UPat.cvar("c2")) * UPat.cvar("c3"), lambda x,y,c1,c2,c3: (x*c1*c3) + (y*c2*c3)), + # fixed point mult, replace (x.float()*c1+c2).int() with an int expression - ((UPat.var("x").cast(dtypes.float)*UPat.var("c1")+UPat.var("c2")).cast(dtypes.int), - lambda x,c1,c2: (x * (c1 * FP).cast(dtypes.int) + (c2 * FP).cast(dtypes.int)) // FP), - # fixed point mult, replace (x.float()*c1 + y.float()*c2) with an int expression - ((UPat.var("x").cast(dtypes.float)*UPat.var("c1")+UPat.var("y").cast(dtypes.float)*UPat.var("c2")), - lambda x,y,c1,c2: ((x * (c1 * FP).cast(dtypes.int) + y * (c2 * FP).cast(dtypes.int)) // FP).cast(dtypes.float)), + ((UPat.var("x").cast(dtypes.float)*UPat.var("c1")+UPat.var("cc")).cast(dtypes.int), + lambda x,c1,cc: ((x*(c1*FP).cast(x.dtype) + (cc*FP).cast(x.dtype)) // FP).cast(dtypes.int)), + # fixed point mult, replace (x.float()*c1 + y.float()*c2)*cc.int() with an int expression + ((UPat.var("x").cast(dtypes.float)*UPat.var("c1")+UPat.var("y").cast(dtypes.float)*UPat.var("c2")+UPat.var("cc")).cast(dtypes.int), + lambda x,c1,y,c2,cc: ((x*(c1*FP).cast(x.dtype) + y.cast(x.dtype)*(c2*FP).cast(x.dtype) + (cc*FP).cast(x.dtype)) // FP).cast(dtypes.int)), # where move (UPat.var("valid").where(UPat.var("yes"), UPat(Ops.CONST, arg=0))*UPat.var("mul"), lambda valid, yes, mul: @@ -211,13 +214,13 @@ pm_quant = symbolic+PatternMatcher([ # where on two adds (UPat.var("x") + UPat.var("v").where(UPat.var("a0"), UPat.var("a1")) + UPat.var("v").where(UPat.var("b0"), UPat.var("b1")), - lambda x,v,a0,a1,b0,b1: x + v.where(a0+a1, b0+b1)), + lambda x,v,a0,a1,b0,b1: x + v.where(a0+b0, a1+b1)), - # split REDUCE into multiple reduces + # split REDUCE into multiple reduces (who remembers FOIL?) (UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.CAST, name="v1")+UPat.var("c1")) * UPat(Ops.CAST, name="v2",), name="r"), lambda v1,v2,c1,r: r.replace(src=(v1*v2,)) + r.replace(src=(c1*v2,))), (UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.CAST, name="v1")+UPat.var("c1")) * (UPat(Ops.CAST, name="v2",)+UPat.var("c2")), name="r"), - lambda v1,v2,c1,c2,r: r.replace(src=(v1*v2,)) + r.replace(src=(c2*v1,)) + r.replace(src=(c1*v2,))), + lambda v1,v2,c1,c2,r: r.replace(src=(v1*v2,)) + r.replace(src=(c2*v1,)) + r.replace(src=(c1*v2,)) + r.replace(src=(c1*c2,))), ]) def rewrite_shapetracker_with_index(ast:UOp, opts:Renderer) -> UOp: diff --git a/tinygrad/codegen/symbolic.py b/tinygrad/codegen/symbolic.py index fea8066d44..d589f35a61 100644 --- a/tinygrad/codegen/symbolic.py +++ b/tinygrad/codegen/symbolic.py @@ -4,7 +4,7 @@ import math, operator, struct, functools from collections import defaultdict from tinygrad.ops import Ops, PatternMatcher, UPat, UOp, GroupOp, exec_alu from tinygrad.dtype import ConstType, dtypes, PtrDType -from tinygrad.helpers import partition, all_same, prod, getenv, DEBUG, flatten +from tinygrad.helpers import partition, all_same, prod, getenv, DEBUG, flatten, get_single_element from tinygrad.codegen.transcendental import xpow # ******** phase 1 of symbolic used to live in ops, it's the most generic folding rules ******** @@ -184,6 +184,11 @@ gep_pushing = PatternMatcher([ (UPat(Ops.GEP, src=(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST), name='alu'),), name='gep'), lambda gep,alu: UOp(alu.op, alu.dtype.scalar().vec(gep.dtype.count), tuple(x.gep(gep.arg) for x in alu.src), alu.arg) \ if not isinstance(gep.dtype, PtrDType) else None), + # CAT can't be rendered. it's a VECTORIZE on vectors, we expand to a single VECTORIZEs with GEPs (TODO: move this later) + (UPat(Ops.CAT, name="x"), lambda x: UOp(Ops.VECTORIZE, x.dtype, tuple(y.gep(i) for y in x.src for i in range(y.dtype.count))) \ + if not isinstance(x.dtype, PtrDType) else None), + # VECTORIZE on same GEP + (UPat(Ops.VECTORIZE, name="v", src=UPat(Ops.GEP, src=(UPat.var("x"),))), lambda v,x: x.gep(tuple(get_single_element(i.arg) for i in v.src))), ]) symbolic = symbolic_simple+PatternMatcher([ @@ -420,9 +425,6 @@ sym = symbolic_flat+PatternMatcher([ (UPat(Ops.VECTORIZE, dtype=dtypes.void, name='x'), lambda x: UOp(Ops.SINK, dtypes.void, x.src)), # push some GEPs through WMMAs (UPat(Ops.GEP, src=(UPat(Ops.WMMA, name="wmma"),), name="gep"), gep_through_wmma), - # CAT can't be rendered. it's a VECTORIZE on vectors, we expand to a single VECTORIZEs with GEPs (TODO: move this later) - (UPat(Ops.CAT, name="x"), lambda x: UOp(Ops.VECTORIZE, x.dtype, tuple(y.gep(i) for y in x.src for i in range(y.dtype.count))) \ - if not isinstance(x.dtype, PtrDType) else None), # tensor core with a 0 input is acc (UPat(Ops.WMMA, src=(UPat.const(None, 0.0), UPat.var(), UPat.var("acc"))), lambda acc: acc), (UPat(Ops.WMMA, src=(UPat.var(), UPat.const(None, 0.0), UPat.var("acc"))), lambda acc: acc), diff --git a/tinygrad/engine/jit.py b/tinygrad/engine/jit.py index dfd5556456..ac65ce5c7a 100644 --- a/tinygrad/engine/jit.py +++ b/tinygrad/engine/jit.py @@ -168,8 +168,8 @@ class CapturedJit(Generic[ReturnType]): update_depends(depends, self.jit_cache) for b in depends: if b is not None: - b.deallocate() - if b._base is not None and b._base.allocated_views == 0: b._base.deallocate() + if b.is_allocated(): b.deallocate() + if b._base is not None and b._base.allocated_views == 0 and b._base.is_allocated(): b._base.deallocate() self.__post_init__() # reset the graph state def optimize_weights(self): diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index b9a63d766c..f0c47f0e41 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -398,7 +398,7 @@ if CAPTURE_PROCESS_REPLAY: class ScheduleItem: ast: UOp bufs: tuple[Buffer, ...] - metadata: tuple[Metadata, ...] + metadata: tuple[Metadata, ...] = () @track_rewrites(name_fxn=lambda r: f"Schedule {pluralize('Kernel', len(r[0]))}"+(f" (with_{pluralize('Var', len(r[1]))})" if len(r[1]) != 0 else "")) def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Variable, int], dict[UOp, UOp]]: diff --git a/tinygrad/ops.py b/tinygrad/ops.py index a17f80c2a7..ed007dc290 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -322,10 +322,10 @@ class UOp(MathTrait, metaclass=UOpMetaClass): @functools.cached_property def full_shape(self) -> tuple[sint, ...]: if self.op is Ops.VIEW: return self.shape + # TODO: this exists because wmma creates consts without ShapeTracker in the AST, there's probably a way to fix this + parent_shapes = [x.full_shape for x in self.src if x.op not in {Ops.DEFINE_GLOBAL,Ops.DEFINE_LOCAL} and not (x.op is Ops.CONST and x.st is None)] # TODO: this should check if st is None, it cannot because local reduce has implicit movement ops - return tuple(smax(x) for x in zip(*[x.full_shape for x in self.src if x.op not in {Ops.DEFINE_GLOBAL,Ops.DEFINE_LOCAL} \ - # TODO: this exists because wmma creates consts without ShapeTracker in the AST, there's probably a way to fix this - and not (x.op is Ops.CONST and x.st is None)])) + return tuple(smax(x) for x in zip(*[x for x in parent_shapes if x != ()])) @property def shape(self) -> tuple[sint, ...]: return unwrap(self.st).shape @property @@ -349,6 +349,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass): def __int__(self): return self._eval(dtypes.ints, int) def __float__(self): return self._eval(dtypes.floats, float) def substitute(self, dvars:dict[UOp, UOp]): + if len(dvars) == 0: return self with Context(TRACK_MATCH_STATS=0): return graph_rewrite(self, _substitute, dvars, bottom_up=True) @@ -976,6 +977,8 @@ renderer = PatternMatcher([ (UPat(Ops.RANGE, name="x"), lambda x: UOp(Ops.NOOP, arg=f"ridx{x.arg}")), (UPat((Ops.CONST, Ops.VCONST), name="x"), lambda x: UOp(Ops.NOOP, arg=str(x.arg))), (UPat(Ops.UNROLL, name="x"), lambda x: UOp(Ops.NOOP, arg=f"UNROLL({x.src[0].arg}, {x.arg})")), + (UPat(Ops.CAST, name="x"), lambda x: UOp(Ops.NOOP, arg=f"({str(x.dtype)[7:]})({x.src[0].arg})")), + (UPat(Ops.LOAD), lambda: UOp(Ops.NOOP, arg="load")), (UPat(Ops.BIND, src=UPat(Ops.NOOP), name="x"), lambda x: x.src[0]), (UPat(Ops.NEG, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"(-{x.src[0].arg})")), (UPat(Ops.MAX, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"max({x.src[0].arg}, {x.src[1].arg})")), diff --git a/tinygrad/renderer/__init__.py b/tinygrad/renderer/__init__.py index a08c3495b5..ebdf97a6be 100644 --- a/tinygrad/renderer/__init__.py +++ b/tinygrad/renderer/__init__.py @@ -70,6 +70,8 @@ class Estimates: if u.op is Ops.RANGE: mult_stack.append(mults) mults *= (u.src[1] - u.src[0]).ssimplify() + # SPECIAL are already counted in mults + mults = mults.substitute({x:x.const_like(0) for x in mults.toposort if x.op is Ops.SPECIAL}) if isinstance(mults, UOp) else mults elif u.op is Ops.ENDRANGE: mults = mult_stack.pop(-1) elif u.op is Ops.SPECIAL: mults *= u.arg[1] # NOTE: we don't push to the mult_stack here, you can't end these elif u.op is Ops.LOAD: lds += u.dtype.itemsize * mults diff --git a/tinygrad/runtime/graph/cpu.py b/tinygrad/runtime/graph/cpu.py index 454ce5f411..c7256545ae 100644 --- a/tinygrad/runtime/graph/cpu.py +++ b/tinygrad/runtime/graph/cpu.py @@ -33,9 +33,9 @@ class CPUGraph(GraphRunner): prep = [device.renderer._render(cast(CompiledRunner, ji.prg).p.uops) for i,ji in enumerate(jit_cache)] funcs = dedup(device.renderer._render_body(prep[i][0], *prep[i][1:], cast(CompiledRunner, ji.prg).p.uops) for i,ji in enumerate(jit_cache)) - defines = '\n'.join(set(itertools.chain.from_iterable(device.renderer._render_defines(cast(CompiledRunner, ji.prg).p.uops) for ji in jit_cache))) + defines = dedup(itertools.chain.from_iterable(device.renderer._render_defines(cast(CompiledRunner, ji.prg).p.uops) for ji in jit_cache)) entry = device.renderer._render_entry("batched", targs) - code = defines + '\n' + '\n'.join([''.join(f) for f in funcs]) + '\n'.join(batched) + '\n' + entry + code = '\n'.join(defines) + '\n' + '\n'.join([''.join(f) for f in funcs]) + '\n'.join(batched) + '\n' + entry if DEBUG >= 4: print(code) self.clprg = device.runtime("batched", device.compiler.compile_cached(code))