reduce collapse generic (#10045)

* reduce collapse generic

* new arange folder

* new range folding

* correct with sym

* all tests pass

* indexing ops passes

* failing tests

* fix tests, remove unused

* revert that

* torch indexing is fast

* skip on webgpu

* touchups

* comments
This commit is contained in:
George Hotz
2025-04-26 09:13:24 -04:00
committed by GitHub
parent 5cdc96409e
commit ea5dddc537
7 changed files with 105 additions and 58 deletions

View File

@@ -208,7 +208,7 @@ jobs:
sudo apt update || true
sudo apt install -y --no-install-recommends ninja-build
- name: Test beautiful_mnist in torch with TINY_BACKEND
run: PYTHONPATH=. LLVM=1 TARGET_EVAL_ACC_PCT=96.0 TINY_BACKEND=1 python3 examples/other_mnist/beautiful_mnist_torch.py
run: SPLIT_REDUCEOP=0 FUSE_ARANGE=1 PYTHONPATH=. LLVM=1 TARGET_EVAL_ACC_PCT=96.0 TINY_BACKEND=1 python3 examples/other_mnist/beautiful_mnist_torch.py
- name: Test some torch tests (expect failure)
run: PYTHONPATH=. python3 -m pytest extra/torch_backend/torch_tests.py -v --tb=no || true

View File

@@ -54,7 +54,7 @@ if __name__ == "__main__":
return loss
test_acc = float('nan')
for i in (t:=trange(70)):
for i in (t:=trange(getenv("STEPS", 70))):
samples = torch.randint(0, X_train.shape[0], (512,)) # putting this in JIT didn't work well
loss = step(samples)
if i%10 == 9: test_acc = ((model(X_test).argmax(axis=-1) == Y_test).sum() * 100 / X_test.shape[0]).item()

View File

@@ -2,7 +2,7 @@
import unittest
import torch
import numpy as np
from tinygrad.helpers import getenv
from tinygrad.helpers import getenv, Context
if getenv("TINY_BACKEND2"):
import extra.torch_backend.backend2
device = "cpu"
@@ -165,5 +165,15 @@ class TestTorchBackend(unittest.TestCase):
result = a // b
np.testing.assert_equal(result.cpu().numpy(), [3., 3., 2.])
def test_mnist_index(self):
with Context(FUSE_ARANGE=1, SPLIT_REDUCEOP=0):
from tinygrad.nn.datasets import mnist
X_train, Y_train, _, _ = mnist()
X_train = torch.tensor(X_train.float().numpy(), device=device)
Y_train = torch.tensor(Y_train.cast('int64').numpy(), device=device)
samples = torch.randint(0, X_train.shape[0], (32,))
X,Y = X_train[samples], Y_train[samples]
X.cpu(), Y.cpu()
if __name__ == "__main__":
unittest.main()

View File

@@ -148,10 +148,10 @@ class TestIndexing(unittest.TestCase):
np.testing.assert_equal(X.numpy(), 0)
@unittest.skipIf(getenv("PTX"), "broken on ptx for some reason")
def test_index_mnist(self, noopt=1, op_limit=512*784*13):
def test_index_mnist(self, noopt=1, op_limit=512*784*13, split_reduceop=0):
from tinygrad.nn.datasets import mnist
X_train, Y_train, _, _ = mnist()
with Context(NOOPT=noopt, FUSE_ARANGE=1, SPLIT_REDUCEOP=0):
with Context(NOOPT=noopt, FUSE_ARANGE=1, SPLIT_REDUCEOP=split_reduceop):
samples = Tensor.randint(getenv("BS", 512), high=X_train.shape[0]).realize()
GlobalCounters.reset()
x = X_train[samples].numpy()
@@ -159,8 +159,14 @@ class TestIndexing(unittest.TestCase):
assert GlobalCounters.global_ops < op_limit, f"too many ops {GlobalCounters.global_ops} != {op_limit}"
np.testing.assert_allclose(X_train.numpy()[samples.numpy()], x)
np.testing.assert_allclose(Y_train.numpy()[samples.numpy()], y)
@unittest.skip("not ready")
# TODO: fix these on WEBGPU, it looks like it has to do with packed stuff
@unittest.skipIf(getenv("WEBGPU"), "broken on webgpu for some reason")
def test_index_mnist_opt(self): self.test_index_mnist(0)
@unittest.skipIf(getenv("WEBGPU"), "broken on webgpu for some reason")
def test_index_mnist_split(self): self.test_index_mnist(1, split_reduceop=1)
@unittest.skipIf(getenv("WEBGPU"), "broken on webgpu for some reason")
def test_index_mnist_opt_split(self): self.test_index_mnist(0, split_reduceop=1)
@unittest.skipIf(getenv("PTX"), "broken on ptx for some reason")
def test_llama_embedding(self, noopt=1, op_limit=65536):

View File

@@ -6,7 +6,7 @@ from tinygrad.device import is_dtype_supported
from tinygrad.dtype import dtypes, ImageDType, PtrDType, promo_lattice
from tinygrad.ops import UOp, Ops, UPat, PatternMatcher, resolve, graph_rewrite, GroupOp, identity_element
from tinygrad.codegen.symbolic import symbolic_simple, split_uop, uop_given_valid, parse_valid, simplify_valid, sym, symbolic_flat, gep_pushing
from tinygrad.helpers import getenv, flatten, TRANSCENDENTAL, AMX, prod, DEVECTORIZE
from tinygrad.helpers import getenv, flatten, TRANSCENDENTAL, AMX, prod, DEVECTORIZE, partition
from tinygrad.codegen.transcendental import xexp2, xlog2, xsin, xpow, TRANSCENDENTAL_SUPPORTED_DTYPES
from tinygrad.renderer import Renderer
@@ -332,13 +332,78 @@ def reduce_to_acc(ctx:ReduceContext, red:UOp):
ret = functools.reduce(lambda x,y: x.alu(red.arg, y), lst)
return acc.assign(ret) if len(reduce_range) != 0 else ret
def no_vectorized_reduce(inp:UOp, red:UOp):
if inp.dtype != red.dtype:
# NOTE: [0 1 2 3 4 5 6 7] -> [0+4, 1+5, 2+6, 3+7]
horizontal_amount = inp.dtype.count//red.dtype.count
lst = [inp.gep(tuple(range(i, inp.dtype.count, horizontal_amount))) for i in range(0, horizontal_amount)]
red = red.replace(src=(functools.reduce(lambda x,y: x.alu(red.arg, y), lst),))
if red.dtype.vcount == 1: return red
return no_vectorized_alu(red)
def range_fold(lo:UOp, hi:UOp, st:UOp, cut:UOp, val:UOp) -> UOp:
# psuedo code: sum(val if i >= cut else 0) for i in range(lo, hi, st))
# TODO: this function is so tricky and still probably wrong. test it
total = (hi-lo+st-1) // st # real count in the range
length = ((lo-cut+total*st)//st).maximum(0).minimum(total) # number in cut
return length.cast(val.dtype) * val
def index_fold(buf:UOp, r:UOp, idx:UOp, r2:UOp) -> UOp|None:
if r.arg != r2.arg: return None
base_idx = (idx-r2.src[0])//r2.src[2] # indexed from 0 to the length of the range
return buf.index(base_idx.cast(r.dtype)*r.src[2] + r.src[0], (idx >= r2.src[0]) & (idx < r2.src[1]))
pm_reduce_collapse = PatternMatcher([
# put third arg in range
(UPat(Ops.RANGE, src=(UPat.var(), UPat.var()), name="r"), lambda r: r.replace(src=r.src+(UOp.const(r.dtype, 1),))),
# mul to range
(UPat.var("x") * UPat(Ops.RANGE, name="r"), lambda x,r: r.replace(src=(r.src[0]*x, r.src[1]*x, r.src[2]*x))),
# add to range
(UPat.var("x") + UPat(Ops.RANGE, name="r"), lambda x,r: r.replace(src=(r.src[0]+x, r.src[1]+x, r.src[2]))),
# 0 is "true" arg in where. fold the range
((UPat(Ops.RANGE, src=(UPat.var("lo"), UPat.var("hi"), UPat.var("st"))) < UPat.cvar("cut")) \
.where(UPat(Ops.CONST, arg=0), UPat.cvar("val")).reduce(arg=Ops.ADD), range_fold),
# devectorize REDUCE
(UPat(Ops.VECTORIZE, name="inp").reduce(name="red"), no_vectorized_reduce),
# REDUCE on ADD
((UPat.var("x")+UPat.var("y")).reduce(arg=Ops.ADD), lambda x,y: x.reduce(arg=Ops.ADD) + y.reduce(arg=Ops.ADD)),
# MUL casted bool
((UPat.var("x") * UPat.var("gate", dtype=dtypes.bool).cast()), lambda x,gate: gate.where(x, 0)),
# WHERE on LOAD (works on max too)
(UPat.var("gate").where(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx"))).load(), 0).reduce(arg=Ops.ADD),
lambda buf,idx,gate: buf.index(idx, gate).load()),
(UPat.var("gate").where(UPat(Ops.CONST, arg=0), UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx"))).load()).reduce(arg=Ops.ADD),
lambda buf,idx,gate: buf.index(idx, gate.logical_not()).load()),
# INDEX on RANGE / gated RANGE
(UPat.var("buf").index(UPat(Ops.RANGE, name="r"), UPat.var("idx").eq(UPat(Ops.RANGE, name="r2"))), index_fold),
# index/load. TODO: this is more aggressive than needed
(UPat((Ops.INDEX, Ops.LOAD), name="alu"), no_vectorized_alu),
# cast on RANGE (fix torch indexing)
(UPat(Ops.RANGE, name="r").cast(name="c"), lambda r,c: r.replace(src=tuple([x.cast(c.dtype) for x in r.src]), dtype=c.dtype)),
])+sym
def reduce_collapse(red:UOp):
included, not_included = partition(red.parents, lambda x: any(y in x.sparents for y in red.src[1:]))
if any(x.op in {Ops.STORE, Ops.REDUCE} for x in included): return None
replaces = {red:red.replace(src=red.src[0:1])}
for u in included:
for s in u.src:
if s in not_included and s not in replaces and s.op not in {Ops.CONST, Ops.VCONST, Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL, Ops.DEFINE_VAR}:
replaces[s] = UOp(Ops.DEFINE_VAR, dtype=s.dtype, arg=(f'in{len(replaces)}', s.vmin, s.vmax))
collapse_fxn = red.substitute(replaces)
sink = graph_rewrite(collapse_fxn, pm_reduce_collapse+devectorize, name="reduce_collapse")
if any(x.op in {Ops.REDUCE, Ops.RANGE} for x in sink.toposort()): return None
return sink.substitute({v:k for k,v in replaces.items()})
pm_reduce = PatternMatcher([
# remove REDUCE without loads (generic arange opt / indexing). TODO: support multi range
(UPat(Ops.REDUCE, src=(UPat(), UPat()), name="red"), reduce_collapse),
# REDUCE -> DEFINE_ACC+ASSIGN
(UPat(Ops.REDUCE, name="red"), reduce_to_acc),
# tensor core built in accumulate
(UPat(Ops.WMMA, name="wmma") + UPat.var("add"),
lambda add, wmma: UOp(wmma.op, wmma.dtype, (wmma.src[0], wmma.src[1], wmma.src[2]+add), wmma.arg)),
])
])+sym
# *** uop graph ***

View File

@@ -4,7 +4,7 @@ import math, operator, struct, functools
from collections import defaultdict
from tinygrad.ops import Ops, PatternMatcher, UPat, UOp, GroupOp, exec_alu
from tinygrad.dtype import ConstType, dtypes, PtrDType
from tinygrad.helpers import partition, all_same, prod, getenv, DEBUG, flatten, get_single_element
from tinygrad.helpers import partition, all_same, prod, flatten, get_single_element
from tinygrad.codegen.transcendental import xpow
# ******** phase 1 of symbolic used to live in ops, it's the most generic folding rules ********
@@ -369,40 +369,6 @@ def threefry2x32(x: UOp, key: UOp):
# ******** phase 3 is the complete symbolic, and deals with very complex things like loop rewriting and threefry transform ********
def loop_collapse(compval, multconst, rng:UOp, acc:UOp, extra:UOp, idx2=None,idx3=None,vec=None,
add=UOp.const(dtypes.int, 0), mul:UOp=UOp.const(dtypes.int, 1)):
if getenv("DISABLE_LOOP_COLLAPSE") or rng not in acc.src: return None # must be the right REDUCE
if acc not in split_uop(extra, Ops.ADD): return None
loop_start, loop_end = rng.src
if loop_start.arg != 0:
# TODO: support and test this with other mul and loop_starts
if DEBUG >= 1: print(f"WARNING, NOT FOLDING: mul:{mul.arg} loop_start:{loop_start.arg}")
return None
if idx2 is not None: add = add + idx2
if idx3 is not None: add = add + idx3
if vec is not None:
# add, mul, loop_start, loop_end
def dvec(x:UOp):
if x.op is Ops.CONST: return UOp.const(x.dtype.vec(vec.dtype.count), x.arg)
return UOp(Ops.VECTORIZE, x.dtype.vec(vec.dtype.count), src=(x,)*vec.dtype.count)
add, mul, loop_start, loop_end = dvec(add), dvec(mul), dvec(loop_start), dvec(loop_end)
if mul.vmin > 0:
comprange = UOp.minimum(loop_end, UOp.maximum((add-compval+(loop_end-loop_start)*mul)//mul, loop_start))
else:
return None
new_reduce_op = comprange.cast(multconst.dtype) * multconst
# TODO: what does it mean to have the same numbered DEFINE_ACC with different ranges?
new_acc = acc.replace(src=acc.src[0:1]+tuple(x for x in acc.src[1:] if x is not rng))
ret = new_acc.assign(new_acc+new_reduce_op)
if extra is not acc: ret = ret + acc.assign(extra)
return ret
def index_collapse(idx:UOp,rng:UOp,buf:UOp,ld:UOp,acc:UOp,add=UOp.const(dtypes.int, 0),mul=UOp.const(dtypes.int, 1)):
if rng not in acc.src: return None
new_load = UOp.load(buf.index(add+mul*idx, (idx >= rng.src[0]) & (idx < rng.src[1])), dtype=ld.dtype)
new_acc = acc.replace(src=acc.src[0:1]+tuple(x for x in acc.src[1:] if x is not rng))
return new_acc.assign(new_acc+new_load)
def reduce_collapse(acc:UOp, ret:UOp, alu:UOp):
reduce_parented, reduce_unparented = partition(acc.src[1:], lambda x: x in ret.toposort())
if len(reduce_unparented) == 0: return None
@@ -411,14 +377,7 @@ def reduce_collapse(acc:UOp, ret:UOp, alu:UOp):
if alu.op is Ops.ADD:
for r in reduce_unparented: ret = ret * (r.src[1]-r.src[0]).cast(ret.dtype.scalar()).broadcast(ret.dtype.count)
return ret
acc_pat, rng_pat = UPat(Ops.DEFINE_ACC, name="acc"), UPat(Ops.RANGE, name="rng")
rng_aug = UPat.any(rng_pat, UPat.var("add")+rng_pat, UPat.var("mul")*rng_pat, UPat.var("add")+UPat.var("mul")*rng_pat)
index_load = UPat.var("buf").index(rng_aug).load(name="ld")
arange_augrng = UPat.any(rng_aug, rng_aug+UPat.var("idx2"), rng_aug+UPat.var("idx2")+UPat.var("idx3"), UPat(Ops.VECTORIZE, name="vec", src=rng_aug))
arange_m = (arange_augrng<UPat.cvar("compval")).where(UPat.const(None, 0), UPat.cvar("multconst"))
acc_pat = UPat(Ops.DEFINE_ACC, name="acc")
def reduce_mul_chain(r:UOp):
if r.arg not in {Ops.ADD, Ops.MAX}: return None
@@ -460,11 +419,6 @@ sym = symbolic_flat+PatternMatcher([
lambda x,y: y.where(x, UOp.const(dtypes.uint32, 0)).cast(dtypes.uint64) * (1<<32)),
((UPat.var('x', dtypes.uint64)&(UPat.var('y').where(UPat.const(dtypes.uint64, 0xFFFFFFFF), UPat.const(dtypes.uint64, 0)))).cast(dtypes.uint32),
lambda x,y: y.where(x.cast(dtypes.uint32), UOp.const(dtypes.uint32, 0))),
# arange loop folding
(acc_pat.assign(arange_m+UPat.var("extra")), loop_collapse),
# indexing, with cast or where
(acc_pat.assign(UPat.var("idx").eq(UPat(Ops.RANGE, name="rng")).cast()*index_load+acc_pat), index_collapse),
(acc_pat.assign((UPat.var("idx")!=UPat(Ops.RANGE, name="rng")).where(UPat.const(None, 0.0), index_load)+acc_pat), index_collapse),
# parentless reduce # TODO: add MUL
(acc_pat.assign(UPat((Ops.ADD, Ops.MAX), src=[acc_pat, UPat.var("ret")], name="alu")), reduce_collapse),
# ** self folding **

View File

@@ -266,6 +266,14 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
def __repr__(self): return pretty_print(self, lambda x: f"{type(self).__name__}({x.op}, {x.dtype}, arg={x.argstr()}, src=(%s))")
def argstr(self): return f'({", ".join(map(str, self.arg))})' if self.op is Ops.REDUCE_AXIS else repr(self.arg)
@functools.cached_property
def parents(self:UOp) -> dict[UOp, None]:
ret = {s:None for s in self.src}
for s in self.src: ret.update(s.parents)
return ret
@property
def sparents(self:UOp) -> dict[UOp, None]: return {self:None, **self.parents}
def toposort(self, gate:Callable|None=None) -> dict[UOp, None]:
ret: dict[UOp, None] = {}
stack: list[tuple[UOp, bool]] = [(self, False)] # each stack entry is (node, visited_flag)
@@ -382,6 +390,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
def cast_vec(self, dtype:DType): return UOp(Ops.CAST, dtype.vec(self.dtype.count), (self,))
def bitcast(self, dtype:DType): return UOp(Ops.BITCAST, dtype, (self,))
def gep(self, i:Union[tuple[int, ...], int]):
if isinstance(i, tuple) and len(i) == 1: return self.gep(i[0])
if isinstance(i, int):
# NOTE: these are just shortcuts to not have to create and fold later
if self.op is Ops.VECTORIZE: return self.src[i]
@@ -389,7 +398,9 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
if self.op is Ops.CONST: return UOp.const(self.dtype.scalar(), self.arg)
i = (i,)
return UOp(Ops.GEP, self.dtype.scalar().vec(len(i)) if len(i) > 1 else self.dtype.scalar(), (self,), i)
def load(self, *src:UOp, **kwargs): return UOp(Ops.LOAD, src=(self,)+src, **kwargs)
def load(self, *src:UOp, **kwargs):
if 'dtype' not in kwargs: kwargs['dtype'] = self.dtype.base
return UOp(Ops.LOAD, src=(self,)+src, **kwargs)
def store(self, *src:UOp, **kwargs): return UOp(Ops.STORE, dtypes.void, (self,)+src, **kwargs)
def alu(self, arg, *src:UOp):
out_dtype = (self, *src)[-1].dtype
@@ -412,6 +423,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
axis = tuple(sorted([x for x in axis if resolve(self.shape[x] != 1)]))
return self if len(axis) == 0 else UOp(Ops.REDUCE_AXIS, self.dtype, (self,), (op, axis))
def assign(self, x:UOp): return UOp(Ops.ASSIGN, self.dtype, (self,x))
def reduce(self, *src:UOp, **kwargs): return UOp(Ops.REDUCE, self.dtype, src=(self,)+src, **kwargs)
def contiguous(self): return self.alu(Ops.CONTIGUOUS)
def contiguous_backward(self): return self.alu(Ops.CONTIGUOUS_BACKWARD)
def fuse(self): return self.alu(Ops.FUSE)
@@ -758,7 +770,7 @@ class UPat(MathTrait):
# copied from UOp
def index(self, idx:UPat, valid:Optional[UPat]=None): return UPat(Ops.INDEX, self.dtype, (self,idx,valid) if valid is not None else (self,idx))
def view(self, st=None, **kwargs): return UPat(Ops.VIEW, self.dtype, (self,), st, **kwargs)
def cast(self, dtype=None): return UPat(Ops.CAST, dtype, (self,))
def cast(self, dtype=None, **kwargs): return UPat(Ops.CAST, dtype, (self,), **kwargs)
def bitcast(self, dtype=None): return UPat(Ops.BITCAST, dtype, (self,))
def gep(self, i:int): return UPat(Ops.GEP, None, (self,), (i,))
def load(self, *src:UPat, **kwargs): return UPat(Ops.LOAD, src=(self,)+src, **kwargs)