mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
fixes from the dsp branch + 12500 lines (#9683)
* fixes from the dsp branch * more changes * those are gep pushing
This commit is contained in:
4
.github/workflows/test.yml
vendored
4
.github/workflows/test.yml
vendored
@@ -335,8 +335,8 @@ jobs:
|
||||
run: awk '/```python/{flag=1;next}/```/{flag=0}flag' README.md > README.py && PYTHONPATH=. python README.py
|
||||
- name: Run unit tests
|
||||
run: PYTHONPATH="." python -m pytest -n=auto test/unit/
|
||||
- name: Repo line count < 12000 lines
|
||||
run: MAX_LINE_COUNT=12000 python sz.py
|
||||
- name: Repo line count < 12500 lines
|
||||
run: MAX_LINE_COUNT=12500 python sz.py
|
||||
|
||||
fuzzing:
|
||||
name: Fuzzing
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import sys, onnx, time
|
||||
import sys, onnx, time, pickle
|
||||
from tinygrad import TinyJit, Device, GlobalCounters, fetch, getenv
|
||||
from tinygrad.frontend.onnx import OnnxRunner
|
||||
from extra.onnx_helpers import get_example_inputs, validate
|
||||
@@ -33,4 +33,9 @@ if __name__ == "__main__":
|
||||
|
||||
if getenv("ORT"):
|
||||
validate(onnx_file, new_inputs, rtol=1e-3, atol=1e-3)
|
||||
print("model validated")
|
||||
print("model validated")
|
||||
|
||||
if (fn:=getenv("SAVE_PKL", "")) != "":
|
||||
with open(fn, "wb") as f:
|
||||
pickle.dump(run_onnx_jit, f)
|
||||
print(f"pkl saved to {fn}")
|
||||
|
||||
@@ -70,7 +70,7 @@ if __name__ == "__main__":
|
||||
GlobalCounters.reset()
|
||||
p = run_onnx_jit(**{t_name:img})
|
||||
assert p.shape == (1,1000)
|
||||
t = p.argmax().item()
|
||||
t = p.to('cpu').argmax().item()
|
||||
hit += y==t
|
||||
print(f"target: {y:3d} pred: {t:3d} acc: {hit/(i+1)*100:.2f}%")
|
||||
|
||||
|
||||
@@ -1,12 +1,34 @@
|
||||
import pickle, sys
|
||||
from dataclasses import replace
|
||||
from tinygrad import Device, Context
|
||||
from tinygrad import Device, Context, Tensor, GlobalCounters
|
||||
from tinygrad.device import Buffer
|
||||
from tinygrad.helpers import getenv, BEAM
|
||||
from tinygrad.engine.jit import TinyJit
|
||||
from tinygrad.engine.realize import CompiledRunner
|
||||
from tinygrad.engine.realize import CompiledRunner, ExecItem, ScheduleItem, lower_schedule_item
|
||||
from tinygrad.renderer import ProgramSpec
|
||||
from tinygrad.codegen.kernel import Kernel, Opt, OptOps
|
||||
import numpy as np
|
||||
|
||||
def move_jit_captured_to_dev(captured, device="DSP"):
|
||||
captured.expected_st_vars_dtype_device = [x[:3] + (device,) for x in captured.expected_st_vars_dtype_device]
|
||||
|
||||
assign = {}
|
||||
def move_buffer(b):
|
||||
if b in assign: return assign[b]
|
||||
|
||||
if b._base is not None:
|
||||
newbuf = Buffer(device, b.size, b.dtype, base=move_buffer(b._base), offset=b.offset)
|
||||
else:
|
||||
newbuf = Buffer(device, b.size, b.dtype)
|
||||
if b.is_allocated(): newbuf.ensure_allocated().copyin(b.as_buffer())
|
||||
assign[b] = newbuf
|
||||
return assign[b]
|
||||
|
||||
for item in captured.jit_cache:
|
||||
for b in item.bufs:
|
||||
if b is not None: move_buffer(b)
|
||||
captured.jit_cache = [ExecItem(item.prg, [assign.get(b,b) for b in item.bufs]) for item in captured.jit_cache]
|
||||
return captured
|
||||
|
||||
if __name__ == "__main__":
|
||||
with Context(DEBUG=0):
|
||||
@@ -15,6 +37,10 @@ if __name__ == "__main__":
|
||||
print(f"{f.tell()/1e6:.2f}M loaded")
|
||||
print(type(fxn))
|
||||
|
||||
# Move all buffers to DSP device.
|
||||
fxn.captured = move_jit_captured_to_dev(fxn.captured, "DSP")
|
||||
new_jit = []
|
||||
|
||||
knum = 1
|
||||
for ei in fxn.captured.jit_cache:
|
||||
# skip the copy and the first kernel
|
||||
@@ -22,9 +48,27 @@ if __name__ == "__main__":
|
||||
if knum == (pknum:=getenv("KNUM", 0)) or pknum == 0:
|
||||
p: ProgramSpec = ei.prg.p
|
||||
k = Kernel(p.ast, Device["DSP"].renderer)
|
||||
dsp_bufs = [Buffer("DSP", 8192+b.size, b.dtype).view(b.size, b.dtype, 4096) for b in ei.bufs]
|
||||
k.hand_coded_optimizations()
|
||||
|
||||
if getenv("VALIDATE"):
|
||||
with Context(NOOPT=1):
|
||||
lower_schedule_item(ScheduleItem(p.ast, ei.bufs)).run()
|
||||
correct = ei.bufs[0].numpy()
|
||||
ei.bufs[0].copyin(memoryview(bytearray(b'\x00'*ei.bufs[0].nbytes)))
|
||||
GlobalCounters.kernel_count -= 1
|
||||
|
||||
if not getenv("NOOPT"): k.hand_coded_optimizations()
|
||||
p2 = k.to_program()
|
||||
new_ei = replace(ei, prg=CompiledRunner(p2), bufs=dsp_bufs)
|
||||
new_ei = replace(ei, prg=CompiledRunner(p2))
|
||||
new_ei.run()
|
||||
new_jit.append(new_ei)
|
||||
test = ei.bufs[0].numpy()
|
||||
|
||||
if getenv("VALIDATE"):
|
||||
import numpy as np
|
||||
np.testing.assert_allclose(correct, test, rtol=1e-3, atol=1e-3)
|
||||
knum += 1
|
||||
|
||||
if getenv("RUN_JIT", 0):
|
||||
fxn.captured.free_intermediates()
|
||||
fxn.captured.jit_cache = new_jit
|
||||
fxn(input=Tensor(np.zeros((1, 3, 224, 224), dtype=np.float32), device="DSP"))
|
||||
|
||||
@@ -150,7 +150,6 @@ def split_load_store(ctx:Renderer|None, ls:UOp, idx:UOp):
|
||||
must_divide = True
|
||||
if ctx is not None and ctx.device == "DSP":
|
||||
lengths = [128,64,32,16,8,4]
|
||||
if ls.dtype.count < 128: return None # leave these as loads (probably means something is broken)
|
||||
must_divide = False
|
||||
elif buf.dtype.base != dtypes.float and buf.dtype.base != dtypes.half and not isinstance(buf.dtype, ImageDType):
|
||||
pass
|
||||
|
||||
@@ -162,7 +162,7 @@ pm_lowerer = PatternMatcher([
|
||||
|
||||
# **** this is the "quantization preprocessor", it makes ONNX quantized models, and probably also others, actually use ints ****
|
||||
|
||||
FP = (1 << 16)
|
||||
FP = (1 << 15)
|
||||
pm_quant = symbolic+PatternMatcher([
|
||||
# cast after add/mul
|
||||
(UPat.var("x").cast(dtypes.float32) + UPat.var("y").cast(dtypes.float32),
|
||||
@@ -192,12 +192,15 @@ pm_quant = symbolic+PatternMatcher([
|
||||
UPat(Ops.VALID, src=(UPat(Ops.VIEW, name="v"),)).where(UPat.cvar(), UPat(Ops.CONST, arg=0))).cast(dtypes.float).named("ld"),
|
||||
lambda ld,v,c1: ld*c1),
|
||||
|
||||
# const push through add
|
||||
((UPat.var("x")*UPat.cvar("c1") + UPat.var("y")*UPat.cvar("c2")) * UPat.cvar("c3"), lambda x,y,c1,c2,c3: (x*c1*c3) + (y*c2*c3)),
|
||||
|
||||
# fixed point mult, replace (x.float()*c1+c2).int() with an int expression
|
||||
((UPat.var("x").cast(dtypes.float)*UPat.var("c1")+UPat.var("c2")).cast(dtypes.int),
|
||||
lambda x,c1,c2: (x * (c1 * FP).cast(dtypes.int) + (c2 * FP).cast(dtypes.int)) // FP),
|
||||
# fixed point mult, replace (x.float()*c1 + y.float()*c2) with an int expression
|
||||
((UPat.var("x").cast(dtypes.float)*UPat.var("c1")+UPat.var("y").cast(dtypes.float)*UPat.var("c2")),
|
||||
lambda x,y,c1,c2: ((x * (c1 * FP).cast(dtypes.int) + y * (c2 * FP).cast(dtypes.int)) // FP).cast(dtypes.float)),
|
||||
((UPat.var("x").cast(dtypes.float)*UPat.var("c1")+UPat.var("cc")).cast(dtypes.int),
|
||||
lambda x,c1,cc: ((x*(c1*FP).cast(x.dtype) + (cc*FP).cast(x.dtype)) // FP).cast(dtypes.int)),
|
||||
# fixed point mult, replace (x.float()*c1 + y.float()*c2)*cc.int() with an int expression
|
||||
((UPat.var("x").cast(dtypes.float)*UPat.var("c1")+UPat.var("y").cast(dtypes.float)*UPat.var("c2")+UPat.var("cc")).cast(dtypes.int),
|
||||
lambda x,c1,y,c2,cc: ((x*(c1*FP).cast(x.dtype) + y.cast(x.dtype)*(c2*FP).cast(x.dtype) + (cc*FP).cast(x.dtype)) // FP).cast(dtypes.int)),
|
||||
|
||||
# where move
|
||||
(UPat.var("valid").where(UPat.var("yes"), UPat(Ops.CONST, arg=0))*UPat.var("mul"), lambda valid, yes, mul:
|
||||
@@ -211,13 +214,13 @@ pm_quant = symbolic+PatternMatcher([
|
||||
|
||||
# where on two adds
|
||||
(UPat.var("x") + UPat.var("v").where(UPat.var("a0"), UPat.var("a1")) + UPat.var("v").where(UPat.var("b0"), UPat.var("b1")),
|
||||
lambda x,v,a0,a1,b0,b1: x + v.where(a0+a1, b0+b1)),
|
||||
lambda x,v,a0,a1,b0,b1: x + v.where(a0+b0, a1+b1)),
|
||||
|
||||
# split REDUCE into multiple reduces
|
||||
# split REDUCE into multiple reduces (who remembers FOIL?)
|
||||
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.CAST, name="v1")+UPat.var("c1")) * UPat(Ops.CAST, name="v2",), name="r"),
|
||||
lambda v1,v2,c1,r: r.replace(src=(v1*v2,)) + r.replace(src=(c1*v2,))),
|
||||
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.CAST, name="v1")+UPat.var("c1")) * (UPat(Ops.CAST, name="v2",)+UPat.var("c2")), name="r"),
|
||||
lambda v1,v2,c1,c2,r: r.replace(src=(v1*v2,)) + r.replace(src=(c2*v1,)) + r.replace(src=(c1*v2,))),
|
||||
lambda v1,v2,c1,c2,r: r.replace(src=(v1*v2,)) + r.replace(src=(c2*v1,)) + r.replace(src=(c1*v2,)) + r.replace(src=(c1*c2,))),
|
||||
])
|
||||
|
||||
def rewrite_shapetracker_with_index(ast:UOp, opts:Renderer) -> UOp:
|
||||
|
||||
@@ -4,7 +4,7 @@ import math, operator, struct, functools
|
||||
from collections import defaultdict
|
||||
from tinygrad.ops import Ops, PatternMatcher, UPat, UOp, GroupOp, exec_alu
|
||||
from tinygrad.dtype import ConstType, dtypes, PtrDType
|
||||
from tinygrad.helpers import partition, all_same, prod, getenv, DEBUG, flatten
|
||||
from tinygrad.helpers import partition, all_same, prod, getenv, DEBUG, flatten, get_single_element
|
||||
from tinygrad.codegen.transcendental import xpow
|
||||
|
||||
# ******** phase 1 of symbolic used to live in ops, it's the most generic folding rules ********
|
||||
@@ -184,6 +184,11 @@ gep_pushing = PatternMatcher([
|
||||
(UPat(Ops.GEP, src=(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST), name='alu'),), name='gep'),
|
||||
lambda gep,alu: UOp(alu.op, alu.dtype.scalar().vec(gep.dtype.count), tuple(x.gep(gep.arg) for x in alu.src), alu.arg) \
|
||||
if not isinstance(gep.dtype, PtrDType) else None),
|
||||
# CAT can't be rendered. it's a VECTORIZE on vectors, we expand to a single VECTORIZEs with GEPs (TODO: move this later)
|
||||
(UPat(Ops.CAT, name="x"), lambda x: UOp(Ops.VECTORIZE, x.dtype, tuple(y.gep(i) for y in x.src for i in range(y.dtype.count))) \
|
||||
if not isinstance(x.dtype, PtrDType) else None),
|
||||
# VECTORIZE on same GEP
|
||||
(UPat(Ops.VECTORIZE, name="v", src=UPat(Ops.GEP, src=(UPat.var("x"),))), lambda v,x: x.gep(tuple(get_single_element(i.arg) for i in v.src))),
|
||||
])
|
||||
|
||||
symbolic = symbolic_simple+PatternMatcher([
|
||||
@@ -420,9 +425,6 @@ sym = symbolic_flat+PatternMatcher([
|
||||
(UPat(Ops.VECTORIZE, dtype=dtypes.void, name='x'), lambda x: UOp(Ops.SINK, dtypes.void, x.src)),
|
||||
# push some GEPs through WMMAs
|
||||
(UPat(Ops.GEP, src=(UPat(Ops.WMMA, name="wmma"),), name="gep"), gep_through_wmma),
|
||||
# CAT can't be rendered. it's a VECTORIZE on vectors, we expand to a single VECTORIZEs with GEPs (TODO: move this later)
|
||||
(UPat(Ops.CAT, name="x"), lambda x: UOp(Ops.VECTORIZE, x.dtype, tuple(y.gep(i) for y in x.src for i in range(y.dtype.count))) \
|
||||
if not isinstance(x.dtype, PtrDType) else None),
|
||||
# tensor core with a 0 input is acc
|
||||
(UPat(Ops.WMMA, src=(UPat.const(None, 0.0), UPat.var(), UPat.var("acc"))), lambda acc: acc),
|
||||
(UPat(Ops.WMMA, src=(UPat.var(), UPat.const(None, 0.0), UPat.var("acc"))), lambda acc: acc),
|
||||
|
||||
@@ -168,8 +168,8 @@ class CapturedJit(Generic[ReturnType]):
|
||||
update_depends(depends, self.jit_cache)
|
||||
for b in depends:
|
||||
if b is not None:
|
||||
b.deallocate()
|
||||
if b._base is not None and b._base.allocated_views == 0: b._base.deallocate()
|
||||
if b.is_allocated(): b.deallocate()
|
||||
if b._base is not None and b._base.allocated_views == 0 and b._base.is_allocated(): b._base.deallocate()
|
||||
self.__post_init__() # reset the graph state
|
||||
|
||||
def optimize_weights(self):
|
||||
|
||||
@@ -398,7 +398,7 @@ if CAPTURE_PROCESS_REPLAY:
|
||||
class ScheduleItem:
|
||||
ast: UOp
|
||||
bufs: tuple[Buffer, ...]
|
||||
metadata: tuple[Metadata, ...]
|
||||
metadata: tuple[Metadata, ...] = ()
|
||||
|
||||
@track_rewrites(name_fxn=lambda r: f"Schedule {pluralize('Kernel', len(r[0]))}"+(f" (with_{pluralize('Var', len(r[1]))})" if len(r[1]) != 0 else ""))
|
||||
def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Variable, int], dict[UOp, UOp]]:
|
||||
|
||||
@@ -322,10 +322,10 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
@functools.cached_property
|
||||
def full_shape(self) -> tuple[sint, ...]:
|
||||
if self.op is Ops.VIEW: return self.shape
|
||||
# TODO: this exists because wmma creates consts without ShapeTracker in the AST, there's probably a way to fix this
|
||||
parent_shapes = [x.full_shape for x in self.src if x.op not in {Ops.DEFINE_GLOBAL,Ops.DEFINE_LOCAL} and not (x.op is Ops.CONST and x.st is None)]
|
||||
# TODO: this should check if st is None, it cannot because local reduce has implicit movement ops
|
||||
return tuple(smax(x) for x in zip(*[x.full_shape for x in self.src if x.op not in {Ops.DEFINE_GLOBAL,Ops.DEFINE_LOCAL} \
|
||||
# TODO: this exists because wmma creates consts without ShapeTracker in the AST, there's probably a way to fix this
|
||||
and not (x.op is Ops.CONST and x.st is None)]))
|
||||
return tuple(smax(x) for x in zip(*[x for x in parent_shapes if x != ()]))
|
||||
@property
|
||||
def shape(self) -> tuple[sint, ...]: return unwrap(self.st).shape
|
||||
@property
|
||||
@@ -349,6 +349,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
def __int__(self): return self._eval(dtypes.ints, int)
|
||||
def __float__(self): return self._eval(dtypes.floats, float)
|
||||
def substitute(self, dvars:dict[UOp, UOp]):
|
||||
if len(dvars) == 0: return self
|
||||
with Context(TRACK_MATCH_STATS=0):
|
||||
return graph_rewrite(self, _substitute, dvars, bottom_up=True)
|
||||
|
||||
@@ -976,6 +977,8 @@ renderer = PatternMatcher([
|
||||
(UPat(Ops.RANGE, name="x"), lambda x: UOp(Ops.NOOP, arg=f"ridx{x.arg}")),
|
||||
(UPat((Ops.CONST, Ops.VCONST), name="x"), lambda x: UOp(Ops.NOOP, arg=str(x.arg))),
|
||||
(UPat(Ops.UNROLL, name="x"), lambda x: UOp(Ops.NOOP, arg=f"UNROLL({x.src[0].arg}, {x.arg})")),
|
||||
(UPat(Ops.CAST, name="x"), lambda x: UOp(Ops.NOOP, arg=f"({str(x.dtype)[7:]})({x.src[0].arg})")),
|
||||
(UPat(Ops.LOAD), lambda: UOp(Ops.NOOP, arg="load")),
|
||||
(UPat(Ops.BIND, src=UPat(Ops.NOOP), name="x"), lambda x: x.src[0]),
|
||||
(UPat(Ops.NEG, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"(-{x.src[0].arg})")),
|
||||
(UPat(Ops.MAX, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"max({x.src[0].arg}, {x.src[1].arg})")),
|
||||
|
||||
@@ -70,6 +70,8 @@ class Estimates:
|
||||
if u.op is Ops.RANGE:
|
||||
mult_stack.append(mults)
|
||||
mults *= (u.src[1] - u.src[0]).ssimplify()
|
||||
# SPECIAL are already counted in mults
|
||||
mults = mults.substitute({x:x.const_like(0) for x in mults.toposort if x.op is Ops.SPECIAL}) if isinstance(mults, UOp) else mults
|
||||
elif u.op is Ops.ENDRANGE: mults = mult_stack.pop(-1)
|
||||
elif u.op is Ops.SPECIAL: mults *= u.arg[1] # NOTE: we don't push to the mult_stack here, you can't end these
|
||||
elif u.op is Ops.LOAD: lds += u.dtype.itemsize * mults
|
||||
|
||||
@@ -33,9 +33,9 @@ class CPUGraph(GraphRunner):
|
||||
prep = [device.renderer._render(cast(CompiledRunner, ji.prg).p.uops) for i,ji in enumerate(jit_cache)]
|
||||
funcs = dedup(device.renderer._render_body(prep[i][0], *prep[i][1:], cast(CompiledRunner, ji.prg).p.uops) for i,ji in enumerate(jit_cache))
|
||||
|
||||
defines = '\n'.join(set(itertools.chain.from_iterable(device.renderer._render_defines(cast(CompiledRunner, ji.prg).p.uops) for ji in jit_cache)))
|
||||
defines = dedup(itertools.chain.from_iterable(device.renderer._render_defines(cast(CompiledRunner, ji.prg).p.uops) for ji in jit_cache))
|
||||
entry = device.renderer._render_entry("batched", targs)
|
||||
code = defines + '\n' + '\n'.join([''.join(f) for f in funcs]) + '\n'.join(batched) + '\n' + entry
|
||||
code = '\n'.join(defines) + '\n' + '\n'.join([''.join(f) for f in funcs]) + '\n'.join(batched) + '\n' + entry
|
||||
|
||||
if DEBUG >= 4: print(code)
|
||||
self.clprg = device.runtime("batched", device.compiler.compile_cached(code))
|
||||
|
||||
Reference in New Issue
Block a user