mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[FRONTEND] Correct the debug syntax (#1705)
- If `TRITON_DEBUG=True`, all triton functions will be compiled in the
debug mode.
- Otherwise, a triton function `f`'s debug flag is either `True`,
`False` or `None` (default).
- If `True`, `f` is compiled in the debug mode.
- If `False`, `f` is compiled in the normal mode.
- If `None`, `f` is compiled based on its caller's debug flag. The root
(kernel) function's debug flag can also be set through the `compile`
function.
cc @ngimel , @Chillee
This commit is contained in:
@@ -22,6 +22,13 @@ def kernel_device_assert_scalar(X, Y, BLOCK: tl.constexpr):
|
||||
tl.store(Y + tl.arange(0, BLOCK), x)
|
||||
|
||||
|
||||
@triton.jit(debug=False)
|
||||
def kernel_device_assert_no_debug(X, Y, BLOCK: tl.constexpr):
|
||||
x = tl.load(X + tl.arange(0, BLOCK))
|
||||
tl.device_assert(x == 0, "x != 0")
|
||||
tl.store(Y + tl.arange(0, BLOCK), x)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def kernel_assert(X, Y, BLOCK: tl.constexpr):
|
||||
x = tl.load(X + tl.arange(0, BLOCK))
|
||||
@@ -43,6 +50,9 @@ def test_assert(func: str):
|
||||
if func == "device_assert":
|
||||
kernel_device_assert[(1,)](x, y, BLOCK=shape[0])
|
||||
kernel_device_assert_scalar[(1,)](x, y, BLOCK=shape[0])
|
||||
elif func == "no_debug":
|
||||
# TRITON_DEBUG=True can override the debug flag
|
||||
kernel_device_assert_no_debug[(1,)](x, y, BLOCK=shape[0])
|
||||
elif func == "assert":
|
||||
kernel_assert[(1,)](x, y, BLOCK=shape[0])
|
||||
elif func == "static_assert":
|
||||
@@ -50,5 +60,72 @@ def test_assert(func: str):
|
||||
assert_close(y, x)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def jit_device_assert_none(x):
|
||||
tl.device_assert(x == 0, "x != 0")
|
||||
|
||||
|
||||
@triton.jit(debug=True)
|
||||
def jit_device_assert_true(x):
|
||||
tl.device_assert(x == 0, "x != 0")
|
||||
|
||||
|
||||
@triton.jit(debug=False)
|
||||
def jit_device_assert_false(x):
|
||||
tl.device_assert(x == 0, "x != 0")
|
||||
|
||||
|
||||
@triton.jit
|
||||
def kernel_device_assert_nested(X, Y, BLOCK: tl.constexpr, jit_debug: tl.constexpr):
|
||||
x = tl.load(X + tl.arange(0, BLOCK))
|
||||
if jit_debug == "true":
|
||||
jit_device_assert_true(x)
|
||||
elif jit_debug == "false":
|
||||
jit_device_assert_false(x)
|
||||
else:
|
||||
jit_device_assert_none(x)
|
||||
tl.store(Y + tl.arange(0, BLOCK), x)
|
||||
|
||||
|
||||
@triton.jit(debug=True)
|
||||
def kernel_device_assert_nested_true(X, Y, BLOCK: tl.constexpr, jit_debug: tl.constexpr):
|
||||
x = tl.load(X + tl.arange(0, BLOCK))
|
||||
if jit_debug == "true":
|
||||
jit_device_assert_true(x)
|
||||
elif jit_debug == "false":
|
||||
jit_device_assert_false(x)
|
||||
else:
|
||||
jit_device_assert_none(x)
|
||||
tl.store(Y + tl.arange(0, BLOCK), x)
|
||||
|
||||
|
||||
@triton.jit(debug=False)
|
||||
def kernel_device_assert_nested_false(X, Y, BLOCK: tl.constexpr, jit_debug: tl.constexpr):
|
||||
x = tl.load(X + tl.arange(0, BLOCK))
|
||||
if jit_debug == "true":
|
||||
jit_device_assert_true(x)
|
||||
elif jit_debug == "false":
|
||||
jit_device_assert_false(x)
|
||||
else:
|
||||
jit_device_assert_none(x)
|
||||
tl.store(Y + tl.arange(0, BLOCK), x)
|
||||
|
||||
|
||||
def test_assert_nested(caller: str, callee: str):
|
||||
shape = (128, )
|
||||
x = torch.arange(0, shape[0], dtype=torch.int32, device='cuda')
|
||||
y = torch.zeros(shape, dtype=x.dtype, device="cuda")
|
||||
if caller == "none":
|
||||
kernel_device_assert_nested[(1,)](x, y, BLOCK=shape[0], jit_debug=callee)
|
||||
elif caller == "true":
|
||||
kernel_device_assert_nested_true[(1,)](x, y, BLOCK=shape[0], jit_debug=callee)
|
||||
elif caller == "false":
|
||||
kernel_device_assert_nested_false[(1,)](x, y, BLOCK=shape[0], jit_debug=callee)
|
||||
assert_close(y, x)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_assert(sys.argv[1])
|
||||
if len(sys.argv) == 3:
|
||||
test_assert_nested(sys.argv[1], sys.argv[2])
|
||||
else:
|
||||
test_assert(sys.argv[1])
|
||||
|
||||
@@ -9,7 +9,8 @@ print_path = os.path.join(dir_path, "print_helper.py")
|
||||
assert_path = os.path.join(dir_path, "assert_helper.py")
|
||||
|
||||
# TODO: bfloat16 after LLVM-15
|
||||
func_types = ["device_assert", "assert", "static_assert"]
|
||||
func_types = ["device_assert", "assert", "static_assert", "no_debug"]
|
||||
nested_types = [(caller, callee) for caller in ["true", "false", "none"] for callee in ["true", "false", "none"]]
|
||||
torch_types = ["int8", "uint8", "int16", "int32", "long", "float16", "float32", "float64"]
|
||||
|
||||
|
||||
@@ -51,3 +52,29 @@ def test_assert(func_type: str):
|
||||
assert num_errs == 127
|
||||
else:
|
||||
assert num_errs == 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize("caller_type, callee_type", nested_types)
|
||||
def test_assert_nested(caller_type, callee_type):
|
||||
proc = subprocess.Popen([sys.executable, assert_path, caller_type, callee_type], stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=False)
|
||||
_, errs = proc.communicate()
|
||||
errs = errs.splitlines()
|
||||
num_errs = 0
|
||||
for err in errs:
|
||||
if "x != 0" in err.decode("utf-8"):
|
||||
num_errs += 1
|
||||
if caller_type == "none":
|
||||
if callee_type == "true":
|
||||
assert num_errs == 127
|
||||
else:
|
||||
assert num_errs == 0
|
||||
elif caller_type == "true":
|
||||
if callee_type == "false":
|
||||
assert num_errs == 0
|
||||
else:
|
||||
assert num_errs == 127
|
||||
elif caller_type == "false":
|
||||
if callee_type == "true":
|
||||
assert num_errs == 127
|
||||
else:
|
||||
assert num_errs == 0
|
||||
|
||||
@@ -160,12 +160,12 @@ def test_jit_debug() -> None:
|
||||
assert len(kernel_add.cache[device]) == 1
|
||||
kernel_add.debug = False
|
||||
kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1,))
|
||||
assert len(kernel_add.cache[device]) == 1
|
||||
assert len(kernel_add.cache[device]) == 2
|
||||
kernel_add.debug = True
|
||||
kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1,))
|
||||
assert len(kernel_add.cache[device]) == 2
|
||||
assert len(kernel_add.cache[device]) == 3
|
||||
bins = list(kernel_add.cache[device].values())
|
||||
assert bins[0].asm['ttir'] != bins[1].asm['ttir']
|
||||
assert bins[2].asm['ttir'] != bins[1].asm['ttir']
|
||||
|
||||
|
||||
@triton.jit
|
||||
|
||||
@@ -859,7 +859,9 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
if not self.module.has_function(fn_name):
|
||||
prototype = language.function_type([], arg_types)
|
||||
gscope = sys.modules[fn.fn.__module__].__dict__
|
||||
generator = CodeGenerator(self.builder.context, prototype, gscope, attributes, constants, module=self.module, function_name=fn_name, function_types=self.function_ret_types, debug=fn.debug, noinline=fn.noinline)
|
||||
# If the callee is not set, we use the same debug setting as the caller
|
||||
debug = self.debug if fn.debug is None else fn.debug
|
||||
generator = CodeGenerator(self.builder.context, prototype, gscope, attributes, constants, module=self.module, function_name=fn_name, function_types=self.function_ret_types, debug=debug, noinline=fn.noinline)
|
||||
generator.visit(fn.parse())
|
||||
callee_ret_type = generator.last_ret_type
|
||||
self.function_ret_types[fn_name] = callee_ret_type
|
||||
|
||||
@@ -356,7 +356,7 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage
|
||||
# when called with a grid using __getitem__
|
||||
self.kernel_decorators = []
|
||||
self.kernel = None
|
||||
self.debug = os.environ.get("TRITON_DEBUG", "0") == "1" if debug is None else debug
|
||||
self.debug = True if os.environ.get("TRITON_DEBUG", "0") == "1" else debug
|
||||
self.noinline = noinline
|
||||
# annotations
|
||||
normalize_ty = lambda ty: ty.__name__ if isinstance(ty, type) else ty
|
||||
|
||||
Reference in New Issue
Block a user