mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-24 22:38:16 -05:00
refactor group_for_reduce a little
This commit is contained in:
@@ -300,6 +300,15 @@ class CLASTKernel(ASTKernel):
|
||||
self.reshape_and_permute(lambda x: [base_shape[0], x[0]//base_shape[0]]+list(x[1:]), None)
|
||||
self.simplify_ones()
|
||||
|
||||
# group for reduce
|
||||
if len(self.group_for_reduce):
|
||||
# with permute for memory coalesing
|
||||
if len(self.group_for_reduce) == 2:
|
||||
permute_axis = list(range(0, self.first_reduce)) + [self.first_reduce+1, self.shape_len, self.first_reduce] + list(range(self.first_reduce+2, self.shape_len))
|
||||
else:
|
||||
permute_axis = list(range(0, self.first_reduce)) + [self.first_reduce+1, self.first_reduce] + list(range(self.first_reduce+2, self.shape_len+1))
|
||||
self.reshape_and_permute(lambda x: list(x[0:self.first_reduce]) + [max(1, x[self.first_reduce]//self.group_for_reduce[0]), min(x[self.first_reduce], self.group_for_reduce[0])] + list(x[self.first_reduce+1:]), permute_axis)
|
||||
|
||||
# STOP WASTING TIME WITH DOING THE RESHAPES AND PERMUTES BY HAND. KERNEL SEARCH IS THE ONLY WAY IT WILL EVER BE GOOD
|
||||
# group_for_reduce will have to be better first
|
||||
def codegen(self):
|
||||
@@ -309,25 +318,12 @@ class CLASTKernel(ASTKernel):
|
||||
|
||||
self.hand_coded_optimizations()
|
||||
|
||||
# group for reduce
|
||||
# TODO: clean this up to be in optimizations
|
||||
self.output_shape = self.shapes[0][:self.first_reduce]
|
||||
if len(self.group_for_reduce):
|
||||
# with permute for memory coalesing
|
||||
if len(self.group_for_reduce) == 2:
|
||||
permute_axis = list(range(0, self.first_reduce)) + [self.first_reduce+1, self.shape_len, self.first_reduce] + list(range(self.first_reduce+2, self.shape_len))
|
||||
else:
|
||||
permute_axis = list(range(0, self.first_reduce)) + [self.first_reduce+1, self.first_reduce] + list(range(self.first_reduce+2, self.shape_len+1))
|
||||
self.reshape_and_permute(lambda x: list(x[0:self.first_reduce]) + [max(1, x[self.first_reduce]//self.group_for_reduce[0]), min(x[self.first_reduce], self.group_for_reduce[0])] + list(x[self.first_reduce+1:]), permute_axis)
|
||||
|
||||
self.first_reduce += len(self.group_for_reduce)
|
||||
self.output_shape += self.group_for_reduce
|
||||
|
||||
self.output_shape = list(self.shapes[0][:self.first_reduce]) + self.group_for_reduce
|
||||
if DEBUG >= 3:
|
||||
print(f"first_reduce: {self.first_reduce} shape_len: {self.shape_len}")
|
||||
print(f"first_reduce: {self.first_reduce} shape_len: {self.shape_len} group_for_reduce: {self.group_for_reduce}")
|
||||
print("output shape", self.output_shape)
|
||||
for i in range(len(self.bufs)):
|
||||
print(self.buftokens[i], self.bufs[i] in self.earlybufs, self.shapes[i], self.strides[i])
|
||||
print(self.buftokens[i], f"early:{'T' if self.bufs[i] in self.earlybufs else 'F'} image:{'T' if isinstance(self.bufs[i]._buf, CLImage) else 'F'}", self.shapes[i], self.strides[i])
|
||||
|
||||
self.bufs_to_delete : Set[int] = set()
|
||||
self.loaded_keys : Dict[Tuple[int,int], Token] = {}
|
||||
@@ -354,20 +350,22 @@ class CLASTKernel(ASTKernel):
|
||||
full_shape = self.shapes[0] if len(full_shape) == 0 else full_shape[0]
|
||||
|
||||
self.kernel += [f"{accumulator.decltype()} {accumulator.tok} = {CLASTKernel.start_for_op[self.reduceop.op]};\n" for accumulator in accumulators]
|
||||
self.kernel += [f"for (int idx{i} = 0; idx{i} < {full_shape[i]}; idx{i}++) {{\n" for i in range(self.first_reduce, self.shape_len)]
|
||||
self.kernel += [f"{x.tok};\n" for x in self.ast_parse(self.reduceop, accumulators, do_reduce=True)] + ["}\n"] * (self.shape_len - self.first_reduce)
|
||||
self.kernel += [f"for (int idx{i} = 0; idx{i} < {full_shape[i]}; idx{i}++) {{\n" for i in range(self.first_reduce+len(self.group_for_reduce), self.shape_len)]
|
||||
self.kernel += [f"{x.tok};\n" for x in self.ast_parse(self.reduceop, accumulators, do_reduce=True)] + ["}\n"] * (self.shape_len - (self.first_reduce + len(self.group_for_reduce)))
|
||||
|
||||
# middle
|
||||
if self.group_for_reduce:
|
||||
self.kernel.append(f"__local {accumulators[0].decltype()} temp[{prod(self.group_for_reduce)}]; // second stage\n")
|
||||
|
||||
if self.upcast_in_mid_reduce:
|
||||
assert len(self.group_for_reduce) == 2
|
||||
# it should be the last dimension
|
||||
self.kernel.append(f"int mid_idx = idx{self.first_reduce-2}*{self.group_for_reduce[1]} + idx{self.first_reduce-1}; temp[mid_idx] = {accumulators[0].tok}; barrier(CLK_LOCAL_MEM_FENCE);\n")
|
||||
self.reshape_and_permute(None, [i for i in range(self.shape_len) if i != self.first_reduce-1] + [self.first_reduce-1])
|
||||
self.kernel.append(f"int mid_idx = idx{self.first_reduce}*{self.group_for_reduce[1]} + idx{self.first_reduce+1}; temp[mid_idx] = {accumulators[0].tok}; barrier(CLK_LOCAL_MEM_FENCE);\n")
|
||||
self.reshape_and_permute(None, [i for i in range(self.shape_len) if i != self.first_reduce+1] + [self.first_reduce+1])
|
||||
self.upcast()
|
||||
else:
|
||||
self.kernel.append(f"int mid_idx = idx{self.first_reduce-1}; temp[mid_idx] = {accumulators[0].tok}; barrier(CLK_LOCAL_MEM_FENCE);\n")
|
||||
assert len(self.group_for_reduce) == 1
|
||||
self.kernel.append(f"int mid_idx = idx{self.first_reduce}; temp[mid_idx] = {accumulators[0].tok}; barrier(CLK_LOCAL_MEM_FENCE);\n")
|
||||
|
||||
self.kernel.append("if (mid_idx == 0) {\n")
|
||||
accumulators = [Token("output", self.buftokens[0].typ)]
|
||||
|
||||
Reference in New Issue
Block a user