mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
reduce collapse generic (#10045)
* reduce collapse generic * new arange folder * new range folding * correct with sym * all tests pass * indexing ops passes * failing tests * fix tests, remove unused * revert that * torch indexing is fast * skip on webgpu * touchups * comments
This commit is contained in:
2
.github/workflows/test.yml
vendored
2
.github/workflows/test.yml
vendored
@@ -208,7 +208,7 @@ jobs:
|
||||
sudo apt update || true
|
||||
sudo apt install -y --no-install-recommends ninja-build
|
||||
- name: Test beautiful_mnist in torch with TINY_BACKEND
|
||||
run: PYTHONPATH=. LLVM=1 TARGET_EVAL_ACC_PCT=96.0 TINY_BACKEND=1 python3 examples/other_mnist/beautiful_mnist_torch.py
|
||||
run: SPLIT_REDUCEOP=0 FUSE_ARANGE=1 PYTHONPATH=. LLVM=1 TARGET_EVAL_ACC_PCT=96.0 TINY_BACKEND=1 python3 examples/other_mnist/beautiful_mnist_torch.py
|
||||
- name: Test some torch tests (expect failure)
|
||||
run: PYTHONPATH=. python3 -m pytest extra/torch_backend/torch_tests.py -v --tb=no || true
|
||||
|
||||
|
||||
@@ -54,7 +54,7 @@ if __name__ == "__main__":
|
||||
return loss
|
||||
|
||||
test_acc = float('nan')
|
||||
for i in (t:=trange(70)):
|
||||
for i in (t:=trange(getenv("STEPS", 70))):
|
||||
samples = torch.randint(0, X_train.shape[0], (512,)) # putting this in JIT didn't work well
|
||||
loss = step(samples)
|
||||
if i%10 == 9: test_acc = ((model(X_test).argmax(axis=-1) == Y_test).sum() * 100 / X_test.shape[0]).item()
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
import unittest
|
||||
import torch
|
||||
import numpy as np
|
||||
from tinygrad.helpers import getenv
|
||||
from tinygrad.helpers import getenv, Context
|
||||
if getenv("TINY_BACKEND2"):
|
||||
import extra.torch_backend.backend2
|
||||
device = "cpu"
|
||||
@@ -165,5 +165,15 @@ class TestTorchBackend(unittest.TestCase):
|
||||
result = a // b
|
||||
np.testing.assert_equal(result.cpu().numpy(), [3., 3., 2.])
|
||||
|
||||
def test_mnist_index(self):
|
||||
with Context(FUSE_ARANGE=1, SPLIT_REDUCEOP=0):
|
||||
from tinygrad.nn.datasets import mnist
|
||||
X_train, Y_train, _, _ = mnist()
|
||||
X_train = torch.tensor(X_train.float().numpy(), device=device)
|
||||
Y_train = torch.tensor(Y_train.cast('int64').numpy(), device=device)
|
||||
samples = torch.randint(0, X_train.shape[0], (32,))
|
||||
X,Y = X_train[samples], Y_train[samples]
|
||||
X.cpu(), Y.cpu()
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -148,10 +148,10 @@ class TestIndexing(unittest.TestCase):
|
||||
np.testing.assert_equal(X.numpy(), 0)
|
||||
|
||||
@unittest.skipIf(getenv("PTX"), "broken on ptx for some reason")
|
||||
def test_index_mnist(self, noopt=1, op_limit=512*784*13):
|
||||
def test_index_mnist(self, noopt=1, op_limit=512*784*13, split_reduceop=0):
|
||||
from tinygrad.nn.datasets import mnist
|
||||
X_train, Y_train, _, _ = mnist()
|
||||
with Context(NOOPT=noopt, FUSE_ARANGE=1, SPLIT_REDUCEOP=0):
|
||||
with Context(NOOPT=noopt, FUSE_ARANGE=1, SPLIT_REDUCEOP=split_reduceop):
|
||||
samples = Tensor.randint(getenv("BS", 512), high=X_train.shape[0]).realize()
|
||||
GlobalCounters.reset()
|
||||
x = X_train[samples].numpy()
|
||||
@@ -159,8 +159,14 @@ class TestIndexing(unittest.TestCase):
|
||||
assert GlobalCounters.global_ops < op_limit, f"too many ops {GlobalCounters.global_ops} != {op_limit}"
|
||||
np.testing.assert_allclose(X_train.numpy()[samples.numpy()], x)
|
||||
np.testing.assert_allclose(Y_train.numpy()[samples.numpy()], y)
|
||||
@unittest.skip("not ready")
|
||||
|
||||
# TODO: fix these on WEBGPU, it looks like it has to do with packed stuff
|
||||
@unittest.skipIf(getenv("WEBGPU"), "broken on webgpu for some reason")
|
||||
def test_index_mnist_opt(self): self.test_index_mnist(0)
|
||||
@unittest.skipIf(getenv("WEBGPU"), "broken on webgpu for some reason")
|
||||
def test_index_mnist_split(self): self.test_index_mnist(1, split_reduceop=1)
|
||||
@unittest.skipIf(getenv("WEBGPU"), "broken on webgpu for some reason")
|
||||
def test_index_mnist_opt_split(self): self.test_index_mnist(0, split_reduceop=1)
|
||||
|
||||
@unittest.skipIf(getenv("PTX"), "broken on ptx for some reason")
|
||||
def test_llama_embedding(self, noopt=1, op_limit=65536):
|
||||
|
||||
@@ -6,7 +6,7 @@ from tinygrad.device import is_dtype_supported
|
||||
from tinygrad.dtype import dtypes, ImageDType, PtrDType, promo_lattice
|
||||
from tinygrad.ops import UOp, Ops, UPat, PatternMatcher, resolve, graph_rewrite, GroupOp, identity_element
|
||||
from tinygrad.codegen.symbolic import symbolic_simple, split_uop, uop_given_valid, parse_valid, simplify_valid, sym, symbolic_flat, gep_pushing
|
||||
from tinygrad.helpers import getenv, flatten, TRANSCENDENTAL, AMX, prod, DEVECTORIZE
|
||||
from tinygrad.helpers import getenv, flatten, TRANSCENDENTAL, AMX, prod, DEVECTORIZE, partition
|
||||
from tinygrad.codegen.transcendental import xexp2, xlog2, xsin, xpow, TRANSCENDENTAL_SUPPORTED_DTYPES
|
||||
from tinygrad.renderer import Renderer
|
||||
|
||||
@@ -332,13 +332,78 @@ def reduce_to_acc(ctx:ReduceContext, red:UOp):
|
||||
ret = functools.reduce(lambda x,y: x.alu(red.arg, y), lst)
|
||||
return acc.assign(ret) if len(reduce_range) != 0 else ret
|
||||
|
||||
def no_vectorized_reduce(inp:UOp, red:UOp):
|
||||
if inp.dtype != red.dtype:
|
||||
# NOTE: [0 1 2 3 4 5 6 7] -> [0+4, 1+5, 2+6, 3+7]
|
||||
horizontal_amount = inp.dtype.count//red.dtype.count
|
||||
lst = [inp.gep(tuple(range(i, inp.dtype.count, horizontal_amount))) for i in range(0, horizontal_amount)]
|
||||
red = red.replace(src=(functools.reduce(lambda x,y: x.alu(red.arg, y), lst),))
|
||||
if red.dtype.vcount == 1: return red
|
||||
return no_vectorized_alu(red)
|
||||
|
||||
def range_fold(lo:UOp, hi:UOp, st:UOp, cut:UOp, val:UOp) -> UOp:
|
||||
# psuedo code: sum(val if i >= cut else 0) for i in range(lo, hi, st))
|
||||
# TODO: this function is so tricky and still probably wrong. test it
|
||||
total = (hi-lo+st-1) // st # real count in the range
|
||||
length = ((lo-cut+total*st)//st).maximum(0).minimum(total) # number in cut
|
||||
return length.cast(val.dtype) * val
|
||||
|
||||
def index_fold(buf:UOp, r:UOp, idx:UOp, r2:UOp) -> UOp|None:
|
||||
if r.arg != r2.arg: return None
|
||||
base_idx = (idx-r2.src[0])//r2.src[2] # indexed from 0 to the length of the range
|
||||
return buf.index(base_idx.cast(r.dtype)*r.src[2] + r.src[0], (idx >= r2.src[0]) & (idx < r2.src[1]))
|
||||
|
||||
pm_reduce_collapse = PatternMatcher([
|
||||
# put third arg in range
|
||||
(UPat(Ops.RANGE, src=(UPat.var(), UPat.var()), name="r"), lambda r: r.replace(src=r.src+(UOp.const(r.dtype, 1),))),
|
||||
# mul to range
|
||||
(UPat.var("x") * UPat(Ops.RANGE, name="r"), lambda x,r: r.replace(src=(r.src[0]*x, r.src[1]*x, r.src[2]*x))),
|
||||
# add to range
|
||||
(UPat.var("x") + UPat(Ops.RANGE, name="r"), lambda x,r: r.replace(src=(r.src[0]+x, r.src[1]+x, r.src[2]))),
|
||||
# 0 is "true" arg in where. fold the range
|
||||
((UPat(Ops.RANGE, src=(UPat.var("lo"), UPat.var("hi"), UPat.var("st"))) < UPat.cvar("cut")) \
|
||||
.where(UPat(Ops.CONST, arg=0), UPat.cvar("val")).reduce(arg=Ops.ADD), range_fold),
|
||||
# devectorize REDUCE
|
||||
(UPat(Ops.VECTORIZE, name="inp").reduce(name="red"), no_vectorized_reduce),
|
||||
# REDUCE on ADD
|
||||
((UPat.var("x")+UPat.var("y")).reduce(arg=Ops.ADD), lambda x,y: x.reduce(arg=Ops.ADD) + y.reduce(arg=Ops.ADD)),
|
||||
# MUL casted bool
|
||||
((UPat.var("x") * UPat.var("gate", dtype=dtypes.bool).cast()), lambda x,gate: gate.where(x, 0)),
|
||||
# WHERE on LOAD (works on max too)
|
||||
(UPat.var("gate").where(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx"))).load(), 0).reduce(arg=Ops.ADD),
|
||||
lambda buf,idx,gate: buf.index(idx, gate).load()),
|
||||
(UPat.var("gate").where(UPat(Ops.CONST, arg=0), UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx"))).load()).reduce(arg=Ops.ADD),
|
||||
lambda buf,idx,gate: buf.index(idx, gate.logical_not()).load()),
|
||||
# INDEX on RANGE / gated RANGE
|
||||
(UPat.var("buf").index(UPat(Ops.RANGE, name="r"), UPat.var("idx").eq(UPat(Ops.RANGE, name="r2"))), index_fold),
|
||||
# index/load. TODO: this is more aggressive than needed
|
||||
(UPat((Ops.INDEX, Ops.LOAD), name="alu"), no_vectorized_alu),
|
||||
# cast on RANGE (fix torch indexing)
|
||||
(UPat(Ops.RANGE, name="r").cast(name="c"), lambda r,c: r.replace(src=tuple([x.cast(c.dtype) for x in r.src]), dtype=c.dtype)),
|
||||
])+sym
|
||||
|
||||
def reduce_collapse(red:UOp):
|
||||
included, not_included = partition(red.parents, lambda x: any(y in x.sparents for y in red.src[1:]))
|
||||
if any(x.op in {Ops.STORE, Ops.REDUCE} for x in included): return None
|
||||
replaces = {red:red.replace(src=red.src[0:1])}
|
||||
for u in included:
|
||||
for s in u.src:
|
||||
if s in not_included and s not in replaces and s.op not in {Ops.CONST, Ops.VCONST, Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL, Ops.DEFINE_VAR}:
|
||||
replaces[s] = UOp(Ops.DEFINE_VAR, dtype=s.dtype, arg=(f'in{len(replaces)}', s.vmin, s.vmax))
|
||||
collapse_fxn = red.substitute(replaces)
|
||||
sink = graph_rewrite(collapse_fxn, pm_reduce_collapse+devectorize, name="reduce_collapse")
|
||||
if any(x.op in {Ops.REDUCE, Ops.RANGE} for x in sink.toposort()): return None
|
||||
return sink.substitute({v:k for k,v in replaces.items()})
|
||||
|
||||
pm_reduce = PatternMatcher([
|
||||
# remove REDUCE without loads (generic arange opt / indexing). TODO: support multi range
|
||||
(UPat(Ops.REDUCE, src=(UPat(), UPat()), name="red"), reduce_collapse),
|
||||
# REDUCE -> DEFINE_ACC+ASSIGN
|
||||
(UPat(Ops.REDUCE, name="red"), reduce_to_acc),
|
||||
# tensor core built in accumulate
|
||||
(UPat(Ops.WMMA, name="wmma") + UPat.var("add"),
|
||||
lambda add, wmma: UOp(wmma.op, wmma.dtype, (wmma.src[0], wmma.src[1], wmma.src[2]+add), wmma.arg)),
|
||||
])
|
||||
])+sym
|
||||
|
||||
# *** uop graph ***
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ import math, operator, struct, functools
|
||||
from collections import defaultdict
|
||||
from tinygrad.ops import Ops, PatternMatcher, UPat, UOp, GroupOp, exec_alu
|
||||
from tinygrad.dtype import ConstType, dtypes, PtrDType
|
||||
from tinygrad.helpers import partition, all_same, prod, getenv, DEBUG, flatten, get_single_element
|
||||
from tinygrad.helpers import partition, all_same, prod, flatten, get_single_element
|
||||
from tinygrad.codegen.transcendental import xpow
|
||||
|
||||
# ******** phase 1 of symbolic used to live in ops, it's the most generic folding rules ********
|
||||
@@ -369,40 +369,6 @@ def threefry2x32(x: UOp, key: UOp):
|
||||
|
||||
# ******** phase 3 is the complete symbolic, and deals with very complex things like loop rewriting and threefry transform ********
|
||||
|
||||
def loop_collapse(compval, multconst, rng:UOp, acc:UOp, extra:UOp, idx2=None,idx3=None,vec=None,
|
||||
add=UOp.const(dtypes.int, 0), mul:UOp=UOp.const(dtypes.int, 1)):
|
||||
if getenv("DISABLE_LOOP_COLLAPSE") or rng not in acc.src: return None # must be the right REDUCE
|
||||
if acc not in split_uop(extra, Ops.ADD): return None
|
||||
loop_start, loop_end = rng.src
|
||||
if loop_start.arg != 0:
|
||||
# TODO: support and test this with other mul and loop_starts
|
||||
if DEBUG >= 1: print(f"WARNING, NOT FOLDING: mul:{mul.arg} loop_start:{loop_start.arg}")
|
||||
return None
|
||||
if idx2 is not None: add = add + idx2
|
||||
if idx3 is not None: add = add + idx3
|
||||
if vec is not None:
|
||||
# add, mul, loop_start, loop_end
|
||||
def dvec(x:UOp):
|
||||
if x.op is Ops.CONST: return UOp.const(x.dtype.vec(vec.dtype.count), x.arg)
|
||||
return UOp(Ops.VECTORIZE, x.dtype.vec(vec.dtype.count), src=(x,)*vec.dtype.count)
|
||||
add, mul, loop_start, loop_end = dvec(add), dvec(mul), dvec(loop_start), dvec(loop_end)
|
||||
if mul.vmin > 0:
|
||||
comprange = UOp.minimum(loop_end, UOp.maximum((add-compval+(loop_end-loop_start)*mul)//mul, loop_start))
|
||||
else:
|
||||
return None
|
||||
new_reduce_op = comprange.cast(multconst.dtype) * multconst
|
||||
# TODO: what does it mean to have the same numbered DEFINE_ACC with different ranges?
|
||||
new_acc = acc.replace(src=acc.src[0:1]+tuple(x for x in acc.src[1:] if x is not rng))
|
||||
ret = new_acc.assign(new_acc+new_reduce_op)
|
||||
if extra is not acc: ret = ret + acc.assign(extra)
|
||||
return ret
|
||||
|
||||
def index_collapse(idx:UOp,rng:UOp,buf:UOp,ld:UOp,acc:UOp,add=UOp.const(dtypes.int, 0),mul=UOp.const(dtypes.int, 1)):
|
||||
if rng not in acc.src: return None
|
||||
new_load = UOp.load(buf.index(add+mul*idx, (idx >= rng.src[0]) & (idx < rng.src[1])), dtype=ld.dtype)
|
||||
new_acc = acc.replace(src=acc.src[0:1]+tuple(x for x in acc.src[1:] if x is not rng))
|
||||
return new_acc.assign(new_acc+new_load)
|
||||
|
||||
def reduce_collapse(acc:UOp, ret:UOp, alu:UOp):
|
||||
reduce_parented, reduce_unparented = partition(acc.src[1:], lambda x: x in ret.toposort())
|
||||
if len(reduce_unparented) == 0: return None
|
||||
@@ -411,14 +377,7 @@ def reduce_collapse(acc:UOp, ret:UOp, alu:UOp):
|
||||
if alu.op is Ops.ADD:
|
||||
for r in reduce_unparented: ret = ret * (r.src[1]-r.src[0]).cast(ret.dtype.scalar()).broadcast(ret.dtype.count)
|
||||
return ret
|
||||
|
||||
acc_pat, rng_pat = UPat(Ops.DEFINE_ACC, name="acc"), UPat(Ops.RANGE, name="rng")
|
||||
rng_aug = UPat.any(rng_pat, UPat.var("add")+rng_pat, UPat.var("mul")*rng_pat, UPat.var("add")+UPat.var("mul")*rng_pat)
|
||||
|
||||
index_load = UPat.var("buf").index(rng_aug).load(name="ld")
|
||||
|
||||
arange_augrng = UPat.any(rng_aug, rng_aug+UPat.var("idx2"), rng_aug+UPat.var("idx2")+UPat.var("idx3"), UPat(Ops.VECTORIZE, name="vec", src=rng_aug))
|
||||
arange_m = (arange_augrng<UPat.cvar("compval")).where(UPat.const(None, 0), UPat.cvar("multconst"))
|
||||
acc_pat = UPat(Ops.DEFINE_ACC, name="acc")
|
||||
|
||||
def reduce_mul_chain(r:UOp):
|
||||
if r.arg not in {Ops.ADD, Ops.MAX}: return None
|
||||
@@ -460,11 +419,6 @@ sym = symbolic_flat+PatternMatcher([
|
||||
lambda x,y: y.where(x, UOp.const(dtypes.uint32, 0)).cast(dtypes.uint64) * (1<<32)),
|
||||
((UPat.var('x', dtypes.uint64)&(UPat.var('y').where(UPat.const(dtypes.uint64, 0xFFFFFFFF), UPat.const(dtypes.uint64, 0)))).cast(dtypes.uint32),
|
||||
lambda x,y: y.where(x.cast(dtypes.uint32), UOp.const(dtypes.uint32, 0))),
|
||||
# arange loop folding
|
||||
(acc_pat.assign(arange_m+UPat.var("extra")), loop_collapse),
|
||||
# indexing, with cast or where
|
||||
(acc_pat.assign(UPat.var("idx").eq(UPat(Ops.RANGE, name="rng")).cast()*index_load+acc_pat), index_collapse),
|
||||
(acc_pat.assign((UPat.var("idx")!=UPat(Ops.RANGE, name="rng")).where(UPat.const(None, 0.0), index_load)+acc_pat), index_collapse),
|
||||
# parentless reduce # TODO: add MUL
|
||||
(acc_pat.assign(UPat((Ops.ADD, Ops.MAX), src=[acc_pat, UPat.var("ret")], name="alu")), reduce_collapse),
|
||||
# ** self folding **
|
||||
|
||||
@@ -266,6 +266,14 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
def __repr__(self): return pretty_print(self, lambda x: f"{type(self).__name__}({x.op}, {x.dtype}, arg={x.argstr()}, src=(%s))")
|
||||
def argstr(self): return f'({", ".join(map(str, self.arg))})' if self.op is Ops.REDUCE_AXIS else repr(self.arg)
|
||||
|
||||
@functools.cached_property
|
||||
def parents(self:UOp) -> dict[UOp, None]:
|
||||
ret = {s:None for s in self.src}
|
||||
for s in self.src: ret.update(s.parents)
|
||||
return ret
|
||||
@property
|
||||
def sparents(self:UOp) -> dict[UOp, None]: return {self:None, **self.parents}
|
||||
|
||||
def toposort(self, gate:Callable|None=None) -> dict[UOp, None]:
|
||||
ret: dict[UOp, None] = {}
|
||||
stack: list[tuple[UOp, bool]] = [(self, False)] # each stack entry is (node, visited_flag)
|
||||
@@ -382,6 +390,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
def cast_vec(self, dtype:DType): return UOp(Ops.CAST, dtype.vec(self.dtype.count), (self,))
|
||||
def bitcast(self, dtype:DType): return UOp(Ops.BITCAST, dtype, (self,))
|
||||
def gep(self, i:Union[tuple[int, ...], int]):
|
||||
if isinstance(i, tuple) and len(i) == 1: return self.gep(i[0])
|
||||
if isinstance(i, int):
|
||||
# NOTE: these are just shortcuts to not have to create and fold later
|
||||
if self.op is Ops.VECTORIZE: return self.src[i]
|
||||
@@ -389,7 +398,9 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
if self.op is Ops.CONST: return UOp.const(self.dtype.scalar(), self.arg)
|
||||
i = (i,)
|
||||
return UOp(Ops.GEP, self.dtype.scalar().vec(len(i)) if len(i) > 1 else self.dtype.scalar(), (self,), i)
|
||||
def load(self, *src:UOp, **kwargs): return UOp(Ops.LOAD, src=(self,)+src, **kwargs)
|
||||
def load(self, *src:UOp, **kwargs):
|
||||
if 'dtype' not in kwargs: kwargs['dtype'] = self.dtype.base
|
||||
return UOp(Ops.LOAD, src=(self,)+src, **kwargs)
|
||||
def store(self, *src:UOp, **kwargs): return UOp(Ops.STORE, dtypes.void, (self,)+src, **kwargs)
|
||||
def alu(self, arg, *src:UOp):
|
||||
out_dtype = (self, *src)[-1].dtype
|
||||
@@ -412,6 +423,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
axis = tuple(sorted([x for x in axis if resolve(self.shape[x] != 1)]))
|
||||
return self if len(axis) == 0 else UOp(Ops.REDUCE_AXIS, self.dtype, (self,), (op, axis))
|
||||
def assign(self, x:UOp): return UOp(Ops.ASSIGN, self.dtype, (self,x))
|
||||
def reduce(self, *src:UOp, **kwargs): return UOp(Ops.REDUCE, self.dtype, src=(self,)+src, **kwargs)
|
||||
def contiguous(self): return self.alu(Ops.CONTIGUOUS)
|
||||
def contiguous_backward(self): return self.alu(Ops.CONTIGUOUS_BACKWARD)
|
||||
def fuse(self): return self.alu(Ops.FUSE)
|
||||
@@ -758,7 +770,7 @@ class UPat(MathTrait):
|
||||
# copied from UOp
|
||||
def index(self, idx:UPat, valid:Optional[UPat]=None): return UPat(Ops.INDEX, self.dtype, (self,idx,valid) if valid is not None else (self,idx))
|
||||
def view(self, st=None, **kwargs): return UPat(Ops.VIEW, self.dtype, (self,), st, **kwargs)
|
||||
def cast(self, dtype=None): return UPat(Ops.CAST, dtype, (self,))
|
||||
def cast(self, dtype=None, **kwargs): return UPat(Ops.CAST, dtype, (self,), **kwargs)
|
||||
def bitcast(self, dtype=None): return UPat(Ops.BITCAST, dtype, (self,))
|
||||
def gep(self, i:int): return UPat(Ops.GEP, None, (self,), (i,))
|
||||
def load(self, *src:UPat, **kwargs): return UPat(Ops.LOAD, src=(self,)+src, **kwargs)
|
||||
|
||||
Reference in New Issue
Block a user