[FRONTEND] Correct the debug syntax (#1705)

- If `TRITON_DEBUG=True`, all triton functions will be compiled in the
debug mode.
- Otherwise, a triton function `f`'s debug flag is either `True`,
`False` or `None` (default).
    - If `True`, `f` is compiled in the debug mode.
    - If `False`, `f` is compiled in the normal mode.
- If `None`, `f` is compiled based on its caller's debug flag. The root
(kernel) function's debug flag can also be set through the `compile`
function.

cc @ngimel , @Chillee
This commit is contained in:
Keren Zhou
2023-05-25 02:24:29 -04:00
committed by GitHub
parent 74dbb2fc0a
commit 0341953466
5 changed files with 113 additions and 7 deletions

View File

@@ -22,6 +22,13 @@ def kernel_device_assert_scalar(X, Y, BLOCK: tl.constexpr):
tl.store(Y + tl.arange(0, BLOCK), x)
@triton.jit(debug=False)
def kernel_device_assert_no_debug(X, Y, BLOCK: tl.constexpr):
x = tl.load(X + tl.arange(0, BLOCK))
tl.device_assert(x == 0, "x != 0")
tl.store(Y + tl.arange(0, BLOCK), x)
@triton.jit
def kernel_assert(X, Y, BLOCK: tl.constexpr):
x = tl.load(X + tl.arange(0, BLOCK))
@@ -43,6 +50,9 @@ def test_assert(func: str):
if func == "device_assert":
kernel_device_assert[(1,)](x, y, BLOCK=shape[0])
kernel_device_assert_scalar[(1,)](x, y, BLOCK=shape[0])
elif func == "no_debug":
# TRITON_DEBUG=True can override the debug flag
kernel_device_assert_no_debug[(1,)](x, y, BLOCK=shape[0])
elif func == "assert":
kernel_assert[(1,)](x, y, BLOCK=shape[0])
elif func == "static_assert":
@@ -50,5 +60,72 @@ def test_assert(func: str):
assert_close(y, x)
@triton.jit
def jit_device_assert_none(x):
tl.device_assert(x == 0, "x != 0")
@triton.jit(debug=True)
def jit_device_assert_true(x):
tl.device_assert(x == 0, "x != 0")
@triton.jit(debug=False)
def jit_device_assert_false(x):
tl.device_assert(x == 0, "x != 0")
@triton.jit
def kernel_device_assert_nested(X, Y, BLOCK: tl.constexpr, jit_debug: tl.constexpr):
x = tl.load(X + tl.arange(0, BLOCK))
if jit_debug == "true":
jit_device_assert_true(x)
elif jit_debug == "false":
jit_device_assert_false(x)
else:
jit_device_assert_none(x)
tl.store(Y + tl.arange(0, BLOCK), x)
@triton.jit(debug=True)
def kernel_device_assert_nested_true(X, Y, BLOCK: tl.constexpr, jit_debug: tl.constexpr):
x = tl.load(X + tl.arange(0, BLOCK))
if jit_debug == "true":
jit_device_assert_true(x)
elif jit_debug == "false":
jit_device_assert_false(x)
else:
jit_device_assert_none(x)
tl.store(Y + tl.arange(0, BLOCK), x)
@triton.jit(debug=False)
def kernel_device_assert_nested_false(X, Y, BLOCK: tl.constexpr, jit_debug: tl.constexpr):
x = tl.load(X + tl.arange(0, BLOCK))
if jit_debug == "true":
jit_device_assert_true(x)
elif jit_debug == "false":
jit_device_assert_false(x)
else:
jit_device_assert_none(x)
tl.store(Y + tl.arange(0, BLOCK), x)
def test_assert_nested(caller: str, callee: str):
shape = (128, )
x = torch.arange(0, shape[0], dtype=torch.int32, device='cuda')
y = torch.zeros(shape, dtype=x.dtype, device="cuda")
if caller == "none":
kernel_device_assert_nested[(1,)](x, y, BLOCK=shape[0], jit_debug=callee)
elif caller == "true":
kernel_device_assert_nested_true[(1,)](x, y, BLOCK=shape[0], jit_debug=callee)
elif caller == "false":
kernel_device_assert_nested_false[(1,)](x, y, BLOCK=shape[0], jit_debug=callee)
assert_close(y, x)
if __name__ == "__main__":
test_assert(sys.argv[1])
if len(sys.argv) == 3:
test_assert_nested(sys.argv[1], sys.argv[2])
else:
test_assert(sys.argv[1])

View File

@@ -9,7 +9,8 @@ print_path = os.path.join(dir_path, "print_helper.py")
assert_path = os.path.join(dir_path, "assert_helper.py")
# TODO: bfloat16 after LLVM-15
func_types = ["device_assert", "assert", "static_assert"]
func_types = ["device_assert", "assert", "static_assert", "no_debug"]
nested_types = [(caller, callee) for caller in ["true", "false", "none"] for callee in ["true", "false", "none"]]
torch_types = ["int8", "uint8", "int16", "int32", "long", "float16", "float32", "float64"]
@@ -51,3 +52,29 @@ def test_assert(func_type: str):
assert num_errs == 127
else:
assert num_errs == 0
@pytest.mark.parametrize("caller_type, callee_type", nested_types)
def test_assert_nested(caller_type, callee_type):
proc = subprocess.Popen([sys.executable, assert_path, caller_type, callee_type], stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=False)
_, errs = proc.communicate()
errs = errs.splitlines()
num_errs = 0
for err in errs:
if "x != 0" in err.decode("utf-8"):
num_errs += 1
if caller_type == "none":
if callee_type == "true":
assert num_errs == 127
else:
assert num_errs == 0
elif caller_type == "true":
if callee_type == "false":
assert num_errs == 0
else:
assert num_errs == 127
elif caller_type == "false":
if callee_type == "true":
assert num_errs == 127
else:
assert num_errs == 0

View File

@@ -160,12 +160,12 @@ def test_jit_debug() -> None:
assert len(kernel_add.cache[device]) == 1
kernel_add.debug = False
kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1,))
assert len(kernel_add.cache[device]) == 1
assert len(kernel_add.cache[device]) == 2
kernel_add.debug = True
kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1,))
assert len(kernel_add.cache[device]) == 2
assert len(kernel_add.cache[device]) == 3
bins = list(kernel_add.cache[device].values())
assert bins[0].asm['ttir'] != bins[1].asm['ttir']
assert bins[2].asm['ttir'] != bins[1].asm['ttir']
@triton.jit

