map tensor core bufs to reduceop (#4763)

* tc_opts.bufs to its only map

* lint

* iterate reduceop bufs
This commit is contained in:
qazal
2024-05-29 03:07:39 +08:00
committed by GitHub
parent 6fcf220b21
commit a88aea626d
2 changed files with 9 additions and 6 deletions

View File

@@ -35,7 +35,6 @@ class Opt:
return self.axis
class TensorCoreOptions(NamedTuple):
bufs: Tuple[int, int] # the local aliased buffers for A and B
axes: List[int] # the location of the original N and M axes if still in the shape
axes_exist: List[bool] # true if the original N and M axes are still in the shape
def fix_axes(self, removed_axis:int): # adjust the TC axes if necesssary when an dimension is removed
@@ -90,6 +89,8 @@ class Kernel:
self.local_alias: Dict[int, LocalBuffer] = {}
self.tensor_core: Optional[TensorCore] = None
self.tensor_core_opts: Optional[TensorCoreOptions] = None
# the local aliased buffers for A and B
self.bufs_for_tensor_core: Dict[LazyOp, Tuple[int, int]] = {}
self.dont_use_locals: bool = False
# group simplifies
@@ -113,7 +114,8 @@ class Kernel:
# parameters for optimizations
ret.applied_opts, ret.group_for_reduces, ret.upcasted, ret.local_dims, ret.dont_use_locals = \
self.applied_opts[:], self.group_for_reduces, self.upcasted, self.local_dims, self.dont_use_locals
ret.tensor_core, ret.tensor_core_opts, ret.local_alias = self.tensor_core, self.tensor_core_opts, {}
ret.tensor_core, ret.tensor_core_opts, ret.local_alias, ret.bufs_for_tensor_core = self.tensor_core, self.tensor_core_opts, {}, \
self.bufs_for_tensor_core
# uncached since linearize didn't run
ret.applied_opts_cache = None
@@ -353,7 +355,8 @@ class Kernel:
if axis_pads and (opt_level < 2): continue
# tensor core -- unroll the reduce dim, upcast input, then create the correct thread pattern
self.tensor_core_opts = (tc_opts:=TensorCoreOptions(bufs=(buf0, buf1), axes=[s0, s1], axes_exist=[True, True]))
self.tensor_core_opts = (tc_opts:=TensorCoreOptions(axes=[s0, s1], axes_exist=[True, True]))
self.bufs_for_tensor_core[self.reduceop] = (buf0, buf1)
# attempt to pad the tensor axes that require it
try:

View File

@@ -320,10 +320,10 @@ class Linearizer(Kernel):
if self.applied_opts == self.applied_opts_cache: return self
# late alias the tensor core buffers
if (tc:=self.tensor_core) and (tc_opts:=self.tensor_core_opts):
if (tc:=self.tensor_core) and self.tensor_core_opts is not None:
alias_pattern = [0]*(self.global_dims) + [2]*(len(tc.threads)) + [0]*(self.local_dims-len(tc.threads)) + [0]*(self.shape_len-self.upcasted-self.first_reduce) + [1,1] + [3]*(self.upcasted-2) # noqa: E501
for tc_buf in tc_opts.bufs:
self.alias_buffer(tc_buf, alias_pattern)
for _,tc_bufs in self.bufs_for_tensor_core.items():
for tc_buf in tc_bufs: self.alias_buffer(tc_buf, alias_pattern) # TODO aliased buffers should map to the reduceop
# save backups
sts_backup, gfr_backup, upc_backup = self.sts[:], self.group_for_reduces, self.upcasted