diff --git a/tinygrad/llops/ops_gpu.py b/tinygrad/llops/ops_gpu.py index cede3649b1..11e2d0a66c 100644 --- a/tinygrad/llops/ops_gpu.py +++ b/tinygrad/llops/ops_gpu.py @@ -300,6 +300,15 @@ class CLASTKernel(ASTKernel): self.reshape_and_permute(lambda x: [base_shape[0], x[0]//base_shape[0]]+list(x[1:]), None) self.simplify_ones() + # group for reduce + if len(self.group_for_reduce): + # with permute for memory coalesing + if len(self.group_for_reduce) == 2: + permute_axis = list(range(0, self.first_reduce)) + [self.first_reduce+1, self.shape_len, self.first_reduce] + list(range(self.first_reduce+2, self.shape_len)) + else: + permute_axis = list(range(0, self.first_reduce)) + [self.first_reduce+1, self.first_reduce] + list(range(self.first_reduce+2, self.shape_len+1)) + self.reshape_and_permute(lambda x: list(x[0:self.first_reduce]) + [max(1, x[self.first_reduce]//self.group_for_reduce[0]), min(x[self.first_reduce], self.group_for_reduce[0])] + list(x[self.first_reduce+1:]), permute_axis) + # STOP WASTING TIME WITH DOING THE RESHAPES AND PERMUTES BY HAND. KERNEL SEARCH IS THE ONLY WAY IT WILL EVER BE GOOD # group_for_reduce will have to be better first def codegen(self): @@ -309,25 +318,12 @@ class CLASTKernel(ASTKernel): self.hand_coded_optimizations() - # group for reduce - # TODO: clean this up to be in optimizations - self.output_shape = self.shapes[0][:self.first_reduce] - if len(self.group_for_reduce): - # with permute for memory coalesing - if len(self.group_for_reduce) == 2: - permute_axis = list(range(0, self.first_reduce)) + [self.first_reduce+1, self.shape_len, self.first_reduce] + list(range(self.first_reduce+2, self.shape_len)) - else: - permute_axis = list(range(0, self.first_reduce)) + [self.first_reduce+1, self.first_reduce] + list(range(self.first_reduce+2, self.shape_len+1)) - self.reshape_and_permute(lambda x: list(x[0:self.first_reduce]) + [max(1, x[self.first_reduce]//self.group_for_reduce[0]), min(x[self.first_reduce], self.group_for_reduce[0])] + list(x[self.first_reduce+1:]), permute_axis) - - self.first_reduce += len(self.group_for_reduce) - self.output_shape += self.group_for_reduce - + self.output_shape = list(self.shapes[0][:self.first_reduce]) + self.group_for_reduce if DEBUG >= 3: - print(f"first_reduce: {self.first_reduce} shape_len: {self.shape_len}") + print(f"first_reduce: {self.first_reduce} shape_len: {self.shape_len} group_for_reduce: {self.group_for_reduce}") print("output shape", self.output_shape) for i in range(len(self.bufs)): - print(self.buftokens[i], self.bufs[i] in self.earlybufs, self.shapes[i], self.strides[i]) + print(self.buftokens[i], f"early:{'T' if self.bufs[i] in self.earlybufs else 'F'} image:{'T' if isinstance(self.bufs[i]._buf, CLImage) else 'F'}", self.shapes[i], self.strides[i]) self.bufs_to_delete : Set[int] = set() self.loaded_keys : Dict[Tuple[int,int], Token] = {} @@ -354,20 +350,22 @@ class CLASTKernel(ASTKernel): full_shape = self.shapes[0] if len(full_shape) == 0 else full_shape[0] self.kernel += [f"{accumulator.decltype()} {accumulator.tok} = {CLASTKernel.start_for_op[self.reduceop.op]};\n" for accumulator in accumulators] - self.kernel += [f"for (int idx{i} = 0; idx{i} < {full_shape[i]}; idx{i}++) {{\n" for i in range(self.first_reduce, self.shape_len)] - self.kernel += [f"{x.tok};\n" for x in self.ast_parse(self.reduceop, accumulators, do_reduce=True)] + ["}\n"] * (self.shape_len - self.first_reduce) + self.kernel += [f"for (int idx{i} = 0; idx{i} < {full_shape[i]}; idx{i}++) {{\n" for i in range(self.first_reduce+len(self.group_for_reduce), self.shape_len)] + self.kernel += [f"{x.tok};\n" for x in self.ast_parse(self.reduceop, accumulators, do_reduce=True)] + ["}\n"] * (self.shape_len - (self.first_reduce + len(self.group_for_reduce))) # middle if self.group_for_reduce: self.kernel.append(f"__local {accumulators[0].decltype()} temp[{prod(self.group_for_reduce)}]; // second stage\n") if self.upcast_in_mid_reduce: + assert len(self.group_for_reduce) == 2 # it should be the last dimension - self.kernel.append(f"int mid_idx = idx{self.first_reduce-2}*{self.group_for_reduce[1]} + idx{self.first_reduce-1}; temp[mid_idx] = {accumulators[0].tok}; barrier(CLK_LOCAL_MEM_FENCE);\n") - self.reshape_and_permute(None, [i for i in range(self.shape_len) if i != self.first_reduce-1] + [self.first_reduce-1]) + self.kernel.append(f"int mid_idx = idx{self.first_reduce}*{self.group_for_reduce[1]} + idx{self.first_reduce+1}; temp[mid_idx] = {accumulators[0].tok}; barrier(CLK_LOCAL_MEM_FENCE);\n") + self.reshape_and_permute(None, [i for i in range(self.shape_len) if i != self.first_reduce+1] + [self.first_reduce+1]) self.upcast() else: - self.kernel.append(f"int mid_idx = idx{self.first_reduce-1}; temp[mid_idx] = {accumulators[0].tok}; barrier(CLK_LOCAL_MEM_FENCE);\n") + assert len(self.group_for_reduce) == 1 + self.kernel.append(f"int mid_idx = idx{self.first_reduce}; temp[mid_idx] = {accumulators[0].tok}; barrier(CLK_LOCAL_MEM_FENCE);\n") self.kernel.append("if (mid_idx == 0) {\n") accumulators = [Token("output", self.buftokens[0].typ)]