mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[TESTING] fix get_max_simd_tflops (#1318)
`_triton.runtime.num_sm`, `_triton.runtime.clock_rate`, `_triton.runtime.cc` seem no longer exist. use the corresponding methods from `get_max_tensorcore_tflops` in the same file.
This commit is contained in:
@@ -454,10 +454,12 @@ def get_max_simd_tflops(dtype: torch.dtype, backend=None, device=None):
|
||||
backend = _triton.runtime.backend.CUDA
|
||||
if not device:
|
||||
device = torch.cuda.current_device()
|
||||
num_subcores = _triton.runtime.num_sm(backend, device) * 4 # on recent GPUs
|
||||
clock_rate = _triton.runtime.clock_rate(backend, device) # in kHz
|
||||
cc = _triton.runtime.cc(backend, device)
|
||||
if cc < 80:
|
||||
|
||||
triton.compiler.init_cuda_utils()
|
||||
num_subcores = triton.compiler.cuda_utils.get_device_properties(device)["multiprocessor_count"] * 4
|
||||
clock_rate = triton.compiler.cuda_utils.get_device_properties(device)["sm_clock_rate"] # in kHz
|
||||
capability = torch.cuda.get_device_capability()
|
||||
if capability[0] < 8:
|
||||
if dtype == torch.float32:
|
||||
ops_per_sub_core = 32 # 2*16
|
||||
elif dtype == torch.float16:
|
||||
|
||||
Reference in New Issue
Block a user