MLIR  19.0.0git
StructuralTypeConversions.cpp
Go to the documentation of this file.
1 //===- StructuralTypeConversions.cpp - scf structural type conversions ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
12 #include <optional>
13 
14 using namespace mlir;
15 using namespace mlir::scf;
16 
17 namespace {
18 
19 // Unpacks the single unrealized_conversion_cast using the list of inputs
20 // e.g., return [%b, %c, %d] for %a = unrealized_conversion_cast(%b, %c, %d)
21 static void unpackUnrealizedConversionCast(Value v,
22  SmallVectorImpl<Value> &unpacked) {
23  if (auto cast =
24  dyn_cast_or_null<UnrealizedConversionCastOp>(v.getDefiningOp())) {
25  if (cast.getInputs().size() != 1) {
26  // 1 : N type conversion.
27  unpacked.append(cast.getInputs().begin(), cast.getInputs().end());
28  return;
29  }
30  }
31  // 1 : 1 type conversion.
32  unpacked.push_back(v);
33 }
34 
35 // CRTP
36 // A base class that takes care of 1:N type conversion, which maps the converted
37 // op results (computed by the derived class) and materializes 1:N conversion.
38 template <typename SourceOp, typename ConcretePattern>
39 class Structural1ToNConversionPattern : public OpConversionPattern<SourceOp> {
40 public:
43  using OpAdaptor = typename OpConversionPattern<SourceOp>::OpAdaptor;
44 
45  //
46  // Derived classes should provide the following method which performs the
47  // actual conversion. It should return std::nullopt upon conversion failure
48  // and return the converted operation upon success.
49  //
50  // std::optional<SourceOp> convertSourceOp(SourceOp op, OpAdaptor adaptor,
51  // ConversionPatternRewriter &rewriter,
52  // TypeRange dstTypes) const;
53 
55  matchAndRewrite(SourceOp op, OpAdaptor adaptor,
56  ConversionPatternRewriter &rewriter) const override {
57  SmallVector<Type> dstTypes;
58  SmallVector<unsigned> offsets;
59  offsets.push_back(0);
60  // Do the type conversion and record the offsets.
61  for (Type type : op.getResultTypes()) {
62  if (failed(typeConverter->convertTypes(type, dstTypes)))
63  return rewriter.notifyMatchFailure(op, "could not convert result type");
64  offsets.push_back(dstTypes.size());
65  }
66 
67  // Calls the actual converter implementation to convert the operation.
68  std::optional<SourceOp> newOp =
69  static_cast<const ConcretePattern *>(this)->convertSourceOp(
70  op, adaptor, rewriter, dstTypes);
71 
72  if (!newOp)
73  return rewriter.notifyMatchFailure(op, "could not convert operation");
74 
75  // Packs the return value.
76  SmallVector<Value> packedRets;
77  for (unsigned i = 1, e = offsets.size(); i < e; i++) {
78  unsigned start = offsets[i - 1], end = offsets[i];
79  unsigned len = end - start;
80  ValueRange mappedValue = newOp->getResults().slice(start, len);
81  if (len != 1) {
82  // 1 : N type conversion.
83  Type origType = op.getResultTypes()[i - 1];
84  Value mat = typeConverter->materializeSourceConversion(
85  rewriter, op.getLoc(), origType, mappedValue);
86  if (!mat) {
87  return rewriter.notifyMatchFailure(
88  op, "Failed to materialize 1:N type conversion");
89  }
90  packedRets.push_back(mat);
91  } else {
92  // 1 : 1 type conversion.
93  packedRets.push_back(mappedValue.front());
94  }
95  }
96 
97  rewriter.replaceOp(op, packedRets);
98  return success();
99  }
100 };
101 
102 class ConvertForOpTypes
103  : public Structural1ToNConversionPattern<ForOp, ConvertForOpTypes> {
104 public:
105  using Structural1ToNConversionPattern::Structural1ToNConversionPattern;
106 
107  // The callback required by CRTP.
108  std::optional<ForOp> convertSourceOp(ForOp op, OpAdaptor adaptor,
109  ConversionPatternRewriter &rewriter,
110  TypeRange dstTypes) const {
111  // Create a empty new op and inline the regions from the old op.
112  //
113  // This is a little bit tricky. We have two concerns here:
114  //
115  // 1. We cannot update the op in place because the dialect conversion
116  // framework does not track type changes for ops updated in place, so it
117  // won't insert appropriate materializations on the changed result types.
118  // PR47938 tracks this issue, but it seems hard to fix. Instead, we need
119  // to clone the op.
120  //
121  // 2. We need to resue the original region instead of cloning it, otherwise
122  // the dialect conversion framework thinks that we just inserted all the
123  // cloned child ops. But what we want is to "take" the child regions and let
124  // the dialect conversion framework continue recursively into ops inside
125  // those regions (which are already in its worklist; inlining them into the
126  // new op's regions doesn't remove the child ops from the worklist).
127 
128  // convertRegionTypes already takes care of 1:N conversion.
129  if (failed(rewriter.convertRegionTypes(&op.getRegion(), *typeConverter)))
130  return std::nullopt;
131 
132  // Unpacked the iteration arguments.
133  SmallVector<Value> flatArgs;
134  for (Value arg : adaptor.getInitArgs())
135  unpackUnrealizedConversionCast(arg, flatArgs);
136 
137  // We can not do clone as the number of result types after conversion
138  // might be different.
139  ForOp newOp = rewriter.create<ForOp>(op.getLoc(), adaptor.getLowerBound(),
140  adaptor.getUpperBound(),
141  adaptor.getStep(), flatArgs);
142 
143  // Reserve whatever attributes in the original op.
144  newOp->setAttrs(op->getAttrs());
145 
146  // We do not need the empty block created by rewriter.
147  rewriter.eraseBlock(newOp.getBody(0));
148  // Inline the type converted region from the original operation.
149  rewriter.inlineRegionBefore(op.getRegion(), newOp.getRegion(),
150  newOp.getRegion().end());
151 
152  return newOp;
153  }
154 };
155 } // namespace
156 
157 namespace {
158 class ConvertIfOpTypes
159  : public Structural1ToNConversionPattern<IfOp, ConvertIfOpTypes> {
160 public:
161  using Structural1ToNConversionPattern::Structural1ToNConversionPattern;
162 
163  std::optional<IfOp> convertSourceOp(IfOp op, OpAdaptor adaptor,
164  ConversionPatternRewriter &rewriter,
165  TypeRange dstTypes) const {
166 
167  IfOp newOp = rewriter.create<IfOp>(op.getLoc(), dstTypes,
168  adaptor.getCondition(), true);
169  newOp->setAttrs(op->getAttrs());
170 
171  // We do not need the empty blocks created by rewriter.
172  rewriter.eraseBlock(newOp.elseBlock());
173  rewriter.eraseBlock(newOp.thenBlock());
174 
175  // Inlines block from the original operation.
176  rewriter.inlineRegionBefore(op.getThenRegion(), newOp.getThenRegion(),
177  newOp.getThenRegion().end());
178  rewriter.inlineRegionBefore(op.getElseRegion(), newOp.getElseRegion(),
179  newOp.getElseRegion().end());
180 
181  return newOp;
182  }
183 };
184 } // namespace
185 
186 namespace {
187 class ConvertWhileOpTypes
188  : public Structural1ToNConversionPattern<WhileOp, ConvertWhileOpTypes> {
189 public:
190  using Structural1ToNConversionPattern::Structural1ToNConversionPattern;
191 
192  std::optional<WhileOp> convertSourceOp(WhileOp op, OpAdaptor adaptor,
193  ConversionPatternRewriter &rewriter,
194  TypeRange dstTypes) const {
195  // Unpacked the iteration arguments.
196  SmallVector<Value> flatArgs;
197  for (Value arg : adaptor.getOperands())
198  unpackUnrealizedConversionCast(arg, flatArgs);
199 
200  auto newOp = rewriter.create<WhileOp>(op.getLoc(), dstTypes, flatArgs);
201 
202  for (auto i : {0u, 1u}) {
203  if (failed(rewriter.convertRegionTypes(&op.getRegion(i), *typeConverter)))
204  return std::nullopt;
205  auto &dstRegion = newOp.getRegion(i);
206  rewriter.inlineRegionBefore(op.getRegion(i), dstRegion, dstRegion.end());
207  }
208  return newOp;
209  }
210 };
211 } // namespace
212 
213 namespace {
214 // When the result types of a ForOp/IfOp get changed, the operand types of the
215 // corresponding yield op need to be changed. In order to trigger the
216 // appropriate type conversions / materializations, we need a dummy pattern.
217 class ConvertYieldOpTypes : public OpConversionPattern<scf::YieldOp> {
218 public:
221  matchAndRewrite(scf::YieldOp op, OpAdaptor adaptor,
222  ConversionPatternRewriter &rewriter) const override {
223  SmallVector<Value> unpackedYield;
224  for (Value operand : adaptor.getOperands())
225  unpackUnrealizedConversionCast(operand, unpackedYield);
226 
227  rewriter.replaceOpWithNewOp<scf::YieldOp>(op, unpackedYield);
228  return success();
229  }
230 };
231 } // namespace
232 
233 namespace {
234 class ConvertConditionOpTypes : public OpConversionPattern<ConditionOp> {
235 public:
238  matchAndRewrite(ConditionOp op, OpAdaptor adaptor,
239  ConversionPatternRewriter &rewriter) const override {
240  SmallVector<Value> unpackedYield;
241  for (Value operand : adaptor.getOperands())
242  unpackUnrealizedConversionCast(operand, unpackedYield);
243 
244  rewriter.modifyOpInPlace(op, [&]() { op->setOperands(unpackedYield); });
245  return success();
246  }
247 };
248 } // namespace
249 
251  TypeConverter &typeConverter, RewritePatternSet &patterns) {
252  patterns.add<ConvertForOpTypes, ConvertIfOpTypes, ConvertYieldOpTypes,
253  ConvertWhileOpTypes, ConvertConditionOpTypes>(
254  typeConverter, patterns.getContext());
255 }
256 
258  const TypeConverter &typeConverter, ConversionTarget &target) {
259  target.addDynamicallyLegalOp<ForOp, IfOp>([&](Operation *op) {
260  return typeConverter.isLegal(op->getResultTypes());
261  });
262  target.addDynamicallyLegalOp<scf::YieldOp>([&](scf::YieldOp op) {
263  // We only have conversions for a subset of ops that use scf.yield
264  // terminators.
265  if (!isa<ForOp, IfOp, WhileOp>(op->getParentOp()))
266  return true;
267  return typeConverter.isLegal(op.getOperandTypes());
268  });
269  target.addDynamicallyLegalOp<WhileOp, ConditionOp>(
270  [&](Operation *op) { return typeConverter.isLegal(op); });
271 }
272 
274  TypeConverter &typeConverter, RewritePatternSet &patterns,
275  ConversionTarget &target) {
276  populateSCFStructuralTypeConversions(typeConverter, patterns);
277  populateSCFStructuralTypeConversionTarget(typeConverter, target);
278 }
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Convert the types of block arguments within the given region.
void eraseBlock(Block *block) override
PatternRewriter hook for erase all operations in a block.
This class describes a specific conversion target.
void addDynamicallyLegalOp(OperationName op, const DynamicLegalityCallbackFn &callback)
Register the given operation as dynamically legal and set the dynamic legalization callback to the on...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
typename SourceOp::Adaptor OpAdaptor
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
void setAttrs(DictionaryAttr newAttrs)
Set the attributes from a dictionary on this operation.
Definition: Operation.cpp:305
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:507
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:682
operand_type_range getOperandTypes()
Definition: Operation.h:392
result_type_range getResultTypes()
Definition: Operation.h:423
void setOperands(ValueRange operands)
Replace the current operands of this operation with the ones provided in 'operands'.
Definition: Operation.cpp:237
iterator end()
Definition: Region.h:56
MLIRContext * getContext() const
Definition: PatternMatch.h:822
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:846
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:630
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
Type conversion class.
bool isLegal(Type type) const
Return true if the given type is legal for this type converter, i.e.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
void populateSCFStructuralTypeConversionsAndLegality(TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target)
Populates patterns for SCF structural type conversions and sets up the provided ConversionTarget with...
void populateSCFStructuralTypeConversions(TypeConverter &typeConverter, RewritePatternSet &patterns)
Similar to populateSCFStructuralTypeConversionsAndLegality but does not populate the conversion targe...
void populateSCFStructuralTypeConversionTarget(const TypeConverter &typeConverter, ConversionTarget &target)
Updates the ConversionTarget with dynamic legality of SCF operations based on the provided type conve...
Include the generated interface declarations.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26