MLIR  16.0.0git
InliningUtils.cpp
Go to the documentation of this file.
1 //===- InliningUtils.cpp ---- Misc utilities for inlining -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements miscellaneous inlining utilities.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
16 #include "mlir/IR/Builders.h"
17 #include "mlir/IR/Operation.h"
19 #include "llvm/ADT/MapVector.h"
20 #include "llvm/Support/Debug.h"
21 #include "llvm/Support/raw_ostream.h"
22 
23 #define DEBUG_TYPE "inlining"
24 
25 using namespace mlir;
26 
27 /// Remap locations from the inlined blocks with CallSiteLoc locations with the
28 /// provided caller location.
29 static void
31  Location callerLoc) {
32  DenseMap<Location, Location> mappedLocations;
33  auto remapOpLoc = [&](Operation *op) {
34  auto it = mappedLocations.find(op->getLoc());
35  if (it == mappedLocations.end()) {
36  auto newLoc = CallSiteLoc::get(op->getLoc(), callerLoc);
37  it = mappedLocations.try_emplace(op->getLoc(), newLoc).first;
38  }
39  op->setLoc(it->second);
40  };
41  for (auto &block : inlinedBlocks)
42  block.walk(remapOpLoc);
43 }
44 
46  BlockAndValueMapping &mapper) {
47  auto remapOperands = [&](Operation *op) {
48  for (auto &operand : op->getOpOperands())
49  if (auto mappedOp = mapper.lookupOrNull(operand.get()))
50  operand.set(mappedOp);
51  };
52  for (auto &block : inlinedBlocks)
53  block.walk(remapOperands);
54 }
55 
56 //===----------------------------------------------------------------------===//
57 // InlinerInterface
58 //===----------------------------------------------------------------------===//
59 
61  bool wouldBeCloned) const {
62  if (auto *handler = getInterfaceFor(call))
63  return handler->isLegalToInline(call, callable, wouldBeCloned);
64  return false;
65 }
66 
68  Region *dest, Region *src, bool wouldBeCloned,
69  BlockAndValueMapping &valueMapping) const {
70  if (auto *handler = getInterfaceFor(dest->getParentOp()))
71  return handler->isLegalToInline(dest, src, wouldBeCloned, valueMapping);
72  return false;
73 }
74 
76  Operation *op, Region *dest, bool wouldBeCloned,
77  BlockAndValueMapping &valueMapping) const {
78  if (auto *handler = getInterfaceFor(op))
79  return handler->isLegalToInline(op, dest, wouldBeCloned, valueMapping);
80  return false;
81 }
82 
84  auto *handler = getInterfaceFor(op);
85  return handler ? handler->shouldAnalyzeRecursively(op) : true;
86 }
87 
88 /// Handle the given inlined terminator by replacing it with a new operation
89 /// as necessary.
91  auto *handler = getInterfaceFor(op);
92  assert(handler && "expected valid dialect handler");
93  handler->handleTerminator(op, newDest);
94 }
95 
96 /// Handle the given inlined terminator by replacing it with a new operation
97 /// as necessary.
99  ArrayRef<Value> valuesToRepl) const {
100  auto *handler = getInterfaceFor(op);
101  assert(handler && "expected valid dialect handler");
102  handler->handleTerminator(op, valuesToRepl);
103 }
104 
106  Operation *call, iterator_range<Region::iterator> inlinedBlocks) const {
107  auto *handler = getInterfaceFor(call);
108  assert(handler && "expected valid dialect handler");
109  handler->processInlinedCallBlocks(call, inlinedBlocks);
110 }
111 
112 /// Utility to check that all of the operations within 'src' can be inlined.
113 static bool isLegalToInline(InlinerInterface &interface, Region *src,
114  Region *insertRegion, bool shouldCloneInlinedRegion,
115  BlockAndValueMapping &valueMapping) {
116  for (auto &block : *src) {
117  for (auto &op : block) {
118  // Check this operation.
119  if (!interface.isLegalToInline(&op, insertRegion,
120  shouldCloneInlinedRegion, valueMapping)) {
121  LLVM_DEBUG({
122  llvm::dbgs() << "* Illegal to inline because of op: ";
123  op.dump();
124  });
125  return false;
126  }
127  // Check any nested regions.
128  if (interface.shouldAnalyzeRecursively(&op) &&
129  llvm::any_of(op.getRegions(), [&](Region &region) {
130  return !isLegalToInline(interface, &region, insertRegion,
131  shouldCloneInlinedRegion, valueMapping);
132  }))
133  return false;
134  }
135  }
136  return true;
137 }
138 
139 //===----------------------------------------------------------------------===//
140 // Inline Methods
141 //===----------------------------------------------------------------------===//
142 
143 static LogicalResult
144 inlineRegionImpl(InlinerInterface &interface, Region *src, Block *inlineBlock,
145  Block::iterator inlinePoint, BlockAndValueMapping &mapper,
146  ValueRange resultsToReplace, TypeRange regionResultTypes,
147  Optional<Location> inlineLoc, bool shouldCloneInlinedRegion,
148  Operation *call = nullptr) {
149  assert(resultsToReplace.size() == regionResultTypes.size());
150  // We expect the region to have at least one block.
151  if (src->empty())
152  return failure();
153 
154  // Check that all of the region arguments have been mapped.
155  auto *srcEntryBlock = &src->front();
156  if (llvm::any_of(srcEntryBlock->getArguments(),
157  [&](BlockArgument arg) { return !mapper.contains(arg); }))
158  return failure();
159 
160  // Check that the operations within the source region are valid to inline.
161  Region *insertRegion = inlineBlock->getParent();
162  if (!interface.isLegalToInline(insertRegion, src, shouldCloneInlinedRegion,
163  mapper) ||
164  !isLegalToInline(interface, src, insertRegion, shouldCloneInlinedRegion,
165  mapper))
166  return failure();
167 
168  // Check to see if the region is being cloned, or moved inline. In either
169  // case, move the new blocks after the 'insertBlock' to improve IR
170  // readability.
171  Block *postInsertBlock = inlineBlock->splitBlock(inlinePoint);
172  if (shouldCloneInlinedRegion)
173  src->cloneInto(insertRegion, postInsertBlock->getIterator(), mapper);
174  else
175  insertRegion->getBlocks().splice(postInsertBlock->getIterator(),
176  src->getBlocks(), src->begin(),
177  src->end());
178 
179  // Get the range of newly inserted blocks.
180  auto newBlocks = llvm::make_range(std::next(inlineBlock->getIterator()),
181  postInsertBlock->getIterator());
182  Block *firstNewBlock = &*newBlocks.begin();
183 
184  // Remap the locations of the inlined operations if a valid source location
185  // was provided.
186  if (inlineLoc && !inlineLoc->isa<UnknownLoc>())
187  remapInlinedLocations(newBlocks, *inlineLoc);
188 
189  // If the blocks were moved in-place, make sure to remap any necessary
190  // operands.
191  if (!shouldCloneInlinedRegion)
192  remapInlinedOperands(newBlocks, mapper);
193 
194  // Process the newly inlined blocks.
195  if (call)
196  interface.processInlinedCallBlocks(call, newBlocks);
197  interface.processInlinedBlocks(newBlocks);
198 
199  // Handle the case where only a single block was inlined.
200  if (std::next(newBlocks.begin()) == newBlocks.end()) {
201  // Have the interface handle the terminator of this block.
202  auto *firstBlockTerminator = firstNewBlock->getTerminator();
203  interface.handleTerminator(firstBlockTerminator,
204  llvm::to_vector<6>(resultsToReplace));
205  firstBlockTerminator->erase();
206 
207  // Merge the post insert block into the cloned entry block.
208  firstNewBlock->getOperations().splice(firstNewBlock->end(),
209  postInsertBlock->getOperations());
210  postInsertBlock->erase();
211  } else {
212  // Otherwise, there were multiple blocks inlined. Add arguments to the post
213  // insertion block to represent the results to replace.
214  for (const auto &resultToRepl : llvm::enumerate(resultsToReplace)) {
215  resultToRepl.value().replaceAllUsesWith(
216  postInsertBlock->addArgument(regionResultTypes[resultToRepl.index()],
217  resultToRepl.value().getLoc()));
218  }
219 
220  /// Handle the terminators for each of the new blocks.
221  for (auto &newBlock : newBlocks)
222  interface.handleTerminator(newBlock.getTerminator(), postInsertBlock);
223  }
224 
225  // Splice the instructions of the inlined entry block into the insert block.
226  inlineBlock->getOperations().splice(inlineBlock->end(),
227  firstNewBlock->getOperations());
228  firstNewBlock->erase();
229  return success();
230 }
231 
232 static LogicalResult
233 inlineRegionImpl(InlinerInterface &interface, Region *src, Block *inlineBlock,
234  Block::iterator inlinePoint, ValueRange inlinedOperands,
235  ValueRange resultsToReplace, Optional<Location> inlineLoc,
236  bool shouldCloneInlinedRegion, Operation *call = nullptr) {
237  // We expect the region to have at least one block.
238  if (src->empty())
239  return failure();
240 
241  auto *entryBlock = &src->front();
242  if (inlinedOperands.size() != entryBlock->getNumArguments())
243  return failure();
244 
245  // Map the provided call operands to the arguments of the region.
246  BlockAndValueMapping mapper;
247  for (unsigned i = 0, e = inlinedOperands.size(); i != e; ++i) {
248  // Verify that the types of the provided values match the function argument
249  // types.
250  BlockArgument regionArg = entryBlock->getArgument(i);
251  if (inlinedOperands[i].getType() != regionArg.getType())
252  return failure();
253  mapper.map(regionArg, inlinedOperands[i]);
254  }
255 
256  // Call into the main region inliner function.
257  return inlineRegionImpl(interface, src, inlineBlock, inlinePoint, mapper,
258  resultsToReplace, resultsToReplace.getTypes(),
259  inlineLoc, shouldCloneInlinedRegion, call);
260 }
261 
263  Operation *inlinePoint,
264  BlockAndValueMapping &mapper,
265  ValueRange resultsToReplace,
266  TypeRange regionResultTypes,
267  Optional<Location> inlineLoc,
268  bool shouldCloneInlinedRegion) {
269  return inlineRegion(interface, src, inlinePoint->getBlock(),
270  ++inlinePoint->getIterator(), mapper, resultsToReplace,
271  regionResultTypes, inlineLoc, shouldCloneInlinedRegion);
272 }
274 mlir::inlineRegion(InlinerInterface &interface, Region *src, Block *inlineBlock,
275  Block::iterator inlinePoint, BlockAndValueMapping &mapper,
276  ValueRange resultsToReplace, TypeRange regionResultTypes,
277  Optional<Location> inlineLoc,
278  bool shouldCloneInlinedRegion) {
279  return inlineRegionImpl(interface, src, inlineBlock, inlinePoint, mapper,
280  resultsToReplace, regionResultTypes, inlineLoc,
281  shouldCloneInlinedRegion);
282 }
283 
285  Operation *inlinePoint,
286  ValueRange inlinedOperands,
287  ValueRange resultsToReplace,
288  Optional<Location> inlineLoc,
289  bool shouldCloneInlinedRegion) {
290  return inlineRegion(interface, src, inlinePoint->getBlock(),
291  ++inlinePoint->getIterator(), inlinedOperands,
292  resultsToReplace, inlineLoc, shouldCloneInlinedRegion);
293 }
295 mlir::inlineRegion(InlinerInterface &interface, Region *src, Block *inlineBlock,
296  Block::iterator inlinePoint, ValueRange inlinedOperands,
297  ValueRange resultsToReplace, Optional<Location> inlineLoc,
298  bool shouldCloneInlinedRegion) {
299  return inlineRegionImpl(interface, src, inlineBlock, inlinePoint,
300  inlinedOperands, resultsToReplace, inlineLoc,
301  shouldCloneInlinedRegion);
302 }
303 
304 /// Utility function used to generate a cast operation from the given interface,
305 /// or return nullptr if a cast could not be generated.
308  OpBuilder &castBuilder, Value arg, Type type,
309  Location conversionLoc) {
310  if (!interface)
311  return nullptr;
312 
313  // Check to see if the interface for the call can materialize a conversion.
314  Operation *castOp = interface->materializeCallConversion(castBuilder, arg,
315  type, conversionLoc);
316  if (!castOp)
317  return nullptr;
318  castOps.push_back(castOp);
319 
320  // Ensure that the generated cast is correct.
321  assert(castOp->getNumOperands() == 1 && castOp->getOperand(0) == arg &&
322  castOp->getNumResults() == 1 && *castOp->result_type_begin() == type);
323  return castOp->getResult(0);
324 }
325 
326 /// This function inlines a given region, 'src', of a callable operation,
327 /// 'callable', into the location defined by the given call operation. This
328 /// function returns failure if inlining is not possible, success otherwise. On
329 /// failure, no changes are made to the module. 'shouldCloneInlinedRegion'
330 /// corresponds to whether the source region should be cloned into the 'call' or
331 /// spliced directly.
333  CallOpInterface call,
334  CallableOpInterface callable, Region *src,
335  bool shouldCloneInlinedRegion) {
336  // We expect the region to have at least one block.
337  if (src->empty())
338  return failure();
339  auto *entryBlock = &src->front();
340  ArrayRef<Type> callableResultTypes = callable.getCallableResults();
341 
342  // Make sure that the number of arguments and results matchup between the call
343  // and the region.
344  SmallVector<Value, 8> callOperands(call.getArgOperands());
345  SmallVector<Value, 8> callResults(call->getResults());
346  if (callOperands.size() != entryBlock->getNumArguments() ||
347  callResults.size() != callableResultTypes.size())
348  return failure();
349 
350  // A set of cast operations generated to matchup the signature of the region
351  // with the signature of the call.
353  castOps.reserve(callOperands.size() + callResults.size());
354 
355  // Functor used to cleanup generated state on failure.
356  auto cleanupState = [&] {
357  for (auto *op : castOps) {
358  op->getResult(0).replaceAllUsesWith(op->getOperand(0));
359  op->erase();
360  }
361  return failure();
362  };
363 
364  // Builder used for any conversion operations that need to be materialized.
365  OpBuilder castBuilder(call);
366  Location castLoc = call.getLoc();
367  const auto *callInterface = interface.getInterfaceFor(call->getDialect());
368 
369  // Map the provided call operands to the arguments of the region.
370  BlockAndValueMapping mapper;
371  for (unsigned i = 0, e = callOperands.size(); i != e; ++i) {
372  BlockArgument regionArg = entryBlock->getArgument(i);
373  Value operand = callOperands[i];
374 
375  // If the call operand doesn't match the expected region argument, try to
376  // generate a cast.
377  Type regionArgType = regionArg.getType();
378  if (operand.getType() != regionArgType) {
379  if (!(operand = materializeConversion(callInterface, castOps, castBuilder,
380  operand, regionArgType, castLoc)))
381  return cleanupState();
382  }
383  mapper.map(regionArg, operand);
384  }
385 
386  // Ensure that the resultant values of the call match the callable.
387  castBuilder.setInsertionPointAfter(call);
388  for (unsigned i = 0, e = callResults.size(); i != e; ++i) {
389  Value callResult = callResults[i];
390  if (callResult.getType() == callableResultTypes[i])
391  continue;
392 
393  // Generate a conversion that will produce the original type, so that the IR
394  // is still valid after the original call gets replaced.
395  Value castResult =
396  materializeConversion(callInterface, castOps, castBuilder, callResult,
397  callResult.getType(), castLoc);
398  if (!castResult)
399  return cleanupState();
400  callResult.replaceAllUsesWith(castResult);
401  castResult.getDefiningOp()->replaceUsesOfWith(castResult, callResult);
402  }
403 
404  // Check that it is legal to inline the callable into the call.
405  if (!interface.isLegalToInline(call, callable, shouldCloneInlinedRegion))
406  return cleanupState();
407 
408  // Attempt to inline the call.
409  if (failed(inlineRegionImpl(interface, src, call->getBlock(),
410  ++call->getIterator(), mapper, callResults,
411  callableResultTypes, call.getLoc(),
412  shouldCloneInlinedRegion, call)))
413  return cleanupState();
414  return success();
415 }
Include the generated interface declarations.
This class contains a list of basic blocks and a link to the parent operation it is attached to...
Definition: Region.h:26
iterator begin()
Definition: Block.h:134
static void remapInlinedLocations(iterator_range< Region::iterator > inlinedBlocks, Location callerLoc)
Remap locations from the inlined blocks with CallSiteLoc locations with the provided caller location...
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
BlockListType & getBlocks()
Definition: Region.h:45
Block represents an ordered list of Operations.
Definition: Block.h:29
Block & front()
Definition: Region.h:65
virtual void processInlinedBlocks(iterator_range< Region::iterator > inlinedBlocks)
Process a set of blocks that have been inlined.
Value getOperand(unsigned idx)
Definition: Operation.h:267
OpListType & getOperations()
Definition: Block.h:128
LogicalResult inlineRegion(InlinerInterface &interface, Region *src, Operation *inlinePoint, BlockAndValueMapping &mapper, ValueRange resultsToReplace, TypeRange regionResultTypes, Optional< Location > inlineLoc=llvm::None, bool shouldCloneInlinedRegion=true)
This function inlines a region, &#39;src&#39;, into another.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
Definition: LogicalResult.h:72
unsigned getNumOperands()
Definition: Operation.h:263
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:144
const DialectInlinerInterface * getInterfaceFor(Object *obj) const
Get the interface for a given object, or null if one is not registered.
void replaceAllUsesWith(Value newValue) const
Replace all uses of &#39;this&#39; value with the new value, updating anything in the IR that uses &#39;this&#39; to ...
Definition: Value.h:162
Region * getParent() const
Provide a &#39;getParent&#39; method for ilist_node_with_parent methods.
Definition: Block.cpp:26
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:48
void map(Block *from, Block *to)
Inserts a new mapping for &#39;from&#39; to &#39;to&#39;.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:358
LogicalResult inlineCall(InlinerInterface &interface, CallOpInterface call, CallableOpInterface callable, Region *src, bool shouldCloneInlinedRegion=true)
This function inlines a given region, &#39;src&#39;, of a callable operation, &#39;callable&#39;, into the location d...
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
void erase()
Unlink this Block from its parent region and delete it.
Definition: Block.cpp:54
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
OpListType::iterator iterator
Definition: Block.h:131
bool empty()
Definition: Region.h:60
iterator begin()
Definition: Region.h:55
static LogicalResult inlineRegionImpl(InlinerInterface &interface, Region *src, Block *inlineBlock, Block::iterator inlinePoint, BlockAndValueMapping &mapper, ValueRange resultsToReplace, TypeRange regionResultTypes, Optional< Location > inlineLoc, bool shouldCloneInlinedRegion, Operation *call=nullptr)
iterator end()
Definition: Block.h:135
Block * lookupOrNull(Block *from) const
Lookup a mapped value within the map.
static Value materializeConversion(const DialectInlinerInterface *interface, SmallVectorImpl< Operation *> &castOps, OpBuilder &castBuilder, Value arg, Type type, Location conversionLoc)
Utility function used to generate a cast operation from the given interface, or return nullptr if a c...
type_range getTypes() const
Definition: ValueRange.cpp:44
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:233
This is the interface that must be implemented by the dialects of operations to be inlined...
Definition: InliningUtils.h:40
OpResult getResult(unsigned idx)
Get the &#39;idx&#39;th result of this operation.
Definition: Operation.h:324
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:32
virtual bool shouldAnalyzeRecursively(Operation *op) const
virtual void processInlinedCallBlocks(Operation *call, iterator_range< Region::iterator > inlinedBlocks) const
void cloneInto(Region *dest, BlockAndValueMapping &mapper)
Clone the internal blocks from this region into dest.
Definition: Region.cpp:70
This class represents an argument of a Block.
Definition: Value.h:300
virtual bool isLegalToInline(Operation *call, Operation *callable, bool wouldBeCloned) const
These hooks mirror the hooks for the DialectInlinerInterface, with default implementations that call ...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
This interface provides the hooks into the inlining interface.
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition: Region.h:200
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
virtual void handleTerminator(Operation *op, Block *newDest) const
Handle the given inlined terminator by replacing it with a new operation as necessary.
bool contains(Block *from) const
Checks to see if a mapping for &#39;from&#39; exists.
Type getType() const
Return the type of this value.
Definition: Value.h:118
result_type_iterator result_type_begin()
Definition: Operation.h:343
iterator end()
Definition: Region.h:56
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
virtual Operation * materializeCallConversion(OpBuilder &builder, Value input, Type resultType, Location conversionLoc) const
Attempt to materialize a conversion for a type mismatch between a call from this dialect, and a callable region.
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:321
static void remapInlinedOperands(iterator_range< Region::iterator > inlinedBlocks, BlockAndValueMapping &mapper)
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Block.cpp:141
void replaceUsesOfWith(Value from, Value to)
Replace any uses of &#39;from&#39; with &#39;to&#39; within this operation.
Definition: Operation.cpp:180
This class helps build Operations.
Definition: Builders.h:192
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:345
Block * splitBlock(iterator splitBefore)
Split the block into two blocks before the specified operation or iterator.
Definition: Block.cpp:289