mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-22 21:38:10 -05:00
lowerer img index (#5592)
This commit is contained in:
@@ -202,7 +202,6 @@ class TestSymbolic(unittest.TestCase):
|
||||
def test_sum_div_no_factor(self):
|
||||
self.helper_test_variable(Node.sum([Variable("a", 0, 7)*5, Variable("b", 0, 3)*5]) // 2, 0, 25, "(((a*5)+(b*5))//2)")
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_mod_factor(self):
|
||||
# NOTE: even though the mod max is 50, it can't know this without knowing about the mul
|
||||
self.helper_test_variable(Node.sum([Variable("a", 0, 7)*100, Variable("b", 0, 3)*50]) % 100, 0, 99, "((b*50)%100)")
|
||||
|
||||
@@ -4,7 +4,7 @@ import functools
|
||||
from tinygrad.shape.shapetracker import ShapeTracker, View
|
||||
from tinygrad.shape.symbolic import sint
|
||||
from tinygrad.dtype import dtypes, PtrDType, ImageDType, DType
|
||||
from tinygrad.ops import BufferOps, LazyOp, TernaryOps, ReduceOps, UnaryOps, MetaOps, KernelInfo
|
||||
from tinygrad.ops import BufferOps, LazyOp, TernaryOps, ReduceOps, UnaryOps, MetaOps, KernelInfo, MemBuffer
|
||||
from tinygrad.codegen.uops import UOp, UOps
|
||||
from tinygrad.renderer import Renderer
|
||||
from tinygrad.helpers import getenv, all_int, get_contraction
|
||||
@@ -33,7 +33,7 @@ if getenv("UOP_IS_SYMBOLIC"):
|
||||
if m[1] != sh: vexpr = vexpr * idx.lt(variable_to_uop(m[1]))
|
||||
return iexpr, vexpr
|
||||
|
||||
def st_to_uops(st:ShapeTracker, idxs:List[UOp]) -> Tuple[UOp, UOp]:
|
||||
def st_to_uops(st:ShapeTracker, idxs:List[UOp], dtype:DType) -> Tuple[UOp, UOp]:
|
||||
idx, valid = _uop_view(st.views[-1], idxs, UOp.const(dtypes.bool, True))
|
||||
for view in reversed(st.views[0:-1]):
|
||||
view = view.minify()
|
||||
@@ -43,13 +43,20 @@ if getenv("UOP_IS_SYMBOLIC"):
|
||||
idxs.append((idx//acc)%d)
|
||||
acc *= d
|
||||
idx, valid = _uop_view(view, idxs[::-1], valid)
|
||||
if isinstance(dtype, ImageDType):
|
||||
idx = UOp(UOps.VECTORIZE, dtypes.int.vec(3), ((idx // 4) % dtype.shape[1], (idx // (4 * dtype.shape[1])), idx % 4))
|
||||
return idx, valid
|
||||
else:
|
||||
def st_to_uops(st:ShapeTracker, idxs:List[UOp]) -> Tuple[UOp, UOp]:
|
||||
def st_to_uops(st:ShapeTracker, idxs:List[UOp], dtype:DType) -> Tuple[UOp, UOp]:
|
||||
fake_idxs = [Variable(f"__idx{i}", 0, s-1) for i,s in enumerate(st.shape)]
|
||||
idx, valid = st.expr_idxs(fake_idxs)
|
||||
ctx = dict(zip(fake_idxs, idxs))
|
||||
uidx, uvalid = idx.render(render_ops, ctx), valid.render(render_ops, ctx)
|
||||
uvalid = valid.render(render_ops, ctx)
|
||||
if isinstance(dtype, ImageDType):
|
||||
image_idxs = (idx // 4) % dtype.shape[1], (idx // (4 * dtype.shape[1])), idx % 4
|
||||
uidx = UOp(UOps.VECTORIZE, dtypes.int.vec(3), tuple(x.render(render_ops, ctx) for x in image_idxs))
|
||||
else:
|
||||
uidx = idx.render(render_ops, ctx)
|
||||
if uvalid.op is UOps.CONST: uvalid = UOp.const(dtypes.bool, uvalid.arg)
|
||||
assert uvalid.dtype == dtypes.bool
|
||||
return uidx, uvalid
|
||||
@@ -130,7 +137,8 @@ class IndependentLowerer:
|
||||
|
||||
def _to_uop(self, x:LazyOp) -> UOp:
|
||||
if x.op in BufferOps:
|
||||
idx, valid = st_to_uops(x.arg.st, self.ridxs if x.op is BufferOps.LOAD and x.arg.idx == -1 else self.idxs)
|
||||
idx, valid = st_to_uops(x.arg.st, self.ridxs if x.op is BufferOps.LOAD and x.arg.idx == -1 else self.idxs,
|
||||
x.arg.dtype.base if isinstance(x.arg.dtype, ImageDType) and (not isinstance(x.arg, MemBuffer) or x.arg.idx == -1) else x.arg.dtype)
|
||||
# TODO: check has_valid in UPat, not here
|
||||
has_valid = valid.op is not UOps.CONST or valid.arg is not True
|
||||
if x.op is BufferOps.CONST:
|
||||
|
||||
@@ -79,23 +79,18 @@ class PatternMatcher:
|
||||
|
||||
# ***** image handling *****
|
||||
|
||||
def fix_image_idx(ls:UOp):
|
||||
if ls.src[1].dtype is None or ls.src[1].dtype.count != 1: return None
|
||||
if not isinstance(ls.src[0].dtype, ImageDType): return None
|
||||
assert ls.op is not UOps.STORE or cast(DType, ls.src[2].dtype).count == 4, "image store must be float4"
|
||||
idxy = ls.src[1]
|
||||
#if not idxy.divides(4): raise RuntimeError("image index must divide 4")
|
||||
base_shape = ls.src[0].dtype.shape
|
||||
idx, idy = (idxy // 4) % base_shape[1], (idxy // (4 * base_shape[1]))
|
||||
image_idx = UOp(UOps.VECTORIZE, cast(DType, idxy.dtype).vec(2), (idx, idy))
|
||||
if ls.op is UOps.LOAD and cast(DType, ls.dtype).count == 1:
|
||||
cconst = (UOp(UOps.VECTORIZE, cast(DType, ls.dtype).vec(4), src=(ls.src[3], ls.src[3], ls.src[3], ls.src[3])),) if len(ls.src) >= 3 else ()
|
||||
loaded = UOp(ls.op, cast(DType, ls.dtype).vec(4), (ls.src[0], image_idx) + ls.src[2:3] + cconst, ls.arg)
|
||||
subidx = idxy%4
|
||||
ret = UOp.const(ls.dtype, 0)
|
||||
for i in range(4): ret = UOp.alu(TernaryOps.WHERE, subidx.ne(i), ret, UOp(UOps.GEP, ls.dtype, (loaded,), i))
|
||||
return ret
|
||||
return UOp(ls.op, ls.dtype, (ls.src[0], image_idx) + ls.src[2:], ls.arg)
|
||||
def image_contract_load(buf, idx, idy, id4, ls_allow_any_len):
|
||||
if len(ls_allow_any_len.src) > 3:
|
||||
# TODO: there's no contract on the gate, is this okay?
|
||||
extra = (ls_allow_any_len.src[2], UOp(UOps.VECTORIZE, ls_allow_any_len.dtype.vec(4), (ls_allow_any_len.src[3],)*4))
|
||||
else: extra = ls_allow_any_len.src[2:] # NOTE: image load shouldn't have barrier and this shouldn't matter
|
||||
vec_load = UOp(UOps.LOAD, ls_allow_any_len.dtype.vec(4), (buf, UOp(UOps.VECTORIZE, dtypes.int.vec(2), (idx, idy))) + extra)
|
||||
return functools.reduce(lambda ret, i: UOp.alu(TernaryOps.WHERE, id4.ne(i), ret, UOp(UOps.GEP, ls_allow_any_len.dtype, (vec_load,), i)), range(4),
|
||||
UOp.const(ls_allow_any_len.dtype, float('nan')))
|
||||
|
||||
def image_contract_store(buf, ex, idx, idy, ls_allow_any_len, var):
|
||||
new_var = UOp(UOps.CONTRACT, var.dtype.vec(4), (var,), (ex.arg[0][0],))
|
||||
return UOp(UOps.STORE, None, (buf, UOp(UOps.VECTORIZE, dtypes.int.vec(2), (idx, idy)), new_var) + ls_allow_any_len.src[3:])
|
||||
|
||||
# ***** float4 handling *****
|
||||
|
||||
@@ -142,6 +137,12 @@ float4_folding = PatternMatcher([
|
||||
UOp(UOps.EXPAND).name("ex")+UOp.var("idx"), UOp.var("var"))).name("store_allow_any_len"), float4_contract_store),
|
||||
(UOp(UOps.STORE, src=(UOp.var("buf"),
|
||||
UOp(UOps.EXPAND).name("ex"), UOp.var("var"))).name("store_allow_any_len"), float4_contract_store),
|
||||
# image handling
|
||||
(UOp(UOps.LOAD, src=(UOp.var("buf"), UOp(UOps.VECTORIZE, dtypes.int.vec(3), (UOp.var('idx'), UOp.var('idy'),
|
||||
UOp.var('id4'))))).name("ls_allow_any_len"), image_contract_load),
|
||||
(UOp(UOps.STORE, src=(UOp.var("buf"), UOp(UOps.VECTORIZE, dtypes.int.vec(3), (UOp.var('idx'), UOp.var('idy'),
|
||||
UOp(UOps.EXPAND, src=tuple(UOp.const(dtypes.int, i) for i in range(4))).name("ex"))), UOp.var("var"))).name("ls_allow_any_len"),
|
||||
image_contract_store),
|
||||
])
|
||||
|
||||
# ***** transcendental *****
|
||||
@@ -300,6 +301,9 @@ constant_folder = PatternMatcher([
|
||||
((UOp.var('x') + UOp.cvar('c1')) + UOp.cvar('c2'), lambda x,c1,c2: x+UOp.const(x.dtype, exec_alu(BinaryOps.ADD, x.dtype, [c1.arg, c2.arg]))),
|
||||
((UOp.var('x') - UOp.cvar('c1')) + UOp.cvar('c2'), lambda x,c1,c2: x+UOp.const(x.dtype, exec_alu(BinaryOps.ADD, x.dtype, [c2.arg, -c1.arg]))),
|
||||
# *** rules from symbolic ***
|
||||
# mod divides
|
||||
((UOp.cvar('c')*UOp.var('x')) % UOp.cvar('c'), lambda x,c: x.const(0)),
|
||||
(((UOp.cvar('c')*UOp.var('x'))+UOp.var('x2')) % UOp.cvar('c'), lambda x,c,x2: x2%c),
|
||||
# two stage mul, (x*c1)*c2 = x*(c1*c2)
|
||||
((UOp.var("x") * UOp.cvar("c1")) * UOp.cvar("c2"), lambda x,c1,c2: x*UOp.const(x.dtype, exec_alu(BinaryOps.MUL, x.dtype, [c1.arg, c2.arg]))),
|
||||
# -(x+y) -> -x + -y
|
||||
@@ -464,8 +468,6 @@ expander = PatternMatcher([
|
||||
if len(a:=tuple(flatten(x.src if x.op is UOps.EXPAND else (x,) for x in root.src))) != len(root.src) else None),
|
||||
# BARRIERs aren't actually expanded
|
||||
(UOp(UOps.BARRIER, src=(UOp(UOps.EXPAND).name("ex"),)), lambda ex: UOp(UOps.EXPAND, None, (UOp(UOps.BARRIER, None, ex.src),)*len(ex.src), ex.arg)),
|
||||
# image indexing (needs to be here)
|
||||
(UPat({UOps.LOAD, UOps.STORE}, name="ls"), fix_image_idx),
|
||||
# empty EXPAND is NOOP
|
||||
(UOp(UOps.EXPAND, src=(UOp.var('x'),), arg=()), lambda x: x),
|
||||
# no ALU on vectorized dtypes
|
||||
|
||||
Reference in New Issue
Block a user