Files
ROCm/python/test/unit/language/assert_helper.py
Jason Furmanek 4c4e42e524 Merge remote-tracking branch 'openai/main' into IFU-230517
Conflicts:
	lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp
	lib/Target/LLVMIR/LLVMIRTranslation.cpp
	python/test/unit/language/assert_helper.py
	python/triton/third_party/cuda/bin/ptxas
	test/Conversion/tritongpu_to_llvm.mlir

 It looks like you may be committing a merge.
 If this is not correct, please remove the file
	.git/MERGE_HEAD
 and try again.
2023-05-17 15:03:42 +00:00

59 lines
1.5 KiB
Python

import sys
import torch
from torch.testing import assert_close
import triton
import triton.language as tl
@triton.jit
def kernel_device_assert(X, Y, BLOCK: tl.constexpr):
x = tl.load(X + tl.arange(0, BLOCK))
tl.device_assert(x == 0, "x != 0")
tl.store(Y + tl.arange(0, BLOCK), x)
@triton.jit
def kernel_device_assert_scalar(X, Y, BLOCK: tl.constexpr):
x = tl.load(X + tl.arange(0, BLOCK))
# Trivial assert
tl.device_assert(0 == 0, "x != 0")
tl.store(Y + tl.arange(0, BLOCK), x)
@triton.jit
def kernel_assert(X, Y, BLOCK: tl.constexpr):
x = tl.load(X + tl.arange(0, BLOCK))
assert x == 0, "x != 0"
tl.store(Y + tl.arange(0, BLOCK), x)
@triton.jit
def kernel_static_assert(X, Y, BLOCK: tl.constexpr):
x = tl.load(X + tl.arange(0, BLOCK))
tl.static_assert(BLOCK == 128, "BLOCK != 128")
tl.store(Y + tl.arange(0, BLOCK), x)
def test_assert(func: str):
shape = (128, )
x = torch.arange(0, shape[0], dtype=torch.int32, device='cuda')
y = torch.zeros(shape, dtype=x.dtype, device="cuda")
if func == "device_assert":
<<<<<<< HEAD
kernel_device_assert[(1,)](x, y, num_warps=2, BLOCK=shape[0])
=======
kernel_device_assert[(1,)](x, y, BLOCK=shape[0])
kernel_device_assert_scalar[(1,)](x, y, BLOCK=shape[0])
>>>>>>> openai/main
elif func == "assert":
kernel_assert[(1,)](x, y, num_warps=2, BLOCK=shape[0])
elif func == "static_assert":
kernel_static_assert[(1,)](x, y, BLOCK=shape[0])
assert_close(y, x)
if __name__ == "__main__":
test_assert(sys.argv[1])