From 03dd1201dc6cd3fadaf4dbb8b38b4084c9dc68cf Mon Sep 17 00:00:00 2001 From: George Hotz Date: Sat, 28 Jan 2023 12:06:28 -0800 Subject: [PATCH] local buffer implied --- tinygrad/llops/ops_gpu.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tinygrad/llops/ops_gpu.py b/tinygrad/llops/ops_gpu.py index 362b1240a0..7c0c3a842d 100644 --- a/tinygrad/llops/ops_gpu.py +++ b/tinygrad/llops/ops_gpu.py @@ -234,9 +234,7 @@ class CLASTKernel(ASTKernel): # add a local buffer for multistage reduce if len(self.group_for_reduce): - local_buffer = GPUBuffer([1] * self.first_reduce + self.group_for_reduce + [1] * (self.shape_len - len(self.group_for_reduce) - self.first_reduce)) - self.bufs.append(local_buffer) - self.sts.append(local_buffer.st.copy()) + self.sts.append(ShapeTracker(tuple([1] * self.first_reduce + self.group_for_reduce + [1] * (self.shape_len - len(self.group_for_reduce) - self.first_reduce)))) self.buftokens.append(Token("temp", Types.FLOAT, ptr=True)) self.output_shape = list(self.sts[0].shape[:self.first_reduce]) + self.group_for_reduce @@ -276,7 +274,7 @@ class CLASTKernel(ASTKernel): # middle if self.group_for_reduce: - lidx, lvalid = self.compute_buf_index_symbolic(local_buffer.st) + lidx, lvalid = self.compute_buf_index_symbolic(self.sts[-1]) assert str(lvalid) == "1", "local buffer must be valid" self.kernel.append(f"__local {accumulators[0].decltype()} {self.buftokens[-1].tok}[{prod(self.group_for_reduce)}]; // second stage\n")