const ops match local shape

This commit is contained in:
Szymon Ożóg
2023-08-19 14:31:36 +02:00
parent 320a012772
commit 5533db9a6a

View File

@@ -81,7 +81,11 @@ def uops_to_triton(function_name:str, uops:List[UOp]):
elif uop == UOps.LOAD:
assert newvar is not None
triton_dtype = {dtypes.float32: "tl.float32", dtypes.float16: "tl.float16", dtypes.int8: "tl.int8", dtypes.uint8: "tl.uint8", dtypes.int32: "tl.int32", dtypes.int64: "tl.int64"}[newvar.dtype]
if isinstance(args, ConstOp): kk(f"{newvar.render()} = {args.value}")
if isinstance(args, ConstOp):
if len(local_size) > 0:
kk(f"{newvar.render()} = tl.full(({','.join([str(next_power_of_2(x)) for x in local_size])},), {args.value}, dtype={triton_dtype})")
else:
kk(f"{newvar.render()} = {args.value}")
elif args.valid.min == 1: kk(f"{newvar.render()} = tl.load({args.name} + {args.idx.render()}, mask = {args.idx.render()}<{args.idx.max+1}).to({triton_dtype})")
else: kk(f"{newvar.render()} = tl.where({args.valid.render()}, tl.load({args.name}), 0.0).to({triton_dtype})")
elif uop == UOps.STORE: