fixes from the dsp branch + 12500 lines (#9683)

* fixes from the dsp branch

* more changes

* those are gep pushing
This commit is contained in:
George Hotz
2025-04-02 13:07:17 +08:00
committed by GitHub
parent c20f112e9f
commit 6f812d3f2f
12 changed files with 90 additions and 32 deletions

View File

@@ -335,8 +335,8 @@ jobs:
run: awk '/```python/{flag=1;next}/```/{flag=0}flag' README.md > README.py && PYTHONPATH=. python README.py
- name: Run unit tests
run: PYTHONPATH="." python -m pytest -n=auto test/unit/
- name: Repo line count < 12000 lines
run: MAX_LINE_COUNT=12000 python sz.py
- name: Repo line count < 12500 lines
run: MAX_LINE_COUNT=12500 python sz.py
fuzzing:
name: Fuzzing

View File

@@ -1,4 +1,4 @@
import sys, onnx, time
import sys, onnx, time, pickle
from tinygrad import TinyJit, Device, GlobalCounters, fetch, getenv
from tinygrad.frontend.onnx import OnnxRunner
from extra.onnx_helpers import get_example_inputs, validate
@@ -33,4 +33,9 @@ if __name__ == "__main__":
if getenv("ORT"):
validate(onnx_file, new_inputs, rtol=1e-3, atol=1e-3)
print("model validated")
print("model validated")
if (fn:=getenv("SAVE_PKL", "")) != "":
with open(fn, "wb") as f:
pickle.dump(run_onnx_jit, f)
print(f"pkl saved to {fn}")

View File

@@ -70,7 +70,7 @@ if __name__ == "__main__":
GlobalCounters.reset()
p = run_onnx_jit(**{t_name:img})
assert p.shape == (1,1000)
t = p.argmax().item()
t = p.to('cpu').argmax().item()
hit += y==t
print(f"target: {y:3d} pred: {t:3d} acc: {hit/(i+1)*100:.2f}%")

View File

@@ -1,12 +1,34 @@
import pickle, sys
from dataclasses import replace
from tinygrad import Device, Context
from tinygrad import Device, Context, Tensor, GlobalCounters
from tinygrad.device import Buffer
from tinygrad.helpers import getenv, BEAM
from tinygrad.engine.jit import TinyJit
from tinygrad.engine.realize import CompiledRunner
from tinygrad.engine.realize import CompiledRunner, ExecItem, ScheduleItem, lower_schedule_item
from tinygrad.renderer import ProgramSpec
from tinygrad.codegen.kernel import Kernel, Opt, OptOps
import numpy as np
def move_jit_captured_to_dev(captured, device="DSP"):
captured.expected_st_vars_dtype_device = [x[:3] + (device,) for x in captured.expected_st_vars_dtype_device]
assign = {}
def move_buffer(b):
if b in assign: return assign[b]
if b._base is not None:
newbuf = Buffer(device, b.size, b.dtype, base=move_buffer(b._base), offset=b.offset)
else:
newbuf = Buffer(device, b.size, b.dtype)
if b.is_allocated(): newbuf.ensure_allocated().copyin(b.as_buffer())
assign[b] = newbuf
return assign[b]
for item in captured.jit_cache:
for b in item.bufs:
if b is not None: move_buffer(b)
captured.jit_cache = [ExecItem(item.prg, [assign.get(b,b) for b in item.bufs]) for item in captured.jit_cache]
return captured
if __name__ == "__main__":
with Context(DEBUG=0):
@@ -15,6 +37,10 @@ if __name__ == "__main__":
print(f"{f.tell()/1e6:.2f}M loaded")
print(type(fxn))
# Move all buffers to DSP device.
fxn.captured = move_jit_captured_to_dev(fxn.captured, "DSP")
new_jit = []
knum = 1
for ei in fxn.captured.jit_cache:
# skip the copy and the first kernel
@@ -22,9 +48,27 @@ if __name__ == "__main__":
if knum == (pknum:=getenv("KNUM", 0)) or pknum == 0:
p: ProgramSpec = ei.prg.p
k = Kernel(p.ast, Device["DSP"].renderer)
dsp_bufs = [Buffer("DSP", 8192+b.size, b.dtype).view(b.size, b.dtype, 4096) for b in ei.bufs]
k.hand_coded_optimizations()
if getenv("VALIDATE"):
with Context(NOOPT=1):
lower_schedule_item(ScheduleItem(p.ast, ei.bufs)).run()
correct = ei.bufs[0].numpy()
ei.bufs[0].copyin(memoryview(bytearray(b'\x00'*ei.bufs[0].nbytes)))
GlobalCounters.kernel_count -= 1
if not getenv("NOOPT"): k.hand_coded_optimizations()
p2 = k.to_program()
new_ei = replace(ei, prg=CompiledRunner(p2), bufs=dsp_bufs)
new_ei = replace(ei, prg=CompiledRunner(p2))
new_ei.run()
new_jit.append(new_ei)
test = ei.bufs[0].numpy()
if getenv("VALIDATE"):
import numpy as np
np.testing.assert_allclose(correct, test, rtol=1e-3, atol=1e-3)
knum += 1
if getenv("RUN_JIT", 0):
fxn.captured.free_intermediates()
fxn.captured.jit_cache = new_jit
fxn(input=Tensor(np.zeros((1, 3, 224, 224), dtype=np.float32), device="DSP"))

View File

@@ -150,7 +150,6 @@ def split_load_store(ctx:Renderer|None, ls:UOp, idx:UOp):
must_divide = True
if ctx is not None and ctx.device == "DSP":
lengths = [128,64,32,16,8,4]
if ls.dtype.count < 128: return None # leave these as loads (probably means something is broken)
must_divide = False
elif buf.dtype.base != dtypes.float and buf.dtype.base != dtypes.half and not isinstance(buf.dtype, ImageDType):
pass

View File

@@ -162,7 +162,7 @@ pm_lowerer = PatternMatcher([
# **** this is the "quantization preprocessor", it makes ONNX quantized models, and probably also others, actually use ints ****
FP = (1 << 16)
FP = (1 << 15)
pm_quant = symbolic+PatternMatcher([
# cast after add/mul
(UPat.var("x").cast(dtypes.float32) + UPat.var("y").cast(dtypes.float32),
@@ -192,12 +192,15 @@ pm_quant = symbolic+PatternMatcher([
UPat(Ops.VALID, src=(UPat(Ops.VIEW, name="v"),)).where(UPat.cvar(), UPat(Ops.CONST, arg=0))).cast(dtypes.float).named("ld"),
lambda ld,v,c1: ld*c1),
# const push through add
((UPat.var("x")*UPat.cvar("c1") + UPat.var("y")*UPat.cvar("c2")) * UPat.cvar("c3"), lambda x,y,c1,c2,c3: (x*c1*c3) + (y*c2*c3)),
# fixed point mult, replace (x.float()*c1+c2).int() with an int expression
((UPat.var("x").cast(dtypes.float)*UPat.var("c1")+UPat.var("c2")).cast(dtypes.int),
lambda x,c1,c2: (x * (c1 * FP).cast(dtypes.int) + (c2 * FP).cast(dtypes.int)) // FP),
# fixed point mult, replace (x.float()*c1 + y.float()*c2) with an int expression
((UPat.var("x").cast(dtypes.float)*UPat.var("c1")+UPat.var("y").cast(dtypes.float)*UPat.var("c2")),
lambda x,y,c1,c2: ((x * (c1 * FP).cast(dtypes.int) + y * (c2 * FP).cast(dtypes.int)) // FP).cast(dtypes.float)),
((UPat.var("x").cast(dtypes.float)*UPat.var("c1")+UPat.var("cc")).cast(dtypes.int),
lambda x,c1,cc: ((x*(c1*FP).cast(x.dtype) + (cc*FP).cast(x.dtype)) // FP).cast(dtypes.int)),
# fixed point mult, replace (x.float()*c1 + y.float()*c2)*cc.int() with an int expression
((UPat.var("x").cast(dtypes.float)*UPat.var("c1")+UPat.var("y").cast(dtypes.float)*UPat.var("c2")+UPat.var("cc")).cast(dtypes.int),
lambda x,c1,y,c2,cc: ((x*(c1*FP).cast(x.dtype) + y.cast(x.dtype)*(c2*FP).cast(x.dtype) + (cc*FP).cast(x.dtype)) // FP).cast(dtypes.int)),
# where move
(UPat.var("valid").where(UPat.var("yes"), UPat(Ops.CONST, arg=0))*UPat.var("mul"), lambda valid, yes, mul:
@@ -211,13 +214,13 @@ pm_quant = symbolic+PatternMatcher([
# where on two adds
(UPat.var("x") + UPat.var("v").where(UPat.var("a0"), UPat.var("a1")) + UPat.var("v").where(UPat.var("b0"), UPat.var("b1")),
lambda x,v,a0,a1,b0,b1: x + v.where(a0+a1, b0+b1)),
lambda x,v,a0,a1,b0,b1: x + v.where(a0+b0, a1+b1)),
# split REDUCE into multiple reduces
# split REDUCE into multiple reduces (who remembers FOIL?)
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.CAST, name="v1")+UPat.var("c1")) * UPat(Ops.CAST, name="v2",), name="r"),
lambda v1,v2,c1,r: r.replace(src=(v1*v2,)) + r.replace(src=(c1*v2,))),
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.CAST, name="v1")+UPat.var("c1")) * (UPat(Ops.CAST, name="v2",)+UPat.var("c2")), name="r"),
lambda v1,v2,c1,c2,r: r.replace(src=(v1*v2,)) + r.replace(src=(c2*v1,)) + r.replace(src=(c1*v2,))),
lambda v1,v2,c1,c2,r: r.replace(src=(v1*v2,)) + r.replace(src=(c2*v1,)) + r.replace(src=(c1*v2,)) + r.replace(src=(c1*c2,))),
])
def rewrite_shapetracker_with_index(ast:UOp, opts:Renderer) -> UOp:

View File

@@ -4,7 +4,7 @@ import math, operator, struct, functools
from collections import defaultdict
from tinygrad.ops import Ops, PatternMatcher, UPat, UOp, GroupOp, exec_alu
from tinygrad.dtype import ConstType, dtypes, PtrDType
from tinygrad.helpers import partition, all_same, prod, getenv, DEBUG, flatten
from tinygrad.helpers import partition, all_same, prod, getenv, DEBUG, flatten, get_single_element
from tinygrad.codegen.transcendental import xpow
# ******** phase 1 of symbolic used to live in ops, it's the most generic folding rules ********
@@ -184,6 +184,11 @@ gep_pushing = PatternMatcher([
(UPat(Ops.GEP, src=(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST), name='alu'),), name='gep'),
lambda gep,alu: UOp(alu.op, alu.dtype.scalar().vec(gep.dtype.count), tuple(x.gep(gep.arg) for x in alu.src), alu.arg) \
if not isinstance(gep.dtype, PtrDType) else None),
# CAT can't be rendered. it's a VECTORIZE on vectors, we expand to a single VECTORIZEs with GEPs (TODO: move this later)
(UPat(Ops.CAT, name="x"), lambda x: UOp(Ops.VECTORIZE, x.dtype, tuple(y.gep(i) for y in x.src for i in range(y.dtype.count))) \
if not isinstance(x.dtype, PtrDType) else None),
# VECTORIZE on same GEP
(UPat(Ops.VECTORIZE, name="v", src=UPat(Ops.GEP, src=(UPat.var("x"),))), lambda v,x: x.gep(tuple(get_single_element(i.arg) for i in v.src))),
])
symbolic = symbolic_simple+PatternMatcher([
@@ -420,9 +425,6 @@ sym = symbolic_flat+PatternMatcher([
(UPat(Ops.VECTORIZE, dtype=dtypes.void, name='x'), lambda x: UOp(Ops.SINK, dtypes.void, x.src)),
# push some GEPs through WMMAs
(UPat(Ops.GEP, src=(UPat(Ops.WMMA, name="wmma"),), name="gep"), gep_through_wmma),
# CAT can't be rendered. it's a VECTORIZE on vectors, we expand to a single VECTORIZEs with GEPs (TODO: move this later)
(UPat(Ops.CAT, name="x"), lambda x: UOp(Ops.VECTORIZE, x.dtype, tuple(y.gep(i) for y in x.src for i in range(y.dtype.count))) \
if not isinstance(x.dtype, PtrDType) else None),
# tensor core with a 0 input is acc
(UPat(Ops.WMMA, src=(UPat.const(None, 0.0), UPat.var(), UPat.var("acc"))), lambda acc: acc),
(UPat(Ops.WMMA, src=(UPat.var(), UPat.const(None, 0.0), UPat.var("acc"))), lambda acc: acc),

View File

@@ -168,8 +168,8 @@ class CapturedJit(Generic[ReturnType]):
update_depends(depends, self.jit_cache)
for b in depends:
if b is not None:
b.deallocate()
if b._base is not None and b._base.allocated_views == 0: b._base.deallocate()
if b.is_allocated(): b.deallocate()
if b._base is not None and b._base.allocated_views == 0 and b._base.is_allocated(): b._base.deallocate()
self.__post_init__() # reset the graph state
def optimize_weights(self):

View File

@@ -398,7 +398,7 @@ if CAPTURE_PROCESS_REPLAY:
class ScheduleItem:
ast: UOp
bufs: tuple[Buffer, ...]
metadata: tuple[Metadata, ...]
metadata: tuple[Metadata, ...] = ()
@track_rewrites(name_fxn=lambda r: f"Schedule {pluralize('Kernel', len(r[0]))}"+(f" (with_{pluralize('Var', len(r[1]))})" if len(r[1]) != 0 else ""))
def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Variable, int], dict[UOp, UOp]]:

View File

@@ -322,10 +322,10 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
@functools.cached_property
def full_shape(self) -> tuple[sint, ...]:
if self.op is Ops.VIEW: return self.shape
# TODO: this exists because wmma creates consts without ShapeTracker in the AST, there's probably a way to fix this
parent_shapes = [x.full_shape for x in self.src if x.op not in {Ops.DEFINE_GLOBAL,Ops.DEFINE_LOCAL} and not (x.op is Ops.CONST and x.st is None)]
# TODO: this should check if st is None, it cannot because local reduce has implicit movement ops
return tuple(smax(x) for x in zip(*[x.full_shape for x in self.src if x.op not in {Ops.DEFINE_GLOBAL,Ops.DEFINE_LOCAL} \
# TODO: this exists because wmma creates consts without ShapeTracker in the AST, there's probably a way to fix this
and not (x.op is Ops.CONST and x.st is None)]))
return tuple(smax(x) for x in zip(*[x for x in parent_shapes if x != ()]))
@property
def shape(self) -> tuple[sint, ...]: return unwrap(self.st).shape
@property
@@ -349,6 +349,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
def __int__(self): return self._eval(dtypes.ints, int)
def __float__(self): return self._eval(dtypes.floats, float)
def substitute(self, dvars:dict[UOp, UOp]):
if len(dvars) == 0: return self
with Context(TRACK_MATCH_STATS=0):
return graph_rewrite(self, _substitute, dvars, bottom_up=True)
@@ -976,6 +977,8 @@ renderer = PatternMatcher([
(UPat(Ops.RANGE, name="x"), lambda x: UOp(Ops.NOOP, arg=f"ridx{x.arg}")),
(UPat((Ops.CONST, Ops.VCONST), name="x"), lambda x: UOp(Ops.NOOP, arg=str(x.arg))),
(UPat(Ops.UNROLL, name="x"), lambda x: UOp(Ops.NOOP, arg=f"UNROLL({x.src[0].arg}, {x.arg})")),
(UPat(Ops.CAST, name="x"), lambda x: UOp(Ops.NOOP, arg=f"({str(x.dtype)[7:]})({x.src[0].arg})")),
(UPat(Ops.LOAD), lambda: UOp(Ops.NOOP, arg="load")),
(UPat(Ops.BIND, src=UPat(Ops.NOOP), name="x"), lambda x: x.src[0]),
(UPat(Ops.NEG, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"(-{x.src[0].arg})")),
(UPat(Ops.MAX, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"max({x.src[0].arg}, {x.src[1].arg})")),

View File

@@ -70,6 +70,8 @@ class Estimates:
if u.op is Ops.RANGE:
mult_stack.append(mults)
mults *= (u.src[1] - u.src[0]).ssimplify()
# SPECIAL are already counted in mults
mults = mults.substitute({x:x.const_like(0) for x in mults.toposort if x.op is Ops.SPECIAL}) if isinstance(mults, UOp) else mults
elif u.op is Ops.ENDRANGE: mults = mult_stack.pop(-1)
elif u.op is Ops.SPECIAL: mults *= u.arg[1] # NOTE: we don't push to the mult_stack here, you can't end these
elif u.op is Ops.LOAD: lds += u.dtype.itemsize * mults

View File

@@ -33,9 +33,9 @@ class CPUGraph(GraphRunner):
prep = [device.renderer._render(cast(CompiledRunner, ji.prg).p.uops) for i,ji in enumerate(jit_cache)]
funcs = dedup(device.renderer._render_body(prep[i][0], *prep[i][1:], cast(CompiledRunner, ji.prg).p.uops) for i,ji in enumerate(jit_cache))
defines = '\n'.join(set(itertools.chain.from_iterable(device.renderer._render_defines(cast(CompiledRunner, ji.prg).p.uops) for ji in jit_cache)))
defines = dedup(itertools.chain.from_iterable(device.renderer._render_defines(cast(CompiledRunner, ji.prg).p.uops) for ji in jit_cache))
entry = device.renderer._render_entry("batched", targs)
code = defines + '\n' + '\n'.join([''.join(f) for f in funcs]) + '\n'.join(batched) + '\n' + entry
code = '\n'.join(defines) + '\n' + '\n'.join([''.join(f) for f in funcs]) + '\n'.join(batched) + '\n' + entry
if DEBUG >= 4: print(code)
self.clprg = device.runtime("batched", device.compiler.compile_cached(code))