MLIR  18.0.0git
DialectConversion.cpp
Go to the documentation of this file.
1 //===- DialectConversion.cpp - MLIR dialect conversion generic pass -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 #include "mlir/IR/Block.h"
11 #include "mlir/IR/Builders.h"
12 #include "mlir/IR/BuiltinOps.h"
13 #include "mlir/IR/IRMapping.h"
14 #include "mlir/IR/Iterators.h"
17 #include "llvm/ADT/ScopeExit.h"
18 #include "llvm/ADT/SetVector.h"
19 #include "llvm/ADT/SmallPtrSet.h"
20 #include "llvm/Support/Debug.h"
21 #include "llvm/Support/FormatVariadic.h"
22 #include "llvm/Support/SaveAndRestore.h"
23 #include "llvm/Support/ScopedPrinter.h"
24 #include <optional>
25 
26 using namespace mlir;
27 using namespace mlir::detail;
28 
29 #define DEBUG_TYPE "dialect-conversion"
30 
31 /// A utility function to log a successful result for the given reason.
32 template <typename... Args>
33 static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
34  LLVM_DEBUG({
35  os.unindent();
36  os.startLine() << "} -> SUCCESS";
37  if (!fmt.empty())
38  os.getOStream() << " : "
39  << llvm::formatv(fmt.data(), std::forward<Args>(args)...);
40  os.getOStream() << "\n";
41  });
42 }
43 
44 /// A utility function to log a failure result for the given reason.
45 template <typename... Args>
46 static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
47  LLVM_DEBUG({
48  os.unindent();
49  os.startLine() << "} -> FAILURE : "
50  << llvm::formatv(fmt.data(), std::forward<Args>(args)...)
51  << "\n";
52  });
53 }
54 
55 //===----------------------------------------------------------------------===//
56 // ConversionValueMapping
57 //===----------------------------------------------------------------------===//
58 
59 namespace {
60 /// This class wraps a IRMapping to provide recursive lookup
61 /// functionality, i.e. we will traverse if the mapped value also has a mapping.
62 struct ConversionValueMapping {
63  /// Lookup a mapped value within the map. If a mapping for the provided value
64  /// does not exist then return the provided value. If `desiredType` is
65  /// non-null, returns the most recently mapped value with that type. If an
66  /// operand of that type does not exist, defaults to normal behavior.
67  Value lookupOrDefault(Value from, Type desiredType = nullptr) const;
68 
69  /// Lookup a mapped value within the map, or return null if a mapping does not
70  /// exist. If a mapping exists, this follows the same behavior of
71  /// `lookupOrDefault`.
72  Value lookupOrNull(Value from, Type desiredType = nullptr) const;
73 
74  /// Map a value to the one provided.
75  void map(Value oldVal, Value newVal) {
76  LLVM_DEBUG({
77  for (Value it = newVal; it; it = mapping.lookupOrNull(it))
78  assert(it != oldVal && "inserting cyclic mapping");
79  });
80  mapping.map(oldVal, newVal);
81  }
82 
83  /// Try to map a value to the one provided. Returns false if a transitive
84  /// mapping from the new value to the old value already exists, true if the
85  /// map was updated.
86  bool tryMap(Value oldVal, Value newVal);
87 
88  /// Drop the last mapping for the given value.
89  void erase(Value value) { mapping.erase(value); }
90 
91  /// Returns the inverse raw value mapping (without recursive query support).
92  DenseMap<Value, SmallVector<Value>> getInverse() const {
94  for (auto &it : mapping.getValueMap())
95  inverse[it.second].push_back(it.first);
96  return inverse;
97  }
98 
99 private:
100  /// Current value mappings.
101  IRMapping mapping;
102 };
103 } // namespace
104 
105 Value ConversionValueMapping::lookupOrDefault(Value from,
106  Type desiredType) const {
107  // If there was no desired type, simply find the leaf value.
108  if (!desiredType) {
109  // If this value had a valid mapping, unmap that value as well in the case
110  // that it was also replaced.
111  while (auto mappedValue = mapping.lookupOrNull(from))
112  from = mappedValue;
113  return from;
114  }
115 
116  // Otherwise, try to find the deepest value that has the desired type.
117  Value desiredValue;
118  do {
119  if (from.getType() == desiredType)
120  desiredValue = from;
121 
122  Value mappedValue = mapping.lookupOrNull(from);
123  if (!mappedValue)
124  break;
125  from = mappedValue;
126  } while (true);
127 
128  // If the desired value was found use it, otherwise default to the leaf value.
129  return desiredValue ? desiredValue : from;
130 }
131 
132 Value ConversionValueMapping::lookupOrNull(Value from, Type desiredType) const {
133  Value result = lookupOrDefault(from, desiredType);
134  if (result == from || (desiredType && result.getType() != desiredType))
135  return nullptr;
136  return result;
137 }
138 
139 bool ConversionValueMapping::tryMap(Value oldVal, Value newVal) {
140  for (Value it = newVal; it; it = mapping.lookupOrNull(it))
141  if (it == oldVal)
142  return false;
143  map(oldVal, newVal);
144  return true;
145 }
146 
147 //===----------------------------------------------------------------------===//
148 // Rewriter and Translation State
149 //===----------------------------------------------------------------------===//
150 namespace {
151 /// This class contains a snapshot of the current conversion rewriter state.
152 /// This is useful when saving and undoing a set of rewrites.
153 struct RewriterState {
154  RewriterState(unsigned numCreatedOps, unsigned numUnresolvedMaterializations,
155  unsigned numReplacements, unsigned numArgReplacements,
156  unsigned numBlockActions, unsigned numIgnoredOperations,
157  unsigned numRootUpdates)
158  : numCreatedOps(numCreatedOps),
159  numUnresolvedMaterializations(numUnresolvedMaterializations),
160  numReplacements(numReplacements),
161  numArgReplacements(numArgReplacements),
162  numBlockActions(numBlockActions),
163  numIgnoredOperations(numIgnoredOperations),
164  numRootUpdates(numRootUpdates) {}
165 
166  /// The current number of created operations.
167  unsigned numCreatedOps;
168 
169  /// The current number of unresolved materializations.
170  unsigned numUnresolvedMaterializations;
171 
172  /// The current number of replacements queued.
173  unsigned numReplacements;
174 
175  /// The current number of argument replacements queued.
176  unsigned numArgReplacements;
177 
178  /// The current number of block actions performed.
179  unsigned numBlockActions;
180 
181  /// The current number of ignored operations.
182  unsigned numIgnoredOperations;
183 
184  /// The current number of operations that were updated in place.
185  unsigned numRootUpdates;
186 };
187 
188 //===----------------------------------------------------------------------===//
189 // OperationTransactionState
190 
191 /// The state of an operation that was updated by a pattern in-place. This
192 /// contains all of the necessary information to reconstruct an operation that
193 /// was updated in place.
194 class OperationTransactionState {
195 public:
196  OperationTransactionState() = default;
197  OperationTransactionState(Operation *op)
198  : op(op), loc(op->getLoc()), attrs(op->getAttrDictionary()),
199  operands(op->operand_begin(), op->operand_end()),
200  successors(op->successor_begin(), op->successor_end()) {}
201 
202  /// Discard the transaction state and reset the state of the original
203  /// operation.
204  void resetOperation() const {
205  op->setLoc(loc);
206  op->setAttrs(attrs);
207  op->setOperands(operands);
208  for (const auto &it : llvm::enumerate(successors))
209  op->setSuccessor(it.value(), it.index());
210  }
211 
212  /// Return the original operation of this state.
213  Operation *getOperation() const { return op; }
214 
215 private:
216  Operation *op;
217  LocationAttr loc;
218  DictionaryAttr attrs;
219  SmallVector<Value, 8> operands;
220  SmallVector<Block *, 2> successors;
221 };
222 
223 //===----------------------------------------------------------------------===//
224 // OpReplacement
225 
226 /// This class represents one requested operation replacement via 'replaceOp' or
227 /// 'eraseOp`.
228 struct OpReplacement {
229  OpReplacement(const TypeConverter *converter = nullptr)
230  : converter(converter) {}
231 
232  /// An optional type converter that can be used to materialize conversions
233  /// between the new and old values if necessary.
234  const TypeConverter *converter;
235 };
236 
237 //===----------------------------------------------------------------------===//
238 // BlockAction
239 
240 /// The kind of the block action performed during the rewrite. Actions can be
241 /// undone if the conversion fails.
242 enum class BlockActionKind {
243  Create,
244  Erase,
245  Inline,
246  Move,
247  Split,
248  TypeConversion
249 };
250 
251 /// Original position of the given block in its parent region. During undo
252 /// actions, the block needs to be placed after `insertAfterBlock`.
253 struct BlockPosition {
254  Region *region;
255  Block *insertAfterBlock;
256 };
257 
258 /// Information needed to undo inlining actions.
259 /// - the source block
260 /// - the first inlined operation (could be null if the source block was empty)
261 /// - the last inlined operation (could be null if the source block was empty)
262 struct InlineInfo {
263  Block *sourceBlock;
264  Operation *firstInlinedInst;
265  Operation *lastInlinedInst;
266 };
267 
268 /// The storage class for an undoable block action (one of BlockActionKind),
269 /// contains the information necessary to undo this action.
270 struct BlockAction {
271  static BlockAction getCreate(Block *block) {
272  return {BlockActionKind::Create, block, {}};
273  }
274  static BlockAction getErase(Block *block, BlockPosition originalPosition) {
275  return {BlockActionKind::Erase, block, {originalPosition}};
276  }
277  static BlockAction getInline(Block *block, Block *srcBlock,
278  Block::iterator before) {
279  BlockAction action{BlockActionKind::Inline, block, {}};
280  action.inlineInfo = {srcBlock,
281  srcBlock->empty() ? nullptr : &srcBlock->front(),
282  srcBlock->empty() ? nullptr : &srcBlock->back()};
283  return action;
284  }
285  static BlockAction getMove(Block *block, BlockPosition originalPosition) {
286  return {BlockActionKind::Move, block, {originalPosition}};
287  }
288  static BlockAction getSplit(Block *block, Block *originalBlock) {
289  BlockAction action{BlockActionKind::Split, block, {}};
290  action.originalBlock = originalBlock;
291  return action;
292  }
293  static BlockAction getTypeConversion(Block *block) {
294  return BlockAction{BlockActionKind::TypeConversion, block, {}};
295  }
296 
297  // The action kind.
298  BlockActionKind kind;
299 
300  // A pointer to the block that was created by the action.
301  Block *block;
302 
303  union {
304  // In use if kind == BlockActionKind::Inline or BlockActionKind::Erase, and
305  // contains a pointer to the region that originally contained the block as
306  // well as the position of the block in that region.
307  BlockPosition originalPosition;
308  // In use if kind == BlockActionKind::Split and contains a pointer to the
309  // block that was split into two parts.
310  Block *originalBlock;
311  // In use if kind == BlockActionKind::Inline, and contains the information
312  // needed to undo the inlining.
313  InlineInfo inlineInfo;
314  };
315 };
316 
317 //===----------------------------------------------------------------------===//
318 // UnresolvedMaterialization
319 
320 /// This class represents an unresolved materialization, i.e. a materialization
321 /// that was inserted during conversion that needs to be legalized at the end of
322 /// the conversion process.
323 class UnresolvedMaterialization {
324 public:
325  /// The type of materialization.
326  enum Kind {
327  /// This materialization materializes a conversion for an illegal block
328  /// argument type, to a legal one.
329  Argument,
330 
331  /// This materialization materializes a conversion from an illegal type to a
332  /// legal one.
333  Target
334  };
335 
336  UnresolvedMaterialization(UnrealizedConversionCastOp op = nullptr,
337  const TypeConverter *converter = nullptr,
338  Kind kind = Target, Type origOutputType = nullptr)
339  : op(op), converterAndKind(converter, kind),
340  origOutputType(origOutputType) {}
341 
342  /// Return the temporary conversion operation inserted for this
343  /// materialization.
344  UnrealizedConversionCastOp getOp() const { return op; }
345 
346  /// Return the type converter of this materialization (which may be null).
347  const TypeConverter *getConverter() const {
348  return converterAndKind.getPointer();
349  }
350 
351  /// Return the kind of this materialization.
352  Kind getKind() const { return converterAndKind.getInt(); }
353 
354  /// Set the kind of this materialization.
355  void setKind(Kind kind) { converterAndKind.setInt(kind); }
356 
357  /// Return the original illegal output type of the input values.
358  Type getOrigOutputType() const { return origOutputType; }
359 
360 private:
361  /// The unresolved materialization operation created during conversion.
362  UnrealizedConversionCastOp op;
363 
364  /// The corresponding type converter to use when resolving this
365  /// materialization, and the kind of this materialization.
366  llvm::PointerIntPair<const TypeConverter *, 1, Kind> converterAndKind;
367 
368  /// The original output type. This is only used for argument conversions.
369  Type origOutputType;
370 };
371 } // namespace
372 
373 /// Build an unresolved materialization operation given an output type and set
374 /// of input operands.
376  UnresolvedMaterialization::Kind kind, Block *insertBlock,
377  Block::iterator insertPt, Location loc, ValueRange inputs, Type outputType,
378  Type origOutputType, const TypeConverter *converter,
379  SmallVectorImpl<UnresolvedMaterialization> &unresolvedMaterializations) {
380  // Avoid materializing an unnecessary cast.
381  if (inputs.size() == 1 && inputs.front().getType() == outputType)
382  return inputs.front();
383 
384  // Create an unresolved materialization. We use a new OpBuilder to avoid
385  // tracking the materialization like we do for other operations.
386  OpBuilder builder(insertBlock, insertPt);
387  auto convertOp =
388  builder.create<UnrealizedConversionCastOp>(loc, outputType, inputs);
389  unresolvedMaterializations.emplace_back(convertOp, converter, kind,
390  origOutputType);
391  return convertOp.getResult(0);
392 }
394  PatternRewriter &rewriter, Location loc, ValueRange inputs,
395  Type origOutputType, Type outputType, const TypeConverter *converter,
396  SmallVectorImpl<UnresolvedMaterialization> &unresolvedMaterializations) {
399  rewriter.getInsertionPoint(), loc, inputs, outputType, origOutputType,
400  converter, unresolvedMaterializations);
401 }
403  Location loc, Value input, Type outputType, const TypeConverter *converter,
404  SmallVectorImpl<UnresolvedMaterialization> &unresolvedMaterializations) {
405  Block *insertBlock = input.getParentBlock();
406  Block::iterator insertPt = insertBlock->begin();
407  if (OpResult inputRes = dyn_cast<OpResult>(input))
408  insertPt = ++inputRes.getOwner()->getIterator();
409 
411  UnresolvedMaterialization::Target, insertBlock, insertPt, loc, input,
412  outputType, outputType, converter, unresolvedMaterializations);
413 }
414 
415 //===----------------------------------------------------------------------===//
416 // ArgConverter
417 //===----------------------------------------------------------------------===//
418 namespace {
419 /// This class provides a simple interface for converting the types of block
420 /// arguments. This is done by creating a new block that contains the new legal
421 /// types and extracting the block that contains the old illegal types to allow
422 /// for undoing pending rewrites in the case of failure.
423 struct ArgConverter {
424  ArgConverter(
425  PatternRewriter &rewriter,
426  SmallVectorImpl<UnresolvedMaterialization> &unresolvedMaterializations)
427  : rewriter(rewriter),
428  unresolvedMaterializations(unresolvedMaterializations) {}
429 
430  /// This structure contains the information pertaining to an argument that has
431  /// been converted.
432  struct ConvertedArgInfo {
433  ConvertedArgInfo(unsigned newArgIdx, unsigned newArgSize,
434  Value castValue = nullptr)
435  : newArgIdx(newArgIdx), newArgSize(newArgSize), castValue(castValue) {}
436 
437  /// The start index of in the new argument list that contains arguments that
438  /// replace the original.
439  unsigned newArgIdx;
440 
441  /// The number of arguments that replaced the original argument.
442  unsigned newArgSize;
443 
444  /// The cast value that was created to cast from the new arguments to the
445  /// old. This only used if 'newArgSize' > 1.
446  Value castValue;
447  };
448 
449  /// This structure contains information pertaining to a block that has had its
450  /// signature converted.
451  struct ConvertedBlockInfo {
452  ConvertedBlockInfo(Block *origBlock, const TypeConverter *converter)
453  : origBlock(origBlock), converter(converter) {}
454 
455  /// The original block that was requested to have its signature converted.
456  Block *origBlock;
457 
458  /// The conversion information for each of the arguments. The information is
459  /// std::nullopt if the argument was dropped during conversion.
461 
462  /// The type converter used to convert the arguments.
463  const TypeConverter *converter;
464  };
465 
466  /// Return if the signature of the given block has already been converted.
467  bool hasBeenConverted(Block *block) const {
468  return conversionInfo.count(block) || convertedBlocks.count(block);
469  }
470 
471  /// Set the type converter to use for the given region.
472  void setConverter(Region *region, const TypeConverter *typeConverter) {
473  assert(typeConverter && "expected valid type converter");
474  regionToConverter[region] = typeConverter;
475  }
476 
477  /// Return the type converter to use for the given region, or null if there
478  /// isn't one.
479  const TypeConverter *getConverter(Region *region) {
480  return regionToConverter.lookup(region);
481  }
482 
483  //===--------------------------------------------------------------------===//
484  // Rewrite Application
485  //===--------------------------------------------------------------------===//
486 
487  /// Erase any rewrites registered for the blocks within the given operation
488  /// which is about to be removed. This merely drops the rewrites without
489  /// undoing them.
490  void notifyOpRemoved(Operation *op);
491 
492  /// Cleanup and undo any generated conversions for the arguments of block.
493  /// This method replaces the new block with the original, reverting the IR to
494  /// its original state.
495  void discardRewrites(Block *block);
496 
497  /// Fully replace uses of the old arguments with the new.
498  void applyRewrites(ConversionValueMapping &mapping);
499 
500  /// Materialize any necessary conversions for converted arguments that have
501  /// live users, using the provided `findLiveUser` to search for a user that
502  /// survives the conversion process.
504  materializeLiveConversions(ConversionValueMapping &mapping,
505  OpBuilder &builder,
506  function_ref<Operation *(Value)> findLiveUser);
507 
508  //===--------------------------------------------------------------------===//
509  // Conversion
510  //===--------------------------------------------------------------------===//
511 
512  /// Attempt to convert the signature of the given block, if successful a new
513  /// block is returned containing the new arguments. Returns `block` if it did
514  /// not require conversion.
516  convertSignature(Block *block, const TypeConverter *converter,
517  ConversionValueMapping &mapping,
518  SmallVectorImpl<BlockArgument> &argReplacements);
519 
520  /// Apply the given signature conversion on the given block. The new block
521  /// containing the updated signature is returned. If no conversions were
522  /// necessary, e.g. if the block has no arguments, `block` is returned.
523  /// `converter` is used to generate any necessary cast operations that
524  /// translate between the origin argument types and those specified in the
525  /// signature conversion.
526  Block *applySignatureConversion(
527  Block *block, const TypeConverter *converter,
528  TypeConverter::SignatureConversion &signatureConversion,
529  ConversionValueMapping &mapping,
530  SmallVectorImpl<BlockArgument> &argReplacements);
531 
532  /// Insert a new conversion into the cache.
533  void insertConversion(Block *newBlock, ConvertedBlockInfo &&info);
534 
535  /// A collection of blocks that have had their arguments converted. This is a
536  /// map from the new replacement block, back to the original block.
537  llvm::MapVector<Block *, ConvertedBlockInfo> conversionInfo;
538 
539  /// The set of original blocks that were converted.
540  DenseSet<Block *> convertedBlocks;
541 
542  /// A mapping from valid regions, to those containing the original blocks of a
543  /// conversion.
545 
546  /// A mapping of regions to type converters that should be used when
547  /// converting the arguments of blocks within that region.
549 
550  /// The pattern rewriter to use when materializing conversions.
551  PatternRewriter &rewriter;
552 
553  /// An ordered set of unresolved materializations during conversion.
554  SmallVectorImpl<UnresolvedMaterialization> &unresolvedMaterializations;
555 };
556 } // namespace
557 
558 //===----------------------------------------------------------------------===//
559 // Rewrite Application
560 
561 void ArgConverter::notifyOpRemoved(Operation *op) {
562  if (conversionInfo.empty())
563  return;
564 
565  for (Region &region : op->getRegions()) {
566  for (Block &block : region) {
567  // Drop any rewrites from within.
568  for (Operation &nestedOp : block)
569  if (nestedOp.getNumRegions())
570  notifyOpRemoved(&nestedOp);
571 
572  // Check if this block was converted.
573  auto it = conversionInfo.find(&block);
574  if (it == conversionInfo.end())
575  continue;
576 
577  // Drop all uses of the original arguments and delete the original block.
578  Block *origBlock = it->second.origBlock;
579  for (BlockArgument arg : origBlock->getArguments())
580  arg.dropAllUses();
581  conversionInfo.erase(it);
582  }
583  }
584 }
585 
586 void ArgConverter::discardRewrites(Block *block) {
587  auto it = conversionInfo.find(block);
588  if (it == conversionInfo.end())
589  return;
590  Block *origBlock = it->second.origBlock;
591 
592  // Drop all uses of the new block arguments and replace uses of the new block.
593  for (int i = block->getNumArguments() - 1; i >= 0; --i)
594  block->getArgument(i).dropAllUses();
595  block->replaceAllUsesWith(origBlock);
596 
597  // Move the operations back the original block and the delete the new block.
598  origBlock->getOperations().splice(origBlock->end(), block->getOperations());
599  origBlock->moveBefore(block);
600  block->erase();
601 
602  convertedBlocks.erase(origBlock);
603  conversionInfo.erase(it);
604 }
605 
606 void ArgConverter::applyRewrites(ConversionValueMapping &mapping) {
607  for (auto &info : conversionInfo) {
608  ConvertedBlockInfo &blockInfo = info.second;
609  Block *origBlock = blockInfo.origBlock;
610 
611  // Process the remapping for each of the original arguments.
612  for (unsigned i = 0, e = origBlock->getNumArguments(); i != e; ++i) {
613  std::optional<ConvertedArgInfo> &argInfo = blockInfo.argInfo[i];
614  BlockArgument origArg = origBlock->getArgument(i);
615 
616  // Handle the case of a 1->0 value mapping.
617  if (!argInfo) {
618  if (Value newArg = mapping.lookupOrNull(origArg, origArg.getType()))
619  origArg.replaceAllUsesWith(newArg);
620  continue;
621  }
622 
623  // Otherwise this is a 1->1+ value mapping.
624  Value castValue = argInfo->castValue;
625  assert(argInfo->newArgSize >= 1 && castValue && "expected 1->1+ mapping");
626 
627  // If the argument is still used, replace it with the generated cast.
628  if (!origArg.use_empty()) {
629  origArg.replaceAllUsesWith(
630  mapping.lookupOrDefault(castValue, origArg.getType()));
631  }
632  }
633  }
634 }
635 
636 LogicalResult ArgConverter::materializeLiveConversions(
637  ConversionValueMapping &mapping, OpBuilder &builder,
638  function_ref<Operation *(Value)> findLiveUser) {
639  for (auto &info : conversionInfo) {
640  Block *newBlock = info.first;
641  ConvertedBlockInfo &blockInfo = info.second;
642  Block *origBlock = blockInfo.origBlock;
643 
644  // Process the remapping for each of the original arguments.
645  for (unsigned i = 0, e = origBlock->getNumArguments(); i != e; ++i) {
646  // If the type of this argument changed and the argument is still live, we
647  // need to materialize a conversion.
648  BlockArgument origArg = origBlock->getArgument(i);
649  if (mapping.lookupOrNull(origArg, origArg.getType()))
650  continue;
651  Operation *liveUser = findLiveUser(origArg);
652  if (!liveUser)
653  continue;
654 
655  Value replacementValue = mapping.lookupOrDefault(origArg);
656  bool isDroppedArg = replacementValue == origArg;
657  if (isDroppedArg)
658  rewriter.setInsertionPointToStart(newBlock);
659  else
660  rewriter.setInsertionPointAfterValue(replacementValue);
661  Value newArg;
662  if (blockInfo.converter) {
663  newArg = blockInfo.converter->materializeSourceConversion(
664  rewriter, origArg.getLoc(), origArg.getType(),
665  isDroppedArg ? ValueRange() : ValueRange(replacementValue));
666  assert((!newArg || newArg.getType() == origArg.getType()) &&
667  "materialization hook did not provide a value of the expected "
668  "type");
669  }
670  if (!newArg) {
672  emitError(origArg.getLoc())
673  << "failed to materialize conversion for block argument #" << i
674  << " that remained live after conversion, type was "
675  << origArg.getType();
676  if (!isDroppedArg)
677  diag << ", with target type " << replacementValue.getType();
678  diag.attachNote(liveUser->getLoc())
679  << "see existing live user here: " << *liveUser;
680  return failure();
681  }
682  mapping.map(origArg, newArg);
683  }
684  }
685  return success();
686 }
687 
688 //===----------------------------------------------------------------------===//
689 // Conversion
690 
691 FailureOr<Block *> ArgConverter::convertSignature(
692  Block *block, const TypeConverter *converter,
693  ConversionValueMapping &mapping,
694  SmallVectorImpl<BlockArgument> &argReplacements) {
695  // Check if the block was already converted. If the block is detached,
696  // conservatively assume it is going to be deleted.
697  if (hasBeenConverted(block) || !block->getParent())
698  return block;
699  // If a converter wasn't provided, and the block wasn't already converted,
700  // there is nothing we can do.
701  if (!converter)
702  return failure();
703 
704  // Try to convert the signature for the block with the provided converter.
705  if (auto conversion = converter->convertBlockSignature(block))
706  return applySignatureConversion(block, converter, *conversion, mapping,
707  argReplacements);
708  return failure();
709 }
710 
711 Block *ArgConverter::applySignatureConversion(
712  Block *block, const TypeConverter *converter,
713  TypeConverter::SignatureConversion &signatureConversion,
714  ConversionValueMapping &mapping,
715  SmallVectorImpl<BlockArgument> &argReplacements) {
716  // If no arguments are being changed or added, there is nothing to do.
717  unsigned origArgCount = block->getNumArguments();
718  auto convertedTypes = signatureConversion.getConvertedTypes();
719  if (origArgCount == 0 && convertedTypes.empty())
720  return block;
721 
722  // Split the block at the beginning to get a new block to use for the updated
723  // signature.
724  Block *newBlock = block->splitBlock(block->begin());
725  block->replaceAllUsesWith(newBlock);
726 
727  // Map all new arguments to the location of the argument they originate from.
728  SmallVector<Location> newLocs(convertedTypes.size(),
729  rewriter.getUnknownLoc());
730  for (unsigned i = 0; i < origArgCount; ++i) {
731  auto inputMap = signatureConversion.getInputMapping(i);
732  if (!inputMap || inputMap->replacementValue)
733  continue;
734  Location origLoc = block->getArgument(i).getLoc();
735  for (unsigned j = 0; j < inputMap->size; ++j)
736  newLocs[inputMap->inputNo + j] = origLoc;
737  }
738 
739  SmallVector<Value, 4> newArgRange(
740  newBlock->addArguments(convertedTypes, newLocs));
741  ArrayRef<Value> newArgs(newArgRange);
742 
743  // Remap each of the original arguments as determined by the signature
744  // conversion.
745  ConvertedBlockInfo info(block, converter);
746  info.argInfo.resize(origArgCount);
747 
748  OpBuilder::InsertionGuard guard(rewriter);
749  rewriter.setInsertionPointToStart(newBlock);
750  for (unsigned i = 0; i != origArgCount; ++i) {
751  auto inputMap = signatureConversion.getInputMapping(i);
752  if (!inputMap)
753  continue;
754  BlockArgument origArg = block->getArgument(i);
755 
756  // If inputMap->replacementValue is not nullptr, then the argument is
757  // dropped and a replacement value is provided to be the remappedValue.
758  if (inputMap->replacementValue) {
759  assert(inputMap->size == 0 &&
760  "invalid to provide a replacement value when the argument isn't "
761  "dropped");
762  mapping.map(origArg, inputMap->replacementValue);
763  argReplacements.push_back(origArg);
764  continue;
765  }
766 
767  // Otherwise, this is a 1->1+ mapping.
768  auto replArgs = newArgs.slice(inputMap->inputNo, inputMap->size);
769  Value newArg;
770 
771  // If this is a 1->1 mapping and the types of new and replacement arguments
772  // match (i.e. it's an identity map), then the argument is mapped to its
773  // original type.
774  // FIXME: We simply pass through the replacement argument if there wasn't a
775  // converter, which isn't great as it allows implicit type conversions to
776  // appear. We should properly restructure this code to handle cases where a
777  // converter isn't provided and also to properly handle the case where an
778  // argument materialization is actually a temporary source materialization
779  // (e.g. in the case of 1->N).
780  if (replArgs.size() == 1 &&
781  (!converter || replArgs[0].getType() == origArg.getType())) {
782  newArg = replArgs.front();
783  } else {
784  Type origOutputType = origArg.getType();
785 
786  // Legalize the argument output type.
787  Type outputType = origOutputType;
788  if (Type legalOutputType = converter->convertType(outputType))
789  outputType = legalOutputType;
790 
792  rewriter, origArg.getLoc(), replArgs, origOutputType, outputType,
793  converter, unresolvedMaterializations);
794  }
795 
796  mapping.map(origArg, newArg);
797  argReplacements.push_back(origArg);
798  info.argInfo[i] =
799  ConvertedArgInfo(inputMap->inputNo, inputMap->size, newArg);
800  }
801 
802  // Remove the original block from the region and return the new one.
803  insertConversion(newBlock, std::move(info));
804  return newBlock;
805 }
806 
807 void ArgConverter::insertConversion(Block *newBlock,
808  ConvertedBlockInfo &&info) {
809  // Get a region to insert the old block.
810  Region *region = newBlock->getParent();
811  std::unique_ptr<Region> &mappedRegion = regionMapping[region];
812  if (!mappedRegion)
813  mappedRegion = std::make_unique<Region>(region->getParentOp());
814 
815  // Move the original block to the mapped region and emplace the conversion.
816  mappedRegion->getBlocks().splice(mappedRegion->end(), region->getBlocks(),
817  info.origBlock->getIterator());
818  convertedBlocks.insert(info.origBlock);
819  conversionInfo.insert({newBlock, std::move(info)});
820 }
821 
822 //===----------------------------------------------------------------------===//
823 // ConversionPatternRewriterImpl
824 //===----------------------------------------------------------------------===//
825 namespace mlir {
826 namespace detail {
829  : argConverter(rewriter, unresolvedMaterializations),
830  notifyCallback(nullptr) {}
831 
832  /// Cleanup and destroy any generated rewrite operations. This method is
833  /// invoked when the conversion process fails.
834  void discardRewrites();
835 
836  /// Apply all requested operation rewrites. This method is invoked when the
837  /// conversion process succeeds.
838  void applyRewrites();
839 
840  //===--------------------------------------------------------------------===//
841  // State Management
842  //===--------------------------------------------------------------------===//
843 
844  /// Return the current state of the rewriter.
845  RewriterState getCurrentState();
846 
847  /// Reset the state of the rewriter to a previously saved point.
848  void resetState(RewriterState state);
849 
850  /// Erase any blocks that were unlinked from their regions and stored in block
851  /// actions.
852  void eraseDanglingBlocks();
853 
854  /// Undo the block actions (motions, splits) one by one in reverse order until
855  /// "numActionsToKeep" actions remains.
856  void undoBlockActions(unsigned numActionsToKeep = 0);
857 
858  /// Remap the given values to those with potentially different types. Returns
859  /// success if the values could be remapped, failure otherwise. `valueDiagTag`
860  /// is the tag used when describing a value within a diagnostic, e.g.
861  /// "operand".
862  LogicalResult remapValues(StringRef valueDiagTag,
863  std::optional<Location> inputLoc,
864  PatternRewriter &rewriter, ValueRange values,
865  SmallVectorImpl<Value> &remapped);
866 
867  /// Returns true if the given operation is ignored, and does not need to be
868  /// converted.
869  bool isOpIgnored(Operation *op) const;
870 
871  /// Recursively marks the nested operations under 'op' as ignored. This
872  /// removes them from being considered for legalization.
873  void markNestedOpsIgnored(Operation *op);
874 
875  //===--------------------------------------------------------------------===//
876  // Type Conversion
877  //===--------------------------------------------------------------------===//
878 
879  /// Convert the signature of the given block.
880  FailureOr<Block *> convertBlockSignature(
881  Block *block, const TypeConverter *converter,
882  TypeConverter::SignatureConversion *conversion = nullptr);
883 
884  /// Apply a signature conversion on the given region, using `converter` for
885  /// materializations if not null.
886  Block *
887  applySignatureConversion(Region *region,
889  const TypeConverter *converter);
890 
891  /// Convert the types of block arguments within the given region.
893  convertRegionTypes(Region *region, const TypeConverter &converter,
894  TypeConverter::SignatureConversion *entryConversion);
895 
896  /// Convert the types of non-entry block arguments within the given region.
897  LogicalResult convertNonEntryRegionTypes(
898  Region *region, const TypeConverter &converter,
899  ArrayRef<TypeConverter::SignatureConversion> blockConversions = {});
900 
901  //===--------------------------------------------------------------------===//
902  // Rewriter Notification Hooks
903  //===--------------------------------------------------------------------===//
904 
905  /// PatternRewriter hook for replacing the results of an operation.
906  void notifyOpReplaced(Operation *op, ValueRange newValues);
907 
908  /// Notifies that a block is about to be erased.
909  void notifyBlockIsBeingErased(Block *block);
910 
911  /// Notifies that a block was created.
912  void notifyCreatedBlock(Block *block);
913 
914  /// Notifies that a block was split.
915  void notifySplitBlock(Block *block, Block *continuation);
916 
917  /// Notifies that a block is being inlined into another block.
918  void notifyBlockBeingInlined(Block *block, Block *srcBlock,
919  Block::iterator before);
920 
921  /// Notifies that the blocks of a region are about to be moved.
922  void notifyRegionIsBeingInlinedBefore(Region &region, Region &parent,
923  Region::iterator before);
924 
925  /// Notifies that a pattern match failed for the given reason.
927  notifyMatchFailure(Location loc,
928  function_ref<void(Diagnostic &)> reasonCallback);
929 
930  //===--------------------------------------------------------------------===//
931  // State
932  //===--------------------------------------------------------------------===//
933 
934  // Mapping between replaced values that differ in type. This happens when
935  // replacing a value with one of a different type.
936  ConversionValueMapping mapping;
937 
938  /// Utility used to convert block arguments.
939  ArgConverter argConverter;
940 
941  /// Ordered vector of all of the newly created operations during conversion.
943 
944  /// Ordered vector of all unresolved type conversion materializations during
945  /// conversion.
947 
948  /// Ordered map of requested operation replacements.
949  llvm::MapVector<Operation *, OpReplacement> replacements;
950 
951  /// Ordered vector of any requested block argument replacements.
953 
954  /// Ordered list of block operations (creations, splits, motions).
956 
957  /// A set of operations that should no longer be considered for legalization,
958  /// but were not directly replace/erased/etc. by a pattern. These are
959  /// generally child operations of other operations who were
960  /// replaced/erased/etc. This is not meant to be an exhaustive list of all
961  /// operations, but the minimal set that can be used to detect if a given
962  /// operation should be `ignored`. For example, we may add the operations that
963  /// define non-empty regions to the set, but not any of the others. This
964  /// simplifies the amount of memory needed as we can query if the parent
965  /// operation was ignored.
967 
968  /// A transaction state for each of operations that were updated in-place.
970 
971  /// A vector of indices into `replacements` of operations that were replaced
972  /// with values with different result types than the original operation, e.g.
973  /// 1->N conversion of some kind.
975 
976  /// The current type converter, or nullptr if no type converter is currently
977  /// active.
978  const TypeConverter *currentTypeConverter = nullptr;
979 
980  /// This allows the user to collect the match failure message.
982 
983 #ifndef NDEBUG
984  /// A set of operations that have pending updates. This tracking isn't
985  /// strictly necessary, and is thus only active during debug builds for extra
986  /// verification.
988 
989  /// A logger used to emit diagnostics during the conversion process.
990  llvm::ScopedPrinter logger{llvm::dbgs()};
991 #endif
992 };
993 } // namespace detail
994 } // namespace mlir
995 
996 /// Detach any operations nested in the given operation from their parent
997 /// blocks, and erase the given operation. This can be used when the nested
998 /// operations are scheduled for erasure themselves, so deleting the regions of
999 /// the given operation together with their content would result in double-free.
1000 /// This happens, for example, when rolling back op creation in the reverse
1001 /// order and if the nested ops were created before the parent op. This function
1002 /// does not need to collect nested ops recursively because it is expected to
1003 /// also be called for each nested op when it is about to be deleted.
1005  for (Region &region : op->getRegions()) {
1006  for (Block &block : region.getBlocks()) {
1007  while (!block.getOperations().empty())
1008  block.getOperations().remove(block.getOperations().begin());
1009  block.dropAllDefinedValueUses();
1010  }
1011  }
1012  op->dropAllUses();
1013  op->erase();
1014 }
1015 
1017  // Reset any operations that were updated in place.
1018  for (auto &state : rootUpdates)
1019  state.resetOperation();
1020 
1021  undoBlockActions();
1022 
1023  // Remove any newly created ops.
1024  for (UnresolvedMaterialization &materialization : unresolvedMaterializations)
1025  detachNestedAndErase(materialization.getOp());
1026  for (auto *op : llvm::reverse(createdOps))
1028 }
1029 
1031  // Apply all of the rewrites replacements requested during conversion.
1032  for (auto &repl : replacements) {
1033  for (OpResult result : repl.first->getResults())
1034  if (Value newValue = mapping.lookupOrNull(result, result.getType()))
1035  result.replaceAllUsesWith(newValue);
1036 
1037  // If this operation defines any regions, drop any pending argument
1038  // rewrites.
1039  if (repl.first->getNumRegions())
1040  argConverter.notifyOpRemoved(repl.first);
1041  }
1042 
1043  // Apply all of the requested argument replacements.
1044  for (BlockArgument arg : argReplacements) {
1045  Value repl = mapping.lookupOrNull(arg, arg.getType());
1046  if (!repl)
1047  continue;
1048 
1049  if (isa<BlockArgument>(repl)) {
1050  arg.replaceAllUsesWith(repl);
1051  continue;
1052  }
1053 
1054  // If the replacement value is an operation, we check to make sure that we
1055  // don't replace uses that are within the parent operation of the
1056  // replacement value.
1057  Operation *replOp = cast<OpResult>(repl).getOwner();
1058  Block *replBlock = replOp->getBlock();
1059  arg.replaceUsesWithIf(repl, [&](OpOperand &operand) {
1060  Operation *user = operand.getOwner();
1061  return user->getBlock() != replBlock || replOp->isBeforeInBlock(user);
1062  });
1063  }
1064 
1065  // Drop all of the unresolved materialization operations created during
1066  // conversion.
1067  for (auto &mat : unresolvedMaterializations) {
1068  mat.getOp()->dropAllUses();
1069  mat.getOp()->erase();
1070  }
1071 
1072  // In a second pass, erase all of the replaced operations in reverse. This
1073  // allows processing nested operations before their parent region is
1074  // destroyed. Because we process in reverse order, producers may be deleted
1075  // before their users (a pattern deleting a producer and then the consumer)
1076  // so we first drop all uses explicitly.
1077  for (auto &repl : llvm::reverse(replacements)) {
1078  repl.first->dropAllUses();
1079  repl.first->erase();
1080  }
1081 
1082  argConverter.applyRewrites(mapping);
1083 
1084  // Now that the ops have been erased, also erase dangling blocks.
1086 }
1087 
1088 //===----------------------------------------------------------------------===//
1089 // State Management
1090 
1092  return RewriterState(createdOps.size(), unresolvedMaterializations.size(),
1093  replacements.size(), argReplacements.size(),
1094  blockActions.size(), ignoredOps.size(),
1095  rootUpdates.size());
1096 }
1097 
1099  // Reset any operations that were updated in place.
1100  for (unsigned i = state.numRootUpdates, e = rootUpdates.size(); i != e; ++i)
1101  rootUpdates[i].resetOperation();
1102  rootUpdates.resize(state.numRootUpdates);
1103 
1104  // Reset any replaced arguments.
1105  for (BlockArgument replacedArg :
1106  llvm::drop_begin(argReplacements, state.numArgReplacements))
1107  mapping.erase(replacedArg);
1108  argReplacements.resize(state.numArgReplacements);
1109 
1110  // Undo any block actions.
1111  undoBlockActions(state.numBlockActions);
1112 
1113  // Reset any replaced operations and undo any saved mappings.
1114  for (auto &repl : llvm::drop_begin(replacements, state.numReplacements))
1115  for (auto result : repl.first->getResults())
1116  mapping.erase(result);
1117  while (replacements.size() != state.numReplacements)
1118  replacements.pop_back();
1119 
1120  // Pop all of the newly inserted materializations.
1121  while (unresolvedMaterializations.size() !=
1122  state.numUnresolvedMaterializations) {
1123  UnresolvedMaterialization mat = unresolvedMaterializations.pop_back_val();
1124  UnrealizedConversionCastOp op = mat.getOp();
1125 
1126  // If this was a target materialization, drop the mapping that was inserted.
1127  if (mat.getKind() == UnresolvedMaterialization::Target) {
1128  for (Value input : op->getOperands())
1129  mapping.erase(input);
1130  }
1132  }
1133 
1134  // Pop all of the newly created operations.
1135  while (createdOps.size() != state.numCreatedOps) {
1137  createdOps.pop_back();
1138  }
1139 
1140  // Pop all of the recorded ignored operations that are no longer valid.
1141  while (ignoredOps.size() != state.numIgnoredOperations)
1142  ignoredOps.pop_back();
1143 
1144  // Reset operations with changed results.
1145  while (!operationsWithChangedResults.empty() &&
1146  operationsWithChangedResults.back() >= state.numReplacements)
1147  operationsWithChangedResults.pop_back();
1148 }
1149 
1151  for (auto &action : blockActions)
1152  if (action.kind == BlockActionKind::Erase)
1153  delete action.block;
1154 }
1155 
1157  unsigned numActionsToKeep) {
1158  for (auto &action :
1159  llvm::reverse(llvm::drop_begin(blockActions, numActionsToKeep))) {
1160  switch (action.kind) {
1161  // Delete the created block.
1162  case BlockActionKind::Create: {
1163  // Unlink all of the operations within this block, they will be deleted
1164  // separately.
1165  auto &blockOps = action.block->getOperations();
1166  while (!blockOps.empty())
1167  blockOps.remove(blockOps.begin());
1168  action.block->dropAllDefinedValueUses();
1169  action.block->erase();
1170  break;
1171  }
1172  // Put the block (owned by action) back into its original position.
1173  case BlockActionKind::Erase: {
1174  auto &blockList = action.originalPosition.region->getBlocks();
1175  Block *insertAfterBlock = action.originalPosition.insertAfterBlock;
1176  blockList.insert((insertAfterBlock
1177  ? std::next(Region::iterator(insertAfterBlock))
1178  : blockList.begin()),
1179  action.block);
1180  break;
1181  }
1182  // Put the instructions from the destination block (owned by the action)
1183  // back into the source block.
1184  case BlockActionKind::Inline: {
1185  Block *sourceBlock = action.inlineInfo.sourceBlock;
1186  if (action.inlineInfo.firstInlinedInst) {
1187  assert(action.inlineInfo.lastInlinedInst && "expected operation");
1188  sourceBlock->getOperations().splice(
1189  sourceBlock->begin(), action.block->getOperations(),
1190  Block::iterator(action.inlineInfo.firstInlinedInst),
1191  ++Block::iterator(action.inlineInfo.lastInlinedInst));
1192  }
1193  break;
1194  }
1195  // Move the block back to its original position.
1196  case BlockActionKind::Move: {
1197  Region *originalRegion = action.originalPosition.region;
1198  Block *insertAfterBlock = action.originalPosition.insertAfterBlock;
1199  originalRegion->getBlocks().splice(
1200  (insertAfterBlock ? std::next(Region::iterator(insertAfterBlock))
1201  : originalRegion->end()),
1202  action.block->getParent()->getBlocks(), action.block);
1203  break;
1204  }
1205  // Merge back the block that was split out.
1206  case BlockActionKind::Split: {
1207  action.originalBlock->getOperations().splice(
1208  action.originalBlock->end(), action.block->getOperations());
1209  action.block->dropAllDefinedValueUses();
1210  action.block->erase();
1211  break;
1212  }
1213  // Undo the type conversion.
1214  case BlockActionKind::TypeConversion: {
1215  argConverter.discardRewrites(action.block);
1216  break;
1217  }
1218  }
1219  }
1220  blockActions.resize(numActionsToKeep);
1221 }
1222 
1224  StringRef valueDiagTag, std::optional<Location> inputLoc,
1225  PatternRewriter &rewriter, ValueRange values,
1226  SmallVectorImpl<Value> &remapped) {
1227  remapped.reserve(llvm::size(values));
1228 
1229  SmallVector<Type, 1> legalTypes;
1230  for (const auto &it : llvm::enumerate(values)) {
1231  Value operand = it.value();
1232  Type origType = operand.getType();
1233 
1234  // If a converter was provided, get the desired legal types for this
1235  // operand.
1236  Type desiredType;
1237  if (currentTypeConverter) {
1238  // If there is no legal conversion, fail to match this pattern.
1239  legalTypes.clear();
1240  if (failed(currentTypeConverter->convertType(origType, legalTypes))) {
1241  Location operandLoc = inputLoc ? *inputLoc : operand.getLoc();
1242  return notifyMatchFailure(operandLoc, [=](Diagnostic &diag) {
1243  diag << "unable to convert type for " << valueDiagTag << " #"
1244  << it.index() << ", type was " << origType;
1245  });
1246  }
1247  // TODO: There currently isn't any mechanism to do 1->N type conversion
1248  // via the PatternRewriter replacement API, so for now we just ignore it.
1249  if (legalTypes.size() == 1)
1250  desiredType = legalTypes.front();
1251  } else {
1252  // TODO: What we should do here is just set `desiredType` to `origType`
1253  // and then handle the necessary type conversions after the conversion
1254  // process has finished. Unfortunately a lot of patterns currently rely on
1255  // receiving the new operands even if the types change, so we keep the
1256  // original behavior here for now until all of the patterns relying on
1257  // this get updated.
1258  }
1259  Value newOperand = mapping.lookupOrDefault(operand, desiredType);
1260 
1261  // Handle the case where the conversion was 1->1 and the new operand type
1262  // isn't legal.
1263  Type newOperandType = newOperand.getType();
1264  if (currentTypeConverter && desiredType && newOperandType != desiredType) {
1265  Location operandLoc = inputLoc ? *inputLoc : operand.getLoc();
1267  operandLoc, newOperand, desiredType, currentTypeConverter,
1269  mapping.map(mapping.lookupOrDefault(newOperand), castValue);
1270  newOperand = castValue;
1271  }
1272  remapped.push_back(newOperand);
1273  }
1274  return success();
1275 }
1276 
1278  // Check to see if this operation was replaced or its parent ignored.
1279  return replacements.count(op) || ignoredOps.count(op->getParentOp());
1280 }
1281 
1283  // Walk this operation and collect nested operations that define non-empty
1284  // regions. We mark such operations as 'ignored' so that we know we don't have
1285  // to convert them, or their nested ops.
1286  if (op->getNumRegions() == 0)
1287  return;
1288  op->walk([&](Operation *op) {
1289  if (llvm::any_of(op->getRegions(),
1290  [](Region &region) { return !region.empty(); }))
1291  ignoredOps.insert(op);
1292  });
1293 }
1294 
1295 //===----------------------------------------------------------------------===//
1296 // Type Conversion
1297 
1299  Block *block, const TypeConverter *converter,
1300  TypeConverter::SignatureConversion *conversion) {
1301  FailureOr<Block *> result =
1302  conversion ? argConverter.applySignatureConversion(
1303  block, converter, *conversion, mapping, argReplacements)
1304  : argConverter.convertSignature(block, converter, mapping,
1305  argReplacements);
1306  if (failed(result))
1307  return failure();
1308  if (Block *newBlock = *result) {
1309  if (newBlock != block)
1310  blockActions.push_back(BlockAction::getTypeConversion(newBlock));
1311  }
1312  return result;
1313 }
1314 
1316  Region *region, TypeConverter::SignatureConversion &conversion,
1317  const TypeConverter *converter) {
1318  if (!region->empty())
1319  return *convertBlockSignature(&region->front(), converter, &conversion);
1320  return nullptr;
1321 }
1322 
1324  Region *region, const TypeConverter &converter,
1325  TypeConverter::SignatureConversion *entryConversion) {
1326  argConverter.setConverter(region, &converter);
1327  if (region->empty())
1328  return nullptr;
1329 
1330  if (failed(convertNonEntryRegionTypes(region, converter)))
1331  return failure();
1332 
1333  FailureOr<Block *> newEntry =
1334  convertBlockSignature(&region->front(), &converter, entryConversion);
1335  return newEntry;
1336 }
1337 
1339  Region *region, const TypeConverter &converter,
1341  argConverter.setConverter(region, &converter);
1342  if (region->empty())
1343  return success();
1344 
1345  // Convert the arguments of each block within the region.
1346  int blockIdx = 0;
1347  assert((blockConversions.empty() ||
1348  blockConversions.size() == region->getBlocks().size() - 1) &&
1349  "expected either to provide no SignatureConversions at all or to "
1350  "provide a SignatureConversion for each non-entry block");
1351 
1352  for (Block &block :
1353  llvm::make_early_inc_range(llvm::drop_begin(*region, 1))) {
1354  TypeConverter::SignatureConversion *blockConversion =
1355  blockConversions.empty()
1356  ? nullptr
1357  : const_cast<TypeConverter::SignatureConversion *>(
1358  &blockConversions[blockIdx++]);
1359 
1360  if (failed(convertBlockSignature(&block, &converter, blockConversion)))
1361  return failure();
1362  }
1363  return success();
1364 }
1365 
1366 //===----------------------------------------------------------------------===//
1367 // Rewriter Notification Hooks
1368 
1370  ValueRange newValues) {
1371  assert(newValues.size() == op->getNumResults());
1372  assert(!replacements.count(op) && "operation was already replaced");
1373 
1374  // Track if any of the results changed, e.g. erased and replaced with null.
1375  bool resultChanged = false;
1376 
1377  // Create mappings for each of the new result values.
1378  for (auto [newValue, result] : llvm::zip(newValues, op->getResults())) {
1379  if (!newValue) {
1380  resultChanged = true;
1381  continue;
1382  }
1383  // Remap, and check for any result type changes.
1384  mapping.map(result, newValue);
1385  resultChanged |= (newValue.getType() != result.getType());
1386  }
1387  if (resultChanged)
1388  operationsWithChangedResults.push_back(replacements.size());
1389 
1390  // Record the requested operation replacement.
1391  replacements.insert(std::make_pair(op, OpReplacement(currentTypeConverter)));
1392 
1393  // Mark this operation as recursively ignored so that we don't need to
1394  // convert any nested operations.
1396 }
1397 
1399  Region *region = block->getParent();
1400  Block *origPrevBlock = block->getPrevNode();
1401  blockActions.push_back(BlockAction::getErase(block, {region, origPrevBlock}));
1402 }
1403 
1405  blockActions.push_back(BlockAction::getCreate(block));
1406 }
1407 
1409  Block *continuation) {
1410  blockActions.push_back(BlockAction::getSplit(continuation, block));
1411 }
1412 
1414  Block *block, Block *srcBlock, Block::iterator before) {
1415  blockActions.push_back(BlockAction::getInline(block, srcBlock, before));
1416 }
1417 
1419  Region &region, Region &parent, Region::iterator before) {
1420  if (region.empty())
1421  return;
1422  Block *laterBlock = &region.back();
1423  for (auto &earlierBlock : llvm::drop_begin(llvm::reverse(region), 1)) {
1424  blockActions.push_back(
1425  BlockAction::getMove(laterBlock, {&region, &earlierBlock}));
1426  laterBlock = &earlierBlock;
1427  }
1428  blockActions.push_back(BlockAction::getMove(laterBlock, {&region, nullptr}));
1429 }
1430 
1432  Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
1433  LLVM_DEBUG({
1435  reasonCallback(diag);
1436  logger.startLine() << "** Failure : " << diag.str() << "\n";
1437  if (notifyCallback)
1439  });
1440  return failure();
1441 }
1442 
1443 //===----------------------------------------------------------------------===//
1444 // ConversionPatternRewriter
1445 //===----------------------------------------------------------------------===//
1446 
1448  : PatternRewriter(ctx),
1449  impl(new detail::ConversionPatternRewriterImpl(*this)) {
1450  setListener(this);
1451 }
1452 
1454 
1456  Operation *op, ValueRange newValues, bool *allUsesReplaced,
1457  llvm::unique_function<bool(OpOperand &) const> functor) {
1458  // TODO: To support this we will need to rework a bit of how replacements are
1459  // tracked, given that this isn't guranteed to replace all of the uses of an
1460  // operation. The main change is that now an operation can be replaced
1461  // multiple times, in parts. The current "set" based tracking is mainly useful
1462  // for tracking if a replaced operation should be ignored, i.e. if all of the
1463  // uses will be replaced.
1464  llvm_unreachable(
1465  "replaceOpWithIf is currently not supported by DialectConversion");
1466 }
1467 
1469  assert(op && newOp && "expected non-null op");
1470  replaceOp(op, newOp->getResults());
1471 }
1472 
1474  assert(op->getNumResults() == newValues.size() &&
1475  "incorrect # of replacement values");
1476  LLVM_DEBUG({
1477  impl->logger.startLine()
1478  << "** Replace : '" << op->getName() << "'(" << op << ")\n";
1479  });
1480  impl->notifyOpReplaced(op, newValues);
1481 }
1482 
1484  LLVM_DEBUG({
1485  impl->logger.startLine()
1486  << "** Erase : '" << op->getName() << "'(" << op << ")\n";
1487  });
1488  SmallVector<Value, 1> nullRepls(op->getNumResults(), nullptr);
1489  impl->notifyOpReplaced(op, nullRepls);
1490 }
1491 
1493  impl->notifyBlockIsBeingErased(block);
1494 
1495  // Mark all ops for erasure.
1496  for (Operation &op : *block)
1497  eraseOp(&op);
1498 
1499  // Unlink the block from its parent region. The block is kept in the block
1500  // action and will be actually destroyed when rewrites are applied. This
1501  // allows us to keep the operations in the block live and undo the removal by
1502  // re-inserting the block.
1503  block->getParent()->getBlocks().remove(block);
1504 }
1505 
1507  Region *region, TypeConverter::SignatureConversion &conversion,
1508  const TypeConverter *converter) {
1509  return impl->applySignatureConversion(region, conversion, converter);
1510 }
1511 
1513  Region *region, const TypeConverter &converter,
1514  TypeConverter::SignatureConversion *entryConversion) {
1515  return impl->convertRegionTypes(region, converter, entryConversion);
1516 }
1517 
1519  Region *region, const TypeConverter &converter,
1521  return impl->convertNonEntryRegionTypes(region, converter, blockConversions);
1522 }
1523 
1525  Value to) {
1526  LLVM_DEBUG({
1527  Operation *parentOp = from.getOwner()->getParentOp();
1528  impl->logger.startLine() << "** Replace Argument : '" << from
1529  << "'(in region of '" << parentOp->getName()
1530  << "'(" << from.getOwner()->getParentOp() << ")\n";
1531  });
1532  impl->argReplacements.push_back(from);
1533  impl->mapping.map(impl->mapping.lookupOrDefault(from), to);
1534 }
1535 
1537  SmallVector<Value> remappedValues;
1538  if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, key,
1539  remappedValues)))
1540  return nullptr;
1541  return remappedValues.front();
1542 }
1543 
1546  SmallVectorImpl<Value> &results) {
1547  if (keys.empty())
1548  return success();
1549  return impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, keys,
1550  results);
1551 }
1552 
1554  impl->notifyCreatedBlock(block);
1555 }
1556 
1558  Block::iterator before) {
1559  auto *continuation = PatternRewriter::splitBlock(block, before);
1560  impl->notifySplitBlock(block, continuation);
1561  return continuation;
1562 }
1563 
1565  Block::iterator before,
1566  ValueRange argValues) {
1567  assert(argValues.size() == source->getNumArguments() &&
1568  "incorrect # of argument replacement values");
1569 #ifndef NDEBUG
1570  auto opIgnored = [&](Operation *op) { return impl->isOpIgnored(op); };
1571 #endif // NDEBUG
1572  // The source block will be deleted, so it should not have any users (i.e.,
1573  // there should be no predecessors).
1574  assert(llvm::all_of(source->getUsers(), opIgnored) &&
1575  "expected 'source' to have no predecessors");
1576 
1577  impl->notifyBlockBeingInlined(dest, source, before);
1578  for (auto it : llvm::zip(source->getArguments(), argValues))
1579  replaceUsesOfBlockArgument(std::get<0>(it), std::get<1>(it));
1580  dest->getOperations().splice(before, source->getOperations());
1581  eraseBlock(source);
1582 }
1583 
1585  Region &parent,
1586  Region::iterator before) {
1587  impl->notifyRegionIsBeingInlinedBefore(region, parent, before);
1588  PatternRewriter::inlineRegionBefore(region, parent, before);
1589 }
1590 
1592  Region &parent,
1593  Region::iterator before,
1594  IRMapping &mapping) {
1595  if (region.empty())
1596  return;
1597 
1598  PatternRewriter::cloneRegionBefore(region, parent, before, mapping);
1599 
1600  for (Block &b : ForwardDominanceIterator<>::makeIterable(region)) {
1601  Block *cloned = mapping.lookup(&b);
1602  impl->notifyCreatedBlock(cloned);
1604  [&](Operation *op) { notifyOperationInserted(op); });
1605  }
1606 }
1607 
1609  LLVM_DEBUG({
1610  impl->logger.startLine()
1611  << "** Insert : '" << op->getName() << "'(" << op << ")\n";
1612  });
1613  impl->createdOps.push_back(op);
1614 }
1615 
1617 #ifndef NDEBUG
1618  impl->pendingRootUpdates.insert(op);
1619 #endif
1620  impl->rootUpdates.emplace_back(op);
1621 }
1622 
1625  // There is nothing to do here, we only need to track the operation at the
1626  // start of the update.
1627 #ifndef NDEBUG
1628  assert(impl->pendingRootUpdates.erase(op) &&
1629  "operation did not have a pending in-place update");
1630 #endif
1631 }
1632 
1634 #ifndef NDEBUG
1635  assert(impl->pendingRootUpdates.erase(op) &&
1636  "operation did not have a pending in-place update");
1637 #endif
1638  // Erase the last update for this operation.
1639  auto stateHasOp = [op](const auto &it) { return it.getOperation() == op; };
1640  auto &rootUpdates = impl->rootUpdates;
1641  auto it = llvm::find_if(llvm::reverse(rootUpdates), stateHasOp);
1642  assert(it != rootUpdates.rend() && "no root update started on op");
1643  (*it).resetOperation();
1644  int updateIdx = std::prev(rootUpdates.rend()) - it;
1645  rootUpdates.erase(rootUpdates.begin() + updateIdx);
1646 }
1647 
1649  Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
1650  return impl->notifyMatchFailure(loc, reasonCallback);
1651 }
1652 
1654  return *impl;
1655 }
1656 
1657 //===----------------------------------------------------------------------===//
1658 // ConversionPattern
1659 //===----------------------------------------------------------------------===//
1660 
1663  PatternRewriter &rewriter) const {
1664  auto &dialectRewriter = static_cast<ConversionPatternRewriter &>(rewriter);
1665  auto &rewriterImpl = dialectRewriter.getImpl();
1666 
1667  // Track the current conversion pattern type converter in the rewriter.
1668  llvm::SaveAndRestore currentConverterGuard(rewriterImpl.currentTypeConverter,
1669  getTypeConverter());
1670 
1671  // Remap the operands of the operation.
1672  SmallVector<Value, 4> operands;
1673  if (failed(rewriterImpl.remapValues("operand", op->getLoc(), rewriter,
1674  op->getOperands(), operands))) {
1675  return failure();
1676  }
1677  return matchAndRewrite(op, operands, dialectRewriter);
1678 }
1679 
1680 //===----------------------------------------------------------------------===//
1681 // OperationLegalizer
1682 //===----------------------------------------------------------------------===//
1683 
1684 namespace {
1685 /// A set of rewrite patterns that can be used to legalize a given operation.
1686 using LegalizationPatterns = SmallVector<const Pattern *, 1>;
1687 
1688 /// This class defines a recursive operation legalizer.
1689 class OperationLegalizer {
1690 public:
1691  using LegalizationAction = ConversionTarget::LegalizationAction;
1692 
1693  OperationLegalizer(const ConversionTarget &targetInfo,
1694  const FrozenRewritePatternSet &patterns);
1695 
1696  /// Returns true if the given operation is known to be illegal on the target.
1697  bool isIllegal(Operation *op) const;
1698 
1699  /// Attempt to legalize the given operation. Returns success if the operation
1700  /// was legalized, failure otherwise.
1701  LogicalResult legalize(Operation *op, ConversionPatternRewriter &rewriter);
1702 
1703  /// Returns the conversion target in use by the legalizer.
1704  const ConversionTarget &getTarget() { return target; }
1705 
1706 private:
1707  /// Attempt to legalize the given operation by folding it.
1708  LogicalResult legalizeWithFold(Operation *op,
1709  ConversionPatternRewriter &rewriter);
1710 
1711  /// Attempt to legalize the given operation by applying a pattern. Returns
1712  /// success if the operation was legalized, failure otherwise.
1713  LogicalResult legalizeWithPattern(Operation *op,
1714  ConversionPatternRewriter &rewriter);
1715 
1716  /// Return true if the given pattern may be applied to the given operation,
1717  /// false otherwise.
1718  bool canApplyPattern(Operation *op, const Pattern &pattern,
1719  ConversionPatternRewriter &rewriter);
1720 
1721  /// Legalize the resultant IR after successfully applying the given pattern.
1722  LogicalResult legalizePatternResult(Operation *op, const Pattern &pattern,
1723  ConversionPatternRewriter &rewriter,
1724  RewriterState &curState);
1725 
1726  /// Legalizes the actions registered during the execution of a pattern.
1727  LogicalResult legalizePatternBlockActions(Operation *op,
1728  ConversionPatternRewriter &rewriter,
1730  RewriterState &state,
1731  RewriterState &newState);
1732  LogicalResult legalizePatternCreatedOperations(
1734  RewriterState &state, RewriterState &newState);
1735  LogicalResult legalizePatternRootUpdates(ConversionPatternRewriter &rewriter,
1737  RewriterState &state,
1738  RewriterState &newState);
1739 
1740  //===--------------------------------------------------------------------===//
1741  // Cost Model
1742  //===--------------------------------------------------------------------===//
1743 
1744  /// Build an optimistic legalization graph given the provided patterns. This
1745  /// function populates 'anyOpLegalizerPatterns' and 'legalizerPatterns' with
1746  /// patterns for operations that are not directly legal, but may be
1747  /// transitively legal for the current target given the provided patterns.
1748  void buildLegalizationGraph(
1749  LegalizationPatterns &anyOpLegalizerPatterns,
1751 
1752  /// Compute the benefit of each node within the computed legalization graph.
1753  /// This orders the patterns within 'legalizerPatterns' based upon two
1754  /// criteria:
1755  /// 1) Prefer patterns that have the lowest legalization depth, i.e.
1756  /// represent the more direct mapping to the target.
1757  /// 2) When comparing patterns with the same legalization depth, prefer the
1758  /// pattern with the highest PatternBenefit. This allows for users to
1759  /// prefer specific legalizations over others.
1760  void computeLegalizationGraphBenefit(
1761  LegalizationPatterns &anyOpLegalizerPatterns,
1763 
1764  /// Compute the legalization depth when legalizing an operation of the given
1765  /// type.
1766  unsigned computeOpLegalizationDepth(
1767  OperationName op, DenseMap<OperationName, unsigned> &minOpPatternDepth,
1769 
1770  /// Apply the conversion cost model to the given set of patterns, and return
1771  /// the smallest legalization depth of any of the patterns. See
1772  /// `computeLegalizationGraphBenefit` for the breakdown of the cost model.
1773  unsigned applyCostModelToPatterns(
1774  LegalizationPatterns &patterns,
1775  DenseMap<OperationName, unsigned> &minOpPatternDepth,
1777 
1778  /// The current set of patterns that have been applied.
1779  SmallPtrSet<const Pattern *, 8> appliedPatterns;
1780 
1781  /// The legalization information provided by the target.
1782  const ConversionTarget &target;
1783 
1784  /// The pattern applicator to use for conversions.
1785  PatternApplicator applicator;
1786 };
1787 } // namespace
1788 
1789 OperationLegalizer::OperationLegalizer(const ConversionTarget &targetInfo,
1790  const FrozenRewritePatternSet &patterns)
1791  : target(targetInfo), applicator(patterns) {
1792  // The set of patterns that can be applied to illegal operations to transform
1793  // them into legal ones.
1795  LegalizationPatterns anyOpLegalizerPatterns;
1796 
1797  buildLegalizationGraph(anyOpLegalizerPatterns, legalizerPatterns);
1798  computeLegalizationGraphBenefit(anyOpLegalizerPatterns, legalizerPatterns);
1799 }
1800 
1801 bool OperationLegalizer::isIllegal(Operation *op) const {
1802  return target.isIllegal(op);
1803 }
1804 
1806 OperationLegalizer::legalize(Operation *op,
1807  ConversionPatternRewriter &rewriter) {
1808 #ifndef NDEBUG
1809  const char *logLineComment =
1810  "//===-------------------------------------------===//\n";
1811 
1812  auto &logger = rewriter.getImpl().logger;
1813 #endif
1814  LLVM_DEBUG({
1815  logger.getOStream() << "\n";
1816  logger.startLine() << logLineComment;
1817  logger.startLine() << "Legalizing operation : '" << op->getName() << "'("
1818  << op << ") {\n";
1819  logger.indent();
1820 
1821  // If the operation has no regions, just print it here.
1822  if (op->getNumRegions() == 0) {
1823  op->print(logger.startLine(), OpPrintingFlags().printGenericOpForm());
1824  logger.getOStream() << "\n\n";
1825  }
1826  });
1827 
1828  // Check if this operation is legal on the target.
1829  if (auto legalityInfo = target.isLegal(op)) {
1830  LLVM_DEBUG({
1831  logSuccess(
1832  logger, "operation marked legal by the target{0}",
1833  legalityInfo->isRecursivelyLegal
1834  ? "; NOTE: operation is recursively legal; skipping internals"
1835  : "");
1836  logger.startLine() << logLineComment;
1837  });
1838 
1839  // If this operation is recursively legal, mark its children as ignored so
1840  // that we don't consider them for legalization.
1841  if (legalityInfo->isRecursivelyLegal)
1842  rewriter.getImpl().markNestedOpsIgnored(op);
1843  return success();
1844  }
1845 
1846  // Check to see if the operation is ignored and doesn't need to be converted.
1847  if (rewriter.getImpl().isOpIgnored(op)) {
1848  LLVM_DEBUG({
1849  logSuccess(logger, "operation marked 'ignored' during conversion");
1850  logger.startLine() << logLineComment;
1851  });
1852  return success();
1853  }
1854 
1855  // If the operation isn't legal, try to fold it in-place.
1856  // TODO: Should we always try to do this, even if the op is
1857  // already legal?
1858  if (succeeded(legalizeWithFold(op, rewriter))) {
1859  LLVM_DEBUG({
1860  logSuccess(logger, "operation was folded");
1861  logger.startLine() << logLineComment;
1862  });
1863  return success();
1864  }
1865 
1866  // Otherwise, we need to apply a legalization pattern to this operation.
1867  if (succeeded(legalizeWithPattern(op, rewriter))) {
1868  LLVM_DEBUG({
1869  logSuccess(logger, "");
1870  logger.startLine() << logLineComment;
1871  });
1872  return success();
1873  }
1874 
1875  LLVM_DEBUG({
1876  logFailure(logger, "no matched legalization pattern");
1877  logger.startLine() << logLineComment;
1878  });
1879  return failure();
1880 }
1881 
1883 OperationLegalizer::legalizeWithFold(Operation *op,
1884  ConversionPatternRewriter &rewriter) {
1885  auto &rewriterImpl = rewriter.getImpl();
1886  RewriterState curState = rewriterImpl.getCurrentState();
1887 
1888  LLVM_DEBUG({
1889  rewriterImpl.logger.startLine() << "* Fold {\n";
1890  rewriterImpl.logger.indent();
1891  });
1892 
1893  // Try to fold the operation.
1894  SmallVector<Value, 2> replacementValues;
1895  rewriter.setInsertionPoint(op);
1896  if (failed(rewriter.tryFold(op, replacementValues))) {
1897  LLVM_DEBUG(logFailure(rewriterImpl.logger, "unable to fold"));
1898  return failure();
1899  }
1900 
1901  // Insert a replacement for 'op' with the folded replacement values.
1902  rewriter.replaceOp(op, replacementValues);
1903 
1904  // Recursively legalize any new constant operations.
1905  for (unsigned i = curState.numCreatedOps, e = rewriterImpl.createdOps.size();
1906  i != e; ++i) {
1907  Operation *cstOp = rewriterImpl.createdOps[i];
1908  if (failed(legalize(cstOp, rewriter))) {
1909  LLVM_DEBUG(logFailure(rewriterImpl.logger,
1910  "failed to legalize generated constant '{0}'",
1911  cstOp->getName()));
1912  rewriterImpl.resetState(curState);
1913  return failure();
1914  }
1915  }
1916 
1917  LLVM_DEBUG(logSuccess(rewriterImpl.logger, ""));
1918  return success();
1919 }
1920 
1922 OperationLegalizer::legalizeWithPattern(Operation *op,
1923  ConversionPatternRewriter &rewriter) {
1924  auto &rewriterImpl = rewriter.getImpl();
1925 
1926  // Functor that returns if the given pattern may be applied.
1927  auto canApply = [&](const Pattern &pattern) {
1928  return canApplyPattern(op, pattern, rewriter);
1929  };
1930 
1931  // Functor that cleans up the rewriter state after a pattern failed to match.
1932  RewriterState curState = rewriterImpl.getCurrentState();
1933  auto onFailure = [&](const Pattern &pattern) {
1934  LLVM_DEBUG({
1935  logFailure(rewriterImpl.logger, "pattern failed to match");
1936  if (rewriterImpl.notifyCallback) {
1938  diag << "Failed to apply pattern \"" << pattern.getDebugName()
1939  << "\" on op:\n"
1940  << *op;
1941  rewriterImpl.notifyCallback(diag);
1942  }
1943  });
1944  rewriterImpl.resetState(curState);
1945  appliedPatterns.erase(&pattern);
1946  };
1947 
1948  // Functor that performs additional legalization when a pattern is
1949  // successfully applied.
1950  auto onSuccess = [&](const Pattern &pattern) {
1951  auto result = legalizePatternResult(op, pattern, rewriter, curState);
1952  appliedPatterns.erase(&pattern);
1953  if (failed(result))
1954  rewriterImpl.resetState(curState);
1955  return result;
1956  };
1957 
1958  // Try to match and rewrite a pattern on this operation.
1959  return applicator.matchAndRewrite(op, rewriter, canApply, onFailure,
1960  onSuccess);
1961 }
1962 
1963 bool OperationLegalizer::canApplyPattern(Operation *op, const Pattern &pattern,
1964  ConversionPatternRewriter &rewriter) {
1965  LLVM_DEBUG({
1966  auto &os = rewriter.getImpl().logger;
1967  os.getOStream() << "\n";
1968  os.startLine() << "* Pattern : '" << op->getName() << " -> (";
1969  llvm::interleaveComma(pattern.getGeneratedOps(), os.getOStream());
1970  os.getOStream() << ")' {\n";
1971  os.indent();
1972  });
1973 
1974  // Ensure that we don't cycle by not allowing the same pattern to be
1975  // applied twice in the same recursion stack if it is not known to be safe.
1976  if (!pattern.hasBoundedRewriteRecursion() &&
1977  !appliedPatterns.insert(&pattern).second) {
1978  LLVM_DEBUG(
1979  logFailure(rewriter.getImpl().logger, "pattern was already applied"));
1980  return false;
1981  }
1982  return true;
1983 }
1984 
1986 OperationLegalizer::legalizePatternResult(Operation *op, const Pattern &pattern,
1987  ConversionPatternRewriter &rewriter,
1988  RewriterState &curState) {
1989  auto &impl = rewriter.getImpl();
1990 
1991 #ifndef NDEBUG
1992  assert(impl.pendingRootUpdates.empty() && "dangling root updates");
1993 #endif
1994 
1995  // Check that the root was either replaced or updated in place.
1996  auto replacedRoot = [&] {
1997  return llvm::any_of(
1998  llvm::drop_begin(impl.replacements, curState.numReplacements),
1999  [op](auto &it) { return it.first == op; });
2000  };
2001  auto updatedRootInPlace = [&] {
2002  return llvm::any_of(
2003  llvm::drop_begin(impl.rootUpdates, curState.numRootUpdates),
2004  [op](auto &state) { return state.getOperation() == op; });
2005  };
2006  (void)replacedRoot;
2007  (void)updatedRootInPlace;
2008  assert((replacedRoot() || updatedRootInPlace()) &&
2009  "expected pattern to replace the root operation");
2010 
2011  // Legalize each of the actions registered during application.
2012  RewriterState newState = impl.getCurrentState();
2013  if (failed(legalizePatternBlockActions(op, rewriter, impl, curState,
2014  newState)) ||
2015  failed(legalizePatternRootUpdates(rewriter, impl, curState, newState)) ||
2016  failed(legalizePatternCreatedOperations(rewriter, impl, curState,
2017  newState))) {
2018  return failure();
2019  }
2020 
2021  LLVM_DEBUG(logSuccess(impl.logger, "pattern applied successfully"));
2022  return success();
2023 }
2024 
2025 LogicalResult OperationLegalizer::legalizePatternBlockActions(
2026  Operation *op, ConversionPatternRewriter &rewriter,
2027  ConversionPatternRewriterImpl &impl, RewriterState &state,
2028  RewriterState &newState) {
2029  SmallPtrSet<Operation *, 16> operationsToIgnore;
2030 
2031  // If the pattern moved or created any blocks, make sure the types of block
2032  // arguments get legalized.
2033  for (int i = state.numBlockActions, e = newState.numBlockActions; i != e;
2034  ++i) {
2035  auto &action = impl.blockActions[i];
2036  if (action.kind == BlockActionKind::TypeConversion ||
2037  action.kind == BlockActionKind::Erase)
2038  continue;
2039  // Only check blocks outside of the current operation.
2040  Operation *parentOp = action.block->getParentOp();
2041  if (!parentOp || parentOp == op || action.block->getNumArguments() == 0)
2042  continue;
2043 
2044  // If the region of the block has a type converter, try to convert the block
2045  // directly.
2046  if (auto *converter =
2047  impl.argConverter.getConverter(action.block->getParent())) {
2048  if (failed(impl.convertBlockSignature(action.block, converter))) {
2049  LLVM_DEBUG(logFailure(impl.logger, "failed to convert types of moved "
2050  "block"));
2051  return failure();
2052  }
2053  continue;
2054  }
2055 
2056  // Otherwise, check that this operation isn't one generated by this pattern.
2057  // This is because we will attempt to legalize the parent operation, and
2058  // blocks in regions created by this pattern will already be legalized later
2059  // on. If we haven't built the set yet, build it now.
2060  if (operationsToIgnore.empty()) {
2061  auto createdOps = ArrayRef<Operation *>(impl.createdOps)
2062  .drop_front(state.numCreatedOps);
2063  operationsToIgnore.insert(createdOps.begin(), createdOps.end());
2064  }
2065 
2066  // If this operation should be considered for re-legalization, try it.
2067  if (operationsToIgnore.insert(parentOp).second &&
2068  failed(legalize(parentOp, rewriter))) {
2069  LLVM_DEBUG(logFailure(
2070  impl.logger, "operation '{0}'({1}) became illegal after block action",
2071  parentOp->getName(), parentOp));
2072  return failure();
2073  }
2074  }
2075  return success();
2076 }
2077 
2078 LogicalResult OperationLegalizer::legalizePatternCreatedOperations(
2080  RewriterState &state, RewriterState &newState) {
2081  for (int i = state.numCreatedOps, e = newState.numCreatedOps; i != e; ++i) {
2082  Operation *op = impl.createdOps[i];
2083  if (failed(legalize(op, rewriter))) {
2084  LLVM_DEBUG(logFailure(impl.logger,
2085  "failed to legalize generated operation '{0}'({1})",
2086  op->getName(), op));
2087  return failure();
2088  }
2089  }
2090  return success();
2091 }
2092 
2093 LogicalResult OperationLegalizer::legalizePatternRootUpdates(
2095  RewriterState &state, RewriterState &newState) {
2096  for (int i = state.numRootUpdates, e = newState.numRootUpdates; i != e; ++i) {
2097  Operation *op = impl.rootUpdates[i].getOperation();
2098  if (failed(legalize(op, rewriter))) {
2099  LLVM_DEBUG(logFailure(
2100  impl.logger, "failed to legalize operation updated in-place '{0}'",
2101  op->getName()));
2102  return failure();
2103  }
2104  }
2105  return success();
2106 }
2107 
2108 //===----------------------------------------------------------------------===//
2109 // Cost Model
2110 
2111 void OperationLegalizer::buildLegalizationGraph(
2112  LegalizationPatterns &anyOpLegalizerPatterns,
2113  DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
2114  // A mapping between an operation and a set of operations that can be used to
2115  // generate it.
2117  // A mapping between an operation and any currently invalid patterns it has.
2119  // A worklist of patterns to consider for legality.
2120  SetVector<const Pattern *> patternWorklist;
2121 
2122  // Build the mapping from operations to the parent ops that may generate them.
2123  applicator.walkAllPatterns([&](const Pattern &pattern) {
2124  std::optional<OperationName> root = pattern.getRootKind();
2125 
2126  // If the pattern has no specific root, we can't analyze the relationship
2127  // between the root op and generated operations. Given that, add all such
2128  // patterns to the legalization set.
2129  if (!root) {
2130  anyOpLegalizerPatterns.push_back(&pattern);
2131  return;
2132  }
2133 
2134  // Skip operations that are always known to be legal.
2135  if (target.getOpAction(*root) == LegalizationAction::Legal)
2136  return;
2137 
2138  // Add this pattern to the invalid set for the root op and record this root
2139  // as a parent for any generated operations.
2140  invalidPatterns[*root].insert(&pattern);
2141  for (auto op : pattern.getGeneratedOps())
2142  parentOps[op].insert(*root);
2143 
2144  // Add this pattern to the worklist.
2145  patternWorklist.insert(&pattern);
2146  });
2147 
2148  // If there are any patterns that don't have a specific root kind, we can't
2149  // make direct assumptions about what operations will never be legalized.
2150  // Note: Technically we could, but it would require an analysis that may
2151  // recurse into itself. It would be better to perform this kind of filtering
2152  // at a higher level than here anyways.
2153  if (!anyOpLegalizerPatterns.empty()) {
2154  for (const Pattern *pattern : patternWorklist)
2155  legalizerPatterns[*pattern->getRootKind()].push_back(pattern);
2156  return;
2157  }
2158 
2159  while (!patternWorklist.empty()) {
2160  auto *pattern = patternWorklist.pop_back_val();
2161 
2162  // Check to see if any of the generated operations are invalid.
2163  if (llvm::any_of(pattern->getGeneratedOps(), [&](OperationName op) {
2164  std::optional<LegalizationAction> action = target.getOpAction(op);
2165  return !legalizerPatterns.count(op) &&
2166  (!action || action == LegalizationAction::Illegal);
2167  }))
2168  continue;
2169 
2170  // Otherwise, if all of the generated operation are valid, this op is now
2171  // legal so add all of the child patterns to the worklist.
2172  legalizerPatterns[*pattern->getRootKind()].push_back(pattern);
2173  invalidPatterns[*pattern->getRootKind()].erase(pattern);
2174 
2175  // Add any invalid patterns of the parent operations to see if they have now
2176  // become legal.
2177  for (auto op : parentOps[*pattern->getRootKind()])
2178  patternWorklist.set_union(invalidPatterns[op]);
2179  }
2180 }
2181 
2182 void OperationLegalizer::computeLegalizationGraphBenefit(
2183  LegalizationPatterns &anyOpLegalizerPatterns,
2184  DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
2185  // The smallest pattern depth, when legalizing an operation.
2186  DenseMap<OperationName, unsigned> minOpPatternDepth;
2187 
2188  // For each operation that is transitively legal, compute a cost for it.
2189  for (auto &opIt : legalizerPatterns)
2190  if (!minOpPatternDepth.count(opIt.first))
2191  computeOpLegalizationDepth(opIt.first, minOpPatternDepth,
2192  legalizerPatterns);
2193 
2194  // Apply the cost model to the patterns that can match any operation. Those
2195  // with a specific operation type are already resolved when computing the op
2196  // legalization depth.
2197  if (!anyOpLegalizerPatterns.empty())
2198  applyCostModelToPatterns(anyOpLegalizerPatterns, minOpPatternDepth,
2199  legalizerPatterns);
2200 
2201  // Apply a cost model to the pattern applicator. We order patterns first by
2202  // depth then benefit. `legalizerPatterns` contains per-op patterns by
2203  // decreasing benefit.
2204  applicator.applyCostModel([&](const Pattern &pattern) {
2205  ArrayRef<const Pattern *> orderedPatternList;
2206  if (std::optional<OperationName> rootName = pattern.getRootKind())
2207  orderedPatternList = legalizerPatterns[*rootName];
2208  else
2209  orderedPatternList = anyOpLegalizerPatterns;
2210 
2211  // If the pattern is not found, then it was removed and cannot be matched.
2212  auto *it = llvm::find(orderedPatternList, &pattern);
2213  if (it == orderedPatternList.end())
2215 
2216  // Patterns found earlier in the list have higher benefit.
2217  return PatternBenefit(std::distance(it, orderedPatternList.end()));
2218  });
2219 }
2220 
2221 unsigned OperationLegalizer::computeOpLegalizationDepth(
2222  OperationName op, DenseMap<OperationName, unsigned> &minOpPatternDepth,
2223  DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
2224  // Check for existing depth.
2225  auto depthIt = minOpPatternDepth.find(op);
2226  if (depthIt != minOpPatternDepth.end())
2227  return depthIt->second;
2228 
2229  // If a mapping for this operation does not exist, then this operation
2230  // is always legal. Return 0 as the depth for a directly legal operation.
2231  auto opPatternsIt = legalizerPatterns.find(op);
2232  if (opPatternsIt == legalizerPatterns.end() || opPatternsIt->second.empty())
2233  return 0u;
2234 
2235  // Record this initial depth in case we encounter this op again when
2236  // recursively computing the depth.
2237  minOpPatternDepth.try_emplace(op, std::numeric_limits<unsigned>::max());
2238 
2239  // Apply the cost model to the operation patterns, and update the minimum
2240  // depth.
2241  unsigned minDepth = applyCostModelToPatterns(
2242  opPatternsIt->second, minOpPatternDepth, legalizerPatterns);
2243  minOpPatternDepth[op] = minDepth;
2244  return minDepth;
2245 }
2246 
2247 unsigned OperationLegalizer::applyCostModelToPatterns(
2248  LegalizationPatterns &patterns,
2249  DenseMap<OperationName, unsigned> &minOpPatternDepth,
2250  DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
2251  unsigned minDepth = std::numeric_limits<unsigned>::max();
2252 
2253  // Compute the depth for each pattern within the set.
2255  patternsByDepth.reserve(patterns.size());
2256  for (const Pattern *pattern : patterns) {
2257  unsigned depth = 1;
2258  for (auto generatedOp : pattern->getGeneratedOps()) {
2259  unsigned generatedOpDepth = computeOpLegalizationDepth(
2260  generatedOp, minOpPatternDepth, legalizerPatterns);
2261  depth = std::max(depth, generatedOpDepth + 1);
2262  }
2263  patternsByDepth.emplace_back(pattern, depth);
2264 
2265  // Update the minimum depth of the pattern list.
2266  minDepth = std::min(minDepth, depth);
2267  }
2268 
2269  // If the operation only has one legalization pattern, there is no need to
2270  // sort them.
2271  if (patternsByDepth.size() == 1)
2272  return minDepth;
2273 
2274  // Sort the patterns by those likely to be the most beneficial.
2275  std::stable_sort(patternsByDepth.begin(), patternsByDepth.end(),
2276  [](const std::pair<const Pattern *, unsigned> &lhs,
2277  const std::pair<const Pattern *, unsigned> &rhs) {
2278  // First sort by the smaller pattern legalization
2279  // depth.
2280  if (lhs.second != rhs.second)
2281  return lhs.second < rhs.second;
2282 
2283  // Then sort by the larger pattern benefit.
2284  auto lhsBenefit = lhs.first->getBenefit();
2285  auto rhsBenefit = rhs.first->getBenefit();
2286  return lhsBenefit > rhsBenefit;
2287  });
2288 
2289  // Update the legalization pattern to use the new sorted list.
2290  patterns.clear();
2291  for (auto &patternIt : patternsByDepth)
2292  patterns.push_back(patternIt.first);
2293  return minDepth;
2294 }
2295 
2296 //===----------------------------------------------------------------------===//
2297 // OperationConverter
2298 //===----------------------------------------------------------------------===//
2299 namespace {
2300 enum OpConversionMode {
2301  /// In this mode, the conversion will ignore failed conversions to allow
2302  /// illegal operations to co-exist in the IR.
2303  Partial,
2304 
2305  /// In this mode, all operations must be legal for the given target for the
2306  /// conversion to succeed.
2307  Full,
2308 
2309  /// In this mode, operations are analyzed for legality. No actual rewrites are
2310  /// applied to the operations on success.
2311  Analysis,
2312 };
2313 
2314 // This class converts operations to a given conversion target via a set of
2315 // rewrite patterns. The conversion behaves differently depending on the
2316 // conversion mode.
2317 struct OperationConverter {
2318  explicit OperationConverter(const ConversionTarget &target,
2319  const FrozenRewritePatternSet &patterns,
2320  OpConversionMode mode,
2321  DenseSet<Operation *> *trackedOps = nullptr)
2322  : opLegalizer(target, patterns), mode(mode), trackedOps(trackedOps) {}
2323 
2324  /// Converts the given operations to the conversion target.
2326  convertOperations(ArrayRef<Operation *> ops,
2327  function_ref<void(Diagnostic &)> notifyCallback = nullptr);
2328 
2329 private:
2330  /// Converts an operation with the given rewriter.
2331  LogicalResult convert(ConversionPatternRewriter &rewriter, Operation *op);
2332 
2333  /// This method is called after the conversion process to legalize any
2334  /// remaining artifacts and complete the conversion.
2335  LogicalResult finalize(ConversionPatternRewriter &rewriter);
2336 
2337  /// Legalize the types of converted block arguments.
2339  legalizeConvertedArgumentTypes(ConversionPatternRewriter &rewriter,
2340  ConversionPatternRewriterImpl &rewriterImpl);
2341 
2342  /// Legalize any unresolved type materializations.
2343  LogicalResult legalizeUnresolvedMaterializations(
2344  ConversionPatternRewriter &rewriter,
2345  ConversionPatternRewriterImpl &rewriterImpl,
2346  std::optional<DenseMap<Value, SmallVector<Value>>> &inverseMapping);
2347 
2348  /// Legalize an operation result that was marked as "erased".
2350  legalizeErasedResult(Operation *op, OpResult result,
2351  ConversionPatternRewriterImpl &rewriterImpl);
2352 
2353  /// Legalize an operation result that was replaced with a value of a different
2354  /// type.
2355  LogicalResult legalizeChangedResultType(
2356  Operation *op, OpResult result, Value newValue,
2357  const TypeConverter *replConverter, ConversionPatternRewriter &rewriter,
2358  ConversionPatternRewriterImpl &rewriterImpl,
2359  const DenseMap<Value, SmallVector<Value>> &inverseMapping);
2360 
2361  /// The legalizer to use when converting operations.
2362  OperationLegalizer opLegalizer;
2363 
2364  /// The conversion mode to use when legalizing operations.
2365  OpConversionMode mode;
2366 
2367  /// A set of pre-existing operations. When mode == OpConversionMode::Analysis,
2368  /// this is populated with ops found to be legalizable to the target.
2369  /// When mode == OpConversionMode::Partial, this is populated with ops found
2370  /// *not* to be legalizable to the target.
2371  DenseSet<Operation *> *trackedOps;
2372 };
2373 } // namespace
2374 
2375 LogicalResult OperationConverter::convert(ConversionPatternRewriter &rewriter,
2376  Operation *op) {
2377  // Legalize the given operation.
2378  if (failed(opLegalizer.legalize(op, rewriter))) {
2379  // Handle the case of a failed conversion for each of the different modes.
2380  // Full conversions expect all operations to be converted.
2381  if (mode == OpConversionMode::Full)
2382  return op->emitError()
2383  << "failed to legalize operation '" << op->getName() << "'";
2384  // Partial conversions allow conversions to fail iff the operation was not
2385  // explicitly marked as illegal. If the user provided a nonlegalizableOps
2386  // set, non-legalizable ops are included.
2387  if (mode == OpConversionMode::Partial) {
2388  if (opLegalizer.isIllegal(op))
2389  return op->emitError()
2390  << "failed to legalize operation '" << op->getName()
2391  << "' that was explicitly marked illegal";
2392  if (trackedOps)
2393  trackedOps->insert(op);
2394  }
2395  } else if (mode == OpConversionMode::Analysis) {
2396  // Analysis conversions don't fail if any operations fail to legalize,
2397  // they are only interested in the operations that were successfully
2398  // legalized.
2399  trackedOps->insert(op);
2400  }
2401  return success();
2402 }
2403 
2404 LogicalResult OperationConverter::convertOperations(
2406  function_ref<void(Diagnostic &)> notifyCallback) {
2407  if (ops.empty())
2408  return success();
2409  const ConversionTarget &target = opLegalizer.getTarget();
2410 
2411  // Compute the set of operations and blocks to convert.
2412  SmallVector<Operation *> toConvert;
2413  for (auto *op : ops) {
2415  [&](Operation *op) {
2416  toConvert.push_back(op);
2417  // Don't check this operation's children for conversion if the
2418  // operation is recursively legal.
2419  auto legalityInfo = target.isLegal(op);
2420  if (legalityInfo && legalityInfo->isRecursivelyLegal)
2421  return WalkResult::skip();
2422  return WalkResult::advance();
2423  });
2424  }
2425 
2426  // Convert each operation and discard rewrites on failure.
2427  ConversionPatternRewriter rewriter(ops.front()->getContext());
2428  ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
2429  rewriterImpl.notifyCallback = notifyCallback;
2430 
2431  for (auto *op : toConvert)
2432  if (failed(convert(rewriter, op)))
2433  return rewriterImpl.discardRewrites(), failure();
2434 
2435  // Now that all of the operations have been converted, finalize the conversion
2436  // process to ensure any lingering conversion artifacts are cleaned up and
2437  // legalized.
2438  if (failed(finalize(rewriter)))
2439  return rewriterImpl.discardRewrites(), failure();
2440 
2441  // After a successful conversion, apply rewrites if this is not an analysis
2442  // conversion.
2443  if (mode == OpConversionMode::Analysis) {
2444  rewriterImpl.discardRewrites();
2445  } else {
2446  rewriterImpl.applyRewrites();
2447 
2448  // It is possible for a later pattern to erase an op that was originally
2449  // identified as illegal and added to the trackedOps, remove it now after
2450  // replacements have been computed.
2451  if (trackedOps)
2452  for (auto &repl : rewriterImpl.replacements)
2453  trackedOps->erase(repl.first);
2454  }
2455  return success();
2456 }
2457 
2459 OperationConverter::finalize(ConversionPatternRewriter &rewriter) {
2460  std::optional<DenseMap<Value, SmallVector<Value>>> inverseMapping;
2461  ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
2462  if (failed(legalizeUnresolvedMaterializations(rewriter, rewriterImpl,
2463  inverseMapping)) ||
2464  failed(legalizeConvertedArgumentTypes(rewriter, rewriterImpl)))
2465  return failure();
2466 
2467  if (rewriterImpl.operationsWithChangedResults.empty())
2468  return success();
2469 
2470  // Process requested operation replacements.
2471  for (unsigned i = 0, e = rewriterImpl.operationsWithChangedResults.size();
2472  i != e; ++i) {
2473  unsigned replIdx = rewriterImpl.operationsWithChangedResults[i];
2474  auto &repl = *(rewriterImpl.replacements.begin() + replIdx);
2475  for (OpResult result : repl.first->getResults()) {
2476  Value newValue = rewriterImpl.mapping.lookupOrNull(result);
2477 
2478  // If the operation result was replaced with null, all of the uses of this
2479  // value should be replaced.
2480  if (!newValue) {
2481  if (failed(legalizeErasedResult(repl.first, result, rewriterImpl)))
2482  return failure();
2483  continue;
2484  }
2485 
2486  // Otherwise, check to see if the type of the result changed.
2487  if (result.getType() == newValue.getType())
2488  continue;
2489 
2490  // Compute the inverse mapping only if it is really needed.
2491  if (!inverseMapping)
2492  inverseMapping = rewriterImpl.mapping.getInverse();
2493 
2494  // Legalize this result.
2495  rewriter.setInsertionPoint(repl.first);
2496  if (failed(legalizeChangedResultType(repl.first, result, newValue,
2497  repl.second.converter, rewriter,
2498  rewriterImpl, *inverseMapping)))
2499  return failure();
2500 
2501  // Update the end iterator for this loop in the case it was updated
2502  // when legalizing generated conversion operations.
2503  e = rewriterImpl.operationsWithChangedResults.size();
2504  }
2505  }
2506  return success();
2507 }
2508 
2509 LogicalResult OperationConverter::legalizeConvertedArgumentTypes(
2510  ConversionPatternRewriter &rewriter,
2511  ConversionPatternRewriterImpl &rewriterImpl) {
2512  // Functor used to check if all users of a value will be dead after
2513  // conversion.
2514  auto findLiveUser = [&](Value val) {
2515  auto liveUserIt = llvm::find_if_not(val.getUsers(), [&](Operation *user) {
2516  return rewriterImpl.isOpIgnored(user);
2517  });
2518  return liveUserIt == val.user_end() ? nullptr : *liveUserIt;
2519  };
2520  return rewriterImpl.argConverter.materializeLiveConversions(
2521  rewriterImpl.mapping, rewriter, findLiveUser);
2522 }
2523 
2524 /// Replace the results of a materialization operation with the given values.
2525 static void
2527  ResultRange matResults, ValueRange values,
2528  DenseMap<Value, SmallVector<Value>> &inverseMapping) {
2529  matResults.replaceAllUsesWith(values);
2530 
2531  // For each of the materialization results, update the inverse mappings to
2532  // point to the replacement values.
2533  for (auto [matResult, newValue] : llvm::zip(matResults, values)) {
2534  auto inverseMapIt = inverseMapping.find(matResult);
2535  if (inverseMapIt == inverseMapping.end())
2536  continue;
2537 
2538  // Update the reverse mapping, or remove the mapping if we couldn't update
2539  // it. Not being able to update signals that the mapping would have become
2540  // circular (i.e. %foo -> newValue -> %foo), which may occur as values are
2541  // propagated through temporary materializations. We simply drop the
2542  // mapping, and let the post-conversion replacement logic handle updating
2543  // uses.
2544  for (Value inverseMapVal : inverseMapIt->second)
2545  if (!rewriterImpl.mapping.tryMap(inverseMapVal, newValue))
2546  rewriterImpl.mapping.erase(inverseMapVal);
2547  }
2548 }
2549 
2550 /// Compute all of the unresolved materializations that will persist beyond the
2551 /// conversion process, and require inserting a proper user materialization for.
2554  ConversionPatternRewriter &rewriter,
2555  ConversionPatternRewriterImpl &rewriterImpl,
2556  DenseMap<Value, SmallVector<Value>> &inverseMapping,
2557  SetVector<UnresolvedMaterialization *> &necessaryMaterializations) {
2558  auto isLive = [&](Value value) {
2559  auto findFn = [&](Operation *user) {
2560  auto matIt = materializationOps.find(user);
2561  if (matIt != materializationOps.end())
2562  return !necessaryMaterializations.count(matIt->second);
2563  return rewriterImpl.isOpIgnored(user);
2564  };
2565  // This value may be replacing another value that has a live user.
2566  for (Value inv : inverseMapping.lookup(value))
2567  if (llvm::find_if_not(inv.getUsers(), findFn) != inv.user_end())
2568  return true;
2569  // Or have live users itself.
2570  return llvm::find_if_not(value.getUsers(), findFn) != value.user_end();
2571  };
2572 
2573  llvm::unique_function<Value(Value, Value, Type)> lookupRemappedValue =
2574  [&](Value invalidRoot, Value value, Type type) {
2575  // Check to see if the input operation was remapped to a variant of the
2576  // output.
2577  Value remappedValue = rewriterImpl.mapping.lookupOrDefault(value, type);
2578  if (remappedValue.getType() == type && remappedValue != invalidRoot)
2579  return remappedValue;
2580 
2581  // Check to see if the input is a materialization operation that
2582  // provides an inverse conversion. We just check blindly for
2583  // UnrealizedConversionCastOp here, but it has no effect on correctness.
2584  auto inputCastOp = value.getDefiningOp<UnrealizedConversionCastOp>();
2585  if (inputCastOp && inputCastOp->getNumOperands() == 1)
2586  return lookupRemappedValue(invalidRoot, inputCastOp->getOperand(0),
2587  type);
2588 
2589  return Value();
2590  };
2591 
2593  for (auto &mat : rewriterImpl.unresolvedMaterializations) {
2594  materializationOps.try_emplace(mat.getOp(), &mat);
2595  worklist.insert(&mat);
2596  }
2597  while (!worklist.empty()) {
2598  UnresolvedMaterialization *mat = worklist.pop_back_val();
2599  UnrealizedConversionCastOp op = mat->getOp();
2600 
2601  // We currently only handle target materializations here.
2602  assert(op->getNumResults() == 1 && "unexpected materialization type");
2603  OpResult opResult = op->getOpResult(0);
2604  Type outputType = opResult.getType();
2605  Operation::operand_range inputOperands = op.getOperands();
2606 
2607  // Try to forward propagate operands for user conversion casts that result
2608  // in the input types of the current cast.
2609  for (Operation *user : llvm::make_early_inc_range(opResult.getUsers())) {
2610  auto castOp = dyn_cast<UnrealizedConversionCastOp>(user);
2611  if (!castOp)
2612  continue;
2613  if (castOp->getResultTypes() == inputOperands.getTypes()) {
2614  replaceMaterialization(rewriterImpl, opResult, inputOperands,
2615  inverseMapping);
2616  necessaryMaterializations.remove(materializationOps.lookup(user));
2617  }
2618  }
2619 
2620  // Try to avoid materializing a resolved materialization if possible.
2621  // Handle the case of a 1-1 materialization.
2622  if (inputOperands.size() == 1) {
2623  // Check to see if the input operation was remapped to a variant of the
2624  // output.
2625  Value remappedValue =
2626  lookupRemappedValue(opResult, inputOperands[0], outputType);
2627  if (remappedValue && remappedValue != opResult) {
2628  replaceMaterialization(rewriterImpl, opResult, remappedValue,
2629  inverseMapping);
2630  necessaryMaterializations.remove(mat);
2631  continue;
2632  }
2633  } else {
2634  // TODO: Avoid materializing other types of conversions here.
2635  }
2636 
2637  // Check to see if this is an argument materialization.
2638  auto isBlockArg = [](Value v) { return isa<BlockArgument>(v); };
2639  if (llvm::any_of(op->getOperands(), isBlockArg) ||
2640  llvm::any_of(inverseMapping[op->getResult(0)], isBlockArg)) {
2642  }
2643 
2644  // If the materialization does not have any live users, we don't need to
2645  // generate a user materialization for it.
2646  // FIXME: For argument materializations, we currently need to check if any
2647  // of the inverse mapped values are used because some patterns expect blind
2648  // value replacement even if the types differ in some cases. When those
2649  // patterns are fixed, we can drop the argument special case here.
2650  bool isMaterializationLive = isLive(opResult);
2651  if (mat->getKind() == UnresolvedMaterialization::Argument)
2652  isMaterializationLive |= llvm::any_of(inverseMapping[opResult], isLive);
2653  if (!isMaterializationLive)
2654  continue;
2655  if (!necessaryMaterializations.insert(mat))
2656  continue;
2657 
2658  // Reprocess input materializations to see if they have an updated status.
2659  for (Value input : inputOperands) {
2660  if (auto parentOp = input.getDefiningOp<UnrealizedConversionCastOp>()) {
2661  if (auto *mat = materializationOps.lookup(parentOp))
2662  worklist.insert(mat);
2663  }
2664  }
2665  }
2666 }
2667 
2668 /// Legalize the given unresolved materialization. Returns success if the
2669 /// materialization was legalized, failure otherise.
2671  UnresolvedMaterialization &mat,
2673  ConversionPatternRewriter &rewriter,
2674  ConversionPatternRewriterImpl &rewriterImpl,
2675  DenseMap<Value, SmallVector<Value>> &inverseMapping) {
2676  auto findLiveUser = [&](auto &&users) {
2677  auto liveUserIt = llvm::find_if_not(
2678  users, [&](Operation *user) { return rewriterImpl.isOpIgnored(user); });
2679  return liveUserIt == users.end() ? nullptr : *liveUserIt;
2680  };
2681 
2682  llvm::unique_function<Value(Value, Type)> lookupRemappedValue =
2683  [&](Value value, Type type) {
2684  // Check to see if the input operation was remapped to a variant of the
2685  // output.
2686  Value remappedValue = rewriterImpl.mapping.lookupOrDefault(value, type);
2687  if (remappedValue.getType() == type)
2688  return remappedValue;
2689  return Value();
2690  };
2691 
2692  UnrealizedConversionCastOp op = mat.getOp();
2693  if (!rewriterImpl.ignoredOps.insert(op))
2694  return success();
2695 
2696  // We currently only handle target materializations here.
2697  OpResult opResult = op->getOpResult(0);
2698  Operation::operand_range inputOperands = op.getOperands();
2699  Type outputType = opResult.getType();
2700 
2701  // If any input to this materialization is another materialization, resolve
2702  // the input first.
2703  for (Value value : op->getOperands()) {
2704  auto valueCast = value.getDefiningOp<UnrealizedConversionCastOp>();
2705  if (!valueCast)
2706  continue;
2707 
2708  auto matIt = materializationOps.find(valueCast);
2709  if (matIt != materializationOps.end())
2711  *matIt->second, materializationOps, rewriter, rewriterImpl,
2712  inverseMapping)))
2713  return failure();
2714  }
2715 
2716  // Perform a last ditch attempt to avoid materializing a resolved
2717  // materialization if possible.
2718  // Handle the case of a 1-1 materialization.
2719  if (inputOperands.size() == 1) {
2720  // Check to see if the input operation was remapped to a variant of the
2721  // output.
2722  Value remappedValue = lookupRemappedValue(inputOperands[0], outputType);
2723  if (remappedValue && remappedValue != opResult) {
2724  replaceMaterialization(rewriterImpl, opResult, remappedValue,
2725  inverseMapping);
2726  return success();
2727  }
2728  } else {
2729  // TODO: Avoid materializing other types of conversions here.
2730  }
2731 
2732  // Try to materialize the conversion.
2733  if (const TypeConverter *converter = mat.getConverter()) {
2734  // FIXME: Determine a suitable insertion location when there are multiple
2735  // inputs.
2736  if (inputOperands.size() == 1)
2737  rewriter.setInsertionPointAfterValue(inputOperands.front());
2738  else
2739  rewriter.setInsertionPoint(op);
2740 
2741  Value newMaterialization;
2742  switch (mat.getKind()) {
2744  // Try to materialize an argument conversion.
2745  // FIXME: The current argument materialization hook expects the original
2746  // output type, even though it doesn't use that as the actual output type
2747  // of the generated IR. The output type is just used as an indicator of
2748  // the type of materialization to do. This behavior is really awkward in
2749  // that it diverges from the behavior of the other hooks, and can be
2750  // easily misunderstood. We should clean up the argument hooks to better
2751  // represent the desired invariants we actually care about.
2752  newMaterialization = converter->materializeArgumentConversion(
2753  rewriter, op->getLoc(), mat.getOrigOutputType(), inputOperands);
2754  if (newMaterialization)
2755  break;
2756 
2757  // If an argument materialization failed, fallback to trying a target
2758  // materialization.
2759  [[fallthrough]];
2760  case UnresolvedMaterialization::Target:
2761  newMaterialization = converter->materializeTargetConversion(
2762  rewriter, op->getLoc(), outputType, inputOperands);
2763  break;
2764  }
2765  if (newMaterialization) {
2766  replaceMaterialization(rewriterImpl, opResult, newMaterialization,
2767  inverseMapping);
2768  return success();
2769  }
2770  }
2771 
2773  << "failed to legalize unresolved materialization "
2774  "from "
2775  << inputOperands.getTypes() << " to " << outputType
2776  << " that remained live after conversion";
2777  if (Operation *liveUser = findLiveUser(op->getUsers())) {
2778  diag.attachNote(liveUser->getLoc())
2779  << "see existing live user here: " << *liveUser;
2780  }
2781  return failure();
2782 }
2783 
2784 LogicalResult OperationConverter::legalizeUnresolvedMaterializations(
2785  ConversionPatternRewriter &rewriter,
2786  ConversionPatternRewriterImpl &rewriterImpl,
2787  std::optional<DenseMap<Value, SmallVector<Value>>> &inverseMapping) {
2788  if (rewriterImpl.unresolvedMaterializations.empty())
2789  return success();
2790  inverseMapping = rewriterImpl.mapping.getInverse();
2791 
2792  // As an initial step, compute all of the inserted materializations that we
2793  // expect to persist beyond the conversion process.
2795  SetVector<UnresolvedMaterialization *> necessaryMaterializations;
2796  computeNecessaryMaterializations(materializationOps, rewriter, rewriterImpl,
2797  *inverseMapping, necessaryMaterializations);
2798 
2799  // Once computed, legalize any necessary materializations.
2800  for (auto *mat : necessaryMaterializations) {
2802  *mat, materializationOps, rewriter, rewriterImpl, *inverseMapping)))
2803  return failure();
2804  }
2805  return success();
2806 }
2807 
2808 LogicalResult OperationConverter::legalizeErasedResult(
2809  Operation *op, OpResult result,
2810  ConversionPatternRewriterImpl &rewriterImpl) {
2811  // If the operation result was replaced with null, all of the uses of this
2812  // value should be replaced.
2813  auto liveUserIt = llvm::find_if_not(result.getUsers(), [&](Operation *user) {
2814  return rewriterImpl.isOpIgnored(user);
2815  });
2816  if (liveUserIt != result.user_end()) {
2817  InFlightDiagnostic diag = op->emitError("failed to legalize operation '")
2818  << op->getName() << "' marked as erased";
2819  diag.attachNote(liveUserIt->getLoc())
2820  << "found live user of result #" << result.getResultNumber() << ": "
2821  << *liveUserIt;
2822  return failure();
2823  }
2824  return success();
2825 }
2826 
2827 /// Finds a user of the given value, or of any other value that the given value
2828 /// replaced, that was not replaced in the conversion process.
2830  Value initialValue, ConversionPatternRewriterImpl &rewriterImpl,
2831  const DenseMap<Value, SmallVector<Value>> &inverseMapping) {
2832  SmallVector<Value> worklist(1, initialValue);
2833  while (!worklist.empty()) {
2834  Value value = worklist.pop_back_val();
2835 
2836  // Walk the users of this value to see if there are any live users that
2837  // weren't replaced during conversion.
2838  auto liveUserIt = llvm::find_if_not(value.getUsers(), [&](Operation *user) {
2839  return rewriterImpl.isOpIgnored(user);
2840  });
2841  if (liveUserIt != value.user_end())
2842  return *liveUserIt;
2843  auto mapIt = inverseMapping.find(value);
2844  if (mapIt != inverseMapping.end())
2845  worklist.append(mapIt->second);
2846  }
2847  return nullptr;
2848 }
2849 
2850 LogicalResult OperationConverter::legalizeChangedResultType(
2851  Operation *op, OpResult result, Value newValue,
2852  const TypeConverter *replConverter, ConversionPatternRewriter &rewriter,
2853  ConversionPatternRewriterImpl &rewriterImpl,
2854  const DenseMap<Value, SmallVector<Value>> &inverseMapping) {
2855  Operation *liveUser =
2856  findLiveUserOfReplaced(result, rewriterImpl, inverseMapping);
2857  if (!liveUser)
2858  return success();
2859 
2860  // Functor used to emit a conversion error for a failed materialization.
2861  auto emitConversionError = [&] {
2863  << "failed to materialize conversion for result #"
2864  << result.getResultNumber() << " of operation '"
2865  << op->getName()
2866  << "' that remained live after conversion";
2867  diag.attachNote(liveUser->getLoc())
2868  << "see existing live user here: " << *liveUser;
2869  return failure();
2870  };
2871 
2872  // If the replacement has a type converter, attempt to materialize a
2873  // conversion back to the original type.
2874  if (!replConverter)
2875  return emitConversionError();
2876 
2877  // Materialize a conversion for this live result value.
2878  Type resultType = result.getType();
2879  Value convertedValue = replConverter->materializeSourceConversion(
2880  rewriter, op->getLoc(), resultType, newValue);
2881  if (!convertedValue)
2882  return emitConversionError();
2883 
2884  rewriterImpl.mapping.map(result, convertedValue);
2885  return success();
2886 }
2887 
2888 //===----------------------------------------------------------------------===//
2889 // Type Conversion
2890 //===----------------------------------------------------------------------===//
2891 
2893  ArrayRef<Type> types) {
2894  assert(!types.empty() && "expected valid types");
2895  remapInput(origInputNo, /*newInputNo=*/argTypes.size(), types.size());
2896  addInputs(types);
2897 }
2898 
2900  assert(!types.empty() &&
2901  "1->0 type remappings don't need to be added explicitly");
2902  argTypes.append(types.begin(), types.end());
2903 }
2904 
2905 void TypeConverter::SignatureConversion::remapInput(unsigned origInputNo,
2906  unsigned newInputNo,
2907  unsigned newInputCount) {
2908  assert(!remappedInputs[origInputNo] && "input has already been remapped");
2909  assert(newInputCount != 0 && "expected valid input count");
2910  remappedInputs[origInputNo] =
2911  InputMapping{newInputNo, newInputCount, /*replacementValue=*/nullptr};
2912 }
2913 
2915  Value replacementValue) {
2916  assert(!remappedInputs[origInputNo] && "input has already been remapped");
2917  remappedInputs[origInputNo] =
2918  InputMapping{origInputNo, /*size=*/0, replacementValue};
2919 }
2920 
2922  SmallVectorImpl<Type> &results) const {
2923  {
2924  std::shared_lock<decltype(cacheMutex)> cacheReadLock(cacheMutex,
2925  std::defer_lock);
2927  cacheReadLock.lock();
2928  auto existingIt = cachedDirectConversions.find(t);
2929  if (existingIt != cachedDirectConversions.end()) {
2930  if (existingIt->second)
2931  results.push_back(existingIt->second);
2932  return success(existingIt->second != nullptr);
2933  }
2934  auto multiIt = cachedMultiConversions.find(t);
2935  if (multiIt != cachedMultiConversions.end()) {
2936  results.append(multiIt->second.begin(), multiIt->second.end());
2937  return success();
2938  }
2939  }
2940  // Walk the added converters in reverse order to apply the most recently
2941  // registered first.
2942  size_t currentCount = results.size();
2943 
2944  std::unique_lock<decltype(cacheMutex)> cacheWriteLock(cacheMutex,
2945  std::defer_lock);
2946 
2947  for (const ConversionCallbackFn &converter : llvm::reverse(conversions)) {
2948  if (std::optional<LogicalResult> result = converter(t, results)) {
2950  cacheWriteLock.lock();
2951  if (!succeeded(*result)) {
2952  cachedDirectConversions.try_emplace(t, nullptr);
2953  return failure();
2954  }
2955  auto newTypes = ArrayRef<Type>(results).drop_front(currentCount);
2956  if (newTypes.size() == 1)
2957  cachedDirectConversions.try_emplace(t, newTypes.front());
2958  else
2959  cachedMultiConversions.try_emplace(t, llvm::to_vector<2>(newTypes));
2960  return success();
2961  }
2962  }
2963  return failure();
2964 }
2965 
2967  // Use the multi-type result version to convert the type.
2968  SmallVector<Type, 1> results;
2969  if (failed(convertType(t, results)))
2970  return nullptr;
2971 
2972  // Check to ensure that only one type was produced.
2973  return results.size() == 1 ? results.front() : nullptr;
2974 }
2975 
2978  SmallVectorImpl<Type> &results) const {
2979  for (Type type : types)
2980  if (failed(convertType(type, results)))
2981  return failure();
2982  return success();
2983 }
2984 
2985 bool TypeConverter::isLegal(Type type) const {
2986  return convertType(type) == type;
2987 }
2989  return isLegal(op->getOperandTypes()) && isLegal(op->getResultTypes());
2990 }
2991 
2992 bool TypeConverter::isLegal(Region *region) const {
2993  return llvm::all_of(*region, [this](Block &block) {
2994  return isLegal(block.getArgumentTypes());
2995  });
2996 }
2997 
2998 bool TypeConverter::isSignatureLegal(FunctionType ty) const {
2999  return isLegal(llvm::concat<const Type>(ty.getInputs(), ty.getResults()));
3000 }
3001 
3004  SignatureConversion &result) const {
3005  // Try to convert the given input type.
3006  SmallVector<Type, 1> convertedTypes;
3007  if (failed(convertType(type, convertedTypes)))
3008  return failure();
3009 
3010  // If this argument is being dropped, there is nothing left to do.
3011  if (convertedTypes.empty())
3012  return success();
3013 
3014  // Otherwise, add the new inputs.
3015  result.addInputs(inputNo, convertedTypes);
3016  return success();
3017 }
3020  SignatureConversion &result,
3021  unsigned origInputOffset) const {
3022  for (unsigned i = 0, e = types.size(); i != e; ++i)
3023  if (failed(convertSignatureArg(origInputOffset + i, types[i], result)))
3024  return failure();
3025  return success();
3026 }
3027 
3028 Value TypeConverter::materializeConversion(
3029  ArrayRef<MaterializationCallbackFn> materializations, OpBuilder &builder,
3030  Location loc, Type resultType, ValueRange inputs) const {
3031  for (const MaterializationCallbackFn &fn : llvm::reverse(materializations))
3032  if (std::optional<Value> result = fn(builder, resultType, inputs, loc))
3033  return *result;
3034  return nullptr;
3035 }
3036 
3037 std::optional<TypeConverter::SignatureConversion>
3039  SignatureConversion conversion(block->getNumArguments());
3040  if (failed(convertSignatureArgs(block->getArgumentTypes(), conversion)))
3041  return std::nullopt;
3042  return conversion;
3043 }
3044 
3045 //===----------------------------------------------------------------------===//
3046 // Type attribute conversion
3047 //===----------------------------------------------------------------------===//
3050  return AttributeConversionResult(attr, resultTag);
3051 }
3052 
3055  return AttributeConversionResult(nullptr, naTag);
3056 }
3057 
3060  return AttributeConversionResult(nullptr, abortTag);
3061 }
3062 
3064  return impl.getInt() == resultTag;
3065 }
3066 
3068  return impl.getInt() == naTag;
3069 }
3070 
3072  return impl.getInt() == abortTag;
3073 }
3074 
3076  assert(hasResult() && "Cannot get result from N/A or abort");
3077  return impl.getPointer();
3078 }
3079 
3080 std::optional<Attribute>
3082  for (const TypeAttributeConversionCallbackFn &fn :
3083  llvm::reverse(typeAttributeConversions)) {
3084  AttributeConversionResult res = fn(type, attr);
3085  if (res.hasResult())
3086  return res.getResult();
3087  if (res.isAbort())
3088  return std::nullopt;
3089  }
3090  return std::nullopt;
3091 }
3092 
3093 //===----------------------------------------------------------------------===//
3094 // FunctionOpInterfaceSignatureConversion
3095 //===----------------------------------------------------------------------===//
3096 
3097 static LogicalResult convertFuncOpTypes(FunctionOpInterface funcOp,
3098  const TypeConverter &typeConverter,
3099  ConversionPatternRewriter &rewriter) {
3100  FunctionType type = dyn_cast<FunctionType>(funcOp.getFunctionType());
3101  if (!type)
3102  return failure();
3103 
3104  // Convert the original function types.
3105  TypeConverter::SignatureConversion result(type.getNumInputs());
3106  SmallVector<Type, 1> newResults;
3107  if (failed(typeConverter.convertSignatureArgs(type.getInputs(), result)) ||
3108  failed(typeConverter.convertTypes(type.getResults(), newResults)) ||
3109  failed(rewriter.convertRegionTypes(&funcOp.getFunctionBody(),
3110  typeConverter, &result)))
3111  return failure();
3112 
3113  // Update the function signature in-place.
3114  auto newType = FunctionType::get(rewriter.getContext(),
3115  result.getConvertedTypes(), newResults);
3116 
3117  rewriter.updateRootInPlace(funcOp, [&] { funcOp.setType(newType); });
3118 
3119  return success();
3120 }
3121 
3122 /// Create a default conversion pattern that rewrites the type signature of a
3123 /// FunctionOpInterface op. This only supports ops which use FunctionType to
3124 /// represent their type.
3125 namespace {
3126 struct FunctionOpInterfaceSignatureConversion : public ConversionPattern {
3127  FunctionOpInterfaceSignatureConversion(StringRef functionLikeOpName,
3128  MLIRContext *ctx,
3129  const TypeConverter &converter)
3130  : ConversionPattern(converter, functionLikeOpName, /*benefit=*/1, ctx) {}
3131 
3133  matchAndRewrite(Operation *op, ArrayRef<Value> /*operands*/,
3134  ConversionPatternRewriter &rewriter) const override {
3135  FunctionOpInterface funcOp = cast<FunctionOpInterface>(op);
3136  return convertFuncOpTypes(funcOp, *typeConverter, rewriter);
3137  }
3138 };
3139 
3140 struct AnyFunctionOpInterfaceSignatureConversion
3141  : public OpInterfaceConversionPattern<FunctionOpInterface> {
3143 
3145  matchAndRewrite(FunctionOpInterface funcOp, ArrayRef<Value> /*operands*/,
3146  ConversionPatternRewriter &rewriter) const override {
3147  return convertFuncOpTypes(funcOp, *typeConverter, rewriter);
3148  }
3149 };
3150 } // namespace
3151 
3153  StringRef functionLikeOpName, RewritePatternSet &patterns,
3154  const TypeConverter &converter) {
3155  patterns.add<FunctionOpInterfaceSignatureConversion>(
3156  functionLikeOpName, patterns.getContext(), converter);
3157 }
3158 
3160  RewritePatternSet &patterns, const TypeConverter &converter) {
3161  patterns.add<AnyFunctionOpInterfaceSignatureConversion>(
3162  converter, patterns.getContext());
3163 }
3164 
3165 //===----------------------------------------------------------------------===//
3166 // ConversionTarget
3167 //===----------------------------------------------------------------------===//
3168 
3170  LegalizationAction action) {
3171  legalOperations[op].action = action;
3172 }
3173 
3175  LegalizationAction action) {
3176  for (StringRef dialect : dialectNames)
3177  legalDialects[dialect] = action;
3178 }
3179 
3181  -> std::optional<LegalizationAction> {
3182  std::optional<LegalizationInfo> info = getOpInfo(op);
3183  return info ? info->action : std::optional<LegalizationAction>();
3184 }
3185 
3187  -> std::optional<LegalOpDetails> {
3188  std::optional<LegalizationInfo> info = getOpInfo(op->getName());
3189  if (!info)
3190  return std::nullopt;
3191 
3192  // Returns true if this operation instance is known to be legal.
3193  auto isOpLegal = [&] {
3194  // Handle dynamic legality either with the provided legality function.
3195  if (info->action == LegalizationAction::Dynamic) {
3196  std::optional<bool> result = info->legalityFn(op);
3197  if (result)
3198  return *result;
3199  }
3200 
3201  // Otherwise, the operation is only legal if it was marked 'Legal'.
3202  return info->action == LegalizationAction::Legal;
3203  };
3204  if (!isOpLegal())
3205  return std::nullopt;
3206 
3207  // This operation is legal, compute any additional legality information.
3208  LegalOpDetails legalityDetails;
3209  if (info->isRecursivelyLegal) {
3210  auto legalityFnIt = opRecursiveLegalityFns.find(op->getName());
3211  if (legalityFnIt != opRecursiveLegalityFns.end()) {
3212  legalityDetails.isRecursivelyLegal =
3213  legalityFnIt->second(op).value_or(true);
3214  } else {
3215  legalityDetails.isRecursivelyLegal = true;
3216  }
3217  }
3218  return legalityDetails;
3219 }
3220 
3222  std::optional<LegalizationInfo> info = getOpInfo(op->getName());
3223  if (!info)
3224  return false;
3225 
3226  if (info->action == LegalizationAction::Dynamic) {
3227  std::optional<bool> result = info->legalityFn(op);
3228  if (!result)
3229  return false;
3230 
3231  return !(*result);
3232  }
3233 
3234  return info->action == LegalizationAction::Illegal;
3235 }
3236 
3240  if (!oldCallback)
3241  return newCallback;
3242 
3243  auto chain = [oldCl = std::move(oldCallback), newCl = std::move(newCallback)](
3244  Operation *op) -> std::optional<bool> {
3245  if (std::optional<bool> result = newCl(op))
3246  return *result;
3247 
3248  return oldCl(op);
3249  };
3250  return chain;
3251 }
3252 
3253 void ConversionTarget::setLegalityCallback(
3254  OperationName name, const DynamicLegalityCallbackFn &callback) {
3255  assert(callback && "expected valid legality callback");
3256  auto infoIt = legalOperations.find(name);
3257  assert(infoIt != legalOperations.end() &&
3258  infoIt->second.action == LegalizationAction::Dynamic &&
3259  "expected operation to already be marked as dynamically legal");
3260  infoIt->second.legalityFn =
3261  composeLegalityCallbacks(std::move(infoIt->second.legalityFn), callback);
3262 }
3263 
3265  OperationName name, const DynamicLegalityCallbackFn &callback) {
3266  auto infoIt = legalOperations.find(name);
3267  assert(infoIt != legalOperations.end() &&
3268  infoIt->second.action != LegalizationAction::Illegal &&
3269  "expected operation to already be marked as legal");
3270  infoIt->second.isRecursivelyLegal = true;
3271  if (callback)
3272  opRecursiveLegalityFns[name] = composeLegalityCallbacks(
3273  std::move(opRecursiveLegalityFns[name]), callback);
3274  else
3275  opRecursiveLegalityFns.erase(name);
3276 }
3277 
3278 void ConversionTarget::setLegalityCallback(
3279  ArrayRef<StringRef> dialects, const DynamicLegalityCallbackFn &callback) {
3280  assert(callback && "expected valid legality callback");
3281  for (StringRef dialect : dialects)
3282  dialectLegalityFns[dialect] = composeLegalityCallbacks(
3283  std::move(dialectLegalityFns[dialect]), callback);
3284 }
3285 
3286 void ConversionTarget::setLegalityCallback(
3287  const DynamicLegalityCallbackFn &callback) {
3288  assert(callback && "expected valid legality callback");
3289  unknownLegalityFn = composeLegalityCallbacks(unknownLegalityFn, callback);
3290 }
3291 
3292 auto ConversionTarget::getOpInfo(OperationName op) const
3293  -> std::optional<LegalizationInfo> {
3294  // Check for info for this specific operation.
3295  auto it = legalOperations.find(op);
3296  if (it != legalOperations.end())
3297  return it->second;
3298  // Check for info for the parent dialect.
3299  auto dialectIt = legalDialects.find(op.getDialectNamespace());
3300  if (dialectIt != legalDialects.end()) {
3301  DynamicLegalityCallbackFn callback;
3302  auto dialectFn = dialectLegalityFns.find(op.getDialectNamespace());
3303  if (dialectFn != dialectLegalityFns.end())
3304  callback = dialectFn->second;
3305  return LegalizationInfo{dialectIt->second, /*isRecursivelyLegal=*/false,
3306  callback};
3307  }
3308  // Otherwise, check if we mark unknown operations as dynamic.
3309  if (unknownLegalityFn)
3310  return LegalizationInfo{LegalizationAction::Dynamic,
3311  /*isRecursivelyLegal=*/false, unknownLegalityFn};
3312  return std::nullopt;
3313 }
3314 
3315 //===----------------------------------------------------------------------===//
3316 // PDL Configuration
3317 //===----------------------------------------------------------------------===//
3318 
3320  auto &rewriterImpl =
3321  static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
3322  rewriterImpl.currentTypeConverter = getTypeConverter();
3323 }
3324 
3326  auto &rewriterImpl =
3327  static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
3328  rewriterImpl.currentTypeConverter = nullptr;
3329 }
3330 
3331 /// Remap the given value using the rewriter and the type converter in the
3332 /// provided config.
3335  SmallVector<Value> mappedValues;
3336  if (failed(rewriter.getRemappedValues(values, mappedValues)))
3337  return failure();
3338  return std::move(mappedValues);
3339 }
3340 
3343  "convertValue",
3344  [](PatternRewriter &rewriter, Value value) -> FailureOr<Value> {
3345  auto results = pdllConvertValues(
3346  static_cast<ConversionPatternRewriter &>(rewriter), value);
3347  if (failed(results))
3348  return failure();
3349  return results->front();
3350  });
3352  "convertValues", [](PatternRewriter &rewriter, ValueRange values) {
3353  return pdllConvertValues(
3354  static_cast<ConversionPatternRewriter &>(rewriter), values);
3355  });
3357  "convertType",
3358  [](PatternRewriter &rewriter, Type type) -> FailureOr<Type> {
3359  auto &rewriterImpl =
3360  static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
3361  if (const TypeConverter *converter =
3362  rewriterImpl.currentTypeConverter) {
3363  if (Type newType = converter->convertType(type))
3364  return newType;
3365  return failure();
3366  }
3367  return type;
3368  });
3370  "convertTypes",
3371  [](PatternRewriter &rewriter,
3373  auto &rewriterImpl =
3374  static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
3375  const TypeConverter *converter = rewriterImpl.currentTypeConverter;
3376  if (!converter)
3377  return SmallVector<Type>(types);
3378 
3379  SmallVector<Type> remappedTypes;
3380  if (failed(converter->convertTypes(types, remappedTypes)))
3381  return failure();
3382  return std::move(remappedTypes);
3383  });
3384 }
3385 
3386 //===----------------------------------------------------------------------===//
3387 // Op Conversion Entry Points
3388 //===----------------------------------------------------------------------===//
3389 
3390 //===----------------------------------------------------------------------===//
3391 // Partial Conversion
3392 
3395  const ConversionTarget &target,
3396  const FrozenRewritePatternSet &patterns,
3397  DenseSet<Operation *> *unconvertedOps) {
3398  OperationConverter opConverter(target, patterns, OpConversionMode::Partial,
3399  unconvertedOps);
3400  return opConverter.convertOperations(ops);
3401 }
3404  const FrozenRewritePatternSet &patterns,
3405  DenseSet<Operation *> *unconvertedOps) {
3406  return applyPartialConversion(llvm::ArrayRef(op), target, patterns,
3407  unconvertedOps);
3408 }
3409 
3410 //===----------------------------------------------------------------------===//
3411 // Full Conversion
3412 
3415  const FrozenRewritePatternSet &patterns) {
3416  OperationConverter opConverter(target, patterns, OpConversionMode::Full);
3417  return opConverter.convertOperations(ops);
3418 }
3421  const FrozenRewritePatternSet &patterns) {
3422  return applyFullConversion(llvm::ArrayRef(op), target, patterns);
3423 }
3424 
3425 //===----------------------------------------------------------------------===//
3426 // Analysis Conversion
3427 
3430  ConversionTarget &target,
3431  const FrozenRewritePatternSet &patterns,
3432  DenseSet<Operation *> &convertedOps,
3433  function_ref<void(Diagnostic &)> notifyCallback) {
3434  OperationConverter opConverter(target, patterns, OpConversionMode::Analysis,
3435  &convertedOps);
3436  return opConverter.convertOperations(ops, notifyCallback);
3437 }
3440  const FrozenRewritePatternSet &patterns,
3441  DenseSet<Operation *> &convertedOps,
3442  function_ref<void(Diagnostic &)> notifyCallback) {
3443  return applyAnalysisConversion(llvm::ArrayRef(op), target, patterns,
3444  convertedOps, notifyCallback);
3445 }
static ConversionTarget::DynamicLegalityCallbackFn composeLegalityCallbacks(ConversionTarget::DynamicLegalityCallbackFn oldCallback, ConversionTarget::DynamicLegalityCallbackFn newCallback)
static FailureOr< SmallVector< Value > > pdllConvertValues(ConversionPatternRewriter &rewriter, ValueRange values)
Remap the given value using the rewriter and the type converter in the provided config.
static LogicalResult convertFuncOpTypes(FunctionOpInterface funcOp, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
static Value buildUnresolvedTargetMaterialization(Location loc, Value input, Type outputType, const TypeConverter *converter, SmallVectorImpl< UnresolvedMaterialization > &unresolvedMaterializations)
static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args)
A utility function to log a failure result for the given reason.
static void detachNestedAndErase(Operation *op)
Detach any operations nested in the given operation from their parent blocks, and erase the given ope...
static Operation * findLiveUserOfReplaced(Value initialValue, ConversionPatternRewriterImpl &rewriterImpl, const DenseMap< Value, SmallVector< Value >> &inverseMapping)
Finds a user of the given value, or of any other value that the given value replaced,...
static Value buildUnresolvedArgumentMaterialization(PatternRewriter &rewriter, Location loc, ValueRange inputs, Type origOutputType, Type outputType, const TypeConverter *converter, SmallVectorImpl< UnresolvedMaterialization > &unresolvedMaterializations)
static void computeNecessaryMaterializations(DenseMap< Operation *, UnresolvedMaterialization * > &materializationOps, ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &rewriterImpl, DenseMap< Value, SmallVector< Value >> &inverseMapping, SetVector< UnresolvedMaterialization * > &necessaryMaterializations)
Compute all of the unresolved materializations that will persist beyond the conversion process,...
static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args)
A utility function to log a successful result for the given reason.
static void replaceMaterialization(ConversionPatternRewriterImpl &rewriterImpl, ResultRange matResults, ValueRange values, DenseMap< Value, SmallVector< Value >> &inverseMapping)
Replace the results of a materialization operation with the given values.
static LogicalResult legalizeUnresolvedMaterialization(UnresolvedMaterialization &mat, DenseMap< Operation *, UnresolvedMaterialization * > &materializationOps, ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &rewriterImpl, DenseMap< Value, SmallVector< Value >> &inverseMapping)
Legalize the given unresolved materialization.
static Value buildUnresolvedMaterialization(UnresolvedMaterialization::Kind kind, Block *insertBlock, Block::iterator insertPt, Location loc, ValueRange inputs, Type outputType, Type origOutputType, const TypeConverter *converter, SmallVectorImpl< UnresolvedMaterialization > &unresolvedMaterializations)
Build an unresolved materialization operation given an output type and set of input operands.
static std::string diag(const llvm::Value &value)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:310
Block * getOwner() const
Returns the block that owns this argument.
Definition: Value.h:319
Location getLoc() const
Return the location for this argument.
Definition: Value.h:325
Block represents an ordered list of Operations.
Definition: Block.h:30
OpListType::iterator iterator
Definition: Block.h:133
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
Definition: Block.cpp:143
bool empty()
Definition: Block.h:141
BlockArgument getArgument(unsigned i)
Definition: Block.h:122
unsigned getNumArguments()
Definition: Block.h:121
Operation & back()
Definition: Block.h:145
void erase()
Unlink this Block from its parent region and delete it.
Definition: Block.cpp:60
iterator_range< args_iterator > addArguments(TypeRange types, ArrayRef< Location > locs)
Add one argument to the argument list for each type specified in the list.
Definition: Block.cpp:154
Block * splitBlock(iterator splitBefore)
Split the block into two blocks before the specified operation or iterator.
Definition: Block.cpp:302
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition: Block.cpp:26
void dropAllDefinedValueUses()
This drops all uses of values defined in this block or in the blocks of nested regions wherever the u...
Definition: Block.cpp:88
RetT walk(FnT &&callback)
Walk the operations in this block.
Definition: Block.h:280
OpListType & getOperations()
Definition: Block.h:130
BlockArgListType getArguments()
Definition: Block.h:80
Operation & front()
Definition: Block.h:146
iterator end()
Definition: Block.h:137
iterator begin()
Definition: Block.h:136
void moveBefore(Block *block)
Unlink this block from its current region and insert it right before the specific block.
Definition: Block.cpp:53
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:30
MLIRContext * getContext() const
Definition: Builders.h:55
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
LogicalResult getRemappedValues(ValueRange keys, SmallVectorImpl< Value > &results)
Return the converted values that replace 'keys' with types defined by the type converter of the curre...
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Convert the types of block arguments within the given region.
void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt) override
PatternRewriter hook for inlining the ops of a block into another block.
LogicalResult notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
PatternRewriter hook for notifying match failure reasons.
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before) override
PatternRewriter hook for moving blocks out of a region.
void finalizeRootUpdate(Operation *op) override
PatternRewriter hook for updating the root operation in-place.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
void replaceOpWithIf(Operation *op, ValueRange newValues, bool *allUsesReplaced, llvm::unique_function< bool(OpOperand &) const > functor) override
PatternRewriter hook for replacing an operation when the given functor returns "true".
void notifyBlockCreated(Block *block) override
PatternRewriter hook creating a new block.
Block * splitBlock(Block *block, Block::iterator before) override
PatternRewriter hook for splitting a block into two parts.
detail::ConversionPatternRewriterImpl & getImpl()
Return a reference to the internal implementation.
void notifyOperationInserted(Operation *op) override
PatternRewriter hook for inserting a new operation.
void cancelRootUpdate(Operation *op) override
PatternRewriter hook for updating the root operation in-place.
void eraseBlock(Block *block) override
PatternRewriter hook for erase all operations in a block.
Value getRemappedValue(Value key)
Return the converted value of 'key' with a type defined by the type converter of the currently execut...
Block * applySignatureConversion(Region *region, TypeConverter::SignatureConversion &conversion, const TypeConverter *converter=nullptr)
Apply a signature conversion to the entry block of the given region.
void replaceUsesOfBlockArgument(BlockArgument from, Value to)
Replace all the uses of the block argument from with value to.
void startRootUpdate(Operation *op) override
PatternRewriter hook for updating the root operation in-place.
LogicalResult convertNonEntryRegionTypes(Region *region, const TypeConverter &converter, ArrayRef< TypeConverter::SignatureConversion > blockConversions)
Convert the types of block arguments within the given region except for the entry region.
void cloneRegionBefore(Region &region, Region &parent, Region::iterator before, IRMapping &mapping) override
PatternRewriter hook for cloning blocks of one region into another.
Base class for the conversion patterns.
const TypeConverter * getTypeConverter() const
Return the type converter held by this pattern, or nullptr if the pattern does not require type conve...
virtual LogicalResult matchAndRewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const
Hook for derived classes to implement combined matching and rewriting.
This class describes a specific conversion target.
void setDialectAction(ArrayRef< StringRef > dialectNames, LegalizationAction action)
Register a legality action for the given dialects.
void setOpAction(OperationName op, LegalizationAction action)
Register a legality action for the given operation.
std::optional< LegalOpDetails > isLegal(Operation *op) const
If the given operation instance is legal on this target, a structure containing legality information ...
std::optional< LegalizationAction > getOpAction(OperationName op) const
Get the legality action for the given operation.
LegalizationAction
This enumeration corresponds to the specific action to take when considering an operation legal for t...
void markOpRecursivelyLegal(OperationName name, const DynamicLegalityCallbackFn &callback)
Mark an operation, that must have either been set as Legal or DynamicallyLegal, as being recursively ...
std::function< std::optional< bool >(Operation *)> DynamicLegalityCallbackFn
The signature of the callback used to determine if an operation is dynamically legal on the target.
bool isIllegal(Operation *op) const
Returns true is operation instance is illegal on this target.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
Definition: Diagnostics.h:156
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
auto lookup(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:72
user_range getUsers() const
Returns a range of all users.
Definition: UseDefLists.h:264
void dropAllUses()
Drop all uses of this object from their respective owners.
Definition: UseDefLists.h:192
void replaceAllUsesWith(ValueT &&newValue)
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
Definition: UseDefLists.h:201
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:308
Location objects represent source locations information in MLIR.
Definition: Location.h:31
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
bool isMultithreadingEnabled()
Return true if multi-threading is enabled by the context.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:333
This class helps build Operations.
Definition: Builders.h:206
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Definition: Builders.h:430
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:383
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
Definition: Builders.h:406
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:446
LogicalResult tryFold(Operation *op, SmallVectorImpl< Value > &results)
Attempts to fold the given operation and places new results within 'results'.
Definition: Builders.cpp:464
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Definition: Builders.h:427
OpInterfaceConversionPattern is a wrapper around ConversionPattern that allows for matching and rewri...
OpInterfaceConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
This class represents an operand of an operation.
Definition: Value.h:261
This is a value defined by a result of an operation.
Definition: Value.h:448
unsigned getResultNumber() const
Returns the number of this result.
Definition: Value.h:460
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:42
type_range getTypes() const
Definition: ValueRange.cpp:26
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
void setLoc(Location loc)
Set the source location the operation was defined or derived from.
Definition: Operation.h:226
OpResult getOpResult(unsigned idx)
Definition: Operation.h:416
void dropAllUses()
Drop all uses of results of this operation.
Definition: Operation.h:813
void setAttrs(DictionaryAttr newAttrs)
Set the attributes from a dictionary on this operation.
Definition: Operation.cpp:304
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Definition: Operation.cpp:385
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:776
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:652
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:267
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:655
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
operand_type_range getOperandTypes()
Definition: Operation.h:392
result_type_range getResultTypes()
Definition: Operation.h:423
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
void setSuccessor(Block *block, unsigned index)
Definition: Operation.cpp:604
void setOperands(ValueRange operands)
Replace the current operands of this operation with the ones provided in 'operands'.
Definition: Operation.cpp:236
user_range getUsers()
Returns a range of all users.
Definition: Operation.h:852
result_range getResults()
Definition: Operation.h:410
void erase()
Remove this operation from its parent block and delete it.
Definition: Operation.cpp:538
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
void notifyRewriteEnd(PatternRewriter &rewriter) final
void notifyRewriteBegin(PatternRewriter &rewriter) final
Hooks that are invoked at the beginning and end of a rewrite of a matched pattern.
void registerRewriteFunction(StringRef name, PDLRewriteFunction rewriteFn)
Register a rewrite function with PDL.
This class manages the application of a group of rewrite patterns, with a user-provided cost model.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:33
static PatternBenefit impossibleToMatch()
Definition: PatternMatch.h:42
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:727
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
Definition: PatternMatch.h:72
bool hasBoundedRewriteRecursion() const
Returns true if this pattern is known to result in recursive application, i.e.
Definition: PatternMatch.h:128
std::optional< OperationName > getRootKind() const
Return the root node that this pattern matches.
Definition: PatternMatch.h:93
ArrayRef< OperationName > getGeneratedOps() const
Return a list of operations that may be generated when rewriting an operation instance with this patt...
Definition: PatternMatch.h:89
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition: Region.h:200
bool empty()
Definition: Region.h:60
iterator end()
Definition: Region.h:56
Block & back()
Definition: Region.h:64
BlockListType & getBlocks()
Definition: Region.h:45
Block & front()
Definition: Region.h:65
BlockListType::iterator iterator
Definition: Region.h:52
This class implements the result iterators for the Operation class.
Definition: ValueRange.h:233
std::enable_if_t<!std::is_convertible< ValuesT, Operation * >::value > replaceAllUsesWith(ValuesT &&values)
Replace all uses of results of this range with the provided 'values'.
Definition: ValueRange.h:272
MLIRContext * getContext() const
PDLPatternModule & getPDLPatterns()
Return the PDL patterns held in this list.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void cloneRegionBefore(Region &region, Region &parent, Region::iterator before, IRMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent".
virtual Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
void updateRootInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around a root update of an operation.
Definition: PatternMatch.h:606
virtual void finalizeRootUpdate(Operation *op)
This method is used to signal the end of a root update on the given operation.
virtual void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
The general result of a type attribute conversion callback, allowing for early termination.
static AttributeConversionResult result(Attribute attr)
This class provides all of the information necessary to convert a type signature.
void addInputs(unsigned origInputNo, ArrayRef< Type > types)
Remap an input of the original signature with a new set of types.
std::optional< InputMapping > getInputMapping(unsigned input) const
Get the input mapping for the given argument.
ArrayRef< Type > getConvertedTypes() const
Return the argument types for the new signature.
void remapInput(unsigned origInputNo, Value replacement)
Remap an input of the original signature to another replacement value.
Type conversion class.
std::optional< Attribute > convertTypeAttribute(Type type, Attribute attr) const
Convert an attribute present attr from within the type type using the registered conversion functions...
bool isLegal(Type type) const
Return true if the given type is legal for this type converter, i.e.
Value materializeSourceConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs) const
Value materializeArgumentConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs) const
Materialize a conversion from a set of types into one result type by generating a cast sequence of so...
LogicalResult convertSignatureArgs(TypeRange types, SignatureConversion &result, unsigned origInputOffset=0) const
LogicalResult convertSignatureArg(unsigned inputNo, Type type, SignatureConversion &result) const
This method allows for converting a specific argument of a signature.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
std::optional< SignatureConversion > convertBlockSignature(Block *block) const
This function converts the type signature of the given block, by invoking 'convertSignatureArg' for e...
LogicalResult convertTypes(TypeRange types, SmallVectorImpl< Type > &results) const
Convert the given set of types, filling 'results' as necessary.
Value materializeTargetConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs) const
bool isSignatureLegal(FunctionType ty) const
Return true if the inputs and outputs of the given function type are legal.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:372
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:93
bool use_empty() const
Returns true if this value has no uses.
Definition: Value.h:212
void dropAllUses() const
Drop all uses of this object from their respective owners.
Definition: Value.h:161
Type getType() const
Return the type of this value.
Definition: Value.h:122
user_iterator user_end() const
Definition: Value.h:221
Block * getParentBlock()
Return the Block in which this Value is defined.
Definition: Value.cpp:48
void replaceAllUsesWith(Value newValue) const
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
Definition: Value.h:166
user_range getUsers() const
Definition: Value.h:222
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
static WalkResult skip()
Definition: Visitors.h:53
static WalkResult advance()
Definition: Visitors.h:52
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
Detect if any of the given parameter types has a sub-element handler.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
@ Full
Documents are synced by always sending the full content of the document.
Kind
An enumeration of the kinds of predicates.
Definition: Predicate.h:44
llvm::PointerUnion< NamedAttribute *, NamedProperty *, NamedTypeConstraint * > Argument
Definition: Argument.h:64
This header declares functions that assist transformations in the MemRef dialect.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
void populateFunctionOpInterfaceTypeConversionPattern(StringRef functionLikeOpName, RewritePatternSet &patterns, const TypeConverter &converter)
Add a pattern to the given pattern list to convert the signature of a FunctionOpInterface op with the...
void populateAnyFunctionOpInterfaceTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter)
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, DenseSet< Operation * > *unconvertedOps=nullptr)
Below we define several entry points for operation conversion.
LogicalResult applyFullConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns)
Apply a complete conversion on the given operations, and all nested operations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void registerConversionPDLFunctions(RewritePatternSet &patterns)
Register the dialect conversion PDL functions with the given pattern set.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
LogicalResult applyAnalysisConversion(ArrayRef< Operation * > ops, ConversionTarget &target, const FrozenRewritePatternSet &patterns, DenseSet< Operation * > &convertedOps, function_ref< void(Diagnostic &)> notifyCallback=nullptr)
Apply an analysis conversion on the given operations, and all nested operations.
A structure containing additional information describing a specific legal operation instance.
bool isRecursivelyLegal
A flag that indicates if this operation is 'recursively' legal.
This iterator enumerates elements according to their dominance relationship.
Definition: Iterators.h:48
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
This struct represents a range of new types or a single value that remaps an existing signature input...
void eraseDanglingBlocks()
Erase any blocks that were unlinked from their regions and stored in block actions.
void undoBlockActions(unsigned numActionsToKeep=0)
Undo the block actions (motions, splits) one by one in reverse order until "numActionsToKeep" actions...
void discardRewrites()
Cleanup and destroy any generated rewrite operations.
ConversionPatternRewriterImpl(PatternRewriter &rewriter)
function_ref< void(Diagnostic &)> notifyCallback
This allows the user to collect the match failure message.
SmallVector< BlockAction, 4 > blockActions
Ordered list of block operations (creations, splits, motions).
llvm::MapVector< Operation *, OpReplacement > replacements
Ordered map of requested operation replacements.
void resetState(RewriterState state)
Reset the state of the rewriter to a previously saved point.
void notifyRegionIsBeingInlinedBefore(Region &region, Region &parent, Region::iterator before)
Notifies that the blocks of a region are about to be moved.
void notifySplitBlock(Block *block, Block *continuation)
Notifies that a block was split.
void applyRewrites()
Apply all requested operation rewrites.
SmallVector< OperationTransactionState, 4 > rootUpdates
A transaction state for each of operations that were updated in-place.
RewriterState getCurrentState()
Return the current state of the rewriter.
void notifyOpReplaced(Operation *op, ValueRange newValues)
PatternRewriter hook for replacing the results of an operation.
llvm::ScopedPrinter logger
A logger used to emit diagnostics during the conversion process.
SmallVector< UnresolvedMaterialization > unresolvedMaterializations
Ordered vector of all unresolved type conversion materializations during conversion.
void notifyBlockBeingInlined(Block *block, Block *srcBlock, Block::iterator before)
Notifies that a block is being inlined into another block.
LogicalResult convertNonEntryRegionTypes(Region *region, const TypeConverter &converter, ArrayRef< TypeConverter::SignatureConversion > blockConversions={})
Convert the types of non-entry block arguments within the given region.
ArgConverter argConverter
Utility used to convert block arguments.
Block * applySignatureConversion(Region *region, TypeConverter::SignatureConversion &conversion, const TypeConverter *converter)
Apply a signature conversion on the given region, using converter for materializations if not null.
SmallPtrSet< Operation *, 1 > pendingRootUpdates
A set of operations that have pending updates.
LogicalResult remapValues(StringRef valueDiagTag, std::optional< Location > inputLoc, PatternRewriter &rewriter, ValueRange values, SmallVectorImpl< Value > &remapped)
Remap the given values to those with potentially different types.
bool isOpIgnored(Operation *op) const
Returns true if the given operation is ignored, and does not need to be converted.
void markNestedOpsIgnored(Operation *op)
Recursively marks the nested operations under 'op' as ignored.
SmallVector< Operation * > createdOps
Ordered vector of all of the newly created operations during conversion.
LogicalResult notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback)
Notifies that a pattern match failed for the given reason.
SetVector< Operation * > ignoredOps
A set of operations that should no longer be considered for legalization, but were not directly repla...
SmallVector< BlockArgument, 4 > argReplacements
Ordered vector of any requested block argument replacements.
FailureOr< Block * > convertBlockSignature(Block *block, const TypeConverter *converter, TypeConverter::SignatureConversion *conversion=nullptr)
Convert the signature of the given block.
void notifyBlockIsBeingErased(Block *block)
Notifies that a block is about to be erased.
void notifyCreatedBlock(Block *block)
Notifies that a block was created.
const TypeConverter * currentTypeConverter
The current type converter, or nullptr if no type converter is currently active.
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion)
Convert the types of block arguments within the given region.
SmallVector< unsigned, 4 > operationsWithChangedResults
A vector of indices into replacements of operations that were replaced with values with different res...
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.