lowerer img index (#5592)

This commit is contained in:
George Hotz
2024-07-19 14:22:02 -07:00
committed by GitHub
parent 3acd8559f4
commit 2e617ca59e
3 changed files with 34 additions and 25 deletions

View File

@@ -202,7 +202,6 @@ class TestSymbolic(unittest.TestCase):
def test_sum_div_no_factor(self):
self.helper_test_variable(Node.sum([Variable("a", 0, 7)*5, Variable("b", 0, 3)*5]) // 2, 0, 25, "(((a*5)+(b*5))//2)")
@unittest.expectedFailure
def test_mod_factor(self):
# NOTE: even though the mod max is 50, it can't know this without knowing about the mul
self.helper_test_variable(Node.sum([Variable("a", 0, 7)*100, Variable("b", 0, 3)*50]) % 100, 0, 99, "((b*50)%100)")

View File

@@ -4,7 +4,7 @@ import functools
from tinygrad.shape.shapetracker import ShapeTracker, View
from tinygrad.shape.symbolic import sint
from tinygrad.dtype import dtypes, PtrDType, ImageDType, DType
from tinygrad.ops import BufferOps, LazyOp, TernaryOps, ReduceOps, UnaryOps, MetaOps, KernelInfo
from tinygrad.ops import BufferOps, LazyOp, TernaryOps, ReduceOps, UnaryOps, MetaOps, KernelInfo, MemBuffer
from tinygrad.codegen.uops import UOp, UOps
from tinygrad.renderer import Renderer
from tinygrad.helpers import getenv, all_int, get_contraction
@@ -33,7 +33,7 @@ if getenv("UOP_IS_SYMBOLIC"):
if m[1] != sh: vexpr = vexpr * idx.lt(variable_to_uop(m[1]))
return iexpr, vexpr
def st_to_uops(st:ShapeTracker, idxs:List[UOp]) -> Tuple[UOp, UOp]:
def st_to_uops(st:ShapeTracker, idxs:List[UOp], dtype:DType) -> Tuple[UOp, UOp]:
idx, valid = _uop_view(st.views[-1], idxs, UOp.const(dtypes.bool, True))
for view in reversed(st.views[0:-1]):
view = view.minify()
@@ -43,13 +43,20 @@ if getenv("UOP_IS_SYMBOLIC"):
idxs.append((idx//acc)%d)
acc *= d
idx, valid = _uop_view(view, idxs[::-1], valid)
if isinstance(dtype, ImageDType):
idx = UOp(UOps.VECTORIZE, dtypes.int.vec(3), ((idx // 4) % dtype.shape[1], (idx // (4 * dtype.shape[1])), idx % 4))
return idx, valid
else:
def st_to_uops(st:ShapeTracker, idxs:List[UOp]) -> Tuple[UOp, UOp]:
def st_to_uops(st:ShapeTracker, idxs:List[UOp], dtype:DType) -> Tuple[UOp, UOp]:
fake_idxs = [Variable(f"__idx{i}", 0, s-1) for i,s in enumerate(st.shape)]
idx, valid = st.expr_idxs(fake_idxs)
ctx = dict(zip(fake_idxs, idxs))
uidx, uvalid = idx.render(render_ops, ctx), valid.render(render_ops, ctx)
uvalid = valid.render(render_ops, ctx)
if isinstance(dtype, ImageDType):
image_idxs = (idx // 4) % dtype.shape[1], (idx // (4 * dtype.shape[1])), idx % 4
uidx = UOp(UOps.VECTORIZE, dtypes.int.vec(3), tuple(x.render(render_ops, ctx) for x in image_idxs))
else:
uidx = idx.render(render_ops, ctx)
if uvalid.op is UOps.CONST: uvalid = UOp.const(dtypes.bool, uvalid.arg)
assert uvalid.dtype == dtypes.bool
return uidx, uvalid
@@ -130,7 +137,8 @@ class IndependentLowerer:
def _to_uop(self, x:LazyOp) -> UOp:
if x.op in BufferOps:
idx, valid = st_to_uops(x.arg.st, self.ridxs if x.op is BufferOps.LOAD and x.arg.idx == -1 else self.idxs)
idx, valid = st_to_uops(x.arg.st, self.ridxs if x.op is BufferOps.LOAD and x.arg.idx == -1 else self.idxs,
x.arg.dtype.base if isinstance(x.arg.dtype, ImageDType) and (not isinstance(x.arg, MemBuffer) or x.arg.idx == -1) else x.arg.dtype)
# TODO: check has_valid in UPat, not here
has_valid = valid.op is not UOps.CONST or valid.arg is not True
if x.op is BufferOps.CONST:

View File

@@ -79,23 +79,18 @@ class PatternMatcher:
# ***** image handling *****
def fix_image_idx(ls:UOp):
if ls.src[1].dtype is None or ls.src[1].dtype.count != 1: return None
if not isinstance(ls.src[0].dtype, ImageDType): return None
assert ls.op is not UOps.STORE or cast(DType, ls.src[2].dtype).count == 4, "image store must be float4"
idxy = ls.src[1]
#if not idxy.divides(4): raise RuntimeError("image index must divide 4")
base_shape = ls.src[0].dtype.shape
idx, idy = (idxy // 4) % base_shape[1], (idxy // (4 * base_shape[1]))
image_idx = UOp(UOps.VECTORIZE, cast(DType, idxy.dtype).vec(2), (idx, idy))
if ls.op is UOps.LOAD and cast(DType, ls.dtype).count == 1:
cconst = (UOp(UOps.VECTORIZE, cast(DType, ls.dtype).vec(4), src=(ls.src[3], ls.src[3], ls.src[3], ls.src[3])),) if len(ls.src) >= 3 else ()
loaded = UOp(ls.op, cast(DType, ls.dtype).vec(4), (ls.src[0], image_idx) + ls.src[2:3] + cconst, ls.arg)
subidx = idxy%4
ret = UOp.const(ls.dtype, 0)
for i in range(4): ret = UOp.alu(TernaryOps.WHERE, subidx.ne(i), ret, UOp(UOps.GEP, ls.dtype, (loaded,), i))
return ret
return UOp(ls.op, ls.dtype, (ls.src[0], image_idx) + ls.src[2:], ls.arg)
def image_contract_load(buf, idx, idy, id4, ls_allow_any_len):
if len(ls_allow_any_len.src) > 3:
# TODO: there's no contract on the gate, is this okay?
extra = (ls_allow_any_len.src[2], UOp(UOps.VECTORIZE, ls_allow_any_len.dtype.vec(4), (ls_allow_any_len.src[3],)*4))
else: extra = ls_allow_any_len.src[2:] # NOTE: image load shouldn't have barrier and this shouldn't matter
vec_load = UOp(UOps.LOAD, ls_allow_any_len.dtype.vec(4), (buf, UOp(UOps.VECTORIZE, dtypes.int.vec(2), (idx, idy))) + extra)
return functools.reduce(lambda ret, i: UOp.alu(TernaryOps.WHERE, id4.ne(i), ret, UOp(UOps.GEP, ls_allow_any_len.dtype, (vec_load,), i)), range(4),
UOp.const(ls_allow_any_len.dtype, float('nan')))
def image_contract_store(buf, ex, idx, idy, ls_allow_any_len, var):
new_var = UOp(UOps.CONTRACT, var.dtype.vec(4), (var,), (ex.arg[0][0],))
return UOp(UOps.STORE, None, (buf, UOp(UOps.VECTORIZE, dtypes.int.vec(2), (idx, idy)), new_var) + ls_allow_any_len.src[3:])
# ***** float4 handling *****
@@ -142,6 +137,12 @@ float4_folding = PatternMatcher([
UOp(UOps.EXPAND).name("ex")+UOp.var("idx"), UOp.var("var"))).name("store_allow_any_len"), float4_contract_store),
(UOp(UOps.STORE, src=(UOp.var("buf"),
UOp(UOps.EXPAND).name("ex"), UOp.var("var"))).name("store_allow_any_len"), float4_contract_store),
# image handling
(UOp(UOps.LOAD, src=(UOp.var("buf"), UOp(UOps.VECTORIZE, dtypes.int.vec(3), (UOp.var('idx'), UOp.var('idy'),
UOp.var('id4'))))).name("ls_allow_any_len"), image_contract_load),
(UOp(UOps.STORE, src=(UOp.var("buf"), UOp(UOps.VECTORIZE, dtypes.int.vec(3), (UOp.var('idx'), UOp.var('idy'),
UOp(UOps.EXPAND, src=tuple(UOp.const(dtypes.int, i) for i in range(4))).name("ex"))), UOp.var("var"))).name("ls_allow_any_len"),
image_contract_store),
])
# ***** transcendental *****
@@ -300,6 +301,9 @@ constant_folder = PatternMatcher([
((UOp.var('x') + UOp.cvar('c1')) + UOp.cvar('c2'), lambda x,c1,c2: x+UOp.const(x.dtype, exec_alu(BinaryOps.ADD, x.dtype, [c1.arg, c2.arg]))),
((UOp.var('x') - UOp.cvar('c1')) + UOp.cvar('c2'), lambda x,c1,c2: x+UOp.const(x.dtype, exec_alu(BinaryOps.ADD, x.dtype, [c2.arg, -c1.arg]))),
# *** rules from symbolic ***
# mod divides
((UOp.cvar('c')*UOp.var('x')) % UOp.cvar('c'), lambda x,c: x.const(0)),
(((UOp.cvar('c')*UOp.var('x'))+UOp.var('x2')) % UOp.cvar('c'), lambda x,c,x2: x2%c),
# two stage mul, (x*c1)*c2 = x*(c1*c2)
((UOp.var("x") * UOp.cvar("c1")) * UOp.cvar("c2"), lambda x,c1,c2: x*UOp.const(x.dtype, exec_alu(BinaryOps.MUL, x.dtype, [c1.arg, c2.arg]))),
# -(x+y) -> -x + -y
@@ -464,8 +468,6 @@ expander = PatternMatcher([
if len(a:=tuple(flatten(x.src if x.op is UOps.EXPAND else (x,) for x in root.src))) != len(root.src) else None),
# BARRIERs aren't actually expanded
(UOp(UOps.BARRIER, src=(UOp(UOps.EXPAND).name("ex"),)), lambda ex: UOp(UOps.EXPAND, None, (UOp(UOps.BARRIER, None, ex.src),)*len(ex.src), ex.arg)),
# image indexing (needs to be here)
(UPat({UOps.LOAD, UOps.STORE}, name="ls"), fix_image_idx),
# empty EXPAND is NOOP
(UOp(UOps.EXPAND, src=(UOp.var('x'),), arg=()), lambda x: x),
# no ALU on vectorized dtypes