diff --git a/tinygrad/llops/ops_gpu.py b/tinygrad/llops/ops_gpu.py index cb88277669..fd829d0e41 100644 --- a/tinygrad/llops/ops_gpu.py +++ b/tinygrad/llops/ops_gpu.py @@ -88,21 +88,24 @@ class CLASTKernel(ASTKernel): val = self.bufs[buf_index]._backing[0] assert not math.isnan(val) const = Token(f"({val}f)", Types.FLOAT) - - can_merge = (not self.bufs[buf_index].st.needs_valid() and len(self.bufs[buf_index].st.views) == 1) or "Image" in str(type(self.bufs[buf_index]._buf)) - should_upcast = not CLANG and const is None and can_merge and self.buftokens[buf_index].can_float4() - + should_upcast = not CLANG and const is None and self.buftokens[buf_index].can_float4() tokens = [] for o in self.buftokens[buf_index].offsets(): key = f"val{buf_index}_{o}" if o >= 0 else f"val{buf_index}_m{-o}" if (buf_index, o) not in self.loaded_keys: idxy, valid = self.sts[buf_index].expr_idxs(o) + if should_upcast: + can_merge = True + for j in range(1,4): + idxy_test, valid_test = self.sts[buf_index].expr_idxs(o+j) + can_merge = can_merge and valid.render() == valid_test.render() + can_merge = can_merge and (idxy+j).render() == idxy_test.render() if const is not None: ldr = const elif isinstance(self.bufs[buf_index]._buf, CLImage): - assert should_upcast, f"Image requires upcasting to FLOAT4 {self.buftokens[buf_index]}" + assert should_upcast and can_merge, f"Image requires upcasting to FLOAT4 {self.buftokens[buf_index]}" ldr = Token(f"read_imagef({self.buftokens[buf_index].tok}, smp, {self.image_idx(buf_index, idxy, VALIDHACKS)}) /* {self.bufs[buf_index]._base_shape} */", Types.FLOAT4) - elif should_upcast: + elif should_upcast and can_merge: ldr = Token(f"(({CLProgram.buffer_prefix}float4*){self.buftokens[buf_index].tok})[{(idxy//4).render(render_cl)}]", Types.FLOAT4) else: ldr = Token(f"{self.buftokens[buf_index].tok}[{idxy.render(render_cl)}]", Types.FLOAT) @@ -111,7 +114,7 @@ class CLASTKernel(ASTKernel): self.loaded_keys[(buf_index,o)] = ldr else: self.kernel.append(f"{ldr.decltype()} {key} = {ldr.tok};\n") - if should_upcast: + if should_upcast and can_merge: for j in range(4): self.loaded_keys[(buf_index,o+j)] = Token(key+f'.{"xyzw"[j]}', Types.FLOAT) else: diff --git a/tinygrad/shape/symbolic.py b/tinygrad/shape/symbolic.py index 1e5e80c769..ab1ef43cdf 100644 --- a/tinygrad/shape/symbolic.py +++ b/tinygrad/shape/symbolic.py @@ -12,6 +12,7 @@ class Node: if self.min == self.max and type(self) != NumNode: return NumNode(self.min).render(ops, ctx) return ops[type(self)](self, ops, ctx) def __add__(self, b:int): return Variable.sum([self, Variable.num(b)]) + def __sub__(self, b:int): return Variable.sum([self, Variable.num(-b)]) def __mul__(self, b:int): if b == 0: return NumNode(0) elif b == 1: return self @@ -77,9 +78,12 @@ class Node: def sum(nodes:List[Node]) -> Node: nodes, num_nodes = partition(nodes, lambda x: not isinstance(x, NumNode)) num_sum = sum([x.b for x in num_nodes]) - # TODO: this is broken due to something with negatives mods - if num_sum > 0: nodes.append(NumNode(num_sum)) - else: nodes += [NumNode(x.b) for x in num_nodes if x.b != 0] + # TODO: this is broken due to something with negative mods. $50 for a PR that fixes this + if num_sum >= 0: nodes.append(NumNode(num_sum)) + else: + lte_0, rest = partition(num_nodes, lambda x: x.b <= 0) + nodes += [NumNode(x.b) for x in sorted(lte_0, key=lambda x:x.b) if x.b != 0] + if len(rest): nodes += [NumNode(sum([x.b for x in rest]))] if any([isinstance(x, SumNode) for x in nodes]): nodes, sum_nodes = partition(nodes, lambda x: not isinstance(x, SumNode))