MLIR  22.0.0git
FuncConversions.cpp
Go to the documentation of this file.
1 //===- FuncConversions.cpp - Function conversions -------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
12 
13 using namespace mlir;
14 using namespace mlir::func;
15 
16 /// Flatten the given value ranges into a single vector of values.
18  SmallVector<Value> result;
19  for (const auto &vals : values)
20  llvm::append_range(result, vals);
21  return result;
22 }
23 
24 namespace {
25 /// Converts the operand and result types of the CallOp, used together with the
26 /// FuncOpSignatureConversion.
27 struct CallOpSignatureConversion : public OpConversionPattern<CallOp> {
29 
30  /// Hook for derived classes to implement combined matching and rewriting.
31  LogicalResult
32  matchAndRewrite(CallOp callOp, OneToNOpAdaptor adaptor,
33  ConversionPatternRewriter &rewriter) const override {
34  // Convert the original function results. Keep track of how many result
35  // types an original result type is converted into.
36  SmallVector<size_t> numResultsReplacments;
37  SmallVector<Type, 1> convertedResults;
38  size_t numFlattenedResults = 0;
39  for (auto [idx, type] : llvm::enumerate(callOp.getResultTypes())) {
40  if (failed(typeConverter->convertTypes(type, convertedResults)))
41  return failure();
42  numResultsReplacments.push_back(convertedResults.size() -
43  numFlattenedResults);
44  numFlattenedResults = convertedResults.size();
45  }
46 
47  // Substitute with the new result types from the corresponding FuncType
48  // conversion.
49  auto newCallOp =
50  CallOp::create(rewriter, callOp.getLoc(), callOp.getCallee(),
51  convertedResults, flattenValues(adaptor.getOperands()));
52  SmallVector<ValueRange> replacements;
53  size_t offset = 0;
54  for (int i = 0, e = callOp->getNumResults(); i < e; ++i) {
55  replacements.push_back(
56  newCallOp->getResults().slice(offset, numResultsReplacments[i]));
57  offset += numResultsReplacments[i];
58  }
59  assert(offset == convertedResults.size() &&
60  "expected that all converted results are used");
61  rewriter.replaceOpWithMultiple(callOp, replacements);
62  return success();
63  }
64 };
65 } // namespace
66 
68  const TypeConverter &converter,
69  PatternBenefit benefit) {
70  patterns.add<CallOpSignatureConversion>(converter, patterns.getContext(),
71  benefit);
72 }
73 
74 namespace {
75 /// Only needed to support partial conversion of functions where this pattern
76 /// ensures that the branch operation arguments matches up with the succesor
77 /// block arguments.
78 class BranchOpInterfaceTypeConversion
79  : public OpInterfaceConversionPattern<BranchOpInterface> {
80 public:
82  BranchOpInterface>::OpInterfaceConversionPattern;
83 
84  BranchOpInterfaceTypeConversion(
85  const TypeConverter &typeConverter, MLIRContext *ctx,
86  function_ref<bool(BranchOpInterface, int)> shouldConvertBranchOperand,
87  PatternBenefit benefit)
88  : OpInterfaceConversionPattern(typeConverter, ctx, benefit),
89  shouldConvertBranchOperand(shouldConvertBranchOperand) {}
90 
91  LogicalResult
92  matchAndRewrite(BranchOpInterface op, ArrayRef<Value> operands,
93  ConversionPatternRewriter &rewriter) const final {
94  // For a branch operation, only some operands go to the target blocks, so
95  // only rewrite those.
96  SmallVector<Value, 4> newOperands(op->operand_begin(), op->operand_end());
97  for (int succIdx = 0, succEnd = op->getBlock()->getNumSuccessors();
98  succIdx < succEnd; ++succIdx) {
99  OperandRange forwardedOperands =
100  op.getSuccessorOperands(succIdx).getForwardedOperands();
101  if (forwardedOperands.empty())
102  continue;
103 
104  for (int idx = forwardedOperands.getBeginOperandIndex(),
105  eidx = idx + forwardedOperands.size();
106  idx < eidx; ++idx) {
107  if (!shouldConvertBranchOperand || shouldConvertBranchOperand(op, idx))
108  newOperands[idx] = operands[idx];
109  }
110  }
111  rewriter.modifyOpInPlace(
112  op, [newOperands, op]() { op->setOperands(newOperands); });
113  return success();
114  }
115 
116 private:
117  function_ref<bool(BranchOpInterface, int)> shouldConvertBranchOperand;
118 };
119 } // namespace
120 
121 namespace {
122 /// Only needed to support partial conversion of functions where this pattern
123 /// ensures that the branch operation arguments matches up with the succesor
124 /// block arguments.
125 class ReturnOpTypeConversion : public OpConversionPattern<ReturnOp> {
126 public:
128 
129  LogicalResult
130  matchAndRewrite(ReturnOp op, OneToNOpAdaptor adaptor,
131  ConversionPatternRewriter &rewriter) const final {
132  rewriter.replaceOpWithNewOp<ReturnOp>(op,
133  flattenValues(adaptor.getOperands()));
134  return success();
135  }
136 };
137 } // namespace
138 
140  RewritePatternSet &patterns, const TypeConverter &typeConverter,
141  function_ref<bool(BranchOpInterface, int)> shouldConvertBranchOperand,
142  PatternBenefit benefit) {
143  patterns.add<BranchOpInterfaceTypeConversion>(
144  typeConverter, patterns.getContext(), shouldConvertBranchOperand,
145  benefit);
146 }
147 
149  Operation *op, const TypeConverter &converter) {
150  // All successor operands of branch like operations must be rewritten.
151  if (auto branchOp = dyn_cast<BranchOpInterface>(op)) {
152  for (int p = 0, e = op->getBlock()->getNumSuccessors(); p < e; ++p) {
153  auto successorOperands = branchOp.getSuccessorOperands(p);
154  if (!converter.isLegal(
155  successorOperands.getForwardedOperands().getTypes()))
156  return false;
157  }
158  return true;
159  }
160 
161  return false;
162 }
163 
165  RewritePatternSet &patterns, const TypeConverter &typeConverter,
166  PatternBenefit benefit) {
167  patterns.add<ReturnOpTypeConversion>(typeConverter, patterns.getContext(),
168  benefit);
169 }
170 
172  Operation *op, const TypeConverter &converter, bool returnOpAlwaysLegal) {
173  // If this is a `return` and the user pass wants to convert/transform across
174  // function boundaries, then `converter` is invoked to check whether the
175  // `return` op is legal.
176  if (isa<ReturnOp>(op) && !returnOpAlwaysLegal)
177  return converter.isLegal(op);
178 
179  // ReturnLike operations have to be legalized with their parent. For
180  // return this is handled, for other ops they remain as is.
181  return op->hasTrait<OpTrait::ReturnLike>();
182 }
183 
185  // If it is not a terminator, ignore it.
187  return true;
188 
189  // If it is not the last operation in the block, also ignore it. We do
190  // this to handle unknown operations, as well.
191  Block *block = op->getBlock();
192  if (!block || &block->back() != op)
193  return true;
194 
195  // We don't want to handle terminators in nested regions, assume they are
196  // always legal.
197  if (!isa_and_nonnull<FuncOp>(op->getParentOp()))
198  return true;
199 
200  return false;
201 }
static SmallVector< Value > flattenValues(ArrayRef< ValueRange > values)
Flatten the given value ranges into a single vector of values.
Block represents an ordered list of Operations.
Definition: Block.h:33
unsigned getNumSuccessors()
Definition: Block.cpp:265
Operation & back()
Definition: Block.h:152
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOpWithMultiple(Operation *op, SmallVector< SmallVector< Value >> &&newValues)
Replace the given operation with the new value ranges.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpInterfaceConversionPattern is a wrapper around ConversionPattern that allows for matching and rewri...
This class provides the API for ops that are known to be terminators.
Definition: OpDefinition.h:773
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:43
unsigned getBeginOperandIndex() const
Return the operand index of the first element of this range.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:749
bool mightHaveTrait()
Returns true if the operation might have the provided trait.
Definition: Operation.h:757
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:638
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:529
Type conversion class.
bool isLegal(Type type) const
Return true if the given type is legal for this type converter, i.e.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
Include the generated interface declarations.
bool isLegalForReturnOpTypeConversionPattern(Operation *op, const TypeConverter &converter, bool returnOpAlwaysLegal=false)
For ReturnLike ops (except return), return True.
void populateReturnOpTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter, PatternBenefit benefit=1)
Add a pattern to the given pattern list to rewrite return ops to use operands that have been legalize...
void populateCallOpTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter, PatternBenefit benefit=1)
Add a pattern to the given pattern list to convert the operand and result types of a CallOp with the ...
bool isNotBranchOpInterfaceOrReturnLikeOp(Operation *op)
Return true if op is neither BranchOpInterface nor ReturnLike.
void populateBranchOpInterfaceTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter, function_ref< bool(BranchOpInterface branchOp, int idx)> shouldConvertBranchOperand=nullptr, PatternBenefit benefit=1)
Add a pattern to the given pattern list to rewrite branch operations to use operands that have been l...
const FrozenRewritePatternSet & patterns
bool isLegalForBranchOpInterfaceTypeConversionPattern(Operation *op, const TypeConverter &converter)
Return true if op is a BranchOpInterface op whose operands are all legal according to converter.
This trait indicates that a terminator operation is "return-like".