minor global dim cleanup (#1724)

This commit is contained in:
George Hotz
2023-08-31 12:23:39 -07:00
committed by GitHub
parent 94b1257f5e
commit c18a497dde
3 changed files with 12 additions and 6 deletions

View File

@@ -111,6 +111,9 @@ class Kernel:
@property
def upcast_in_mid_reduce_axes(self) -> List[int]: return [j for j in range(self.first_reduce, self.first_reduce+len(self.group_for_reduce)) if self.full_shape[j] == self.sts[0].shape[j]]
@property
def global_dims(self) -> int: return self.first_reduce-self.local_dims
# there's seven chunks of the shape
# blue -- global dims
# cyan -- local dims
@@ -123,9 +126,9 @@ class Kernel:
# yellow -- normal upcasted dimensions
def colors(self) -> List[str]:
# up to first_reduce, they are all global (blue)
colors = ["blue"] * (self.first_reduce-self.local_dims)
colors = ["blue"] * self.global_dims
# except the local_dims, these are non-reduce locals (cyan)
colors += ["cyan"] * (self.local_dims)
colors += ["cyan"] * self.local_dims
# between first_reduce and first_reduce + group_for_reduce, they are either local (cyan), or late upcasted (green)
colors += ["white" if i in self.upcast_in_mid_reduce_axes else "green" for i in range(self.first_reduce, self.first_reduce + len(self.group_for_reduce))]
# between first_reduce + group_for_reduce and upcasted, they are reduce (red)

View File

@@ -200,7 +200,8 @@ class Linearizer(OptimizedKernel):
self.process()
# limit dims if we need to
if self.opts.global_max and self.opts.local_max: self.limit_global_dims(3, self.opts.global_max, self.opts.local_max)
self.limit_global_dim_count(3)
if self.opts.global_max and self.opts.local_max: self.limit_dims_to_max(self.opts.global_max, self.opts.local_max)
# uops
self.uops: List[UOp] = []

View File

@@ -92,14 +92,16 @@ class OptimizedKernel(Kernel):
new_shape[next_idx] = new_shape[next_idx] * 2
return tuple(new_shape)
def limit_global_dims(self, limit: int, global_max: List[int], local_max: List[int]):
def limit_global_dim_count(self, limit:int):
# sometimes, there's more dimensions than len(self.lang.gid).
# compact all the dimensions into the first
# NOTE: this might make multiview shapetrackers
if (self.first_reduce-self.local_dims) > limit:
num_to_merge = ((self.first_reduce-self.local_dims) - limit)+1
if self.global_dims > limit:
num_to_merge = (self.global_dims - limit)+1
self.reshape_and_permute(lambda x: (prod(x[0:num_to_merge]),)+x[num_to_merge:], None)
if DEBUG >= 3: print("reshaped to", self.full_shape, "due to too many global dimensions")
def limit_dims_to_max(self, global_max: List[int], local_max: List[int]):
# Check the global allocation limit, current the global_size will be flipped during codegen
# and then padded right with 1s if its length < 3 which makes this part a bit awkward to write
global_dims = self.first_reduce-self.local_dims