mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
const ops match local shape
This commit is contained in:
@@ -81,7 +81,11 @@ def uops_to_triton(function_name:str, uops:List[UOp]):
|
||||
elif uop == UOps.LOAD:
|
||||
assert newvar is not None
|
||||
triton_dtype = {dtypes.float32: "tl.float32", dtypes.float16: "tl.float16", dtypes.int8: "tl.int8", dtypes.uint8: "tl.uint8", dtypes.int32: "tl.int32", dtypes.int64: "tl.int64"}[newvar.dtype]
|
||||
if isinstance(args, ConstOp): kk(f"{newvar.render()} = {args.value}")
|
||||
if isinstance(args, ConstOp):
|
||||
if len(local_size) > 0:
|
||||
kk(f"{newvar.render()} = tl.full(({','.join([str(next_power_of_2(x)) for x in local_size])},), {args.value}, dtype={triton_dtype})")
|
||||
else:
|
||||
kk(f"{newvar.render()} = {args.value}")
|
||||
elif args.valid.min == 1: kk(f"{newvar.render()} = tl.load({args.name} + {args.idx.render()}, mask = {args.idx.render()}<{args.idx.max+1}).to({triton_dtype})")
|
||||
else: kk(f"{newvar.render()} = tl.where({args.valid.render()}, tl.load({args.name}), 0.0).to({triton_dtype})")
|
||||
elif uop == UOps.STORE:
|
||||
|
||||
Reference in New Issue
Block a user