mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[FRONTEND] fix the hard code builder.arch that could block third_party tests (#1859)
For CUDA devices, the `builder.arch` is an int. For third_party devices, this line would be a TypeError. For example: ``` TypeError: '<' not supported between instances of 'dict' and 'int' ``` Co-authored-by: Wang Weihan <eikan.wang@intel.com>
This commit is contained in:
@@ -663,6 +663,11 @@ def bitcast(input: tl.tensor,
|
||||
dst_ty)
|
||||
|
||||
|
||||
# TODO: architecture descriptor class
|
||||
def _is_cuda(arch):
|
||||
return isinstance(arch, int)
|
||||
|
||||
|
||||
def cast(input: tl.tensor,
|
||||
dst_ty: tl.dtype,
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
@@ -677,7 +682,7 @@ def cast(input: tl.tensor,
|
||||
src_sca_ty = src_ty.scalar
|
||||
dst_sca_ty = dst_ty.scalar
|
||||
|
||||
if builder.arch < 89 and \
|
||||
if _is_cuda(builder.arch) and builder.arch < 89 and \
|
||||
(src_sca_ty.is_fp8e4() or dst_sca_ty.is_fp8e4()):
|
||||
warnings.warn("Standard tl.float8e4 format will be deprecated on SM < 89. "
|
||||
"Please use tl.float8e4b15.", DeprecationWarning)
|
||||
|
||||
Reference in New Issue
Block a user