From f2e81f7208322212f310641a0208cc4c68aeaad2 Mon Sep 17 00:00:00 2001 From: George Hotz Date: Sat, 28 Jan 2023 12:17:40 -0800 Subject: [PATCH] line reduction and cleanups --- tinygrad/llops/ops_gpu.py | 20 ++++++-------------- 1 file changed, 6 insertions(+), 14 deletions(-) diff --git a/tinygrad/llops/ops_gpu.py b/tinygrad/llops/ops_gpu.py index 7c0c3a842d..8488653e53 100644 --- a/tinygrad/llops/ops_gpu.py +++ b/tinygrad/llops/ops_gpu.py @@ -51,9 +51,7 @@ class CLASTKernel(ASTKernel): assert self.buftokens[buf_index].typ == Types.FLOAT4, f"image must be FLOAT4 {self.buftokens[buf_index]} {self.bufs[buf_index].st}" idx = (idxy//4)%self.bufs[buf_index]._base_shape[1] idy = (idxy//(4*self.bufs[buf_index]._base_shape[1]))%self.bufs[buf_index]._base_shape[0] - if validhacks: - if isinstance(idx, ModNode) and idx.max < idx.b*2: idx = idx.a - if isinstance(idy, ModNode) and idy.max < idy.b*2: idy = idy.a + if validhacks: idx, idy = [x.a if isinstance(x, ModNode) and x.max < x.b*2 else x for x in (idx, idy)] return f"(int2)({idx.cl}, {idy.cl})" def store(self, buf_index, value:List[Token]): @@ -70,26 +68,20 @@ class CLASTKernel(ASTKernel): self.kernel.append(f"data{buf_index}[{(idxy//(4 if v.typ == Types.FLOAT4 else 1)).cl}] = {v.tok};\n") def load(self, buf_index:int) -> List[Token]: - tokens = [] - # constant folding + const = None if self.bufs[buf_index]._base_shape == (1,) and self.bufs[buf_index]._backing is not None: assert self.buftokens[buf_index].typ == Types.FLOAT self.bufs_to_delete.add(buf_index) const = Token(f"({self.bufs[buf_index]._backing[0]}f)", self.buftokens[buf_index].typ) - if self.bufs[buf_index].st.needs_valid(): - for o in self.buftokens[buf_index].offsets(): - _, valid = self.compute_buf_index_symbolic(self.sts[buf_index], o) - tokens.append(Token(f"({valid.cl} ? {const.tok} : 0.0f)", const.typ) if str(valid) != "1" else const) - return tokens - else: - return [const]*self.buftokens[buf_index].size() - # not constant folded + tokens = [] for o in self.buftokens[buf_index].offsets(): if (buf_index, o) not in self.loaded_keys: idxy, valid = self.compute_buf_index_symbolic(self.sts[buf_index], o) - if isinstance(self.bufs[buf_index]._buf, CLImage): + if const is not None: + ldr = const + elif isinstance(self.bufs[buf_index]._buf, CLImage): ldr = Token(f"read_imagef({self.buftokens[buf_index].tok}, smp, {self.image_idx(buf_index, idxy, VALIDHACKS)}) /* {self.bufs[buf_index]._base_shape} */", Types.FLOAT4) else: ldr = Token(f"{self.buftokens[buf_index].tok}[{(idxy//(4 if self.buftokens[buf_index].typ == Types.FLOAT4 else 1)).cl}]", self.buftokens[buf_index].typ)