[FRONTEND][OPTIMIZER] small perf improvements (#1945)

This commit is contained in:
Philippe Tillet
2023-07-14 15:11:36 -07:00
committed by GitHub
parent 80163a9c1e
commit 8207eabd7b
6 changed files with 40 additions and 13 deletions

View File

@@ -809,11 +809,11 @@ void LoopPipeliner::prefetchNextIteration(scf::ForOp newForOp,
nextIV, newForOp.getUpperBound());
pipelineIterIdx = newForOp.getRegionIterArgs()[ivIndex + 1];
Value insertSliceIndex = builder.create<arith::RemSIOp>(
Value insertSliceIndex = builder.create<arith::RemUIOp>(
nextIV.getLoc(), pipelineIterIdx,
builder.create<arith::ConstantIntOp>(nextIV.getLoc(), numStages, 32));
loopIterIdx = newForOp.getRegionIterArgs()[ivIndex + 2];
Value extractSliceIndex = builder.create<arith::RemSIOp>(
Value extractSliceIndex = builder.create<arith::RemUIOp>(
nextIV.getLoc(), loopIterIdx,
builder.create<arith::ConstantIntOp>(nextIV.getLoc(), numStages, 32));

View File

@@ -125,8 +125,8 @@ def download_and_copy_ptxas():
base_dir = os.path.dirname(__file__)
src_path = "bin/ptxas"
version = "12.1.105"
url = f"https://conda.anaconda.org/nvidia/label/cuda-12.1.1/linux-64/cuda-nvcc-{version}-0.tar.bz2"
version = "12.2.91"
url = f"https://conda.anaconda.org/nvidia/label/cuda-12.2.0/linux-64/cuda-nvcc-{version}-0.tar.bz2"
dst_prefix = os.path.join(base_dir, "triton")
dst_suffix = os.path.join("third_party", "cuda", src_path)
dst_path = os.path.join(dst_prefix, dst_suffix)

View File

@@ -7,7 +7,7 @@ from typing import Callable, List, Sequence, TypeVar
from .._C.libtriton.triton import ir
from ..runtime.jit import jit
from . import semantic
from . import math, semantic
T = TypeVar('T')
@@ -1422,6 +1422,11 @@ def _argmax_combine_tie_break_fast(value1, index1, value2, index2):
return _argmax_combine(value1, index1, value2, index2, False)
@jit
def _fast_max(x, y):
return math.max(x, y)
@jit
@_add_reduction_docstr("maximum",
return_indices_arg="return_indices",
@@ -1434,7 +1439,13 @@ def max(input, axis=None, return_indices=False, return_indices_tie_break_left=Tr
else:
return _reduce_with_indices(input, axis, _argmax_combine_tie_break_fast)
else:
return reduce(input, axis, maximum)
if constexpr(input.dtype.primitive_bitwidth) < 32:
if constexpr(input.dtype.is_floating()):
input = input.to(float32)
else:
assert input.dtype.is_integer_type()
input = input.to(int32)
return reduce(input, axis, _fast_max)
@jit
@@ -1468,6 +1479,11 @@ def _argmin_combine_tie_break_fast(value1, index1, value2, index2):
return _argmin_combine(value1, index1, value2, index2, False)
@jit
def _fast_min(x, y):
return math.min(x, y)
@jit
@_add_reduction_docstr("minimum",
return_indices_arg="return_indices",
@@ -1480,7 +1496,13 @@ def min(input, axis=None, return_indices=False, return_indices_tie_break_left=Tr
else:
return _reduce_with_indices(input, axis, _argmin_combine_tie_break_fast)
else:
return reduce(input, axis, minimum)
if constexpr(input.dtype.primitive_bitwidth) < 32:
if constexpr(input.dtype.is_floating()):
input = input.to(float32)
else:
assert input.dtype.is_integer_type()
input = input.to(int32)
return reduce(input, axis, _fast_min)
@jit

Binary file not shown.

View File

@@ -13,6 +13,11 @@ import triton
import triton.language as tl
@triton.jit
def max_fn(x, y):
return tl.math.max(x, y)
@triton.jit
def _fwd_kernel(
Q, K, V, sm_scale,

View File

@@ -37,8 +37,8 @@
// CHECK: %[[arg_b0_dot_op_0:.*]] = triton_gpu.convert_layout %[[arg_b0]]
// CHECK: %[[arg_b0_dot_op_1:.*]] = arith.mulf %[[arg_b0_dot_op_0]]
// CHECK: tt.dot %[[arg_a0_dot_op]], %[[arg_b0_dot_op_1]], {{.*}}
// CHECK-DAG: %[[INSERT_IDX:.*]] = arith.remsi %[[PIPELINE_IDX]], %[[CONSTANT_3]]
// CHECK-DAG: %[[EXTRACT_IDX:.*]] = arith.remsi %[[LOOP_IDX]], %[[CONSTANT_3]]
// CHECK-DAG: %[[INSERT_IDX:.*]] = arith.remui %[[PIPELINE_IDX]], %[[CONSTANT_3]]
// CHECK-DAG: %[[EXTRACT_IDX:.*]] = arith.remui %[[LOOP_IDX]], %[[CONSTANT_3]]
// CHECK: %[[NEXT_A_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[INSERT_IDX]]
// CHECK: %[[NEXT_B_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[INSERT_IDX]]
// CHECK: triton_gpu.async_wait {num = 2 : i32}
@@ -110,8 +110,8 @@ tt.func @matmul_loop(%lb : index, %ub : index, %step : index,
// CHECK: %[[arg_a0_dot_op:.*]] = triton_gpu.convert_layout %[[arg_a0]]
// CHECK: %[[arg_b0_dot_op:.*]] = triton_gpu.convert_layout %[[arg_b0]]
// CHECK: tt.dot %[[arg_a0_dot_op]], %[[arg_b0_dot_op]], {{.*}}
// CHECK-DAG: %[[INSERT_IDX:.*]] = arith.remsi %[[PIPELINE_IDX]], %[[CONSTANT_3]]
// CHECK-DAG: %[[EXTRACT_IDX:.*]] = arith.remsi %[[LOOP_IDX]], %[[CONSTANT_3]]
// CHECK-DAG: %[[INSERT_IDX:.*]] = arith.remui %[[PIPELINE_IDX]], %[[CONSTANT_3]]
// CHECK-DAG: %[[EXTRACT_IDX:.*]] = arith.remui %[[LOOP_IDX]], %[[CONSTANT_3]]
// CHECK: %[[NEXT_A_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[INSERT_IDX]]
// CHECK: %[[NEXT_B_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[INSERT_IDX]]
// CHECK: triton_gpu.async_wait {num = 2 : i32}
@@ -179,8 +179,8 @@ tt.func @matmul_loop_nested(%lb : index, %ub : index, %step : index,
// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, %[[arg_b0:.*]] = %[[B0]], {{.*}}, {{.*}}, %[[PIPELINE_IDX:.*]] = %[[CONSTANT_2]], %[[LOOP_IDX:.*]] = %[[CONSTANT_1]]
// CHECK: %[[arg_b0_dot_op:.*]] = triton_gpu.convert_layout %[[arg_b0]]
// CHECK: tt.dot {{.*}}, %[[arg_b0_dot_op]], {{.*}}
// CHECK-DAG: %[[INSERT_IDX:.*]] = arith.remsi %[[PIPELINE_IDX]], %[[CONSTANT_3]]
// CHECK-DAG: %[[EXTRACT_IDX:.*]] = arith.remsi %[[LOOP_IDX]], %[[CONSTANT_3]]
// CHECK-DAG: %[[INSERT_IDX:.*]] = arith.remui %[[PIPELINE_IDX]], %[[CONSTANT_3]]
// CHECK-DAG: %[[EXTRACT_IDX:.*]] = arith.remui %[[LOOP_IDX]], %[[CONSTANT_3]]
// CHECK: %[[NEXT_B_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[INSERT_IDX]]
// CHECK: triton_gpu.async_wait {num = 1 : i32}
// CHECK: %[[NEXT_B:.*]] = triton_gpu.extract_slice %[[NEXT_B_BUFFER]][%[[EXTRACT_IDX]], 0, 0]