diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index d34666938c..1fec9e9f21 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -149,7 +149,7 @@ jobs: torchbackend: name: Torch Backend Tests runs-on: ubuntu-latest - timeout-minutes: 10 + timeout-minutes: 15 steps: - name: Checkout Code uses: actions/checkout@v4 @@ -186,7 +186,7 @@ jobs: torchbackendmore: name: Torch Backend Tests More runs-on: ubuntu-latest - timeout-minutes: 10 + timeout-minutes: 15 steps: - name: Checkout Code uses: actions/checkout@v4 diff --git a/test/test_arange.py b/test/test_arange.py index 73d9185309..e300665588 100644 --- a/test/test_arange.py +++ b/test/test_arange.py @@ -46,6 +46,8 @@ class TestArange(unittest.TestCase): def test_complexity_w_group_top(self): return self.test_complexity([Opt(OptOps.GROUPTOP, 0, 16)], limit=106496) def test_complexity_w_local(self): return self.test_complexity([Opt(OptOps.LOCAL, 0, 16)], limit=0) + @unittest.skip("doesn't work yet. TODO: this absolutely should work") + def test_complexity_w_local_unroll4(self): return self.test_complexity([Opt(OptOps.LOCAL, 0, 16), Opt(OptOps.UNROLL, 0, 4)], limit=0) @unittest.skip("doesn't work yet") def test_complexity_w_local_and_padto(self): return self.test_complexity([Opt(OptOps.LOCAL, 0, 16), Opt(OptOps.PADTO, axis=1, arg=32)]) diff --git a/test/test_linearizer.py b/test/test_linearizer.py index 232f04ae94..bb3851f7cd 100644 --- a/test/test_linearizer.py +++ b/test/test_linearizer.py @@ -1825,8 +1825,9 @@ def _helper_linearizer_opt_ast(realized_ast:UOp, real_bufs:list[Buffer], opts=[] apply_tc=False, atol=1e-4, rtol=1e-4, color_sizes=[], wanna_output=[]) -> list[Kernel]: lins: list[Kernel] = [] outbufs = [real_bufs[x.src[0].arg] for x in realized_ast.src] + device = real_bufs[0].device - def get_prg(k:Kernel): return CompiledRunner(replace(k.to_program(), device=Device.DEFAULT)) + def get_prg(k:Kernel): return CompiledRunner(replace(k.to_program(), device=device)) def check_opt(opts, create_k, expected_color_size): k = create_k() diff --git a/test/test_uops_stats.py b/test/test_uops_stats.py index 3845629d45..cee06b7dda 100644 --- a/test/test_uops_stats.py +++ b/test/test_uops_stats.py @@ -120,17 +120,19 @@ class TestUOpsStats(unittest.TestCase): # NOTE; ops also include indexing ops assert expected_ops <= ops and ops <= expected_ops * 2 - def test_simple_matmul(self): - a = Tensor.empty(1024,1024) - b = Tensor.empty(1024,1024) + def test_simple_matmul(self, M=1024, N=1024, K=1024): + a = Tensor.empty(M,N) + b = Tensor.empty(N,K) c = a@b ops, mem = get_stats(c) - expected_ops = c.numel() * 1024 * 2 + expected_ops = c.numel() * N * 2 required_mem = a.nbytes() + b.nbytes() + c.nbytes() assert expected_ops <= ops and ops <= expected_ops * 1.2 # NOTE: it's hard to assert on the memory here, all depends on caching assert required_mem <= mem + def test_simple_matmul_8192(self): self.test_simple_matmul(8192, 8192, 8192) + #MULACC should have the same stats as MUL + ADD def test_mulacc(self): globl = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), tuple()) @@ -154,7 +156,7 @@ class TestUOpsStats(unittest.TestCase): self.assertEqual(flops_mem(uops), flops_mem(uops_fma)) -N = 100 +N = 64 @unittest.skipIf(getenv("PTX"), "wrong in PTX") # maybe? class TestStatsOptimized(unittest.TestCase): @classmethod @@ -174,6 +176,14 @@ class TestStatsOptimized(unittest.TestCase): self.check_gemm(p) self.assertEqual(p.estimates.lds, 2*N*N*N*4 + 4*N*N) + def test_gemm_tc_unroll(self): + k = Kernel(self.ast_gemm) + if not k.apply_tensor_cores(): self.skipTest("no tensor cores") + k.apply_opt(Opt(OptOps.UNROLL, 0, 2)) + p = k.to_program() + print(p.src) + self.check_gemm(p) + # this is a good lesson about why UPCASTing is a good idea def test_gemm_one_upcasted(self): diff --git a/test/unit/test_linearizer_rewrite.py b/test/unit/test_linearizer_rewrite.py new file mode 100644 index 0000000000..358c43cec6 --- /dev/null +++ b/test/unit/test_linearizer_rewrite.py @@ -0,0 +1,28 @@ +import unittest +from tinygrad import Tensor, Context, Device +from tinygrad.codegen.kernel import Kernel, Opt, OptOps + +class TestLinearizerRewrite(unittest.TestCase): + def test_reduction(self): + t = Tensor.ones((64,64), device="NULL").contiguous().realize() + out = (t*2).sum(axis=1) + with Context(SPLIT_REDUCEOP=0, DEVECTORIZE=0): + si = out.schedule()[-1] + k = Kernel(si.ast, Device["CPU"].renderer) + k.apply_opt(Opt(OptOps.UPCAST, 0, 4)) + k.apply_opt(Opt(OptOps.UNROLL, 0, 4)) + prg = k.to_program() + print(prg.src) + + def test_arange(self): + out = Tensor.arange(32, device="NULL") + with Context(SPLIT_REDUCEOP=0, DEVECTORIZE=0): + si = out.schedule()[-1] + k = Kernel(si.ast, Device["CPU"].renderer) + k.apply_opt(Opt(OptOps.UPCAST, 0, 4)) + k.apply_opt(Opt(OptOps.UNROLL, 0, 4)) + prg = k.to_program() + print(prg.src) + +if __name__ == '__main__': + unittest.main() diff --git a/tinygrad/codegen/devectorizer.py b/tinygrad/codegen/devectorizer.py index 42e22dc3fa..b63d869739 100644 --- a/tinygrad/codegen/devectorizer.py +++ b/tinygrad/codegen/devectorizer.py @@ -1,10 +1,10 @@ from typing import Optional, Any, Callable, cast import functools, operator, itertools from collections import defaultdict +from dataclasses import dataclass from tinygrad.dtype import dtypes, ImageDType, PtrDType -from tinygrad.ops import UOp, Ops, UPat, PatternMatcher, resolve -from tinygrad.ops import graph_rewrite, GroupOp -from tinygrad.codegen.symbolic import symbolic_simple, split_uop, uop_given_valid, parse_valid, simplify_valid, sym, symbolic_flat +from tinygrad.ops import UOp, Ops, UPat, PatternMatcher, resolve, graph_rewrite, GroupOp, identity_element +from tinygrad.codegen.symbolic import symbolic_simple, split_uop, uop_given_valid, parse_valid, simplify_valid, sym, symbolic_flat, gep_pushing from tinygrad.helpers import getenv, flatten, TRANSCENDENTAL, AMX, prod, DEVECTORIZE from tinygrad.codegen.transcendental import xexp2, xlog2, xsin, xpow, TRANSCENDENTAL_SUPPORTED_DTYPES from tinygrad.renderer import Renderer @@ -280,6 +280,38 @@ pm_render = PatternMatcher([ lambda store,idx: UOp(Ops.STORE, src=store.src+(UOp(Ops.IF, src=(idx.src[2],)),))), ]) +# *** Ops.REDUCE -> Ops.DEFINE_ACC+Ops.ASSIGN *** + +@dataclass +class ReduceContext: + acc_num: int = 0 + +def reduce_to_acc(ctx:ReduceContext, red:UOp): + inp, reduce_range = red.src[0], red.src[1:] + # if this has a horizontal reduction component, do that first + if inp.dtype != red.dtype: + # NOTE: [0 1 2 3 4 5 6 7] -> [0+4, 1+5, 2+6, 3+7] + horizontal_amount = inp.dtype.count//red.dtype.count + lst = [inp.gep(tuple(range(i, inp.dtype.count, horizontal_amount))) for i in range(0, horizontal_amount)] + else: + lst = [inp] + assert all(x.dtype == red.dtype for x in lst), f"horizontal reduction mismatch {lst[0].dtype} != {red.dtype}" + # if we have a range + if len(reduce_range) != 0: + acc = UOp(Ops.DEFINE_ACC, red.dtype, (red.const_like(identity_element(red.arg, red.dtype.scalar())),) + tuple(reduce_range), (ctx.acc_num,)) + lst = [acc] + lst # put acc as the first element + ctx.acc_num += 1 + ret = functools.reduce(lambda x,y: x.alu(red.arg, y), lst) + return acc.assign(ret) if len(reduce_range) != 0 else ret + +pm_reduce = PatternMatcher([ + # REDUCE -> DEFINE_ACC+ASSIGN + (UPat(Ops.REDUCE, name="red"), reduce_to_acc), + # tensor core built in accumulate + (UPat(Ops.WMMA, name="wmma") + UPat.var("add"), + lambda add, wmma: UOp(wmma.op, wmma.dtype, (wmma.src[0], wmma.src[1], wmma.src[2]+add), wmma.arg)), +]) + # *** uop graph *** def full_graph_rewrite(sink:UOp, opts:Optional[Renderer]=None) -> UOp: @@ -287,6 +319,9 @@ def full_graph_rewrite(sink:UOp, opts:Optional[Renderer]=None) -> UOp: supported_ops = tuple(opts.code_for_op.keys()) if opts is not None else () extra_matcher = opts.extra_matcher if opts is not None and opts.extra_matcher is not None else PatternMatcher([]) + # remove reduce + sink = graph_rewrite(sink, pm_reduce+gep_pushing, ctx=ReduceContext(), name="remove_reduce") + # devectorize is optional if DEVECTORIZE >= 2: sink = graph_rewrite(sink, sym+load_store_folding+load_store_indexing, ctx=opts) elif DEVECTORIZE: sink = graph_rewrite(sink, sym+devectorize+load_store_folding+correct_load_store+load_store_indexing, ctx=opts) diff --git a/tinygrad/codegen/lowerer.py b/tinygrad/codegen/lowerer.py index e1e93e70f2..fbb30b74a0 100644 --- a/tinygrad/codegen/lowerer.py +++ b/tinygrad/codegen/lowerer.py @@ -1,9 +1,9 @@ # the job of the lowerer is to do indexing -import functools, itertools, operator, math +import itertools, operator, math from dataclasses import dataclass from typing import cast from tinygrad.dtype import dtypes, PtrDType, least_upper_dtype -from tinygrad.ops import KernelInfo, UOp, Ops, graph_rewrite, PatternMatcher, UPat, sint, identity_element, sint_to_uop +from tinygrad.ops import KernelInfo, UOp, Ops, graph_rewrite, PatternMatcher, UPat, sint, sint_to_uop from tinygrad.renderer import Renderer from tinygrad.helpers import all_int, prod, partition, flatten, unwrap, QUANTIZE from tinygrad.codegen.expander import expand_rewrite @@ -116,17 +116,10 @@ def lower_reduce_axis(ctx: IndexContext, x: UOp): assert all(x.op is Ops.UNROLL for x in reduce_expand), f"not all UNROLLS in {reduce_expand} for {x.axis_arg}" alu_op: Ops = x.arg[0] ret = x.src[0] - # create acc - acc = UOp(Ops.DEFINE_ACC, x.dtype, (x.const_like(identity_element(alu_op, x.dtype.scalar())),) + tuple(reduce_range), (ctx.acc_num,)) - ctx.acc_num += 1 if len(contract_axis:=flatten(x.arg for x in reduce_expand)): ret = UOp(Ops.CONTRACT, x.dtype.vec(prod(x[1] for x in contract_axis)), (ret,), tuple(contract_axis)) - ret = functools.reduce(lambda x,y: x.alu(alu_op, y), [acc]+[ret.gep(i) for i in range(ret.dtype.count)]) - else: - ret = acc.alu(alu_op, ret) - if not len(reduce_range): return ret - # create ACC and assign - return acc.assign(ret) + # REDUCE supports both "horizonal" reduction and range reduction. the horizonal elements are taken in the nearest group + return UOp(Ops.REDUCE, x.dtype, (ret,)+tuple(reduce_range), alu_op) def lower_load_store(ctx: IndexContext, x: UOp): idx, valid = x.st_arg.to_indexed_uops(ctx.ridxs if x.op is Ops.LOAD and x.src[0].op is Ops.DEFINE_LOCAL else ctx.idxs) @@ -135,8 +128,8 @@ def lower_load_store(ctx: IndexContext, x: UOp): barrier = (UOp(Ops.BARRIER, dtypes.void, (x.src[2],)),) if x.src[0].op is Ops.DEFINE_LOCAL else () return UOp(Ops.LOAD, x.dtype, (buf.index(idx, valid),) + barrier) # NOTE: only store the local reduceop in the threads that are actually doing the reduce - if cast(PtrDType, x.src[0].dtype).local and x.src[2].op is Ops.ASSIGN: - reduce_input = x.src[2].src[1].src[1] if x.src[2].src[1].src[1] is not x.src[2].src[0] else x.src[2].src[1].src[0] + if cast(PtrDType, x.src[0].dtype).local and x.src[2].op is Ops.REDUCE: + reduce_input = x.src[2].src[0] store_back = reduce_input.op is Ops.LOAD and cast(PtrDType, reduce_input.src[0].dtype).local else: store_back = False # NOTE: If we're storing the reduced value back into each thread, need to zero-out the reduced axes diff --git a/tinygrad/codegen/symbolic.py b/tinygrad/codegen/symbolic.py index 70461cda81..6503ea7e19 100644 --- a/tinygrad/codegen/symbolic.py +++ b/tinygrad/codegen/symbolic.py @@ -172,6 +172,19 @@ def div_and_mod_folding(x: UOp, y: UOp, which: Literal[Ops.MOD, Ops.IDIV], split if which is Ops.MOD: return gcd*(rem % (c//gcd)) + const%gcd return rem//(c//gcd)+quo +def gep_through_wmma(gep:UOp, wmma:UOp): + out_sz = prod(x[1] for x in wmma.arg[6][-1]) + wmma_idxs = gep.arg[::out_sz] + for i in range(out_sz): + if tuple(x-i for x in gep.arg[i::out_sz]) != wmma_idxs: return None + tsrcs = [] + for s,sz in zip(wmma.src, wmma.arg[6]): + src_args = [] + ssz = prod(x[1] for x in sz) + for w in wmma_idxs: src_args += list(range((w//out_sz)*ssz, (w//out_sz)*ssz + ssz)) + tsrcs.append(s.gep(tuple(src_args))) + return UOp(Ops.WMMA, gep.dtype, tuple(tsrcs), wmma.arg) + gep_pushing = PatternMatcher([ # GEP/VECTORIZE, GEP/GEP, GEP/CONST, GEP/VCONST (UPat(Ops.GEP, src=(UPat(Ops.GEP, name='g2'),), name='g1'), @@ -193,6 +206,8 @@ gep_pushing = PatternMatcher([ if not isinstance(x.dtype, PtrDType) else None), # VECTORIZE on same GEP (UPat(Ops.VECTORIZE, name="v", src=UPat(Ops.GEP, src=(UPat.var("x"),))), lambda v,x: x.gep(tuple(get_single_element(i.arg) for i in v.src))), + # push some GEPs through WMMAs + (UPat(Ops.GEP, src=(UPat(Ops.WMMA, name="wmma"),), name="gep"), gep_through_wmma), ]) commutative = PatternMatcher([ @@ -395,19 +410,6 @@ def reduce_collapse(acc:UOp, ret:UOp, alu:UOp): for r in reduce_unparented: ret = ret * (r.src[1]-r.src[0]).cast(ret.dtype.scalar()).broadcast(ret.dtype.count) return ret -def gep_through_wmma(gep:UOp, wmma:UOp): - out_sz = prod(x[1] for x in wmma.arg[6][-1]) - wmma_idxs = gep.arg[::out_sz] - for i in range(out_sz): - if tuple(x-i for x in gep.arg[i::out_sz]) != wmma_idxs: return None - tsrcs = [] - for s,sz in zip(wmma.src, wmma.arg[6]): - src_args = [] - ssz = prod(x[1] for x in sz) - for w in wmma_idxs: src_args += list(range((w//out_sz)*ssz, (w//out_sz)*ssz + ssz)) - tsrcs.append(s.gep(tuple(src_args))) - return UOp(Ops.WMMA, gep.dtype, tuple(tsrcs), wmma.arg) - acc_pat, rng_pat = UPat(Ops.DEFINE_ACC, name="acc"), UPat(Ops.RANGE, name="rng") rng_aug = UPat.any(rng_pat, UPat.var("add")+rng_pat, UPat.var("mul")*rng_pat, UPat.var("add")+UPat.var("mul")*rng_pat) @@ -431,14 +433,9 @@ sym = symbolic_flat+PatternMatcher([ # VECTORIZE void is SINK (UPat(Ops.VECTORIZE, dtype=dtypes.void, src=UPat(Ops.BARRIER, name='b')), lambda b: b), (UPat(Ops.VECTORIZE, dtype=dtypes.void, name='x'), lambda x: UOp(Ops.SINK, dtypes.void, x.src)), - # push some GEPs through WMMAs - (UPat(Ops.GEP, src=(UPat(Ops.WMMA, name="wmma"),), name="gep"), gep_through_wmma), # tensor core with a 0 input is acc (UPat(Ops.WMMA, src=(UPat.const(None, 0.0), UPat.var(), UPat.var("acc"))), lambda acc: acc), (UPat(Ops.WMMA, src=(UPat.var(), UPat.const(None, 0.0), UPat.var("acc"))), lambda acc: acc), - # tensor core cleanups - (UPat.var("add") + UPat(Ops.WMMA, name="wmma"), - lambda add, wmma: UOp(wmma.op, wmma.dtype, (wmma.src[0], wmma.src[1], wmma.src[2]+add), wmma.arg)), # threefry + remove longs (UPat(Ops.THREEFRY, dtype=dtypes.uint64, src=(UPat.var("x"), UPat.var("key"))), threefry2x32), (UPat.var('x', dtypes.uint32).cast(dtypes.uint64).cast(dtypes.uint32), lambda x: x), # cast there and back is noop (TODO: genericize) @@ -489,4 +486,7 @@ sym = symbolic_flat+PatternMatcher([ (UPat.var("x") * ((1+UPat.var("x")).reciprocal().named("d")), lambda x,d: 1-d), # x*/(1+x) -> 1-1/(1+x) (UPat.var("x") * ((1+UPat.var("x")).reciprocal().named("d")*UPat.var("y")), lambda x,y,d: y*(1-d)), (UPat.var("x") * ((1+UPat.var("x")).reciprocal().named("d")+UPat.var("y")), lambda x,y,d: (1-d)+x*y), + # move const multiply after REDUCE. TODO: enable later + #(UPat(Ops.REDUCE, src=(UPat.var("x")*UPat.cvar("c", vec=False),), arg=Ops.ADD, name="r", allow_any_len=True), + # lambda x,c,r: r.replace(src=(x,)+r.src[1:])*c.arg), ]) diff --git a/tinygrad/engine/jit.py b/tinygrad/engine/jit.py index 04fbfb1cb9..d8b7b60a31 100644 --- a/tinygrad/engine/jit.py +++ b/tinygrad/engine/jit.py @@ -42,7 +42,7 @@ def apply_graph_to_jit(jit_cache: list[ExecItem], input_rawbuffers: list[Buffer] if isinstance(ji.prg, ViewOp): continue ji_graph_dev: Optional[Compiled] = None # device on which the ji will be graphed. Not graphed if None. if isinstance(ji.prg, CompiledRunner): ji_graph_dev = ji.prg.dev - elif isinstance(ji.prg, BufferXfer) and ji.bufs[0] and ji.bufs[0].device.split(":", 1)[0] in {"CUDA", "NV", "AMD"}: + elif isinstance(ji.prg, BufferXfer) and ji.bufs[0] and ji.bufs[0].device.split(":", 1)[0] in {"CUDA", "NV", "AMD", "NULL"}: ji_graph_dev = Device[ji.bufs[0].device] graph_class = (ji_graph_dev.graph.func if isinstance(ji_graph_dev.graph, functools.partial) else ji_graph_dev.graph) if ji_graph_dev else None diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 11f061a42e..275e38e5f9 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -715,6 +715,7 @@ class UPat(MathTrait): self.arg, self.name, self._in_src, self.custom_early_reject = arg, name, src, custom_early_reject self.src: Any = None assert self.name != "ctx", "UPat can't be named ctx" + assert dtype is None or isinstance(dtype, DType) or all(isinstance(x, DType) for x in dtype), f"invalid dtype {dtype}" # try all permutations if it's a list if isinstance(src, list): self.src = list(itertools.permutations(src)) if not all_same(src) else [src] diff --git a/tinygrad/runtime/ops_null.py b/tinygrad/runtime/ops_null.py index 2c384f9dab..e90189c3ec 100644 --- a/tinygrad/runtime/ops_null.py +++ b/tinygrad/runtime/ops_null.py @@ -10,9 +10,11 @@ class NullProgram: return 1e-4 class NullAllocator(Allocator): + dev = None def _alloc(self, size, options): pass def _copyin(self, dest, src:memoryview): pass def _copyout(self, dest:memoryview, src): pass + def _transfer(self, dest, src, sz:int, src_dev, dest_dev): pass class NullGraph(MultiGraphRunner): def __call__(self, input_rawbuffers, var_vals, wait=False) -> float|None: return 1e-3