Files
ROCm/python/triton/compiler/__init__.py
chengjunlu 6cb67185f8 [FRONTEND]To use proper default num_warps and num_stages based on the device backend in JITFucntion (#2130)
The default values used by JITFunction for num_warps and num_stages are
coupled with Nvidia GPU architecture. We should use the proper default
values based on the device backend for the kernel to be compiled to.
1. Add two functions to return the default num_warps and num_stages for
the specific device backend.
2. JITFunction uses the proper default num_warps and num_stages based on
the specific device backend.

Co-authored-by: Wang Weihan <eikan.wang@intel.com>
2023-08-24 21:58:18 +08:00

6 lines
331 B
Python

from .compiler import (CompiledKernel, compile, get_arch_default_num_stages,
get_arch_default_num_warps, instance_descriptor)
from .errors import CompilationError
__all__ = ["compile", "instance_descriptor", "CompiledKernel", "CompilationError", "get_arch_default_num_warps", "get_arch_default_num_stages"]