mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
upscale local index to power of 2 and add masking
This commit is contained in:
@@ -32,6 +32,9 @@ class TritonProgram:
|
||||
|
||||
self.program(*[x._buf for x in args], block = tuple(local_size), grid = tuple(global_size))
|
||||
|
||||
def next_power_of_2(x):
|
||||
return 1 << (x - 1).bit_length()
|
||||
|
||||
def uops_to_triton(function_name:str, uops:List[UOp]):
|
||||
kernel = []
|
||||
global_size: List[int] = []
|
||||
@@ -67,7 +70,7 @@ def uops_to_triton(function_name:str, uops:List[UOp]):
|
||||
elif args[1] == "local":
|
||||
full_local_shape = tuple([var.max+1 for var in args[0]])
|
||||
assert var.min == 0, "local loop must start at 0"
|
||||
kk(f"{var.expr} = tl.arange({0}, {var.max+1})[{', '.join([':' if i == j else 'None' for j in range(len(args[0]))])}]")
|
||||
kk(f"{var.expr} = tl.arange({0}, {next_power_of_2(var.max+1)})[{', '.join([':' if i == j else 'None' for j in range(len(args[0]))])}]")
|
||||
acc_local_shape *= var.max+1
|
||||
local_size.append(var.max+1)
|
||||
else:
|
||||
@@ -84,13 +87,13 @@ def uops_to_triton(function_name:str, uops:List[UOp]):
|
||||
assert newvar is not None
|
||||
val = f"{args.name}" # defaults to render_python
|
||||
triton_dtype = {dtypes.float32: "tl.float32", dtypes.float16: "tl.float16", dtypes.int8: "tl.int8", dtypes.uint8: "tl.uint8", dtypes.int32: "tl.int32", dtypes.int64: "tl.int64"}[newvar.dtype]
|
||||
if args.valid.min == 1: kk(f"{newvar.render()} = tl.load({val} + {args.idx.render()}).to({triton_dtype})")
|
||||
if args.valid.min == 1: kk(f"{newvar.render()} = tl.load({val} + {args.idx.render()}, mask = {args.idx.render()}<{args.idx.max+1}).to({triton_dtype})")
|
||||
else: kk(f"{newvar.render()} = tl.where({args.valid.render()}, tl.load({val}, mask={args.valid.render()}), 0.0).to({triton_dtype})")
|
||||
elif uop == UOps.STORE:
|
||||
assert vin[0].dtype == dtypes.float, "unimplemented: float4 store"
|
||||
assert not isinstance(args.memory_dtype, ImageDType), "unimplemented: image store"
|
||||
assert args.valid.min == 1, "store must be valid"
|
||||
kk(f"tl.store({args.name} + {args.idx.render()}, {vin[0].render()})")
|
||||
kk(f"tl.store({args.name} + {args.idx.render()}, {vin[0].render()}, mask = {args.idx.render()}<{args.idx.max+1})")
|
||||
elif uop == UOps.DEFINE_GLOBAL:
|
||||
bufs.append(args)
|
||||
elif uop == UOps.CAST: raise NotImplementedError("unimplemented: cast")
|
||||
|
||||
Reference in New Issue
Block a user