mirror of
https://github.com/ROCm/ROCm.git
synced 2026-02-21 03:00:39 -05:00
Simple mechanism to run Triton kernels on PyTorch for debugging purpose (upstream from Kernl). Todo: - random grid iteration - support of atomic ops - more unit tests - cover new APIs?
19 lines
353 B
Python
19 lines
353 B
Python
try:
|
|
import torch as _torch
|
|
except ImportError:
|
|
_torch = None
|
|
|
|
|
|
class TorchWrapper:
|
|
"""
|
|
Helps in making torch an optional dependency
|
|
"""
|
|
|
|
def __getattr__(self, name):
|
|
if _torch is None:
|
|
raise ImportError("Triton requires PyTorch to be installed")
|
|
return getattr(_torch, name)
|
|
|
|
|
|
torch = TorchWrapper()
|