From fa03b921097ac088b2fbdbcd46f54b50fc186d94 Mon Sep 17 00:00:00 2001 From: peterbell10 Date: Sat, 26 Aug 2023 17:44:13 +0100 Subject: [PATCH] [OPTIMIZER] Add folder for MakeRangeOp (#2187) This folds `tl.arange(x, x + 1)` into a constant. This shows up for example when autotuning and one of the block sizes gets set to 1. Co-authored-by: Philippe Tillet --- include/triton/Dialect/Triton/IR/TritonOps.td | 2 ++ lib/Dialect/Triton/IR/Ops.cpp | 10 ++++++++++ test/Triton/canonicalize.mlir | 15 +++++++++++++++ 3 files changed, 27 insertions(+) diff --git a/include/triton/Dialect/Triton/IR/TritonOps.td b/include/triton/Dialect/Triton/IR/TritonOps.td index ef0255925..69cad2bcf 100644 --- a/include/triton/Dialect/Triton/IR/TritonOps.td +++ b/include/triton/Dialect/Triton/IR/TritonOps.td @@ -498,6 +498,8 @@ def TT_MakeRangeOp : TT_Op<"make_range", [Pure]> { let results = (outs TT_IntTensor:$result); let assemblyFormat = "attr-dict `:` type($result)"; + + let hasFolder = 1; } // diff --git a/lib/Dialect/Triton/IR/Ops.cpp b/lib/Dialect/Triton/IR/Ops.cpp index 0adf9e737..2d7bec31c 100644 --- a/lib/Dialect/Triton/IR/Ops.cpp +++ b/lib/Dialect/Triton/IR/Ops.cpp @@ -428,6 +428,16 @@ LogicalResult mlir::triton::DotOp::verify() { bEncoding); } +//-- MakeRangeOp -- +OpFoldResult MakeRangeOp::fold(FoldAdaptor adaptor) { + // make_range(start, start + 1) -> constant(start) + if (adaptor.getStart() + 1 == adaptor.getEnd()) { + auto shapedType = getType().cast(); + return SplatElementsAttr::get(shapedType, adaptor.getStartAttr()); + } + return {}; +} + //-- ReduceOp -- static mlir::LogicalResult inferReduceReturnShape(const RankedTensorType &argTy, const Type &retEltTy, diff --git a/test/Triton/canonicalize.mlir b/test/Triton/canonicalize.mlir index 589a085bb..c532b4686 100644 --- a/test/Triton/canonicalize.mlir +++ b/test/Triton/canonicalize.mlir @@ -10,3 +10,18 @@ tt.func @dead_load(%ptr: tensor<32x128x!tt.ptr>) { %b = tt.load %ptr, %mask, %other {cache = 1 : i32, evict = 1 : i32, isVolatile = true} : tensor<32x128xf16> tt.return } + + +// CHECK-LABEL: make_range +tt.func @make_range() -> (tensor<128x1xi32>, tensor<1xi32>) { + // CHECK-DAG: %[[c:.*]] = arith.constant dense<0> : tensor<128x1xi32> + %a = tt.make_range {end = 1 : i32, start = 0 : i32} : tensor<1xi32> + %b = tt.expand_dims %a {axis = 1 : i32} : (tensor<1xi32>) -> tensor<1x1xi32> + %c = tt.broadcast %b : (tensor<1x1xi32>) -> tensor<128x1xi32> + + // CHECK-DAG: %[[d:.*]] = arith.constant dense<1> : tensor<1xi32> + %d = tt.make_range {end = 2 : i32, start = 1 : i32} : tensor<1xi32> + + // CHECK-DAG: tt.return %[[c]], %[[d]] : tensor<128x1xi32>, tensor<1xi32> + tt.return %c, %d : tensor<128x1xi32>, tensor<1xi32> +}