mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
Previously, if there were any specializations of "1" or "constexpr" mixed with unspecialized arguments in arbitrary order, we might have encountered errors due to passing incorrect arguments. This was because the length of the signature did not indicate the maximum index of regular arguments. https://github.com/openai/triton/issues/2229 @shunting314 @amjames More specifically for cases like: ``` kernel( b: tl.tensor, a: tl.constexpr, c: tl.int = 1, d, e: tl.constexpr, ... ) ```
297 lines
12 KiB
Python
297 lines
12 KiB
Python
# Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved.
|
|
#
|
|
# Permission is hereby granted, free of charge, to any person obtaining
|
|
# a copy of this software and associated documentation files
|
|
# (the "Software"), to deal in the Software without restriction,
|
|
# including without limitation the rights to use, copy, modify, merge,
|
|
# publish, distribute, sublicense, and/or sell copies of the Software,
|
|
# and to permit persons to whom the Software is furnished to do so,
|
|
# subject to the following conditions:
|
|
#
|
|
# The above copyright notice and this permission notice shall be
|
|
# included in all copies or substantial portions of the Software.
|
|
#
|
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
|
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
|
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
|
# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
|
# CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
|
# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
|
# SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
|
|
|
from __future__ import annotations
|
|
|
|
from ..runtime import driver
|
|
|
|
|
|
def generate_cu_signature(constants, signature, ids):
|
|
# CUtensorMap*s are always the last arguments
|
|
num_regular_signatures = max(signature.keys()) + 1 if len(signature) > 0 else 0
|
|
if ids["ids_of_tensormaps"] is not None:
|
|
for i, _ in enumerate(ids["ids_of_tensormaps"]):
|
|
signature[num_regular_signatures + i] = '*CUtensorMap'
|
|
return signature, num_regular_signatures
|
|
|
|
|
|
def dummy_tensormaps_info(n=2):
|
|
ret = []
|
|
for i in range(n):
|
|
ret.append(InfoFromBackendForTensorMap(dummy=True))
|
|
return ret
|
|
|
|
|
|
def parse_tma_info(infos, ids_of_folded_args):
|
|
ret = []
|
|
for info in infos:
|
|
e = InfoFromBackendForTensorMap(infos=info)
|
|
e.ids_of_folded_args = ids_of_folded_args
|
|
ret.append(e)
|
|
return ret
|
|
|
|
|
|
def get_tma_mapping(tensormaps_info):
|
|
ret = {}
|
|
if tensormaps_info is not None:
|
|
for i, e in enumerate(tensormaps_info):
|
|
ret.update(e.get_address_tma_mapping())
|
|
else:
|
|
ret = None
|
|
return ret
|
|
|
|
|
|
def get_ids_of_tensormaps(tensormaps_info):
|
|
ret = None
|
|
# order is not relevant
|
|
if tensormaps_info is not None:
|
|
ret = [e.get_id_of_tensormap() for e in tensormaps_info]
|
|
return ret
|
|
|
|
|
|
# decouple information for tensormap from backend
|
|
# please ignore the naming style, xx_yy is compiler.py style, xxYy is to comply with cuda tensormap style
|
|
# mixing style is for readability
|
|
class InfoFromBackendForTensorMap:
|
|
N = 2
|
|
n = 0
|
|
ntma = 0
|
|
|
|
def __init__(self, infos=None, dummy=False):
|
|
self.dummy = dummy
|
|
self.ids_of_folded_args = ()
|
|
if not dummy and not isinstance(infos, dict):
|
|
self._extract_info_from_backend(infos)
|
|
elif not dummy and isinstance(infos, dict):
|
|
self._extract_info_from_dict(infos)
|
|
elif dummy:
|
|
self._dummy()
|
|
|
|
def _dummy(self):
|
|
assert InfoFromBackendForTensorMap.n < InfoFromBackendForTensorMap.N
|
|
if InfoFromBackendForTensorMap.n == 0:
|
|
self.tensorDataType = driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_FLOAT16"]
|
|
self.tensorRank = 4
|
|
self.globalAddressArgIdx = 0
|
|
self.globalStridesArgIdx = [7, 6, -1, -1]
|
|
self.globalDimsArgIdx = [5, 3, -1, -1]
|
|
self.boxDims = [16, 64, 1, 1]
|
|
self.elementStrides = [1, 1, 1, 1]
|
|
self.interleave = driver.utils.CUtensorMapInterleave["CU_TENSOR_MAP_INTERLEAVE_NONE"]
|
|
self.swizzle = driver.utils.CUtensorMapSwizzle["CU_TENSOR_MAP_SWIZZLE_32B"]
|
|
self.l2Promotion = driver.utils.CUtensorMapL2promotion["CU_TENSOR_MAP_L2_PROMOTION_L2_128B"]
|
|
self.TMADescArgIdx = 11
|
|
self.oobFill = driver.utils.CUtensorMapFloatOOBfill["CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE"]
|
|
InfoFromBackendForTensorMap.n += 1
|
|
return
|
|
if InfoFromBackendForTensorMap.n == 1:
|
|
self.tensorDataType = driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_FLOAT16"]
|
|
self.tensorRank = 4
|
|
self.globalAddressArgIdx = 1
|
|
self.globalStridesArgIdx = [7, 6, -1, -1]
|
|
self.globalDimsArgIdx = [5, 3, -1, -1]
|
|
self.boxDims = [16, 64, 1, 1]
|
|
self.elementStrides = [1, 1, 1, 1]
|
|
self.interleave = driver.utils.CUtensorMapInterleave["CU_TENSOR_MAP_INTERLEAVE_NONE"]
|
|
self.swizzle = driver.utils.CUtensorMapSwizzle["CU_TENSOR_MAP_SWIZZLE_32B"]
|
|
self.l2Promotion = driver.utils.CUtensorMapL2promotion["CU_TENSOR_MAP_L2_PROMOTION_L2_128B"]
|
|
self.TMADescArgIdx = 12
|
|
self.oobFill = driver.utils.CUtensorMapFloatOOBfill["CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE"]
|
|
InfoFromBackendForTensorMap.n += 1
|
|
return
|
|
|
|
def _extract_info_from_backend(self, infos):
|
|
self.tensorDataType = infos.tensorDataType
|
|
self.tensorRank = infos.tensorRank
|
|
self.globalAddressArgIdx = infos.globalAddressArgIdx
|
|
self.globalStridesArgIdx = infos.globalStridesArgIdx
|
|
self.globalDimsArgIdx = infos.globalDimsArgIdx
|
|
self.boxDims = infos.boxDims
|
|
self.elementStrides = infos.elementStrides
|
|
self.interleave = infos.interleave
|
|
self.swizzle = infos.swizzle
|
|
self.l2Promotion = infos.l2Promotion
|
|
self.oobFill = infos.oobFill
|
|
self.TMADescArgIdx = infos.TMADescArgIdx
|
|
|
|
# dict could be from cached metadata json
|
|
def _extract_info_from_dict(self, infos: dict):
|
|
self.tensorDataType = infos['tensorDataType']
|
|
self.tensorRank = infos['tensorRank']
|
|
self.globalAddressArgIdx = infos['globalAddressArgIdx']
|
|
self.globalStridesArgIdx = infos['globalStridesArgIdx']
|
|
self.globalDimsArgIdx = infos['globalDimsArgIdx']
|
|
self.boxDims = infos['boxDims']
|
|
self.elementStrides = infos['elementStrides']
|
|
self.interleave = infos['interleave']
|
|
self.swizzle = infos['swizzle']
|
|
self.l2Promotion = infos['l2Promotion']
|
|
self.oobFill = infos['oobFill']
|
|
self.TMADescArgIdx = infos['TMADescArgIdx']
|
|
|
|
def get_address_tma_mapping(self):
|
|
return {self.globalAddressArgIdx: self.TMADescArgIdx + len(self.ids_of_folded_args)}
|
|
|
|
def get_id_of_tensormap(self):
|
|
return self.TMADescArgIdx + len(self.ids_of_folded_args)
|
|
|
|
def getTMADescArgIdx(self):
|
|
return self.TMADescArgIdx
|
|
|
|
# dtype:cuda.CUtensorMapDataType | int
|
|
def bytes_from_type(self, dtype):
|
|
return {driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_UINT8"]: 1,
|
|
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_UINT16"]: 2,
|
|
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_UINT32"]: 4,
|
|
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_INT32"]: 4,
|
|
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_UINT64"]: 8,
|
|
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_INT64"]: 8,
|
|
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_FLOAT16"]: 2,
|
|
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_FLOAT32"]: 4,
|
|
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_FLOAT64"]: 8,
|
|
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_BFLOAT16"]: 2,
|
|
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_FLOAT32_FTZ"]: 4,
|
|
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_TFLOAT32"]: 4,
|
|
driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_TFLOAT32_FTZ"]: 4}[dtype]
|
|
|
|
def getTensorMapDataType(self):
|
|
return self.tensorDataType
|
|
|
|
def getInterleave(self):
|
|
return self.interleave
|
|
|
|
def getSwizzle(self):
|
|
return self.swizzle
|
|
|
|
def getL2Promotion(self):
|
|
return self.l2Promotion
|
|
|
|
def getOobFill(self):
|
|
return self.oobFill
|
|
|
|
def getTensorRank(self):
|
|
return self.tensorRank
|
|
|
|
def getBoxDims(self):
|
|
return self.boxDims
|
|
|
|
def getElementStrides(self):
|
|
return self.elementStrides
|
|
|
|
def getGlobalAddress(self, args):
|
|
idx = self.getOriginArgIdx(self.globalAddressArgIdx, args)
|
|
return args[idx]
|
|
|
|
# args, captured kernel args in runtime
|
|
def getGlobalDims(self, args):
|
|
shape = []
|
|
for e in self.globalDimsArgIdx:
|
|
t = 1
|
|
# < 0 means folded arg or constant (-1 - value)
|
|
# -1 means extended dim which is 1, -2 means folded arg with constant 1 (-1 - value)
|
|
if e == -1:
|
|
t = 1
|
|
elif e < 0 and e != -1:
|
|
t = -e - 1
|
|
else:
|
|
idx = self.getOriginArgIdx(e, args)
|
|
t = args[idx]
|
|
shape.append(t)
|
|
return shape
|
|
|
|
def getGlobalStrides(self, args):
|
|
t_globalDims = [int(e) for e in self.getGlobalDims(args)]
|
|
t_globalStridesArgIdx = self.globalStridesArgIdx.copy()
|
|
strides_in_elements = []
|
|
# todo: get all stride from backend even in extended mode
|
|
for i in range(self.tensorRank):
|
|
t = 1
|
|
if t_globalStridesArgIdx[i] == -1:
|
|
for ii in range(i):
|
|
t *= t_globalDims[ii]
|
|
# -2 means the sride in arguments is folded constant 1, we don't use 1 because it can not be distinguished from index 1
|
|
elif t_globalStridesArgIdx[i] == -2:
|
|
t = 1
|
|
else:
|
|
new_idx = self.getOriginArgIdx(t_globalStridesArgIdx[i], args)
|
|
t = args[new_idx]
|
|
|
|
strides_in_elements.append(t)
|
|
|
|
strides_in_elements = strides_in_elements[1:]
|
|
strides_in_bytes = [e * self.bytes_from_type(self.tensorDataType) for e in strides_in_elements]
|
|
return strides_in_bytes
|
|
|
|
def getOriginArgIdx(self, idx, args):
|
|
if self.ids_of_folded_args:
|
|
ids_before_folding_arg = [i for i in range(len(args)) if i not in self.ids_of_folded_args]
|
|
return ids_before_folding_arg[idx]
|
|
else:
|
|
return idx
|
|
|
|
def tensormap(self, args):
|
|
return driver.utils.cuTensorMapEncodeTiled(
|
|
self.getTensorMapDataType(),
|
|
self.getTensorRank(),
|
|
self.getGlobalAddress(args),
|
|
self.getGlobalDims(args),
|
|
self.getGlobalStrides(args),
|
|
self.getBoxDims(),
|
|
self.getElementStrides(),
|
|
self.getInterleave(),
|
|
self.getSwizzle(),
|
|
self.getL2Promotion(),
|
|
self.getOobFill()
|
|
)
|
|
|
|
# make hashable to use as partial key in cache
|
|
def __hash__(self):
|
|
return hash((self.ids_of_folded_args, self.globalAddressArgIdx, tuple(self.globalDimsArgIdx), tuple(self.globalStridesArgIdx), self.tensorDataType,
|
|
self.tensorRank, tuple(self.boxDims), tuple(self.elementStrides), self.interleave, self.swizzle, self.l2Promotion, self.oobFill))
|
|
|
|
def __eq__(self, other):
|
|
if not isinstance(other, self.__class__):
|
|
return False
|
|
return (self.ids_of_folded_args, self.globalAddressArgIdx, self.globalDimsArgIdx, self.globalStridesArgIdx, self.tensorDataType, self.tensorRank, self.boxDims, self.elementStrides, self.interleave, self.swizzle, self.l2Promotion, self.oobFill) == (
|
|
other.ids_of_folded_args, other.globalAddressArgIdx, other.globalDimsArgIdx, other.globalStridesArgIdx, other.tensorDataType, other.tensorRank, other.boxDims, other.elementStrides, other.interleave, other.swizzle, other.l2Promotion, other.oobFill)
|
|
|
|
|
|
class TensorMapManager:
|
|
def __init__(self):
|
|
self.tensormaps_device = {}
|
|
|
|
def __getitem__(self, key: tuple):
|
|
if key in self.tensormaps_device:
|
|
return int(self.tensormaps_device[key])
|
|
else:
|
|
(e, args) = key
|
|
t_tensormap = e.tensormap(args)
|
|
TENSORMAP_SIZE_IN_BYTES = 128
|
|
t_tensormap_device = driver.utils.cuMemAlloc(TENSORMAP_SIZE_IN_BYTES)
|
|
driver.utils.cuMemcpyHtoD(
|
|
t_tensormap_device, t_tensormap, TENSORMAP_SIZE_IN_BYTES)
|
|
self.tensormaps_device[key] = t_tensormap_device
|
|
return int(self.tensormaps_device[key])
|
|
|
|
def __del__(self):
|
|
for _, v in self.tensormaps_device.items():
|
|
driver.utils.cuMemFree(v)
|