mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
move expr_idxs to shapetracker
This commit is contained in:
@@ -6,8 +6,8 @@ from tinygrad.helpers import prod
|
||||
from tinygrad.ops import DEBUG, UnaryOps, BinaryOps, ReduceOps, MovementOps, LazyOp, Op, ExplicitExecAST, GlobalCounters
|
||||
from tinygrad.ast import ASTKernel, Token, Types
|
||||
from tinygrad.lazy import IMAGE
|
||||
from tinygrad.shape import ShapeTracker, ZeroView
|
||||
from tinygrad.shape.symbolic import Variable, ModNode
|
||||
from tinygrad.shape import ShapeTracker
|
||||
from tinygrad.shape.symbolic import ModNode # this will go away when VALIDHACKS does
|
||||
|
||||
CUDA = int(os.getenv("CUDA", "0"))
|
||||
if not CUDA: from tinygrad.runtime.opencl import CLBuffer, CLImage, CLProgram, CL # NOTE: using CL will not work for the CUDA runtime # noqa: F401
|
||||
@@ -38,15 +38,6 @@ class CLASTKernel(ASTKernel):
|
||||
}
|
||||
start_for_op = {ReduceOps.SUM: "0.0", ReduceOps.MAX: "-INFINITY"}
|
||||
|
||||
# TODO: move to shapetracker
|
||||
def compute_buf_index_symbolic(self, st, offset=0):
|
||||
idx = st.views[-1].expr_idxs([f"idx{i}" for i in range(self.shape_len)], offset)
|
||||
valid = Variable.num(1)
|
||||
for v in st.views[0:-1][::-1]:
|
||||
if isinstance(v, ZeroView): valid = v.expr_node(valid, idx)
|
||||
else: idx = v.expr_node(idx)
|
||||
return idx, valid
|
||||
|
||||
def image_idx(self, buf_index, idxy, validhacks=False):
|
||||
assert self.buftokens[buf_index].typ == Types.FLOAT4, f"image must be FLOAT4 {self.buftokens[buf_index]} {self.bufs[buf_index].st}"
|
||||
idx = (idxy//4)%self.bufs[buf_index]._base_shape[1]
|
||||
@@ -59,7 +50,7 @@ class CLASTKernel(ASTKernel):
|
||||
if len(value)*4 == self.buftokens[buf_index].size(): value = split_float4(value)
|
||||
assert len(value) == self.buftokens[buf_index].size(), f"size mismatch {len(value)} != {self.buftokens[buf_index].size()}"
|
||||
for v, o in zip(value, self.buftokens[buf_index].offsets()):
|
||||
idxy, valid = self.compute_buf_index_symbolic(self.sts[buf_index], o)
|
||||
idxy, valid = self.sts[buf_index].expr_idxs(o)
|
||||
assert str(valid) == "1", "store must always be valid"
|
||||
assert self.buftokens[buf_index].typ == v.typ, f"buf must be {v.typ}"
|
||||
if isinstance(self.bufs[buf_index]._buf, CLImage):
|
||||
@@ -78,7 +69,7 @@ class CLASTKernel(ASTKernel):
|
||||
tokens = []
|
||||
for o in self.buftokens[buf_index].offsets():
|
||||
if (buf_index, o) not in self.loaded_keys:
|
||||
idxy, valid = self.compute_buf_index_symbolic(self.sts[buf_index], o)
|
||||
idxy, valid = self.sts[buf_index].expr_idxs(o)
|
||||
if const is not None:
|
||||
ldr = const
|
||||
elif isinstance(self.bufs[buf_index]._buf, CLImage):
|
||||
@@ -266,7 +257,7 @@ class CLASTKernel(ASTKernel):
|
||||
|
||||
# middle
|
||||
if self.group_for_reduce:
|
||||
lidx, lvalid = self.compute_buf_index_symbolic(self.sts[-1])
|
||||
lidx, lvalid = self.sts[-1].expr_idxs()
|
||||
assert str(lvalid) == "1", "local buffer must be valid"
|
||||
|
||||
self.kernel.append(f"__local {accumulators[0].decltype()} {self.buftokens[-1].tok}[{prod(self.group_for_reduce)}]; // second stage\n")
|
||||
|
||||
@@ -109,6 +109,15 @@ class ShapeTracker:
|
||||
@property
|
||||
def offset(self) -> int: return self.views[-1].offset
|
||||
|
||||
# TODO: pass in the idxs?
|
||||
def expr_idxs(self, offset=0):
|
||||
idx = self.views[-1].expr_idxs([f"idx{i}" for i in range(len(self.shape))], offset)
|
||||
valid = Variable.num(1)
|
||||
for v in self.views[0:-1][::-1]:
|
||||
if isinstance(v, ZeroView): valid = v.expr_node(valid, idx)
|
||||
else: idx = v.expr_node(idx)
|
||||
return idx, valid
|
||||
|
||||
def expr_node(self):
|
||||
idx = Variable('idx', 0, prod(self.shape)-1)
|
||||
valid = None #Variable.num(1)
|
||||
|
||||
Reference in New Issue
Block a user