mirror of
https://github.com/ROCm/ROCm.git
synced 2026-02-21 03:00:39 -05:00
[FRONTEND][OPTIMIZER] small perf improvements (#1945)
This commit is contained in:
@@ -809,11 +809,11 @@ void LoopPipeliner::prefetchNextIteration(scf::ForOp newForOp,
|
||||
nextIV, newForOp.getUpperBound());
|
||||
|
||||
pipelineIterIdx = newForOp.getRegionIterArgs()[ivIndex + 1];
|
||||
Value insertSliceIndex = builder.create<arith::RemSIOp>(
|
||||
Value insertSliceIndex = builder.create<arith::RemUIOp>(
|
||||
nextIV.getLoc(), pipelineIterIdx,
|
||||
builder.create<arith::ConstantIntOp>(nextIV.getLoc(), numStages, 32));
|
||||
loopIterIdx = newForOp.getRegionIterArgs()[ivIndex + 2];
|
||||
Value extractSliceIndex = builder.create<arith::RemSIOp>(
|
||||
Value extractSliceIndex = builder.create<arith::RemUIOp>(
|
||||
nextIV.getLoc(), loopIterIdx,
|
||||
builder.create<arith::ConstantIntOp>(nextIV.getLoc(), numStages, 32));
|
||||
|
||||
|
||||
@@ -125,8 +125,8 @@ def download_and_copy_ptxas():
|
||||
|
||||
base_dir = os.path.dirname(__file__)
|
||||
src_path = "bin/ptxas"
|
||||
version = "12.1.105"
|
||||
url = f"https://conda.anaconda.org/nvidia/label/cuda-12.1.1/linux-64/cuda-nvcc-{version}-0.tar.bz2"
|
||||
version = "12.2.91"
|
||||
url = f"https://conda.anaconda.org/nvidia/label/cuda-12.2.0/linux-64/cuda-nvcc-{version}-0.tar.bz2"
|
||||
dst_prefix = os.path.join(base_dir, "triton")
|
||||
dst_suffix = os.path.join("third_party", "cuda", src_path)
|
||||
dst_path = os.path.join(dst_prefix, dst_suffix)
|
||||
|
||||
@@ -7,7 +7,7 @@ from typing import Callable, List, Sequence, TypeVar
|
||||
|
||||
from .._C.libtriton.triton import ir
|
||||
from ..runtime.jit import jit
|
||||
from . import semantic
|
||||
from . import math, semantic
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
@@ -1422,6 +1422,11 @@ def _argmax_combine_tie_break_fast(value1, index1, value2, index2):
|
||||
return _argmax_combine(value1, index1, value2, index2, False)
|
||||
|
||||
|
||||
@jit
|
||||
def _fast_max(x, y):
|
||||
return math.max(x, y)
|
||||
|
||||
|
||||
@jit
|
||||
@_add_reduction_docstr("maximum",
|
||||
return_indices_arg="return_indices",
|
||||
@@ -1434,7 +1439,13 @@ def max(input, axis=None, return_indices=False, return_indices_tie_break_left=Tr
|
||||
else:
|
||||
return _reduce_with_indices(input, axis, _argmax_combine_tie_break_fast)
|
||||
else:
|
||||
return reduce(input, axis, maximum)
|
||||
if constexpr(input.dtype.primitive_bitwidth) < 32:
|
||||
if constexpr(input.dtype.is_floating()):
|
||||
input = input.to(float32)
|
||||
else:
|
||||
assert input.dtype.is_integer_type()
|
||||
input = input.to(int32)
|
||||
return reduce(input, axis, _fast_max)
|
||||
|
||||
|
||||
@jit
|
||||
@@ -1468,6 +1479,11 @@ def _argmin_combine_tie_break_fast(value1, index1, value2, index2):
|
||||
return _argmin_combine(value1, index1, value2, index2, False)
|
||||
|
||||
|
||||
@jit
|
||||
def _fast_min(x, y):
|
||||
return math.min(x, y)
|
||||
|
||||
|
||||
@jit
|
||||
@_add_reduction_docstr("minimum",
|
||||
return_indices_arg="return_indices",
|
||||
@@ -1480,7 +1496,13 @@ def min(input, axis=None, return_indices=False, return_indices_tie_break_left=Tr
|
||||
else:
|
||||
return _reduce_with_indices(input, axis, _argmin_combine_tie_break_fast)
|
||||
else:
|
||||
return reduce(input, axis, minimum)
|
||||
if constexpr(input.dtype.primitive_bitwidth) < 32:
|
||||
if constexpr(input.dtype.is_floating()):
|
||||
input = input.to(float32)
|
||||
else:
|
||||
assert input.dtype.is_integer_type()
|
||||
input = input.to(int32)
|
||||
return reduce(input, axis, _fast_min)
|
||||
|
||||
|
||||
@jit
|
||||
|
||||
BIN
python/triton/third_party/cuda/bin/ptxas
vendored
BIN
python/triton/third_party/cuda/bin/ptxas
vendored
Binary file not shown.
@@ -13,6 +13,11 @@ import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
@triton.jit
|
||||
def max_fn(x, y):
|
||||
return tl.math.max(x, y)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _fwd_kernel(
|
||||
Q, K, V, sm_scale,
|
||||
|
||||
@@ -37,8 +37,8 @@
|
||||
// CHECK: %[[arg_b0_dot_op_0:.*]] = triton_gpu.convert_layout %[[arg_b0]]
|
||||
// CHECK: %[[arg_b0_dot_op_1:.*]] = arith.mulf %[[arg_b0_dot_op_0]]
|
||||
// CHECK: tt.dot %[[arg_a0_dot_op]], %[[arg_b0_dot_op_1]], {{.*}}
|
||||
// CHECK-DAG: %[[INSERT_IDX:.*]] = arith.remsi %[[PIPELINE_IDX]], %[[CONSTANT_3]]
|
||||
// CHECK-DAG: %[[EXTRACT_IDX:.*]] = arith.remsi %[[LOOP_IDX]], %[[CONSTANT_3]]
|
||||
// CHECK-DAG: %[[INSERT_IDX:.*]] = arith.remui %[[PIPELINE_IDX]], %[[CONSTANT_3]]
|
||||
// CHECK-DAG: %[[EXTRACT_IDX:.*]] = arith.remui %[[LOOP_IDX]], %[[CONSTANT_3]]
|
||||
// CHECK: %[[NEXT_A_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[INSERT_IDX]]
|
||||
// CHECK: %[[NEXT_B_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[INSERT_IDX]]
|
||||
// CHECK: triton_gpu.async_wait {num = 2 : i32}
|
||||
@@ -110,8 +110,8 @@ tt.func @matmul_loop(%lb : index, %ub : index, %step : index,
|
||||
// CHECK: %[[arg_a0_dot_op:.*]] = triton_gpu.convert_layout %[[arg_a0]]
|
||||
// CHECK: %[[arg_b0_dot_op:.*]] = triton_gpu.convert_layout %[[arg_b0]]
|
||||
// CHECK: tt.dot %[[arg_a0_dot_op]], %[[arg_b0_dot_op]], {{.*}}
|
||||
// CHECK-DAG: %[[INSERT_IDX:.*]] = arith.remsi %[[PIPELINE_IDX]], %[[CONSTANT_3]]
|
||||
// CHECK-DAG: %[[EXTRACT_IDX:.*]] = arith.remsi %[[LOOP_IDX]], %[[CONSTANT_3]]
|
||||
// CHECK-DAG: %[[INSERT_IDX:.*]] = arith.remui %[[PIPELINE_IDX]], %[[CONSTANT_3]]
|
||||
// CHECK-DAG: %[[EXTRACT_IDX:.*]] = arith.remui %[[LOOP_IDX]], %[[CONSTANT_3]]
|
||||
// CHECK: %[[NEXT_A_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[INSERT_IDX]]
|
||||
// CHECK: %[[NEXT_B_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[INSERT_IDX]]
|
||||
// CHECK: triton_gpu.async_wait {num = 2 : i32}
|
||||
@@ -179,8 +179,8 @@ tt.func @matmul_loop_nested(%lb : index, %ub : index, %step : index,
|
||||
// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, %[[arg_b0:.*]] = %[[B0]], {{.*}}, {{.*}}, %[[PIPELINE_IDX:.*]] = %[[CONSTANT_2]], %[[LOOP_IDX:.*]] = %[[CONSTANT_1]]
|
||||
// CHECK: %[[arg_b0_dot_op:.*]] = triton_gpu.convert_layout %[[arg_b0]]
|
||||
// CHECK: tt.dot {{.*}}, %[[arg_b0_dot_op]], {{.*}}
|
||||
// CHECK-DAG: %[[INSERT_IDX:.*]] = arith.remsi %[[PIPELINE_IDX]], %[[CONSTANT_3]]
|
||||
// CHECK-DAG: %[[EXTRACT_IDX:.*]] = arith.remsi %[[LOOP_IDX]], %[[CONSTANT_3]]
|
||||
// CHECK-DAG: %[[INSERT_IDX:.*]] = arith.remui %[[PIPELINE_IDX]], %[[CONSTANT_3]]
|
||||
// CHECK-DAG: %[[EXTRACT_IDX:.*]] = arith.remui %[[LOOP_IDX]], %[[CONSTANT_3]]
|
||||
// CHECK: %[[NEXT_B_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[INSERT_IDX]]
|
||||
// CHECK: triton_gpu.async_wait {num = 1 : i32}
|
||||
// CHECK: %[[NEXT_B:.*]] = triton_gpu.extract_slice %[[NEXT_B_BUFFER]][%[[EXTRACT_IDX]], 0, 0]
|
||||
|
||||
Reference in New Issue
Block a user