upscale local index to power of 2 and add masking

This commit is contained in:
Szymon Ożóg
2023-08-18 17:46:46 +02:00
parent 2cfc7121b1
commit 89bda3c550

View File

@@ -32,6 +32,9 @@ class TritonProgram:
self.program(*[x._buf for x in args], block = tuple(local_size), grid = tuple(global_size))
def next_power_of_2(x):
return 1 << (x - 1).bit_length()
def uops_to_triton(function_name:str, uops:List[UOp]):
kernel = []
global_size: List[int] = []
@@ -67,7 +70,7 @@ def uops_to_triton(function_name:str, uops:List[UOp]):
elif args[1] == "local":
full_local_shape = tuple([var.max+1 for var in args[0]])
assert var.min == 0, "local loop must start at 0"
kk(f"{var.expr} = tl.arange({0}, {var.max+1})[{', '.join([':' if i == j else 'None' for j in range(len(args[0]))])}]")
kk(f"{var.expr} = tl.arange({0}, {next_power_of_2(var.max+1)})[{', '.join([':' if i == j else 'None' for j in range(len(args[0]))])}]")
acc_local_shape *= var.max+1
local_size.append(var.max+1)
else:
@@ -84,13 +87,13 @@ def uops_to_triton(function_name:str, uops:List[UOp]):
assert newvar is not None
val = f"{args.name}" # defaults to render_python
triton_dtype = {dtypes.float32: "tl.float32", dtypes.float16: "tl.float16", dtypes.int8: "tl.int8", dtypes.uint8: "tl.uint8", dtypes.int32: "tl.int32", dtypes.int64: "tl.int64"}[newvar.dtype]
if args.valid.min == 1: kk(f"{newvar.render()} = tl.load({val} + {args.idx.render()}).to({triton_dtype})")
if args.valid.min == 1: kk(f"{newvar.render()} = tl.load({val} + {args.idx.render()}, mask = {args.idx.render()}<{args.idx.max+1}).to({triton_dtype})")
else: kk(f"{newvar.render()} = tl.where({args.valid.render()}, tl.load({val}, mask={args.valid.render()}), 0.0).to({triton_dtype})")
elif uop == UOps.STORE:
assert vin[0].dtype == dtypes.float, "unimplemented: float4 store"
assert not isinstance(args.memory_dtype, ImageDType), "unimplemented: image store"
assert args.valid.min == 1, "store must be valid"
kk(f"tl.store({args.name} + {args.idx.render()}, {vin[0].render()})")
kk(f"tl.store({args.name} + {args.idx.render()}, {vin[0].render()}, mask = {args.idx.render()}<{args.idx.max+1})")
elif uop == UOps.DEFINE_GLOBAL:
bufs.append(args)
elif uop == UOps.CAST: raise NotImplementedError("unimplemented: cast")