diff --git a/tinygrad/llops/ops_gpu.py b/tinygrad/llops/ops_gpu.py index 2cca5665b3..9c9c24a81c 100644 --- a/tinygrad/llops/ops_gpu.py +++ b/tinygrad/llops/ops_gpu.py @@ -48,20 +48,26 @@ class CLASTKernel(ASTKernel): else: idx = v.expr_node(idx) return idx, valid + def image_idx(self, buf_index, idxy, validhacks=False): + assert self.buftokens[buf_index].typ == Types.FLOAT4, f"image must be FLOAT4 {self.buftokens[buf_index]} {self.bufs[buf_index].st}" + idx = (idxy//4)%self.bufs[buf_index]._base_shape[1] + idy = (idxy//(4*self.bufs[buf_index]._base_shape[1]))%self.bufs[buf_index]._base_shape[0] + if validhacks: + if isinstance(idx, ModNode) and idx.max < idx.b*2: idx = idx.a + if isinstance(idy, ModNode) and idy.max < idy.b*2: idy = idy.a + return f"(int2)({idx.cl}, {idy.cl})" + def store(self, buf_index, value:List[Token]): if len(value) == self.buftokens[buf_index].size()*4: value = group_float4(value) if len(value)*4 == self.buftokens[buf_index].size(): value = split_float4(value) assert len(value) == self.buftokens[buf_index].size(), f"size mismatch {len(value)} != {self.buftokens[buf_index].size()}" for v, o in zip(value, self.buftokens[buf_index].offsets()): idxy, valid = self.compute_buf_index_symbolic(self.bufs[buf_index].st, buf_index, o) - assert str(valid) == "1" + assert str(valid) == "1", "store must always be valid" + assert self.buftokens[buf_index].typ == v.typ, f"buf must be {v.typ}" if isinstance(self.bufs[buf_index]._buf, CLImage): - assert self.buftokens[buf_index].typ == Types.FLOAT4, "image must be FLOAT4" - idx = (idxy//4)%self.bufs[buf_index]._base_shape[1] - idy = (idxy//(4*self.bufs[buf_index]._base_shape[1]))%self.bufs[buf_index]._base_shape[0] - self.kernel.append(f"write_imagef(data{buf_index}, (int2)({idx.cl}, {idy.cl}), {v.tok}); /* {self.bufs[buf_index]._base_shape} */\n") + self.kernel.append(f"write_imagef(data{buf_index}, {self.image_idx(buf_index, idxy)}, {v.tok}); /* {self.bufs[buf_index]._base_shape} */\n") else: - assert self.buftokens[buf_index].typ == v.typ, f"buf must be {v.typ}" self.kernel.append(f"data{buf_index}[{(idxy//(4 if v.typ == Types.FLOAT4 else 1)).cl}] = {v.tok};\n") def load(self, buf_index:int) -> List[Token]: @@ -85,17 +91,8 @@ class CLASTKernel(ASTKernel): if (buf_index, o) not in self.loaded_keys: idxy, valid = self.compute_buf_index_symbolic(self.bufs[buf_index].st, buf_index, o) if isinstance(self.bufs[buf_index]._buf, CLImage): - assert self.buftokens[buf_index].typ == Types.FLOAT4, f"image must be FLOAT4 {self.buftokens[buf_index]} {self.bufs[buf_index].st}" - idx = (idxy//4)%self.bufs[buf_index]._base_shape[1] - idy = (idxy//(4*self.bufs[buf_index]._base_shape[1]))%self.bufs[buf_index]._base_shape[0] - - if VALIDHACKS: - if isinstance(idx, ModNode) and idx.max < idx.b*2: idx = idx.a - if isinstance(idy, ModNode) and idy.max < idy.b*2: idy = idy.a - valid = None - - ldrt = f"read_imagef({self.buftokens[buf_index].tok}, smp, (int2)({idx.cl}, {idy.cl})) /* {self.bufs[buf_index]._base_shape} */" - ldr = Token(f"({valid.cl} ? \\ \n {ldrt} : (float4)(0.0, 0.0, 0.0, 0.0))" if str(valid) != "1" and valid is not None else ldrt, Types.FLOAT4) + ldrt = f"read_imagef({self.buftokens[buf_index].tok}, smp, {self.image_idx(buf_index, idxy, VALIDHACKS)}) /* {self.bufs[buf_index]._base_shape} */" + ldr = Token(f"({valid.cl} ? \\ \n {ldrt} : (float4)(0.0, 0.0, 0.0, 0.0))" if str(valid) != "1" and not VALIDHACKS else ldrt, Types.FLOAT4) else: ldr = Token(f"{self.buftokens[buf_index].tok}[{(idxy//(4 if self.buftokens[buf_index].typ == Types.FLOAT4 else 1)).cl}]", self.buftokens[buf_index].typ) ldr = Token(f"({valid.cl} ? {ldr.tok} : 0.0f)", ldr.typ) if str(valid) != "1" else ldr