mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-24 14:28:09 -05:00
factor out image_idx
This commit is contained in:
@@ -48,20 +48,26 @@ class CLASTKernel(ASTKernel):
|
||||
else: idx = v.expr_node(idx)
|
||||
return idx, valid
|
||||
|
||||
def image_idx(self, buf_index, idxy, validhacks=False):
|
||||
assert self.buftokens[buf_index].typ == Types.FLOAT4, f"image must be FLOAT4 {self.buftokens[buf_index]} {self.bufs[buf_index].st}"
|
||||
idx = (idxy//4)%self.bufs[buf_index]._base_shape[1]
|
||||
idy = (idxy//(4*self.bufs[buf_index]._base_shape[1]))%self.bufs[buf_index]._base_shape[0]
|
||||
if validhacks:
|
||||
if isinstance(idx, ModNode) and idx.max < idx.b*2: idx = idx.a
|
||||
if isinstance(idy, ModNode) and idy.max < idy.b*2: idy = idy.a
|
||||
return f"(int2)({idx.cl}, {idy.cl})"
|
||||
|
||||
def store(self, buf_index, value:List[Token]):
|
||||
if len(value) == self.buftokens[buf_index].size()*4: value = group_float4(value)
|
||||
if len(value)*4 == self.buftokens[buf_index].size(): value = split_float4(value)
|
||||
assert len(value) == self.buftokens[buf_index].size(), f"size mismatch {len(value)} != {self.buftokens[buf_index].size()}"
|
||||
for v, o in zip(value, self.buftokens[buf_index].offsets()):
|
||||
idxy, valid = self.compute_buf_index_symbolic(self.bufs[buf_index].st, buf_index, o)
|
||||
assert str(valid) == "1"
|
||||
assert str(valid) == "1", "store must always be valid"
|
||||
assert self.buftokens[buf_index].typ == v.typ, f"buf must be {v.typ}"
|
||||
if isinstance(self.bufs[buf_index]._buf, CLImage):
|
||||
assert self.buftokens[buf_index].typ == Types.FLOAT4, "image must be FLOAT4"
|
||||
idx = (idxy//4)%self.bufs[buf_index]._base_shape[1]
|
||||
idy = (idxy//(4*self.bufs[buf_index]._base_shape[1]))%self.bufs[buf_index]._base_shape[0]
|
||||
self.kernel.append(f"write_imagef(data{buf_index}, (int2)({idx.cl}, {idy.cl}), {v.tok}); /* {self.bufs[buf_index]._base_shape} */\n")
|
||||
self.kernel.append(f"write_imagef(data{buf_index}, {self.image_idx(buf_index, idxy)}, {v.tok}); /* {self.bufs[buf_index]._base_shape} */\n")
|
||||
else:
|
||||
assert self.buftokens[buf_index].typ == v.typ, f"buf must be {v.typ}"
|
||||
self.kernel.append(f"data{buf_index}[{(idxy//(4 if v.typ == Types.FLOAT4 else 1)).cl}] = {v.tok};\n")
|
||||
|
||||
def load(self, buf_index:int) -> List[Token]:
|
||||
@@ -85,17 +91,8 @@ class CLASTKernel(ASTKernel):
|
||||
if (buf_index, o) not in self.loaded_keys:
|
||||
idxy, valid = self.compute_buf_index_symbolic(self.bufs[buf_index].st, buf_index, o)
|
||||
if isinstance(self.bufs[buf_index]._buf, CLImage):
|
||||
assert self.buftokens[buf_index].typ == Types.FLOAT4, f"image must be FLOAT4 {self.buftokens[buf_index]} {self.bufs[buf_index].st}"
|
||||
idx = (idxy//4)%self.bufs[buf_index]._base_shape[1]
|
||||
idy = (idxy//(4*self.bufs[buf_index]._base_shape[1]))%self.bufs[buf_index]._base_shape[0]
|
||||
|
||||
if VALIDHACKS:
|
||||
if isinstance(idx, ModNode) and idx.max < idx.b*2: idx = idx.a
|
||||
if isinstance(idy, ModNode) and idy.max < idy.b*2: idy = idy.a
|
||||
valid = None
|
||||
|
||||
ldrt = f"read_imagef({self.buftokens[buf_index].tok}, smp, (int2)({idx.cl}, {idy.cl})) /* {self.bufs[buf_index]._base_shape} */"
|
||||
ldr = Token(f"({valid.cl} ? \\ \n {ldrt} : (float4)(0.0, 0.0, 0.0, 0.0))" if str(valid) != "1" and valid is not None else ldrt, Types.FLOAT4)
|
||||
ldrt = f"read_imagef({self.buftokens[buf_index].tok}, smp, {self.image_idx(buf_index, idxy, VALIDHACKS)}) /* {self.bufs[buf_index]._base_shape} */"
|
||||
ldr = Token(f"({valid.cl} ? \\ \n {ldrt} : (float4)(0.0, 0.0, 0.0, 0.0))" if str(valid) != "1" and not VALIDHACKS else ldrt, Types.FLOAT4)
|
||||
else:
|
||||
ldr = Token(f"{self.buftokens[buf_index].tok}[{(idxy//(4 if self.buftokens[buf_index].typ == Types.FLOAT4 else 1)).cl}]", self.buftokens[buf_index].typ)
|
||||
ldr = Token(f"({valid.cl} ? {ldr.tok} : 0.0f)", ldr.typ) if str(valid) != "1" else ldr
|
||||
|
||||
Reference in New Issue
Block a user