mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-25 14:58:46 -05:00
minor global dim cleanup (#1724)
This commit is contained in:
@@ -111,6 +111,9 @@ class Kernel:
|
||||
@property
|
||||
def upcast_in_mid_reduce_axes(self) -> List[int]: return [j for j in range(self.first_reduce, self.first_reduce+len(self.group_for_reduce)) if self.full_shape[j] == self.sts[0].shape[j]]
|
||||
|
||||
@property
|
||||
def global_dims(self) -> int: return self.first_reduce-self.local_dims
|
||||
|
||||
# there's seven chunks of the shape
|
||||
# blue -- global dims
|
||||
# cyan -- local dims
|
||||
@@ -123,9 +126,9 @@ class Kernel:
|
||||
# yellow -- normal upcasted dimensions
|
||||
def colors(self) -> List[str]:
|
||||
# up to first_reduce, they are all global (blue)
|
||||
colors = ["blue"] * (self.first_reduce-self.local_dims)
|
||||
colors = ["blue"] * self.global_dims
|
||||
# except the local_dims, these are non-reduce locals (cyan)
|
||||
colors += ["cyan"] * (self.local_dims)
|
||||
colors += ["cyan"] * self.local_dims
|
||||
# between first_reduce and first_reduce + group_for_reduce, they are either local (cyan), or late upcasted (green)
|
||||
colors += ["white" if i in self.upcast_in_mid_reduce_axes else "green" for i in range(self.first_reduce, self.first_reduce + len(self.group_for_reduce))]
|
||||
# between first_reduce + group_for_reduce and upcasted, they are reduce (red)
|
||||
|
||||
@@ -200,7 +200,8 @@ class Linearizer(OptimizedKernel):
|
||||
self.process()
|
||||
|
||||
# limit dims if we need to
|
||||
if self.opts.global_max and self.opts.local_max: self.limit_global_dims(3, self.opts.global_max, self.opts.local_max)
|
||||
self.limit_global_dim_count(3)
|
||||
if self.opts.global_max and self.opts.local_max: self.limit_dims_to_max(self.opts.global_max, self.opts.local_max)
|
||||
|
||||
# uops
|
||||
self.uops: List[UOp] = []
|
||||
|
||||
@@ -92,14 +92,16 @@ class OptimizedKernel(Kernel):
|
||||
new_shape[next_idx] = new_shape[next_idx] * 2
|
||||
return tuple(new_shape)
|
||||
|
||||
def limit_global_dims(self, limit: int, global_max: List[int], local_max: List[int]):
|
||||
def limit_global_dim_count(self, limit:int):
|
||||
# sometimes, there's more dimensions than len(self.lang.gid).
|
||||
# compact all the dimensions into the first
|
||||
# NOTE: this might make multiview shapetrackers
|
||||
if (self.first_reduce-self.local_dims) > limit:
|
||||
num_to_merge = ((self.first_reduce-self.local_dims) - limit)+1
|
||||
if self.global_dims > limit:
|
||||
num_to_merge = (self.global_dims - limit)+1
|
||||
self.reshape_and_permute(lambda x: (prod(x[0:num_to_merge]),)+x[num_to_merge:], None)
|
||||
if DEBUG >= 3: print("reshaped to", self.full_shape, "due to too many global dimensions")
|
||||
|
||||
def limit_dims_to_max(self, global_max: List[int], local_max: List[int]):
|
||||
# Check the global allocation limit, current the global_size will be flipped during codegen
|
||||
# and then padded right with 1s if its length < 3 which makes this part a bit awkward to write
|
||||
global_dims = self.first_reduce-self.local_dims
|
||||
|
||||
Reference in New Issue
Block a user