mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-24 06:18:01 -05:00
local buffer implied
This commit is contained in:
@@ -234,9 +234,7 @@ class CLASTKernel(ASTKernel):
|
||||
|
||||
# add a local buffer for multistage reduce
|
||||
if len(self.group_for_reduce):
|
||||
local_buffer = GPUBuffer([1] * self.first_reduce + self.group_for_reduce + [1] * (self.shape_len - len(self.group_for_reduce) - self.first_reduce))
|
||||
self.bufs.append(local_buffer)
|
||||
self.sts.append(local_buffer.st.copy())
|
||||
self.sts.append(ShapeTracker(tuple([1] * self.first_reduce + self.group_for_reduce + [1] * (self.shape_len - len(self.group_for_reduce) - self.first_reduce))))
|
||||
self.buftokens.append(Token("temp", Types.FLOAT, ptr=True))
|
||||
|
||||
self.output_shape = list(self.sts[0].shape[:self.first_reduce]) + self.group_for_reduce
|
||||
@@ -276,7 +274,7 @@ class CLASTKernel(ASTKernel):
|
||||
|
||||
# middle
|
||||
if self.group_for_reduce:
|
||||
lidx, lvalid = self.compute_buf_index_symbolic(local_buffer.st)
|
||||
lidx, lvalid = self.compute_buf_index_symbolic(self.sts[-1])
|
||||
assert str(lvalid) == "1", "local buffer must be valid"
|
||||
|
||||
self.kernel.append(f"__local {accumulators[0].decltype()} {self.buftokens[-1].tok}[{prod(self.group_for_reduce)}]; // second stage\n")
|
||||
|
||||
Reference in New Issue
Block a user