From be8958e26b7939a73b67dcbc2948e1f6f7cc6fc8 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Sun, 4 Aug 2024 16:17:33 -0700 Subject: [PATCH] use CONTRACT before REDUCE (#5903) * use CONTRACT before REDUCE [run_process_replay] * support half expand * EXPAND GEP --- test/test_uop_graph.py | 16 ++++++++++++++++ test/test_uops.py | 5 +++-- tinygrad/codegen/lowerer.py | 12 +++++++++--- tinygrad/codegen/uopgraph.py | 7 ++++--- tinygrad/codegen/uops.py | 1 + 5 files changed, 33 insertions(+), 8 deletions(-) diff --git a/test/test_uop_graph.py b/test/test_uop_graph.py index 5ee245ddbe..e17fa1aca2 100644 --- a/test/test_uop_graph.py +++ b/test/test_uop_graph.py @@ -352,6 +352,22 @@ class TestExpander(unittest.TestCase): self.assertListEqual([x.arg for x in sink.src[2].src], [4,6]) self.assertListEqual([x.arg for x in sink.src[3].src], [5,7]) + def test_contract_no_expand(self): + e1 = UOp(UOps.DEFINE_VAR, dtypes.int) + con = UOp(UOps.CONTRACT, dtypes.int.vec(2), (e1,), ((2,2),)) + sink = expander_rewrite(con) + assert sink.op is UOps.VECTORIZE and len(sink.src) == 2 + assert sink.src[0] == sink.src[1] + + def test_contract_half_expand(self): + e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((1,4),)) + con = UOp(UOps.CONTRACT, dtypes.int.vec(8), (e1,), ((1,4), (2,2))) + sink = expander_rewrite(con) + assert sink.op is UOps.VECTORIZE and len(sink.src) == 8 + assert sink.src[0] == sink.src[1] + assert sink.src[0] != sink.src[2] + assert sink.src[6] == sink.src[7] + def test_expand_same_axis(self): e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((1,4),)) e2 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, 4*x) for x in range(4)), ((1,4),)) diff --git a/test/test_uops.py b/test/test_uops.py index 3595b621f1..41603a8d92 100644 --- a/test/test_uops.py +++ b/test/test_uops.py @@ -2,7 +2,7 @@ from typing import Optional, Tuple, Any, List import unittest, math import numpy as np from tinygrad.tensor import Tensor, _to_np_dtype -from tinygrad.helpers import CI, DEBUG, getenv +from tinygrad.helpers import CI, DEBUG, getenv, Context from tinygrad.dtype import dtypes, DType, PtrDType from tinygrad.device import Buffer, Device from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, ReduceOps, exec_alu # noqa F401 @@ -364,7 +364,8 @@ class TestUOpStr(TestEqUOps): t = Tensor.arange(10) t = t + t * Tensor.rand(10) # nice big complicated uop - sink = get_kernel(Device[Device.DEFAULT].renderer, t.schedule()[-1].ast).linearize().uops.sink + with Context(NOOPT=1): + sink = get_kernel(Device[Device.DEFAULT].renderer, t.schedule()[-1].ast).linearize().uops.sink self.assert_equiv_uops(sink, eval(str(sink))) def test_nop_str(self): diff --git a/tinygrad/codegen/lowerer.py b/tinygrad/codegen/lowerer.py index 556943a544..9e8445a188 100644 --- a/tinygrad/codegen/lowerer.py +++ b/tinygrad/codegen/lowerer.py @@ -4,10 +4,10 @@ import functools from tinygrad.shape.shapetracker import ShapeTracker, View from tinygrad.shape.symbolic import sint from tinygrad.dtype import dtypes, PtrDType, ImageDType, DType -from tinygrad.ops import BufferOps, LazyOp, ReduceOps, UnaryOps, MetaOps, KernelInfo, MemBuffer +from tinygrad.ops import BufferOps, LazyOp, ReduceOps, UnaryOps, MetaOps, KernelInfo, MemBuffer, BinaryOps from tinygrad.codegen.uops import UOp, UOps from tinygrad.renderer import Renderer -from tinygrad.helpers import getenv, all_int, get_contraction, prod +from tinygrad.helpers import getenv, all_int, get_contraction, prod, partition, flatten # TODO: this needs to be replaced, there shouldn't be variables in the shapetracker, only ints and UOps from tinygrad.shape.symbolic import Variable, NumNode, SumNode, MulNode, DivNode, ModNode, LtNode, AndNode @@ -203,7 +203,13 @@ class IndependentLowerer: UOp.const(dtype.vec(wmma_sz[2]), 0.0)), arg=x.arg) return UOp(UOps.EXPAND, dtype, tuple(UOp(UOps.GEP, dtype, (ret,), i) for i in range(wmma_sz[2])), arg=upcast_axes[2]) # NOTE: always using ridxs is fine here - return UOp(UOps.REDUCE, dtype, (in_uops[0],) + tuple(self.ridxs[i] for i in x.arg), x.op) + reduce_range, reduce_expand = partition([self.ridxs[i] for i in x.arg], lambda y: y.op is UOps.RANGE) + ret = in_uops[0] + if len(contract_axis:=flatten(x.arg for x in reduce_expand)): + alu_op = {ReduceOps.SUM:BinaryOps.ADD, ReduceOps.MAX:BinaryOps.MAX}[cast(ReduceOps, x.op)] + ret = UOp(UOps.CONTRACT, dtype.vec(prod(x[1] for x in contract_axis)), (ret,), tuple(contract_axis)) + ret = functools.reduce(lambda x,y: x.alu(alu_op, y), [ret.gep(i) for i in range(cast(DType, ret.dtype).count)]) + return UOp(UOps.REDUCE, dtype, (ret,) + tuple(reduce_range), x.op) if len(reduce_range) else ret return in_uops[0].alu(x.op, *in_uops[1:]) def lazyop_to_uop(ast:LazyOp, opts:Renderer) -> UOp: return IndependentLowerer().lower(ast, opts) diff --git a/tinygrad/codegen/uopgraph.py b/tinygrad/codegen/uopgraph.py index fa1654ca5a..70d6ddf88e 100644 --- a/tinygrad/codegen/uopgraph.py +++ b/tinygrad/codegen/uopgraph.py @@ -400,9 +400,7 @@ def do_contract(con:UOp): ex = con.src[0] assert con.dtype is not None # CONTRACT without EXPAND repeats the element VECTORIZED - if ex.op is not UOps.EXPAND or not all(x in ex.arg for x in con.arg): - assert ex.op is not UOps.EXPAND or not any(x in ex.arg for x in con.arg), "partial contract not supported" - return UOp(UOps.VECTORIZE, con.dtype, con.src*con.dtype.count) + if ex.op is not UOps.EXPAND: return UOp(UOps.VECTORIZE, con.dtype, con.src*con.dtype.count) # CONTRACT may remove several axes from EXPAND assert con.dtype.count == prod([x[1] for x in con.arg]), "dtype is wrong" srcs = [] @@ -439,6 +437,9 @@ expander = PatternMatcher([ (NOp(UOps.BARRIER, src=(NOp(UOps.EXPAND, name="ex"),)), lambda ex: UOp(UOps.EXPAND, None, (UOp(UOps.BARRIER, None, ex.src),)*len(ex.src), ex.arg)), # empty EXPAND is NOOP (NOp(UOps.EXPAND, src=(NOp.var('x'),), arg=()), lambda x: x), + # EXPAND GEP (needed for WMMA, generalize this) -> vectorized ALU + (NOp(UOps.EXPAND, name="ex", src=tuple(NOp.var('x').gep(i)+NOp.var('y').gep(i) for i in range(8))), + lambda ex,x,y: UOp(UOps.EXPAND, ex.dtype, tuple((x+y).gep(i) for i in range(8)), ex.arg)), ]) def delete_redundant_gates(root:UOp) -> Optional[UOp]: diff --git a/tinygrad/codegen/uops.py b/tinygrad/codegen/uops.py index 5e2b856805..65bc9b41b4 100644 --- a/tinygrad/codegen/uops.py +++ b/tinygrad/codegen/uops.py @@ -48,6 +48,7 @@ class UOp: def ufix(self, x): return self.const(x) if not isinstance(x, UOp) else x def cast(self, dtype=None): return type(self)(UOps.CAST, dtype, (self,)) def bitcast(self, dtype=None): return type(self)(UOps.BITCAST, dtype, (self,)) + def gep(self, i:int): return type(self)(UOps.GEP, self.dtype.scalar() if self.dtype is not None else None, (self,), i) def __neg__(self): return self.alu(UnaryOps.NEG) def __add__(self, x): return self.alu(BinaryOps.ADD, self.ufix(x)) def __radd__(self, x): return self.alu(BinaryOps.ADD, self.ufix(x))