line reduction and cleanups

This commit is contained in:
George Hotz
2023-01-28 12:17:40 -08:00
parent 03dd1201dc
commit f2e81f7208

View File

@@ -51,9 +51,7 @@ class CLASTKernel(ASTKernel):
assert self.buftokens[buf_index].typ == Types.FLOAT4, f"image must be FLOAT4 {self.buftokens[buf_index]} {self.bufs[buf_index].st}"
idx = (idxy//4)%self.bufs[buf_index]._base_shape[1]
idy = (idxy//(4*self.bufs[buf_index]._base_shape[1]))%self.bufs[buf_index]._base_shape[0]
if validhacks:
if isinstance(idx, ModNode) and idx.max < idx.b*2: idx = idx.a
if isinstance(idy, ModNode) and idy.max < idy.b*2: idy = idy.a
if validhacks: idx, idy = [x.a if isinstance(x, ModNode) and x.max < x.b*2 else x for x in (idx, idy)]
return f"(int2)({idx.cl}, {idy.cl})"
def store(self, buf_index, value:List[Token]):
@@ -70,26 +68,20 @@ class CLASTKernel(ASTKernel):
self.kernel.append(f"data{buf_index}[{(idxy//(4 if v.typ == Types.FLOAT4 else 1)).cl}] = {v.tok};\n")
def load(self, buf_index:int) -> List[Token]:
tokens = []
# constant folding
const = None
if self.bufs[buf_index]._base_shape == (1,) and self.bufs[buf_index]._backing is not None:
assert self.buftokens[buf_index].typ == Types.FLOAT
self.bufs_to_delete.add(buf_index)
const = Token(f"({self.bufs[buf_index]._backing[0]}f)", self.buftokens[buf_index].typ)
if self.bufs[buf_index].st.needs_valid():
for o in self.buftokens[buf_index].offsets():
_, valid = self.compute_buf_index_symbolic(self.sts[buf_index], o)
tokens.append(Token(f"({valid.cl} ? {const.tok} : 0.0f)", const.typ) if str(valid) != "1" else const)
return tokens
else:
return [const]*self.buftokens[buf_index].size()
# not constant folded
tokens = []
for o in self.buftokens[buf_index].offsets():
if (buf_index, o) not in self.loaded_keys:
idxy, valid = self.compute_buf_index_symbolic(self.sts[buf_index], o)
if isinstance(self.bufs[buf_index]._buf, CLImage):
if const is not None:
ldr = const
elif isinstance(self.bufs[buf_index]._buf, CLImage):
ldr = Token(f"read_imagef({self.buftokens[buf_index].tok}, smp, {self.image_idx(buf_index, idxy, VALIDHACKS)}) /* {self.bufs[buf_index]._base_shape} */", Types.FLOAT4)
else:
ldr = Token(f"{self.buftokens[buf_index].tok}[{(idxy//(4 if self.buftokens[buf_index].typ == Types.FLOAT4 else 1)).cl}]", self.buftokens[buf_index].typ)