MLIR  20.0.0git
OneToNTypeConversion.cpp
Go to the documentation of this file.
1 //===-- OneToNTypeConversion.cpp - SCF 1:N type conversion ------*- C++ -*-===//
2 //
3 // Licensed under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // The patterns in this file are heavily inspired (and copied from)
10 // lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp but work for 1:N
11 // type conversions.
12 //
13 //===----------------------------------------------------------------------===//
14 
16 
19 
20 using namespace mlir;
21 using namespace mlir::scf;
22 
24 public:
26 
27  LogicalResult
28  matchAndRewrite(IfOp op, OpAdaptor adaptor,
29  OneToNPatternRewriter &rewriter) const override {
30  Location loc = op->getLoc();
31  const OneToNTypeMapping &resultMapping = adaptor.getResultMapping();
32 
33  // Nothing to do if there is no non-identity conversion.
34  if (!resultMapping.hasNonIdentityConversion())
35  return failure();
36 
37  // Create new IfOp.
38  TypeRange convertedResultTypes = resultMapping.getConvertedTypes();
39  auto newOp = rewriter.create<IfOp>(loc, convertedResultTypes,
40  op.getCondition(), true);
41  newOp->setAttrs(op->getAttrs());
42 
43  // We do not need the empty blocks created by rewriter.
44  rewriter.eraseBlock(newOp.elseBlock());
45  rewriter.eraseBlock(newOp.thenBlock());
46 
47  // Inlines block from the original operation.
48  rewriter.inlineRegionBefore(op.getThenRegion(), newOp.getThenRegion(),
49  newOp.getThenRegion().end());
50  rewriter.inlineRegionBefore(op.getElseRegion(), newOp.getElseRegion(),
51  newOp.getElseRegion().end());
52 
53  rewriter.replaceOp(op, newOp->getResults(), resultMapping);
54  return success();
55  }
56 };
57 
59 public:
61 
62  LogicalResult
63  matchAndRewrite(WhileOp op, OpAdaptor adaptor,
64  OneToNPatternRewriter &rewriter) const override {
65  Location loc = op->getLoc();
66 
67  const OneToNTypeMapping &operandMapping = adaptor.getOperandMapping();
68  const OneToNTypeMapping &resultMapping = adaptor.getResultMapping();
69 
70  // Nothing to do if the op doesn't have any non-identity conversions for its
71  // operands or results.
72  if (!operandMapping.hasNonIdentityConversion() &&
73  !resultMapping.hasNonIdentityConversion())
74  return failure();
75 
76  // Create new WhileOp.
77  TypeRange convertedResultTypes = resultMapping.getConvertedTypes();
78 
79  auto newOp = rewriter.create<WhileOp>(loc, convertedResultTypes,
80  adaptor.getFlatOperands());
81  newOp->setAttrs(op->getAttrs());
82 
83  // Update block signatures.
84  std::array<OneToNTypeMapping, 2> blockMappings = {operandMapping,
85  resultMapping};
86  for (unsigned int i : {0u, 1u}) {
87  Region *region = &op.getRegion(i);
88  Block *block = &region->front();
89 
90  rewriter.applySignatureConversion(block, blockMappings[i]);
91 
92  // Move updated region to new WhileOp.
93  Region &dstRegion = newOp.getRegion(i);
94  rewriter.inlineRegionBefore(op.getRegion(i), dstRegion, dstRegion.end());
95  }
96 
97  rewriter.replaceOp(op, newOp->getResults(), resultMapping);
98  return success();
99  }
100 };
101 
103 public:
105 
106  LogicalResult
107  matchAndRewrite(YieldOp op, OpAdaptor adaptor,
108  OneToNPatternRewriter &rewriter) const override {
109  // Nothing to do if there is no non-identity conversion.
110  if (!adaptor.getOperandMapping().hasNonIdentityConversion())
111  return failure();
112 
113  // Convert operands.
114  rewriter.modifyOpInPlace(
115  op, [&] { op->setOperands(adaptor.getFlatOperands()); });
116 
117  return success();
118  }
119 };
120 
122  : public OneToNOpConversionPattern<ConditionOp> {
123 public:
125 
126  LogicalResult
127  matchAndRewrite(ConditionOp op, OpAdaptor adaptor,
128  OneToNPatternRewriter &rewriter) const override {
129  // Nothing to do if there is no non-identity conversion.
130  if (!adaptor.getOperandMapping().hasNonIdentityConversion())
131  return failure();
132 
133  // Convert operands.
134  rewriter.modifyOpInPlace(
135  op, [&] { op->setOperands(adaptor.getFlatOperands()); });
136 
137  return success();
138  }
139 };
140 
142 public:
144 
145  LogicalResult
146  matchAndRewrite(ForOp forOp, OpAdaptor adaptor,
147  OneToNPatternRewriter &rewriter) const override {
148  const OneToNTypeMapping &operandMapping = adaptor.getOperandMapping();
149  const OneToNTypeMapping &resultMapping = adaptor.getResultMapping();
150 
151  // Nothing to do if there is no non-identity conversion.
152  if (!operandMapping.hasNonIdentityConversion() &&
153  !resultMapping.hasNonIdentityConversion())
154  return failure();
155 
156  // If the lower-bound, upper-bound, or step were expanded, abort the
157  // conversion. This conversion does not know what to do in such cases.
158  ValueRange lbs = adaptor.getLowerBound();
159  ValueRange ubs = adaptor.getUpperBound();
160  ValueRange steps = adaptor.getStep();
161  if (lbs.size() != 1 || ubs.size() != 1 || steps.size() != 1)
162  return rewriter.notifyMatchFailure(
163  forOp, "index operands converted to multiple values");
164 
165  Location loc = forOp.getLoc();
166 
167  Region *region = &forOp.getRegion();
168  Block *block = &region->front();
169 
170  // Construct the new for-op with an empty body.
171  ValueRange newInits = adaptor.getFlatOperands().drop_front(3);
172  auto newOp =
173  rewriter.create<ForOp>(loc, lbs[0], ubs[0], steps[0], newInits);
174  newOp->setAttrs(forOp->getAttrs());
175 
176  // We do not need the empty blocks created by rewriter.
177  rewriter.eraseBlock(newOp.getBody());
178 
179  // Convert the signature of the body region.
180  OneToNTypeMapping bodyTypeMapping(block->getArgumentTypes());
181  if (failed(typeConverter->convertSignatureArgs(block->getArgumentTypes(),
182  bodyTypeMapping)))
183  return failure();
184 
185  // Perform signature conversion on the body block.
186  rewriter.applySignatureConversion(block, bodyTypeMapping);
187 
188  // Splice the old body region into the new for-op.
189  Region &dstRegion = newOp.getBodyRegion();
190  rewriter.inlineRegionBefore(forOp.getRegion(), dstRegion, dstRegion.end());
191 
192  rewriter.replaceOp(forOp, newOp.getResults(), resultMapping);
193 
194  return success();
195  }
196 };
197 
198 namespace mlir {
199 namespace scf {
200 
202  const TypeConverter &typeConverter, RewritePatternSet &patterns) {
203  patterns.add<
204  // clang-format off
210  // clang-format on
211  >(typeConverter, patterns.getContext());
212 }
213 
214 } // namespace scf
215 } // namespace mlir
LogicalResult matchAndRewrite(ConditionOp op, OpAdaptor adaptor, OneToNPatternRewriter &rewriter) const override
Overload that derived classes have to override for their op type.
LogicalResult matchAndRewrite(ForOp forOp, OpAdaptor adaptor, OneToNPatternRewriter &rewriter) const override
Overload that derived classes have to override for their op type.
LogicalResult matchAndRewrite(IfOp op, OpAdaptor adaptor, OneToNPatternRewriter &rewriter) const override
Overload that derived classes have to override for their op type.
LogicalResult matchAndRewrite(WhileOp op, OpAdaptor adaptor, OneToNPatternRewriter &rewriter) const override
Overload that derived classes have to override for their op type.
LogicalResult matchAndRewrite(YieldOp op, OpAdaptor adaptor, OneToNPatternRewriter &rewriter) const override
Overload that derived classes have to override for their op type.
Block represents an ordered list of Operations.
Definition: Block.h:33
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
Definition: Block.cpp:151
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
This class is a wrapper around OneToNConversionPattern for matching against instances of a particular...
Specialization of PatternRewriter that OneToNConversionPatterns use.
Block * applySignatureConversion(Block *block, OneToNTypeMapping &argumentConversion)
Applies the given argument conversion to the given block.
void replaceOp(Operation *op, ValueRange newValues, const OneToNTypeMapping &resultMapping)
Replaces the results of the operation with the specified list of values mapped back to the original t...
Stores a 1:N mapping of types and provides several useful accessors.
TypeRange getConvertedTypes(unsigned originalTypeNo) const
Returns the list of types that corresponds to the original type at the given index.
bool hasNonIdentityConversion() const
Returns true iff at least one type conversion maps an input type to a type that is different from its...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
void setAttrs(DictionaryAttr newAttrs)
Set the attributes from a dictionary on this operation.
Definition: Operation.cpp:305
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
iterator end()
Definition: Region.h:56
Block & front()
Definition: Region.h:65
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:724
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:636
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
Type conversion class.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
void populateSCFStructuralOneToNTypeConversions(const TypeConverter &typeConverter, RewritePatternSet &patterns)
Populates the provided pattern set with patterns that do 1:N type conversions on (some) SCF ops.
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns