From 2e617ca59eb31b39bfe1d32622c3ec514503895e Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Fri, 19 Jul 2024 14:22:02 -0700 Subject: [PATCH] lowerer img index (#5592) --- test/unit/test_uop_symbolic.py | 1 - tinygrad/codegen/lowerer.py | 18 ++++++++++----- tinygrad/codegen/uopgraph.py | 40 ++++++++++++++++++---------------- 3 files changed, 34 insertions(+), 25 deletions(-) diff --git a/test/unit/test_uop_symbolic.py b/test/unit/test_uop_symbolic.py index de73c4f569..866da88ca8 100644 --- a/test/unit/test_uop_symbolic.py +++ b/test/unit/test_uop_symbolic.py @@ -202,7 +202,6 @@ class TestSymbolic(unittest.TestCase): def test_sum_div_no_factor(self): self.helper_test_variable(Node.sum([Variable("a", 0, 7)*5, Variable("b", 0, 3)*5]) // 2, 0, 25, "(((a*5)+(b*5))//2)") - @unittest.expectedFailure def test_mod_factor(self): # NOTE: even though the mod max is 50, it can't know this without knowing about the mul self.helper_test_variable(Node.sum([Variable("a", 0, 7)*100, Variable("b", 0, 3)*50]) % 100, 0, 99, "((b*50)%100)") diff --git a/tinygrad/codegen/lowerer.py b/tinygrad/codegen/lowerer.py index babec5737d..ca27b5aee3 100644 --- a/tinygrad/codegen/lowerer.py +++ b/tinygrad/codegen/lowerer.py @@ -4,7 +4,7 @@ import functools from tinygrad.shape.shapetracker import ShapeTracker, View from tinygrad.shape.symbolic import sint from tinygrad.dtype import dtypes, PtrDType, ImageDType, DType -from tinygrad.ops import BufferOps, LazyOp, TernaryOps, ReduceOps, UnaryOps, MetaOps, KernelInfo +from tinygrad.ops import BufferOps, LazyOp, TernaryOps, ReduceOps, UnaryOps, MetaOps, KernelInfo, MemBuffer from tinygrad.codegen.uops import UOp, UOps from tinygrad.renderer import Renderer from tinygrad.helpers import getenv, all_int, get_contraction @@ -33,7 +33,7 @@ if getenv("UOP_IS_SYMBOLIC"): if m[1] != sh: vexpr = vexpr * idx.lt(variable_to_uop(m[1])) return iexpr, vexpr - def st_to_uops(st:ShapeTracker, idxs:List[UOp]) -> Tuple[UOp, UOp]: + def st_to_uops(st:ShapeTracker, idxs:List[UOp], dtype:DType) -> Tuple[UOp, UOp]: idx, valid = _uop_view(st.views[-1], idxs, UOp.const(dtypes.bool, True)) for view in reversed(st.views[0:-1]): view = view.minify() @@ -43,13 +43,20 @@ if getenv("UOP_IS_SYMBOLIC"): idxs.append((idx//acc)%d) acc *= d idx, valid = _uop_view(view, idxs[::-1], valid) + if isinstance(dtype, ImageDType): + idx = UOp(UOps.VECTORIZE, dtypes.int.vec(3), ((idx // 4) % dtype.shape[1], (idx // (4 * dtype.shape[1])), idx % 4)) return idx, valid else: - def st_to_uops(st:ShapeTracker, idxs:List[UOp]) -> Tuple[UOp, UOp]: + def st_to_uops(st:ShapeTracker, idxs:List[UOp], dtype:DType) -> Tuple[UOp, UOp]: fake_idxs = [Variable(f"__idx{i}", 0, s-1) for i,s in enumerate(st.shape)] idx, valid = st.expr_idxs(fake_idxs) ctx = dict(zip(fake_idxs, idxs)) - uidx, uvalid = idx.render(render_ops, ctx), valid.render(render_ops, ctx) + uvalid = valid.render(render_ops, ctx) + if isinstance(dtype, ImageDType): + image_idxs = (idx // 4) % dtype.shape[1], (idx // (4 * dtype.shape[1])), idx % 4 + uidx = UOp(UOps.VECTORIZE, dtypes.int.vec(3), tuple(x.render(render_ops, ctx) for x in image_idxs)) + else: + uidx = idx.render(render_ops, ctx) if uvalid.op is UOps.CONST: uvalid = UOp.const(dtypes.bool, uvalid.arg) assert uvalid.dtype == dtypes.bool return uidx, uvalid @@ -130,7 +137,8 @@ class IndependentLowerer: def _to_uop(self, x:LazyOp) -> UOp: if x.op in BufferOps: - idx, valid = st_to_uops(x.arg.st, self.ridxs if x.op is BufferOps.LOAD and x.arg.idx == -1 else self.idxs) + idx, valid = st_to_uops(x.arg.st, self.ridxs if x.op is BufferOps.LOAD and x.arg.idx == -1 else self.idxs, + x.arg.dtype.base if isinstance(x.arg.dtype, ImageDType) and (not isinstance(x.arg, MemBuffer) or x.arg.idx == -1) else x.arg.dtype) # TODO: check has_valid in UPat, not here has_valid = valid.op is not UOps.CONST or valid.arg is not True if x.op is BufferOps.CONST: diff --git a/tinygrad/codegen/uopgraph.py b/tinygrad/codegen/uopgraph.py index 467174c219..b81001917a 100644 --- a/tinygrad/codegen/uopgraph.py +++ b/tinygrad/codegen/uopgraph.py @@ -79,23 +79,18 @@ class PatternMatcher: # ***** image handling ***** -def fix_image_idx(ls:UOp): - if ls.src[1].dtype is None or ls.src[1].dtype.count != 1: return None - if not isinstance(ls.src[0].dtype, ImageDType): return None - assert ls.op is not UOps.STORE or cast(DType, ls.src[2].dtype).count == 4, "image store must be float4" - idxy = ls.src[1] - #if not idxy.divides(4): raise RuntimeError("image index must divide 4") - base_shape = ls.src[0].dtype.shape - idx, idy = (idxy // 4) % base_shape[1], (idxy // (4 * base_shape[1])) - image_idx = UOp(UOps.VECTORIZE, cast(DType, idxy.dtype).vec(2), (idx, idy)) - if ls.op is UOps.LOAD and cast(DType, ls.dtype).count == 1: - cconst = (UOp(UOps.VECTORIZE, cast(DType, ls.dtype).vec(4), src=(ls.src[3], ls.src[3], ls.src[3], ls.src[3])),) if len(ls.src) >= 3 else () - loaded = UOp(ls.op, cast(DType, ls.dtype).vec(4), (ls.src[0], image_idx) + ls.src[2:3] + cconst, ls.arg) - subidx = idxy%4 - ret = UOp.const(ls.dtype, 0) - for i in range(4): ret = UOp.alu(TernaryOps.WHERE, subidx.ne(i), ret, UOp(UOps.GEP, ls.dtype, (loaded,), i)) - return ret - return UOp(ls.op, ls.dtype, (ls.src[0], image_idx) + ls.src[2:], ls.arg) +def image_contract_load(buf, idx, idy, id4, ls_allow_any_len): + if len(ls_allow_any_len.src) > 3: + # TODO: there's no contract on the gate, is this okay? + extra = (ls_allow_any_len.src[2], UOp(UOps.VECTORIZE, ls_allow_any_len.dtype.vec(4), (ls_allow_any_len.src[3],)*4)) + else: extra = ls_allow_any_len.src[2:] # NOTE: image load shouldn't have barrier and this shouldn't matter + vec_load = UOp(UOps.LOAD, ls_allow_any_len.dtype.vec(4), (buf, UOp(UOps.VECTORIZE, dtypes.int.vec(2), (idx, idy))) + extra) + return functools.reduce(lambda ret, i: UOp.alu(TernaryOps.WHERE, id4.ne(i), ret, UOp(UOps.GEP, ls_allow_any_len.dtype, (vec_load,), i)), range(4), + UOp.const(ls_allow_any_len.dtype, float('nan'))) + +def image_contract_store(buf, ex, idx, idy, ls_allow_any_len, var): + new_var = UOp(UOps.CONTRACT, var.dtype.vec(4), (var,), (ex.arg[0][0],)) + return UOp(UOps.STORE, None, (buf, UOp(UOps.VECTORIZE, dtypes.int.vec(2), (idx, idy)), new_var) + ls_allow_any_len.src[3:]) # ***** float4 handling ***** @@ -142,6 +137,12 @@ float4_folding = PatternMatcher([ UOp(UOps.EXPAND).name("ex")+UOp.var("idx"), UOp.var("var"))).name("store_allow_any_len"), float4_contract_store), (UOp(UOps.STORE, src=(UOp.var("buf"), UOp(UOps.EXPAND).name("ex"), UOp.var("var"))).name("store_allow_any_len"), float4_contract_store), + # image handling + (UOp(UOps.LOAD, src=(UOp.var("buf"), UOp(UOps.VECTORIZE, dtypes.int.vec(3), (UOp.var('idx'), UOp.var('idy'), + UOp.var('id4'))))).name("ls_allow_any_len"), image_contract_load), + (UOp(UOps.STORE, src=(UOp.var("buf"), UOp(UOps.VECTORIZE, dtypes.int.vec(3), (UOp.var('idx'), UOp.var('idy'), + UOp(UOps.EXPAND, src=tuple(UOp.const(dtypes.int, i) for i in range(4))).name("ex"))), UOp.var("var"))).name("ls_allow_any_len"), + image_contract_store), ]) # ***** transcendental ***** @@ -300,6 +301,9 @@ constant_folder = PatternMatcher([ ((UOp.var('x') + UOp.cvar('c1')) + UOp.cvar('c2'), lambda x,c1,c2: x+UOp.const(x.dtype, exec_alu(BinaryOps.ADD, x.dtype, [c1.arg, c2.arg]))), ((UOp.var('x') - UOp.cvar('c1')) + UOp.cvar('c2'), lambda x,c1,c2: x+UOp.const(x.dtype, exec_alu(BinaryOps.ADD, x.dtype, [c2.arg, -c1.arg]))), # *** rules from symbolic *** + # mod divides + ((UOp.cvar('c')*UOp.var('x')) % UOp.cvar('c'), lambda x,c: x.const(0)), + (((UOp.cvar('c')*UOp.var('x'))+UOp.var('x2')) % UOp.cvar('c'), lambda x,c,x2: x2%c), # two stage mul, (x*c1)*c2 = x*(c1*c2) ((UOp.var("x") * UOp.cvar("c1")) * UOp.cvar("c2"), lambda x,c1,c2: x*UOp.const(x.dtype, exec_alu(BinaryOps.MUL, x.dtype, [c1.arg, c2.arg]))), # -(x+y) -> -x + -y @@ -464,8 +468,6 @@ expander = PatternMatcher([ if len(a:=tuple(flatten(x.src if x.op is UOps.EXPAND else (x,) for x in root.src))) != len(root.src) else None), # BARRIERs aren't actually expanded (UOp(UOps.BARRIER, src=(UOp(UOps.EXPAND).name("ex"),)), lambda ex: UOp(UOps.EXPAND, None, (UOp(UOps.BARRIER, None, ex.src),)*len(ex.src), ex.arg)), - # image indexing (needs to be here) - (UPat({UOps.LOAD, UOps.STORE}, name="ls"), fix_image_idx), # empty EXPAND is NOOP (UOp(UOps.EXPAND, src=(UOp.var('x'),), arg=()), lambda x: x), # no ALU on vectorized dtypes