View File

@@ -859,7 +859,9 @@ class CodeGenerator(ast.NodeVisitor):
if not self.module.has_function(fn_name):
prototype = language.function_type([], arg_types)
gscope = sys.modules[fn.fn.__module__].__dict__
generator = CodeGenerator(self.builder.context, prototype, gscope, attributes, constants, module=self.module, function_name=fn_name, function_types=self.function_ret_types, debug=fn.debug, noinline=fn.noinline)
# If the callee is not set, we use the same debug setting as the caller
debug = self.debug if fn.debug is None else fn.debug
generator = CodeGenerator(self.builder.context, prototype, gscope, attributes, constants, module=self.module, function_name=fn_name, function_types=self.function_ret_types, debug=debug, noinline=fn.noinline)
generator.visit(fn.parse())
callee_ret_type = generator.last_ret_type
self.function_ret_types[fn_name] = callee_ret_type

View File

@@ -356,7 +356,7 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage
# when called with a grid using __getitem__
self.kernel_decorators = []
self.kernel = None
self.debug = os.environ.get("TRITON_DEBUG", "0") == "1" if debug is None else debug
self.debug = True if os.environ.get("TRITON_DEBUG", "0") == "1" else debug
self.noinline = noinline
# annotations
normalize_ty = lambda ty: ty.__name__ if isinstance(ty, type) else ty