MLIR  19.0.0git
ConcatOpPatterns.cpp
Go to the documentation of this file.
1 //===- ConcatOpPatterns.cpp - Patterns related to tensor.concat lowering --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
14 #include "mlir/IR/PatternMatch.h"
15 
16 using namespace mlir;
17 using namespace mlir::tensor;
18 
19 namespace {
20 
21 /// Decompose `tensor.concat` into `tensor.empty` and a chain of slice inserts.
22 ///
23 /// %concat = tensor.concat dim(1) %0, %1 :
24 /// (tensor<2x3xf32>, tensor<2x4xf32>) -> tensor<2x7xf32>
25 ///
26 /// Becomes
27 ///
28 /// %empty = tensor.empty() : tensor<2x7xf32>
29 /// %insert0 = tensor.insert_slice %0 into %empty[0, 0][2, 3][1, 1]
30 /// %concat = tensor.insert_slice %1 into %insert0[0, 3][2, 4][1, 1]
31 struct DecomposeTensorConcatOp : public OpRewritePattern<ConcatOp> {
33 
34  LogicalResult matchAndRewrite(ConcatOp concatOp,
35  PatternRewriter &rewriter) const override {
36  Location loc = concatOp.getLoc();
37  FailureOr<Value> dest =
38  tensor::getOrCreateDestination(rewriter, loc, concatOp->getResult(0));
39  if (failed(dest))
40  return failure();
41 
42  auto empty = dest->getDefiningOp<tensor::EmptyOp>();
43  if (!empty)
44  return failure();
45 
46  int64_t dim = concatOp.getDim();
47  Value dimValue =
48  rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(dim));
49 
50  int64_t rank = concatOp.getResultType().getRank();
51  SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
52  SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0));
53 
54  // Compute the partial sums for the slice offsets.
55  AffineExpr sum = rewriter.getAffineDimExpr(0);
56  SmallVector<AffineExpr> partialSums = {sum};
57  SmallVector<OpFoldResult> offsetStrides = {rewriter.getIndexAttr(0)};
58  for (auto [idx, input] :
59  llvm::enumerate(concatOp.getInputs().drop_back())) {
60  sum = sum + rewriter.getAffineDimExpr(idx + 1);
61  partialSums.push_back(sum);
62  offsetStrides.push_back(
63  rewriter.createOrFold<tensor::DimOp>(loc, input, dimValue));
64  }
65  auto partialSumMap = AffineMap::get(concatOp.getInputs().size(), 0,
66  partialSums, rewriter.getContext());
67  SmallVector<OpFoldResult> dimOffsets =
69  rewriter, loc, partialSumMap, offsetStrides);
70 
71  // Construct the chain of insert_slice ops into the destination.
72  Value result = *dest;
73  for (auto [input, offset] :
74  llvm::zip_equal(concatOp.getInputs(), dimOffsets)) {
76  tensor::getMixedSizes(rewriter, loc, input);
77  offsets[dim] = offset;
78  result = rewriter.createOrFold<tensor::InsertSliceOp>(
79  loc, input, result, offsets, sizes, strides);
80  }
81 
82  rewriter.replaceOpWithNewOp<tensor::CastOp>(
83  concatOp, concatOp.getResultType(), result);
84  return success();
85  }
86 };
87 
88 } // namespace
89 
91  RewritePatternSet &patterns) {
92  patterns.add<DecomposeTensorConcatOp>(patterns.getContext());
93 }
Base type for affine expression.
Definition: AffineExpr.h:69
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:124
AffineExpr getAffineDimExpr(unsigned position)
Definition: Builders.cpp:371
MLIRContext * getContext() const
Definition: Builders.h:55
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:522
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
MLIRContext * getContext() const
Definition: PatternMatch.h:822
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:846
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
SmallVector< OpFoldResult > makeComposedFoldedMultiResultAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Variant of makeComposedFoldedAffineApply suitable for multi-result maps.
Definition: AffineOps.cpp:1235
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
void populateDecomposeTensorConcatPatterns(RewritePatternSet &patterns)
Populates patterns with patterns that decompose tensor.concat into tensor.empty of a tensor of the co...
FailureOr< Value > getOrCreateDestination(OpBuilder &b, Location loc, OpResult opResult)
This is a helper function for DestinationStyleOpInterface.
Definition: TensorOps.cpp:70
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition: TensorOps.cpp:61
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358