local buffer implied

This commit is contained in:
George Hotz
2023-01-28 12:06:28 -08:00
parent b3e4e678e8
commit 03dd1201dc

View File

@@ -234,9 +234,7 @@ class CLASTKernel(ASTKernel):
# add a local buffer for multistage reduce
if len(self.group_for_reduce):
local_buffer = GPUBuffer([1] * self.first_reduce + self.group_for_reduce + [1] * (self.shape_len - len(self.group_for_reduce) - self.first_reduce))
self.bufs.append(local_buffer)
self.sts.append(local_buffer.st.copy())
self.sts.append(ShapeTracker(tuple([1] * self.first_reduce + self.group_for_reduce + [1] * (self.shape_len - len(self.group_for_reduce) - self.first_reduce))))
self.buftokens.append(Token("temp", Types.FLOAT, ptr=True))
self.output_shape = list(self.sts[0].shape[:self.first_reduce]) + self.group_for_reduce
@@ -276,7 +274,7 @@ class CLASTKernel(ASTKernel):
# middle
if self.group_for_reduce:
lidx, lvalid = self.compute_buf_index_symbolic(local_buffer.st)
lidx, lvalid = self.compute_buf_index_symbolic(self.sts[-1])
assert str(lvalid) == "1", "local buffer must be valid"
self.kernel.append(f"__local {accumulators[0].decltype()} {self.buftokens[-1].tok}[{prod(self.group_for_reduce)}]; // second stage\n")