From cb57774b64b8c11691e2d41c0ee52a12801481e4 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Tue, 5 Nov 2024 01:21:14 +0800 Subject: [PATCH] pre index load and store [pr] (#7535) * pre index load and store [pr] * check ptrtype --- tinygrad/codegen/lowerer.py | 22 ++++++---------------- tinygrad/ops.py | 3 +-- 2 files changed, 7 insertions(+), 18 deletions(-) diff --git a/tinygrad/codegen/lowerer.py b/tinygrad/codegen/lowerer.py index f83b12bec4..04026705a5 100644 --- a/tinygrad/codegen/lowerer.py +++ b/tinygrad/codegen/lowerer.py @@ -5,7 +5,7 @@ from dataclasses import dataclass from typing import List, Tuple, cast, Optional from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.view import variable_to_uop -from tinygrad.dtype import dtypes, ImageDType +from tinygrad.dtype import dtypes, PtrDType from tinygrad.ops import KernelInfo, UOp, Ops, graph_rewrite, PatternMatcher, UPat, sint, identity_element from tinygrad.renderer import Renderer from tinygrad.helpers import all_int, prod, partition, flatten @@ -114,17 +114,17 @@ def lower_load_store(ctx: IndexContext, x: UOp): buf = x.src[0] if x.op is Ops.LOAD: barrier = (UOp(Ops.BARRIER, dtypes.void, (x.src[2],)),) if x.src[0].op is Ops.DEFINE_LOCAL else () - return UOp(Ops.LOAD, x.dtype, (buf, idx) + ((x.const_like(0), valid) if has_valid else ()) + barrier) + return UOp(Ops.LOAD, x.dtype, (buf.index(idx, valid if has_valid else None),) + barrier) # NOTE: only store the local reduceop in the threads that are actually doing the reduce - store_back = x.src[0].op is Ops.DEFINE_LOCAL and x.src[2].op is Ops.REDUCE and \ - x.src[2].src[0].op is Ops.LOAD and x.src[2].src[0].src[0].op is Ops.DEFINE_LOCAL + store_back = cast(PtrDType, x.src[0].dtype).local and x.src[2].op is Ops.REDUCE and \ + x.src[2].src[0].op is Ops.LOAD and cast(PtrDType, x.src[2].src[0].src[0].dtype).local # NOTE: If we're storing the reduced value back into each thread, need to zero-out the reduced axes if store_back: idx, _ = x.st_arg.to_indexed_uops([u.const_like(0) if u in x.src[2].src else u for u in ctx.idxs]) - if x.src[0].op is Ops.DEFINE_GLOBAL or store_back: + if (not cast(PtrDType, x.src[0].dtype).local) or store_back: for oidx, ridx in zip(ctx.idxs, ctx.ridxs): if oidx is not ridx: valid = valid * oidx.eq(0) has_valid = valid.op is not Ops.CONST or valid.arg is not True - return UOp(Ops.STORE, dtypes.void, (buf, idx, x.src[2]) + ((valid,) if has_valid else ())) + return UOp(Ops.STORE, dtypes.void, (buf.index(idx, valid if has_valid else None), x.src[2])) pm_lowerer = PatternMatcher([ (UPat(Ops.REDUCE_AXIS, name="x"), lower_reduce_axis), @@ -133,14 +133,6 @@ pm_lowerer = PatternMatcher([ (UPat((Ops.LOAD, Ops.STORE), src=(UPat(), UPat(Ops.VIEW)), allow_any_len=True, name="x"), lower_load_store), ]) -def idx_load_store(x:UOp): - idx = x.src[0].index(x.src[1], x.src[3] if len(x.src) > 3 else None) - v = x.dtype.count if x.op is Ops.LOAD else x.src[2].dtype.count - if v > 1 and not isinstance(x.src[0].dtype, ImageDType): idx = idx.cast(idx.dtype.base.vec(v).ptr(idx.dtype.local)) - post_mask = x.src[4:] if len(x.src) > 3 else (x.src[2:] if x.op is Ops.LOAD else x.src[3:]) - if x.op is Ops.LOAD: return UOp(x.op, x.dtype, (idx,)+post_mask, x.arg) - return UOp(x.op, x.dtype, (idx,x.src[2])+post_mask, x.arg) - def do_reduce(ctx:List[int], root:UOp): acc = UOp(Ops.DEFINE_ACC, root.dtype, (root.const_like(identity_element(root.arg, root.dtype.scalar())),) + tuple(root.src[1:]), (ctx[0],)) @@ -148,8 +140,6 @@ def do_reduce(ctx:List[int], root:UOp): return acc.assign(acc.alu(root.arg, root.src[0])) just_reduce = PatternMatcher([ - # use indexing for LOAD/STORE - (UPat((Ops.LOAD, Ops.STORE), src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)),), allow_any_len=True, name="x"), idx_load_store), # do reduce (UPat(Ops.REDUCE, name="root"), do_reduce), ]) diff --git a/tinygrad/ops.py b/tinygrad/ops.py index d09e81f421..f07cf69e93 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -791,8 +791,7 @@ spec = PatternMatcher([ (UPat(Ops.WHERE, name="w", src=(UPat(dtype=dtypes.bool), UPat(name="x"), UPat(name="y"))), lambda w,x,y: w.dtype == x.dtype == y.dtype), (UPat((Ops.CMPLT, Ops.CMPNE), dtype=dtypes.bool, src=(UPat(name="x"), UPat(name="y"))), lambda x,y: x.dtype == y.dtype), # and SHL/SHR, the shift distance can be an int - (UPat((Ops.SHL, Ops.SHR), src=(UPat(name="x"), UPat(name="y")), name="alu"), - lambda alu,x,y: alu.dtype == x.dtype and (x.dtype == y.dtype or y.dtype == dtypes.uint)), + (UPat((Ops.SHL, Ops.SHR), src=(UPat(name="x"), UPat(name="y")), name="a"), lambda a,x,y: a.dtype == x.dtype and y.dtype in (x.dtype, dtypes.uint)), (UPat(Ops.IDIV, name="x"), lambda x: None if dtypes.is_int(x.dtype) else False), (UPat(GroupOp.ALU, name="x"), lambda x: all(x.dtype == y.dtype for y in x.src)),