MLIR  21.0.0git
StructuralTypeConversions.cpp
Go to the documentation of this file.
1 //===- StructuralTypeConversions.cpp - scf structural type conversions ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
12 #include <optional>
13 
14 using namespace mlir;
15 using namespace mlir::scf;
16 
17 namespace {
18 
19 /// Flatten the given value ranges into a single vector of values.
21  SmallVector<Value> result;
22  for (const auto &vals : values)
23  llvm::append_range(result, vals);
24  return result;
25 }
26 
27 // CRTP
28 // A base class that takes care of 1:N type conversion, which maps the converted
29 // op results (computed by the derived class) and materializes 1:N conversion.
30 template <typename SourceOp, typename ConcretePattern>
31 class Structural1ToNConversionPattern : public OpConversionPattern<SourceOp> {
32 public:
35  using OneToNOpAdaptor =
37 
38  //
39  // Derived classes should provide the following method which performs the
40  // actual conversion. It should return std::nullopt upon conversion failure
41  // and return the converted operation upon success.
42  //
43  // std::optional<SourceOp> convertSourceOp(
44  // SourceOp op, OneToNOpAdaptor adaptor,
45  // ConversionPatternRewriter &rewriter,
46  // TypeRange dstTypes) const;
47 
48  LogicalResult
49  matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor,
50  ConversionPatternRewriter &rewriter) const override {
51  SmallVector<Type> dstTypes;
52  SmallVector<unsigned> offsets;
53  offsets.push_back(0);
54  // Do the type conversion and record the offsets.
55  for (Type type : op.getResultTypes()) {
56  if (failed(typeConverter->convertTypes(type, dstTypes)))
57  return rewriter.notifyMatchFailure(op, "could not convert result type");
58  offsets.push_back(dstTypes.size());
59  }
60 
61  // Calls the actual converter implementation to convert the operation.
62  std::optional<SourceOp> newOp =
63  static_cast<const ConcretePattern *>(this)->convertSourceOp(
64  op, adaptor, rewriter, dstTypes);
65 
66  if (!newOp)
67  return rewriter.notifyMatchFailure(op, "could not convert operation");
68 
69  // Packs the return value.
70  SmallVector<ValueRange> packedRets;
71  for (unsigned i = 1, e = offsets.size(); i < e; i++) {
72  unsigned start = offsets[i - 1], end = offsets[i];
73  unsigned len = end - start;
74  ValueRange mappedValue = newOp->getResults().slice(start, len);
75  packedRets.push_back(mappedValue);
76  }
77 
78  rewriter.replaceOpWithMultiple(op, packedRets);
79  return success();
80  }
81 };
82 
83 class ConvertForOpTypes
84  : public Structural1ToNConversionPattern<ForOp, ConvertForOpTypes> {
85 public:
86  using Structural1ToNConversionPattern::Structural1ToNConversionPattern;
87 
88  // The callback required by CRTP.
89  std::optional<ForOp> convertSourceOp(ForOp op, OneToNOpAdaptor adaptor,
90  ConversionPatternRewriter &rewriter,
91  TypeRange dstTypes) const {
92  // Create a empty new op and inline the regions from the old op.
93  //
94  // This is a little bit tricky. We have two concerns here:
95  //
96  // 1. We cannot update the op in place because the dialect conversion
97  // framework does not track type changes for ops updated in place, so it
98  // won't insert appropriate materializations on the changed result types.
99  // PR47938 tracks this issue, but it seems hard to fix. Instead, we need
100  // to clone the op.
101  //
102  // 2. We need to resue the original region instead of cloning it, otherwise
103  // the dialect conversion framework thinks that we just inserted all the
104  // cloned child ops. But what we want is to "take" the child regions and let
105  // the dialect conversion framework continue recursively into ops inside
106  // those regions (which are already in its worklist; inlining them into the
107  // new op's regions doesn't remove the child ops from the worklist).
108 
109  // convertRegionTypes already takes care of 1:N conversion.
110  if (failed(rewriter.convertRegionTypes(&op.getRegion(), *typeConverter)))
111  return std::nullopt;
112 
113  // We can not do clone as the number of result types after conversion
114  // might be different.
115  ForOp newOp = rewriter.create<ForOp>(
116  op.getLoc(), llvm::getSingleElement(adaptor.getLowerBound()),
117  llvm::getSingleElement(adaptor.getUpperBound()),
118  llvm::getSingleElement(adaptor.getStep()),
119  flattenValues(adaptor.getInitArgs()));
120 
121  // Reserve whatever attributes in the original op.
122  newOp->setAttrs(op->getAttrs());
123 
124  // We do not need the empty block created by rewriter.
125  rewriter.eraseBlock(newOp.getBody(0));
126  // Inline the type converted region from the original operation.
127  rewriter.inlineRegionBefore(op.getRegion(), newOp.getRegion(),
128  newOp.getRegion().end());
129 
130  return newOp;
131  }
132 };
133 } // namespace
134 
135 namespace {
136 class ConvertIfOpTypes
137  : public Structural1ToNConversionPattern<IfOp, ConvertIfOpTypes> {
138 public:
139  using Structural1ToNConversionPattern::Structural1ToNConversionPattern;
140 
141  std::optional<IfOp> convertSourceOp(IfOp op, OneToNOpAdaptor adaptor,
142  ConversionPatternRewriter &rewriter,
143  TypeRange dstTypes) const {
144 
145  IfOp newOp = rewriter.create<IfOp>(
146  op.getLoc(), dstTypes, llvm::getSingleElement(adaptor.getCondition()),
147  true);
148  newOp->setAttrs(op->getAttrs());
149 
150  // We do not need the empty blocks created by rewriter.
151  rewriter.eraseBlock(newOp.elseBlock());
152  rewriter.eraseBlock(newOp.thenBlock());
153 
154  // Inlines block from the original operation.
155  rewriter.inlineRegionBefore(op.getThenRegion(), newOp.getThenRegion(),
156  newOp.getThenRegion().end());
157  rewriter.inlineRegionBefore(op.getElseRegion(), newOp.getElseRegion(),
158  newOp.getElseRegion().end());
159 
160  return newOp;
161  }
162 };
163 } // namespace
164 
165 namespace {
166 class ConvertWhileOpTypes
167  : public Structural1ToNConversionPattern<WhileOp, ConvertWhileOpTypes> {
168 public:
169  using Structural1ToNConversionPattern::Structural1ToNConversionPattern;
170 
171  std::optional<WhileOp> convertSourceOp(WhileOp op, OneToNOpAdaptor adaptor,
172  ConversionPatternRewriter &rewriter,
173  TypeRange dstTypes) const {
174  auto newOp = rewriter.create<WhileOp>(op.getLoc(), dstTypes,
175  flattenValues(adaptor.getOperands()));
176 
177  for (auto i : {0u, 1u}) {
178  if (failed(rewriter.convertRegionTypes(&op.getRegion(i), *typeConverter)))
179  return std::nullopt;
180  auto &dstRegion = newOp.getRegion(i);
181  rewriter.inlineRegionBefore(op.getRegion(i), dstRegion, dstRegion.end());
182  }
183  return newOp;
184  }
185 };
186 } // namespace
187 
188 namespace {
189 // When the result types of a ForOp/IfOp get changed, the operand types of the
190 // corresponding yield op need to be changed. In order to trigger the
191 // appropriate type conversions / materializations, we need a dummy pattern.
192 class ConvertYieldOpTypes : public OpConversionPattern<scf::YieldOp> {
193 public:
195  LogicalResult
196  matchAndRewrite(scf::YieldOp op, OneToNOpAdaptor adaptor,
197  ConversionPatternRewriter &rewriter) const override {
198  rewriter.replaceOpWithNewOp<scf::YieldOp>(
199  op, flattenValues(adaptor.getOperands()));
200  return success();
201  }
202 };
203 } // namespace
204 
205 namespace {
206 class ConvertConditionOpTypes : public OpConversionPattern<ConditionOp> {
207 public:
209  LogicalResult
210  matchAndRewrite(ConditionOp op, OneToNOpAdaptor adaptor,
211  ConversionPatternRewriter &rewriter) const override {
212  rewriter.modifyOpInPlace(
213  op, [&]() { op->setOperands(flattenValues(adaptor.getOperands())); });
214  return success();
215  }
216 };
217 } // namespace
218 
220  const TypeConverter &typeConverter, RewritePatternSet &patterns) {
221  patterns.add<ConvertForOpTypes, ConvertIfOpTypes, ConvertYieldOpTypes,
222  ConvertWhileOpTypes, ConvertConditionOpTypes>(
223  typeConverter, patterns.getContext());
224 }
225 
227  const TypeConverter &typeConverter, ConversionTarget &target) {
228  target.addDynamicallyLegalOp<ForOp, IfOp>([&](Operation *op) {
229  return typeConverter.isLegal(op->getResultTypes());
230  });
231  target.addDynamicallyLegalOp<scf::YieldOp>([&](scf::YieldOp op) {
232  // We only have conversions for a subset of ops that use scf.yield
233  // terminators.
234  if (!isa<ForOp, IfOp, WhileOp>(op->getParentOp()))
235  return true;
236  return typeConverter.isLegal(op.getOperandTypes());
237  });
238  target.addDynamicallyLegalOp<WhileOp, ConditionOp>(
239  [&](Operation *op) { return typeConverter.isLegal(op); });
240 }
241 
243  const TypeConverter &typeConverter, RewritePatternSet &patterns,
244  ConversionTarget &target) {
246  populateSCFStructuralTypeConversionTarget(typeConverter, target);
247 }
static SmallVector< Value > flattenValues(ArrayRef< ValueRange > values)
Flatten the given value ranges into a single vector of values.
This class implements a pattern rewriter for use with ConversionPatterns.
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Apply a signature conversion to each block in the given region.
void replaceOpWithMultiple(Operation *op, ArrayRef< ValueRange > newValues)
Replace the given operation with the new value ranges.
void eraseBlock(Block *block) override
PatternRewriter hook for erase all operations in a block.
This class describes a specific conversion target.
void addDynamicallyLegalOp(OperationName op, const DynamicLegalityCallbackFn &callback)
Register the given operation as dynamically legal and set the dynamic legalization callback to the on...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
typename SourceOp::template GenericAdaptor< ArrayRef< ValueRange > > OneToNOpAdaptor
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
void setAttrs(DictionaryAttr newAttrs)
Set the attributes from a dictionary on this operation.
Definition: Operation.cpp:305
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:687
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:736
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:648
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:554
Type conversion class.
bool isLegal(Type type) const
Return true if the given type is legal for this type converter, i.e.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
void populateSCFStructuralTypeConversionsAndLegality(const TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target)
Populates patterns for SCF structural type conversions and sets up the provided ConversionTarget with...
void populateSCFStructuralTypeConversions(const TypeConverter &typeConverter, RewritePatternSet &patterns)
Similar to populateSCFStructuralTypeConversionsAndLegality but does not populate the conversion targe...
void populateSCFStructuralTypeConversionTarget(const TypeConverter &typeConverter, ConversionTarget &target)
Updates the ConversionTarget with dynamic legality of SCF operations based on the provided type conve...
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns