MLIR  14.0.0git
DialectConversion.cpp
Go to the documentation of this file.
1 //===- DialectConversion.cpp - MLIR dialect conversion generic pass -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 #include "mlir/IR/Block.h"
12 #include "mlir/IR/Builders.h"
13 #include "mlir/IR/BuiltinOps.h"
16 #include "llvm/ADT/ScopeExit.h"
17 #include "llvm/ADT/SetVector.h"
18 #include "llvm/ADT/SmallPtrSet.h"
19 #include "llvm/Support/Debug.h"
20 #include "llvm/Support/FormatVariadic.h"
21 #include "llvm/Support/SaveAndRestore.h"
22 #include "llvm/Support/ScopedPrinter.h"
23 
24 using namespace mlir;
25 using namespace mlir::detail;
26 
27 #define DEBUG_TYPE "dialect-conversion"
28 
29 /// Recursively collect all of the operations to convert from within 'region'.
30 /// If 'target' is nonnull, operations that are recursively legal have their
31 /// regions pre-filtered to avoid considering them for legalization.
32 static LogicalResult
34  Location regionLoc,
36  ConversionTarget *target = nullptr) {
37  if (llvm::empty(region))
38  return success();
39 
40  // Traverse starting from the entry block.
41  SmallVector<Block *, 16> worklist(1, &*region.begin());
42  DenseSet<Block *> visitedBlocks;
43  visitedBlocks.insert(worklist.front());
44  while (!worklist.empty()) {
45  Block *block = worklist.pop_back_val();
46 
47  // Compute the conversion set of each of the nested operations.
48  for (Operation &op : *block) {
49  toConvert.emplace_back(&op);
50 
51  // Don't check this operation's children for conversion if the operation
52  // is recursively legal.
53  auto legalityInfo = target ? target->isLegal(&op)
55  if (legalityInfo && legalityInfo->isRecursivelyLegal)
56  continue;
57  for (auto &region : op.getRegions()) {
58  if (failed(computeConversionSet(region.getBlocks(), region.getLoc(),
59  toConvert, target)))
60  return failure();
61  }
62  }
63 
64  // Recurse to children that haven't been visited.
65  for (Block *succ : block->getSuccessors())
66  if (visitedBlocks.insert(succ).second)
67  worklist.push_back(succ);
68  }
69 
70  // Check that all blocks in the region were visited.
71  if (llvm::any_of(llvm::drop_begin(region, 1),
72  [&](Block &block) { return !visitedBlocks.count(&block); }))
73  return emitError(regionLoc, "unreachable blocks were not converted");
74  return success();
75 }
76 
77 /// A utility function to log a successful result for the given reason.
78 template <typename... Args>
79 static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
80  LLVM_DEBUG({
81  os.unindent();
82  os.startLine() << "} -> SUCCESS";
83  if (!fmt.empty())
84  os.getOStream() << " : "
85  << llvm::formatv(fmt.data(), std::forward<Args>(args)...);
86  os.getOStream() << "\n";
87  });
88 }
89 
90 /// A utility function to log a failure result for the given reason.
91 template <typename... Args>
92 static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
93  LLVM_DEBUG({
94  os.unindent();
95  os.startLine() << "} -> FAILURE : "
96  << llvm::formatv(fmt.data(), std::forward<Args>(args)...)
97  << "\n";
98  });
99 }
100 
101 //===----------------------------------------------------------------------===//
102 // ConversionValueMapping
103 //===----------------------------------------------------------------------===//
104 
105 namespace {
106 /// This class wraps a BlockAndValueMapping to provide recursive lookup
107 /// functionality, i.e. we will traverse if the mapped value also has a mapping.
108 struct ConversionValueMapping {
109  /// Lookup a mapped value within the map. If a mapping for the provided value
110  /// does not exist then return the provided value. If `desiredType` is
111  /// non-null, returns the most recently mapped value with that type. If an
112  /// operand of that type does not exist, defaults to normal behavior.
113  Value lookupOrDefault(Value from, Type desiredType = nullptr) const;
114 
115  /// Lookup a mapped value within the map, or return null if a mapping does not
116  /// exist. If a mapping exists, this follows the same behavior of
117  /// `lookupOrDefault`.
118  Value lookupOrNull(Value from, Type desiredType = nullptr) const;
119 
120  /// Map a value to the one provided.
121  void map(Value oldVal, Value newVal) {
122  LLVM_DEBUG({
123  for (Value it = newVal; it; it = mapping.lookupOrNull(it))
124  assert(it != oldVal && "inserting cyclic mapping");
125  });
126  mapping.map(oldVal, newVal);
127  }
128 
129  /// Try to map a value to the one provided. Returns false if a transitive
130  /// mapping from the new value to the old value already exists, true if the
131  /// map was updated.
132  bool tryMap(Value oldVal, Value newVal);
133 
134  /// Drop the last mapping for the given value.
135  void erase(Value value) { mapping.erase(value); }
136 
137  /// Returns the inverse raw value mapping (without recursive query support).
138  DenseMap<Value, SmallVector<Value>> getInverse() const {
140  for (auto &it : mapping.getValueMap())
141  inverse[it.second].push_back(it.first);
142  return inverse;
143  }
144 
145 private:
146  /// Current value mappings.
147  BlockAndValueMapping mapping;
148 };
149 } // namespace
150 
151 Value ConversionValueMapping::lookupOrDefault(Value from,
152  Type desiredType) const {
153  // If there was no desired type, simply find the leaf value.
154  if (!desiredType) {
155  // If this value had a valid mapping, unmap that value as well in the case
156  // that it was also replaced.
157  while (auto mappedValue = mapping.lookupOrNull(from))
158  from = mappedValue;
159  return from;
160  }
161 
162  // Otherwise, try to find the deepest value that has the desired type.
163  Value desiredValue;
164  do {
165  if (from.getType() == desiredType)
166  desiredValue = from;
167 
168  Value mappedValue = mapping.lookupOrNull(from);
169  if (!mappedValue)
170  break;
171  from = mappedValue;
172  } while (true);
173 
174  // If the desired value was found use it, otherwise default to the leaf value.
175  return desiredValue ? desiredValue : from;
176 }
177 
178 Value ConversionValueMapping::lookupOrNull(Value from, Type desiredType) const {
179  Value result = lookupOrDefault(from, desiredType);
180  if (result == from || (desiredType && result.getType() != desiredType))
181  return nullptr;
182  return result;
183 }
184 
185 bool ConversionValueMapping::tryMap(Value oldVal, Value newVal) {
186  for (Value it = newVal; it; it = mapping.lookupOrNull(it))
187  if (it == oldVal)
188  return false;
189  map(oldVal, newVal);
190  return true;
191 }
192 
193 //===----------------------------------------------------------------------===//
194 // Rewriter and Translation State
195 //===----------------------------------------------------------------------===//
196 namespace {
197 /// This class contains a snapshot of the current conversion rewriter state.
198 /// This is useful when saving and undoing a set of rewrites.
199 struct RewriterState {
200  RewriterState(unsigned numCreatedOps, unsigned numUnresolvedMaterializations,
201  unsigned numReplacements, unsigned numArgReplacements,
202  unsigned numBlockActions, unsigned numIgnoredOperations,
203  unsigned numRootUpdates)
204  : numCreatedOps(numCreatedOps),
205  numUnresolvedMaterializations(numUnresolvedMaterializations),
206  numReplacements(numReplacements),
207  numArgReplacements(numArgReplacements),
208  numBlockActions(numBlockActions),
209  numIgnoredOperations(numIgnoredOperations),
210  numRootUpdates(numRootUpdates) {}
211 
212  /// The current number of created operations.
213  unsigned numCreatedOps;
214 
215  /// The current number of unresolved materializations.
216  unsigned numUnresolvedMaterializations;
217 
218  /// The current number of replacements queued.
219  unsigned numReplacements;
220 
221  /// The current number of argument replacements queued.
222  unsigned numArgReplacements;
223 
224  /// The current number of block actions performed.
225  unsigned numBlockActions;
226 
227  /// The current number of ignored operations.
228  unsigned numIgnoredOperations;
229 
230  /// The current number of operations that were updated in place.
231  unsigned numRootUpdates;
232 };
233 
234 //===----------------------------------------------------------------------===//
235 // OperationTransactionState
236 
237 /// The state of an operation that was updated by a pattern in-place. This
238 /// contains all of the necessary information to reconstruct an operation that
239 /// was updated in place.
240 class OperationTransactionState {
241 public:
242  OperationTransactionState() = default;
243  OperationTransactionState(Operation *op)
244  : op(op), loc(op->getLoc()), attrs(op->getAttrDictionary()),
245  operands(op->operand_begin(), op->operand_end()),
246  successors(op->successor_begin(), op->successor_end()) {}
247 
248  /// Discard the transaction state and reset the state of the original
249  /// operation.
250  void resetOperation() const {
251  op->setLoc(loc);
252  op->setAttrs(attrs);
253  op->setOperands(operands);
254  for (const auto &it : llvm::enumerate(successors))
255  op->setSuccessor(it.value(), it.index());
256  }
257 
258  /// Return the original operation of this state.
259  Operation *getOperation() const { return op; }
260 
261 private:
262  Operation *op;
263  LocationAttr loc;
264  DictionaryAttr attrs;
265  SmallVector<Value, 8> operands;
266  SmallVector<Block *, 2> successors;
267 };
268 
269 //===----------------------------------------------------------------------===//
270 // OpReplacement
271 
272 /// This class represents one requested operation replacement via 'replaceOp' or
273 /// 'eraseOp`.
274 struct OpReplacement {
275  OpReplacement(TypeConverter *converter = nullptr) : converter(converter) {}
276 
277  /// An optional type converter that can be used to materialize conversions
278  /// between the new and old values if necessary.
279  TypeConverter *converter;
280 };
281 
282 //===----------------------------------------------------------------------===//
283 // BlockAction
284 
285 /// The kind of the block action performed during the rewrite. Actions can be
286 /// undone if the conversion fails.
287 enum class BlockActionKind {
288  Create,
289  Erase,
290  Merge,
291  Move,
292  Split,
293  TypeConversion
294 };
295 
296 /// Original position of the given block in its parent region. During undo
297 /// actions, the block needs to be placed after `insertAfterBlock`.
298 struct BlockPosition {
299  Region *region;
300  Block *insertAfterBlock;
301 };
302 
303 /// Information needed to undo the merge actions.
304 /// - the source block, and
305 /// - the Operation that was the last operation in the dest block before the
306 /// merge (could be null if the dest block was empty).
307 struct MergeInfo {
308  Block *sourceBlock;
309  Operation *destBlockLastInst;
310 };
311 
312 /// The storage class for an undoable block action (one of BlockActionKind),
313 /// contains the information necessary to undo this action.
314 struct BlockAction {
315  static BlockAction getCreate(Block *block) {
316  return {BlockActionKind::Create, block, {}};
317  }
318  static BlockAction getErase(Block *block, BlockPosition originalPosition) {
319  return {BlockActionKind::Erase, block, {originalPosition}};
320  }
321  static BlockAction getMerge(Block *block, Block *sourceBlock) {
322  BlockAction action{BlockActionKind::Merge, block, {}};
323  action.mergeInfo = {sourceBlock, block->empty() ? nullptr : &block->back()};
324  return action;
325  }
326  static BlockAction getMove(Block *block, BlockPosition originalPosition) {
327  return {BlockActionKind::Move, block, {originalPosition}};
328  }
329  static BlockAction getSplit(Block *block, Block *originalBlock) {
330  BlockAction action{BlockActionKind::Split, block, {}};
331  action.originalBlock = originalBlock;
332  return action;
333  }
334  static BlockAction getTypeConversion(Block *block) {
335  return BlockAction{BlockActionKind::TypeConversion, block, {}};
336  }
337 
338  // The action kind.
339  BlockActionKind kind;
340 
341  // A pointer to the block that was created by the action.
342  Block *block;
343 
344  union {
345  // In use if kind == BlockActionKind::Move or BlockActionKind::Erase, and
346  // contains a pointer to the region that originally contained the block as
347  // well as the position of the block in that region.
348  BlockPosition originalPosition;
349  // In use if kind == BlockActionKind::Split and contains a pointer to the
350  // block that was split into two parts.
351  Block *originalBlock;
352  // In use if kind == BlockActionKind::Merge, and contains the information
353  // needed to undo the merge.
354  MergeInfo mergeInfo;
355  };
356 };
357 
358 //===----------------------------------------------------------------------===//
359 // UnresolvedMaterialization
360 
361 /// This class represents an unresolved materialization, i.e. a materialization
362 /// that was inserted during conversion that needs to be legalized at the end of
363 /// the conversion process.
364 class UnresolvedMaterialization {
365 public:
366  /// The type of materialization.
367  enum Kind {
368  /// This materialization materializes a conversion for an illegal block
369  /// argument type, to a legal one.
370  Argument,
371 
372  /// This materialization materializes a conversion from an illegal type to a
373  /// legal one.
374  Target
375  };
376 
377  UnresolvedMaterialization(UnrealizedConversionCastOp op = nullptr,
378  TypeConverter *converter = nullptr,
379  Kind kind = Target, Type origOutputType = nullptr)
380  : op(op), converterAndKind(converter, kind),
381  origOutputType(origOutputType) {}
382 
383  /// Return the temporary conversion operation inserted for this
384  /// materialization.
385  UnrealizedConversionCastOp getOp() const { return op; }
386 
387  /// Return the type converter of this materialization (which may be null).
388  TypeConverter *getConverter() const { return converterAndKind.getPointer(); }
389 
390  /// Return the kind of this materialization.
391  Kind getKind() const { return converterAndKind.getInt(); }
392 
393  /// Set the kind of this materialization.
394  void setKind(Kind kind) { converterAndKind.setInt(kind); }
395 
396  /// Return the original illegal output type of the input values.
397  Type getOrigOutputType() const { return origOutputType; }
398 
399 private:
400  /// The unresolved materialization operation created during conversion.
401  UnrealizedConversionCastOp op;
402 
403  /// The corresponding type converter to use when resolving this
404  /// materialization, and the kind of this materialization.
405  llvm::PointerIntPair<TypeConverter *, 1, Kind> converterAndKind;
406 
407  /// The original output type. This is only used for argument conversions.
408  Type origOutputType;
409 };
410 } // namespace
411 
412 /// Build an unresolved materialization operation given an output type and set
413 /// of input operands.
415  UnresolvedMaterialization::Kind kind, Block *insertBlock,
416  Block::iterator insertPt, Location loc, ValueRange inputs, Type outputType,
417  Type origOutputType, TypeConverter *converter,
418  SmallVectorImpl<UnresolvedMaterialization> &unresolvedMaterializations) {
419  // Avoid materializing an unnecessary cast.
420  if (inputs.size() == 1 && inputs.front().getType() == outputType)
421  return inputs.front();
422 
423  // Create an unresolved materialization. We use a new OpBuilder to avoid
424  // tracking the materialization like we do for other operations.
425  OpBuilder builder(insertBlock, insertPt);
426  auto convertOp =
427  builder.create<UnrealizedConversionCastOp>(loc, outputType, inputs);
428  unresolvedMaterializations.emplace_back(convertOp, converter, kind,
429  origOutputType);
430  return convertOp.getResult(0);
431 }
433  PatternRewriter &rewriter, Location loc, ValueRange inputs,
434  Type origOutputType, Type outputType, TypeConverter *converter,
435  SmallVectorImpl<UnresolvedMaterialization> &unresolvedMaterializations) {
438  rewriter.getInsertionPoint(), loc, inputs, outputType, origOutputType,
439  converter, unresolvedMaterializations);
440 }
442  Location loc, Value input, Type outputType, TypeConverter *converter,
443  SmallVectorImpl<UnresolvedMaterialization> &unresolvedMaterializations) {
444  Block *insertBlock = input.getParentBlock();
445  Block::iterator insertPt = insertBlock->begin();
446  if (OpResult inputRes = input.dyn_cast<OpResult>())
447  insertPt = ++inputRes.getOwner()->getIterator();
448 
450  UnresolvedMaterialization::Target, insertBlock, insertPt, loc, input,
451  outputType, outputType, converter, unresolvedMaterializations);
452 }
453 
454 //===----------------------------------------------------------------------===//
455 // ArgConverter
456 //===----------------------------------------------------------------------===//
457 namespace {
458 /// This class provides a simple interface for converting the types of block
459 /// arguments. This is done by creating a new block that contains the new legal
460 /// types and extracting the block that contains the old illegal types to allow
461 /// for undoing pending rewrites in the case of failure.
462 struct ArgConverter {
463  ArgConverter(
464  PatternRewriter &rewriter,
465  SmallVectorImpl<UnresolvedMaterialization> &unresolvedMaterializations)
466  : rewriter(rewriter),
467  unresolvedMaterializations(unresolvedMaterializations) {}
468 
469  /// This structure contains the information pertaining to an argument that has
470  /// been converted.
471  struct ConvertedArgInfo {
472  ConvertedArgInfo(unsigned newArgIdx, unsigned newArgSize,
473  Value castValue = nullptr)
474  : newArgIdx(newArgIdx), newArgSize(newArgSize), castValue(castValue) {}
475 
476  /// The start index of in the new argument list that contains arguments that
477  /// replace the original.
478  unsigned newArgIdx;
479 
480  /// The number of arguments that replaced the original argument.
481  unsigned newArgSize;
482 
483  /// The cast value that was created to cast from the new arguments to the
484  /// old. This only used if 'newArgSize' > 1.
485  Value castValue;
486  };
487 
488  /// This structure contains information pertaining to a block that has had its
489  /// signature converted.
490  struct ConvertedBlockInfo {
491  ConvertedBlockInfo(Block *origBlock, TypeConverter *converter)
492  : origBlock(origBlock), converter(converter) {}
493 
494  /// The original block that was requested to have its signature converted.
495  Block *origBlock;
496 
497  /// The conversion information for each of the arguments. The information is
498  /// None if the argument was dropped during conversion.
500 
501  /// The type converter used to convert the arguments.
502  TypeConverter *converter;
503  };
504 
505  /// Return if the signature of the given block has already been converted.
506  bool hasBeenConverted(Block *block) const {
507  return conversionInfo.count(block) || convertedBlocks.count(block);
508  }
509 
510  /// Set the type converter to use for the given region.
511  void setConverter(Region *region, TypeConverter *typeConverter) {
512  assert(typeConverter && "expected valid type converter");
513  regionToConverter[region] = typeConverter;
514  }
515 
516  /// Return the type converter to use for the given region, or null if there
517  /// isn't one.
518  TypeConverter *getConverter(Region *region) {
519  return regionToConverter.lookup(region);
520  }
521 
522  //===--------------------------------------------------------------------===//
523  // Rewrite Application
524  //===--------------------------------------------------------------------===//
525 
526  /// Erase any rewrites registered for the blocks within the given operation
527  /// which is about to be removed. This merely drops the rewrites without
528  /// undoing them.
529  void notifyOpRemoved(Operation *op);
530 
531  /// Cleanup and undo any generated conversions for the arguments of block.
532  /// This method replaces the new block with the original, reverting the IR to
533  /// its original state.
534  void discardRewrites(Block *block);
535 
536  /// Fully replace uses of the old arguments with the new.
537  void applyRewrites(ConversionValueMapping &mapping);
538 
539  /// Materialize any necessary conversions for converted arguments that have
540  /// live users, using the provided `findLiveUser` to search for a user that
541  /// survives the conversion process.
543  materializeLiveConversions(ConversionValueMapping &mapping,
544  OpBuilder &builder,
545  function_ref<Operation *(Value)> findLiveUser);
546 
547  //===--------------------------------------------------------------------===//
548  // Conversion
549  //===--------------------------------------------------------------------===//
550 
551  /// Attempt to convert the signature of the given block, if successful a new
552  /// block is returned containing the new arguments. Returns `block` if it did
553  /// not require conversion.
555  convertSignature(Block *block, TypeConverter *converter,
556  ConversionValueMapping &mapping,
557  SmallVectorImpl<BlockArgument> &argReplacements);
558 
559  /// Apply the given signature conversion on the given block. The new block
560  /// containing the updated signature is returned. If no conversions were
561  /// necessary, e.g. if the block has no arguments, `block` is returned.
562  /// `converter` is used to generate any necessary cast operations that
563  /// translate between the origin argument types and those specified in the
564  /// signature conversion.
565  Block *applySignatureConversion(
566  Block *block, TypeConverter *converter,
567  TypeConverter::SignatureConversion &signatureConversion,
568  ConversionValueMapping &mapping,
569  SmallVectorImpl<BlockArgument> &argReplacements);
570 
571  /// Insert a new conversion into the cache.
572  void insertConversion(Block *newBlock, ConvertedBlockInfo &&info);
573 
574  /// A collection of blocks that have had their arguments converted. This is a
575  /// map from the new replacement block, back to the original block.
576  llvm::MapVector<Block *, ConvertedBlockInfo> conversionInfo;
577 
578  /// The set of original blocks that were converted.
579  DenseSet<Block *> convertedBlocks;
580 
581  /// A mapping from valid regions, to those containing the original blocks of a
582  /// conversion.
584 
585  /// A mapping of regions to type converters that should be used when
586  /// converting the arguments of blocks within that region.
587  DenseMap<Region *, TypeConverter *> regionToConverter;
588 
589  /// The pattern rewriter to use when materializing conversions.
590  PatternRewriter &rewriter;
591 
592  /// An ordered set of unresolved materializations during conversion.
593  SmallVectorImpl<UnresolvedMaterialization> &unresolvedMaterializations;
594 };
595 } // namespace
596 
597 //===----------------------------------------------------------------------===//
598 // Rewrite Application
599 
600 void ArgConverter::notifyOpRemoved(Operation *op) {
601  if (conversionInfo.empty())
602  return;
603 
604  for (Region &region : op->getRegions()) {
605  for (Block &block : region) {
606  // Drop any rewrites from within.
607  for (Operation &nestedOp : block)
608  if (nestedOp.getNumRegions())
609  notifyOpRemoved(&nestedOp);
610 
611  // Check if this block was converted.
612  auto it = conversionInfo.find(&block);
613  if (it == conversionInfo.end())
614  continue;
615 
616  // Drop all uses of the original arguments and delete the original block.
617  Block *origBlock = it->second.origBlock;
618  for (BlockArgument arg : origBlock->getArguments())
619  arg.dropAllUses();
620  conversionInfo.erase(it);
621  }
622  }
623 }
624 
625 void ArgConverter::discardRewrites(Block *block) {
626  auto it = conversionInfo.find(block);
627  if (it == conversionInfo.end())
628  return;
629  Block *origBlock = it->second.origBlock;
630 
631  // Drop all uses of the new block arguments and replace uses of the new block.
632  for (int i = block->getNumArguments() - 1; i >= 0; --i)
633  block->getArgument(i).dropAllUses();
634  block->replaceAllUsesWith(origBlock);
635 
636  // Move the operations back the original block and the delete the new block.
637  origBlock->getOperations().splice(origBlock->end(), block->getOperations());
638  origBlock->moveBefore(block);
639  block->erase();
640 
641  convertedBlocks.erase(origBlock);
642  conversionInfo.erase(it);
643 }
644 
645 void ArgConverter::applyRewrites(ConversionValueMapping &mapping) {
646  for (auto &info : conversionInfo) {
647  ConvertedBlockInfo &blockInfo = info.second;
648  Block *origBlock = blockInfo.origBlock;
649 
650  // Process the remapping for each of the original arguments.
651  for (unsigned i = 0, e = origBlock->getNumArguments(); i != e; ++i) {
652  Optional<ConvertedArgInfo> &argInfo = blockInfo.argInfo[i];
653  BlockArgument origArg = origBlock->getArgument(i);
654 
655  // Handle the case of a 1->0 value mapping.
656  if (!argInfo) {
657  if (Value newArg = mapping.lookupOrNull(origArg, origArg.getType()))
658  origArg.replaceAllUsesWith(newArg);
659  continue;
660  }
661 
662  // Otherwise this is a 1->1+ value mapping.
663  Value castValue = argInfo->castValue;
664  assert(argInfo->newArgSize >= 1 && castValue && "expected 1->1+ mapping");
665 
666  // If the argument is still used, replace it with the generated cast.
667  if (!origArg.use_empty()) {
668  origArg.replaceAllUsesWith(
669  mapping.lookupOrDefault(castValue, origArg.getType()));
670  }
671  }
672  }
673 }
674 
675 LogicalResult ArgConverter::materializeLiveConversions(
676  ConversionValueMapping &mapping, OpBuilder &builder,
677  function_ref<Operation *(Value)> findLiveUser) {
678  for (auto &info : conversionInfo) {
679  Block *newBlock = info.first;
680  ConvertedBlockInfo &blockInfo = info.second;
681  Block *origBlock = blockInfo.origBlock;
682 
683  // Process the remapping for each of the original arguments.
684  for (unsigned i = 0, e = origBlock->getNumArguments(); i != e; ++i) {
685  // If the type of this argument changed and the argument is still live, we
686  // need to materialize a conversion.
687  BlockArgument origArg = origBlock->getArgument(i);
688  if (mapping.lookupOrNull(origArg, origArg.getType()))
689  continue;
690  Operation *liveUser = findLiveUser(origArg);
691  if (!liveUser)
692  continue;
693 
694  Value replacementValue = mapping.lookupOrDefault(origArg);
695  bool isDroppedArg = replacementValue == origArg;
696  if (isDroppedArg)
697  rewriter.setInsertionPointToStart(newBlock);
698  else
699  rewriter.setInsertionPointAfterValue(replacementValue);
700  Value newArg;
701  if (blockInfo.converter) {
702  newArg = blockInfo.converter->materializeSourceConversion(
703  rewriter, origArg.getLoc(), origArg.getType(),
704  isDroppedArg ? ValueRange() : ValueRange(replacementValue));
705  assert((!newArg || newArg.getType() == origArg.getType()) &&
706  "materialization hook did not provide a value of the expected "
707  "type");
708  }
709  if (!newArg) {
711  emitError(origArg.getLoc())
712  << "failed to materialize conversion for block argument #" << i
713  << " that remained live after conversion, type was "
714  << origArg.getType();
715  if (!isDroppedArg)
716  diag << ", with target type " << replacementValue.getType();
717  diag.attachNote(liveUser->getLoc())
718  << "see existing live user here: " << *liveUser;
719  return failure();
720  }
721  mapping.map(origArg, newArg);
722  }
723  }
724  return success();
725 }
726 
727 //===----------------------------------------------------------------------===//
728 // Conversion
729 
730 FailureOr<Block *> ArgConverter::convertSignature(
731  Block *block, TypeConverter *converter, ConversionValueMapping &mapping,
732  SmallVectorImpl<BlockArgument> &argReplacements) {
733  // Check if the block was already converted. If the block is detached,
734  // conservatively assume it is going to be deleted.
735  if (hasBeenConverted(block) || !block->getParent())
736  return block;
737  // If a converter wasn't provided, and the block wasn't already converted,
738  // there is nothing we can do.
739  if (!converter)
740  return failure();
741 
742  // Try to convert the signature for the block with the provided converter.
743  if (auto conversion = converter->convertBlockSignature(block))
744  return applySignatureConversion(block, converter, *conversion, mapping,
745  argReplacements);
746  return failure();
747 }
748 
749 Block *ArgConverter::applySignatureConversion(
750  Block *block, TypeConverter *converter,
751  TypeConverter::SignatureConversion &signatureConversion,
752  ConversionValueMapping &mapping,
753  SmallVectorImpl<BlockArgument> &argReplacements) {
754  // If no arguments are being changed or added, there is nothing to do.
755  unsigned origArgCount = block->getNumArguments();
756  auto convertedTypes = signatureConversion.getConvertedTypes();
757  if (origArgCount == 0 && convertedTypes.empty())
758  return block;
759 
760  // Split the block at the beginning to get a new block to use for the updated
761  // signature.
762  Block *newBlock = block->splitBlock(block->begin());
763  block->replaceAllUsesWith(newBlock);
764 
765  // FIXME: We should map the new arguments to proper locations.
766  SmallVector<Location> newLocs(convertedTypes.size(),
767  rewriter.getUnknownLoc());
768  SmallVector<Value, 4> newArgRange(
769  newBlock->addArguments(convertedTypes, newLocs));
770  ArrayRef<Value> newArgs(newArgRange);
771 
772  // Remap each of the original arguments as determined by the signature
773  // conversion.
774  ConvertedBlockInfo info(block, converter);
775  info.argInfo.resize(origArgCount);
776 
777  OpBuilder::InsertionGuard guard(rewriter);
778  rewriter.setInsertionPointToStart(newBlock);
779  for (unsigned i = 0; i != origArgCount; ++i) {
780  auto inputMap = signatureConversion.getInputMapping(i);
781  if (!inputMap)
782  continue;
783  BlockArgument origArg = block->getArgument(i);
784 
785  // If inputMap->replacementValue is not nullptr, then the argument is
786  // dropped and a replacement value is provided to be the remappedValue.
787  if (inputMap->replacementValue) {
788  assert(inputMap->size == 0 &&
789  "invalid to provide a replacement value when the argument isn't "
790  "dropped");
791  mapping.map(origArg, inputMap->replacementValue);
792  argReplacements.push_back(origArg);
793  continue;
794  }
795 
796  // Otherwise, this is a 1->1+ mapping.
797  auto replArgs = newArgs.slice(inputMap->inputNo, inputMap->size);
798  Value newArg;
799 
800  // If this is a 1->1 mapping and the types of new and replacement arguments
801  // match (i.e. it's an identity map), then the argument is mapped to its
802  // original type.
803  // FIXME: We simply pass through the replacement argument if there wasn't a
804  // converter, which isn't great as it allows implicit type conversions to
805  // appear. We should properly restructure this code to handle cases where a
806  // converter isn't provided and also to properly handle the case where an
807  // argument materialization is actually a temporary source materialization
808  // (e.g. in the case of 1->N).
809  if (replArgs.size() == 1 &&
810  (!converter || replArgs[0].getType() == origArg.getType())) {
811  newArg = replArgs.front();
812  } else {
813  Type origOutputType = origArg.getType();
814 
815  // Legalize the argument output type.
816  Type outputType = origOutputType;
817  if (Type legalOutputType = converter->convertType(outputType))
818  outputType = legalOutputType;
819 
821  rewriter, origArg.getLoc(), replArgs, origOutputType, outputType,
822  converter, unresolvedMaterializations);
823  }
824 
825  mapping.map(origArg, newArg);
826  argReplacements.push_back(origArg);
827  info.argInfo[i] =
828  ConvertedArgInfo(inputMap->inputNo, inputMap->size, newArg);
829  }
830 
831  // Remove the original block from the region and return the new one.
832  insertConversion(newBlock, std::move(info));
833  return newBlock;
834 }
835 
836 void ArgConverter::insertConversion(Block *newBlock,
837  ConvertedBlockInfo &&info) {
838  // Get a region to insert the old block.
839  Region *region = newBlock->getParent();
840  std::unique_ptr<Region> &mappedRegion = regionMapping[region];
841  if (!mappedRegion)
842  mappedRegion = std::make_unique<Region>(region->getParentOp());
843 
844  // Move the original block to the mapped region and emplace the conversion.
845  mappedRegion->getBlocks().splice(mappedRegion->end(), region->getBlocks(),
846  info.origBlock->getIterator());
847  convertedBlocks.insert(info.origBlock);
848  conversionInfo.insert({newBlock, std::move(info)});
849 }
850 
851 //===----------------------------------------------------------------------===//
852 // ConversionPatternRewriterImpl
853 //===----------------------------------------------------------------------===//
854 namespace mlir {
855 namespace detail {
858  : argConverter(rewriter, unresolvedMaterializations),
859  notifyCallback(nullptr) {}
860 
861  /// Cleanup and destroy any generated rewrite operations. This method is
862  /// invoked when the conversion process fails.
863  void discardRewrites();
864 
865  /// Apply all requested operation rewrites. This method is invoked when the
866  /// conversion process succeeds.
867  void applyRewrites();
868 
869  //===--------------------------------------------------------------------===//
870  // State Management
871  //===--------------------------------------------------------------------===//
872 
873  /// Return the current state of the rewriter.
874  RewriterState getCurrentState();
875 
876  /// Reset the state of the rewriter to a previously saved point.
877  void resetState(RewriterState state);
878 
879  /// Erase any blocks that were unlinked from their regions and stored in block
880  /// actions.
881  void eraseDanglingBlocks();
882 
883  /// Undo the block actions (motions, splits) one by one in reverse order until
884  /// "numActionsToKeep" actions remains.
885  void undoBlockActions(unsigned numActionsToKeep = 0);
886 
887  /// Remap the given values to those with potentially different types. Returns
888  /// success if the values could be remapped, failure otherwise. `valueDiagTag`
889  /// is the tag used when describing a value within a diagnostic, e.g.
890  /// "operand".
891  LogicalResult remapValues(StringRef valueDiagTag, Optional<Location> inputLoc,
892  PatternRewriter &rewriter, ValueRange values,
893  SmallVectorImpl<Value> &remapped);
894 
895  /// Returns true if the given operation is ignored, and does not need to be
896  /// converted.
897  bool isOpIgnored(Operation *op) const;
898 
899  /// Recursively marks the nested operations under 'op' as ignored. This
900  /// removes them from being considered for legalization.
901  void markNestedOpsIgnored(Operation *op);
902 
903  //===--------------------------------------------------------------------===//
904  // Type Conversion
905  //===--------------------------------------------------------------------===//
906 
907  /// Convert the signature of the given block.
908  FailureOr<Block *> convertBlockSignature(
909  Block *block, TypeConverter *converter,
910  TypeConverter::SignatureConversion *conversion = nullptr);
911 
912  /// Apply a signature conversion on the given region, using `converter` for
913  /// materializations if not null.
914  Block *
915  applySignatureConversion(Region *region,
917  TypeConverter *converter);
918 
919  /// Convert the types of block arguments within the given region.
921  convertRegionTypes(Region *region, TypeConverter &converter,
922  TypeConverter::SignatureConversion *entryConversion);
923 
924  /// Convert the types of non-entry block arguments within the given region.
925  LogicalResult convertNonEntryRegionTypes(
926  Region *region, TypeConverter &converter,
927  ArrayRef<TypeConverter::SignatureConversion> blockConversions = {});
928 
929  //===--------------------------------------------------------------------===//
930  // Rewriter Notification Hooks
931  //===--------------------------------------------------------------------===//
932 
933  /// PatternRewriter hook for replacing the results of an operation.
934  void notifyOpReplaced(Operation *op, ValueRange newValues);
935 
936  /// Notifies that a block is about to be erased.
937  void notifyBlockIsBeingErased(Block *block);
938 
939  /// Notifies that a block was created.
940  void notifyCreatedBlock(Block *block);
941 
942  /// Notifies that a block was split.
943  void notifySplitBlock(Block *block, Block *continuation);
944 
945  /// Notifies that `block` is being merged with `srcBlock`.
946  void notifyBlocksBeingMerged(Block *block, Block *srcBlock);
947 
948  /// Notifies that the blocks of a region are about to be moved.
949  void notifyRegionIsBeingInlinedBefore(Region &region, Region &parent,
950  Region::iterator before);
951 
952  /// Notifies that the blocks of a region were cloned into another.
953  void notifyRegionWasClonedBefore(iterator_range<Region::iterator> &blocks,
954  Location origRegionLoc);
955 
956  /// Notifies that a pattern match failed for the given reason.
958  notifyMatchFailure(Location loc,
959  function_ref<void(Diagnostic &)> reasonCallback);
960 
961  //===--------------------------------------------------------------------===//
962  // State
963  //===--------------------------------------------------------------------===//
964 
965  // Mapping between replaced values that differ in type. This happens when
966  // replacing a value with one of a different type.
967  ConversionValueMapping mapping;
968 
969  /// Utility used to convert block arguments.
970  ArgConverter argConverter;
971 
972  /// Ordered vector of all of the newly created operations during conversion.
974 
975  /// Ordered vector of all unresolved type conversion materializations during
976  /// conversion.
978 
979  /// Ordered map of requested operation replacements.
980  llvm::MapVector<Operation *, OpReplacement> replacements;
981 
982  /// Ordered vector of any requested block argument replacements.
984 
985  /// Ordered list of block operations (creations, splits, motions).
987 
988  /// A set of operations that should no longer be considered for legalization,
989  /// but were not directly replace/erased/etc. by a pattern. These are
990  /// generally child operations of other operations who were
991  /// replaced/erased/etc. This is not meant to be an exhaustive list of all
992  /// operations, but the minimal set that can be used to detect if a given
993  /// operation should be `ignored`. For example, we may add the operations that
994  /// define non-empty regions to the set, but not any of the others. This
995  /// simplifies the amount of memory needed as we can query if the parent
996  /// operation was ignored.
998 
999  /// A transaction state for each of operations that were updated in-place.
1001 
1002  /// A vector of indices into `replacements` of operations that were replaced
1003  /// with values with different result types than the original operation, e.g.
1004  /// 1->N conversion of some kind.
1006 
1007  /// The current type converter, or nullptr if no type converter is currently
1008  /// active.
1009  TypeConverter *currentTypeConverter = nullptr;
1010 
1011  /// This allows the user to collect the match failure message.
1013 
1014 #ifndef NDEBUG
1015  /// A set of operations that have pending updates. This tracking isn't
1016  /// strictly necessary, and is thus only active during debug builds for extra
1017  /// verification.
1019 
1020  /// A logger used to emit diagnostics during the conversion process.
1021  llvm::ScopedPrinter logger{llvm::dbgs()};
1022 #endif
1023 };
1024 } // namespace detail
1025 } // namespace mlir
1026 
1027 /// Detach any operations nested in the given operation from their parent
1028 /// blocks, and erase the given operation. This can be used when the nested
1029 /// operations are scheduled for erasure themselves, so deleting the regions of
1030 /// the given operation together with their content would result in double-free.
1031 /// This happens, for example, when rolling back op creation in the reverse
1032 /// order and if the nested ops were created before the parent op. This function
1033 /// does not need to collect nested ops recursively because it is expected to
1034 /// also be called for each nested op when it is about to be deleted.
1036  for (Region &region : op->getRegions()) {
1037  for (Block &block : region.getBlocks()) {
1038  while (!block.getOperations().empty())
1039  block.getOperations().remove(block.getOperations().begin());
1040  block.dropAllDefinedValueUses();
1041  }
1042  }
1043  op->dropAllUses();
1044  op->erase();
1045 }
1046 
1048  // Reset any operations that were updated in place.
1049  for (auto &state : rootUpdates)
1050  state.resetOperation();
1051 
1052  undoBlockActions();
1053 
1054  // Remove any newly created ops.
1055  for (UnresolvedMaterialization &materialization : unresolvedMaterializations)
1056  detachNestedAndErase(materialization.getOp());
1057  for (auto *op : llvm::reverse(createdOps))
1059 }
1060 
1062  // Apply all of the rewrites replacements requested during conversion.
1063  for (auto &repl : replacements) {
1064  for (OpResult result : repl.first->getResults())
1065  if (Value newValue = mapping.lookupOrNull(result, result.getType()))
1066  result.replaceAllUsesWith(newValue);
1067 
1068  // If this operation defines any regions, drop any pending argument
1069  // rewrites.
1070  if (repl.first->getNumRegions())
1071  argConverter.notifyOpRemoved(repl.first);
1072  }
1073 
1074  // Apply all of the requested argument replacements.
1075  for (BlockArgument arg : argReplacements) {
1076  Value repl = mapping.lookupOrNull(arg, arg.getType());
1077  if (!repl)
1078  continue;
1079 
1080  if (repl.isa<BlockArgument>()) {
1081  arg.replaceAllUsesWith(repl);
1082  continue;
1083  }
1084 
1085  // If the replacement value is an operation, we check to make sure that we
1086  // don't replace uses that are within the parent operation of the
1087  // replacement value.
1088  Operation *replOp = repl.cast<OpResult>().getOwner();
1089  Block *replBlock = replOp->getBlock();
1090  arg.replaceUsesWithIf(repl, [&](OpOperand &operand) {
1091  Operation *user = operand.getOwner();
1092  return user->getBlock() != replBlock || replOp->isBeforeInBlock(user);
1093  });
1094  }
1095 
1096  // Drop all of the unresolved materialization operations created during
1097  // conversion.
1098  for (auto &mat : unresolvedMaterializations) {
1099  mat.getOp()->dropAllUses();
1100  mat.getOp()->erase();
1101  }
1102 
1103  // In a second pass, erase all of the replaced operations in reverse. This
1104  // allows processing nested operations before their parent region is
1105  // destroyed. Because we process in reverse order, producers may be deleted
1106  // before their users (a pattern deleting a producer and then the consumer)
1107  // so we first drop all uses explicitly.
1108  for (auto &repl : llvm::reverse(replacements)) {
1109  repl.first->dropAllUses();
1110  repl.first->erase();
1111  }
1112 
1113  argConverter.applyRewrites(mapping);
1114 
1115  // Now that the ops have been erased, also erase dangling blocks.
1116  eraseDanglingBlocks();
1117 }
1118 
1119 //===----------------------------------------------------------------------===//
1120 // State Management
1121 
1123  return RewriterState(createdOps.size(), unresolvedMaterializations.size(),
1124  replacements.size(), argReplacements.size(),
1125  blockActions.size(), ignoredOps.size(),
1126  rootUpdates.size());
1127 }
1128 
1130  // Reset any operations that were updated in place.
1131  for (unsigned i = state.numRootUpdates, e = rootUpdates.size(); i != e; ++i)
1132  rootUpdates[i].resetOperation();
1133  rootUpdates.resize(state.numRootUpdates);
1134 
1135  // Reset any replaced arguments.
1136  for (BlockArgument replacedArg :
1137  llvm::drop_begin(argReplacements, state.numArgReplacements))
1138  mapping.erase(replacedArg);
1139  argReplacements.resize(state.numArgReplacements);
1140 
1141  // Undo any block actions.
1142  undoBlockActions(state.numBlockActions);
1143 
1144  // Reset any replaced operations and undo any saved mappings.
1145  for (auto &repl : llvm::drop_begin(replacements, state.numReplacements))
1146  for (auto result : repl.first->getResults())
1147  mapping.erase(result);
1148  while (replacements.size() != state.numReplacements)
1149  replacements.pop_back();
1150 
1151  // Pop all of the newly inserted materializations.
1152  while (unresolvedMaterializations.size() !=
1153  state.numUnresolvedMaterializations) {
1154  UnresolvedMaterialization mat = unresolvedMaterializations.pop_back_val();
1155  UnrealizedConversionCastOp op = mat.getOp();
1156 
1157  // If this was a target materialization, drop the mapping that was inserted.
1158  if (mat.getKind() == UnresolvedMaterialization::Target) {
1159  for (Value input : op->getOperands())
1160  mapping.erase(input);
1161  }
1163  }
1164 
1165  // Pop all of the newly created operations.
1166  while (createdOps.size() != state.numCreatedOps) {
1167  detachNestedAndErase(createdOps.back());
1168  createdOps.pop_back();
1169  }
1170 
1171  // Pop all of the recorded ignored operations that are no longer valid.
1172  while (ignoredOps.size() != state.numIgnoredOperations)
1173  ignoredOps.pop_back();
1174 
1175  // Reset operations with changed results.
1176  while (!operationsWithChangedResults.empty() &&
1177  operationsWithChangedResults.back() >= state.numReplacements)
1178  operationsWithChangedResults.pop_back();
1179 }
1180 
1182  for (auto &action : blockActions)
1183  if (action.kind == BlockActionKind::Erase)
1184  delete action.block;
1185 }
1186 
1188  unsigned numActionsToKeep) {
1189  for (auto &action :
1190  llvm::reverse(llvm::drop_begin(blockActions, numActionsToKeep))) {
1191  switch (action.kind) {
1192  // Delete the created block.
1193  case BlockActionKind::Create: {
1194  // Unlink all of the operations within this block, they will be deleted
1195  // separately.
1196  auto &blockOps = action.block->getOperations();
1197  while (!blockOps.empty())
1198  blockOps.remove(blockOps.begin());
1199  action.block->dropAllDefinedValueUses();
1200  action.block->erase();
1201  break;
1202  }
1203  // Put the block (owned by action) back into its original position.
1204  case BlockActionKind::Erase: {
1205  auto &blockList = action.originalPosition.region->getBlocks();
1206  Block *insertAfterBlock = action.originalPosition.insertAfterBlock;
1207  blockList.insert((insertAfterBlock
1208  ? std::next(Region::iterator(insertAfterBlock))
1209  : blockList.begin()),
1210  action.block);
1211  break;
1212  }
1213  // Split the block at the position which was originally the end of the
1214  // destination block (owned by action), and put the instructions back into
1215  // the block used before the merge.
1216  case BlockActionKind::Merge: {
1217  Block *sourceBlock = action.mergeInfo.sourceBlock;
1218  Block::iterator splitPoint =
1219  (action.mergeInfo.destBlockLastInst
1220  ? ++Block::iterator(action.mergeInfo.destBlockLastInst)
1221  : action.block->begin());
1222  sourceBlock->getOperations().splice(sourceBlock->begin(),
1223  action.block->getOperations(),
1224  splitPoint, action.block->end());
1225  break;
1226  }
1227  // Move the block back to its original position.
1228  case BlockActionKind::Move: {
1229  Region *originalRegion = action.originalPosition.region;
1230  Block *insertAfterBlock = action.originalPosition.insertAfterBlock;
1231  originalRegion->getBlocks().splice(
1232  (insertAfterBlock ? std::next(Region::iterator(insertAfterBlock))
1233  : originalRegion->end()),
1234  action.block->getParent()->getBlocks(), action.block);
1235  break;
1236  }
1237  // Merge back the block that was split out.
1238  case BlockActionKind::Split: {
1239  action.originalBlock->getOperations().splice(
1240  action.originalBlock->end(), action.block->getOperations());
1241  action.block->dropAllDefinedValueUses();
1242  action.block->erase();
1243  break;
1244  }
1245  // Undo the type conversion.
1246  case BlockActionKind::TypeConversion: {
1247  argConverter.discardRewrites(action.block);
1248  break;
1249  }
1250  }
1251  }
1252  blockActions.resize(numActionsToKeep);
1253 }
1254 
1256  StringRef valueDiagTag, Optional<Location> inputLoc,
1257  PatternRewriter &rewriter, ValueRange values,
1258  SmallVectorImpl<Value> &remapped) {
1259  remapped.reserve(llvm::size(values));
1260 
1261  SmallVector<Type, 1> legalTypes;
1262  for (const auto &it : llvm::enumerate(values)) {
1263  Value operand = it.value();
1264  Type origType = operand.getType();
1265 
1266  // If a converter was provided, get the desired legal types for this
1267  // operand.
1268  Type desiredType;
1269  if (currentTypeConverter) {
1270  // If there is no legal conversion, fail to match this pattern.
1271  legalTypes.clear();
1272  if (failed(currentTypeConverter->convertType(origType, legalTypes))) {
1273  Location operandLoc = inputLoc ? *inputLoc : operand.getLoc();
1274  return notifyMatchFailure(operandLoc, [=](Diagnostic &diag) {
1275  diag << "unable to convert type for " << valueDiagTag << " #"
1276  << it.index() << ", type was " << origType;
1277  });
1278  }
1279  // TODO: There currently isn't any mechanism to do 1->N type conversion
1280  // via the PatternRewriter replacement API, so for now we just ignore it.
1281  if (legalTypes.size() == 1)
1282  desiredType = legalTypes.front();
1283  } else {
1284  // TODO: What we should do here is just set `desiredType` to `origType`
1285  // and then handle the necessary type conversions after the conversion
1286  // process has finished. Unfortunately a lot of patterns currently rely on
1287  // receiving the new operands even if the types change, so we keep the
1288  // original behavior here for now until all of the patterns relying on
1289  // this get updated.
1290  }
1291  Value newOperand = mapping.lookupOrDefault(operand, desiredType);
1292 
1293  // Handle the case where the conversion was 1->1 and the new operand type
1294  // isn't legal.
1295  Type newOperandType = newOperand.getType();
1296  if (currentTypeConverter && desiredType && newOperandType != desiredType) {
1297  Location operandLoc = inputLoc ? *inputLoc : operand.getLoc();
1299  operandLoc, newOperand, desiredType, currentTypeConverter,
1300  unresolvedMaterializations);
1301  mapping.map(mapping.lookupOrDefault(newOperand), castValue);
1302  newOperand = castValue;
1303  }
1304  remapped.push_back(newOperand);
1305  }
1306  return success();
1307 }
1308 
1310  // Check to see if this operation was replaced or its parent ignored.
1311  return replacements.count(op) || ignoredOps.count(op->getParentOp());
1312 }
1313 
1315  // Walk this operation and collect nested operations that define non-empty
1316  // regions. We mark such operations as 'ignored' so that we know we don't have
1317  // to convert them, or their nested ops.
1318  if (op->getNumRegions() == 0)
1319  return;
1320  op->walk([&](Operation *op) {
1321  if (llvm::any_of(op->getRegions(),
1322  [](Region &region) { return !region.empty(); }))
1323  ignoredOps.insert(op);
1324  });
1325 }
1326 
1327 //===----------------------------------------------------------------------===//
1328 // Type Conversion
1329 
1331  Block *block, TypeConverter *converter,
1332  TypeConverter::SignatureConversion *conversion) {
1333  FailureOr<Block *> result =
1334  conversion ? argConverter.applySignatureConversion(
1335  block, converter, *conversion, mapping, argReplacements)
1336  : argConverter.convertSignature(block, converter, mapping,
1337  argReplacements);
1338  if (failed(result))
1339  return failure();
1340  if (Block *newBlock = result.getValue()) {
1341  if (newBlock != block)
1342  blockActions.push_back(BlockAction::getTypeConversion(newBlock));
1343  }
1344  return result;
1345 }
1346 
1348  Region *region, TypeConverter::SignatureConversion &conversion,
1349  TypeConverter *converter) {
1350  if (!region->empty())
1351  return *convertBlockSignature(&region->front(), converter, &conversion);
1352  return nullptr;
1353 }
1354 
1356  Region *region, TypeConverter &converter,
1357  TypeConverter::SignatureConversion *entryConversion) {
1358  argConverter.setConverter(region, &converter);
1359  if (region->empty())
1360  return nullptr;
1361 
1362  if (failed(convertNonEntryRegionTypes(region, converter)))
1363  return failure();
1364 
1365  FailureOr<Block *> newEntry =
1366  convertBlockSignature(&region->front(), &converter, entryConversion);
1367  return newEntry;
1368 }
1369 
1371  Region *region, TypeConverter &converter,
1373  argConverter.setConverter(region, &converter);
1374  if (region->empty())
1375  return success();
1376 
1377  // Convert the arguments of each block within the region.
1378  int blockIdx = 0;
1379  assert((blockConversions.empty() ||
1380  blockConversions.size() == region->getBlocks().size() - 1) &&
1381  "expected either to provide no SignatureConversions at all or to "
1382  "provide a SignatureConversion for each non-entry block");
1383 
1384  for (Block &block :
1385  llvm::make_early_inc_range(llvm::drop_begin(*region, 1))) {
1386  TypeConverter::SignatureConversion *blockConversion =
1387  blockConversions.empty()
1388  ? nullptr
1389  : const_cast<TypeConverter::SignatureConversion *>(
1390  &blockConversions[blockIdx++]);
1391 
1392  if (failed(convertBlockSignature(&block, &converter, blockConversion)))
1393  return failure();
1394  }
1395  return success();
1396 }
1397 
1398 //===----------------------------------------------------------------------===//
1399 // Rewriter Notification Hooks
1400 
1402  ValueRange newValues) {
1403  assert(newValues.size() == op->getNumResults());
1404  assert(!replacements.count(op) && "operation was already replaced");
1405 
1406  // Track if any of the results changed, e.g. erased and replaced with null.
1407  bool resultChanged = false;
1408 
1409  // Create mappings for each of the new result values.
1410  Value newValue, result;
1411  for (auto it : llvm::zip(newValues, op->getResults())) {
1412  std::tie(newValue, result) = it;
1413  if (!newValue) {
1414  resultChanged = true;
1415  continue;
1416  }
1417  // Remap, and check for any result type changes.
1418  mapping.map(result, newValue);
1419  resultChanged |= (newValue.getType() != result.getType());
1420  }
1421  if (resultChanged)
1422  operationsWithChangedResults.push_back(replacements.size());
1423 
1424  // Record the requested operation replacement.
1425  replacements.insert(std::make_pair(op, OpReplacement(currentTypeConverter)));
1426 
1427  // Mark this operation as recursively ignored so that we don't need to
1428  // convert any nested operations.
1429  markNestedOpsIgnored(op);
1430 }
1431 
1433  Region *region = block->getParent();
1434  Block *origPrevBlock = block->getPrevNode();
1435  blockActions.push_back(BlockAction::getErase(block, {region, origPrevBlock}));
1436 }
1437 
1439  blockActions.push_back(BlockAction::getCreate(block));
1440 }
1441 
1443  Block *continuation) {
1444  blockActions.push_back(BlockAction::getSplit(continuation, block));
1445 }
1446 
1448  Block *srcBlock) {
1449  blockActions.push_back(BlockAction::getMerge(block, srcBlock));
1450 }
1451 
1453  Region &region, Region &parent, Region::iterator before) {
1454  if (region.empty())
1455  return;
1456  Block *laterBlock = &region.back();
1457  for (auto &earlierBlock : llvm::drop_begin(llvm::reverse(region), 1)) {
1458  blockActions.push_back(
1459  BlockAction::getMove(laterBlock, {&region, &earlierBlock}));
1460  laterBlock = &earlierBlock;
1461  }
1462  blockActions.push_back(BlockAction::getMove(laterBlock, {&region, nullptr}));
1463 }
1464 
1466  iterator_range<Region::iterator> &blocks, Location origRegionLoc) {
1467  for (Block &block : blocks)
1468  blockActions.push_back(BlockAction::getCreate(&block));
1469 
1470  // Compute the conversion set for the inlined region.
1471  auto result = computeConversionSet(blocks, origRegionLoc, createdOps);
1472 
1473  // This original region has already had its conversion set computed, so there
1474  // shouldn't be any new failures.
1475  (void)result;
1476  assert(succeeded(result) && "expected region to have no unreachable blocks");
1477 }
1478 
1480  Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
1481  LLVM_DEBUG({
1483  reasonCallback(diag);
1484  logger.startLine() << "** Failure : " << diag.str() << "\n";
1485  if (notifyCallback)
1486  notifyCallback(diag);
1487  });
1488  return failure();
1489 }
1490 
1491 //===----------------------------------------------------------------------===//
1492 // ConversionPatternRewriter
1493 //===----------------------------------------------------------------------===//
1494 
1496  : PatternRewriter(ctx),
1497  impl(new detail::ConversionPatternRewriterImpl(*this)) {}
1499 
1501  Operation *op, ValueRange newValues, bool *allUsesReplaced,
1502  llvm::unique_function<bool(OpOperand &) const> functor) {
1503  // TODO: To support this we will need to rework a bit of how replacements are
1504  // tracked, given that this isn't guranteed to replace all of the uses of an
1505  // operation. The main change is that now an operation can be replaced
1506  // multiple times, in parts. The current "set" based tracking is mainly useful
1507  // for tracking if a replaced operation should be ignored, i.e. if all of the
1508  // uses will be replaced.
1509  llvm_unreachable(
1510  "replaceOpWithIf is currently not supported by DialectConversion");
1511 }
1512 
1514  LLVM_DEBUG({
1515  impl->logger.startLine()
1516  << "** Replace : '" << op->getName() << "'(" << op << ")\n";
1517  });
1518  impl->notifyOpReplaced(op, newValues);
1519 }
1520 
1522  LLVM_DEBUG({
1523  impl->logger.startLine()
1524  << "** Erase : '" << op->getName() << "'(" << op << ")\n";
1525  });
1526  SmallVector<Value, 1> nullRepls(op->getNumResults(), nullptr);
1527  impl->notifyOpReplaced(op, nullRepls);
1528 }
1529 
1531  impl->notifyBlockIsBeingErased(block);
1532 
1533  // Mark all ops for erasure.
1534  for (Operation &op : *block)
1535  eraseOp(&op);
1536 
1537  // Unlink the block from its parent region. The block is kept in the block
1538  // action and will be actually destroyed when rewrites are applied. This
1539  // allows us to keep the operations in the block live and undo the removal by
1540  // re-inserting the block.
1541  block->getParent()->getBlocks().remove(block);
1542 }
1543 
1545  Region *region, TypeConverter::SignatureConversion &conversion,
1546  TypeConverter *converter) {
1547  return impl->applySignatureConversion(region, conversion, converter);
1548 }
1549 
1551  Region *region, TypeConverter &converter,
1552  TypeConverter::SignatureConversion *entryConversion) {
1553  return impl->convertRegionTypes(region, converter, entryConversion);
1554 }
1555 
1557  Region *region, TypeConverter &converter,
1559  return impl->convertNonEntryRegionTypes(region, converter, blockConversions);
1560 }
1561 
1563  Value to) {
1564  LLVM_DEBUG({
1565  Operation *parentOp = from.getOwner()->getParentOp();
1566  impl->logger.startLine() << "** Replace Argument : '" << from
1567  << "'(in region of '" << parentOp->getName()
1568  << "'(" << from.getOwner()->getParentOp() << ")\n";
1569  });
1570  impl->argReplacements.push_back(from);
1571  impl->mapping.map(impl->mapping.lookupOrDefault(from), to);
1572 }
1573 
1575  SmallVector<Value> remappedValues;
1576  if (failed(impl->remapValues("value", /*inputLoc=*/llvm::None, *this, key,
1577  remappedValues)))
1578  return nullptr;
1579  return remappedValues.front();
1580 }
1581 
1584  SmallVectorImpl<Value> &results) {
1585  if (keys.empty())
1586  return success();
1587  return impl->remapValues("value", /*inputLoc=*/llvm::None, *this, keys,
1588  results);
1589 }
1590 
1592  impl->notifyCreatedBlock(block);
1593 }
1594 
1596  Block::iterator before) {
1597  auto *continuation = PatternRewriter::splitBlock(block, before);
1598  impl->notifySplitBlock(block, continuation);
1599  return continuation;
1600 }
1601 
1603  ValueRange argValues) {
1604  impl->notifyBlocksBeingMerged(dest, source);
1605  assert(llvm::all_of(source->getPredecessors(),
1606  [dest](Block *succ) { return succ == dest; }) &&
1607  "expected 'source' to have no predecessors or only 'dest'");
1608  assert(argValues.size() == source->getNumArguments() &&
1609  "incorrect # of argument replacement values");
1610  for (auto it : llvm::zip(source->getArguments(), argValues))
1611  replaceUsesOfBlockArgument(std::get<0>(it), std::get<1>(it));
1612  dest->getOperations().splice(dest->end(), source->getOperations());
1613  eraseBlock(source);
1614 }
1615 
1617  Region &parent,
1618  Region::iterator before) {
1619  impl->notifyRegionIsBeingInlinedBefore(region, parent, before);
1620  PatternRewriter::inlineRegionBefore(region, parent, before);
1621 }
1622 
1624  Region &region, Region &parent, Region::iterator before,
1625  BlockAndValueMapping &mapping) {
1626  if (region.empty())
1627  return;
1628  PatternRewriter::cloneRegionBefore(region, parent, before, mapping);
1629 
1630  // Collect the range of the cloned blocks.
1631  auto clonedBeginIt = mapping.lookup(&region.front())->getIterator();
1632  auto clonedBlocks = llvm::make_range(clonedBeginIt, before);
1633  impl->notifyRegionWasClonedBefore(clonedBlocks, region.getLoc());
1634 }
1635 
1637  LLVM_DEBUG({
1638  impl->logger.startLine()
1639  << "** Insert : '" << op->getName() << "'(" << op << ")\n";
1640  });
1641  impl->createdOps.push_back(op);
1642 }
1643 
1645 #ifndef NDEBUG
1646  impl->pendingRootUpdates.insert(op);
1647 #endif
1648  impl->rootUpdates.emplace_back(op);
1649 }
1650 
1652  // There is nothing to do here, we only need to track the operation at the
1653  // start of the update.
1654 #ifndef NDEBUG
1655  assert(impl->pendingRootUpdates.erase(op) &&
1656  "operation did not have a pending in-place update");
1657 #endif
1658 }
1659 
1661 #ifndef NDEBUG
1662  assert(impl->pendingRootUpdates.erase(op) &&
1663  "operation did not have a pending in-place update");
1664 #endif
1665  // Erase the last update for this operation.
1666  auto stateHasOp = [op](const auto &it) { return it.getOperation() == op; };
1667  auto &rootUpdates = impl->rootUpdates;
1668  auto it = llvm::find_if(llvm::reverse(rootUpdates), stateHasOp);
1669  assert(it != rootUpdates.rend() && "no root update started on op");
1670  (*it).resetOperation();
1671  int updateIdx = std::prev(rootUpdates.rend()) - it;
1672  rootUpdates.erase(rootUpdates.begin() + updateIdx);
1673 }
1674 
1676  Operation *op, function_ref<void(Diagnostic &)> reasonCallback) {
1677  return impl->notifyMatchFailure(op->getLoc(), reasonCallback);
1678 }
1679 
1681  return *impl;
1682 }
1683 
1684 //===----------------------------------------------------------------------===//
1685 // ConversionPattern
1686 //===----------------------------------------------------------------------===//
1687 
1690  PatternRewriter &rewriter) const {
1691  auto &dialectRewriter = static_cast<ConversionPatternRewriter &>(rewriter);
1692  auto &rewriterImpl = dialectRewriter.getImpl();
1693 
1694  // Track the current conversion pattern type converter in the rewriter.
1695  llvm::SaveAndRestore<TypeConverter *> currentConverterGuard(
1696  rewriterImpl.currentTypeConverter, getTypeConverter());
1697 
1698  // Remap the operands of the operation.
1699  SmallVector<Value, 4> operands;
1700  if (failed(rewriterImpl.remapValues("operand", op->getLoc(), rewriter,
1701  op->getOperands(), operands))) {
1702  return failure();
1703  }
1704  return matchAndRewrite(op, operands, dialectRewriter);
1705 }
1706 
1707 //===----------------------------------------------------------------------===//
1708 // OperationLegalizer
1709 //===----------------------------------------------------------------------===//
1710 
1711 namespace {
1712 /// A set of rewrite patterns that can be used to legalize a given operation.
1713 using LegalizationPatterns = SmallVector<const Pattern *, 1>;
1714 
1715 /// This class defines a recursive operation legalizer.
1716 class OperationLegalizer {
1717 public:
1718  using LegalizationAction = ConversionTarget::LegalizationAction;
1719 
1720  OperationLegalizer(ConversionTarget &targetInfo,
1721  const FrozenRewritePatternSet &patterns);
1722 
1723  /// Returns true if the given operation is known to be illegal on the target.
1724  bool isIllegal(Operation *op) const;
1725 
1726  /// Attempt to legalize the given operation. Returns success if the operation
1727  /// was legalized, failure otherwise.
1728  LogicalResult legalize(Operation *op, ConversionPatternRewriter &rewriter);
1729 
1730  /// Returns the conversion target in use by the legalizer.
1731  ConversionTarget &getTarget() { return target; }
1732 
1733 private:
1734  /// Attempt to legalize the given operation by folding it.
1735  LogicalResult legalizeWithFold(Operation *op,
1736  ConversionPatternRewriter &rewriter);
1737 
1738  /// Attempt to legalize the given operation by applying a pattern. Returns
1739  /// success if the operation was legalized, failure otherwise.
1740  LogicalResult legalizeWithPattern(Operation *op,
1741  ConversionPatternRewriter &rewriter);
1742 
1743  /// Return true if the given pattern may be applied to the given operation,
1744  /// false otherwise.
1745  bool canApplyPattern(Operation *op, const Pattern &pattern,
1746  ConversionPatternRewriter &rewriter);
1747 
1748  /// Legalize the resultant IR after successfully applying the given pattern.
1749  LogicalResult legalizePatternResult(Operation *op, const Pattern &pattern,
1750  ConversionPatternRewriter &rewriter,
1751  RewriterState &curState);
1752 
1753  /// Legalizes the actions registered during the execution of a pattern.
1754  LogicalResult legalizePatternBlockActions(Operation *op,
1755  ConversionPatternRewriter &rewriter,
1757  RewriterState &state,
1758  RewriterState &newState);
1759  LogicalResult legalizePatternCreatedOperations(
1761  RewriterState &state, RewriterState &newState);
1762  LogicalResult legalizePatternRootUpdates(ConversionPatternRewriter &rewriter,
1764  RewriterState &state,
1765  RewriterState &newState);
1766 
1767  //===--------------------------------------------------------------------===//
1768  // Cost Model
1769  //===--------------------------------------------------------------------===//
1770 
1771  /// Build an optimistic legalization graph given the provided patterns. This
1772  /// function populates 'anyOpLegalizerPatterns' and 'legalizerPatterns' with
1773  /// patterns for operations that are not directly legal, but may be
1774  /// transitively legal for the current target given the provided patterns.
1775  void buildLegalizationGraph(
1776  LegalizationPatterns &anyOpLegalizerPatterns,
1778 
1779  /// Compute the benefit of each node within the computed legalization graph.
1780  /// This orders the patterns within 'legalizerPatterns' based upon two
1781  /// criteria:
1782  /// 1) Prefer patterns that have the lowest legalization depth, i.e.
1783  /// represent the more direct mapping to the target.
1784  /// 2) When comparing patterns with the same legalization depth, prefer the
1785  /// pattern with the highest PatternBenefit. This allows for users to
1786  /// prefer specific legalizations over others.
1787  void computeLegalizationGraphBenefit(
1788  LegalizationPatterns &anyOpLegalizerPatterns,
1790 
1791  /// Compute the legalization depth when legalizing an operation of the given
1792  /// type.
1793  unsigned computeOpLegalizationDepth(
1794  OperationName op, DenseMap<OperationName, unsigned> &minOpPatternDepth,
1796 
1797  /// Apply the conversion cost model to the given set of patterns, and return
1798  /// the smallest legalization depth of any of the patterns. See
1799  /// `computeLegalizationGraphBenefit` for the breakdown of the cost model.
1800  unsigned applyCostModelToPatterns(
1801  LegalizationPatterns &patterns,
1802  DenseMap<OperationName, unsigned> &minOpPatternDepth,
1804 
1805  /// The current set of patterns that have been applied.
1806  SmallPtrSet<const Pattern *, 8> appliedPatterns;
1807 
1808  /// The legalization information provided by the target.
1809  ConversionTarget &target;
1810 
1811  /// The pattern applicator to use for conversions.
1812  PatternApplicator applicator;
1813 };
1814 } // namespace
1815 
1816 OperationLegalizer::OperationLegalizer(ConversionTarget &targetInfo,
1817  const FrozenRewritePatternSet &patterns)
1818  : target(targetInfo), applicator(patterns) {
1819  // The set of patterns that can be applied to illegal operations to transform
1820  // them into legal ones.
1822  LegalizationPatterns anyOpLegalizerPatterns;
1823 
1824  buildLegalizationGraph(anyOpLegalizerPatterns, legalizerPatterns);
1825  computeLegalizationGraphBenefit(anyOpLegalizerPatterns, legalizerPatterns);
1826 }
1827 
1828 bool OperationLegalizer::isIllegal(Operation *op) const {
1829  return target.isIllegal(op);
1830 }
1831 
1833 OperationLegalizer::legalize(Operation *op,
1834  ConversionPatternRewriter &rewriter) {
1835 #ifndef NDEBUG
1836  const char *logLineComment =
1837  "//===-------------------------------------------===//\n";
1838 
1839  auto &logger = rewriter.getImpl().logger;
1840 #endif
1841  LLVM_DEBUG({
1842  logger.getOStream() << "\n";
1843  logger.startLine() << logLineComment;
1844  logger.startLine() << "Legalizing operation : '" << op->getName() << "'("
1845  << op << ") {\n";
1846  logger.indent();
1847 
1848  // If the operation has no regions, just print it here.
1849  if (op->getNumRegions() == 0) {
1850  op->print(logger.startLine(), OpPrintingFlags().printGenericOpForm());
1851  logger.getOStream() << "\n\n";
1852  }
1853  });
1854 
1855  // Check if this operation is legal on the target.
1856  if (auto legalityInfo = target.isLegal(op)) {
1857  LLVM_DEBUG({
1858  logSuccess(
1859  logger, "operation marked legal by the target{0}",
1860  legalityInfo->isRecursivelyLegal
1861  ? "; NOTE: operation is recursively legal; skipping internals"
1862  : "");
1863  logger.startLine() << logLineComment;
1864  });
1865 
1866  // If this operation is recursively legal, mark its children as ignored so
1867  // that we don't consider them for legalization.
1868  if (legalityInfo->isRecursivelyLegal)
1869  rewriter.getImpl().markNestedOpsIgnored(op);
1870  return success();
1871  }
1872 
1873  // Check to see if the operation is ignored and doesn't need to be converted.
1874  if (rewriter.getImpl().isOpIgnored(op)) {
1875  LLVM_DEBUG({
1876  logSuccess(logger, "operation marked 'ignored' during conversion");
1877  logger.startLine() << logLineComment;
1878  });
1879  return success();
1880  }
1881 
1882  // If the operation isn't legal, try to fold it in-place.
1883  // TODO: Should we always try to do this, even if the op is
1884  // already legal?
1885  if (succeeded(legalizeWithFold(op, rewriter))) {
1886  LLVM_DEBUG({
1887  logSuccess(logger, "operation was folded");
1888  logger.startLine() << logLineComment;
1889  });
1890  return success();
1891  }
1892 
1893  // Otherwise, we need to apply a legalization pattern to this operation.
1894  if (succeeded(legalizeWithPattern(op, rewriter))) {
1895  LLVM_DEBUG({
1896  logSuccess(logger, "");
1897  logger.startLine() << logLineComment;
1898  });
1899  return success();
1900  }
1901 
1902  LLVM_DEBUG({
1903  logFailure(logger, "no matched legalization pattern");
1904  logger.startLine() << logLineComment;
1905  });
1906  return failure();
1907 }
1908 
1910 OperationLegalizer::legalizeWithFold(Operation *op,
1911  ConversionPatternRewriter &rewriter) {
1912  auto &rewriterImpl = rewriter.getImpl();
1913  RewriterState curState = rewriterImpl.getCurrentState();
1914 
1915  LLVM_DEBUG({
1916  rewriterImpl.logger.startLine() << "* Fold {\n";
1917  rewriterImpl.logger.indent();
1918  });
1919 
1920  // Try to fold the operation.
1921  SmallVector<Value, 2> replacementValues;
1922  rewriter.setInsertionPoint(op);
1923  if (failed(rewriter.tryFold(op, replacementValues))) {
1924  LLVM_DEBUG(logFailure(rewriterImpl.logger, "unable to fold"));
1925  return failure();
1926  }
1927 
1928  // Insert a replacement for 'op' with the folded replacement values.
1929  rewriter.replaceOp(op, replacementValues);
1930 
1931  // Recursively legalize any new constant operations.
1932  for (unsigned i = curState.numCreatedOps, e = rewriterImpl.createdOps.size();
1933  i != e; ++i) {
1934  Operation *cstOp = rewriterImpl.createdOps[i];
1935  if (failed(legalize(cstOp, rewriter))) {
1936  LLVM_DEBUG(logFailure(rewriterImpl.logger,
1937  "failed to legalize generated constant '{0}'",
1938  cstOp->getName()));
1939  rewriterImpl.resetState(curState);
1940  return failure();
1941  }
1942  }
1943 
1944  LLVM_DEBUG(logSuccess(rewriterImpl.logger, ""));
1945  return success();
1946 }
1947 
1949 OperationLegalizer::legalizeWithPattern(Operation *op,
1950  ConversionPatternRewriter &rewriter) {
1951  auto &rewriterImpl = rewriter.getImpl();
1952 
1953  // Functor that returns if the given pattern may be applied.
1954  auto canApply = [&](const Pattern &pattern) {
1955  return canApplyPattern(op, pattern, rewriter);
1956  };
1957 
1958  // Functor that cleans up the rewriter state after a pattern failed to match.
1959  RewriterState curState = rewriterImpl.getCurrentState();
1960  auto onFailure = [&](const Pattern &pattern) {
1961  LLVM_DEBUG({
1962  logFailure(rewriterImpl.logger, "pattern failed to match");
1963  if (rewriterImpl.notifyCallback) {
1965  diag << "Failed to apply pattern \"" << pattern.getDebugName()
1966  << "\" on op:\n"
1967  << *op;
1968  rewriterImpl.notifyCallback(diag);
1969  }
1970  });
1971  rewriterImpl.resetState(curState);
1972  appliedPatterns.erase(&pattern);
1973  };
1974 
1975  // Functor that performs additional legalization when a pattern is
1976  // successfully applied.
1977  auto onSuccess = [&](const Pattern &pattern) {
1978  auto result = legalizePatternResult(op, pattern, rewriter, curState);
1979  appliedPatterns.erase(&pattern);
1980  if (failed(result))
1981  rewriterImpl.resetState(curState);
1982  return result;
1983  };
1984 
1985  // Try to match and rewrite a pattern on this operation.
1986  return applicator.matchAndRewrite(op, rewriter, canApply, onFailure,
1987  onSuccess);
1988 }
1989 
1990 bool OperationLegalizer::canApplyPattern(Operation *op, const Pattern &pattern,
1991  ConversionPatternRewriter &rewriter) {
1992  LLVM_DEBUG({
1993  auto &os = rewriter.getImpl().logger;
1994  os.getOStream() << "\n";
1995  os.startLine() << "* Pattern : '" << op->getName() << " -> (";
1996  llvm::interleaveComma(pattern.getGeneratedOps(), os.getOStream());
1997  os.getOStream() << ")' {\n";
1998  os.indent();
1999  });
2000 
2001  // Ensure that we don't cycle by not allowing the same pattern to be
2002  // applied twice in the same recursion stack if it is not known to be safe.
2003  if (!pattern.hasBoundedRewriteRecursion() &&
2004  !appliedPatterns.insert(&pattern).second) {
2005  LLVM_DEBUG(
2006  logFailure(rewriter.getImpl().logger, "pattern was already applied"));
2007  return false;
2008  }
2009  return true;
2010 }
2011 
2013 OperationLegalizer::legalizePatternResult(Operation *op, const Pattern &pattern,
2014  ConversionPatternRewriter &rewriter,
2015  RewriterState &curState) {
2016  auto &impl = rewriter.getImpl();
2017 
2018 #ifndef NDEBUG
2019  assert(impl.pendingRootUpdates.empty() && "dangling root updates");
2020 #endif
2021 
2022  // Check that the root was either replaced or updated in place.
2023  auto replacedRoot = [&] {
2024  return llvm::any_of(
2025  llvm::drop_begin(impl.replacements, curState.numReplacements),
2026  [op](auto &it) { return it.first == op; });
2027  };
2028  auto updatedRootInPlace = [&] {
2029  return llvm::any_of(
2030  llvm::drop_begin(impl.rootUpdates, curState.numRootUpdates),
2031  [op](auto &state) { return state.getOperation() == op; });
2032  };
2033  (void)replacedRoot;
2034  (void)updatedRootInPlace;
2035  assert((replacedRoot() || updatedRootInPlace()) &&
2036  "expected pattern to replace the root operation");
2037 
2038  // Legalize each of the actions registered during application.
2039  RewriterState newState = impl.getCurrentState();
2040  if (failed(legalizePatternBlockActions(op, rewriter, impl, curState,
2041  newState)) ||
2042  failed(legalizePatternRootUpdates(rewriter, impl, curState, newState)) ||
2043  failed(legalizePatternCreatedOperations(rewriter, impl, curState,
2044  newState))) {
2045  return failure();
2046  }
2047 
2048  LLVM_DEBUG(logSuccess(impl.logger, "pattern applied successfully"));
2049  return success();
2050 }
2051 
2052 LogicalResult OperationLegalizer::legalizePatternBlockActions(
2053  Operation *op, ConversionPatternRewriter &rewriter,
2054  ConversionPatternRewriterImpl &impl, RewriterState &state,
2055  RewriterState &newState) {
2056  SmallPtrSet<Operation *, 16> operationsToIgnore;
2057 
2058  // If the pattern moved or created any blocks, make sure the types of block
2059  // arguments get legalized.
2060  for (int i = state.numBlockActions, e = newState.numBlockActions; i != e;
2061  ++i) {
2062  auto &action = impl.blockActions[i];
2063  if (action.kind == BlockActionKind::TypeConversion ||
2064  action.kind == BlockActionKind::Erase)
2065  continue;
2066  // Only check blocks outside of the current operation.
2067  Operation *parentOp = action.block->getParentOp();
2068  if (!parentOp || parentOp == op || action.block->getNumArguments() == 0)
2069  continue;
2070 
2071  // If the region of the block has a type converter, try to convert the block
2072  // directly.
2073  if (auto *converter =
2074  impl.argConverter.getConverter(action.block->getParent())) {
2075  if (failed(impl.convertBlockSignature(action.block, converter))) {
2076  LLVM_DEBUG(logFailure(impl.logger, "failed to convert types of moved "
2077  "block"));
2078  return failure();
2079  }
2080  continue;
2081  }
2082 
2083  // Otherwise, check that this operation isn't one generated by this pattern.
2084  // This is because we will attempt to legalize the parent operation, and
2085  // blocks in regions created by this pattern will already be legalized later
2086  // on. If we haven't built the set yet, build it now.
2087  if (operationsToIgnore.empty()) {
2088  auto createdOps = ArrayRef<Operation *>(impl.createdOps)
2089  .drop_front(state.numCreatedOps);
2090  operationsToIgnore.insert(createdOps.begin(), createdOps.end());
2091  }
2092 
2093  // If this operation should be considered for re-legalization, try it.
2094  if (operationsToIgnore.insert(parentOp).second &&
2095  failed(legalize(parentOp, rewriter))) {
2096  LLVM_DEBUG(logFailure(
2097  impl.logger, "operation '{0}'({1}) became illegal after block action",
2098  parentOp->getName(), parentOp));
2099  return failure();
2100  }
2101  }
2102  return success();
2103 }
2104 
2105 LogicalResult OperationLegalizer::legalizePatternCreatedOperations(
2107  RewriterState &state, RewriterState &newState) {
2108  for (int i = state.numCreatedOps, e = newState.numCreatedOps; i != e; ++i) {
2109  Operation *op = impl.createdOps[i];
2110  if (failed(legalize(op, rewriter))) {
2111  LLVM_DEBUG(logFailure(impl.logger,
2112  "failed to legalize generated operation '{0}'({1})",
2113  op->getName(), op));
2114  return failure();
2115  }
2116  }
2117  return success();
2118 }
2119 
2120 LogicalResult OperationLegalizer::legalizePatternRootUpdates(
2122  RewriterState &state, RewriterState &newState) {
2123  for (int i = state.numRootUpdates, e = newState.numRootUpdates; i != e; ++i) {
2124  Operation *op = impl.rootUpdates[i].getOperation();
2125  if (failed(legalize(op, rewriter))) {
2126  LLVM_DEBUG(logFailure(
2127  impl.logger, "failed to legalize operation updated in-place '{0}'",
2128  op->getName()));
2129  return failure();
2130  }
2131  }
2132  return success();
2133 }
2134 
2135 //===----------------------------------------------------------------------===//
2136 // Cost Model
2137 
2138 void OperationLegalizer::buildLegalizationGraph(
2139  LegalizationPatterns &anyOpLegalizerPatterns,
2140  DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
2141  // A mapping between an operation and a set of operations that can be used to
2142  // generate it.
2144  // A mapping between an operation and any currently invalid patterns it has.
2146  // A worklist of patterns to consider for legality.
2147  SetVector<const Pattern *> patternWorklist;
2148 
2149  // Build the mapping from operations to the parent ops that may generate them.
2150  applicator.walkAllPatterns([&](const Pattern &pattern) {
2151  Optional<OperationName> root = pattern.getRootKind();
2152 
2153  // If the pattern has no specific root, we can't analyze the relationship
2154  // between the root op and generated operations. Given that, add all such
2155  // patterns to the legalization set.
2156  if (!root) {
2157  anyOpLegalizerPatterns.push_back(&pattern);
2158  return;
2159  }
2160 
2161  // Skip operations that are always known to be legal.
2162  if (target.getOpAction(*root) == LegalizationAction::Legal)
2163  return;
2164 
2165  // Add this pattern to the invalid set for the root op and record this root
2166  // as a parent for any generated operations.
2167  invalidPatterns[*root].insert(&pattern);
2168  for (auto op : pattern.getGeneratedOps())
2169  parentOps[op].insert(*root);
2170 
2171  // Add this pattern to the worklist.
2172  patternWorklist.insert(&pattern);
2173  });
2174 
2175  // If there are any patterns that don't have a specific root kind, we can't
2176  // make direct assumptions about what operations will never be legalized.
2177  // Note: Technically we could, but it would require an analysis that may
2178  // recurse into itself. It would be better to perform this kind of filtering
2179  // at a higher level than here anyways.
2180  if (!anyOpLegalizerPatterns.empty()) {
2181  for (const Pattern *pattern : patternWorklist)
2182  legalizerPatterns[*pattern->getRootKind()].push_back(pattern);
2183  return;
2184  }
2185 
2186  while (!patternWorklist.empty()) {
2187  auto *pattern = patternWorklist.pop_back_val();
2188 
2189  // Check to see if any of the generated operations are invalid.
2190  if (llvm::any_of(pattern->getGeneratedOps(), [&](OperationName op) {
2191  Optional<LegalizationAction> action = target.getOpAction(op);
2192  return !legalizerPatterns.count(op) &&
2193  (!action || action == LegalizationAction::Illegal);
2194  }))
2195  continue;
2196 
2197  // Otherwise, if all of the generated operation are valid, this op is now
2198  // legal so add all of the child patterns to the worklist.
2199  legalizerPatterns[*pattern->getRootKind()].push_back(pattern);
2200  invalidPatterns[*pattern->getRootKind()].erase(pattern);
2201 
2202  // Add any invalid patterns of the parent operations to see if they have now
2203  // become legal.
2204  for (auto op : parentOps[*pattern->getRootKind()])
2205  patternWorklist.set_union(invalidPatterns[op]);
2206  }
2207 }
2208 
2209 void OperationLegalizer::computeLegalizationGraphBenefit(
2210  LegalizationPatterns &anyOpLegalizerPatterns,
2211  DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
2212  // The smallest pattern depth, when legalizing an operation.
2213  DenseMap<OperationName, unsigned> minOpPatternDepth;
2214 
2215  // For each operation that is transitively legal, compute a cost for it.
2216  for (auto &opIt : legalizerPatterns)
2217  if (!minOpPatternDepth.count(opIt.first))
2218  computeOpLegalizationDepth(opIt.first, minOpPatternDepth,
2219  legalizerPatterns);
2220 
2221  // Apply the cost model to the patterns that can match any operation. Those
2222  // with a specific operation type are already resolved when computing the op
2223  // legalization depth.
2224  if (!anyOpLegalizerPatterns.empty())
2225  applyCostModelToPatterns(anyOpLegalizerPatterns, minOpPatternDepth,
2226  legalizerPatterns);
2227 
2228  // Apply a cost model to the pattern applicator. We order patterns first by
2229  // depth then benefit. `legalizerPatterns` contains per-op patterns by
2230  // decreasing benefit.
2231  applicator.applyCostModel([&](const Pattern &pattern) {
2232  ArrayRef<const Pattern *> orderedPatternList;
2233  if (Optional<OperationName> rootName = pattern.getRootKind())
2234  orderedPatternList = legalizerPatterns[*rootName];
2235  else
2236  orderedPatternList = anyOpLegalizerPatterns;
2237 
2238  // If the pattern is not found, then it was removed and cannot be matched.
2239  auto *it = llvm::find(orderedPatternList, &pattern);
2240  if (it == orderedPatternList.end())
2242 
2243  // Patterns found earlier in the list have higher benefit.
2244  return PatternBenefit(std::distance(it, orderedPatternList.end()));
2245  });
2246 }
2247 
2248 unsigned OperationLegalizer::computeOpLegalizationDepth(
2249  OperationName op, DenseMap<OperationName, unsigned> &minOpPatternDepth,
2250  DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
2251  // Check for existing depth.
2252  auto depthIt = minOpPatternDepth.find(op);
2253  if (depthIt != minOpPatternDepth.end())
2254  return depthIt->second;
2255 
2256  // If a mapping for this operation does not exist, then this operation
2257  // is always legal. Return 0 as the depth for a directly legal operation.
2258  auto opPatternsIt = legalizerPatterns.find(op);
2259  if (opPatternsIt == legalizerPatterns.end() || opPatternsIt->second.empty())
2260  return 0u;
2261 
2262  // Record this initial depth in case we encounter this op again when
2263  // recursively computing the depth.
2264  minOpPatternDepth.try_emplace(op, std::numeric_limits<unsigned>::max());
2265 
2266  // Apply the cost model to the operation patterns, and update the minimum
2267  // depth.
2268  unsigned minDepth = applyCostModelToPatterns(
2269  opPatternsIt->second, minOpPatternDepth, legalizerPatterns);
2270  minOpPatternDepth[op] = minDepth;
2271  return minDepth;
2272 }
2273 
2274 unsigned OperationLegalizer::applyCostModelToPatterns(
2275  LegalizationPatterns &patterns,
2276  DenseMap<OperationName, unsigned> &minOpPatternDepth,
2277  DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
2278  unsigned minDepth = std::numeric_limits<unsigned>::max();
2279 
2280  // Compute the depth for each pattern within the set.
2282  patternsByDepth.reserve(patterns.size());
2283  for (const Pattern *pattern : patterns) {
2284  unsigned depth = 1;
2285  for (auto generatedOp : pattern->getGeneratedOps()) {
2286  unsigned generatedOpDepth = computeOpLegalizationDepth(
2287  generatedOp, minOpPatternDepth, legalizerPatterns);
2288  depth = std::max(depth, generatedOpDepth + 1);
2289  }
2290  patternsByDepth.emplace_back(pattern, depth);
2291 
2292  // Update the minimum depth of the pattern list.
2293  minDepth = std::min(minDepth, depth);
2294  }
2295 
2296  // If the operation only has one legalization pattern, there is no need to
2297  // sort them.
2298  if (patternsByDepth.size() == 1)
2299  return minDepth;
2300 
2301  // Sort the patterns by those likely to be the most beneficial.
2302  llvm::array_pod_sort(patternsByDepth.begin(), patternsByDepth.end(),
2303  [](const std::pair<const Pattern *, unsigned> *lhs,
2304  const std::pair<const Pattern *, unsigned> *rhs) {
2305  // First sort by the smaller pattern legalization
2306  // depth.
2307  if (lhs->second != rhs->second)
2308  return llvm::array_pod_sort_comparator<unsigned>(
2309  &lhs->second, &rhs->second);
2310 
2311  // Then sort by the larger pattern benefit.
2312  auto lhsBenefit = lhs->first->getBenefit();
2313  auto rhsBenefit = rhs->first->getBenefit();
2314  return llvm::array_pod_sort_comparator<PatternBenefit>(
2315  &rhsBenefit, &lhsBenefit);
2316  });
2317 
2318  // Update the legalization pattern to use the new sorted list.
2319  patterns.clear();
2320  for (auto &patternIt : patternsByDepth)
2321  patterns.push_back(patternIt.first);
2322  return minDepth;
2323 }
2324 
2325 //===----------------------------------------------------------------------===//
2326 // OperationConverter
2327 //===----------------------------------------------------------------------===//
2328 namespace {
2330  /// In this mode, the conversion will ignore failed conversions to allow
2331  /// illegal operations to co-exist in the IR.
2332  Partial,
2333 
2334  /// In this mode, all operations must be legal for the given target for the
2335  /// conversion to succeed.
2336  Full,
2337 
2338  /// In this mode, operations are analyzed for legality. No actual rewrites are
2339  /// applied to the operations on success.
2340  Analysis,
2341 };
2342 
2343 // This class converts operations to a given conversion target via a set of
2344 // rewrite patterns. The conversion behaves differently depending on the
2345 // conversion mode.
2346 struct OperationConverter {
2347  explicit OperationConverter(ConversionTarget &target,
2348  const FrozenRewritePatternSet &patterns,
2349  OpConversionMode mode,
2350  DenseSet<Operation *> *trackedOps = nullptr)
2351  : opLegalizer(target, patterns), mode(mode), trackedOps(trackedOps) {}
2352 
2353  /// Converts the given operations to the conversion target.
2355  convertOperations(ArrayRef<Operation *> ops,
2356  function_ref<void(Diagnostic &)> notifyCallback = nullptr);
2357 
2358 private:
2359  /// Converts an operation with the given rewriter.
2360  LogicalResult convert(ConversionPatternRewriter &rewriter, Operation *op);
2361 
2362  /// This method is called after the conversion process to legalize any
2363  /// remaining artifacts and complete the conversion.
2364  LogicalResult finalize(ConversionPatternRewriter &rewriter);
2365 
2366  /// Legalize the types of converted block arguments.
2368  legalizeConvertedArgumentTypes(ConversionPatternRewriter &rewriter,
2369  ConversionPatternRewriterImpl &rewriterImpl);
2370 
2371  /// Legalize any unresolved type materializations.
2372  LogicalResult legalizeUnresolvedMaterializations(
2373  ConversionPatternRewriter &rewriter,
2374  ConversionPatternRewriterImpl &rewriterImpl,
2375  Optional<DenseMap<Value, SmallVector<Value>>> &inverseMapping);
2376 
2377  /// Legalize an operation result that was marked as "erased".
2379  legalizeErasedResult(Operation *op, OpResult result,
2380  ConversionPatternRewriterImpl &rewriterImpl);
2381 
2382  /// Legalize an operation result that was replaced with a value of a different
2383  /// type.
2384  LogicalResult legalizeChangedResultType(
2385  Operation *op, OpResult result, Value newValue,
2386  TypeConverter *replConverter, ConversionPatternRewriter &rewriter,
2387  ConversionPatternRewriterImpl &rewriterImpl,
2388  const DenseMap<Value, SmallVector<Value>> &inverseMapping);
2389 
2390  /// The legalizer to use when converting operations.
2391  OperationLegalizer opLegalizer;
2392 
2393  /// The conversion mode to use when legalizing operations.
2394  OpConversionMode mode;
2395 
2396  /// A set of pre-existing operations. When mode == OpConversionMode::Analysis,
2397  /// this is populated with ops found to be legalizable to the target.
2398  /// When mode == OpConversionMode::Partial, this is populated with ops found
2399  /// *not* to be legalizable to the target.
2400  DenseSet<Operation *> *trackedOps;
2401 };
2402 } // namespace
2403 
2404 LogicalResult OperationConverter::convert(ConversionPatternRewriter &rewriter,
2405  Operation *op) {
2406  // Legalize the given operation.
2407  if (failed(opLegalizer.legalize(op, rewriter))) {
2408  // Handle the case of a failed conversion for each of the different modes.
2409  // Full conversions expect all operations to be converted.
2410  if (mode == OpConversionMode::Full)
2411  return op->emitError()
2412  << "failed to legalize operation '" << op->getName() << "'";
2413  // Partial conversions allow conversions to fail iff the operation was not
2414  // explicitly marked as illegal. If the user provided a nonlegalizableOps
2415  // set, non-legalizable ops are included.
2416  if (mode == OpConversionMode::Partial) {
2417  if (opLegalizer.isIllegal(op))
2418  return op->emitError()
2419  << "failed to legalize operation '" << op->getName()
2420  << "' that was explicitly marked illegal";
2421  if (trackedOps)
2422  trackedOps->insert(op);
2423  }
2424  } else if (mode == OpConversionMode::Analysis) {
2425  // Analysis conversions don't fail if any operations fail to legalize,
2426  // they are only interested in the operations that were successfully
2427  // legalized.
2428  trackedOps->insert(op);
2429  }
2430  return success();
2431 }
2432 
2433 LogicalResult OperationConverter::convertOperations(
2435  function_ref<void(Diagnostic &)> notifyCallback) {
2436  if (ops.empty())
2437  return success();
2438  ConversionTarget &target = opLegalizer.getTarget();
2439 
2440  // Compute the set of operations and blocks to convert.
2441  SmallVector<Operation *> toConvert;
2442  for (auto *op : ops) {
2443  toConvert.emplace_back(op);
2444  for (auto &region : op->getRegions())
2445  if (failed(computeConversionSet(region.getBlocks(), region.getLoc(),
2446  toConvert, &target)))
2447  return failure();
2448  }
2449 
2450  // Convert each operation and discard rewrites on failure.
2451  ConversionPatternRewriter rewriter(ops.front()->getContext());
2452  ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
2453  rewriterImpl.notifyCallback = notifyCallback;
2454 
2455  for (auto *op : toConvert)
2456  if (failed(convert(rewriter, op)))
2457  return rewriterImpl.discardRewrites(), failure();
2458 
2459  // Now that all of the operations have been converted, finalize the conversion
2460  // process to ensure any lingering conversion artifacts are cleaned up and
2461  // legalized.
2462  if (failed(finalize(rewriter)))
2463  return rewriterImpl.discardRewrites(), failure();
2464 
2465  // After a successful conversion, apply rewrites if this is not an analysis
2466  // conversion.
2467  if (mode == OpConversionMode::Analysis) {
2468  rewriterImpl.discardRewrites();
2469  } else {
2470  rewriterImpl.applyRewrites();
2471 
2472  // It is possible for a later pattern to erase an op that was originally
2473  // identified as illegal and added to the trackedOps, remove it now after
2474  // replacements have been computed.
2475  if (trackedOps)
2476  for (auto &repl : rewriterImpl.replacements)
2477  trackedOps->erase(repl.first);
2478  }
2479  return success();
2480 }
2481 
2483 OperationConverter::finalize(ConversionPatternRewriter &rewriter) {
2485  ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
2486  if (failed(legalizeUnresolvedMaterializations(rewriter, rewriterImpl,
2487  inverseMapping)) ||
2488  failed(legalizeConvertedArgumentTypes(rewriter, rewriterImpl)))
2489  return failure();
2490 
2491  if (rewriterImpl.operationsWithChangedResults.empty())
2492  return success();
2493 
2494  // Process requested operation replacements.
2495  for (unsigned i = 0, e = rewriterImpl.operationsWithChangedResults.size();
2496  i != e; ++i) {
2497  unsigned replIdx = rewriterImpl.operationsWithChangedResults[i];
2498  auto &repl = *(rewriterImpl.replacements.begin() + replIdx);
2499  for (OpResult result : repl.first->getResults()) {
2500  Value newValue = rewriterImpl.mapping.lookupOrNull(result);
2501 
2502  // If the operation result was replaced with null, all of the uses of this
2503  // value should be replaced.
2504  if (!newValue) {
2505  if (failed(legalizeErasedResult(repl.first, result, rewriterImpl)))
2506  return failure();
2507  continue;
2508  }
2509 
2510  // Otherwise, check to see if the type of the result changed.
2511  if (result.getType() == newValue.getType())
2512  continue;
2513 
2514  // Compute the inverse mapping only if it is really needed.
2515  if (!inverseMapping)
2516  inverseMapping = rewriterImpl.mapping.getInverse();
2517 
2518  // Legalize this result.
2519  rewriter.setInsertionPoint(repl.first);
2520  if (failed(legalizeChangedResultType(repl.first, result, newValue,
2521  repl.second.converter, rewriter,
2522  rewriterImpl, *inverseMapping)))
2523  return failure();
2524 
2525  // Update the end iterator for this loop in the case it was updated
2526  // when legalizing generated conversion operations.
2527  e = rewriterImpl.operationsWithChangedResults.size();
2528  }
2529  }
2530  return success();
2531 }
2532 
2533 LogicalResult OperationConverter::legalizeConvertedArgumentTypes(
2534  ConversionPatternRewriter &rewriter,
2535  ConversionPatternRewriterImpl &rewriterImpl) {
2536  // Functor used to check if all users of a value will be dead after
2537  // conversion.
2538  auto findLiveUser = [&](Value val) {
2539  auto liveUserIt = llvm::find_if_not(val.getUsers(), [&](Operation *user) {
2540  return rewriterImpl.isOpIgnored(user);
2541  });
2542  return liveUserIt == val.user_end() ? nullptr : *liveUserIt;
2543  };
2544  return rewriterImpl.argConverter.materializeLiveConversions(
2545  rewriterImpl.mapping, rewriter, findLiveUser);
2546 }
2547 
2548 /// Replace the results of a materialization operation with the given values.
2549 static void
2551  ResultRange matResults, ValueRange values,
2552  DenseMap<Value, SmallVector<Value>> &inverseMapping) {
2553  matResults.replaceAllUsesWith(values);
2554 
2555  // For each of the materialization results, update the inverse mappings to
2556  // point to the replacement values.
2557  for (auto it : llvm::zip(matResults, values)) {
2558  Value matResult, newValue;
2559  std::tie(matResult, newValue) = it;
2560  auto inverseMapIt = inverseMapping.find(matResult);
2561  if (inverseMapIt == inverseMapping.end())
2562  continue;
2563 
2564  // Update the reverse mapping, or remove the mapping if we couldn't update
2565  // it. Not being able to update signals that the mapping would have become
2566  // circular (i.e. %foo -> newValue -> %foo), which may occur as values are
2567  // propagated through temporary materializations. We simply drop the
2568  // mapping, and let the post-conversion replacement logic handle updating
2569  // uses.
2570  for (Value inverseMapVal : inverseMapIt->second)
2571  if (!rewriterImpl.mapping.tryMap(inverseMapVal, newValue))
2572  rewriterImpl.mapping.erase(inverseMapVal);
2573  }
2574 }
2575 
2576 /// Compute all of the unresolved materializations that will persist beyond the
2577 /// conversion process, and require inserting a proper user materialization for.
2580  ConversionPatternRewriter &rewriter,
2581  ConversionPatternRewriterImpl &rewriterImpl,
2582  DenseMap<Value, SmallVector<Value>> &inverseMapping,
2583  SetVector<UnresolvedMaterialization *> &necessaryMaterializations) {
2584  auto isLive = [&](Value value) {
2585  auto findFn = [&](Operation *user) {
2586  auto matIt = materializationOps.find(user);
2587  if (matIt != materializationOps.end())
2588  return !necessaryMaterializations.count(matIt->second);
2589  return rewriterImpl.isOpIgnored(user);
2590  };
2591  return llvm::find_if_not(value.getUsers(), findFn) != value.user_end();
2592  };
2593 
2594  llvm::unique_function<Value(Value, Value, Type)> lookupRemappedValue =
2595  [&](Value invalidRoot, Value value, Type type) {
2596  // Check to see if the input operation was remapped to a variant of the
2597  // output.
2598  Value remappedValue = rewriterImpl.mapping.lookupOrDefault(value, type);
2599  if (remappedValue.getType() == type && remappedValue != invalidRoot)
2600  return remappedValue;
2601 
2602  // Check to see if the input is a materialization operation that
2603  // provides an inverse conversion. We just check blindly for
2604  // UnrealizedConversionCastOp here, but it has no effect on correctness.
2605  auto inputCastOp = value.getDefiningOp<UnrealizedConversionCastOp>();
2606  if (inputCastOp && inputCastOp->getNumOperands() == 1)
2607  return lookupRemappedValue(invalidRoot, inputCastOp->getOperand(0),
2608  type);
2609 
2610  return Value();
2611  };
2612 
2614  for (auto &mat : rewriterImpl.unresolvedMaterializations) {
2615  materializationOps.try_emplace(mat.getOp(), &mat);
2616  worklist.insert(&mat);
2617  }
2618  while (!worklist.empty()) {
2619  UnresolvedMaterialization *mat = worklist.pop_back_val();
2620  UnrealizedConversionCastOp op = mat->getOp();
2621 
2622  // We currently only handle target materializations here.
2623  assert(op->getNumResults() == 1 && "unexpected materialization type");
2624  OpResult opResult = op->getOpResult(0);
2625  Type outputType = opResult.getType();
2626  Operation::operand_range inputOperands = op.getOperands();
2627 
2628  // Try to forward propagate operands for user conversion casts that result
2629  // in the input types of the current cast.
2630  for (Operation *user : llvm::make_early_inc_range(opResult.getUsers())) {
2631  auto castOp = dyn_cast<UnrealizedConversionCastOp>(user);
2632  if (!castOp)
2633  continue;
2634  if (castOp->getResultTypes() == inputOperands.getTypes()) {
2635  replaceMaterialization(rewriterImpl, opResult, inputOperands,
2636  inverseMapping);
2637  necessaryMaterializations.remove(materializationOps.lookup(user));
2638  }
2639  }
2640 
2641  // Try to avoid materializing a resolved materialization if possible.
2642  // Handle the case of a 1-1 materialization.
2643  if (inputOperands.size() == 1) {
2644  // Check to see if the input operation was remapped to a variant of the
2645  // output.
2646  Value remappedValue =
2647  lookupRemappedValue(opResult, inputOperands[0], outputType);
2648  if (remappedValue && remappedValue != opResult) {
2649  replaceMaterialization(rewriterImpl, opResult, remappedValue,
2650  inverseMapping);
2651  necessaryMaterializations.remove(mat);
2652  continue;
2653  }
2654  } else {
2655  // TODO: Avoid materializing other types of conversions here.
2656  }
2657 
2658  // Check to see if this is an argument materialization.
2659  auto isBlockArg = [](Value v) { return v.isa<BlockArgument>(); };
2660  if (llvm::any_of(op->getOperands(), isBlockArg) ||
2661  llvm::any_of(inverseMapping[op->getResult(0)], isBlockArg)) {
2663  }
2664 
2665  // If the materialization does not have any live users, we don't need to
2666  // generate a user materialization for it.
2667  // FIXME: For argument materializations, we currently need to check if any
2668  // of the inverse mapped values are used because some patterns expect blind
2669  // value replacement even if the types differ in some cases. When those
2670  // patterns are fixed, we can drop the argument special case here.
2671  bool isMaterializationLive = isLive(opResult);
2672  if (mat->getKind() == UnresolvedMaterialization::Argument)
2673  isMaterializationLive |= llvm::any_of(inverseMapping[opResult], isLive);
2674  if (!isMaterializationLive)
2675  continue;
2676  if (!necessaryMaterializations.insert(mat))
2677  continue;
2678 
2679  // Reprocess input materializations to see if they have an updated status.
2680  for (Value input : inputOperands) {
2681  if (auto parentOp = input.getDefiningOp<UnrealizedConversionCastOp>()) {
2682  if (auto *mat = materializationOps.lookup(parentOp))
2683  worklist.insert(mat);
2684  }
2685  }
2686  }
2687 }
2688 
2689 /// Legalize the given unresolved materialization. Returns success if the
2690 /// materialization was legalized, failure otherise.
2692  UnresolvedMaterialization &mat,
2694  ConversionPatternRewriter &rewriter,
2695  ConversionPatternRewriterImpl &rewriterImpl,
2696  DenseMap<Value, SmallVector<Value>> &inverseMapping) {
2697  auto findLiveUser = [&](auto &&users) {
2698  auto liveUserIt = llvm::find_if_not(
2699  users, [&](Operation *user) { return rewriterImpl.isOpIgnored(user); });
2700  return liveUserIt == users.end() ? nullptr : *liveUserIt;
2701  };
2702 
2703  llvm::unique_function<Value(Value, Type)> lookupRemappedValue =
2704  [&](Value value, Type type) {
2705  // Check to see if the input operation was remapped to a variant of the
2706  // output.
2707  Value remappedValue = rewriterImpl.mapping.lookupOrDefault(value, type);
2708  if (remappedValue.getType() == type)
2709  return remappedValue;
2710  return Value();
2711  };
2712 
2713  UnrealizedConversionCastOp op = mat.getOp();
2714  if (!rewriterImpl.ignoredOps.insert(op))
2715  return success();
2716 
2717  // We currently only handle target materializations here.
2718  OpResult opResult = op->getOpResult(0);
2719  Operation::operand_range inputOperands = op.getOperands();
2720  Type outputType = opResult.getType();
2721 
2722  // If any input to this materialization is another materialization, resolve
2723  // the input first.
2724  for (Value value : op->getOperands()) {
2725  auto valueCast = value.getDefiningOp<UnrealizedConversionCastOp>();
2726  if (!valueCast)
2727  continue;
2728 
2729  auto matIt = materializationOps.find(valueCast);
2730  if (matIt != materializationOps.end())
2732  *matIt->second, materializationOps, rewriter, rewriterImpl,
2733  inverseMapping)))
2734  return failure();
2735  }
2736 
2737  // Perform a last ditch attempt to avoid materializing a resolved
2738  // materialization if possible.
2739  // Handle the case of a 1-1 materialization.
2740  if (inputOperands.size() == 1) {
2741  // Check to see if the input operation was remapped to a variant of the
2742  // output.
2743  Value remappedValue = lookupRemappedValue(inputOperands[0], outputType);
2744  if (remappedValue && remappedValue != opResult) {
2745  replaceMaterialization(rewriterImpl, opResult, remappedValue,
2746  inverseMapping);
2747  return success();
2748  }
2749  } else {
2750  // TODO: Avoid materializing other types of conversions here.
2751  }
2752 
2753  // Try to materialize the conversion.
2754  if (TypeConverter *converter = mat.getConverter()) {
2755  // FIXME: Determine a suitable insertion location when there are multiple
2756  // inputs.
2757  if (inputOperands.size() == 1)
2758  rewriter.setInsertionPointAfterValue(inputOperands.front());
2759  else
2760  rewriter.setInsertionPoint(op);
2761 
2762  Value newMaterialization;
2763  switch (mat.getKind()) {
2765  // Try to materialize an argument conversion.
2766  // FIXME: The current argument materialization hook expects the original
2767  // output type, even though it doesn't use that as the actual output type
2768  // of the generated IR. The output type is just used as an indicator of
2769  // the type of materialization to do. This behavior is really awkward in
2770  // that it diverges from the behavior of the other hooks, and can be
2771  // easily misunderstood. We should clean up the argument hooks to better
2772  // represent the desired invariants we actually care about.
2773  newMaterialization = converter->materializeArgumentConversion(
2774  rewriter, op->getLoc(), mat.getOrigOutputType(), inputOperands);
2775  if (newMaterialization)
2776  break;
2777 
2778  // If an argument materialization failed, fallback to trying a target
2779  // materialization.
2780  LLVM_FALLTHROUGH;
2781  case UnresolvedMaterialization::Target:
2782  newMaterialization = converter->materializeTargetConversion(
2783  rewriter, op->getLoc(), outputType, inputOperands);
2784  break;
2785  }
2786  if (newMaterialization) {
2787  replaceMaterialization(rewriterImpl, opResult, newMaterialization,
2788  inverseMapping);
2789  return success();
2790  }
2791  }
2792 
2793  InFlightDiagnostic diag = op->emitError()
2794  << "failed to legalize unresolved materialization "
2795  "from "
2796  << inputOperands.getTypes() << " to " << outputType
2797  << " that remained live after conversion";
2798  if (Operation *liveUser = findLiveUser(op->getUsers())) {
2799  diag.attachNote(liveUser->getLoc())
2800  << "see existing live user here: " << *liveUser;
2801  }
2802  return failure();
2803 }
2804 
2805 LogicalResult OperationConverter::legalizeUnresolvedMaterializations(
2806  ConversionPatternRewriter &rewriter,
2807  ConversionPatternRewriterImpl &rewriterImpl,
2808  Optional<DenseMap<Value, SmallVector<Value>>> &inverseMapping) {
2809  if (rewriterImpl.unresolvedMaterializations.empty())
2810  return success();
2811  inverseMapping = rewriterImpl.mapping.getInverse();
2812 
2813  // As an initial step, compute all of the inserted materializations that we
2814  // expect to persist beyond the conversion process.
2816  SetVector<UnresolvedMaterialization *> necessaryMaterializations;
2817  computeNecessaryMaterializations(materializationOps, rewriter, rewriterImpl,
2818  *inverseMapping, necessaryMaterializations);
2819 
2820  // Once computed, legalize any necessary materializations.
2821  for (auto *mat : necessaryMaterializations) {
2823  *mat, materializationOps, rewriter, rewriterImpl, *inverseMapping)))
2824  return failure();
2825  }
2826  return success();
2827 }
2828 
2829 LogicalResult OperationConverter::legalizeErasedResult(
2830  Operation *op, OpResult result,
2831  ConversionPatternRewriterImpl &rewriterImpl) {
2832  // If the operation result was replaced with null, all of the uses of this
2833  // value should be replaced.
2834  auto liveUserIt = llvm::find_if_not(result.getUsers(), [&](Operation *user) {
2835  return rewriterImpl.isOpIgnored(user);
2836  });
2837  if (liveUserIt != result.user_end()) {
2838  InFlightDiagnostic diag = op->emitError("failed to legalize operation '")
2839  << op->getName() << "' marked as erased";
2840  diag.attachNote(liveUserIt->getLoc())
2841  << "found live user of result #" << result.getResultNumber() << ": "
2842  << *liveUserIt;
2843  return failure();
2844  }
2845  return success();
2846 }
2847 
2848 /// Finds a user of the given value, or of any other value that the given value
2849 /// replaced, that was not replaced in the conversion process.
2851  Value initialValue, ConversionPatternRewriterImpl &rewriterImpl,
2852  const DenseMap<Value, SmallVector<Value>> &inverseMapping) {
2853  SmallVector<Value> worklist(1, initialValue);
2854  while (!worklist.empty()) {
2855  Value value = worklist.pop_back_val();
2856 
2857  // Walk the users of this value to see if there are any live users that
2858  // weren't replaced during conversion.
2859  auto liveUserIt = llvm::find_if_not(value.getUsers(), [&](Operation *user) {
2860  return rewriterImpl.isOpIgnored(user);
2861  });
2862  if (liveUserIt != value.user_end())
2863  return *liveUserIt;
2864  auto mapIt = inverseMapping.find(value);
2865  if (mapIt != inverseMapping.end())
2866  worklist.append(mapIt->second);
2867  }
2868  return nullptr;
2869 }
2870 
2871 LogicalResult OperationConverter::legalizeChangedResultType(
2872  Operation *op, OpResult result, Value newValue,
2873  TypeConverter *replConverter, ConversionPatternRewriter &rewriter,
2874  ConversionPatternRewriterImpl &rewriterImpl,
2875  const DenseMap<Value, SmallVector<Value>> &inverseMapping) {
2876  Operation *liveUser =
2877  findLiveUserOfReplaced(result, rewriterImpl, inverseMapping);
2878  if (!liveUser)
2879  return success();
2880 
2881  // Functor used to emit a conversion error for a failed materialization.
2882  auto emitConversionError = [&] {
2884  << "failed to materialize conversion for result #"
2885  << result.getResultNumber() << " of operation '"
2886  << op->getName()
2887  << "' that remained live after conversion";
2888  diag.attachNote(liveUser->getLoc())
2889  << "see existing live user here: " << *liveUser;
2890  return failure();
2891  };
2892 
2893  // If the replacement has a type converter, attempt to materialize a
2894  // conversion back to the original type.
2895  if (!replConverter)
2896  return emitConversionError();
2897 
2898  // Materialize a conversion for this live result value.
2899  Type resultType = result.getType();
2900  Value convertedValue = replConverter->materializeSourceConversion(
2901  rewriter, op->getLoc(), resultType, newValue);
2902  if (!convertedValue)
2903  return emitConversionError();
2904 
2905  rewriterImpl.mapping.map(result, convertedValue);
2906  return success();
2907 }
2908 
2909 //===----------------------------------------------------------------------===//
2910 // Type Conversion
2911 //===----------------------------------------------------------------------===//
2912 
2914  ArrayRef<Type> types) {
2915  assert(!types.empty() && "expected valid types");
2916  remapInput(origInputNo, /*newInputNo=*/argTypes.size(), types.size());
2917  addInputs(types);
2918 }
2919 
2921  assert(!types.empty() &&
2922  "1->0 type remappings don't need to be added explicitly");
2923  argTypes.append(types.begin(), types.end());
2924 }
2925 
2926 void TypeConverter::SignatureConversion::remapInput(unsigned origInputNo,
2927  unsigned newInputNo,
2928  unsigned newInputCount) {
2929  assert(!remappedInputs[origInputNo] && "input has already been remapped");
2930  assert(newInputCount != 0 && "expected valid input count");
2931  remappedInputs[origInputNo] =
2932  InputMapping{newInputNo, newInputCount, /*replacementValue=*/nullptr};
2933 }
2934 
2936  Value replacementValue) {
2937  assert(!remappedInputs[origInputNo] && "input has already been remapped");
2938  remappedInputs[origInputNo] =
2939  InputMapping{origInputNo, /*size=*/0, replacementValue};
2940 }
2941 
2943  SmallVectorImpl<Type> &results) {
2944  auto existingIt = cachedDirectConversions.find(t);
2945  if (existingIt != cachedDirectConversions.end()) {
2946  if (existingIt->second)
2947  results.push_back(existingIt->second);
2948  return success(existingIt->second != nullptr);
2949  }
2950  auto multiIt = cachedMultiConversions.find(t);
2951  if (multiIt != cachedMultiConversions.end()) {
2952  results.append(multiIt->second.begin(), multiIt->second.end());
2953  return success();
2954  }
2955 
2956  // Walk the added converters in reverse order to apply the most recently
2957  // registered first.
2958  size_t currentCount = results.size();
2959  conversionCallStack.push_back(t);
2960  auto popConversionCallStack =
2961  llvm::make_scope_exit([this]() { conversionCallStack.pop_back(); });
2962  for (ConversionCallbackFn &converter : llvm::reverse(conversions)) {
2963  if (Optional<LogicalResult> result =
2964  converter(t, results, conversionCallStack)) {
2965  if (!succeeded(*result)) {
2966  cachedDirectConversions.try_emplace(t, nullptr);
2967  return failure();
2968  }
2969  auto newTypes = ArrayRef<Type>(results).drop_front(currentCount);
2970  if (newTypes.size() == 1)
2971  cachedDirectConversions.try_emplace(t, newTypes.front());
2972  else
2973  cachedMultiConversions.try_emplace(t, llvm::to_vector<2>(newTypes));
2974  return success();
2975  }
2976  }
2977  return failure();
2978 }
2979 
2981  // Use the multi-type result version to convert the type.
2982  SmallVector<Type, 1> results;
2983  if (failed(convertType(t, results)))
2984  return nullptr;
2985 
2986  // Check to ensure that only one type was produced.
2987  return results.size() == 1 ? results.front() : nullptr;
2988 }
2989 
2991  SmallVectorImpl<Type> &results) {
2992  for (Type type : types)
2993  if (failed(convertType(type, results)))
2994  return failure();
2995  return success();
2996 }
2997 
2998 bool TypeConverter::isLegal(Type type) { return convertType(type) == type; }
3000  return isLegal(op->getOperandTypes()) && isLegal(op->getResultTypes());
3001 }
3002 
3004  return llvm::all_of(*region, [this](Block &block) {
3005  return isLegal(block.getArgumentTypes());
3006  });
3007 }
3008 
3009 bool TypeConverter::isSignatureLegal(FunctionType ty) {
3010  return isLegal(llvm::concat<const Type>(ty.getInputs(), ty.getResults()));
3011 }
3012 
3014  SignatureConversion &result) {
3015  // Try to convert the given input type.
3016  SmallVector<Type, 1> convertedTypes;
3017  if (failed(convertType(type, convertedTypes)))
3018  return failure();
3019 
3020  // If this argument is being dropped, there is nothing left to do.
3021  if (convertedTypes.empty())
3022  return success();
3023 
3024  // Otherwise, add the new inputs.
3025  result.addInputs(inputNo, convertedTypes);
3026  return success();
3027 }
3029  SignatureConversion &result,
3030  unsigned origInputOffset) {
3031  for (unsigned i = 0, e = types.size(); i != e; ++i)
3032  if (failed(convertSignatureArg(origInputOffset + i, types[i], result)))
3033  return failure();
3034  return success();
3035 }
3036 
3037 Value TypeConverter::materializeConversion(
3039  OpBuilder &builder, Location loc, Type resultType, ValueRange inputs) {
3040  for (MaterializationCallbackFn &fn : llvm::reverse(materializations))
3041  if (Optional<Value> result = fn(builder, resultType, inputs, loc))
3042  return result.getValue();
3043  return nullptr;
3044 }
3045 
3048  SignatureConversion conversion(block->getNumArguments());
3049  if (failed(convertSignatureArgs(block->getArgumentTypes(), conversion)))
3050  return llvm::None;
3051  return conversion;
3052 }
3053 
3054 //===----------------------------------------------------------------------===//
3055 // FunctionOpInterfaceSignatureConversion
3056 //===----------------------------------------------------------------------===//
3057 
3058 /// Create a default conversion pattern that rewrites the type signature of a
3059 /// FunctionOpInterface op. This only supports ops which use FunctionType to
3060 /// represent their type.
3061 namespace {
3062 struct FunctionOpInterfaceSignatureConversion : public ConversionPattern {
3063  FunctionOpInterfaceSignatureConversion(StringRef functionLikeOpName,
3064  MLIRContext *ctx,
3065  TypeConverter &converter)
3066  : ConversionPattern(converter, functionLikeOpName, /*benefit=*/1, ctx) {}
3067 
3069  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
3070  ConversionPatternRewriter &rewriter) const override {
3071  FunctionOpInterface funcOp = cast<FunctionOpInterface>(op);
3072  FunctionType type = funcOp.getType().cast<FunctionType>();
3073 
3074  // Convert the original function types.
3075  TypeConverter::SignatureConversion result(type.getNumInputs());
3076  SmallVector<Type, 1> newResults;
3077  if (failed(typeConverter->convertSignatureArgs(type.getInputs(), result)) ||
3078  failed(typeConverter->convertTypes(type.getResults(), newResults)) ||
3079  failed(rewriter.convertRegionTypes(&funcOp.getBody(), *typeConverter,
3080  &result)))
3081  return failure();
3082 
3083  // Update the function signature in-place.
3084  auto newType = FunctionType::get(rewriter.getContext(),
3085  result.getConvertedTypes(), newResults);
3086 
3087  rewriter.updateRootInPlace(op, [&] { funcOp.setType(newType); });
3088 
3089  return success();
3090  }
3091 };
3092 } // namespace
3093 
3095  StringRef functionLikeOpName, RewritePatternSet &patterns,
3096  TypeConverter &converter) {
3097  patterns.add<FunctionOpInterfaceSignatureConversion>(
3098  functionLikeOpName, patterns.getContext(), converter);
3099 }
3100 
3101 //===----------------------------------------------------------------------===//
3102 // ConversionTarget
3103 //===----------------------------------------------------------------------===//
3104 
3106  LegalizationAction action) {
3107  legalOperations[op].action = action;
3108 }
3109 
3111  LegalizationAction action) {
3112  for (StringRef dialect : dialectNames)
3113  legalDialects[dialect] = action;
3114 }
3115 
3118  Optional<LegalizationInfo> info = getOpInfo(op);
3119  return info ? info->action : Optional<LegalizationAction>();
3120 }
3121 
3124  Optional<LegalizationInfo> info = getOpInfo(op->getName());
3125  if (!info)
3126  return llvm::None;
3127 
3128  // Returns true if this operation instance is known to be legal.
3129  auto isOpLegal = [&] {
3130  // Handle dynamic legality either with the provided legality function.
3131  if (info->action == LegalizationAction::Dynamic) {
3132  Optional<bool> result = info->legalityFn(op);
3133  if (result)
3134  return *result;
3135  }
3136 
3137  // Otherwise, the operation is only legal if it was marked 'Legal'.
3138  return info->action == LegalizationAction::Legal;
3139  };
3140  if (!isOpLegal())
3141  return llvm::None;
3142 
3143  // This operation is legal, compute any additional legality information.
3144  LegalOpDetails legalityDetails;
3145  if (info->isRecursivelyLegal) {
3146  auto legalityFnIt = opRecursiveLegalityFns.find(op->getName());
3147  if (legalityFnIt != opRecursiveLegalityFns.end()) {
3148  legalityDetails.isRecursivelyLegal =
3149  legalityFnIt->second(op).getValueOr(true);
3150  } else {
3151  legalityDetails.isRecursivelyLegal = true;
3152  }
3153  }
3154  return legalityDetails;
3155 }
3156 
3158  Optional<LegalizationInfo> info = getOpInfo(op->getName());
3159  if (!info)
3160  return false;
3161 
3162  if (info->action == LegalizationAction::Dynamic) {
3163  Optional<bool> result = info->legalityFn(op);
3164  if (!result)
3165  return false;
3166 
3167  return !(*result);
3168  }
3169 
3170  return info->action == LegalizationAction::Illegal;
3171 }
3172 
3176  if (!oldCallback)
3177  return newCallback;
3178 
3179  auto chain = [oldCl = std::move(oldCallback), newCl = std::move(newCallback)](
3180  Operation *op) -> Optional<bool> {
3181  if (Optional<bool> result = newCl(op))
3182  return *result;
3183 
3184  return oldCl(op);
3185  };
3186  return chain;
3187 }
3188 
3189 void ConversionTarget::setLegalityCallback(
3190  OperationName name, const DynamicLegalityCallbackFn &callback) {
3191  assert(callback && "expected valid legality callback");
3192  auto infoIt = legalOperations.find(name);
3193  assert(infoIt != legalOperations.end() &&
3194  infoIt->second.action == LegalizationAction::Dynamic &&
3195  "expected operation to already be marked as dynamically legal");
3196  infoIt->second.legalityFn =
3197  composeLegalityCallbacks(std::move(infoIt->second.legalityFn), callback);
3198 }
3199 
3201  OperationName name, const DynamicLegalityCallbackFn &callback) {
3202  auto infoIt = legalOperations.find(name);
3203  assert(infoIt != legalOperations.end() &&
3204  infoIt->second.action != LegalizationAction::Illegal &&
3205  "expected operation to already be marked as legal");
3206  infoIt->second.isRecursivelyLegal = true;
3207  if (callback)
3208  opRecursiveLegalityFns[name] = composeLegalityCallbacks(
3209  std::move(opRecursiveLegalityFns[name]), callback);
3210  else
3211  opRecursiveLegalityFns.erase(name);
3212 }
3213 
3214 void ConversionTarget::setLegalityCallback(
3215  ArrayRef<StringRef> dialects, const DynamicLegalityCallbackFn &callback) {
3216  assert(callback && "expected valid legality callback");
3217  for (StringRef dialect : dialects)
3218  dialectLegalityFns[dialect] = composeLegalityCallbacks(
3219  std::move(dialectLegalityFns[dialect]), callback);
3220 }
3221 
3222 void ConversionTarget::setLegalityCallback(
3223  const DynamicLegalityCallbackFn &callback) {
3224  assert(callback && "expected valid legality callback");
3225  unknownLegalityFn = composeLegalityCallbacks(unknownLegalityFn, callback);
3226 }
3227 
3228 auto ConversionTarget::getOpInfo(OperationName op) const
3230  // Check for info for this specific operation.
3231  auto it = legalOperations.find(op);
3232  if (it != legalOperations.end())
3233  return it->second;
3234  // Check for info for the parent dialect.
3235  auto dialectIt = legalDialects.find(op.getDialectNamespace());
3236  if (dialectIt != legalDialects.end()) {
3237  DynamicLegalityCallbackFn callback;
3238  auto dialectFn = dialectLegalityFns.find(op.getDialectNamespace());
3239  if (dialectFn != dialectLegalityFns.end())
3240  callback = dialectFn->second;
3241  return LegalizationInfo{dialectIt->second, /*isRecursivelyLegal=*/false,
3242  callback};
3243  }
3244  // Otherwise, check if we mark unknown operations as dynamic.
3245  if (unknownLegalityFn)
3246  return LegalizationInfo{LegalizationAction::Dynamic,
3247  /*isRecursivelyLegal=*/false, unknownLegalityFn};
3248  return llvm::None;
3249 }
3250 
3251 //===----------------------------------------------------------------------===//
3252 // Op Conversion Entry Points
3253 //===----------------------------------------------------------------------===//
3254 
3255 //===----------------------------------------------------------------------===//
3256 // Partial Conversion
3257 
3260  ConversionTarget &target,
3261  const FrozenRewritePatternSet &patterns,
3262  DenseSet<Operation *> *unconvertedOps) {
3263  OperationConverter opConverter(target, patterns, OpConversionMode::Partial,
3264  unconvertedOps);
3265  return opConverter.convertOperations(ops);
3266 }
3269  const FrozenRewritePatternSet &patterns,
3270  DenseSet<Operation *> *unconvertedOps) {
3271  return applyPartialConversion(llvm::makeArrayRef(op), target, patterns,
3272  unconvertedOps);
3273 }
3274 
3275 //===----------------------------------------------------------------------===//
3276 // Full Conversion
3277 
3280  const FrozenRewritePatternSet &patterns) {
3281  OperationConverter opConverter(target, patterns, OpConversionMode::Full);
3282  return opConverter.convertOperations(ops);
3283 }
3286  const FrozenRewritePatternSet &patterns) {
3287  return applyFullConversion(llvm::makeArrayRef(op), target, patterns);
3288 }
3289 
3290 //===----------------------------------------------------------------------===//
3291 // Analysis Conversion
3292 
3295  ConversionTarget &target,
3296  const FrozenRewritePatternSet &patterns,
3297  DenseSet<Operation *> &convertedOps,
3298  function_ref<void(Diagnostic &)> notifyCallback) {
3299  OperationConverter opConverter(target, patterns, OpConversionMode::Analysis,
3300  &convertedOps);
3301  return opConverter.convertOperations(ops, notifyCallback);
3302 }
3305  const FrozenRewritePatternSet &patterns,
3306  DenseSet<Operation *> &convertedOps,
3307  function_ref<void(Diagnostic &)> notifyCallback) {
3308  return applyAnalysisConversion(llvm::makeArrayRef(op), target, patterns,
3309  convertedOps, notifyCallback);
3310 }
llvm::MapVector< Operation *, OpReplacement > replacements
Ordered map of requested operation replacements.
Kind
Tensor expression kind.
Definition: Merger.h:27
Value getRemappedValue(Value key)
Return the converted value of &#39;key&#39; with a type defined by the type converter of the currently execut...
void undoBlockActions(unsigned numActionsToKeep=0)
Undo the block actions (motions, splits) one by one in reverse order until "numActionsToKeep" actions...
Include the generated interface declarations.
OpTy create(Location location, Args &&...args)
Create an operation of specific op type at the current insertion point.
Definition: Builders.h:430
This class contains a list of basic blocks and a link to the parent operation it is attached to...
Definition: Region.h:26
iterator begin()
Definition: Block.h:134
static std::string diag(llvm::Value &v)
LogicalResult convertTypes(TypeRange types, SmallVectorImpl< Type > &results)
Convert the given set of types, filling &#39;results&#39; as necessary.
LogicalResult tryFold(Operation *op, SmallVectorImpl< Value > &results)
Attempts to fold the given operation and places new results within &#39;results&#39;.
Definition: Builders.cpp:387
ArrayRef< Type > getConvertedTypes() const
Return the argument types for the new signature.
void notifyOpReplaced(Operation *op, ValueRange newValues)
PatternRewriter hook for replacing the results of an operation.
MLIRContext * getContext() const
Definition: Builders.h:54
Block * applySignatureConversion(Region *region, TypeConverter::SignatureConversion &conversion, TypeConverter *converter=nullptr)
Apply a signature conversion to the entry block of the given region.
FailureOr< Block * > convertRegionTypes(Region *region, TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion)
Convert the types of block arguments within the given region.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:881
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Definition: Builders.h:373
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
LogicalResult applyPartialConversion(ArrayRef< Operation *> ops, ConversionTarget &target, const FrozenRewritePatternSet &patterns, DenseSet< Operation *> *unconvertedOps=nullptr)
Below we define several entry points for operation conversion.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:423
Operation & back()
Definition: Block.h:143
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:30
static Value min(ImplicitLocOpBuilder &builder, Value a, Value b)
BlockListType & getBlocks()
Definition: Region.h:45
This is a value defined by a result of an operation.
Definition: Value.h:423
operand_range getOperands()
Returns an iterator on the underlying Value&#39;s.
Definition: Operation.h:247
Value materializeArgumentConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs)
Materialize a conversion from a set of types into one result type by generating a cast sequence of so...
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:420
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:301
void setOperands(ValueRange operands)
Replace the current operands of this operation with the ones provided in &#39;operands&#39;.
Definition: Operation.cpp:200
This class represents a frozen set of patterns that can be processed by a pattern applicator...
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results)
Convert the given type.
Block represents an ordered list of Operations.
Definition: Block.h:29
Block & front()
Definition: Region.h:65
iterator_range< args_iterator > addArguments(TypeRange types, ArrayRef< Location > locs)
Add one argument to the argument list for each type specified in the list.
Definition: Block.cpp:148
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:329
void replaceUsesOfBlockArgument(BlockArgument from, Value to)
Replace all the uses of the block argument from with value to.
void moveBefore(Block *block)
Unlink this block from its current region and insert it right before the specific block...
Definition: Block.cpp:47
static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args)
A utility function to log a successful result for the given reason.
static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args)
A utility function to log a failure result for the given reason.
void resetState(RewriterState state)
Reset the state of the rewriter to a previously saved point.
OpListType & getOperations()
Definition: Block.h:128
virtual Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block...
Base class for the conversion patterns.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
Definition: LogicalResult.h:72
void markNestedOpsIgnored(Operation *op)
Recursively marks the nested operations under &#39;op&#39; as ignored.
BlockListType::iterator iterator
Definition: Region.h:52
operand_type_range getOperandTypes()
Definition: Operation.h:266
iterator_range< pred_iterator > getPredecessors()
Definition: Block.h:225
bool isBeforeInBlock(Operation *other)
Given an operation &#39;other&#39; that is within the same parent block, return whether the current operation...
Definition: Operation.cpp:271
Location objects represent source locations information in MLIR.
Definition: Location.h:31
This class implements the result iterators for the Operation class.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value...
Definition: LogicalResult.h:68
void notifySplitBlock(Block *block, Block *continuation)
Notifies that a block was split.
type_range getTypes() const
void notifyOperationInserted(Operation *op) override
PatternRewriter hook for inserting a new operation.
LogicalResult notifyMatchFailure(Operation *op, function_ref< void(Diagnostic &)> reasonCallback) override
PatternRewriter hook for notifying match failure reasons.
T lookup(T from) const
Lookup a mapped value within the map.
This struct represents a range of new types or a single value that remaps an existing signature input...
A structure containing additional information describing a specific legal operation instance...
virtual void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent"...
void dropAllUses() const
Drop all uses of this object from their respective owners.
Definition: Value.h:156
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:96
user_range getUsers() const
Definition: Value.h:212
BlockArgument getArgument(unsigned i)
Definition: Block.h:120
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing the results of an operation.
void replaceAllUsesWith(Value newValue) const
Replace all uses of &#39;this&#39; value with the new value, updating anything in the IR that uses &#39;this&#39; to ...
Definition: Value.h:161
Region * getParent() const
Provide a &#39;getParent&#39; method for ilist_node_with_parent methods.
Definition: Block.cpp:26
llvm::PointerUnion< NamedAttribute *, NamedTypeConstraint * > Argument
Definition: Argument.h:62
static constexpr const bool value
void erase()
Remove this operation from its parent block and delete it.
Definition: Operation.cpp:424
void setSuccessor(Block *block, unsigned index)
Definition: Operation.cpp:490
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:48
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
Definition: PatternMatch.h:71
LogicalResult convertNonEntryRegionTypes(Region *region, TypeConverter &converter, ArrayRef< TypeConverter::SignatureConversion > blockConversions)
Convert the types of block arguments within the given region except for the entry region...
std::function< Optional< bool >(Operation *)> DynamicLegalityCallbackFn
The signature of the callback used to determine if an operation is dynamically legal on the target...
LogicalResult convertSignatureArg(unsigned inputNo, Type type, SignatureConversion &result)
This method allows for converting a specific argument of a signature.
SetVector< Operation * > ignoredOps
A set of operations that should no longer be considered for legalization, but were not directly repla...
void dropAllUses()
Drop all uses of this object from their respective owners.
Definition: UseDefLists.h:174
Block * getOwner() const
Returns the block that owns this argument.
Definition: Value.h:307
LogicalResult applyFullConversion(ArrayRef< Operation *> ops, ConversionTarget &target, const FrozenRewritePatternSet &patterns)
Apply a complete conversion on the given operations, and all nested operations.
std::enable_if_t<!std::is_convertible< ValuesT, Operation * >::value > replaceAllUsesWith(ValuesT &&values)
Replace all uses of results of this range with the provided &#39;values&#39;.
bool isOpIgnored(Operation *op) const
Returns true if the given operation is ignored, and does not need to be converted.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
std::enable_if< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT >::type walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one)...
Definition: Operation.h:515
void erase()
Unlink this Block from its parent region and delete it.
Definition: Block.cpp:54
Value materializeSourceConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs)
LogicalResult getRemappedValues(ValueRange keys, SmallVectorImpl< Value > &results)
Return the converted values that replace &#39;keys&#39; with types defined by the type converter of the curre...
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine...
Definition: Diagnostics.h:157
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
static PatternBenefit impossibleToMatch()
Definition: PatternMatch.h:41
This class provides all of the information necessary to convert a type signature. ...
OpPrintingFlags & printGenericOpForm()
Always print operations in the generic form.
Definition: AsmPrinter.cpp:194
void replaceOpWithIf(Operation *op, ValueRange newValues, bool *allUsesReplaced, llvm::unique_function< bool(OpOperand &) const > functor) override
PatternRewriter hook for replacing the results of an operation when the given functor returns true...
OpListType::iterator iterator
Definition: Block.h:131
succ_iterator successor_end()
Definition: Operation.h:445
SmallVector< BlockAction, 4 > blockActions
Ordered list of block operations (creations, splits, motions).
bool empty()
Definition: Region.h:60
LogicalResult notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback)
Notifies that a pattern match failed for the given reason.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues) override
PatternRewriter hook for merging a block into another.
static LogicalResult legalizeUnresolvedMaterialization(UnresolvedMaterialization &mat, DenseMap< Operation *, UnresolvedMaterialization *> &materializationOps, ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &rewriterImpl, DenseMap< Value, SmallVector< Value >> &inverseMapping)
Legalize the given unresolved materialization.
This class provides support for representing a failure result, or a valid value of type T...
Definition: LogicalResult.h:77
void startRootUpdate(Operation *op) override
PatternRewriter hook for updating the root operation in-place.
Operation * insert(Operation *op)
Insert the given operation at the current insertion point and return it.
Definition: Builders.cpp:344
void replaceAllUsesWith(ValueT &&newValue)
Replace all uses of &#39;this&#39; value with the new value, updating anything in the IR that uses &#39;this&#39; to ...
Definition: UseDefLists.h:183
Diagnostic & attachNote(Optional< Location > noteLoc=llvm::None)
Attaches a note to this diagnostic.
Definition: Diagnostics.h:335
iterator end()
Definition: Block.h:135
unsigned getNumArguments()
Definition: Block.h:119
LegalizationAction
This enumeration corresponds to the specific action to take when considering an operation legal for t...
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:32
static void detachNestedAndErase(Operation *op)
Detach any operations nested in the given operation from their parent blocks, and erase the given ope...
SmallVector< UnresolvedMaterialization > unresolvedMaterializations
Ordered vector of all unresolved type conversion materializations during conversion.
U dyn_cast() const
Definition: Value.h:99
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:117
void eraseBlock(Block *block) override
PatternRewriter hook for erase all operations in a block.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:206
unsigned getResultNumber() const
Returns the number of this result.
Definition: Value.h:435
Block * getParentBlock()
Return the Block in which this Value is defined.
Definition: Value.cpp:48
ArrayRef< OperationName > getGeneratedOps() const
Return a list of operations that may be generated when rewriting an operation instance with this patt...
Definition: PatternMatch.h:88
llvm::ScopedPrinter logger
A logger used to emit diagnostics during the conversion process.
Block & back()
Definition: Region.h:64
BlockActionKind
The kind of the block action performed during the rewrite.
succ_iterator successor_begin()
Definition: Operation.h:444
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
Definition: Builders.h:352
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:38
static Value buildUnresolvedArgumentMaterialization(PatternRewriter &rewriter, Location loc, ValueRange inputs, Type origOutputType, Type outputType, TypeConverter *converter, SmallVectorImpl< UnresolvedMaterialization > &unresolvedMaterializations)
Optional< InputMapping > getInputMapping(unsigned input) const
Get the input mapping for the given argument.
void cloneRegionBefore(Region &region, Region &parent, Region::iterator before, BlockAndValueMapping &mapping) override
PatternRewriter hook for cloning blocks of one region into another.
void updateRootInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around a root update of an operation.
Definition: PatternMatch.h:789
StringRef getDialectNamespace() const
Return the name of the dialect this operation is registered to.
static Value buildUnresolvedMaterialization(UnresolvedMaterialization::Kind kind, Block *insertBlock, Block::iterator insertPt, Location loc, ValueRange inputs, Type outputType, Type origOutputType, TypeConverter *converter, SmallVectorImpl< UnresolvedMaterialization > &unresolvedMaterializations)
Build an unresolved materialization operation given an output type and set of input operands...
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:106
LogicalResult applyAnalysisConversion(ArrayRef< Operation *> ops, ConversionTarget &target, const FrozenRewritePatternSet &patterns, DenseSet< Operation *> &convertedOps, function_ref< void(Diagnostic &)> notifyCallback=nullptr)
Apply an analysis conversion on the given operations, and all nested operations.
static void replaceMaterialization(ConversionPatternRewriterImpl &rewriterImpl, ResultRange matResults, ValueRange values, DenseMap< Value, SmallVector< Value >> &inverseMapping)
Replace the results of a materialization operation with the given values.
void setOpAction(OperationName op, LegalizationAction action)
Register a legality action for the given operation.
BlockArgListType getArguments()
Definition: Block.h:76
bool isLegal(Type type)
Return true if the given type is legal for this type converter, i.e.
This class represents an argument of a Block.
Definition: Value.h:298
auto getType() const
bool hasBoundedRewriteRecursion() const
Returns true if this pattern is known to result in recursive application, i.e.
Definition: PatternMatch.h:127
void dropAllDefinedValueUses()
This drops all uses of values defined in this block or in the blocks of nested regions wherever the u...
Definition: Block.cpp:82
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before) override
PatternRewriter hook for moving blocks out of a region.
void notifyBlockCreated(Block *block) override
PatternRewriter hook creating a new block.
void dropAllUses()
Drop all uses of results of this operation.
Definition: Operation.h:552
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
static ConversionTarget::DynamicLegalityCallbackFn composeLegalityCallbacks(ConversionTarget::DynamicLegalityCallbackFn oldCallback, ConversionTarget::DynamicLegalityCallbackFn newCallback)
bool empty()
Definition: Block.h:139
void print(raw_ostream &os, const OpPrintingFlags &flags=llvm::None)
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
DictionaryAttr getAttrDictionary()
Return all of the attributes on this operation as a DictionaryAttr.
Definition: Operation.h:311
user_iterator user_end() const
Definition: Value.h:211
void notifyCreatedBlock(Block *block)
Notifies that a block was created.
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition: Region.h:200
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:84
bool use_empty() const
Returns true if this value has no uses.
Definition: Value.h:202
bool isRecursivelyLegal
A flag that indicates if this operation is &#39;recursively&#39; legal.
void markOpRecursivelyLegal(const DynamicLegalityCallbackFn &callback={})
Mark an operation, that must have either been set as Legal or DynamicallyLegal, as being recursively ...
detail::ConversionPatternRewriterImpl & getImpl()
Return a reference to the internal implementation.
void notifyRegionWasClonedBefore(iterator_range< Region::iterator > &blocks, Location origRegionLoc)
Notifies that the blocks of a region were cloned into another.
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
Definition: Block.cpp:137
operand_iterator operand_begin()
Definition: Operation.h:243
RewriterState getCurrentState()
Return the current state of the rewriter.
bool isa() const
Definition: Value.h:89
Optional< LegalizationAction > getOpAction(OperationName op) const
Get the legality action for the given operation.
static Operation * findLiveUserOfReplaced(Value initialValue, ConversionPatternRewriterImpl &rewriterImpl, const DenseMap< Value, SmallVector< Value >> &inverseMapping)
Finds a user of the given value, or of any other value that the given value replaced, that was not replaced in the conversion process.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:279
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
Set of flags used to control the behavior of the various IR print methods (e.g.
void eraseDanglingBlocks()
Erase any blocks that were unlinked from their regions and stored in block actions.
void applyRewrites()
Apply all requested operation rewrites.
void push_back(Operation *op)
Definition: Block.h:140
void discardRewrites()
Cleanup and destroy any generated rewrite operations.
Type getType() const
Return the type of this value.
Definition: Value.h:117
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&... args)
Add an instance of each of the pattern types &#39;Ts&#39; to the pattern list with the given arguments...
Definition: PatternMatch.h:930
void setLoc(Location loc)
Set the source location the operation was defined or derived from.
Definition: Operation.h:109
FailureOr< Block * > convertBlockSignature(Block *block, TypeConverter *converter, TypeConverter::SignatureConversion *conversion=nullptr)
Convert the signature of the given block.
iterator end()
Definition: Region.h:56
virtual LogicalResult matchAndRewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const
Hook for derived classes to implement combined matching and rewriting.
Type conversion class.
bool isSignatureLegal(FunctionType ty)
Return true if the inputs and outputs of the given function type are legal.
SmallPtrSet< Operation *, 1 > pendingRootUpdates
A set of operations that have pending updates.
Optional< SignatureConversion > convertBlockSignature(Block *block)
This function converts the type signature of the given block, by invoking &#39;convertSignatureArg&#39; for e...
virtual void cloneRegionBefore(Region &region, Region &parent, Region::iterator before, BlockAndValueMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent"...
Value materializeTargetConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs)
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:37
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
This class manages the application of a group of rewrite patterns, with a user-provided cost model...
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:55
Block * splitBlock(Block *block, Block::iterator before) override
PatternRewriter hook for splitting a block into two parts.
This class represents an operand of an operation.
Definition: Value.h:249
This class implements a pattern rewriter for use with ConversionPatterns.
static Value buildUnresolvedTargetMaterialization(Location loc, Value input, Type outputType, TypeConverter *converter, SmallVectorImpl< UnresolvedMaterialization > &unresolvedMaterializations)
LogicalResult convertSignatureArgs(TypeRange types, SignatureConversion &result, unsigned origInputOffset=0)
SmallVector< unsigned, 4 > operationsWithChangedResults
A vector of indices into replacements of operations that were replaced with values with different res...
This class implements the operand iterators for the Operation class.
void populateFunctionOpInterfaceTypeConversionPattern(StringRef functionLikeOpName, RewritePatternSet &patterns, TypeConverter &converter)
Add a pattern to the given pattern list to convert the signature of a FunctionOpInterface op with the...
SmallVector< Operation * > createdOps
Ordered vector of all of the newly created operations during conversion.
U cast() const
Definition: Value.h:107
LogicalResult convertNonEntryRegionTypes(Region *region, TypeConverter &converter, ArrayRef< TypeConverter::SignatureConversion > blockConversions={})
Convert the types of non-entry block arguments within the given region.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
void finalizeRootUpdate(Operation *op) override
PatternRewriter hook for updating the root operation in-place.
static void computeNecessaryMaterializations(DenseMap< Operation *, UnresolvedMaterialization *> &materializationOps, ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &rewriterImpl, DenseMap< Value, SmallVector< Value >> &inverseMapping, SetVector< UnresolvedMaterialization *> &necessaryMaterializations)
Compute all of the unresolved materializations that will persist beyond the conversion process...
ArgConverter argConverter
Utility used to convert block arguments.
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:273
This class describes a specific conversion target.
function_ref< void(Diagnostic &)> notifyCallback
This allows the user to collect the match failure message.
ConversionPatternRewriterImpl(PatternRewriter &rewriter)
void notifyBlocksBeingMerged(Block *block, Block *srcBlock)
Notifies that block is being merged with srcBlock.
operand_iterator operand_end()
Definition: Operation.h:244
SuccessorRange getSuccessors()
Definition: Block.h:255
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Definition: Builders.h:376
static LogicalResult computeConversionSet(iterator_range< Region::iterator > region, Location regionLoc, SmallVectorImpl< Operation *> &toConvert, ConversionTarget *target=nullptr)
Recursively collect all of the operations to convert from within &#39;region&#39;.
Location getLoc()
Return a location for this region.
Definition: Region.cpp:31
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:57
SmallVector< OperationTransactionState, 4 > rootUpdates
A transaction state for each of operations that were updated in-place.
void cancelRootUpdate(Operation *op) override
PatternRewriter hook for updating the root operation in-place.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:231
Block * applySignatureConversion(Region *region, TypeConverter::SignatureConversion &conversion, TypeConverter *converter)
Apply a signature conversion on the given region, using converter for materializations if not null...
void notifyBlockIsBeingErased(Block *block)
Notifies that a block is about to be erased.
void setAttrs(DictionaryAttr newAttrs)
Set the attribute dictionary on this operation.
Definition: Operation.h:314
result_range getResults()
Definition: Operation.h:284
This class helps build Operations.
Definition: Builders.h:177
This class provides an abstraction over the different types of ranges over Values.
bool isIllegal(Operation *op) const
Returns true is operation instance is illegal on this target.
Optional< LegalOpDetails > isLegal(Operation *op) const
If the given operation instance is legal on this target, a structure containing legality information ...
void setDialectAction(ArrayRef< StringRef > dialectNames, LegalizationAction action)
Register a legality action for the given dialects.
OpConversionMode
result_type_range getResultTypes()
Definition: Operation.h:297
LogicalResult remapValues(StringRef valueDiagTag, Optional< Location > inputLoc, PatternRewriter &rewriter, ValueRange values, SmallVectorImpl< Value > &remapped)
Remap the given values to those with potentially different types.
void addInputs(unsigned origInputNo, ArrayRef< Type > types)
Remap an input of the original signature with a new set of types.
Location getLoc() const
Return the location for this argument.
Definition: Value.h:313
MLIRContext * getContext() const
Definition: PatternMatch.h:906
void notifyRegionIsBeingInlinedBefore(Region &region, Region &parent, Region::iterator before)
Notifies that the blocks of a region are about to be moved.
std::string str() const
Converts the diagnostic to a string.
FailureOr< Block * > convertRegionTypes(Region *region, TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Convert the types of block arguments within the given region.
static Value max(ImplicitLocOpBuilder &builder, Value a, Value b)
Optional< OperationName > getRootKind() const
Return the root node that this pattern matches.
Definition: PatternMatch.h:92
void remapInput(unsigned origInputNo, Value replacement)
Remap an input of the original signature to another replacement value.
Block * splitBlock(iterator splitBefore)
Split the block into two blocks before the specified operation or iterator.
Definition: Block.cpp:289
SmallVector< BlockArgument, 4 > argReplacements
Ordered vector of any requested block argument replacements.