MLIR 22.0.0git
StructuralTypeConversions.cpp
Go to the documentation of this file.
1//===- StructuralTypeConversions.cpp - scf structural type conversions ----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
12#include <optional>
13
14using namespace mlir;
15using namespace mlir::scf;
16
17namespace {
18
19/// Flatten the given value ranges into a single vector of values.
22 for (const auto &vals : values)
23 llvm::append_range(result, vals);
24 return result;
25}
26
27// CRTP
28// A base class that takes care of 1:N type conversion, which maps the converted
29// op results (computed by the derived class) and materializes 1:N conversion.
30template <typename SourceOp, typename ConcretePattern>
31class Structural1ToNConversionPattern : public OpConversionPattern<SourceOp> {
32public:
33 using OpConversionPattern<SourceOp>::typeConverter;
34 using OpConversionPattern<SourceOp>::OpConversionPattern;
35 using OneToNOpAdaptor =
36 typename OpConversionPattern<SourceOp>::OneToNOpAdaptor;
37
38 //
39 // Derived classes should provide the following method which performs the
40 // actual conversion. It should return std::nullopt upon conversion failure
41 // and return the converted operation upon success.
42 //
43 // std::optional<SourceOp> convertSourceOp(
44 // SourceOp op, OneToNOpAdaptor adaptor,
45 // ConversionPatternRewriter &rewriter,
46 // TypeRange dstTypes) const;
47
48 LogicalResult
49 matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor,
50 ConversionPatternRewriter &rewriter) const override {
51 SmallVector<Type> dstTypes;
52 SmallVector<unsigned> offsets;
53 offsets.push_back(0);
54 // Do the type conversion and record the offsets.
55 for (Value v : op.getResults()) {
56 if (failed(typeConverter->convertType(v, dstTypes)))
57 return rewriter.notifyMatchFailure(op, "could not convert result type");
58 offsets.push_back(dstTypes.size());
59 }
60
61 // Calls the actual converter implementation to convert the operation.
62 std::optional<SourceOp> newOp =
63 static_cast<const ConcretePattern *>(this)->convertSourceOp(
64 op, adaptor, rewriter, dstTypes);
65
66 if (!newOp)
67 return rewriter.notifyMatchFailure(op, "could not convert operation");
68
69 // Packs the return value.
70 SmallVector<ValueRange> packedRets;
71 for (unsigned i = 1, e = offsets.size(); i < e; i++) {
72 unsigned start = offsets[i - 1], end = offsets[i];
73 unsigned len = end - start;
74 ValueRange mappedValue = newOp->getResults().slice(start, len);
75 packedRets.push_back(mappedValue);
76 }
77
78 rewriter.replaceOpWithMultiple(op, packedRets);
79 return success();
80 }
81};
82
83class ConvertForOpTypes
84 : public Structural1ToNConversionPattern<ForOp, ConvertForOpTypes> {
85public:
86 using Structural1ToNConversionPattern::Structural1ToNConversionPattern;
87
88 // The callback required by CRTP.
89 std::optional<ForOp> convertSourceOp(ForOp op, OneToNOpAdaptor adaptor,
90 ConversionPatternRewriter &rewriter,
91 TypeRange dstTypes) const {
92 // Create a empty new op and inline the regions from the old op.
93 //
94 // This is a little bit tricky. We have two concerns here:
95 //
96 // 1. We cannot update the op in place because the dialect conversion
97 // framework does not track type changes for ops updated in place, so it
98 // won't insert appropriate materializations on the changed result types.
99 // PR47938 tracks this issue, but it seems hard to fix. Instead, we need
100 // to clone the op.
101 //
102 // 2. We need to reuse the original region instead of cloning it, otherwise
103 // the dialect conversion framework thinks that we just inserted all the
104 // cloned child ops. But what we want is to "take" the child regions and let
105 // the dialect conversion framework continue recursively into ops inside
106 // those regions (which are already in its worklist; inlining them into the
107 // new op's regions doesn't remove the child ops from the worklist).
108
109 // convertRegionTypes already takes care of 1:N conversion.
110 if (failed(rewriter.convertRegionTypes(&op.getRegion(), *typeConverter)))
111 return std::nullopt;
112
113 // We can not do clone as the number of result types after conversion
114 // might be different.
115 ForOp newOp = ForOp::create(rewriter, op.getLoc(),
116 llvm::getSingleElement(adaptor.getLowerBound()),
117 llvm::getSingleElement(adaptor.getUpperBound()),
118 llvm::getSingleElement(adaptor.getStep()),
119 flattenValues(adaptor.getInitArgs()),
120 /*bodyBuilder=*/nullptr, op.getUnsignedCmp());
121
122 // Reserve whatever attributes in the original op.
123 newOp->setAttrs(op->getAttrs());
124
125 // We do not need the empty block created by rewriter.
126 rewriter.eraseBlock(newOp.getBody(0));
127 // Inline the type converted region from the original operation.
128 rewriter.inlineRegionBefore(op.getRegion(), newOp.getRegion(),
129 newOp.getRegion().end());
130 return newOp;
131 }
132};
133} // namespace
134
135namespace {
136class ConvertIfOpTypes
137 : public Structural1ToNConversionPattern<IfOp, ConvertIfOpTypes> {
138public:
139 using Structural1ToNConversionPattern::Structural1ToNConversionPattern;
140
141 std::optional<IfOp> convertSourceOp(IfOp op, OneToNOpAdaptor adaptor,
142 ConversionPatternRewriter &rewriter,
143 TypeRange dstTypes) const {
144
145 IfOp newOp =
146 IfOp::create(rewriter, op.getLoc(), dstTypes,
147 llvm::getSingleElement(adaptor.getCondition()), true);
148 newOp->setAttrs(op->getAttrs());
149
150 // We do not need the empty blocks created by rewriter.
151 rewriter.eraseBlock(newOp.elseBlock());
152 rewriter.eraseBlock(newOp.thenBlock());
153
154 // Inlines block from the original operation.
155 rewriter.inlineRegionBefore(op.getThenRegion(), newOp.getThenRegion(),
156 newOp.getThenRegion().end());
157 rewriter.inlineRegionBefore(op.getElseRegion(), newOp.getElseRegion(),
158 newOp.getElseRegion().end());
159
160 return newOp;
161 }
162};
163} // namespace
164
165namespace {
166class ConvertWhileOpTypes
167 : public Structural1ToNConversionPattern<WhileOp, ConvertWhileOpTypes> {
168public:
169 using Structural1ToNConversionPattern::Structural1ToNConversionPattern;
170
171 std::optional<WhileOp> convertSourceOp(WhileOp op, OneToNOpAdaptor adaptor,
172 ConversionPatternRewriter &rewriter,
173 TypeRange dstTypes) const {
174 auto newOp = WhileOp::create(rewriter, op.getLoc(), dstTypes,
175 flattenValues(adaptor.getOperands()));
176
177 for (auto i : {0u, 1u}) {
178 if (failed(rewriter.convertRegionTypes(&op.getRegion(i), *typeConverter)))
179 return std::nullopt;
180 auto &dstRegion = newOp.getRegion(i);
181 rewriter.inlineRegionBefore(op.getRegion(i), dstRegion, dstRegion.end());
182 }
183 return newOp;
184 }
185};
186} // namespace
187
188namespace {
189class ConvertIndexSwitchOpTypes
190 : public Structural1ToNConversionPattern<IndexSwitchOp,
191 ConvertIndexSwitchOpTypes> {
192public:
193 using Structural1ToNConversionPattern::Structural1ToNConversionPattern;
194
195 std::optional<IndexSwitchOp>
196 convertSourceOp(IndexSwitchOp op, OneToNOpAdaptor adaptor,
197 ConversionPatternRewriter &rewriter,
198 TypeRange dstTypes) const {
199 auto newOp =
200 IndexSwitchOp::create(rewriter, op.getLoc(), dstTypes, op.getArg(),
201 op.getCases(), op.getNumCases());
202
203 for (unsigned i = 0u; i < op.getNumRegions(); i++) {
204 auto &dstRegion = newOp.getRegion(i);
205 rewriter.inlineRegionBefore(op.getRegion(i), dstRegion, dstRegion.end());
206 }
207 return newOp;
208 }
209};
210} // namespace
211
212namespace {
213// When the result types of a ForOp/IfOp get changed, the operand types of the
214// corresponding yield op need to be changed. In order to trigger the
215// appropriate type conversions / materializations, we need a dummy pattern.
216class ConvertYieldOpTypes : public OpConversionPattern<scf::YieldOp> {
217public:
218 using OpConversionPattern::OpConversionPattern;
219 LogicalResult
220 matchAndRewrite(scf::YieldOp op, OneToNOpAdaptor adaptor,
221 ConversionPatternRewriter &rewriter) const override {
222 rewriter.replaceOpWithNewOp<scf::YieldOp>(
223 op, flattenValues(adaptor.getOperands()));
224 return success();
225 }
226};
227} // namespace
228
229namespace {
230class ConvertConditionOpTypes : public OpConversionPattern<ConditionOp> {
231public:
232 using OpConversionPattern<ConditionOp>::OpConversionPattern;
233 LogicalResult
234 matchAndRewrite(ConditionOp op, OneToNOpAdaptor adaptor,
235 ConversionPatternRewriter &rewriter) const override {
236 rewriter.modifyOpInPlace(
237 op, [&]() { op->setOperands(flattenValues(adaptor.getOperands())); });
238 return success();
239 }
240};
241} // namespace
242
244 const TypeConverter &typeConverter, RewritePatternSet &patterns,
245 PatternBenefit benefit) {
246 patterns.add<ConvertForOpTypes, ConvertIfOpTypes, ConvertYieldOpTypes,
247 ConvertWhileOpTypes, ConvertConditionOpTypes,
248 ConvertIndexSwitchOpTypes>(typeConverter, patterns.getContext(),
249 benefit);
250}
251
253 const TypeConverter &typeConverter, ConversionTarget &target) {
254 target.addDynamicallyLegalOp<ForOp, IfOp, IndexSwitchOp>(
255 [&](Operation *op) { return typeConverter.isLegal(op->getResults()); });
256 target.addDynamicallyLegalOp<scf::YieldOp>([&](scf::YieldOp op) {
257 // We only have conversions for a subset of ops that use scf.yield
258 // terminators.
259 if (!isa<ForOp, IfOp, WhileOp, IndexSwitchOp>(op->getParentOp()))
260 return true;
261 return typeConverter.isLegal(op.getOperands());
262 });
263 target.addDynamicallyLegalOp<WhileOp, ConditionOp>(
264 [&](Operation *op) { return typeConverter.isLegal(op); });
265}
266
return success()
static SmallVector< Value > flattenValues(ArrayRef< ValueRange > values)
Flatten the given value ranges into a single vector of values.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
void populateSCFStructuralTypeConversions(const TypeConverter &typeConverter, RewritePatternSet &patterns, PatternBenefit benefit=1)
Similar to populateSCFStructuralTypeConversionsAndLegality but does not populate the conversion targe...
void populateSCFStructuralTypeConversionsAndLegality(const TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, PatternBenefit benefit=1)
Populates patterns for SCF structural type conversions and sets up the provided ConversionTarget with...
void populateSCFStructuralTypeConversionTarget(const TypeConverter &typeConverter, ConversionTarget &target)
Updates the ConversionTarget with dynamic legality of SCF operations based on the provided type conve...
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns