diff --git a/tinygrad/runtime/ops_triton.py b/tinygrad/runtime/ops_triton.py index 706d01af32..9487778284 100644 --- a/tinygrad/runtime/ops_triton.py +++ b/tinygrad/runtime/ops_triton.py @@ -81,7 +81,11 @@ def uops_to_triton(function_name:str, uops:List[UOp]): elif uop == UOps.LOAD: assert newvar is not None triton_dtype = {dtypes.float32: "tl.float32", dtypes.float16: "tl.float16", dtypes.int8: "tl.int8", dtypes.uint8: "tl.uint8", dtypes.int32: "tl.int32", dtypes.int64: "tl.int64"}[newvar.dtype] - if isinstance(args, ConstOp): kk(f"{newvar.render()} = {args.value}") + if isinstance(args, ConstOp): + if len(local_size) > 0: + kk(f"{newvar.render()} = tl.full(({','.join([str(next_power_of_2(x)) for x in local_size])},), {args.value}, dtype={triton_dtype})") + else: + kk(f"{newvar.render()} = {args.value}") elif args.valid.min == 1: kk(f"{newvar.render()} = tl.load({args.name} + {args.idx.render()}, mask = {args.idx.render()}<{args.idx.max+1}).to({triton_dtype})") else: kk(f"{newvar.render()} = tl.where({args.valid.render()}, tl.load({args.name}), 0.0).to({triton_dtype})") elif uop == UOps.STORE: