From 7ea89779fa5a6ac3eacf09137f9db186fda9784d Mon Sep 17 00:00:00 2001 From: George Hotz Date: Sun, 15 Jan 2023 08:58:10 -0800 Subject: [PATCH] add returns between views --- tinygrad/llops/ops_gpu.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tinygrad/llops/ops_gpu.py b/tinygrad/llops/ops_gpu.py index da878561a6..330bbc52d9 100644 --- a/tinygrad/llops/ops_gpu.py +++ b/tinygrad/llops/ops_gpu.py @@ -113,11 +113,11 @@ class CLASTKernel(ASTKernel): key = f"{buf_index}_{offset}" # add the index if we don't have it if key not in self.seen_idx: - idx_pieces = [str(st.offset + offset)] + [(f"idx{i}*{st}" if st != 1 else f"idx{i}") for i,(sh,st) in enumerate(zip(self.shapes[buf_index][0:self.last_reduce], self.strides[buf_index][0:self.last_reduce])) if sh != 1 and st != 0] + idx_pieces = [str(self.offsets[buf_index] + offset)] + [(f"idx{i}*{st}" if st != 1 else f"idx{i}") for i,(sh,st) in enumerate(zip(self.shapes[buf_index][0:self.last_reduce], self.strides[buf_index][0:self.last_reduce])) if sh != 1 and st != 0] if st.needs_valid(): self.kernel.append(f"bool bufvalid{key} = true;") self.kernel.append(f"int bufi{key} = " + '('+' + '.join(idx_pieces)+');\n') if len(st.views) > 1: - extra_idx = ';'.join([v.expr for v in st.views[0:-1][::-1] if v.expr not in ['', 'idx=idx', 'valid=valid']]) + extra_idx = ';\n '.join([v.expr for v in st.views[0:-1][::-1] if v.expr not in ['', 'idx=idx', 'valid=valid']]) self.kernel.append(extra_idx.replace("//", "/").replace("idx", f"bufi{key}").replace("valid", f"bufvalid{key}") + ";\n") self.seen_idx.add(key) return key