factor out image_idx

This commit is contained in:
George Hotz
2023-01-28 00:22:54 -08:00
parent bd8a5c2ced
commit 2b5bc5d4a1

View File

@@ -48,20 +48,26 @@ class CLASTKernel(ASTKernel):
else: idx = v.expr_node(idx)
return idx, valid
def image_idx(self, buf_index, idxy, validhacks=False):
assert self.buftokens[buf_index].typ == Types.FLOAT4, f"image must be FLOAT4 {self.buftokens[buf_index]} {self.bufs[buf_index].st}"
idx = (idxy//4)%self.bufs[buf_index]._base_shape[1]
idy = (idxy//(4*self.bufs[buf_index]._base_shape[1]))%self.bufs[buf_index]._base_shape[0]
if validhacks:
if isinstance(idx, ModNode) and idx.max < idx.b*2: idx = idx.a
if isinstance(idy, ModNode) and idy.max < idy.b*2: idy = idy.a
return f"(int2)({idx.cl}, {idy.cl})"
def store(self, buf_index, value:List[Token]):
if len(value) == self.buftokens[buf_index].size()*4: value = group_float4(value)
if len(value)*4 == self.buftokens[buf_index].size(): value = split_float4(value)
assert len(value) == self.buftokens[buf_index].size(), f"size mismatch {len(value)} != {self.buftokens[buf_index].size()}"
for v, o in zip(value, self.buftokens[buf_index].offsets()):
idxy, valid = self.compute_buf_index_symbolic(self.bufs[buf_index].st, buf_index, o)
assert str(valid) == "1"
assert str(valid) == "1", "store must always be valid"
assert self.buftokens[buf_index].typ == v.typ, f"buf must be {v.typ}"
if isinstance(self.bufs[buf_index]._buf, CLImage):
assert self.buftokens[buf_index].typ == Types.FLOAT4, "image must be FLOAT4"
idx = (idxy//4)%self.bufs[buf_index]._base_shape[1]
idy = (idxy//(4*self.bufs[buf_index]._base_shape[1]))%self.bufs[buf_index]._base_shape[0]
self.kernel.append(f"write_imagef(data{buf_index}, (int2)({idx.cl}, {idy.cl}), {v.tok}); /* {self.bufs[buf_index]._base_shape} */\n")
self.kernel.append(f"write_imagef(data{buf_index}, {self.image_idx(buf_index, idxy)}, {v.tok}); /* {self.bufs[buf_index]._base_shape} */\n")
else:
assert self.buftokens[buf_index].typ == v.typ, f"buf must be {v.typ}"
self.kernel.append(f"data{buf_index}[{(idxy//(4 if v.typ == Types.FLOAT4 else 1)).cl}] = {v.tok};\n")
def load(self, buf_index:int) -> List[Token]:
@@ -85,17 +91,8 @@ class CLASTKernel(ASTKernel):
if (buf_index, o) not in self.loaded_keys:
idxy, valid = self.compute_buf_index_symbolic(self.bufs[buf_index].st, buf_index, o)
if isinstance(self.bufs[buf_index]._buf, CLImage):
assert self.buftokens[buf_index].typ == Types.FLOAT4, f"image must be FLOAT4 {self.buftokens[buf_index]} {self.bufs[buf_index].st}"
idx = (idxy//4)%self.bufs[buf_index]._base_shape[1]
idy = (idxy//(4*self.bufs[buf_index]._base_shape[1]))%self.bufs[buf_index]._base_shape[0]
if VALIDHACKS:
if isinstance(idx, ModNode) and idx.max < idx.b*2: idx = idx.a
if isinstance(idy, ModNode) and idy.max < idy.b*2: idy = idy.a
valid = None
ldrt = f"read_imagef({self.buftokens[buf_index].tok}, smp, (int2)({idx.cl}, {idy.cl})) /* {self.bufs[buf_index]._base_shape} */"
ldr = Token(f"({valid.cl} ? \\ \n {ldrt} : (float4)(0.0, 0.0, 0.0, 0.0))" if str(valid) != "1" and valid is not None else ldrt, Types.FLOAT4)
ldrt = f"read_imagef({self.buftokens[buf_index].tok}, smp, {self.image_idx(buf_index, idxy, VALIDHACKS)}) /* {self.bufs[buf_index]._base_shape} */"
ldr = Token(f"({valid.cl} ? \\ \n {ldrt} : (float4)(0.0, 0.0, 0.0, 0.0))" if str(valid) != "1" and not VALIDHACKS else ldrt, Types.FLOAT4)
else:
ldr = Token(f"{self.buftokens[buf_index].tok}[{(idxy//(4 if self.buftokens[buf_index].typ == Types.FLOAT4 else 1)).cl}]", self.buftokens[buf_index].typ)
ldr = Token(f"({valid.cl} ? {ldr.tok} : 0.0f)", ldr.typ) if str(valid) != "1" else ldr