remove deprecated variables

This commit is contained in:
Szymon Ożóg
2023-08-19 13:56:37 +02:00
parent fecc58cc2b
commit 4123920bcc

View File

@@ -43,9 +43,6 @@ def uops_to_triton(function_name:str, uops:List[UOp]):
bufs = []
def kk(s): kernel.append(" "*depth+s)
full_local_shape: Tuple[Any, ...] = ()
acc_local_shape = 1
gid = [f"tl.program_id({i})" for i in range(3)]
code_for_op: Final[Dict[Op, Callable]] = {
UnaryOps.EXP2: lambda x: f"tl.math.exp2({x})",
@@ -69,10 +66,8 @@ def uops_to_triton(function_name:str, uops:List[UOp]):
global_size.append(var.max+1)
kk(f"{var.expr} = {gid[i]} # {var.max+1}")
elif args[1] == "local":
full_local_shape = tuple([var.max+1 for var in args[0]])
assert var.min == 0, "local loop must start at 0"
kk(f"{var.expr} = tl.arange({0}, {next_power_of_2(var.max+1)})[{', '.join([':' if i == j else 'None' for j in range(len(args[0]))])}]")
acc_local_shape *= var.max+1
local_size.append(var.max+1)
else:
kk(f"for {var.expr} in range({var.min}, {var.max+1}):")