MLIR 22.0.0git
FuncConversions.cpp
Go to the documentation of this file.
1//===- FuncConversions.cpp - Function conversions -------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
12
13using namespace mlir;
14using namespace mlir::func;
15
16/// Flatten the given value ranges into a single vector of values.
19 for (const auto &vals : values)
20 llvm::append_range(result, vals);
21 return result;
22}
23
24namespace {
25/// Converts the operand and result types of the CallOp, used together with the
26/// FuncOpSignatureConversion.
27struct CallOpSignatureConversion : public OpConversionPattern<CallOp> {
28 using OpConversionPattern<CallOp>::OpConversionPattern;
29
30 /// Hook for derived classes to implement combined matching and rewriting.
31 LogicalResult
32 matchAndRewrite(CallOp callOp, OneToNOpAdaptor adaptor,
33 ConversionPatternRewriter &rewriter) const override {
34 // Convert the original function results. Keep track of how many result
35 // types an original result type is converted into.
36 SmallVector<size_t> numResultsReplacments;
37 SmallVector<Type, 1> convertedResults;
38 size_t numFlattenedResults = 0;
39 for (auto [idx, type] : llvm::enumerate(callOp.getResultTypes())) {
40 if (failed(typeConverter->convertTypes(type, convertedResults)))
41 return failure();
42 numResultsReplacments.push_back(convertedResults.size() -
43 numFlattenedResults);
44 numFlattenedResults = convertedResults.size();
45 }
46
47 // Substitute with the new result types from the corresponding FuncType
48 // conversion.
49 auto newCallOp =
50 CallOp::create(rewriter, callOp.getLoc(), callOp.getCallee(),
51 convertedResults, flattenValues(adaptor.getOperands()));
52 SmallVector<ValueRange> replacements;
53 size_t offset = 0;
54 for (int i = 0, e = callOp->getNumResults(); i < e; ++i) {
55 replacements.push_back(
56 newCallOp->getResults().slice(offset, numResultsReplacments[i]));
57 offset += numResultsReplacments[i];
58 }
59 assert(offset == convertedResults.size() &&
60 "expected that all converted results are used");
61 rewriter.replaceOpWithMultiple(callOp, replacements);
62 return success();
63 }
64};
65} // namespace
66
68 const TypeConverter &converter,
69 PatternBenefit benefit) {
70 patterns.add<CallOpSignatureConversion>(converter, patterns.getContext(),
71 benefit);
72}
73
74namespace {
75/// Only needed to support partial conversion of functions where this pattern
76/// ensures that the branch operation arguments matches up with the succesor
77/// block arguments.
78class BranchOpInterfaceTypeConversion
79 : public OpInterfaceConversionPattern<BranchOpInterface> {
80public:
81 using OpInterfaceConversionPattern<
82 BranchOpInterface>::OpInterfaceConversionPattern;
83
84 BranchOpInterfaceTypeConversion(
85 const TypeConverter &typeConverter, MLIRContext *ctx,
86 function_ref<bool(BranchOpInterface, int)> shouldConvertBranchOperand,
87 PatternBenefit benefit)
88 : OpInterfaceConversionPattern(typeConverter, ctx, benefit),
89 shouldConvertBranchOperand(shouldConvertBranchOperand) {}
90
91 LogicalResult
92 matchAndRewrite(BranchOpInterface op, ArrayRef<Value> operands,
93 ConversionPatternRewriter &rewriter) const final {
94 // For a branch operation, only some operands go to the target blocks, so
95 // only rewrite those.
96 SmallVector<Value, 4> newOperands(op->operand_begin(), op->operand_end());
97 for (int succIdx = 0, succEnd = op->getBlock()->getNumSuccessors();
98 succIdx < succEnd; ++succIdx) {
99 OperandRange forwardedOperands =
100 op.getSuccessorOperands(succIdx).getForwardedOperands();
101 if (forwardedOperands.empty())
102 continue;
103
104 for (int idx = forwardedOperands.getBeginOperandIndex(),
105 eidx = idx + forwardedOperands.size();
106 idx < eidx; ++idx) {
107 if (!shouldConvertBranchOperand || shouldConvertBranchOperand(op, idx))
108 newOperands[idx] = operands[idx];
109 }
110 }
111 rewriter.modifyOpInPlace(
112 op, [newOperands, op]() { op->setOperands(newOperands); });
113 return success();
114 }
115
116private:
117 function_ref<bool(BranchOpInterface, int)> shouldConvertBranchOperand;
118};
119} // namespace
120
121namespace {
122/// Only needed to support partial conversion of functions where this pattern
123/// ensures that the branch operation arguments matches up with the succesor
124/// block arguments.
125class ReturnOpTypeConversion : public OpConversionPattern<ReturnOp> {
126public:
127 using OpConversionPattern<ReturnOp>::OpConversionPattern;
128
129 LogicalResult
130 matchAndRewrite(ReturnOp op, OneToNOpAdaptor adaptor,
131 ConversionPatternRewriter &rewriter) const final {
132 rewriter.replaceOpWithNewOp<ReturnOp>(op,
133 flattenValues(adaptor.getOperands()));
134 return success();
135 }
136};
137} // namespace
138
140 RewritePatternSet &patterns, const TypeConverter &typeConverter,
141 function_ref<bool(BranchOpInterface, int)> shouldConvertBranchOperand,
142 PatternBenefit benefit) {
143 patterns.add<BranchOpInterfaceTypeConversion>(
144 typeConverter, patterns.getContext(), shouldConvertBranchOperand,
145 benefit);
146}
147
149 Operation *op, const TypeConverter &converter) {
150 // All successor operands of branch like operations must be rewritten.
151 if (auto branchOp = dyn_cast<BranchOpInterface>(op)) {
152 for (int p = 0, e = op->getBlock()->getNumSuccessors(); p < e; ++p) {
153 auto successorOperands = branchOp.getSuccessorOperands(p);
154 if (!converter.isLegal(
155 successorOperands.getForwardedOperands().getTypes()))
156 return false;
157 }
158 return true;
159 }
160
161 return false;
162}
163
165 RewritePatternSet &patterns, const TypeConverter &typeConverter,
166 PatternBenefit benefit) {
167 patterns.add<ReturnOpTypeConversion>(typeConverter, patterns.getContext(),
168 benefit);
169}
170
172 Operation *op, const TypeConverter &converter, bool returnOpAlwaysLegal) {
173 // If this is a `return` and the user pass wants to convert/transform across
174 // function boundaries, then `converter` is invoked to check whether the
175 // `return` op is legal.
176 if (isa<ReturnOp>(op) && !returnOpAlwaysLegal)
177 return converter.isLegal(op);
178
179 // ReturnLike operations have to be legalized with their parent. For
180 // return this is handled, for other ops they remain as is.
181 return op->hasTrait<OpTrait::ReturnLike>();
182}
183
185 // If it is not a terminator, ignore it.
187 return true;
188
189 // If it is not the last operation in the block, also ignore it. We do
190 // this to handle unknown operations, as well.
191 Block *block = op->getBlock();
192 if (!block || &block->back() != op)
193 return true;
194
195 // We don't want to handle terminators in nested regions, assume they are
196 // always legal.
197 if (!isa_and_nonnull<FuncOp>(op->getParentOp()))
198 return true;
199
200 return false;
201}
return success()
static SmallVector< Value > flattenValues(ArrayRef< ValueRange > values)
Flatten the given value ranges into a single vector of values.
Block represents an ordered list of Operations.
Definition Block.h:33
unsigned getNumSuccessors()
Definition Block.cpp:265
Operation & back()
Definition Block.h:152
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class provides the API for ops that are known to be terminators.
unsigned getBeginOperandIndex() const
Return the operand index of the first element of this range.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition Operation.h:749
bool mightHaveTrait()
Returns true if the operation might have the provided trait.
Definition Operation.h:757
Block * getBlock()
Returns the operation block that contains this operation.
Definition Operation.h:213
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:234
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
Include the generated interface declarations.
bool isLegalForReturnOpTypeConversionPattern(Operation *op, const TypeConverter &converter, bool returnOpAlwaysLegal=false)
For ReturnLike ops (except return), return True.
void populateReturnOpTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter, PatternBenefit benefit=1)
Add a pattern to the given pattern list to rewrite return ops to use operands that have been legalize...
void populateCallOpTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter, PatternBenefit benefit=1)
Add a pattern to the given pattern list to convert the operand and result types of a CallOp with the ...
bool isNotBranchOpInterfaceOrReturnLikeOp(Operation *op)
Return true if op is neither BranchOpInterface nor ReturnLike.
void populateBranchOpInterfaceTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter, function_ref< bool(BranchOpInterface branchOp, int idx)> shouldConvertBranchOperand=nullptr, PatternBenefit benefit=1)
Add a pattern to the given pattern list to rewrite branch operations to use operands that have been l...
const FrozenRewritePatternSet & patterns
llvm::function_ref< Fn > function_ref
Definition LLVM.h:152
bool isLegalForBranchOpInterfaceTypeConversionPattern(Operation *op, const TypeConverter &converter)
Return true if op is a BranchOpInterface op whose operands are all legal according to converter.
This trait indicates that a terminator operation is "return-like".