pre index load and store [pr] (#7535)

* pre index load and store [pr]

* check ptrtype
This commit is contained in:
George Hotz
2024-11-05 01:21:14 +08:00
committed by GitHub
parent e34b89645a
commit cb57774b64
2 changed files with 7 additions and 18 deletions

View File

@@ -5,7 +5,7 @@ from dataclasses import dataclass
from typing import List, Tuple, cast, Optional
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import variable_to_uop
from tinygrad.dtype import dtypes, ImageDType
from tinygrad.dtype import dtypes, PtrDType
from tinygrad.ops import KernelInfo, UOp, Ops, graph_rewrite, PatternMatcher, UPat, sint, identity_element
from tinygrad.renderer import Renderer
from tinygrad.helpers import all_int, prod, partition, flatten
@@ -114,17 +114,17 @@ def lower_load_store(ctx: IndexContext, x: UOp):
buf = x.src[0]
if x.op is Ops.LOAD:
barrier = (UOp(Ops.BARRIER, dtypes.void, (x.src[2],)),) if x.src[0].op is Ops.DEFINE_LOCAL else ()
return UOp(Ops.LOAD, x.dtype, (buf, idx) + ((x.const_like(0), valid) if has_valid else ()) + barrier)
return UOp(Ops.LOAD, x.dtype, (buf.index(idx, valid if has_valid else None),) + barrier)
# NOTE: only store the local reduceop in the threads that are actually doing the reduce
store_back = x.src[0].op is Ops.DEFINE_LOCAL and x.src[2].op is Ops.REDUCE and \
x.src[2].src[0].op is Ops.LOAD and x.src[2].src[0].src[0].op is Ops.DEFINE_LOCAL
store_back = cast(PtrDType, x.src[0].dtype).local and x.src[2].op is Ops.REDUCE and \
x.src[2].src[0].op is Ops.LOAD and cast(PtrDType, x.src[2].src[0].src[0].dtype).local
# NOTE: If we're storing the reduced value back into each thread, need to zero-out the reduced axes
if store_back: idx, _ = x.st_arg.to_indexed_uops([u.const_like(0) if u in x.src[2].src else u for u in ctx.idxs])
if x.src[0].op is Ops.DEFINE_GLOBAL or store_back:
if (not cast(PtrDType, x.src[0].dtype).local) or store_back:
for oidx, ridx in zip(ctx.idxs, ctx.ridxs):
if oidx is not ridx: valid = valid * oidx.eq(0)
has_valid = valid.op is not Ops.CONST or valid.arg is not True
return UOp(Ops.STORE, dtypes.void, (buf, idx, x.src[2]) + ((valid,) if has_valid else ()))
return UOp(Ops.STORE, dtypes.void, (buf.index(idx, valid if has_valid else None), x.src[2]))
pm_lowerer = PatternMatcher([
(UPat(Ops.REDUCE_AXIS, name="x"), lower_reduce_axis),
@@ -133,14 +133,6 @@ pm_lowerer = PatternMatcher([
(UPat((Ops.LOAD, Ops.STORE), src=(UPat(), UPat(Ops.VIEW)), allow_any_len=True, name="x"), lower_load_store),
])
def idx_load_store(x:UOp):
idx = x.src[0].index(x.src[1], x.src[3] if len(x.src) > 3 else None)
v = x.dtype.count if x.op is Ops.LOAD else x.src[2].dtype.count
if v > 1 and not isinstance(x.src[0].dtype, ImageDType): idx = idx.cast(idx.dtype.base.vec(v).ptr(idx.dtype.local))
post_mask = x.src[4:] if len(x.src) > 3 else (x.src[2:] if x.op is Ops.LOAD else x.src[3:])
if x.op is Ops.LOAD: return UOp(x.op, x.dtype, (idx,)+post_mask, x.arg)
return UOp(x.op, x.dtype, (idx,x.src[2])+post_mask, x.arg)
def do_reduce(ctx:List[int], root:UOp):
acc = UOp(Ops.DEFINE_ACC, root.dtype,
(root.const_like(identity_element(root.arg, root.dtype.scalar())),) + tuple(root.src[1:]), (ctx[0],))
@@ -148,8 +140,6 @@ def do_reduce(ctx:List[int], root:UOp):
return acc.assign(acc.alu(root.arg, root.src[0]))
just_reduce = PatternMatcher([
# use indexing for LOAD/STORE
(UPat((Ops.LOAD, Ops.STORE), src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)),), allow_any_len=True, name="x"), idx_load_store),
# do reduce
(UPat(Ops.REDUCE, name="root"), do_reduce),
])

View File

@@ -791,8 +791,7 @@ spec = PatternMatcher([
(UPat(Ops.WHERE, name="w", src=(UPat(dtype=dtypes.bool), UPat(name="x"), UPat(name="y"))), lambda w,x,y: w.dtype == x.dtype == y.dtype),
(UPat((Ops.CMPLT, Ops.CMPNE), dtype=dtypes.bool, src=(UPat(name="x"), UPat(name="y"))), lambda x,y: x.dtype == y.dtype),
# and SHL/SHR, the shift distance can be an int
(UPat((Ops.SHL, Ops.SHR), src=(UPat(name="x"), UPat(name="y")), name="alu"),
lambda alu,x,y: alu.dtype == x.dtype and (x.dtype == y.dtype or y.dtype == dtypes.uint)),
(UPat((Ops.SHL, Ops.SHR), src=(UPat(name="x"), UPat(name="y")), name="a"), lambda a,x,y: a.dtype == x.dtype and y.dtype in (x.dtype, dtypes.uint)),
(UPat(Ops.IDIV, name="x"), lambda x: None if dtypes.is_int(x.dtype) else False),
(UPat(GroupOp.ALU, name="x"), lambda x: all(x.dtype == y.dtype for y in x.src)),