refactor group_for_reduce a little

This commit is contained in:
George Hotz
2023-01-27 08:51:23 -08:00
parent 82e58108e3
commit 123993156d

View File

@@ -300,6 +300,15 @@ class CLASTKernel(ASTKernel):
self.reshape_and_permute(lambda x: [base_shape[0], x[0]//base_shape[0]]+list(x[1:]), None)
self.simplify_ones()
# group for reduce
if len(self.group_for_reduce):
# with permute for memory coalesing
if len(self.group_for_reduce) == 2:
permute_axis = list(range(0, self.first_reduce)) + [self.first_reduce+1, self.shape_len, self.first_reduce] + list(range(self.first_reduce+2, self.shape_len))
else:
permute_axis = list(range(0, self.first_reduce)) + [self.first_reduce+1, self.first_reduce] + list(range(self.first_reduce+2, self.shape_len+1))
self.reshape_and_permute(lambda x: list(x[0:self.first_reduce]) + [max(1, x[self.first_reduce]//self.group_for_reduce[0]), min(x[self.first_reduce], self.group_for_reduce[0])] + list(x[self.first_reduce+1:]), permute_axis)
# STOP WASTING TIME WITH DOING THE RESHAPES AND PERMUTES BY HAND. KERNEL SEARCH IS THE ONLY WAY IT WILL EVER BE GOOD
# group_for_reduce will have to be better first
def codegen(self):
@@ -309,25 +318,12 @@ class CLASTKernel(ASTKernel):
self.hand_coded_optimizations()
# group for reduce
# TODO: clean this up to be in optimizations
self.output_shape = self.shapes[0][:self.first_reduce]
if len(self.group_for_reduce):
# with permute for memory coalesing
if len(self.group_for_reduce) == 2:
permute_axis = list(range(0, self.first_reduce)) + [self.first_reduce+1, self.shape_len, self.first_reduce] + list(range(self.first_reduce+2, self.shape_len))
else:
permute_axis = list(range(0, self.first_reduce)) + [self.first_reduce+1, self.first_reduce] + list(range(self.first_reduce+2, self.shape_len+1))
self.reshape_and_permute(lambda x: list(x[0:self.first_reduce]) + [max(1, x[self.first_reduce]//self.group_for_reduce[0]), min(x[self.first_reduce], self.group_for_reduce[0])] + list(x[self.first_reduce+1:]), permute_axis)
self.first_reduce += len(self.group_for_reduce)
self.output_shape += self.group_for_reduce
self.output_shape = list(self.shapes[0][:self.first_reduce]) + self.group_for_reduce
if DEBUG >= 3:
print(f"first_reduce: {self.first_reduce} shape_len: {self.shape_len}")
print(f"first_reduce: {self.first_reduce} shape_len: {self.shape_len} group_for_reduce: {self.group_for_reduce}")
print("output shape", self.output_shape)
for i in range(len(self.bufs)):
print(self.buftokens[i], self.bufs[i] in self.earlybufs, self.shapes[i], self.strides[i])
print(self.buftokens[i], f"early:{'T' if self.bufs[i] in self.earlybufs else 'F'} image:{'T' if isinstance(self.bufs[i]._buf, CLImage) else 'F'}", self.shapes[i], self.strides[i])
self.bufs_to_delete : Set[int] = set()
self.loaded_keys : Dict[Tuple[int,int], Token] = {}
@@ -354,20 +350,22 @@ class CLASTKernel(ASTKernel):
full_shape = self.shapes[0] if len(full_shape) == 0 else full_shape[0]
self.kernel += [f"{accumulator.decltype()} {accumulator.tok} = {CLASTKernel.start_for_op[self.reduceop.op]};\n" for accumulator in accumulators]
self.kernel += [f"for (int idx{i} = 0; idx{i} < {full_shape[i]}; idx{i}++) {{\n" for i in range(self.first_reduce, self.shape_len)]
self.kernel += [f"{x.tok};\n" for x in self.ast_parse(self.reduceop, accumulators, do_reduce=True)] + ["}\n"] * (self.shape_len - self.first_reduce)
self.kernel += [f"for (int idx{i} = 0; idx{i} < {full_shape[i]}; idx{i}++) {{\n" for i in range(self.first_reduce+len(self.group_for_reduce), self.shape_len)]
self.kernel += [f"{x.tok};\n" for x in self.ast_parse(self.reduceop, accumulators, do_reduce=True)] + ["}\n"] * (self.shape_len - (self.first_reduce + len(self.group_for_reduce)))
# middle
if self.group_for_reduce:
self.kernel.append(f"__local {accumulators[0].decltype()} temp[{prod(self.group_for_reduce)}]; // second stage\n")
if self.upcast_in_mid_reduce:
assert len(self.group_for_reduce) == 2
# it should be the last dimension
self.kernel.append(f"int mid_idx = idx{self.first_reduce-2}*{self.group_for_reduce[1]} + idx{self.first_reduce-1}; temp[mid_idx] = {accumulators[0].tok}; barrier(CLK_LOCAL_MEM_FENCE);\n")
self.reshape_and_permute(None, [i for i in range(self.shape_len) if i != self.first_reduce-1] + [self.first_reduce-1])
self.kernel.append(f"int mid_idx = idx{self.first_reduce}*{self.group_for_reduce[1]} + idx{self.first_reduce+1}; temp[mid_idx] = {accumulators[0].tok}; barrier(CLK_LOCAL_MEM_FENCE);\n")
self.reshape_and_permute(None, [i for i in range(self.shape_len) if i != self.first_reduce+1] + [self.first_reduce+1])
self.upcast()
else:
self.kernel.append(f"int mid_idx = idx{self.first_reduce-1}; temp[mid_idx] = {accumulators[0].tok}; barrier(CLK_LOCAL_MEM_FENCE);\n")
assert len(self.group_for_reduce) == 1
self.kernel.append(f"int mid_idx = idx{self.first_reduce}; temp[mid_idx] = {accumulators[0].tok}; barrier(CLK_LOCAL_MEM_FENCE);\n")
self.kernel.append("if (mid_idx == 0) {\n")
accumulators = [Token("output", self.buftokens[0].typ)]