mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[FRONTEND] fix xpu stages logic (#2305)
This commit is contained in:
@@ -405,12 +405,6 @@ def compile(fn, **kwargs):
|
||||
add_cuda_stages(arch, extern_libs, stages)
|
||||
elif device_type == "hip":
|
||||
_device_backend.add_stages(arch, extern_libs, stages, num_warps=num_warps, num_stages=num_stages)
|
||||
elif device_type == "xpu":
|
||||
stages["ttgir"] = (lambda path: parse_mlir_module(path, context),
|
||||
lambda src: optimize_ttgir(ttir_to_ttgir(src, num_warps, num_ctas, arch), num_stages, num_warps, num_ctas, arch, cluster_info, enable_warp_specialization, enable_persistent, optimize_epilogue))
|
||||
stages["llir"] = (lambda path: Path(path).read_text(),
|
||||
lambda src: ttgir_to_llir(src, extern_libs, arch, tma_infos))
|
||||
_device_backend.add_stages(arch, extern_libs, stages)
|
||||
else:
|
||||
# pass the user's configuration to the backend device.
|
||||
arch["num_warps"] = num_warps
|
||||
|
||||
Reference in New Issue
Block a user