mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[CI] Added H100 node (#1779)
This commit is contained in:
@@ -11,8 +11,6 @@ from collections import defaultdict, namedtuple
|
||||
from typing import (Callable, Generic, Iterable, List, Optional, TypeVar, Union, cast,
|
||||
overload)
|
||||
|
||||
import torch
|
||||
|
||||
# import triton
|
||||
# from .. import compile, CompiledKernel
|
||||
from ..common.backend import get_backend
|
||||
@@ -289,6 +287,7 @@ class JITFunction(KernelInterface[T]):
|
||||
device_types = [device_type for device_type in device_types if device_type != '']
|
||||
# Return cuda if one of the input tensors is cuda
|
||||
if 'cuda' in device_types:
|
||||
import torch
|
||||
return 'hip' if torch.version.hip else 'cuda'
|
||||
|
||||
is_cpu = all(device_type == 'cpu' for device_type in device_types)
|
||||
|
||||
Reference in New Issue
Block a user