MLIR  14.0.0git
InliningUtils.cpp
Go to the documentation of this file.
1 //===- InliningUtils.cpp ---- Misc utilities for inlining -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements miscellaneous inlining utilities.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
16 #include "mlir/IR/Builders.h"
17 #include "mlir/IR/BuiltinOps.h"
18 #include "mlir/IR/Operation.h"
19 #include "llvm/ADT/MapVector.h"
20 #include "llvm/Support/Debug.h"
21 #include "llvm/Support/raw_ostream.h"
22 
23 #define DEBUG_TYPE "inlining"
24 
25 using namespace mlir;
26 
27 /// Remap locations from the inlined blocks with CallSiteLoc locations with the
28 /// provided caller location.
29 static void
31  Location callerLoc) {
32  DenseMap<Location, Location> mappedLocations;
33  auto remapOpLoc = [&](Operation *op) {
34  auto it = mappedLocations.find(op->getLoc());
35  if (it == mappedLocations.end()) {
36  auto newLoc = CallSiteLoc::get(op->getLoc(), callerLoc);
37  it = mappedLocations.try_emplace(op->getLoc(), newLoc).first;
38  }
39  op->setLoc(it->second);
40  };
41  for (auto &block : inlinedBlocks)
42  block.walk(remapOpLoc);
43 }
44 
46  BlockAndValueMapping &mapper) {
47  auto remapOperands = [&](Operation *op) {
48  for (auto &operand : op->getOpOperands())
49  if (auto mappedOp = mapper.lookupOrNull(operand.get()))
50  operand.set(mappedOp);
51  };
52  for (auto &block : inlinedBlocks)
53  block.walk(remapOperands);
54 }
55 
56 //===----------------------------------------------------------------------===//
57 // InlinerInterface
58 //===----------------------------------------------------------------------===//
59 
61  bool wouldBeCloned) const {
62  if (auto *handler = getInterfaceFor(call))
63  return handler->isLegalToInline(call, callable, wouldBeCloned);
64  return false;
65 }
66 
68  Region *dest, Region *src, bool wouldBeCloned,
69  BlockAndValueMapping &valueMapping) const {
70  // Regions can always be inlined into functions.
71  if (isa<FuncOp>(dest->getParentOp()))
72  return true;
73 
74  if (auto *handler = getInterfaceFor(dest->getParentOp()))
75  return handler->isLegalToInline(dest, src, wouldBeCloned, valueMapping);
76  return false;
77 }
78 
80  Operation *op, Region *dest, bool wouldBeCloned,
81  BlockAndValueMapping &valueMapping) const {
82  if (auto *handler = getInterfaceFor(op))
83  return handler->isLegalToInline(op, dest, wouldBeCloned, valueMapping);
84  return false;
85 }
86 
88  auto *handler = getInterfaceFor(op);
89  return handler ? handler->shouldAnalyzeRecursively(op) : true;
90 }
91 
92 /// Handle the given inlined terminator by replacing it with a new operation
93 /// as necessary.
95  auto *handler = getInterfaceFor(op);
96  assert(handler && "expected valid dialect handler");
97  handler->handleTerminator(op, newDest);
98 }
99 
100 /// Handle the given inlined terminator by replacing it with a new operation
101 /// as necessary.
103  ArrayRef<Value> valuesToRepl) const {
104  auto *handler = getInterfaceFor(op);
105  assert(handler && "expected valid dialect handler");
106  handler->handleTerminator(op, valuesToRepl);
107 }
108 
110  Operation *call, iterator_range<Region::iterator> inlinedBlocks) const {
111  auto *handler = getInterfaceFor(call);
112  assert(handler && "expected valid dialect handler");
113  handler->processInlinedCallBlocks(call, inlinedBlocks);
114 }
115 
116 /// Utility to check that all of the operations within 'src' can be inlined.
117 static bool isLegalToInline(InlinerInterface &interface, Region *src,
118  Region *insertRegion, bool shouldCloneInlinedRegion,
119  BlockAndValueMapping &valueMapping) {
120  for (auto &block : *src) {
121  for (auto &op : block) {
122  // Check this operation.
123  if (!interface.isLegalToInline(&op, insertRegion,
124  shouldCloneInlinedRegion, valueMapping)) {
125  LLVM_DEBUG({
126  llvm::dbgs() << "* Illegal to inline because of op: ";
127  op.dump();
128  });
129  return false;
130  }
131  // Check any nested regions.
132  if (interface.shouldAnalyzeRecursively(&op) &&
133  llvm::any_of(op.getRegions(), [&](Region &region) {
134  return !isLegalToInline(interface, &region, insertRegion,
135  shouldCloneInlinedRegion, valueMapping);
136  }))
137  return false;
138  }
139  }
140  return true;
141 }
142 
143 //===----------------------------------------------------------------------===//
144 // Inline Methods
145 //===----------------------------------------------------------------------===//
146 
147 static LogicalResult
148 inlineRegionImpl(InlinerInterface &interface, Region *src, Block *inlineBlock,
149  Block::iterator inlinePoint, BlockAndValueMapping &mapper,
150  ValueRange resultsToReplace, TypeRange regionResultTypes,
151  Optional<Location> inlineLoc, bool shouldCloneInlinedRegion,
152  Operation *call = nullptr) {
153  assert(resultsToReplace.size() == regionResultTypes.size());
154  // We expect the region to have at least one block.
155  if (src->empty())
156  return failure();
157 
158  // Check that all of the region arguments have been mapped.
159  auto *srcEntryBlock = &src->front();
160  if (llvm::any_of(srcEntryBlock->getArguments(),
161  [&](BlockArgument arg) { return !mapper.contains(arg); }))
162  return failure();
163 
164  // Check that the operations within the source region are valid to inline.
165  Region *insertRegion = inlineBlock->getParent();
166  if (!interface.isLegalToInline(insertRegion, src, shouldCloneInlinedRegion,
167  mapper) ||
168  !isLegalToInline(interface, src, insertRegion, shouldCloneInlinedRegion,
169  mapper))
170  return failure();
171 
172  // Check to see if the region is being cloned, or moved inline. In either
173  // case, move the new blocks after the 'insertBlock' to improve IR
174  // readability.
175  Block *postInsertBlock = inlineBlock->splitBlock(inlinePoint);
176  if (shouldCloneInlinedRegion)
177  src->cloneInto(insertRegion, postInsertBlock->getIterator(), mapper);
178  else
179  insertRegion->getBlocks().splice(postInsertBlock->getIterator(),
180  src->getBlocks(), src->begin(),
181  src->end());
182 
183  // Get the range of newly inserted blocks.
184  auto newBlocks = llvm::make_range(std::next(inlineBlock->getIterator()),
185  postInsertBlock->getIterator());
186  Block *firstNewBlock = &*newBlocks.begin();
187 
188  // Remap the locations of the inlined operations if a valid source location
189  // was provided.
190  if (inlineLoc && !inlineLoc->isa<UnknownLoc>())
191  remapInlinedLocations(newBlocks, *inlineLoc);
192 
193  // If the blocks were moved in-place, make sure to remap any necessary
194  // operands.
195  if (!shouldCloneInlinedRegion)
196  remapInlinedOperands(newBlocks, mapper);
197 
198  // Process the newly inlined blocks.
199  if (call)
200  interface.processInlinedCallBlocks(call, newBlocks);
201  interface.processInlinedBlocks(newBlocks);
202 
203  // Handle the case where only a single block was inlined.
204  if (std::next(newBlocks.begin()) == newBlocks.end()) {
205  // Have the interface handle the terminator of this block.
206  auto *firstBlockTerminator = firstNewBlock->getTerminator();
207  interface.handleTerminator(firstBlockTerminator,
208  llvm::to_vector<6>(resultsToReplace));
209  firstBlockTerminator->erase();
210 
211  // Merge the post insert block into the cloned entry block.
212  firstNewBlock->getOperations().splice(firstNewBlock->end(),
213  postInsertBlock->getOperations());
214  postInsertBlock->erase();
215  } else {
216  // Otherwise, there were multiple blocks inlined. Add arguments to the post
217  // insertion block to represent the results to replace.
218  for (const auto &resultToRepl : llvm::enumerate(resultsToReplace)) {
219  resultToRepl.value().replaceAllUsesWith(
220  postInsertBlock->addArgument(regionResultTypes[resultToRepl.index()],
221  resultToRepl.value().getLoc()));
222  }
223 
224  /// Handle the terminators for each of the new blocks.
225  for (auto &newBlock : newBlocks)
226  interface.handleTerminator(newBlock.getTerminator(), postInsertBlock);
227  }
228 
229  // Splice the instructions of the inlined entry block into the insert block.
230  inlineBlock->getOperations().splice(inlineBlock->end(),
231  firstNewBlock->getOperations());
232  firstNewBlock->erase();
233  return success();
234 }
235 
236 static LogicalResult
237 inlineRegionImpl(InlinerInterface &interface, Region *src, Block *inlineBlock,
238  Block::iterator inlinePoint, ValueRange inlinedOperands,
239  ValueRange resultsToReplace, Optional<Location> inlineLoc,
240  bool shouldCloneInlinedRegion, Operation *call = nullptr) {
241  // We expect the region to have at least one block.
242  if (src->empty())
243  return failure();
244 
245  auto *entryBlock = &src->front();
246  if (inlinedOperands.size() != entryBlock->getNumArguments())
247  return failure();
248 
249  // Map the provided call operands to the arguments of the region.
250  BlockAndValueMapping mapper;
251  for (unsigned i = 0, e = inlinedOperands.size(); i != e; ++i) {
252  // Verify that the types of the provided values match the function argument
253  // types.
254  BlockArgument regionArg = entryBlock->getArgument(i);
255  if (inlinedOperands[i].getType() != regionArg.getType())
256  return failure();
257  mapper.map(regionArg, inlinedOperands[i]);
258  }
259 
260  // Call into the main region inliner function.
261  return inlineRegionImpl(interface, src, inlineBlock, inlinePoint, mapper,
262  resultsToReplace, resultsToReplace.getTypes(),
263  inlineLoc, shouldCloneInlinedRegion, call);
264 }
265 
267  Operation *inlinePoint,
268  BlockAndValueMapping &mapper,
269  ValueRange resultsToReplace,
270  TypeRange regionResultTypes,
271  Optional<Location> inlineLoc,
272  bool shouldCloneInlinedRegion) {
273  return inlineRegion(interface, src, inlinePoint->getBlock(),
274  ++inlinePoint->getIterator(), mapper, resultsToReplace,
275  regionResultTypes, inlineLoc, shouldCloneInlinedRegion);
276 }
278 mlir::inlineRegion(InlinerInterface &interface, Region *src, Block *inlineBlock,
279  Block::iterator inlinePoint, BlockAndValueMapping &mapper,
280  ValueRange resultsToReplace, TypeRange regionResultTypes,
281  Optional<Location> inlineLoc,
282  bool shouldCloneInlinedRegion) {
283  return inlineRegionImpl(interface, src, inlineBlock, inlinePoint, mapper,
284  resultsToReplace, regionResultTypes, inlineLoc,
285  shouldCloneInlinedRegion);
286 }
287 
289  Operation *inlinePoint,
290  ValueRange inlinedOperands,
291  ValueRange resultsToReplace,
292  Optional<Location> inlineLoc,
293  bool shouldCloneInlinedRegion) {
294  return inlineRegion(interface, src, inlinePoint->getBlock(),
295  ++inlinePoint->getIterator(), inlinedOperands,
296  resultsToReplace, inlineLoc, shouldCloneInlinedRegion);
297 }
299 mlir::inlineRegion(InlinerInterface &interface, Region *src, Block *inlineBlock,
300  Block::iterator inlinePoint, ValueRange inlinedOperands,
301  ValueRange resultsToReplace, Optional<Location> inlineLoc,
302  bool shouldCloneInlinedRegion) {
303  return inlineRegionImpl(interface, src, inlineBlock, inlinePoint,
304  inlinedOperands, resultsToReplace, inlineLoc,
305  shouldCloneInlinedRegion);
306 }
307 
308 /// Utility function used to generate a cast operation from the given interface,
309 /// or return nullptr if a cast could not be generated.
312  OpBuilder &castBuilder, Value arg, Type type,
313  Location conversionLoc) {
314  if (!interface)
315  return nullptr;
316 
317  // Check to see if the interface for the call can materialize a conversion.
318  Operation *castOp = interface->materializeCallConversion(castBuilder, arg,
319  type, conversionLoc);
320  if (!castOp)
321  return nullptr;
322  castOps.push_back(castOp);
323 
324  // Ensure that the generated cast is correct.
325  assert(castOp->getNumOperands() == 1 && castOp->getOperand(0) == arg &&
326  castOp->getNumResults() == 1 && *castOp->result_type_begin() == type);
327  return castOp->getResult(0);
328 }
329 
330 /// This function inlines a given region, 'src', of a callable operation,
331 /// 'callable', into the location defined by the given call operation. This
332 /// function returns failure if inlining is not possible, success otherwise. On
333 /// failure, no changes are made to the module. 'shouldCloneInlinedRegion'
334 /// corresponds to whether the source region should be cloned into the 'call' or
335 /// spliced directly.
337  CallOpInterface call,
338  CallableOpInterface callable, Region *src,
339  bool shouldCloneInlinedRegion) {
340  // We expect the region to have at least one block.
341  if (src->empty())
342  return failure();
343  auto *entryBlock = &src->front();
344  ArrayRef<Type> callableResultTypes = callable.getCallableResults();
345 
346  // Make sure that the number of arguments and results matchup between the call
347  // and the region.
348  SmallVector<Value, 8> callOperands(call.getArgOperands());
349  SmallVector<Value, 8> callResults(call->getResults());
350  if (callOperands.size() != entryBlock->getNumArguments() ||
351  callResults.size() != callableResultTypes.size())
352  return failure();
353 
354  // A set of cast operations generated to matchup the signature of the region
355  // with the signature of the call.
357  castOps.reserve(callOperands.size() + callResults.size());
358 
359  // Functor used to cleanup generated state on failure.
360  auto cleanupState = [&] {
361  for (auto *op : castOps) {
362  op->getResult(0).replaceAllUsesWith(op->getOperand(0));
363  op->erase();
364  }
365  return failure();
366  };
367 
368  // Builder used for any conversion operations that need to be materialized.
369  OpBuilder castBuilder(call);
370  Location castLoc = call.getLoc();
371  const auto *callInterface = interface.getInterfaceFor(call->getDialect());
372 
373  // Map the provided call operands to the arguments of the region.
374  BlockAndValueMapping mapper;
375  for (unsigned i = 0, e = callOperands.size(); i != e; ++i) {
376  BlockArgument regionArg = entryBlock->getArgument(i);
377  Value operand = callOperands[i];
378 
379  // If the call operand doesn't match the expected region argument, try to
380  // generate a cast.
381  Type regionArgType = regionArg.getType();
382  if (operand.getType() != regionArgType) {
383  if (!(operand = materializeConversion(callInterface, castOps, castBuilder,
384  operand, regionArgType, castLoc)))
385  return cleanupState();
386  }
387  mapper.map(regionArg, operand);
388  }
389 
390  // Ensure that the resultant values of the call match the callable.
391  castBuilder.setInsertionPointAfter(call);
392  for (unsigned i = 0, e = callResults.size(); i != e; ++i) {
393  Value callResult = callResults[i];
394  if (callResult.getType() == callableResultTypes[i])
395  continue;
396 
397  // Generate a conversion that will produce the original type, so that the IR
398  // is still valid after the original call gets replaced.
399  Value castResult =
400  materializeConversion(callInterface, castOps, castBuilder, callResult,
401  callResult.getType(), castLoc);
402  if (!castResult)
403  return cleanupState();
404  callResult.replaceAllUsesWith(castResult);
405  castResult.getDefiningOp()->replaceUsesOfWith(castResult, callResult);
406  }
407 
408  // Check that it is legal to inline the callable into the call.
409  if (!interface.isLegalToInline(call, callable, shouldCloneInlinedRegion))
410  return cleanupState();
411 
412  // Attempt to inline the call.
413  if (failed(inlineRegionImpl(interface, src, call->getBlock(),
414  ++call->getIterator(), mapper, callResults,
415  callableResultTypes, call.getLoc(),
416  shouldCloneInlinedRegion, call)))
417  return cleanupState();
418  return success();
419 }
Include the generated interface declarations.
This class contains a list of basic blocks and a link to the parent operation it is attached to...
Definition: Region.h:26
iterator begin()
Definition: Block.h:134
static void remapInlinedLocations(iterator_range< Region::iterator > inlinedBlocks, Location callerLoc)
Remap locations from the inlined blocks with CallSiteLoc locations with the provided caller location...
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
BlockListType & getBlocks()
Definition: Region.h:45
Block represents an ordered list of Operations.
Definition: Block.h:29
Block & front()
Definition: Region.h:65
virtual void processInlinedBlocks(iterator_range< Region::iterator > inlinedBlocks)
Process a set of blocks that have been inlined.
Value getOperand(unsigned idx)
Definition: Operation.h:219
OpListType & getOperations()
Definition: Block.h:128
LogicalResult inlineRegion(InlinerInterface &interface, Region *src, Operation *inlinePoint, BlockAndValueMapping &mapper, ValueRange resultsToReplace, TypeRange regionResultTypes, Optional< Location > inlineLoc=llvm::None, bool shouldCloneInlinedRegion=true)
This function inlines a region, &#39;src&#39;, into another.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
Definition: LogicalResult.h:72
unsigned getNumOperands()
Definition: Operation.h:215
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:96
const DialectInlinerInterface * getInterfaceFor(Object *obj) const
Get the interface for a given object, or null if one is not registered.
void replaceAllUsesWith(Value newValue) const
Replace all uses of &#39;this&#39; value with the new value, updating anything in the IR that uses &#39;this&#39; to ...
Definition: Value.h:161
Region * getParent() const
Provide a &#39;getParent&#39; method for ilist_node_with_parent methods.
Definition: Block.cpp:26
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:48
void map(Block *from, Block *to)
Inserts a new mapping for &#39;from&#39; to &#39;to&#39;.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:343
LogicalResult inlineCall(InlinerInterface &interface, CallOpInterface call, CallableOpInterface callable, Region *src, bool shouldCloneInlinedRegion=true)
This function inlines a given region, &#39;src&#39;, of a callable operation, &#39;callable&#39;, into the location d...
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
void erase()
Unlink this Block from its parent region and delete it.
Definition: Block.cpp:54
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
OpListType::iterator iterator
Definition: Block.h:131
bool empty()
Definition: Region.h:60
iterator begin()
Definition: Region.h:55
static LogicalResult inlineRegionImpl(InlinerInterface &interface, Region *src, Block *inlineBlock, Block::iterator inlinePoint, BlockAndValueMapping &mapper, ValueRange resultsToReplace, TypeRange regionResultTypes, Optional< Location > inlineLoc, bool shouldCloneInlinedRegion, Operation *call=nullptr)
iterator end()
Definition: Block.h:135
Block * lookupOrNull(Block *from) const
Lookup a mapped value within the map.
static Value materializeConversion(const DialectInlinerInterface *interface, SmallVectorImpl< Operation *> &castOps, OpBuilder &castBuilder, Value arg, Type type, Location conversionLoc)
Utility function used to generate a cast operation from the given interface, or return nullptr if a c...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:206
This is the interface that must be implemented by the dialects of operations to be inlined...
Definition: InliningUtils.h:41
OpResult getResult(unsigned idx)
Get the &#39;idx&#39;th result of this operation.
Definition: Operation.h:276
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:38
virtual bool shouldAnalyzeRecursively(Operation *op) const
virtual void processInlinedCallBlocks(Operation *call, iterator_range< Region::iterator > inlinedBlocks) const
void cloneInto(Region *dest, BlockAndValueMapping &mapper)
Clone the internal blocks from this region into dest.
Definition: Region.cpp:70
This class represents an argument of a Block.
Definition: Value.h:298
virtual bool isLegalToInline(Operation *call, Operation *callable, bool wouldBeCloned) const
These hooks mirror the hooks for the DialectInlinerInterface, with default implementations that call ...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
This interface provides the hooks into the inlining interface.
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition: Region.h:200
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:84
virtual void handleTerminator(Operation *op, Block *newDest) const
Handle the given inlined terminator by replacing it with a new operation as necessary.
bool contains(Block *from) const
Checks to see if a mapping for &#39;from&#39; exists.
Type getType() const
Return the type of this value.
Definition: Value.h:117
result_type_iterator result_type_begin()
Definition: Operation.h:295
iterator end()
Definition: Region.h:56
type_range getTypes() const
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
virtual Operation * materializeCallConversion(OpBuilder &builder, Value input, Type resultType, Location conversionLoc) const
Attempt to materialize a conversion for a type mismatch between a call from this dialect, and a callable region.
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:273
static void remapInlinedOperands(iterator_range< Region::iterator > inlinedBlocks, BlockAndValueMapping &mapper)
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Block.cpp:141
void replaceUsesOfWith(Value from, Value to)
Replace any uses of &#39;from&#39; with &#39;to&#39; within this operation.
Definition: Operation.cpp:190
This class helps build Operations.
Definition: Builders.h:177
This class provides an abstraction over the different types of ranges over Values.
Block * splitBlock(iterator splitBefore)
Split the block into two blocks before the specified operation or iterator.
Definition: Block.cpp:289