mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
Conflicts: lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp lib/Dialect/TritonGPU/IR/Dialect.cpp python/setup.py python/test/unit/language/assert_helper.py python/test/unit/operators/test_flash_attention.py python/test/unit/runtime/test_subproc.py python/triton/compiler/compiler.py python/triton/language/semantic.py python/triton/runtime/autotuner.py python/triton/runtime/jit.py python/tutorials/03-matrix-multiplication.py python/tutorials/05-layer-norm.py python/tutorials/06-fused-attention.py python/tutorials/11-grouped-gemm.py test/Conversion/tritongpu_to_llvm.mlir
119 lines
2.9 KiB
Python
119 lines
2.9 KiB
Python
import multiprocessing
|
|
import os
|
|
import shutil
|
|
from collections import namedtuple
|
|
|
|
import torch
|
|
|
|
import triton
|
|
import triton.language as tl
|
|
|
|
tmpdir = ".tmp"
|
|
|
|
|
|
def reset_tmp_dir():
|
|
os.environ["TRITON_CACHE_DIR"] = tmpdir
|
|
if os.path.exists(tmpdir):
|
|
shutil.rmtree(tmpdir, ignore_errors=True)
|
|
|
|
|
|
instance_descriptor = namedtuple("instance_descriptor",
|
|
["divisible_by_16", "equal_to_1", "ids_of_folded_args", "divisible_by_8"])
|
|
|
|
|
|
<<<<<<< HEAD
|
|
def get_device_type():
|
|
try:
|
|
import torch
|
|
except ImportError:
|
|
raise ImportError("Triton requires PyTorch to be installed")
|
|
if torch.version.hip is None:
|
|
major, minor = torch.cuda.get_device_capability(0)
|
|
cc = major * 10 + minor
|
|
device_type="cuda"
|
|
else:
|
|
cc = None
|
|
device_type="hip"
|
|
return cc, device_type
|
|
|
|
|
|
def compile_fn(config, device_type, cc):
|
|
=======
|
|
def compile_fn(config, cc):
|
|
|
|
>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52
|
|
@triton.jit
|
|
def kernel_sub(a, b, o, N: tl.constexpr):
|
|
idx = tl.arange(0, N)
|
|
tl.store(o + idx, tl.load(a + idx) - tl.load(b + idx) * 777)
|
|
|
|
triton.compile(
|
|
fn=kernel_sub,
|
|
signature={0: "*fp32", 1: "*fp32", 2: "*fp32"},
|
|
device=0,
|
|
device_type=device_type,
|
|
constants={3: 32},
|
|
configs=[config],
|
|
warm_cache_only=True,
|
|
cc=cc,
|
|
)
|
|
|
|
|
|
def test_compile_in_subproc() -> None:
|
|
cc, device_type = get_device_type()
|
|
config = instance_descriptor(tuple(range(4)), (), (), ())
|
|
|
|
multiprocessing.set_start_method('fork')
|
|
<<<<<<< HEAD
|
|
proc = multiprocessing.Process(
|
|
target=compile_fn,
|
|
args=(config, device_type, cc))
|
|
=======
|
|
proc = multiprocessing.Process(target=compile_fn, args=(config, cc))
|
|
>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52
|
|
proc.start()
|
|
proc.join()
|
|
assert proc.exitcode == 0
|
|
|
|
|
|
<<<<<<< HEAD
|
|
def compile_fn_dot(config, device_type, cc):
|
|
=======
|
|
def compile_fn_dot(config, cc):
|
|
|
|
>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52
|
|
@triton.jit
|
|
def kernel_dot(Z):
|
|
offs = tl.arange(0, 16)[:, None] * 16 + tl.arange(0, 16)[None, :]
|
|
z = tl.load(Z + offs)
|
|
z = tl.dot(z, z)
|
|
tl.store(Z + offs, z)
|
|
|
|
triton.compile(
|
|
fn=kernel_dot,
|
|
signature={0: "*fp32"},
|
|
device=0,
|
|
device_type=device_type,
|
|
configs=[config],
|
|
warm_cache_only=True,
|
|
cc=cc,
|
|
)
|
|
|
|
|
|
def test_compile_in_forked_subproc() -> None:
|
|
reset_tmp_dir()
|
|
cc, device_type = get_device_type()
|
|
config = instance_descriptor(tuple(range(1)), (), (), ())
|
|
|
|
assert multiprocessing.get_start_method() == 'fork'
|
|
<<<<<<< HEAD
|
|
proc = multiprocessing.Process(
|
|
target=compile_fn_dot,
|
|
args=(config, device_type, cc))
|
|
=======
|
|
proc = multiprocessing.Process(target=compile_fn_dot, args=(config, cc))
|
|
>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52
|
|
proc.start()
|
|
proc.join()
|
|
assert proc.exitcode == 0
|