mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-22 21:38:10 -05:00
remove hacks from can_merge
This commit is contained in:
@@ -88,21 +88,24 @@ class CLASTKernel(ASTKernel):
|
||||
val = self.bufs[buf_index]._backing[0]
|
||||
assert not math.isnan(val)
|
||||
const = Token(f"({val}f)", Types.FLOAT)
|
||||
|
||||
can_merge = (not self.bufs[buf_index].st.needs_valid() and len(self.bufs[buf_index].st.views) == 1) or "Image" in str(type(self.bufs[buf_index]._buf))
|
||||
should_upcast = not CLANG and const is None and can_merge and self.buftokens[buf_index].can_float4()
|
||||
|
||||
should_upcast = not CLANG and const is None and self.buftokens[buf_index].can_float4()
|
||||
tokens = []
|
||||
for o in self.buftokens[buf_index].offsets():
|
||||
key = f"val{buf_index}_{o}" if o >= 0 else f"val{buf_index}_m{-o}"
|
||||
if (buf_index, o) not in self.loaded_keys:
|
||||
idxy, valid = self.sts[buf_index].expr_idxs(o)
|
||||
if should_upcast:
|
||||
can_merge = True
|
||||
for j in range(1,4):
|
||||
idxy_test, valid_test = self.sts[buf_index].expr_idxs(o+j)
|
||||
can_merge = can_merge and valid.render() == valid_test.render()
|
||||
can_merge = can_merge and (idxy+j).render() == idxy_test.render()
|
||||
if const is not None:
|
||||
ldr = const
|
||||
elif isinstance(self.bufs[buf_index]._buf, CLImage):
|
||||
assert should_upcast, f"Image requires upcasting to FLOAT4 {self.buftokens[buf_index]}"
|
||||
assert should_upcast and can_merge, f"Image requires upcasting to FLOAT4 {self.buftokens[buf_index]}"
|
||||
ldr = Token(f"read_imagef({self.buftokens[buf_index].tok}, smp, {self.image_idx(buf_index, idxy, VALIDHACKS)}) /* {self.bufs[buf_index]._base_shape} */", Types.FLOAT4)
|
||||
elif should_upcast:
|
||||
elif should_upcast and can_merge:
|
||||
ldr = Token(f"(({CLProgram.buffer_prefix}float4*){self.buftokens[buf_index].tok})[{(idxy//4).render(render_cl)}]", Types.FLOAT4)
|
||||
else:
|
||||
ldr = Token(f"{self.buftokens[buf_index].tok}[{idxy.render(render_cl)}]", Types.FLOAT)
|
||||
@@ -111,7 +114,7 @@ class CLASTKernel(ASTKernel):
|
||||
self.loaded_keys[(buf_index,o)] = ldr
|
||||
else:
|
||||
self.kernel.append(f"{ldr.decltype()} {key} = {ldr.tok};\n")
|
||||
if should_upcast:
|
||||
if should_upcast and can_merge:
|
||||
for j in range(4):
|
||||
self.loaded_keys[(buf_index,o+j)] = Token(key+f'.{"xyzw"[j]}', Types.FLOAT)
|
||||
else:
|
||||
|
||||
@@ -12,6 +12,7 @@ class Node:
|
||||
if self.min == self.max and type(self) != NumNode: return NumNode(self.min).render(ops, ctx)
|
||||
return ops[type(self)](self, ops, ctx)
|
||||
def __add__(self, b:int): return Variable.sum([self, Variable.num(b)])
|
||||
def __sub__(self, b:int): return Variable.sum([self, Variable.num(-b)])
|
||||
def __mul__(self, b:int):
|
||||
if b == 0: return NumNode(0)
|
||||
elif b == 1: return self
|
||||
@@ -77,9 +78,12 @@ class Node:
|
||||
def sum(nodes:List[Node]) -> Node:
|
||||
nodes, num_nodes = partition(nodes, lambda x: not isinstance(x, NumNode))
|
||||
num_sum = sum([x.b for x in num_nodes])
|
||||
# TODO: this is broken due to something with negatives mods
|
||||
if num_sum > 0: nodes.append(NumNode(num_sum))
|
||||
else: nodes += [NumNode(x.b) for x in num_nodes if x.b != 0]
|
||||
# TODO: this is broken due to something with negative mods. $50 for a PR that fixes this
|
||||
if num_sum >= 0: nodes.append(NumNode(num_sum))
|
||||
else:
|
||||
lte_0, rest = partition(num_nodes, lambda x: x.b <= 0)
|
||||
nodes += [NumNode(x.b) for x in sorted(lte_0, key=lambda x:x.b) if x.b != 0]
|
||||
if len(rest): nodes += [NumNode(sum([x.b for x in rest]))]
|
||||
|
||||
if any([isinstance(x, SumNode) for x in nodes]):
|
||||
nodes, sum_nodes = partition(nodes, lambda x: not isinstance(x, SumNode))
|
||||
|
||||
Reference in New Issue
Block a user