mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
The default values used by JITFunction for num_warps and num_stages are coupled with Nvidia GPU architecture. We should use the proper default values based on the device backend for the kernel to be compiled to. 1. Add two functions to return the default num_warps and num_stages for the specific device backend. 2. JITFunction uses the proper default num_warps and num_stages based on the specific device backend. Co-authored-by: Wang Weihan <eikan.wang@intel.com>
6 lines
331 B
Python
6 lines
331 B
Python
from .compiler import (CompiledKernel, compile, get_arch_default_num_stages,
|
|
get_arch_default_num_warps, instance_descriptor)
|
|
from .errors import CompilationError
|
|
|
|
__all__ = ["compile", "instance_descriptor", "CompiledKernel", "CompilationError", "get_arch_default_num_warps", "get_arch_default_num_stages"]
|