From 89bda3c5507e8ff1ddb107cb0f12aa63a58dffd0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Szymon=20O=C5=BC=C3=B3g?= Date: Fri, 18 Aug 2023 17:46:46 +0200 Subject: [PATCH] upscale local index to power of 2 and add masking --- tinygrad/runtime/ops_triton.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/tinygrad/runtime/ops_triton.py b/tinygrad/runtime/ops_triton.py index f7bffb580d..b891979065 100644 --- a/tinygrad/runtime/ops_triton.py +++ b/tinygrad/runtime/ops_triton.py @@ -32,6 +32,9 @@ class TritonProgram: self.program(*[x._buf for x in args], block = tuple(local_size), grid = tuple(global_size)) +def next_power_of_2(x): + return 1 << (x - 1).bit_length() + def uops_to_triton(function_name:str, uops:List[UOp]): kernel = [] global_size: List[int] = [] @@ -67,7 +70,7 @@ def uops_to_triton(function_name:str, uops:List[UOp]): elif args[1] == "local": full_local_shape = tuple([var.max+1 for var in args[0]]) assert var.min == 0, "local loop must start at 0" - kk(f"{var.expr} = tl.arange({0}, {var.max+1})[{', '.join([':' if i == j else 'None' for j in range(len(args[0]))])}]") + kk(f"{var.expr} = tl.arange({0}, {next_power_of_2(var.max+1)})[{', '.join([':' if i == j else 'None' for j in range(len(args[0]))])}]") acc_local_shape *= var.max+1 local_size.append(var.max+1) else: @@ -84,13 +87,13 @@ def uops_to_triton(function_name:str, uops:List[UOp]): assert newvar is not None val = f"{args.name}" # defaults to render_python triton_dtype = {dtypes.float32: "tl.float32", dtypes.float16: "tl.float16", dtypes.int8: "tl.int8", dtypes.uint8: "tl.uint8", dtypes.int32: "tl.int32", dtypes.int64: "tl.int64"}[newvar.dtype] - if args.valid.min == 1: kk(f"{newvar.render()} = tl.load({val} + {args.idx.render()}).to({triton_dtype})") + if args.valid.min == 1: kk(f"{newvar.render()} = tl.load({val} + {args.idx.render()}, mask = {args.idx.render()}<{args.idx.max+1}).to({triton_dtype})") else: kk(f"{newvar.render()} = tl.where({args.valid.render()}, tl.load({val}, mask={args.valid.render()}), 0.0).to({triton_dtype})") elif uop == UOps.STORE: assert vin[0].dtype == dtypes.float, "unimplemented: float4 store" assert not isinstance(args.memory_dtype, ImageDType), "unimplemented: image store" assert args.valid.min == 1, "store must be valid" - kk(f"tl.store({args.name} + {args.idx.render()}, {vin[0].render()})") + kk(f"tl.store({args.name} + {args.idx.render()}, {vin[0].render()}, mask = {args.idx.render()}<{args.idx.max+1})") elif uop == UOps.DEFINE_GLOBAL: bufs.append(args) elif uop == UOps.CAST: raise NotImplementedError("unimplemented: cast")