flip output shape extra dimension indexing for speed

This commit is contained in:
George Hotz
2023-01-25 11:00:37 -08:00
parent 335a261a2e
commit ee1f6ab3ca

View File

@@ -82,6 +82,7 @@ class CLProgram:
CLProgram.kernel_cnt[name] += 1
def __call__(self, *args):
CL.kernel_count += 1
if DEBUG >= 4: print(args[0], args[1], self.prg)
if OSX and DEBUG >= 2: st = time.monotonic_ns()
if CL.CACHE is not None: CL.CACHE.append((self, args))
else: e = self.clprg(CL().cl_queue, *args)
@@ -90,7 +91,6 @@ class CLProgram:
# NOTE: Profiling is (sadly) broken in OS X, so we take the real kernel time
# BOUNTY: will paypal $50 to anyone who fixes this
et = (time.monotonic_ns() - st) if OSX else (e.profile.end - e.profile.start)
if DEBUG >= 4: print(self.prg)
if DEBUG >= 1:
CL.time_sum += 0 if DEBUG <= 1 or CL.CACHE is not None else et
CL.ops_sum += self.op_estimate
@@ -326,12 +326,17 @@ class CLASTKernel(ASTKernel):
self.loaded_keys : Dict[Tuple[int,int], Token] = {}
self.kernel : List[str] = ["const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n"]
self.kernel += [f"int idx{i} = get_global_id({min(3, len(self.output_shape))-1-i}); /* {self.output_shape[i]} */\n" for i in range(min(3, len(self.output_shape)))]
# output_shape[-1] is get_global_id(0)
self.kernel += [f"int idx{len(self.output_shape)-1-i} = get_global_id({i}); /* {self.output_shape[-1-i]} */\n" for i in range(min(3, len(self.output_shape)))]
if len(self.output_shape) > 3:
# compact all the dimensions into the final one
for i in range(len(self.output_shape)-1, 2, -1):
self.kernel += [f"int idx{i} = idx2 % {self.output_shape[i]};", f"idx2 = idx2 / {self.output_shape[i]};\n"]
self.output_shape = list(self.output_shape[0:2]) + [prod(self.output_shape[2:])]
# sometimes, there's more dimensions. compact all the dimensions into the first one
final_dimension = len(self.output_shape)-3
for i in range(len(self.output_shape)-4, -1, -1):
self.kernel += [f"int idx{i} = idx{final_dimension} % {self.output_shape[i]};", f"idx{final_dimension} = idx{final_dimension} / {self.output_shape[i]};\n"]
self.output_shape = [prod(self.output_shape[0:-2])] + self.output_shape[-2:]
if DEBUG >= 3: print(f"replaced output shape with {self.output_shape}")
# early ast
accumulators : List[Token] = [Token("acc%d" % i, self.buftokens[0].typ) for i in range(self.buftokens[0].size())]