mirror of
https://github.com/ROCm/ROCm.git
synced 2026-02-21 03:00:39 -05:00
[OPTIMIZER] Add folder for MakeRangeOp (#2187)
This folds `tl.arange(x, x + 1)` into a constant. This shows up for example when autotuning and one of the block sizes gets set to 1. Co-authored-by: Philippe Tillet <phil@openai.com>
This commit is contained in:
@@ -498,6 +498,8 @@ def TT_MakeRangeOp : TT_Op<"make_range", [Pure]> {
|
||||
let results = (outs TT_IntTensor:$result);
|
||||
|
||||
let assemblyFormat = "attr-dict `:` type($result)";
|
||||
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
//
|
||||
|
||||
@@ -428,6 +428,16 @@ LogicalResult mlir::triton::DotOp::verify() {
|
||||
bEncoding);
|
||||
}
|
||||
|
||||
//-- MakeRangeOp --
|
||||
OpFoldResult MakeRangeOp::fold(FoldAdaptor adaptor) {
|
||||
// make_range(start, start + 1) -> constant(start)
|
||||
if (adaptor.getStart() + 1 == adaptor.getEnd()) {
|
||||
auto shapedType = getType().cast<ShapedType>();
|
||||
return SplatElementsAttr::get(shapedType, adaptor.getStartAttr());
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
//-- ReduceOp --
|
||||
static mlir::LogicalResult
|
||||
inferReduceReturnShape(const RankedTensorType &argTy, const Type &retEltTy,
|
||||
|
||||
@@ -10,3 +10,18 @@ tt.func @dead_load(%ptr: tensor<32x128x!tt.ptr<f16>>) {
|
||||
%b = tt.load %ptr, %mask, %other {cache = 1 : i32, evict = 1 : i32, isVolatile = true} : tensor<32x128xf16>
|
||||
tt.return
|
||||
}
|
||||
|
||||
|
||||
// CHECK-LABEL: make_range
|
||||
tt.func @make_range() -> (tensor<128x1xi32>, tensor<1xi32>) {
|
||||
// CHECK-DAG: %[[c:.*]] = arith.constant dense<0> : tensor<128x1xi32>
|
||||
%a = tt.make_range {end = 1 : i32, start = 0 : i32} : tensor<1xi32>
|
||||
%b = tt.expand_dims %a {axis = 1 : i32} : (tensor<1xi32>) -> tensor<1x1xi32>
|
||||
%c = tt.broadcast %b : (tensor<1x1xi32>) -> tensor<128x1xi32>
|
||||
|
||||
// CHECK-DAG: %[[d:.*]] = arith.constant dense<1> : tensor<1xi32>
|
||||
%d = tt.make_range {end = 2 : i32, start = 1 : i32} : tensor<1xi32>
|
||||
|
||||
// CHECK-DAG: tt.return %[[c]], %[[d]] : tensor<128x1xi32>, tensor<1xi32>
|
||||
tt.return %c, %d : tensor<128x1xi32>, tensor<1xi32>
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user