MLIR  16.0.0git
StructuralTypeConversions.cpp
Go to the documentation of this file.
1 //===- StructuralTypeConversions.cpp - scf structural type conversions ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "PassDetail.h"
14 
15 using namespace mlir;
16 using namespace mlir::scf;
17 
18 namespace {
19 class ConvertForOpTypes : public OpConversionPattern<ForOp> {
20 public:
23  matchAndRewrite(ForOp op, OpAdaptor adaptor,
24  ConversionPatternRewriter &rewriter) const override {
25  SmallVector<Type, 6> newResultTypes;
26  for (auto type : op.getResultTypes()) {
27  Type newType = typeConverter->convertType(type);
28  if (!newType)
29  return rewriter.notifyMatchFailure(op, "not a 1:1 type conversion");
30  newResultTypes.push_back(newType);
31  }
32 
33  // Clone the op without the regions and inline the regions from the old op.
34  //
35  // This is a little bit tricky. We have two concerns here:
36  //
37  // 1. We cannot update the op in place because the dialect conversion
38  // framework does not track type changes for ops updated in place, so it
39  // won't insert appropriate materializations on the changed result types.
40  // PR47938 tracks this issue, but it seems hard to fix. Instead, we need to
41  // clone the op.
42  //
43  // 2. We cannot simply call `op.clone()` to get the cloned op. Besides being
44  // inefficient to recursively clone the regions, there is a correctness
45  // issue: if we clone with the regions, then the dialect conversion
46  // framework thinks that we just inserted all the cloned child ops. But what
47  // we want is to "take" the child regions and let the dialect conversion
48  // framework continue recursively into ops inside those regions (which are
49  // already in its worklist; inlining them into the new op's regions doesn't
50  // remove the child ops from the worklist).
51  ForOp newOp = cast<ForOp>(rewriter.cloneWithoutRegions(*op.getOperation()));
52  // Take the region from the old op and put it in the new op.
53  rewriter.inlineRegionBefore(op.getLoopBody(), newOp.getLoopBody(),
54  newOp.getLoopBody().end());
55 
56  // Now, update all the types.
57 
58  // Convert the type of the entry block of the ForOp's body.
59  if (failed(rewriter.convertRegionTypes(&newOp.getLoopBody(),
60  *getTypeConverter()))) {
61  return rewriter.notifyMatchFailure(op, "could not convert body types");
62  }
63  // Change the clone to use the updated operands. We could have cloned with
64  // a BlockAndValueMapping, but this seems a bit more direct.
65  newOp->setOperands(adaptor.getOperands());
66  // Update the result types to the new converted types.
67  for (auto t : llvm::zip(newOp.getResults(), newResultTypes))
68  std::get<0>(t).setType(std::get<1>(t));
69 
70  rewriter.replaceOp(op, newOp.getResults());
71  return success();
72  }
73 };
74 } // namespace
75 
76 namespace {
77 class ConvertIfOpTypes : public OpConversionPattern<IfOp> {
78 public:
81  matchAndRewrite(IfOp op, OpAdaptor adaptor,
82  ConversionPatternRewriter &rewriter) const override {
83  // TODO: Generalize this to any type conversion, not just 1:1.
84  //
85  // We need to implement something more sophisticated here that tracks which
86  // types convert to which other types and does the appropriate
87  // materialization logic.
88  // For example, it's possible that one result type converts to 0 types and
89  // another to 2 types, so newResultTypes would at least be the right size to
90  // not crash in the llvm::zip call below, but then we would set the the
91  // wrong type on the SSA values! These edge cases are also why we cannot
92  // safely use the TypeConverter::convertTypes helper here.
93  SmallVector<Type, 6> newResultTypes;
94  for (auto type : op.getResultTypes()) {
95  Type newType = typeConverter->convertType(type);
96  if (!newType)
97  return rewriter.notifyMatchFailure(op, "not a 1:1 type conversion");
98  newResultTypes.push_back(newType);
99  }
100 
101  // See comments in the ForOp pattern for why we clone without regions and
102  // then inline.
103  IfOp newOp = cast<IfOp>(rewriter.cloneWithoutRegions(*op.getOperation()));
104  rewriter.inlineRegionBefore(op.getThenRegion(), newOp.getThenRegion(),
105  newOp.getThenRegion().end());
106  rewriter.inlineRegionBefore(op.getElseRegion(), newOp.getElseRegion(),
107  newOp.getElseRegion().end());
108 
109  // Update the operands and types.
110  newOp->setOperands(adaptor.getOperands());
111  for (auto t : llvm::zip(newOp.getResults(), newResultTypes))
112  std::get<0>(t).setType(std::get<1>(t));
113  rewriter.replaceOp(op, newOp.getResults());
114  return success();
115  }
116 };
117 } // namespace
118 
119 namespace {
120 // When the result types of a ForOp/IfOp get changed, the operand types of the
121 // corresponding yield op need to be changed. In order to trigger the
122 // appropriate type conversions / materializations, we need a dummy pattern.
123 class ConvertYieldOpTypes : public OpConversionPattern<scf::YieldOp> {
124 public:
127  matchAndRewrite(scf::YieldOp op, OpAdaptor adaptor,
128  ConversionPatternRewriter &rewriter) const override {
129  rewriter.replaceOpWithNewOp<scf::YieldOp>(op, adaptor.getOperands());
130  return success();
131  }
132 };
133 } // namespace
134 
135 namespace {
136 class ConvertWhileOpTypes : public OpConversionPattern<WhileOp> {
137 public:
139 
141  matchAndRewrite(WhileOp op, OpAdaptor adaptor,
142  ConversionPatternRewriter &rewriter) const override {
143  auto *converter = getTypeConverter();
144  assert(converter);
145  SmallVector<Type> newResultTypes;
146  if (failed(converter->convertTypes(op.getResultTypes(), newResultTypes)))
147  return failure();
148 
149  auto newOp = rewriter.create<WhileOp>(op.getLoc(), newResultTypes,
150  adaptor.getOperands());
151  for (auto i : {0u, 1u}) {
152  auto &dstRegion = newOp.getRegion(i);
153  rewriter.inlineRegionBefore(op.getRegion(i), dstRegion, dstRegion.end());
154  if (failed(rewriter.convertRegionTypes(&dstRegion, *converter)))
155  return rewriter.notifyMatchFailure(op, "could not convert body types");
156  }
157  rewriter.replaceOp(op, newOp.getResults());
158  return success();
159  }
160 };
161 } // namespace
162 
163 namespace {
164 class ConvertConditionOpTypes : public OpConversionPattern<ConditionOp> {
165 public:
168  matchAndRewrite(ConditionOp op, OpAdaptor adaptor,
169  ConversionPatternRewriter &rewriter) const override {
170  rewriter.updateRootInPlace(
171  op, [&]() { op->setOperands(adaptor.getOperands()); });
172  return success();
173  }
174 };
175 } // namespace
176 
178  TypeConverter &typeConverter, RewritePatternSet &patterns,
179  ConversionTarget &target) {
180  patterns.add<ConvertForOpTypes, ConvertIfOpTypes, ConvertYieldOpTypes,
181  ConvertWhileOpTypes, ConvertConditionOpTypes>(
182  typeConverter, patterns.getContext());
183  target.addDynamicallyLegalOp<ForOp, IfOp>([&](Operation *op) {
184  return typeConverter.isLegal(op->getResultTypes());
185  });
186  target.addDynamicallyLegalOp<scf::YieldOp>([&](scf::YieldOp op) {
187  // We only have conversions for a subset of ops that use scf.yield
188  // terminators.
189  if (!isa<ForOp, IfOp, WhileOp>(op->getParentOp()))
190  return true;
191  return typeConverter.isLegal(op.getOperandTypes());
192  });
193  target.addDynamicallyLegalOp<WhileOp, ConditionOp>(
194  [&](Operation *op) { return typeConverter.isLegal(op); });
195 }
Include the generated interface declarations.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types &#39;Ts&#39; to the pattern list with the given arguments...
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
Operation * cloneWithoutRegions(Operation &op, BlockAndValueMapping &mapper)
Creates a deep copy of this operation but keep the operation regions empty.
Definition: Builders.h:522
void populateSCFStructuralTypeConversionsAndLegality(TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target)
Populates patterns for SCF structural type conversions and sets up the provided ConversionTarget with...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
Definition: LogicalResult.h:72
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing the results of an operation.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:404
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
void addDynamicallyLegalOp(OperationName op, const DynamicLegalityCallbackFn &callback)
Register the given operation as dynamically legal and set the dynamic legalization callback to the on...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
void updateRootInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around a root update of an operation.
Definition: PatternMatch.h:499
LogicalResult notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
PatternRewriter hook for notifying match failure reasons.
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before) override
PatternRewriter hook for moving blocks out of a region.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:451
Type conversion class.
This class implements a pattern rewriter for use with ConversionPatterns.
This class describes a specific conversion target.
Region & getRegion(unsigned index)
Returns the region held by this operation at position &#39;index&#39;.
Definition: Operation.h:486
MLIRContext * getContext() const
FailureOr< Block * > convertRegionTypes(Region *region, TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Convert the types of block arguments within the given region.