mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-23 13:58:00 -05:00
pre index load and store [pr] (#7535)
* pre index load and store [pr] * check ptrtype
This commit is contained in:
@@ -5,7 +5,7 @@ from dataclasses import dataclass
|
||||
from typing import List, Tuple, cast, Optional
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.view import variable_to_uop
|
||||
from tinygrad.dtype import dtypes, ImageDType
|
||||
from tinygrad.dtype import dtypes, PtrDType
|
||||
from tinygrad.ops import KernelInfo, UOp, Ops, graph_rewrite, PatternMatcher, UPat, sint, identity_element
|
||||
from tinygrad.renderer import Renderer
|
||||
from tinygrad.helpers import all_int, prod, partition, flatten
|
||||
@@ -114,17 +114,17 @@ def lower_load_store(ctx: IndexContext, x: UOp):
|
||||
buf = x.src[0]
|
||||
if x.op is Ops.LOAD:
|
||||
barrier = (UOp(Ops.BARRIER, dtypes.void, (x.src[2],)),) if x.src[0].op is Ops.DEFINE_LOCAL else ()
|
||||
return UOp(Ops.LOAD, x.dtype, (buf, idx) + ((x.const_like(0), valid) if has_valid else ()) + barrier)
|
||||
return UOp(Ops.LOAD, x.dtype, (buf.index(idx, valid if has_valid else None),) + barrier)
|
||||
# NOTE: only store the local reduceop in the threads that are actually doing the reduce
|
||||
store_back = x.src[0].op is Ops.DEFINE_LOCAL and x.src[2].op is Ops.REDUCE and \
|
||||
x.src[2].src[0].op is Ops.LOAD and x.src[2].src[0].src[0].op is Ops.DEFINE_LOCAL
|
||||
store_back = cast(PtrDType, x.src[0].dtype).local and x.src[2].op is Ops.REDUCE and \
|
||||
x.src[2].src[0].op is Ops.LOAD and cast(PtrDType, x.src[2].src[0].src[0].dtype).local
|
||||
# NOTE: If we're storing the reduced value back into each thread, need to zero-out the reduced axes
|
||||
if store_back: idx, _ = x.st_arg.to_indexed_uops([u.const_like(0) if u in x.src[2].src else u for u in ctx.idxs])
|
||||
if x.src[0].op is Ops.DEFINE_GLOBAL or store_back:
|
||||
if (not cast(PtrDType, x.src[0].dtype).local) or store_back:
|
||||
for oidx, ridx in zip(ctx.idxs, ctx.ridxs):
|
||||
if oidx is not ridx: valid = valid * oidx.eq(0)
|
||||
has_valid = valid.op is not Ops.CONST or valid.arg is not True
|
||||
return UOp(Ops.STORE, dtypes.void, (buf, idx, x.src[2]) + ((valid,) if has_valid else ()))
|
||||
return UOp(Ops.STORE, dtypes.void, (buf.index(idx, valid if has_valid else None), x.src[2]))
|
||||
|
||||
pm_lowerer = PatternMatcher([
|
||||
(UPat(Ops.REDUCE_AXIS, name="x"), lower_reduce_axis),
|
||||
@@ -133,14 +133,6 @@ pm_lowerer = PatternMatcher([
|
||||
(UPat((Ops.LOAD, Ops.STORE), src=(UPat(), UPat(Ops.VIEW)), allow_any_len=True, name="x"), lower_load_store),
|
||||
])
|
||||
|
||||
def idx_load_store(x:UOp):
|
||||
idx = x.src[0].index(x.src[1], x.src[3] if len(x.src) > 3 else None)
|
||||
v = x.dtype.count if x.op is Ops.LOAD else x.src[2].dtype.count
|
||||
if v > 1 and not isinstance(x.src[0].dtype, ImageDType): idx = idx.cast(idx.dtype.base.vec(v).ptr(idx.dtype.local))
|
||||
post_mask = x.src[4:] if len(x.src) > 3 else (x.src[2:] if x.op is Ops.LOAD else x.src[3:])
|
||||
if x.op is Ops.LOAD: return UOp(x.op, x.dtype, (idx,)+post_mask, x.arg)
|
||||
return UOp(x.op, x.dtype, (idx,x.src[2])+post_mask, x.arg)
|
||||
|
||||
def do_reduce(ctx:List[int], root:UOp):
|
||||
acc = UOp(Ops.DEFINE_ACC, root.dtype,
|
||||
(root.const_like(identity_element(root.arg, root.dtype.scalar())),) + tuple(root.src[1:]), (ctx[0],))
|
||||
@@ -148,8 +140,6 @@ def do_reduce(ctx:List[int], root:UOp):
|
||||
return acc.assign(acc.alu(root.arg, root.src[0]))
|
||||
|
||||
just_reduce = PatternMatcher([
|
||||
# use indexing for LOAD/STORE
|
||||
(UPat((Ops.LOAD, Ops.STORE), src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)),), allow_any_len=True, name="x"), idx_load_store),
|
||||
# do reduce
|
||||
(UPat(Ops.REDUCE, name="root"), do_reduce),
|
||||
])
|
||||
|
||||
@@ -791,8 +791,7 @@ spec = PatternMatcher([
|
||||
(UPat(Ops.WHERE, name="w", src=(UPat(dtype=dtypes.bool), UPat(name="x"), UPat(name="y"))), lambda w,x,y: w.dtype == x.dtype == y.dtype),
|
||||
(UPat((Ops.CMPLT, Ops.CMPNE), dtype=dtypes.bool, src=(UPat(name="x"), UPat(name="y"))), lambda x,y: x.dtype == y.dtype),
|
||||
# and SHL/SHR, the shift distance can be an int
|
||||
(UPat((Ops.SHL, Ops.SHR), src=(UPat(name="x"), UPat(name="y")), name="alu"),
|
||||
lambda alu,x,y: alu.dtype == x.dtype and (x.dtype == y.dtype or y.dtype == dtypes.uint)),
|
||||
(UPat((Ops.SHL, Ops.SHR), src=(UPat(name="x"), UPat(name="y")), name="a"), lambda a,x,y: a.dtype == x.dtype and y.dtype in (x.dtype, dtypes.uint)),
|
||||
(UPat(Ops.IDIV, name="x"), lambda x: None if dtypes.is_int(x.dtype) else False),
|
||||
(UPat(GroupOp.ALU, name="x"), lambda x: all(x.dtype == y.dtype for y in x.src)),
|
||||
|
||||
|
||||
Reference in New Issue
Block a user