move expr_idxs to shapetracker

This commit is contained in:
George Hotz
2023-01-28 12:25:05 -08:00
parent f2e81f7208
commit 0f34c24aeb
2 changed files with 14 additions and 14 deletions

View File

@@ -6,8 +6,8 @@ from tinygrad.helpers import prod
from tinygrad.ops import DEBUG, UnaryOps, BinaryOps, ReduceOps, MovementOps, LazyOp, Op, ExplicitExecAST, GlobalCounters
from tinygrad.ast import ASTKernel, Token, Types
from tinygrad.lazy import IMAGE
from tinygrad.shape import ShapeTracker, ZeroView
from tinygrad.shape.symbolic import Variable, ModNode
from tinygrad.shape import ShapeTracker
from tinygrad.shape.symbolic import ModNode # this will go away when VALIDHACKS does
CUDA = int(os.getenv("CUDA", "0"))
if not CUDA: from tinygrad.runtime.opencl import CLBuffer, CLImage, CLProgram, CL # NOTE: using CL will not work for the CUDA runtime # noqa: F401
@@ -38,15 +38,6 @@ class CLASTKernel(ASTKernel):
}
start_for_op = {ReduceOps.SUM: "0.0", ReduceOps.MAX: "-INFINITY"}
# TODO: move to shapetracker
def compute_buf_index_symbolic(self, st, offset=0):
idx = st.views[-1].expr_idxs([f"idx{i}" for i in range(self.shape_len)], offset)
valid = Variable.num(1)
for v in st.views[0:-1][::-1]:
if isinstance(v, ZeroView): valid = v.expr_node(valid, idx)
else: idx = v.expr_node(idx)
return idx, valid
def image_idx(self, buf_index, idxy, validhacks=False):
assert self.buftokens[buf_index].typ == Types.FLOAT4, f"image must be FLOAT4 {self.buftokens[buf_index]} {self.bufs[buf_index].st}"
idx = (idxy//4)%self.bufs[buf_index]._base_shape[1]
@@ -59,7 +50,7 @@ class CLASTKernel(ASTKernel):
if len(value)*4 == self.buftokens[buf_index].size(): value = split_float4(value)
assert len(value) == self.buftokens[buf_index].size(), f"size mismatch {len(value)} != {self.buftokens[buf_index].size()}"
for v, o in zip(value, self.buftokens[buf_index].offsets()):
idxy, valid = self.compute_buf_index_symbolic(self.sts[buf_index], o)
idxy, valid = self.sts[buf_index].expr_idxs(o)
assert str(valid) == "1", "store must always be valid"
assert self.buftokens[buf_index].typ == v.typ, f"buf must be {v.typ}"
if isinstance(self.bufs[buf_index]._buf, CLImage):
@@ -78,7 +69,7 @@ class CLASTKernel(ASTKernel):
tokens = []
for o in self.buftokens[buf_index].offsets():
if (buf_index, o) not in self.loaded_keys:
idxy, valid = self.compute_buf_index_symbolic(self.sts[buf_index], o)
idxy, valid = self.sts[buf_index].expr_idxs(o)
if const is not None:
ldr = const
elif isinstance(self.bufs[buf_index]._buf, CLImage):
@@ -266,7 +257,7 @@ class CLASTKernel(ASTKernel):
# middle
if self.group_for_reduce:
lidx, lvalid = self.compute_buf_index_symbolic(self.sts[-1])
lidx, lvalid = self.sts[-1].expr_idxs()
assert str(lvalid) == "1", "local buffer must be valid"
self.kernel.append(f"__local {accumulators[0].decltype()} {self.buftokens[-1].tok}[{prod(self.group_for_reduce)}]; // second stage\n")

View File

@@ -109,6 +109,15 @@ class ShapeTracker:
@property
def offset(self) -> int: return self.views[-1].offset
# TODO: pass in the idxs?
def expr_idxs(self, offset=0):
idx = self.views[-1].expr_idxs([f"idx{i}" for i in range(len(self.shape))], offset)
valid = Variable.num(1)
for v in self.views[0:-1][::-1]:
if isinstance(v, ZeroView): valid = v.expr_node(valid, idx)
else: idx = v.expr_node(idx)
return idx, valid
def expr_node(self):
idx = Variable('idx', 0, prod(self.shape)-1)
valid = None #Variable.num(1)