MLIR  22.0.0git
StructuralTypeConversions.cpp
Go to the documentation of this file.
1 //===- StructuralTypeConversions.cpp - scf structural type conversions ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
12 #include <optional>
13 
14 using namespace mlir;
15 using namespace mlir::scf;
16 
17 namespace {
18 
19 /// Flatten the given value ranges into a single vector of values.
21  SmallVector<Value> result;
22  for (const auto &vals : values)
23  llvm::append_range(result, vals);
24  return result;
25 }
26 
27 // CRTP
28 // A base class that takes care of 1:N type conversion, which maps the converted
29 // op results (computed by the derived class) and materializes 1:N conversion.
30 template <typename SourceOp, typename ConcretePattern>
31 class Structural1ToNConversionPattern : public OpConversionPattern<SourceOp> {
32 public:
35  using OneToNOpAdaptor =
37 
38  //
39  // Derived classes should provide the following method which performs the
40  // actual conversion. It should return std::nullopt upon conversion failure
41  // and return the converted operation upon success.
42  //
43  // std::optional<SourceOp> convertSourceOp(
44  // SourceOp op, OneToNOpAdaptor adaptor,
45  // ConversionPatternRewriter &rewriter,
46  // TypeRange dstTypes) const;
47 
48  LogicalResult
49  matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor,
50  ConversionPatternRewriter &rewriter) const override {
51  SmallVector<Type> dstTypes;
52  SmallVector<unsigned> offsets;
53  offsets.push_back(0);
54  // Do the type conversion and record the offsets.
55  for (Value v : op.getResults()) {
56  if (failed(typeConverter->convertType(v, dstTypes)))
57  return rewriter.notifyMatchFailure(op, "could not convert result type");
58  offsets.push_back(dstTypes.size());
59  }
60 
61  // Calls the actual converter implementation to convert the operation.
62  std::optional<SourceOp> newOp =
63  static_cast<const ConcretePattern *>(this)->convertSourceOp(
64  op, adaptor, rewriter, dstTypes);
65 
66  if (!newOp)
67  return rewriter.notifyMatchFailure(op, "could not convert operation");
68 
69  // Packs the return value.
70  SmallVector<ValueRange> packedRets;
71  for (unsigned i = 1, e = offsets.size(); i < e; i++) {
72  unsigned start = offsets[i - 1], end = offsets[i];
73  unsigned len = end - start;
74  ValueRange mappedValue = newOp->getResults().slice(start, len);
75  packedRets.push_back(mappedValue);
76  }
77 
78  rewriter.replaceOpWithMultiple(op, packedRets);
79  return success();
80  }
81 };
82 
83 class ConvertForOpTypes
84  : public Structural1ToNConversionPattern<ForOp, ConvertForOpTypes> {
85 public:
86  using Structural1ToNConversionPattern::Structural1ToNConversionPattern;
87 
88  // The callback required by CRTP.
89  std::optional<ForOp> convertSourceOp(ForOp op, OneToNOpAdaptor adaptor,
90  ConversionPatternRewriter &rewriter,
91  TypeRange dstTypes) const {
92  // Create a empty new op and inline the regions from the old op.
93  //
94  // This is a little bit tricky. We have two concerns here:
95  //
96  // 1. We cannot update the op in place because the dialect conversion
97  // framework does not track type changes for ops updated in place, so it
98  // won't insert appropriate materializations on the changed result types.
99  // PR47938 tracks this issue, but it seems hard to fix. Instead, we need
100  // to clone the op.
101  //
102  // 2. We need to reuse the original region instead of cloning it, otherwise
103  // the dialect conversion framework thinks that we just inserted all the
104  // cloned child ops. But what we want is to "take" the child regions and let
105  // the dialect conversion framework continue recursively into ops inside
106  // those regions (which are already in its worklist; inlining them into the
107  // new op's regions doesn't remove the child ops from the worklist).
108 
109  // convertRegionTypes already takes care of 1:N conversion.
110  if (failed(rewriter.convertRegionTypes(&op.getRegion(), *typeConverter)))
111  return std::nullopt;
112 
113  // We can not do clone as the number of result types after conversion
114  // might be different.
115  ForOp newOp = ForOp::create(rewriter, op.getLoc(),
116  llvm::getSingleElement(adaptor.getLowerBound()),
117  llvm::getSingleElement(adaptor.getUpperBound()),
118  llvm::getSingleElement(adaptor.getStep()),
119  flattenValues(adaptor.getInitArgs()),
120  /*bodyBuilder=*/nullptr, op.getUnsignedCmp());
121 
122  // Reserve whatever attributes in the original op.
123  newOp->setAttrs(op->getAttrs());
124 
125  // We do not need the empty block created by rewriter.
126  rewriter.eraseBlock(newOp.getBody(0));
127  // Inline the type converted region from the original operation.
128  rewriter.inlineRegionBefore(op.getRegion(), newOp.getRegion(),
129  newOp.getRegion().end());
130  return newOp;
131  }
132 };
133 } // namespace
134 
135 namespace {
136 class ConvertIfOpTypes
137  : public Structural1ToNConversionPattern<IfOp, ConvertIfOpTypes> {
138 public:
139  using Structural1ToNConversionPattern::Structural1ToNConversionPattern;
140 
141  std::optional<IfOp> convertSourceOp(IfOp op, OneToNOpAdaptor adaptor,
142  ConversionPatternRewriter &rewriter,
143  TypeRange dstTypes) const {
144 
145  IfOp newOp =
146  IfOp::create(rewriter, op.getLoc(), dstTypes,
147  llvm::getSingleElement(adaptor.getCondition()), true);
148  newOp->setAttrs(op->getAttrs());
149 
150  // We do not need the empty blocks created by rewriter.
151  rewriter.eraseBlock(newOp.elseBlock());
152  rewriter.eraseBlock(newOp.thenBlock());
153 
154  // Inlines block from the original operation.
155  rewriter.inlineRegionBefore(op.getThenRegion(), newOp.getThenRegion(),
156  newOp.getThenRegion().end());
157  rewriter.inlineRegionBefore(op.getElseRegion(), newOp.getElseRegion(),
158  newOp.getElseRegion().end());
159 
160  return newOp;
161  }
162 };
163 } // namespace
164 
165 namespace {
166 class ConvertWhileOpTypes
167  : public Structural1ToNConversionPattern<WhileOp, ConvertWhileOpTypes> {
168 public:
169  using Structural1ToNConversionPattern::Structural1ToNConversionPattern;
170 
171  std::optional<WhileOp> convertSourceOp(WhileOp op, OneToNOpAdaptor adaptor,
172  ConversionPatternRewriter &rewriter,
173  TypeRange dstTypes) const {
174  auto newOp = WhileOp::create(rewriter, op.getLoc(), dstTypes,
175  flattenValues(adaptor.getOperands()));
176 
177  for (auto i : {0u, 1u}) {
178  if (failed(rewriter.convertRegionTypes(&op.getRegion(i), *typeConverter)))
179  return std::nullopt;
180  auto &dstRegion = newOp.getRegion(i);
181  rewriter.inlineRegionBefore(op.getRegion(i), dstRegion, dstRegion.end());
182  }
183  return newOp;
184  }
185 };
186 } // namespace
187 
188 namespace {
189 class ConvertIndexSwitchOpTypes
190  : public Structural1ToNConversionPattern<IndexSwitchOp,
191  ConvertIndexSwitchOpTypes> {
192 public:
193  using Structural1ToNConversionPattern::Structural1ToNConversionPattern;
194 
195  std::optional<IndexSwitchOp>
196  convertSourceOp(IndexSwitchOp op, OneToNOpAdaptor adaptor,
197  ConversionPatternRewriter &rewriter,
198  TypeRange dstTypes) const {
199  auto newOp =
200  IndexSwitchOp::create(rewriter, op.getLoc(), dstTypes, op.getArg(),
201  op.getCases(), op.getNumCases());
202 
203  for (unsigned i = 0u; i < op.getNumRegions(); i++) {
204  auto &dstRegion = newOp.getRegion(i);
205  rewriter.inlineRegionBefore(op.getRegion(i), dstRegion, dstRegion.end());
206  }
207  return newOp;
208  }
209 };
210 } // namespace
211 
212 namespace {
213 // When the result types of a ForOp/IfOp get changed, the operand types of the
214 // corresponding yield op need to be changed. In order to trigger the
215 // appropriate type conversions / materializations, we need a dummy pattern.
216 class ConvertYieldOpTypes : public OpConversionPattern<scf::YieldOp> {
217 public:
219  LogicalResult
220  matchAndRewrite(scf::YieldOp op, OneToNOpAdaptor adaptor,
221  ConversionPatternRewriter &rewriter) const override {
222  rewriter.replaceOpWithNewOp<scf::YieldOp>(
223  op, flattenValues(adaptor.getOperands()));
224  return success();
225  }
226 };
227 } // namespace
228 
229 namespace {
230 class ConvertConditionOpTypes : public OpConversionPattern<ConditionOp> {
231 public:
233  LogicalResult
234  matchAndRewrite(ConditionOp op, OneToNOpAdaptor adaptor,
235  ConversionPatternRewriter &rewriter) const override {
236  rewriter.modifyOpInPlace(
237  op, [&]() { op->setOperands(flattenValues(adaptor.getOperands())); });
238  return success();
239  }
240 };
241 } // namespace
242 
244  const TypeConverter &typeConverter, RewritePatternSet &patterns,
245  PatternBenefit benefit) {
246  patterns.add<ConvertForOpTypes, ConvertIfOpTypes, ConvertYieldOpTypes,
247  ConvertWhileOpTypes, ConvertConditionOpTypes,
248  ConvertIndexSwitchOpTypes>(typeConverter, patterns.getContext(),
249  benefit);
250 }
251 
253  const TypeConverter &typeConverter, ConversionTarget &target) {
254  target.addDynamicallyLegalOp<ForOp, IfOp, IndexSwitchOp>(
255  [&](Operation *op) { return typeConverter.isLegal(op->getResults()); });
256  target.addDynamicallyLegalOp<scf::YieldOp>([&](scf::YieldOp op) {
257  // We only have conversions for a subset of ops that use scf.yield
258  // terminators.
259  if (!isa<ForOp, IfOp, WhileOp, IndexSwitchOp>(op->getParentOp()))
260  return true;
261  return typeConverter.isLegal(op.getOperands());
262  });
263  target.addDynamicallyLegalOp<WhileOp, ConditionOp>(
264  [&](Operation *op) { return typeConverter.isLegal(op); });
265 }
266 
268  const TypeConverter &typeConverter, RewritePatternSet &patterns,
269  ConversionTarget &target, PatternBenefit benefit) {
270  populateSCFStructuralTypeConversions(typeConverter, patterns, benefit);
271  populateSCFStructuralTypeConversionTarget(typeConverter, target);
272 }
This class implements a pattern rewriter for use with ConversionPatterns.
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Apply a signature conversion to each block in the given region.
void replaceOpWithMultiple(Operation *op, SmallVector< SmallVector< Value >> &&newValues)
Replace the given operation with the new value ranges.
void eraseBlock(Block *block) override
PatternRewriter hook for erase all operations in a block.
This class describes a specific conversion target.
void addDynamicallyLegalOp(OperationName op, const DynamicLegalityCallbackFn &callback)
Register the given operation as dynamically legal and set the dynamic legalization callback to the on...
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
typename SourceOp::template GenericAdaptor< ArrayRef< ValueRange > > OneToNOpAdaptor
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:726
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:638
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:529
Type conversion class.
bool isLegal(Type type) const
Return true if the given type is legal for this type converter, i.e.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
void populateSCFStructuralTypeConversions(const TypeConverter &typeConverter, RewritePatternSet &patterns, PatternBenefit benefit=1)
Similar to populateSCFStructuralTypeConversionsAndLegality but does not populate the conversion targe...
void populateSCFStructuralTypeConversionsAndLegality(const TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, PatternBenefit benefit=1)
Populates patterns for SCF structural type conversions and sets up the provided ConversionTarget with...
void populateSCFStructuralTypeConversionTarget(const TypeConverter &typeConverter, ConversionTarget &target)
Updates the ConversionTarget with dynamic legality of SCF operations based on the provided type conve...
SmallVector< Value > flattenValues(ArrayRef< ValueRange > values)
Flatten a set of ValueRange into a single SmallVector<Value>
Definition: XeGPUUtils.cpp:32
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns