diff --git a/python/triton/language/extern.py b/python/triton/language/extern.py index 98e42100d..6019726ef 100644 --- a/python/triton/language/extern.py +++ b/python/triton/language/extern.py @@ -65,9 +65,9 @@ def elementwise(lib_name: str, lib_path: str, args: list, arg_type_symbol_dict: if not all_scalar: broadcast_arg = dispatch_args[0] # Get the broadcast shape over all the arguments - for i in range(len(dispatch_args)): + for i, item in enumerate(dispatch_args): _, broadcast_arg = semantic.binary_op_type_checking_impl( - dispatch_args[i], broadcast_arg, _builder) + item, broadcast_arg, _builder) # Change the shape of each argument based on the broadcast shape for i in range(len(dispatch_args)): dispatch_args[i], _ = semantic.binary_op_type_checking_impl( diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index af51da74e..67245037b 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -533,10 +533,10 @@ def broadcast_impl_shape(input: tl.tensor, raise ValueError(f"Cannot broadcast, rank mismatch: {src_shape}, {shape}") if shape == src_shape: return input - for i in range(len(src_shape)): - if shape[i] != src_shape[i] and src_shape[i] != 1: + for i, item in enumerate(src_shape): + if shape[i] != item and item != 1: raise ValueError(f"Cannot broadcast, the expanded size of the tensor ({shape[i]})" - f" must match the existing size ({src_shape[i]}) at non-singleton dimension" + f" must match the existing size ({item}) at non-singleton dimension" f" {i}: {src_shape}, {shape}") ret_ty = tl.block_type(input.type.scalar, shape) return tl.tensor(builder.create_broadcast(input.handle, shape), ret_ty) @@ -576,8 +576,7 @@ def broadcast_impl_value(lhs: tl.tensor, assert len(rhs_shape) == len(lhs_shape) ret_shape = [] - for i in range(len(lhs_shape)): - left = lhs_shape[i] + for i, left in enumerate(lhs_shape): right = rhs_shape[i] if left == 1: ret_shape.append(right)