MLIR  16.0.0git
InliningUtils.cpp
Go to the documentation of this file.
1 //===- InliningUtils.cpp ---- Misc utilities for inlining -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements miscellaneous inlining utilities.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
16 #include "mlir/IR/Builders.h"
17 #include "mlir/IR/Operation.h"
19 #include "llvm/ADT/MapVector.h"
20 #include "llvm/Support/Debug.h"
21 #include "llvm/Support/raw_ostream.h"
22 
23 #define DEBUG_TYPE "inlining"
24 
25 using namespace mlir;
26 
27 /// Remap locations from the inlined blocks with CallSiteLoc locations with the
28 /// provided caller location.
29 static void
31  Location callerLoc) {
32  DenseMap<Location, Location> mappedLocations;
33  auto remapOpLoc = [&](Operation *op) {
34  auto it = mappedLocations.find(op->getLoc());
35  if (it == mappedLocations.end()) {
36  auto newLoc = CallSiteLoc::get(op->getLoc(), callerLoc);
37  it = mappedLocations.try_emplace(op->getLoc(), newLoc).first;
38  }
39  op->setLoc(it->second);
40  };
41  for (auto &block : inlinedBlocks)
42  block.walk(remapOpLoc);
43 }
44 
46  BlockAndValueMapping &mapper) {
47  auto remapOperands = [&](Operation *op) {
48  for (auto &operand : op->getOpOperands())
49  if (auto mappedOp = mapper.lookupOrNull(operand.get()))
50  operand.set(mappedOp);
51  };
52  for (auto &block : inlinedBlocks)
53  block.walk(remapOperands);
54 }
55 
56 //===----------------------------------------------------------------------===//
57 // InlinerInterface
58 //===----------------------------------------------------------------------===//
59 
61  bool wouldBeCloned) const {
62  if (auto *handler = getInterfaceFor(call))
63  return handler->isLegalToInline(call, callable, wouldBeCloned);
64  return false;
65 }
66 
68  Region *dest, Region *src, bool wouldBeCloned,
69  BlockAndValueMapping &valueMapping) const {
70  if (auto *handler = getInterfaceFor(dest->getParentOp()))
71  return handler->isLegalToInline(dest, src, wouldBeCloned, valueMapping);
72  return false;
73 }
74 
76  Operation *op, Region *dest, bool wouldBeCloned,
77  BlockAndValueMapping &valueMapping) const {
78  if (auto *handler = getInterfaceFor(op))
79  return handler->isLegalToInline(op, dest, wouldBeCloned, valueMapping);
80  return false;
81 }
82 
84  auto *handler = getInterfaceFor(op);
85  return handler ? handler->shouldAnalyzeRecursively(op) : true;
86 }
87 
88 /// Handle the given inlined terminator by replacing it with a new operation
89 /// as necessary.
91  auto *handler = getInterfaceFor(op);
92  assert(handler && "expected valid dialect handler");
93  handler->handleTerminator(op, newDest);
94 }
95 
96 /// Handle the given inlined terminator by replacing it with a new operation
97 /// as necessary.
99  ArrayRef<Value> valuesToRepl) const {
100  auto *handler = getInterfaceFor(op);
101  assert(handler && "expected valid dialect handler");
102  handler->handleTerminator(op, valuesToRepl);
103 }
104 
106  Operation *call, iterator_range<Region::iterator> inlinedBlocks) const {
107  auto *handler = getInterfaceFor(call);
108  assert(handler && "expected valid dialect handler");
109  handler->processInlinedCallBlocks(call, inlinedBlocks);
110 }
111 
112 /// Utility to check that all of the operations within 'src' can be inlined.
113 static bool isLegalToInline(InlinerInterface &interface, Region *src,
114  Region *insertRegion, bool shouldCloneInlinedRegion,
115  BlockAndValueMapping &valueMapping) {
116  for (auto &block : *src) {
117  for (auto &op : block) {
118  // Check this operation.
119  if (!interface.isLegalToInline(&op, insertRegion,
120  shouldCloneInlinedRegion, valueMapping)) {
121  LLVM_DEBUG({
122  llvm::dbgs() << "* Illegal to inline because of op: ";
123  op.dump();
124  });
125  return false;
126  }
127  // Check any nested regions.
128  if (interface.shouldAnalyzeRecursively(&op) &&
129  llvm::any_of(op.getRegions(), [&](Region &region) {
130  return !isLegalToInline(interface, &region, insertRegion,
131  shouldCloneInlinedRegion, valueMapping);
132  }))
133  return false;
134  }
135  }
136  return true;
137 }
138 
139 //===----------------------------------------------------------------------===//
140 // Inline Methods
141 //===----------------------------------------------------------------------===//
142 
143 static LogicalResult
144 inlineRegionImpl(InlinerInterface &interface, Region *src, Block *inlineBlock,
145  Block::iterator inlinePoint, BlockAndValueMapping &mapper,
146  ValueRange resultsToReplace, TypeRange regionResultTypes,
147  Optional<Location> inlineLoc, bool shouldCloneInlinedRegion,
148  Operation *call = nullptr) {
149  assert(resultsToReplace.size() == regionResultTypes.size());
150  // We expect the region to have at least one block.
151  if (src->empty())
152  return failure();
153 
154  // Check that all of the region arguments have been mapped.
155  auto *srcEntryBlock = &src->front();
156  if (llvm::any_of(srcEntryBlock->getArguments(),
157  [&](BlockArgument arg) { return !mapper.contains(arg); }))
158  return failure();
159 
160  // Check that the operations within the source region are valid to inline.
161  Region *insertRegion = inlineBlock->getParent();
162  if (!interface.isLegalToInline(insertRegion, src, shouldCloneInlinedRegion,
163  mapper) ||
164  !isLegalToInline(interface, src, insertRegion, shouldCloneInlinedRegion,
165  mapper))
166  return failure();
167 
168  // Check to see if the region is being cloned, or moved inline. In either
169  // case, move the new blocks after the 'insertBlock' to improve IR
170  // readability.
171  Block *postInsertBlock = inlineBlock->splitBlock(inlinePoint);
172  if (shouldCloneInlinedRegion)
173  src->cloneInto(insertRegion, postInsertBlock->getIterator(), mapper);
174  else
175  insertRegion->getBlocks().splice(postInsertBlock->getIterator(),
176  src->getBlocks(), src->begin(),
177  src->end());
178 
179  // Get the range of newly inserted blocks.
180  auto newBlocks = llvm::make_range(std::next(inlineBlock->getIterator()),
181  postInsertBlock->getIterator());
182  Block *firstNewBlock = &*newBlocks.begin();
183 
184  // Remap the locations of the inlined operations if a valid source location
185  // was provided.
186  if (inlineLoc && !inlineLoc->isa<UnknownLoc>())
187  remapInlinedLocations(newBlocks, *inlineLoc);
188 
189  // If the blocks were moved in-place, make sure to remap any necessary
190  // operands.
191  if (!shouldCloneInlinedRegion)
192  remapInlinedOperands(newBlocks, mapper);
193 
194  // Process the newly inlined blocks.
195  if (call)
196  interface.processInlinedCallBlocks(call, newBlocks);
197  interface.processInlinedBlocks(newBlocks);
198 
199  // Handle the case where only a single block was inlined.
200  if (std::next(newBlocks.begin()) == newBlocks.end()) {
201  // Have the interface handle the terminator of this block.
202  auto *firstBlockTerminator = firstNewBlock->getTerminator();
203  interface.handleTerminator(firstBlockTerminator,
204  llvm::to_vector<6>(resultsToReplace));
205  firstBlockTerminator->erase();
206 
207  // Merge the post insert block into the cloned entry block.
208  firstNewBlock->getOperations().splice(firstNewBlock->end(),
209  postInsertBlock->getOperations());
210  postInsertBlock->erase();
211  } else {
212  // Otherwise, there were multiple blocks inlined. Add arguments to the post
213  // insertion block to represent the results to replace.
214  for (const auto &resultToRepl : llvm::enumerate(resultsToReplace)) {
215  resultToRepl.value().replaceAllUsesWith(
216  postInsertBlock->addArgument(regionResultTypes[resultToRepl.index()],
217  resultToRepl.value().getLoc()));
218  }
219 
220  /// Handle the terminators for each of the new blocks.
221  for (auto &newBlock : newBlocks)
222  interface.handleTerminator(newBlock.getTerminator(), postInsertBlock);
223  }
224 
225  // Splice the instructions of the inlined entry block into the insert block.
226  inlineBlock->getOperations().splice(inlineBlock->end(),
227  firstNewBlock->getOperations());
228  firstNewBlock->erase();
229  return success();
230 }
231 
232 static LogicalResult
233 inlineRegionImpl(InlinerInterface &interface, Region *src, Block *inlineBlock,
234  Block::iterator inlinePoint, ValueRange inlinedOperands,
235  ValueRange resultsToReplace, Optional<Location> inlineLoc,
236  bool shouldCloneInlinedRegion, Operation *call = nullptr) {
237  // We expect the region to have at least one block.
238  if (src->empty())
239  return failure();
240 
241  auto *entryBlock = &src->front();
242  if (inlinedOperands.size() != entryBlock->getNumArguments())
243  return failure();
244 
245  // Map the provided call operands to the arguments of the region.
246  BlockAndValueMapping mapper;
247  for (unsigned i = 0, e = inlinedOperands.size(); i != e; ++i) {
248  // Verify that the types of the provided values match the function argument
249  // types.
250  BlockArgument regionArg = entryBlock->getArgument(i);
251  if (inlinedOperands[i].getType() != regionArg.getType())
252  return failure();
253  mapper.map(regionArg, inlinedOperands[i]);
254  }
255 
256  // Call into the main region inliner function.
257  return inlineRegionImpl(interface, src, inlineBlock, inlinePoint, mapper,
258  resultsToReplace, resultsToReplace.getTypes(),
259  inlineLoc, shouldCloneInlinedRegion, call);
260 }
261 
263  Operation *inlinePoint,
264  BlockAndValueMapping &mapper,
265  ValueRange resultsToReplace,
266  TypeRange regionResultTypes,
267  Optional<Location> inlineLoc,
268  bool shouldCloneInlinedRegion) {
269  return inlineRegion(interface, src, inlinePoint->getBlock(),
270  ++inlinePoint->getIterator(), mapper, resultsToReplace,
271  regionResultTypes, inlineLoc, shouldCloneInlinedRegion);
272 }
274 mlir::inlineRegion(InlinerInterface &interface, Region *src, Block *inlineBlock,
275  Block::iterator inlinePoint, BlockAndValueMapping &mapper,
276  ValueRange resultsToReplace, TypeRange regionResultTypes,
277  Optional<Location> inlineLoc,
278  bool shouldCloneInlinedRegion) {
279  return inlineRegionImpl(interface, src, inlineBlock, inlinePoint, mapper,
280  resultsToReplace, regionResultTypes, inlineLoc,
281  shouldCloneInlinedRegion);
282 }
283 
285  Operation *inlinePoint,
286  ValueRange inlinedOperands,
287  ValueRange resultsToReplace,
288  Optional<Location> inlineLoc,
289  bool shouldCloneInlinedRegion) {
290  return inlineRegion(interface, src, inlinePoint->getBlock(),
291  ++inlinePoint->getIterator(), inlinedOperands,
292  resultsToReplace, inlineLoc, shouldCloneInlinedRegion);
293 }
295 mlir::inlineRegion(InlinerInterface &interface, Region *src, Block *inlineBlock,
296  Block::iterator inlinePoint, ValueRange inlinedOperands,
297  ValueRange resultsToReplace, Optional<Location> inlineLoc,
298  bool shouldCloneInlinedRegion) {
299  return inlineRegionImpl(interface, src, inlineBlock, inlinePoint,
300  inlinedOperands, resultsToReplace, inlineLoc,
301  shouldCloneInlinedRegion);
302 }
303 
304 /// Utility function used to generate a cast operation from the given interface,
305 /// or return nullptr if a cast could not be generated.
308  OpBuilder &castBuilder, Value arg, Type type,
309  Location conversionLoc) {
310  if (!interface)
311  return nullptr;
312 
313  // Check to see if the interface for the call can materialize a conversion.
314  Operation *castOp = interface->materializeCallConversion(castBuilder, arg,
315  type, conversionLoc);
316  if (!castOp)
317  return nullptr;
318  castOps.push_back(castOp);
319 
320  // Ensure that the generated cast is correct.
321  assert(castOp->getNumOperands() == 1 && castOp->getOperand(0) == arg &&
322  castOp->getNumResults() == 1 && *castOp->result_type_begin() == type);
323  return castOp->getResult(0);
324 }
325 
326 /// This function inlines a given region, 'src', of a callable operation,
327 /// 'callable', into the location defined by the given call operation. This
328 /// function returns failure if inlining is not possible, success otherwise. On
329 /// failure, no changes are made to the module. 'shouldCloneInlinedRegion'
330 /// corresponds to whether the source region should be cloned into the 'call' or
331 /// spliced directly.
333  CallOpInterface call,
334  CallableOpInterface callable, Region *src,
335  bool shouldCloneInlinedRegion) {
336  // We expect the region to have at least one block.
337  if (src->empty())
338  return failure();
339  auto *entryBlock = &src->front();
340  ArrayRef<Type> callableResultTypes = callable.getCallableResults();
341 
342  // Make sure that the number of arguments and results matchup between the call
343  // and the region.
344  SmallVector<Value, 8> callOperands(call.getArgOperands());
345  SmallVector<Value, 8> callResults(call->getResults());
346  if (callOperands.size() != entryBlock->getNumArguments() ||
347  callResults.size() != callableResultTypes.size())
348  return failure();
349 
350  // A set of cast operations generated to matchup the signature of the region
351  // with the signature of the call.
353  castOps.reserve(callOperands.size() + callResults.size());
354 
355  // Functor used to cleanup generated state on failure.
356  auto cleanupState = [&] {
357  for (auto *op : castOps) {
358  op->getResult(0).replaceAllUsesWith(op->getOperand(0));
359  op->erase();
360  }
361  return failure();
362  };
363 
364  // Builder used for any conversion operations that need to be materialized.
365  OpBuilder castBuilder(call);
366  Location castLoc = call.getLoc();
367  const auto *callInterface = interface.getInterfaceFor(call->getDialect());
368 
369  // Map the provided call operands to the arguments of the region.
370  BlockAndValueMapping mapper;
371  for (unsigned i = 0, e = callOperands.size(); i != e; ++i) {
372  BlockArgument regionArg = entryBlock->getArgument(i);
373  Value operand = callOperands[i];
374 
375  // If the call operand doesn't match the expected region argument, try to
376  // generate a cast.
377  Type regionArgType = regionArg.getType();
378  if (operand.getType() != regionArgType) {
379  if (!(operand = materializeConversion(callInterface, castOps, castBuilder,
380  operand, regionArgType, castLoc)))
381  return cleanupState();
382  }
383  mapper.map(regionArg, operand);
384  }
385 
386  // Ensure that the resultant values of the call match the callable.
387  castBuilder.setInsertionPointAfter(call);
388  for (unsigned i = 0, e = callResults.size(); i != e; ++i) {
389  Value callResult = callResults[i];
390  if (callResult.getType() == callableResultTypes[i])
391  continue;
392 
393  // Generate a conversion that will produce the original type, so that the IR
394  // is still valid after the original call gets replaced.
395  Value castResult =
396  materializeConversion(callInterface, castOps, castBuilder, callResult,
397  callResult.getType(), castLoc);
398  if (!castResult)
399  return cleanupState();
400  callResult.replaceAllUsesWith(castResult);
401  castResult.getDefiningOp()->replaceUsesOfWith(castResult, callResult);
402  }
403 
404  // Check that it is legal to inline the callable into the call.
405  if (!interface.isLegalToInline(call, callable, shouldCloneInlinedRegion))
406  return cleanupState();
407 
408  // Attempt to inline the call.
409  if (failed(inlineRegionImpl(interface, src, call->getBlock(),
410  ++call->getIterator(), mapper, callResults,
411  callableResultTypes, call.getLoc(),
412  shouldCloneInlinedRegion, call)))
413  return cleanupState();
414  return success();
415 }
static LogicalResult inlineRegionImpl(InlinerInterface &interface, Region *src, Block *inlineBlock, Block::iterator inlinePoint, BlockAndValueMapping &mapper, ValueRange resultsToReplace, TypeRange regionResultTypes, Optional< Location > inlineLoc, bool shouldCloneInlinedRegion, Operation *call=nullptr)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, BlockAndValueMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static Value materializeConversion(const DialectInlinerInterface *interface, SmallVectorImpl< Operation * > &castOps, OpBuilder &castBuilder, Value arg, Type type, Location conversionLoc)
Utility function used to generate a cast operation from the given interface, or return nullptr if a c...
static void remapInlinedOperands(iterator_range< Region::iterator > inlinedBlocks, BlockAndValueMapping &mapper)
static void remapInlinedLocations(iterator_range< Region::iterator > inlinedBlocks, Location callerLoc)
Remap locations from the inlined blocks with CallSiteLoc locations with the provided caller location.
void map(Block *from, Block *to)
Inserts a new mapping for 'from' to 'to'.
Block * lookupOrNull(Block *from) const
Lookup a mapped value within the map.
This class represents an argument of a Block.
Definition: Value.h:296
Block represents an ordered list of Operations.
Definition: Block.h:30
OpListType::iterator iterator
Definition: Block.h:129
void erase()
Unlink this Block from its parent region and delete it.
Definition: Block.cpp:54
Block * splitBlock(iterator splitBefore)
Split the block into two blocks before the specified operation or iterator.
Definition: Block.cpp:291
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition: Block.cpp:26
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:232
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Block.cpp:141
OpListType & getOperations()
Definition: Block.h:126
iterator end()
Definition: Block.h:133
iterator begin()
Definition: Block.h:132
This is the interface that must be implemented by the dialects of operations to be inlined.
Definition: InliningUtils.h:41
virtual Operation * materializeCallConversion(OpBuilder &builder, Value input, Type resultType, Location conversionLoc) const
Attempt to materialize a conversion for a type mismatch between a call from this dialect,...
const DialectInlinerInterface * getInterfaceFor(Object *obj) const
Get the interface for a given object, or null if one is not registered.
This interface provides the hooks into the inlining interface.
virtual bool shouldAnalyzeRecursively(Operation *op) const
virtual void handleTerminator(Operation *op, Block *newDest) const
Handle the given inlined terminator by replacing it with a new operation as necessary.
virtual void processInlinedCallBlocks(Operation *call, iterator_range< Region::iterator > inlinedBlocks) const
virtual void processInlinedBlocks(iterator_range< Region::iterator > inlinedBlocks)
Process a set of blocks that have been inlined.
virtual bool isLegalToInline(Operation *call, Operation *callable, bool wouldBeCloned) const
These hooks mirror the hooks for the DialectInlinerInterface, with default implementations that call ...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:64
This class helps build Operations.
Definition: Builders.h:198
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:364
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:31
void replaceUsesOfWith(Value from, Value to)
Replace any uses of 'from' with 'to' within this operation.
Definition: Operation.cpp:184
Value getOperand(unsigned idx)
Definition: Operation.h:267
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:324
unsigned getNumOperands()
Definition: Operation.h:263
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:144
result_type_iterator result_type_begin()
Definition: Operation.h:343
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:321
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition: Region.h:200
bool empty()
Definition: Region.h:60
iterator end()
Definition: Region.h:56
iterator begin()
Definition: Region.h:55
BlockListType & getBlocks()
Definition: Region.h:45
void cloneInto(Region *dest, BlockAndValueMapping &mapper)
Clone the internal blocks from this region into dest.
Definition: Region.cpp:70
Block & front()
Definition: Region.h:65
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:349
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
Type getType() const
Return the type of this value.
Definition: Value.h:114
void replaceAllUsesWith(Value newValue) const
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
Definition: Value.h:158
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:230
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
LogicalResult inlineRegion(InlinerInterface &interface, Region *src, Operation *inlinePoint, BlockAndValueMapping &mapper, ValueRange resultsToReplace, TypeRange regionResultTypes, Optional< Location > inlineLoc=std::nullopt, bool shouldCloneInlinedRegion=true)
This function inlines a region, 'src', into another.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
LogicalResult inlineCall(InlinerInterface &interface, CallOpInterface call, CallableOpInterface callable, Region *src, bool shouldCloneInlinedRegion=true)
This function inlines a given region, 'src', of a callable operation, 'callable', into the location d...
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26