remove hacks from can_merge

This commit is contained in:
George Hotz
2023-02-28 15:30:20 -08:00
parent e21df1701b
commit 1702a5779f
2 changed files with 17 additions and 10 deletions

View File

@@ -88,21 +88,24 @@ class CLASTKernel(ASTKernel):
val = self.bufs[buf_index]._backing[0]
assert not math.isnan(val)
const = Token(f"({val}f)", Types.FLOAT)
can_merge = (not self.bufs[buf_index].st.needs_valid() and len(self.bufs[buf_index].st.views) == 1) or "Image" in str(type(self.bufs[buf_index]._buf))
should_upcast = not CLANG and const is None and can_merge and self.buftokens[buf_index].can_float4()
should_upcast = not CLANG and const is None and self.buftokens[buf_index].can_float4()
tokens = []
for o in self.buftokens[buf_index].offsets():
key = f"val{buf_index}_{o}" if o >= 0 else f"val{buf_index}_m{-o}"
if (buf_index, o) not in self.loaded_keys:
idxy, valid = self.sts[buf_index].expr_idxs(o)
if should_upcast:
can_merge = True
for j in range(1,4):
idxy_test, valid_test = self.sts[buf_index].expr_idxs(o+j)
can_merge = can_merge and valid.render() == valid_test.render()
can_merge = can_merge and (idxy+j).render() == idxy_test.render()
if const is not None:
ldr = const
elif isinstance(self.bufs[buf_index]._buf, CLImage):
assert should_upcast, f"Image requires upcasting to FLOAT4 {self.buftokens[buf_index]}"
assert should_upcast and can_merge, f"Image requires upcasting to FLOAT4 {self.buftokens[buf_index]}"
ldr = Token(f"read_imagef({self.buftokens[buf_index].tok}, smp, {self.image_idx(buf_index, idxy, VALIDHACKS)}) /* {self.bufs[buf_index]._base_shape} */", Types.FLOAT4)
elif should_upcast:
elif should_upcast and can_merge:
ldr = Token(f"(({CLProgram.buffer_prefix}float4*){self.buftokens[buf_index].tok})[{(idxy//4).render(render_cl)}]", Types.FLOAT4)
else:
ldr = Token(f"{self.buftokens[buf_index].tok}[{idxy.render(render_cl)}]", Types.FLOAT)
@@ -111,7 +114,7 @@ class CLASTKernel(ASTKernel):
self.loaded_keys[(buf_index,o)] = ldr
else:
self.kernel.append(f"{ldr.decltype()} {key} = {ldr.tok};\n")
if should_upcast:
if should_upcast and can_merge:
for j in range(4):
self.loaded_keys[(buf_index,o+j)] = Token(key+f'.{"xyzw"[j]}', Types.FLOAT)
else:

View File

@@ -12,6 +12,7 @@ class Node:
if self.min == self.max and type(self) != NumNode: return NumNode(self.min).render(ops, ctx)
return ops[type(self)](self, ops, ctx)
def __add__(self, b:int): return Variable.sum([self, Variable.num(b)])
def __sub__(self, b:int): return Variable.sum([self, Variable.num(-b)])
def __mul__(self, b:int):
if b == 0: return NumNode(0)
elif b == 1: return self
@@ -77,9 +78,12 @@ class Node:
def sum(nodes:List[Node]) -> Node:
nodes, num_nodes = partition(nodes, lambda x: not isinstance(x, NumNode))
num_sum = sum([x.b for x in num_nodes])
# TODO: this is broken due to something with negatives mods
if num_sum > 0: nodes.append(NumNode(num_sum))
else: nodes += [NumNode(x.b) for x in num_nodes if x.b != 0]
# TODO: this is broken due to something with negative mods. $50 for a PR that fixes this
if num_sum >= 0: nodes.append(NumNode(num_sum))
else:
lte_0, rest = partition(num_nodes, lambda x: x.b <= 0)
nodes += [NumNode(x.b) for x in sorted(lte_0, key=lambda x:x.b) if x.b != 0]
if len(rest): nodes += [NumNode(sum([x.b for x in rest]))]
if any([isinstance(x, SumNode) for x in nodes]):
nodes, sum_nodes = partition(nodes, lambda x: not isinstance(x, SumNode))