MLIR  20.0.0git
StructuralTypeConversions.cpp
Go to the documentation of this file.
1 //===- StructuralTypeConversions.cpp - scf structural type conversions ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
12 #include <optional>
13 
14 using namespace mlir;
15 using namespace mlir::scf;
16 
17 namespace {
18 
19 /// Flatten the given value ranges into a single vector of values.
21  SmallVector<Value> result;
22  for (const auto &vals : values)
23  llvm::append_range(result, vals);
24  return result;
25 }
26 
27 /// Assert that the given value range contains a single value and return it.
28 static Value getSingleValue(ValueRange values) {
29  assert(values.size() == 1 && "expected single value");
30  return values.front();
31 }
32 
33 // CRTP
34 // A base class that takes care of 1:N type conversion, which maps the converted
35 // op results (computed by the derived class) and materializes 1:N conversion.
36 template <typename SourceOp, typename ConcretePattern>
37 class Structural1ToNConversionPattern : public OpConversionPattern<SourceOp> {
38 public:
41  using OneToNOpAdaptor =
43 
44  //
45  // Derived classes should provide the following method which performs the
46  // actual conversion. It should return std::nullopt upon conversion failure
47  // and return the converted operation upon success.
48  //
49  // std::optional<SourceOp> convertSourceOp(
50  // SourceOp op, OneToNOpAdaptor adaptor,
51  // ConversionPatternRewriter &rewriter,
52  // TypeRange dstTypes) const;
53 
54  LogicalResult
55  matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor,
56  ConversionPatternRewriter &rewriter) const override {
57  SmallVector<Type> dstTypes;
58  SmallVector<unsigned> offsets;
59  offsets.push_back(0);
60  // Do the type conversion and record the offsets.
61  for (Type type : op.getResultTypes()) {
62  if (failed(typeConverter->convertTypes(type, dstTypes)))
63  return rewriter.notifyMatchFailure(op, "could not convert result type");
64  offsets.push_back(dstTypes.size());
65  }
66 
67  // Calls the actual converter implementation to convert the operation.
68  std::optional<SourceOp> newOp =
69  static_cast<const ConcretePattern *>(this)->convertSourceOp(
70  op, adaptor, rewriter, dstTypes);
71 
72  if (!newOp)
73  return rewriter.notifyMatchFailure(op, "could not convert operation");
74 
75  // Packs the return value.
76  SmallVector<ValueRange> packedRets;
77  for (unsigned i = 1, e = offsets.size(); i < e; i++) {
78  unsigned start = offsets[i - 1], end = offsets[i];
79  unsigned len = end - start;
80  ValueRange mappedValue = newOp->getResults().slice(start, len);
81  packedRets.push_back(mappedValue);
82  }
83 
84  rewriter.replaceOpWithMultiple(op, packedRets);
85  return success();
86  }
87 };
88 
89 class ConvertForOpTypes
90  : public Structural1ToNConversionPattern<ForOp, ConvertForOpTypes> {
91 public:
92  using Structural1ToNConversionPattern::Structural1ToNConversionPattern;
93 
94  // The callback required by CRTP.
95  std::optional<ForOp> convertSourceOp(ForOp op, OneToNOpAdaptor adaptor,
96  ConversionPatternRewriter &rewriter,
97  TypeRange dstTypes) const {
98  // Create a empty new op and inline the regions from the old op.
99  //
100  // This is a little bit tricky. We have two concerns here:
101  //
102  // 1. We cannot update the op in place because the dialect conversion
103  // framework does not track type changes for ops updated in place, so it
104  // won't insert appropriate materializations on the changed result types.
105  // PR47938 tracks this issue, but it seems hard to fix. Instead, we need
106  // to clone the op.
107  //
108  // 2. We need to resue the original region instead of cloning it, otherwise
109  // the dialect conversion framework thinks that we just inserted all the
110  // cloned child ops. But what we want is to "take" the child regions and let
111  // the dialect conversion framework continue recursively into ops inside
112  // those regions (which are already in its worklist; inlining them into the
113  // new op's regions doesn't remove the child ops from the worklist).
114 
115  // convertRegionTypes already takes care of 1:N conversion.
116  if (failed(rewriter.convertRegionTypes(&op.getRegion(), *typeConverter)))
117  return std::nullopt;
118 
119  // We can not do clone as the number of result types after conversion
120  // might be different.
121  ForOp newOp = rewriter.create<ForOp>(
122  op.getLoc(), getSingleValue(adaptor.getLowerBound()),
123  getSingleValue(adaptor.getUpperBound()),
124  getSingleValue(adaptor.getStep()),
125  flattenValues(adaptor.getInitArgs()));
126 
127  // Reserve whatever attributes in the original op.
128  newOp->setAttrs(op->getAttrs());
129 
130  // We do not need the empty block created by rewriter.
131  rewriter.eraseBlock(newOp.getBody(0));
132  // Inline the type converted region from the original operation.
133  rewriter.inlineRegionBefore(op.getRegion(), newOp.getRegion(),
134  newOp.getRegion().end());
135 
136  return newOp;
137  }
138 };
139 } // namespace
140 
141 namespace {
142 class ConvertIfOpTypes
143  : public Structural1ToNConversionPattern<IfOp, ConvertIfOpTypes> {
144 public:
145  using Structural1ToNConversionPattern::Structural1ToNConversionPattern;
146 
147  std::optional<IfOp> convertSourceOp(IfOp op, OneToNOpAdaptor adaptor,
148  ConversionPatternRewriter &rewriter,
149  TypeRange dstTypes) const {
150 
151  IfOp newOp = rewriter.create<IfOp>(
152  op.getLoc(), dstTypes, getSingleValue(adaptor.getCondition()), true);
153  newOp->setAttrs(op->getAttrs());
154 
155  // We do not need the empty blocks created by rewriter.
156  rewriter.eraseBlock(newOp.elseBlock());
157  rewriter.eraseBlock(newOp.thenBlock());
158 
159  // Inlines block from the original operation.
160  rewriter.inlineRegionBefore(op.getThenRegion(), newOp.getThenRegion(),
161  newOp.getThenRegion().end());
162  rewriter.inlineRegionBefore(op.getElseRegion(), newOp.getElseRegion(),
163  newOp.getElseRegion().end());
164 
165  return newOp;
166  }
167 };
168 } // namespace
169 
170 namespace {
171 class ConvertWhileOpTypes
172  : public Structural1ToNConversionPattern<WhileOp, ConvertWhileOpTypes> {
173 public:
174  using Structural1ToNConversionPattern::Structural1ToNConversionPattern;
175 
176  std::optional<WhileOp> convertSourceOp(WhileOp op, OneToNOpAdaptor adaptor,
177  ConversionPatternRewriter &rewriter,
178  TypeRange dstTypes) const {
179  auto newOp = rewriter.create<WhileOp>(op.getLoc(), dstTypes,
180  flattenValues(adaptor.getOperands()));
181 
182  for (auto i : {0u, 1u}) {
183  if (failed(rewriter.convertRegionTypes(&op.getRegion(i), *typeConverter)))
184  return std::nullopt;
185  auto &dstRegion = newOp.getRegion(i);
186  rewriter.inlineRegionBefore(op.getRegion(i), dstRegion, dstRegion.end());
187  }
188  return newOp;
189  }
190 };
191 } // namespace
192 
193 namespace {
194 // When the result types of a ForOp/IfOp get changed, the operand types of the
195 // corresponding yield op need to be changed. In order to trigger the
196 // appropriate type conversions / materializations, we need a dummy pattern.
197 class ConvertYieldOpTypes : public OpConversionPattern<scf::YieldOp> {
198 public:
200  LogicalResult
201  matchAndRewrite(scf::YieldOp op, OneToNOpAdaptor adaptor,
202  ConversionPatternRewriter &rewriter) const override {
203  rewriter.replaceOpWithNewOp<scf::YieldOp>(
204  op, flattenValues(adaptor.getOperands()));
205  return success();
206  }
207 };
208 } // namespace
209 
210 namespace {
211 class ConvertConditionOpTypes : public OpConversionPattern<ConditionOp> {
212 public:
214  LogicalResult
215  matchAndRewrite(ConditionOp op, OneToNOpAdaptor adaptor,
216  ConversionPatternRewriter &rewriter) const override {
217  rewriter.modifyOpInPlace(
218  op, [&]() { op->setOperands(flattenValues(adaptor.getOperands())); });
219  return success();
220  }
221 };
222 } // namespace
223 
225  const TypeConverter &typeConverter, RewritePatternSet &patterns) {
226  patterns.add<ConvertForOpTypes, ConvertIfOpTypes, ConvertYieldOpTypes,
227  ConvertWhileOpTypes, ConvertConditionOpTypes>(
228  typeConverter, patterns.getContext());
229 }
230 
232  const TypeConverter &typeConverter, ConversionTarget &target) {
233  target.addDynamicallyLegalOp<ForOp, IfOp>([&](Operation *op) {
234  return typeConverter.isLegal(op->getResultTypes());
235  });
236  target.addDynamicallyLegalOp<scf::YieldOp>([&](scf::YieldOp op) {
237  // We only have conversions for a subset of ops that use scf.yield
238  // terminators.
239  if (!isa<ForOp, IfOp, WhileOp>(op->getParentOp()))
240  return true;
241  return typeConverter.isLegal(op.getOperandTypes());
242  });
243  target.addDynamicallyLegalOp<WhileOp, ConditionOp>(
244  [&](Operation *op) { return typeConverter.isLegal(op); });
245 }
246 
248  const TypeConverter &typeConverter, RewritePatternSet &patterns,
249  ConversionTarget &target) {
251  populateSCFStructuralTypeConversionTarget(typeConverter, target);
252 }
static SmallVector< Value > flattenValues(ArrayRef< ValueRange > values)
Flatten the given value ranges into a single vector of values.
static Value getSingleValue(ValueRange values)
Assert that the given value range contains a single value and return it.
This class implements a pattern rewriter for use with ConversionPatterns.
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Apply a signature conversion to each block in the given region.
void replaceOpWithMultiple(Operation *op, ArrayRef< ValueRange > newValues)
Replace the given operation with the new value ranges.
void eraseBlock(Block *block) override
PatternRewriter hook for erase all operations in a block.
This class describes a specific conversion target.
void addDynamicallyLegalOp(OperationName op, const DynamicLegalityCallbackFn &callback)
Register the given operation as dynamically legal and set the dynamic legalization callback to the on...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
typename SourceOp::template GenericAdaptor< ArrayRef< ValueRange > > OneToNOpAdaptor
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
void setAttrs(DictionaryAttr newAttrs)
Set the attributes from a dictionary on this operation.
Definition: Operation.cpp:305
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:687
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:724
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:636
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:542
Type conversion class.
bool isLegal(Type type) const
Return true if the given type is legal for this type converter, i.e.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
void populateSCFStructuralTypeConversionsAndLegality(const TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target)
Populates patterns for SCF structural type conversions and sets up the provided ConversionTarget with...
void populateSCFStructuralTypeConversions(const TypeConverter &typeConverter, RewritePatternSet &patterns)
Similar to populateSCFStructuralTypeConversionsAndLegality but does not populate the conversion targe...
void populateSCFStructuralTypeConversionTarget(const TypeConverter &typeConverter, ConversionTarget &target)
Updates the ConversionTarget with dynamic legality of SCF operations based on the provided type conve...
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns