MLIR  20.0.0git
DialectConversion.cpp
Go to the documentation of this file.
1 //===- DialectConversion.cpp - MLIR dialect conversion generic pass -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 #include "mlir/Config/mlir-config.h"
11 #include "mlir/IR/Block.h"
12 #include "mlir/IR/Builders.h"
13 #include "mlir/IR/BuiltinOps.h"
14 #include "mlir/IR/IRMapping.h"
15 #include "mlir/IR/Iterators.h"
18 #include "llvm/ADT/ScopeExit.h"
19 #include "llvm/ADT/SetVector.h"
20 #include "llvm/ADT/SmallPtrSet.h"
21 #include "llvm/Support/Debug.h"
22 #include "llvm/Support/FormatVariadic.h"
23 #include "llvm/Support/SaveAndRestore.h"
24 #include "llvm/Support/ScopedPrinter.h"
25 #include <optional>
26 
27 using namespace mlir;
28 using namespace mlir::detail;
29 
30 #define DEBUG_TYPE "dialect-conversion"
31 
32 /// A utility function to log a successful result for the given reason.
33 template <typename... Args>
34 static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
35  LLVM_DEBUG({
36  os.unindent();
37  os.startLine() << "} -> SUCCESS";
38  if (!fmt.empty())
39  os.getOStream() << " : "
40  << llvm::formatv(fmt.data(), std::forward<Args>(args)...);
41  os.getOStream() << "\n";
42  });
43 }
44 
45 /// A utility function to log a failure result for the given reason.
46 template <typename... Args>
47 static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
48  LLVM_DEBUG({
49  os.unindent();
50  os.startLine() << "} -> FAILURE : "
51  << llvm::formatv(fmt.data(), std::forward<Args>(args)...)
52  << "\n";
53  });
54 }
55 
56 /// Helper function that computes an insertion point where the given value is
57 /// defined and can be used without a dominance violation.
59  Block *insertBlock = value.getParentBlock();
60  Block::iterator insertPt = insertBlock->begin();
61  if (OpResult inputRes = dyn_cast<OpResult>(value))
62  insertPt = ++inputRes.getOwner()->getIterator();
63  return OpBuilder::InsertPoint(insertBlock, insertPt);
64 }
65 
66 //===----------------------------------------------------------------------===//
67 // ConversionValueMapping
68 //===----------------------------------------------------------------------===//
69 
70 /// A list of replacement SSA values. Optimized for the common case of a single
71 /// SSA value.
72 using ReplacementValues = SmallVector<Value, 1>;
73 
74 namespace {
75 /// This class wraps a IRMapping to provide recursive lookup
76 /// functionality, i.e. we will traverse if the mapped value also has a mapping.
77 struct ConversionValueMapping {
78  /// Return "true" if an SSA value is mapped to the given value. May return
79  /// false positives.
80  bool isMappedTo(Value value) const { return mappedTo.contains(value); }
81 
82  /// Lookup the most recently mapped value with the desired type in the
83  /// mapping.
84  ///
85  /// Special cases:
86  /// - If the desired type is "null", simply return the most recently mapped
87  /// value.
88  /// - If there is no mapping to the desired type, also return the most
89  /// recently mapped value.
90  /// - If there is no mapping for the given value at all, return the given
91  /// value.
92  Value lookupOrDefault(Value from, Type desiredType = nullptr) const;
93 
94  /// Lookup a mapped value within the map, or return null if a mapping does not
95  /// exist. If a mapping exists, this follows the same behavior of
96  /// `lookupOrDefault`.
97  Value lookupOrNull(Value from, Type desiredType = nullptr) const;
98 
99  /// Map a value to the one provided.
100  void map(Value oldVal, Value newVal) {
101  LLVM_DEBUG({
102  for (Value it = newVal; it; it = mapping.lookupOrNull(it))
103  assert(it != oldVal && "inserting cyclic mapping");
104  });
105  mapping.map(oldVal, newVal);
106  mappedTo.insert(newVal);
107  }
108 
109  /// Drop the last mapping for the given value.
110  void erase(Value value) { mapping.erase(value); }
111 
112 private:
113  /// Current value mappings.
114  IRMapping mapping;
115 
116  /// All SSA values that are mapped to. May contain false positives.
117  DenseSet<Value> mappedTo;
118 };
119 } // namespace
120 
121 Value ConversionValueMapping::lookupOrDefault(Value from,
122  Type desiredType) const {
123  // Try to find the deepest value that has the desired type. If there is no
124  // such value, simply return the deepest value.
125  Value desiredValue;
126  do {
127  if (!desiredType || from.getType() == desiredType)
128  desiredValue = from;
129 
130  Value mappedValue = mapping.lookupOrNull(from);
131  if (!mappedValue)
132  break;
133  from = mappedValue;
134  } while (true);
135 
136  // If the desired value was found use it, otherwise default to the leaf value.
137  return desiredValue ? desiredValue : from;
138 }
139 
140 Value ConversionValueMapping::lookupOrNull(Value from, Type desiredType) const {
141  Value result = lookupOrDefault(from, desiredType);
142  if (result == from || (desiredType && result.getType() != desiredType))
143  return nullptr;
144  return result;
145 }
146 
147 //===----------------------------------------------------------------------===//
148 // Rewriter and Translation State
149 //===----------------------------------------------------------------------===//
150 namespace {
151 /// This class contains a snapshot of the current conversion rewriter state.
152 /// This is useful when saving and undoing a set of rewrites.
153 struct RewriterState {
154  RewriterState(unsigned numRewrites, unsigned numIgnoredOperations,
155  unsigned numReplacedOps)
156  : numRewrites(numRewrites), numIgnoredOperations(numIgnoredOperations),
157  numReplacedOps(numReplacedOps) {}
158 
159  /// The current number of rewrites performed.
160  unsigned numRewrites;
161 
162  /// The current number of ignored operations.
163  unsigned numIgnoredOperations;
164 
165  /// The current number of replaced ops that are scheduled for erasure.
166  unsigned numReplacedOps;
167 };
168 
169 //===----------------------------------------------------------------------===//
170 // IR rewrites
171 //===----------------------------------------------------------------------===//
172 
173 /// An IR rewrite that can be committed (upon success) or rolled back (upon
174 /// failure).
175 ///
176 /// The dialect conversion keeps track of IR modifications (requested by the
177 /// user through the rewriter API) in `IRRewrite` objects. Some kind of rewrites
178 /// are directly applied to the IR as the rewriter API is used, some are applied
179 /// partially, and some are delayed until the `IRRewrite` objects are committed.
180 class IRRewrite {
181 public:
182  /// The kind of the rewrite. Rewrites can be undone if the conversion fails.
183  /// Enum values are ordered, so that they can be used in `classof`: first all
184  /// block rewrites, then all operation rewrites.
185  enum class Kind {
186  // Block rewrites
187  CreateBlock,
188  EraseBlock,
189  InlineBlock,
190  MoveBlock,
191  BlockTypeConversion,
192  ReplaceBlockArg,
193  // Operation rewrites
194  MoveOperation,
195  ModifyOperation,
196  ReplaceOperation,
197  CreateOperation,
198  UnresolvedMaterialization
199  };
200 
201  virtual ~IRRewrite() = default;
202 
203  /// Roll back the rewrite. Operations may be erased during rollback.
204  virtual void rollback() = 0;
205 
206  /// Commit the rewrite. At this point, it is certain that the dialect
207  /// conversion will succeed. All IR modifications, except for operation/block
208  /// erasure, must be performed through the given rewriter.
209  ///
210  /// Instead of erasing operations/blocks, they should merely be unlinked
211  /// commit phase and finally be erased during the cleanup phase. This is
212  /// because internal dialect conversion state (such as `mapping`) may still
213  /// be using them.
214  ///
215  /// Any IR modification that was already performed before the commit phase
216  /// (e.g., insertion of an op) must be communicated to the listener that may
217  /// be attached to the given rewriter.
218  virtual void commit(RewriterBase &rewriter) {}
219 
220  /// Cleanup operations/blocks. Cleanup is called after commit.
221  virtual void cleanup(RewriterBase &rewriter) {}
222 
223  Kind getKind() const { return kind; }
224 
225  static bool classof(const IRRewrite *rewrite) { return true; }
226 
227 protected:
228  IRRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl)
229  : kind(kind), rewriterImpl(rewriterImpl) {}
230 
231  const ConversionConfig &getConfig() const;
232 
233  const Kind kind;
234  ConversionPatternRewriterImpl &rewriterImpl;
235 };
236 
237 /// A block rewrite.
238 class BlockRewrite : public IRRewrite {
239 public:
240  /// Return the block that this rewrite operates on.
241  Block *getBlock() const { return block; }
242 
243  static bool classof(const IRRewrite *rewrite) {
244  return rewrite->getKind() >= Kind::CreateBlock &&
245  rewrite->getKind() <= Kind::ReplaceBlockArg;
246  }
247 
248 protected:
249  BlockRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl,
250  Block *block)
251  : IRRewrite(kind, rewriterImpl), block(block) {}
252 
253  // The block that this rewrite operates on.
254  Block *block;
255 };
256 
257 /// Creation of a block. Block creations are immediately reflected in the IR.
258 /// There is no extra work to commit the rewrite. During rollback, the newly
259 /// created block is erased.
260 class CreateBlockRewrite : public BlockRewrite {
261 public:
262  CreateBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block)
263  : BlockRewrite(Kind::CreateBlock, rewriterImpl, block) {}
264 
265  static bool classof(const IRRewrite *rewrite) {
266  return rewrite->getKind() == Kind::CreateBlock;
267  }
268 
269  void commit(RewriterBase &rewriter) override {
270  // The block was already created and inserted. Just inform the listener.
271  if (auto *listener = rewriter.getListener())
272  listener->notifyBlockInserted(block, /*previous=*/{}, /*previousIt=*/{});
273  }
274 
275  void rollback() override {
276  // Unlink all of the operations within this block, they will be deleted
277  // separately.
278  auto &blockOps = block->getOperations();
279  while (!blockOps.empty())
280  blockOps.remove(blockOps.begin());
281  block->dropAllUses();
282  if (block->getParent())
283  block->erase();
284  else
285  delete block;
286  }
287 };
288 
289 /// Erasure of a block. Block erasures are partially reflected in the IR. Erased
290 /// blocks are immediately unlinked, but only erased during cleanup. This makes
291 /// it easier to rollback a block erasure: the block is simply inserted into its
292 /// original location.
293 class EraseBlockRewrite : public BlockRewrite {
294 public:
295  EraseBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block)
296  : BlockRewrite(Kind::EraseBlock, rewriterImpl, block),
297  region(block->getParent()), insertBeforeBlock(block->getNextNode()) {}
298 
299  static bool classof(const IRRewrite *rewrite) {
300  return rewrite->getKind() == Kind::EraseBlock;
301  }
302 
303  ~EraseBlockRewrite() override {
304  assert(!block &&
305  "rewrite was neither rolled back nor committed/cleaned up");
306  }
307 
308  void rollback() override {
309  // The block (owned by this rewrite) was not actually erased yet. It was
310  // just unlinked. Put it back into its original position.
311  assert(block && "expected block");
312  auto &blockList = region->getBlocks();
313  Region::iterator before = insertBeforeBlock
314  ? Region::iterator(insertBeforeBlock)
315  : blockList.end();
316  blockList.insert(before, block);
317  block = nullptr;
318  }
319 
320  void commit(RewriterBase &rewriter) override {
321  // Erase the block.
322  assert(block && "expected block");
323  assert(block->empty() && "expected empty block");
324 
325  // Notify the listener that the block is about to be erased.
326  if (auto *listener =
327  dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener()))
328  listener->notifyBlockErased(block);
329  }
330 
331  void cleanup(RewriterBase &rewriter) override {
332  // Erase the block.
333  block->dropAllDefinedValueUses();
334  delete block;
335  block = nullptr;
336  }
337 
338 private:
339  // The region in which this block was previously contained.
340  Region *region;
341 
342  // The original successor of this block before it was unlinked. "nullptr" if
343  // this block was the only block in the region.
344  Block *insertBeforeBlock;
345 };
346 
347 /// Inlining of a block. This rewrite is immediately reflected in the IR.
348 /// Note: This rewrite represents only the inlining of the operations. The
349 /// erasure of the inlined block is a separate rewrite.
350 class InlineBlockRewrite : public BlockRewrite {
351 public:
352  InlineBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block,
353  Block *sourceBlock, Block::iterator before)
354  : BlockRewrite(Kind::InlineBlock, rewriterImpl, block),
355  sourceBlock(sourceBlock),
356  firstInlinedInst(sourceBlock->empty() ? nullptr
357  : &sourceBlock->front()),
358  lastInlinedInst(sourceBlock->empty() ? nullptr : &sourceBlock->back()) {
359  // If a listener is attached to the dialect conversion, ops must be moved
360  // one-by-one. When they are moved in bulk, notifications cannot be sent
361  // because the ops that used to be in the source block at the time of the
362  // inlining (before the "commit" phase) are unknown at the time when
363  // notifications are sent (which is during the "commit" phase).
364  assert(!getConfig().listener &&
365  "InlineBlockRewrite not supported if listener is attached");
366  }
367 
368  static bool classof(const IRRewrite *rewrite) {
369  return rewrite->getKind() == Kind::InlineBlock;
370  }
371 
372  void rollback() override {
373  // Put the operations from the destination block (owned by the rewrite)
374  // back into the source block.
375  if (firstInlinedInst) {
376  assert(lastInlinedInst && "expected operation");
377  sourceBlock->getOperations().splice(sourceBlock->begin(),
378  block->getOperations(),
379  Block::iterator(firstInlinedInst),
380  ++Block::iterator(lastInlinedInst));
381  }
382  }
383 
384 private:
385  // The block that originally contained the operations.
386  Block *sourceBlock;
387 
388  // The first inlined operation.
389  Operation *firstInlinedInst;
390 
391  // The last inlined operation.
392  Operation *lastInlinedInst;
393 };
394 
395 /// Moving of a block. This rewrite is immediately reflected in the IR.
396 class MoveBlockRewrite : public BlockRewrite {
397 public:
398  MoveBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block,
399  Region *region, Block *insertBeforeBlock)
400  : BlockRewrite(Kind::MoveBlock, rewriterImpl, block), region(region),
401  insertBeforeBlock(insertBeforeBlock) {}
402 
403  static bool classof(const IRRewrite *rewrite) {
404  return rewrite->getKind() == Kind::MoveBlock;
405  }
406 
407  void commit(RewriterBase &rewriter) override {
408  // The block was already moved. Just inform the listener.
409  if (auto *listener = rewriter.getListener()) {
410  // Note: `previousIt` cannot be passed because this is a delayed
411  // notification and iterators into past IR state cannot be represented.
412  listener->notifyBlockInserted(block, /*previous=*/region,
413  /*previousIt=*/{});
414  }
415  }
416 
417  void rollback() override {
418  // Move the block back to its original position.
419  Region::iterator before =
420  insertBeforeBlock ? Region::iterator(insertBeforeBlock) : region->end();
421  region->getBlocks().splice(before, block->getParent()->getBlocks(), block);
422  }
423 
424 private:
425  // The region in which this block was previously contained.
426  Region *region;
427 
428  // The original successor of this block before it was moved. "nullptr" if
429  // this block was the only block in the region.
430  Block *insertBeforeBlock;
431 };
432 
433 /// Block type conversion. This rewrite is partially reflected in the IR.
434 class BlockTypeConversionRewrite : public BlockRewrite {
435 public:
436  BlockTypeConversionRewrite(ConversionPatternRewriterImpl &rewriterImpl,
437  Block *origBlock, Block *newBlock)
438  : BlockRewrite(Kind::BlockTypeConversion, rewriterImpl, origBlock),
439  newBlock(newBlock) {}
440 
441  static bool classof(const IRRewrite *rewrite) {
442  return rewrite->getKind() == Kind::BlockTypeConversion;
443  }
444 
445  Block *getOrigBlock() const { return block; }
446 
447  Block *getNewBlock() const { return newBlock; }
448 
449  void commit(RewriterBase &rewriter) override;
450 
451  void rollback() override;
452 
453 private:
454  /// The new block that was created as part of this signature conversion.
455  Block *newBlock;
456 };
457 
458 /// Replacing a block argument. This rewrite is not immediately reflected in the
459 /// IR. An internal IR mapping is updated, but the actual replacement is delayed
460 /// until the rewrite is committed.
461 class ReplaceBlockArgRewrite : public BlockRewrite {
462 public:
463  ReplaceBlockArgRewrite(ConversionPatternRewriterImpl &rewriterImpl,
464  Block *block, BlockArgument arg,
465  const TypeConverter *converter)
466  : BlockRewrite(Kind::ReplaceBlockArg, rewriterImpl, block), arg(arg),
467  converter(converter) {}
468 
469  static bool classof(const IRRewrite *rewrite) {
470  return rewrite->getKind() == Kind::ReplaceBlockArg;
471  }
472 
473  void commit(RewriterBase &rewriter) override;
474 
475  void rollback() override;
476 
477 private:
478  BlockArgument arg;
479 
480  /// The current type converter when the block argument was replaced.
481  const TypeConverter *converter;
482 };
483 
484 /// An operation rewrite.
485 class OperationRewrite : public IRRewrite {
486 public:
487  /// Return the operation that this rewrite operates on.
488  Operation *getOperation() const { return op; }
489 
490  static bool classof(const IRRewrite *rewrite) {
491  return rewrite->getKind() >= Kind::MoveOperation &&
492  rewrite->getKind() <= Kind::UnresolvedMaterialization;
493  }
494 
495 protected:
496  OperationRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl,
497  Operation *op)
498  : IRRewrite(kind, rewriterImpl), op(op) {}
499 
500  // The operation that this rewrite operates on.
501  Operation *op;
502 };
503 
504 /// Moving of an operation. This rewrite is immediately reflected in the IR.
505 class MoveOperationRewrite : public OperationRewrite {
506 public:
507  MoveOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
508  Operation *op, Block *block, Operation *insertBeforeOp)
509  : OperationRewrite(Kind::MoveOperation, rewriterImpl, op), block(block),
510  insertBeforeOp(insertBeforeOp) {}
511 
512  static bool classof(const IRRewrite *rewrite) {
513  return rewrite->getKind() == Kind::MoveOperation;
514  }
515 
516  void commit(RewriterBase &rewriter) override {
517  // The operation was already moved. Just inform the listener.
518  if (auto *listener = rewriter.getListener()) {
519  // Note: `previousIt` cannot be passed because this is a delayed
520  // notification and iterators into past IR state cannot be represented.
521  listener->notifyOperationInserted(
522  op, /*previous=*/OpBuilder::InsertPoint(/*insertBlock=*/block,
523  /*insertPt=*/{}));
524  }
525  }
526 
527  void rollback() override {
528  // Move the operation back to its original position.
529  Block::iterator before =
530  insertBeforeOp ? Block::iterator(insertBeforeOp) : block->end();
531  block->getOperations().splice(before, op->getBlock()->getOperations(), op);
532  }
533 
534 private:
535  // The block in which this operation was previously contained.
536  Block *block;
537 
538  // The original successor of this operation before it was moved. "nullptr"
539  // if this operation was the only operation in the region.
540  Operation *insertBeforeOp;
541 };
542 
543 /// In-place modification of an op. This rewrite is immediately reflected in
544 /// the IR. The previous state of the operation is stored in this object.
545 class ModifyOperationRewrite : public OperationRewrite {
546 public:
547  ModifyOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
548  Operation *op)
549  : OperationRewrite(Kind::ModifyOperation, rewriterImpl, op),
550  name(op->getName()), loc(op->getLoc()), attrs(op->getAttrDictionary()),
551  operands(op->operand_begin(), op->operand_end()),
552  successors(op->successor_begin(), op->successor_end()) {
553  if (OpaqueProperties prop = op->getPropertiesStorage()) {
554  // Make a copy of the properties.
555  propertiesStorage = operator new(op->getPropertiesStorageSize());
556  OpaqueProperties propCopy(propertiesStorage);
557  name.initOpProperties(propCopy, /*init=*/prop);
558  }
559  }
560 
561  static bool classof(const IRRewrite *rewrite) {
562  return rewrite->getKind() == Kind::ModifyOperation;
563  }
564 
565  ~ModifyOperationRewrite() override {
566  assert(!propertiesStorage &&
567  "rewrite was neither committed nor rolled back");
568  }
569 
570  void commit(RewriterBase &rewriter) override {
571  // Notify the listener that the operation was modified in-place.
572  if (auto *listener =
573  dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener()))
574  listener->notifyOperationModified(op);
575 
576  if (propertiesStorage) {
577  OpaqueProperties propCopy(propertiesStorage);
578  // Note: The operation may have been erased in the mean time, so
579  // OperationName must be stored in this object.
580  name.destroyOpProperties(propCopy);
581  operator delete(propertiesStorage);
582  propertiesStorage = nullptr;
583  }
584  }
585 
586  void rollback() override {
587  op->setLoc(loc);
588  op->setAttrs(attrs);
589  op->setOperands(operands);
590  for (const auto &it : llvm::enumerate(successors))
591  op->setSuccessor(it.value(), it.index());
592  if (propertiesStorage) {
593  OpaqueProperties propCopy(propertiesStorage);
594  op->copyProperties(propCopy);
595  name.destroyOpProperties(propCopy);
596  operator delete(propertiesStorage);
597  propertiesStorage = nullptr;
598  }
599  }
600 
601 private:
602  OperationName name;
603  LocationAttr loc;
604  DictionaryAttr attrs;
605  SmallVector<Value, 8> operands;
606  SmallVector<Block *, 2> successors;
607  void *propertiesStorage = nullptr;
608 };
609 
610 /// Replacing an operation. Erasing an operation is treated as a special case
611 /// with "null" replacements. This rewrite is not immediately reflected in the
612 /// IR. An internal IR mapping is updated, but values are not replaced and the
613 /// original op is not erased until the rewrite is committed.
614 class ReplaceOperationRewrite : public OperationRewrite {
615 public:
616  ReplaceOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
617  Operation *op, const TypeConverter *converter)
618  : OperationRewrite(Kind::ReplaceOperation, rewriterImpl, op),
619  converter(converter) {}
620 
621  static bool classof(const IRRewrite *rewrite) {
622  return rewrite->getKind() == Kind::ReplaceOperation;
623  }
624 
625  void commit(RewriterBase &rewriter) override;
626 
627  void rollback() override;
628 
629  void cleanup(RewriterBase &rewriter) override;
630 
631 private:
632  /// An optional type converter that can be used to materialize conversions
633  /// between the new and old values if necessary.
634  const TypeConverter *converter;
635 };
636 
637 class CreateOperationRewrite : public OperationRewrite {
638 public:
639  CreateOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
640  Operation *op)
641  : OperationRewrite(Kind::CreateOperation, rewriterImpl, op) {}
642 
643  static bool classof(const IRRewrite *rewrite) {
644  return rewrite->getKind() == Kind::CreateOperation;
645  }
646 
647  void commit(RewriterBase &rewriter) override {
648  // The operation was already created and inserted. Just inform the listener.
649  if (auto *listener = rewriter.getListener())
650  listener->notifyOperationInserted(op, /*previous=*/{});
651  }
652 
653  void rollback() override;
654 };
655 
656 /// The type of materialization.
657 enum MaterializationKind {
658  /// This materialization materializes a conversion for an illegal block
659  /// argument type, to the original one.
660  Argument,
661 
662  /// This materialization materializes a conversion from an illegal type to a
663  /// legal one.
664  Target,
665 
666  /// This materialization materializes a conversion from a legal type back to
667  /// an illegal one.
668  Source
669 };
670 
671 /// An unresolved materialization, i.e., a "builtin.unrealized_conversion_cast"
672 /// op. Unresolved materializations are erased at the end of the dialect
673 /// conversion.
674 class UnresolvedMaterializationRewrite : public OperationRewrite {
675 public:
676  UnresolvedMaterializationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
677  UnrealizedConversionCastOp op,
678  const TypeConverter *converter,
679  MaterializationKind kind, Type originalType);
680 
681  static bool classof(const IRRewrite *rewrite) {
682  return rewrite->getKind() == Kind::UnresolvedMaterialization;
683  }
684 
685  void rollback() override;
686 
687  UnrealizedConversionCastOp getOperation() const {
688  return cast<UnrealizedConversionCastOp>(op);
689  }
690 
691  /// Return the type converter of this materialization (which may be null).
692  const TypeConverter *getConverter() const {
693  return converterAndKind.getPointer();
694  }
695 
696  /// Return the kind of this materialization.
697  MaterializationKind getMaterializationKind() const {
698  return converterAndKind.getInt();
699  }
700 
701  /// Return the original type of the SSA value.
702  Type getOriginalType() const { return originalType; }
703 
704 private:
705  /// The corresponding type converter to use when resolving this
706  /// materialization, and the kind of this materialization.
707  llvm::PointerIntPair<const TypeConverter *, 2, MaterializationKind>
708  converterAndKind;
709 
710  /// The original type of the SSA value. Only used for target
711  /// materializations.
712  Type originalType;
713 };
714 } // namespace
715 
716 /// Return "true" if there is an operation rewrite that matches the specified
717 /// rewrite type and operation among the given rewrites.
718 template <typename RewriteTy, typename R>
719 static bool hasRewrite(R &&rewrites, Operation *op) {
720  return any_of(std::forward<R>(rewrites), [&](auto &rewrite) {
721  auto *rewriteTy = dyn_cast<RewriteTy>(rewrite.get());
722  return rewriteTy && rewriteTy->getOperation() == op;
723  });
724 }
725 
726 #ifndef NDEBUG
727 /// Return "true" if there is a block rewrite that matches the specified
728 /// rewrite type and block among the given rewrites.
729 template <typename RewriteTy, typename R>
730 static bool hasRewrite(R &&rewrites, Block *block) {
731  return any_of(std::forward<R>(rewrites), [&](auto &rewrite) {
732  auto *rewriteTy = dyn_cast<RewriteTy>(rewrite.get());
733  return rewriteTy && rewriteTy->getBlock() == block;
734  });
735 }
736 #endif // NDEBUG
737 
738 //===----------------------------------------------------------------------===//
739 // ConversionPatternRewriterImpl
740 //===----------------------------------------------------------------------===//
741 namespace mlir {
742 namespace detail {
745  const ConversionConfig &config)
746  : context(ctx), eraseRewriter(ctx), config(config) {}
747 
748  //===--------------------------------------------------------------------===//
749  // State Management
750  //===--------------------------------------------------------------------===//
751 
752  /// Return the current state of the rewriter.
753  RewriterState getCurrentState();
754 
755  /// Apply all requested operation rewrites. This method is invoked when the
756  /// conversion process succeeds.
757  void applyRewrites();
758 
759  /// Reset the state of the rewriter to a previously saved point.
760  void resetState(RewriterState state);
761 
762  /// Append a rewrite. Rewrites are committed upon success and rolled back upon
763  /// failure.
764  template <typename RewriteTy, typename... Args>
765  void appendRewrite(Args &&...args) {
766  rewrites.push_back(
767  std::make_unique<RewriteTy>(*this, std::forward<Args>(args)...));
768  }
769 
770  /// Undo the rewrites (motions, splits) one by one in reverse order until
771  /// "numRewritesToKeep" rewrites remains.
772  void undoRewrites(unsigned numRewritesToKeep = 0);
773 
774  /// Remap the given values to those with potentially different types. Returns
775  /// success if the values could be remapped, failure otherwise. `valueDiagTag`
776  /// is the tag used when describing a value within a diagnostic, e.g.
777  /// "operand".
778  LogicalResult remapValues(StringRef valueDiagTag,
779  std::optional<Location> inputLoc,
780  PatternRewriter &rewriter, ValueRange values,
781  SmallVectorImpl<Value> &remapped);
782 
783  /// Return "true" if the given operation is ignored, and does not need to be
784  /// converted.
785  bool isOpIgnored(Operation *op) const;
786 
787  /// Return "true" if the given operation was replaced or erased.
788  bool wasOpReplaced(Operation *op) const;
789 
790  //===--------------------------------------------------------------------===//
791  // Type Conversion
792  //===--------------------------------------------------------------------===//
793 
794  /// Convert the types of block arguments within the given region.
795  FailureOr<Block *>
796  convertRegionTypes(ConversionPatternRewriter &rewriter, Region *region,
797  const TypeConverter &converter,
798  TypeConverter::SignatureConversion *entryConversion);
799 
800  /// Apply the given signature conversion on the given block. The new block
801  /// containing the updated signature is returned. If no conversions were
802  /// necessary, e.g. if the block has no arguments, `block` is returned.
803  /// `converter` is used to generate any necessary cast operations that
804  /// translate between the origin argument types and those specified in the
805  /// signature conversion.
806  Block *applySignatureConversion(
807  ConversionPatternRewriter &rewriter, Block *block,
808  const TypeConverter *converter,
809  TypeConverter::SignatureConversion &signatureConversion);
810 
811  //===--------------------------------------------------------------------===//
812  // Materializations
813  //===--------------------------------------------------------------------===//
814 
815  /// Build an unresolved materialization operation given an output type and set
816  /// of input operands.
817  Value buildUnresolvedMaterialization(MaterializationKind kind,
819  ValueRange inputs, Type outputType,
820  Type originalType,
821  const TypeConverter *converter);
822 
823  /// Build an N:1 materialization for the given original value that was
824  /// replaced with the given replacement values.
825  ///
826  /// This is a workaround around incomplete 1:N support in the dialect
827  /// conversion driver. The conversion mapping can store only 1:1 replacements
828  /// and the conversion patterns only support single Value replacements in the
829  /// adaptor, so N values must be converted back to a single value. This
830  /// function will be deleted when full 1:N support has been added.
831  ///
832  /// This function inserts an argument materialization back to the original
833  /// type, followed by a target materialization to the legalized type (if
834  /// applicable).
835  void insertNTo1Materialization(OpBuilder::InsertPoint ip, Location loc,
836  ValueRange replacements, Value originalValue,
837  const TypeConverter *converter);
838 
839  /// Find a replacement value for the given SSA value in the conversion value
840  /// mapping. The replacement value must have the same type as the given SSA
841  /// value. If there is no replacement value with the correct type, find the
842  /// latest replacement value (regardless of the type) and build a source
843  /// materialization.
844  Value findOrBuildReplacementValue(Value value,
845  const TypeConverter *converter);
846 
847  //===--------------------------------------------------------------------===//
848  // Rewriter Notification Hooks
849  //===--------------------------------------------------------------------===//
850 
851  //// Notifies that an op was inserted.
852  void notifyOperationInserted(Operation *op,
853  OpBuilder::InsertPoint previous) override;
854 
855  /// Notifies that an op is about to be replaced with the given values.
856  void notifyOpReplaced(Operation *op, ArrayRef<ReplacementValues> newValues);
857 
858  /// Notifies that a block is about to be erased.
859  void notifyBlockIsBeingErased(Block *block);
860 
861  /// Notifies that a block was inserted.
862  void notifyBlockInserted(Block *block, Region *previous,
863  Region::iterator previousIt) override;
864 
865  /// Notifies that a block is being inlined into another block.
866  void notifyBlockBeingInlined(Block *block, Block *srcBlock,
867  Block::iterator before);
868 
869  /// Notifies that a pattern match failed for the given reason.
870  void
871  notifyMatchFailure(Location loc,
872  function_ref<void(Diagnostic &)> reasonCallback) override;
873 
874  //===--------------------------------------------------------------------===//
875  // IR Erasure
876  //===--------------------------------------------------------------------===//
877 
878  /// A rewriter that keeps track of erased ops and blocks. It ensures that no
879  /// operation or block is erased multiple times. This rewriter assumes that
880  /// no new IR is created between calls to `eraseOp`/`eraseBlock`.
882  public:
884  : RewriterBase(context, /*listener=*/this) {}
885 
886  /// Erase the given op (unless it was already erased).
887  void eraseOp(Operation *op) override {
888  if (wasErased(op))
889  return;
890  op->dropAllUses();
892  }
893 
894  /// Erase the given block (unless it was already erased).
895  void eraseBlock(Block *block) override {
896  if (wasErased(block))
897  return;
898  assert(block->empty() && "expected empty block");
899  block->dropAllDefinedValueUses();
901  }
902 
903  bool wasErased(void *ptr) const { return erased.contains(ptr); }
904 
905  void notifyOperationErased(Operation *op) override { erased.insert(op); }
906 
907  void notifyBlockErased(Block *block) override { erased.insert(block); }
908 
909  private:
910  /// Pointers to all erased operations and blocks.
911  DenseSet<void *> erased;
912  };
913 
914  //===--------------------------------------------------------------------===//
915  // State
916  //===--------------------------------------------------------------------===//
917 
918  /// MLIR context.
920 
921  /// A rewriter that keeps track of ops/block that were already erased and
922  /// skips duplicate op/block erasures. This rewriter is used during the
923  /// "cleanup" phase.
925 
926  // Mapping between replaced values that differ in type. This happens when
927  // replacing a value with one of a different type.
928  ConversionValueMapping mapping;
929 
930  /// Ordered list of block operations (creations, splits, motions).
932 
933  /// A set of operations that should no longer be considered for legalization.
934  /// E.g., ops that are recursively legal. Ops that were replaced/erased are
935  /// tracked separately.
937 
938  /// A set of operations that were replaced/erased. Such ops are not erased
939  /// immediately but only when the dialect conversion succeeds. In the mean
940  /// time, they should no longer be considered for legalization and any attempt
941  /// to modify/access them is invalid rewriter API usage.
943 
944  /// A mapping of all unresolved materializations (UnrealizedConversionCastOp)
945  /// to the corresponding rewrite objects.
948 
949  /// The current type converter, or nullptr if no type converter is currently
950  /// active.
951  const TypeConverter *currentTypeConverter = nullptr;
952 
953  /// A mapping of regions to type converters that should be used when
954  /// converting the arguments of blocks within that region.
956 
957  /// Dialect conversion configuration.
959 
960 #ifndef NDEBUG
961  /// A set of operations that have pending updates. This tracking isn't
962  /// strictly necessary, and is thus only active during debug builds for extra
963  /// verification.
965 
966  /// A logger used to emit diagnostics during the conversion process.
967  llvm::ScopedPrinter logger{llvm::dbgs()};
968 #endif
969 };
970 } // namespace detail
971 } // namespace mlir
972 
973 const ConversionConfig &IRRewrite::getConfig() const {
974  return rewriterImpl.config;
975 }
976 
977 void BlockTypeConversionRewrite::commit(RewriterBase &rewriter) {
978  // Inform the listener about all IR modifications that have already taken
979  // place: References to the original block have been replaced with the new
980  // block.
981  if (auto *listener =
982  dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener()))
983  for (Operation *op : getNewBlock()->getUsers())
984  listener->notifyOperationModified(op);
985 }
986 
987 void BlockTypeConversionRewrite::rollback() {
988  getNewBlock()->replaceAllUsesWith(getOrigBlock());
989 }
990 
991 void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) {
992  Value repl = rewriterImpl.findOrBuildReplacementValue(arg, converter);
993  if (!repl)
994  return;
995 
996  if (isa<BlockArgument>(repl)) {
997  rewriter.replaceAllUsesWith(arg, repl);
998  return;
999  }
1000 
1001  // If the replacement value is an operation, we check to make sure that we
1002  // don't replace uses that are within the parent operation of the
1003  // replacement value.
1004  Operation *replOp = cast<OpResult>(repl).getOwner();
1005  Block *replBlock = replOp->getBlock();
1006  rewriter.replaceUsesWithIf(arg, repl, [&](OpOperand &operand) {
1007  Operation *user = operand.getOwner();
1008  return user->getBlock() != replBlock || replOp->isBeforeInBlock(user);
1009  });
1010 }
1011 
1012 void ReplaceBlockArgRewrite::rollback() { rewriterImpl.mapping.erase(arg); }
1013 
1014 void ReplaceOperationRewrite::commit(RewriterBase &rewriter) {
1015  auto *listener =
1016  dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener());
1017 
1018  // Compute replacement values.
1019  SmallVector<Value> replacements =
1020  llvm::map_to_vector(op->getResults(), [&](OpResult result) {
1021  return rewriterImpl.findOrBuildReplacementValue(result, converter);
1022  });
1023 
1024  // Notify the listener that the operation is about to be replaced.
1025  if (listener)
1026  listener->notifyOperationReplaced(op, replacements);
1027 
1028  // Replace all uses with the new values.
1029  for (auto [result, newValue] :
1030  llvm::zip_equal(op->getResults(), replacements))
1031  if (newValue)
1032  rewriter.replaceAllUsesWith(result, newValue);
1033 
1034  // The original op will be erased, so remove it from the set of unlegalized
1035  // ops.
1036  if (getConfig().unlegalizedOps)
1037  getConfig().unlegalizedOps->erase(op);
1038 
1039  // Notify the listener that the operation (and its nested operations) was
1040  // erased.
1041  if (listener) {
1043  [&](Operation *op) { listener->notifyOperationErased(op); });
1044  }
1045 
1046  // Do not erase the operation yet. It may still be referenced in `mapping`.
1047  // Just unlink it for now and erase it during cleanup.
1048  op->getBlock()->getOperations().remove(op);
1049 }
1050 
1051 void ReplaceOperationRewrite::rollback() {
1052  for (auto result : op->getResults())
1053  rewriterImpl.mapping.erase(result);
1054 }
1055 
1056 void ReplaceOperationRewrite::cleanup(RewriterBase &rewriter) {
1057  rewriter.eraseOp(op);
1058 }
1059 
1060 void CreateOperationRewrite::rollback() {
1061  for (Region &region : op->getRegions()) {
1062  while (!region.getBlocks().empty())
1063  region.getBlocks().remove(region.getBlocks().begin());
1064  }
1065  op->dropAllUses();
1066  op->erase();
1067 }
1068 
1069 UnresolvedMaterializationRewrite::UnresolvedMaterializationRewrite(
1070  ConversionPatternRewriterImpl &rewriterImpl, UnrealizedConversionCastOp op,
1071  const TypeConverter *converter, MaterializationKind kind, Type originalType)
1072  : OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
1073  converterAndKind(converter, kind), originalType(originalType) {
1074  assert((!originalType || kind == MaterializationKind::Target) &&
1075  "original type is valid only for target materializations");
1076  rewriterImpl.unresolvedMaterializations[op] = this;
1077 }
1078 
1079 void UnresolvedMaterializationRewrite::rollback() {
1080  if (getMaterializationKind() == MaterializationKind::Target) {
1081  for (Value input : op->getOperands())
1082  rewriterImpl.mapping.erase(input);
1083  }
1084  rewriterImpl.unresolvedMaterializations.erase(getOperation());
1085  op->erase();
1086 }
1087 
1089  // Commit all rewrites.
1090  IRRewriter rewriter(context, config.listener);
1091  // Note: New rewrites may be added during the "commit" phase and the
1092  // `rewrites` vector may reallocate.
1093  for (size_t i = 0; i < rewrites.size(); ++i)
1094  rewrites[i]->commit(rewriter);
1095 
1096  // Clean up all rewrites.
1097  for (auto &rewrite : rewrites)
1098  rewrite->cleanup(eraseRewriter);
1099 }
1100 
1101 //===----------------------------------------------------------------------===//
1102 // State Management
1103 
1105  return RewriterState(rewrites.size(), ignoredOps.size(), replacedOps.size());
1106 }
1107 
1109  // Undo any rewrites.
1110  undoRewrites(state.numRewrites);
1111 
1112  // Pop all of the recorded ignored operations that are no longer valid.
1113  while (ignoredOps.size() != state.numIgnoredOperations)
1114  ignoredOps.pop_back();
1115 
1116  while (replacedOps.size() != state.numReplacedOps)
1117  replacedOps.pop_back();
1118 }
1119 
1120 void ConversionPatternRewriterImpl::undoRewrites(unsigned numRewritesToKeep) {
1121  for (auto &rewrite :
1122  llvm::reverse(llvm::drop_begin(rewrites, numRewritesToKeep)))
1123  rewrite->rollback();
1124  rewrites.resize(numRewritesToKeep);
1125 }
1126 
1128  StringRef valueDiagTag, std::optional<Location> inputLoc,
1129  PatternRewriter &rewriter, ValueRange values,
1130  SmallVectorImpl<Value> &remapped) {
1131  remapped.reserve(llvm::size(values));
1132 
1133  for (const auto &it : llvm::enumerate(values)) {
1134  Value operand = it.value();
1135  Type origType = operand.getType();
1136  Location operandLoc = inputLoc ? *inputLoc : operand.getLoc();
1137 
1138  if (!currentTypeConverter) {
1139  // The current pattern does not have a type converter. I.e., it does not
1140  // distinguish between legal and illegal types. For each operand, simply
1141  // pass through the most recently mapped value.
1142  remapped.push_back(mapping.lookupOrDefault(operand));
1143  continue;
1144  }
1145 
1146  // If there is no legal conversion, fail to match this pattern.
1147  SmallVector<Type, 1> legalTypes;
1148  if (failed(currentTypeConverter->convertType(origType, legalTypes))) {
1149  notifyMatchFailure(operandLoc, [=](Diagnostic &diag) {
1150  diag << "unable to convert type for " << valueDiagTag << " #"
1151  << it.index() << ", type was " << origType;
1152  });
1153  return failure();
1154  }
1155 
1156  if (legalTypes.size() != 1) {
1157  // TODO: Parts of the dialect conversion infrastructure do not support
1158  // 1->N type conversions yet. Therefore, if a type is converted to 0 or
1159  // multiple types, the only thing that we can do for now is passing
1160  // through the most recently mapped value. Fixing this requires
1161  // improvements to the `ConversionValueMapping` (to be able to store 1:N
1162  // mappings) and to the `ConversionPattern` adaptor handling (to be able
1163  // to pass multiple remapped values for a single operand to the adaptor).
1164  remapped.push_back(mapping.lookupOrDefault(operand));
1165  continue;
1166  }
1167 
1168  // Handle 1->1 type conversions.
1169  Type desiredType = legalTypes.front();
1170  // Try to find a mapped value with the desired type. (Or the operand itself
1171  // if the value is not mapped at all.)
1172  Value newOperand = mapping.lookupOrDefault(operand, desiredType);
1173  if (newOperand.getType() != desiredType) {
1174  // If the looked up value's type does not have the desired type, it means
1175  // that the value was replaced with a value of different type and no
1176  // source materialization was created yet.
1178  MaterializationKind::Target, computeInsertPoint(newOperand),
1179  operandLoc,
1180  /*inputs=*/newOperand, /*outputType=*/desiredType,
1181  /*originalType=*/origType, currentTypeConverter);
1182  mapping.map(newOperand, castValue);
1183  newOperand = castValue;
1184  }
1185  remapped.push_back(newOperand);
1186  }
1187  return success();
1188 }
1189 
1191  // Check to see if this operation is ignored or was replaced.
1192  return replacedOps.count(op) || ignoredOps.count(op);
1193 }
1194 
1196  // Check to see if this operation was replaced.
1197  return replacedOps.count(op);
1198 }
1199 
1200 //===----------------------------------------------------------------------===//
1201 // Type Conversion
1202 
1204  ConversionPatternRewriter &rewriter, Region *region,
1205  const TypeConverter &converter,
1206  TypeConverter::SignatureConversion *entryConversion) {
1207  regionToConverter[region] = &converter;
1208  if (region->empty())
1209  return nullptr;
1210 
1211  // Convert the arguments of each non-entry block within the region.
1212  for (Block &block :
1213  llvm::make_early_inc_range(llvm::drop_begin(*region, 1))) {
1214  // Compute the signature for the block with the provided converter.
1215  std::optional<TypeConverter::SignatureConversion> conversion =
1216  converter.convertBlockSignature(&block);
1217  if (!conversion)
1218  return failure();
1219  // Convert the block with the computed signature.
1220  applySignatureConversion(rewriter, &block, &converter, *conversion);
1221  }
1222 
1223  // Convert the entry block. If an entry signature conversion was provided,
1224  // use that one. Otherwise, compute the signature with the type converter.
1225  if (entryConversion)
1226  return applySignatureConversion(rewriter, &region->front(), &converter,
1227  *entryConversion);
1228  std::optional<TypeConverter::SignatureConversion> conversion =
1229  converter.convertBlockSignature(&region->front());
1230  if (!conversion)
1231  return failure();
1232  return applySignatureConversion(rewriter, &region->front(), &converter,
1233  *conversion);
1234 }
1235 
1237  ConversionPatternRewriter &rewriter, Block *block,
1238  const TypeConverter *converter,
1239  TypeConverter::SignatureConversion &signatureConversion) {
1240  // A block cannot be converted multiple times.
1241  assert(!hasRewrite<BlockTypeConversionRewrite>(rewrites, block) &&
1242  "block was already converted");
1243  OpBuilder::InsertionGuard g(rewriter);
1244 
1245  // If no arguments are being changed or added, there is nothing to do.
1246  unsigned origArgCount = block->getNumArguments();
1247  auto convertedTypes = signatureConversion.getConvertedTypes();
1248  if (llvm::equal(block->getArgumentTypes(), convertedTypes))
1249  return block;
1250 
1251  // Compute the locations of all block arguments in the new block.
1252  SmallVector<Location> newLocs(convertedTypes.size(),
1253  rewriter.getUnknownLoc());
1254  for (unsigned i = 0; i < origArgCount; ++i) {
1255  auto inputMap = signatureConversion.getInputMapping(i);
1256  if (!inputMap || inputMap->replacementValue)
1257  continue;
1258  Location origLoc = block->getArgument(i).getLoc();
1259  for (unsigned j = 0; j < inputMap->size; ++j)
1260  newLocs[inputMap->inputNo + j] = origLoc;
1261  }
1262 
1263  // Insert a new block with the converted block argument types and move all ops
1264  // from the old block to the new block.
1265  Block *newBlock =
1266  rewriter.createBlock(block->getParent(), std::next(block->getIterator()),
1267  convertedTypes, newLocs);
1268 
1269  // If a listener is attached to the dialect conversion, ops cannot be moved
1270  // to the destination block in bulk ("fast path"). This is because at the time
1271  // the notifications are sent, it is unknown which ops were moved. Instead,
1272  // ops should be moved one-by-one ("slow path"), so that a separate
1273  // `MoveOperationRewrite` is enqueued for each moved op. Moving ops in bulk is
1274  // a bit more efficient, so we try to do that when possible.
1275  bool fastPath = !config.listener;
1276  if (fastPath) {
1277  appendRewrite<InlineBlockRewrite>(newBlock, block, newBlock->end());
1278  newBlock->getOperations().splice(newBlock->end(), block->getOperations());
1279  } else {
1280  while (!block->empty())
1281  rewriter.moveOpBefore(&block->front(), newBlock, newBlock->end());
1282  }
1283 
1284  // Replace all uses of the old block with the new block.
1285  block->replaceAllUsesWith(newBlock);
1286 
1287  for (unsigned i = 0; i != origArgCount; ++i) {
1288  BlockArgument origArg = block->getArgument(i);
1289  Type origArgType = origArg.getType();
1290 
1291  std::optional<TypeConverter::SignatureConversion::InputMapping> inputMap =
1292  signatureConversion.getInputMapping(i);
1293  if (!inputMap) {
1294  // This block argument was dropped and no replacement value was provided.
1295  // Materialize a replacement value "out of thin air".
1297  MaterializationKind::Source,
1298  OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(),
1299  /*inputs=*/ValueRange(),
1300  /*outputType=*/origArgType, /*originalType=*/Type(), converter);
1301  mapping.map(origArg, repl);
1302  appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
1303  continue;
1304  }
1305 
1306  if (Value repl = inputMap->replacementValue) {
1307  // This block argument was dropped and a replacement value was provided.
1308  assert(inputMap->size == 0 &&
1309  "invalid to provide a replacement value when the argument isn't "
1310  "dropped");
1311  mapping.map(origArg, repl);
1312  appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
1313  continue;
1314  }
1315 
1316  // This is a 1->1+ mapping. 1->N mappings are not fully supported in the
1317  // dialect conversion. Therefore, we need an argument materialization to
1318  // turn the replacement block arguments into a single SSA value that can be
1319  // used as a replacement.
1320  auto replArgs =
1321  newBlock->getArguments().slice(inputMap->inputNo, inputMap->size);
1323  OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(),
1324  /*replacements=*/replArgs, /*outputValue=*/origArg, converter);
1325  appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
1326  }
1327 
1328  appendRewrite<BlockTypeConversionRewrite>(/*origBlock=*/block, newBlock);
1329 
1330  // Erase the old block. (It is just unlinked for now and will be erased during
1331  // cleanup.)
1332  rewriter.eraseBlock(block);
1333 
1334  return newBlock;
1335 }
1336 
1337 //===----------------------------------------------------------------------===//
1338 // Materializations
1339 //===----------------------------------------------------------------------===//
1340 
1341 /// Build an unresolved materialization operation given an output type and set
1342 /// of input operands.
1344  MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
1345  ValueRange inputs, Type outputType, Type originalType,
1346  const TypeConverter *converter) {
1347  assert((!originalType || kind == MaterializationKind::Target) &&
1348  "original type is valid only for target materializations");
1349 
1350  // Avoid materializing an unnecessary cast.
1351  if (inputs.size() == 1 && inputs.front().getType() == outputType)
1352  return inputs.front();
1353 
1354  // Create an unresolved materialization. We use a new OpBuilder to avoid
1355  // tracking the materialization like we do for other operations.
1356  OpBuilder builder(outputType.getContext());
1357  builder.setInsertionPoint(ip.getBlock(), ip.getPoint());
1358  auto convertOp =
1359  builder.create<UnrealizedConversionCastOp>(loc, outputType, inputs);
1360  appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind,
1361  originalType);
1362  return convertOp.getResult(0);
1363 }
1364 
1366  OpBuilder::InsertPoint ip, Location loc, ValueRange replacements,
1367  Value originalValue, const TypeConverter *converter) {
1368  // Insert argument materialization back to the original type.
1369  Type originalType = originalValue.getType();
1370  Value argMat =
1372  /*inputs=*/replacements, originalType,
1373  /*originalType=*/Type(), converter);
1374  mapping.map(originalValue, argMat);
1375 
1376  // Insert target materialization to the legalized type.
1377  Type legalOutputType;
1378  if (converter) {
1379  legalOutputType = converter->convertType(originalType);
1380  } else if (replacements.size() == 1) {
1381  // When there is no type converter, assume that the replacement value
1382  // types are legal. This is reasonable to assume because they were
1383  // specified by the user.
1384  // FIXME: This won't work for 1->N conversions because multiple output
1385  // types are not supported in parts of the dialect conversion. In such a
1386  // case, we currently use the original value type.
1387  legalOutputType = replacements[0].getType();
1388  }
1389  if (legalOutputType && legalOutputType != originalType) {
1391  MaterializationKind::Target, computeInsertPoint(argMat), loc,
1392  /*inputs=*/argMat, /*outputType=*/legalOutputType,
1393  /*originalType=*/originalType, converter);
1394  mapping.map(argMat, targetMat);
1395  }
1396 }
1397 
1399  Value value, const TypeConverter *converter) {
1400  // Find a replacement value with the same type.
1401  Value repl = mapping.lookupOrNull(value, value.getType());
1402  if (repl)
1403  return repl;
1404 
1405  // Check if the value is dead. No replacement value is needed in that case.
1406  // This is an approximate check that may have false negatives but does not
1407  // require computing and traversing an inverse mapping. (We may end up
1408  // building source materializations that are never used and that fold away.)
1409  if (llvm::all_of(value.getUsers(),
1410  [&](Operation *op) { return replacedOps.contains(op); }) &&
1411  !mapping.isMappedTo(value))
1412  return Value();
1413 
1414  // No replacement value was found. Get the latest replacement value
1415  // (regardless of the type) and build a source materialization to the
1416  // original type.
1417  repl = mapping.lookupOrNull(value);
1418  if (!repl) {
1419  // No replacement value is registered in the mapping. This means that the
1420  // value is dropped and no longer needed. (If the value were still needed,
1421  // a source materialization producing a replacement value "out of thin air"
1422  // would have already been created during `replaceOp` or
1423  // `applySignatureConversion`.)
1424  return Value();
1425  }
1427  MaterializationKind::Source, computeInsertPoint(repl), value.getLoc(),
1428  /*inputs=*/repl, /*outputType=*/value.getType(),
1429  /*originalType=*/Type(), converter);
1430  mapping.map(value, castValue);
1431  return castValue;
1432 }
1433 
1434 //===----------------------------------------------------------------------===//
1435 // Rewriter Notification Hooks
1436 
1438  Operation *op, OpBuilder::InsertPoint previous) {
1439  LLVM_DEBUG({
1440  logger.startLine() << "** Insert : '" << op->getName() << "'(" << op
1441  << ")\n";
1442  });
1443  assert(!wasOpReplaced(op->getParentOp()) &&
1444  "attempting to insert into a block within a replaced/erased op");
1445 
1446  if (!previous.isSet()) {
1447  // This is a newly created op.
1448  appendRewrite<CreateOperationRewrite>(op);
1449  return;
1450  }
1451  Operation *prevOp = previous.getPoint() == previous.getBlock()->end()
1452  ? nullptr
1453  : &*previous.getPoint();
1454  appendRewrite<MoveOperationRewrite>(op, previous.getBlock(), prevOp);
1455 }
1456 
1458  Operation *op, ArrayRef<ReplacementValues> newValues) {
1459  assert(newValues.size() == op->getNumResults());
1460  assert(!ignoredOps.contains(op) && "operation was already replaced");
1461 
1462  // Check if replaced op is an unresolved materialization, i.e., an
1463  // unrealized_conversion_cast op that was created by the conversion driver.
1464  bool isUnresolvedMaterialization = false;
1465  if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op))
1466  if (unresolvedMaterializations.contains(castOp))
1467  isUnresolvedMaterialization = true;
1468 
1469  // Create mappings for each of the new result values.
1470  for (auto [n, result] : llvm::zip_equal(newValues, op->getResults())) {
1471  ReplacementValues repl = n;
1472  if (repl.empty()) {
1473  // This result was dropped and no replacement value was provided.
1474  if (isUnresolvedMaterialization) {
1475  // Do not create another materializations if we are erasing a
1476  // materialization.
1477  continue;
1478  }
1479 
1480  // Materialize a replacement value "out of thin air".
1482  MaterializationKind::Source, computeInsertPoint(result),
1483  result.getLoc(), /*inputs=*/ValueRange(),
1484  /*outputType=*/result.getType(), /*originalType=*/Type(),
1486  repl.push_back(sourceMat);
1487  } else {
1488  // Make sure that the user does not mess with unresolved materializations
1489  // that were inserted by the conversion driver. We keep track of these
1490  // ops in internal data structures. Erasing them must be allowed because
1491  // this can happen when the user is erasing an entire block (including
1492  // its body). But replacing them with another value should be forbidden
1493  // to avoid problems with the `mapping`.
1494  assert(!isUnresolvedMaterialization &&
1495  "attempting to replace an unresolved materialization");
1496  }
1497 
1498  // Remap result to replacement value.
1499  if (repl.empty())
1500  continue;
1501 
1502  if (repl.size() == 1) {
1503  // Single replacement value: replace directly.
1504  mapping.map(result, repl.front());
1505  } else {
1506  // Multiple replacement values: insert N:1 materialization.
1508  /*replacements=*/repl, /*outputValue=*/result,
1510  }
1511  }
1512 
1513  appendRewrite<ReplaceOperationRewrite>(op, currentTypeConverter);
1514  // Mark this operation and all nested ops as replaced.
1515  op->walk([&](Operation *op) { replacedOps.insert(op); });
1516 }
1517 
1519  appendRewrite<EraseBlockRewrite>(block);
1520 }
1521 
1523  Block *block, Region *previous, Region::iterator previousIt) {
1524  assert(!wasOpReplaced(block->getParentOp()) &&
1525  "attempting to insert into a region within a replaced/erased op");
1526  LLVM_DEBUG(
1527  {
1528  Operation *parent = block->getParentOp();
1529  if (parent) {
1530  logger.startLine() << "** Insert Block into : '" << parent->getName()
1531  << "'(" << parent << ")\n";
1532  } else {
1533  logger.startLine()
1534  << "** Insert Block into detached Region (nullptr parent op)'";
1535  }
1536  });
1537 
1538  if (!previous) {
1539  // This is a newly created block.
1540  appendRewrite<CreateBlockRewrite>(block);
1541  return;
1542  }
1543  Block *prevBlock = previousIt == previous->end() ? nullptr : &*previousIt;
1544  appendRewrite<MoveBlockRewrite>(block, previous, prevBlock);
1545 }
1546 
1548  Block *block, Block *srcBlock, Block::iterator before) {
1549  appendRewrite<InlineBlockRewrite>(block, srcBlock, before);
1550 }
1551 
1553  Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
1554  LLVM_DEBUG({
1556  reasonCallback(diag);
1557  logger.startLine() << "** Failure : " << diag.str() << "\n";
1558  if (config.notifyCallback)
1560  });
1561 }
1562 
1563 //===----------------------------------------------------------------------===//
1564 // ConversionPatternRewriter
1565 //===----------------------------------------------------------------------===//
1566 
1567 ConversionPatternRewriter::ConversionPatternRewriter(
1568  MLIRContext *ctx, const ConversionConfig &config)
1569  : PatternRewriter(ctx),
1570  impl(new detail::ConversionPatternRewriterImpl(ctx, config)) {
1571  setListener(impl.get());
1572 }
1573 
1575 
1577  assert(op && newOp && "expected non-null op");
1578  replaceOp(op, newOp->getResults());
1579 }
1580 
1582  assert(op->getNumResults() == newValues.size() &&
1583  "incorrect # of replacement values");
1584  LLVM_DEBUG({
1585  impl->logger.startLine()
1586  << "** Replace : '" << op->getName() << "'(" << op << ")\n";
1587  });
1588  SmallVector<ReplacementValues> newVals(newValues.size());
1589  for (auto [index, val] : llvm::enumerate(newValues))
1590  if (val)
1591  newVals[index].push_back(val);
1592  impl->notifyOpReplaced(op, newVals);
1593 }
1594 
1596  Operation *op, ArrayRef<ValueRange> newValues) {
1597  assert(op->getNumResults() == newValues.size() &&
1598  "incorrect # of replacement values");
1599  LLVM_DEBUG({
1600  impl->logger.startLine()
1601  << "** Replace : '" << op->getName() << "'(" << op << ")\n";
1602  });
1603  SmallVector<ReplacementValues> newVals(newValues.size(), {});
1604  for (auto [index, val] : llvm::enumerate(newValues))
1605  llvm::append_range(newVals[index], val);
1606  impl->notifyOpReplaced(op, newVals);
1607 }
1608 
1610  LLVM_DEBUG({
1611  impl->logger.startLine()
1612  << "** Erase : '" << op->getName() << "'(" << op << ")\n";
1613  });
1614  SmallVector<ReplacementValues> nullRepls(op->getNumResults(), {});
1615  impl->notifyOpReplaced(op, nullRepls);
1616 }
1617 
1619  assert(!impl->wasOpReplaced(block->getParentOp()) &&
1620  "attempting to erase a block within a replaced/erased op");
1621 
1622  // Mark all ops for erasure.
1623  for (Operation &op : *block)
1624  eraseOp(&op);
1625 
1626  // Unlink the block from its parent region. The block is kept in the rewrite
1627  // object and will be actually destroyed when rewrites are applied. This
1628  // allows us to keep the operations in the block live and undo the removal by
1629  // re-inserting the block.
1630  impl->notifyBlockIsBeingErased(block);
1631  block->getParent()->getBlocks().remove(block);
1632 }
1633 
1635  Block *block, TypeConverter::SignatureConversion &conversion,
1636  const TypeConverter *converter) {
1637  assert(!impl->wasOpReplaced(block->getParentOp()) &&
1638  "attempting to apply a signature conversion to a block within a "
1639  "replaced/erased op");
1640  return impl->applySignatureConversion(*this, block, converter, conversion);
1641 }
1642 
1644  Region *region, const TypeConverter &converter,
1645  TypeConverter::SignatureConversion *entryConversion) {
1646  assert(!impl->wasOpReplaced(region->getParentOp()) &&
1647  "attempting to apply a signature conversion to a block within a "
1648  "replaced/erased op");
1649  return impl->convertRegionTypes(*this, region, converter, entryConversion);
1650 }
1651 
1653  Value to) {
1654  LLVM_DEBUG({
1655  Operation *parentOp = from.getOwner()->getParentOp();
1656  impl->logger.startLine() << "** Replace Argument : '" << from
1657  << "'(in region of '" << parentOp->getName()
1658  << "'(" << from.getOwner()->getParentOp() << ")\n";
1659  });
1660  impl->appendRewrite<ReplaceBlockArgRewrite>(from.getOwner(), from,
1661  impl->currentTypeConverter);
1662  impl->mapping.map(impl->mapping.lookupOrDefault(from), to);
1663 }
1664 
1666  SmallVector<Value> remappedValues;
1667  if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, key,
1668  remappedValues)))
1669  return nullptr;
1670  return remappedValues.front();
1671 }
1672 
1673 LogicalResult
1675  SmallVectorImpl<Value> &results) {
1676  if (keys.empty())
1677  return success();
1678  return impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, keys,
1679  results);
1680 }
1681 
1683  Block::iterator before,
1684  ValueRange argValues) {
1685 #ifndef NDEBUG
1686  assert(argValues.size() == source->getNumArguments() &&
1687  "incorrect # of argument replacement values");
1688  assert(!impl->wasOpReplaced(source->getParentOp()) &&
1689  "attempting to inline a block from a replaced/erased op");
1690  assert(!impl->wasOpReplaced(dest->getParentOp()) &&
1691  "attempting to inline a block into a replaced/erased op");
1692  auto opIgnored = [&](Operation *op) { return impl->isOpIgnored(op); };
1693  // The source block will be deleted, so it should not have any users (i.e.,
1694  // there should be no predecessors).
1695  assert(llvm::all_of(source->getUsers(), opIgnored) &&
1696  "expected 'source' to have no predecessors");
1697 #endif // NDEBUG
1698 
1699  // If a listener is attached to the dialect conversion, ops cannot be moved
1700  // to the destination block in bulk ("fast path"). This is because at the time
1701  // the notifications are sent, it is unknown which ops were moved. Instead,
1702  // ops should be moved one-by-one ("slow path"), so that a separate
1703  // `MoveOperationRewrite` is enqueued for each moved op. Moving ops in bulk is
1704  // a bit more efficient, so we try to do that when possible.
1705  bool fastPath = !impl->config.listener;
1706 
1707  if (fastPath)
1708  impl->notifyBlockBeingInlined(dest, source, before);
1709 
1710  // Replace all uses of block arguments.
1711  for (auto it : llvm::zip(source->getArguments(), argValues))
1712  replaceUsesOfBlockArgument(std::get<0>(it), std::get<1>(it));
1713 
1714  if (fastPath) {
1715  // Move all ops at once.
1716  dest->getOperations().splice(before, source->getOperations());
1717  } else {
1718  // Move op by op.
1719  while (!source->empty())
1720  moveOpBefore(&source->front(), dest, before);
1721  }
1722 
1723  // Erase the source block.
1724  eraseBlock(source);
1725 }
1726 
1728  assert(!impl->wasOpReplaced(op) &&
1729  "attempting to modify a replaced/erased op");
1730 #ifndef NDEBUG
1731  impl->pendingRootUpdates.insert(op);
1732 #endif
1733  impl->appendRewrite<ModifyOperationRewrite>(op);
1734 }
1735 
1737  assert(!impl->wasOpReplaced(op) &&
1738  "attempting to modify a replaced/erased op");
1740  // There is nothing to do here, we only need to track the operation at the
1741  // start of the update.
1742 #ifndef NDEBUG
1743  assert(impl->pendingRootUpdates.erase(op) &&
1744  "operation did not have a pending in-place update");
1745 #endif
1746 }
1747 
1749 #ifndef NDEBUG
1750  assert(impl->pendingRootUpdates.erase(op) &&
1751  "operation did not have a pending in-place update");
1752 #endif
1753  // Erase the last update for this operation.
1754  auto it = llvm::find_if(
1755  llvm::reverse(impl->rewrites), [&](std::unique_ptr<IRRewrite> &rewrite) {
1756  auto *modifyRewrite = dyn_cast<ModifyOperationRewrite>(rewrite.get());
1757  return modifyRewrite && modifyRewrite->getOperation() == op;
1758  });
1759  assert(it != impl->rewrites.rend() && "no root update started on op");
1760  (*it)->rollback();
1761  int updateIdx = std::prev(impl->rewrites.rend()) - it;
1762  impl->rewrites.erase(impl->rewrites.begin() + updateIdx);
1763 }
1764 
1766  return *impl;
1767 }
1768 
1769 //===----------------------------------------------------------------------===//
1770 // ConversionPattern
1771 //===----------------------------------------------------------------------===//
1772 
1773 LogicalResult
1775  PatternRewriter &rewriter) const {
1776  auto &dialectRewriter = static_cast<ConversionPatternRewriter &>(rewriter);
1777  auto &rewriterImpl = dialectRewriter.getImpl();
1778 
1779  // Track the current conversion pattern type converter in the rewriter.
1780  llvm::SaveAndRestore currentConverterGuard(rewriterImpl.currentTypeConverter,
1781  getTypeConverter());
1782 
1783  // Remap the operands of the operation.
1784  SmallVector<Value, 4> operands;
1785  if (failed(rewriterImpl.remapValues("operand", op->getLoc(), rewriter,
1786  op->getOperands(), operands))) {
1787  return failure();
1788  }
1789  return matchAndRewrite(op, operands, dialectRewriter);
1790 }
1791 
1792 //===----------------------------------------------------------------------===//
1793 // OperationLegalizer
1794 //===----------------------------------------------------------------------===//
1795 
1796 namespace {
1797 /// A set of rewrite patterns that can be used to legalize a given operation.
1798 using LegalizationPatterns = SmallVector<const Pattern *, 1>;
1799 
1800 /// This class defines a recursive operation legalizer.
1801 class OperationLegalizer {
1802 public:
1803  using LegalizationAction = ConversionTarget::LegalizationAction;
1804 
1805  OperationLegalizer(const ConversionTarget &targetInfo,
1806  const FrozenRewritePatternSet &patterns,
1807  const ConversionConfig &config);
1808 
1809  /// Returns true if the given operation is known to be illegal on the target.
1810  bool isIllegal(Operation *op) const;
1811 
1812  /// Attempt to legalize the given operation. Returns success if the operation
1813  /// was legalized, failure otherwise.
1814  LogicalResult legalize(Operation *op, ConversionPatternRewriter &rewriter);
1815 
1816  /// Returns the conversion target in use by the legalizer.
1817  const ConversionTarget &getTarget() { return target; }
1818 
1819 private:
1820  /// Attempt to legalize the given operation by folding it.
1821  LogicalResult legalizeWithFold(Operation *op,
1822  ConversionPatternRewriter &rewriter);
1823 
1824  /// Attempt to legalize the given operation by applying a pattern. Returns
1825  /// success if the operation was legalized, failure otherwise.
1826  LogicalResult legalizeWithPattern(Operation *op,
1827  ConversionPatternRewriter &rewriter);
1828 
1829  /// Return true if the given pattern may be applied to the given operation,
1830  /// false otherwise.
1831  bool canApplyPattern(Operation *op, const Pattern &pattern,
1832  ConversionPatternRewriter &rewriter);
1833 
1834  /// Legalize the resultant IR after successfully applying the given pattern.
1835  LogicalResult legalizePatternResult(Operation *op, const Pattern &pattern,
1836  ConversionPatternRewriter &rewriter,
1837  RewriterState &curState);
1838 
1839  /// Legalizes the actions registered during the execution of a pattern.
1840  LogicalResult
1841  legalizePatternBlockRewrites(Operation *op,
1842  ConversionPatternRewriter &rewriter,
1844  RewriterState &state, RewriterState &newState);
1845  LogicalResult legalizePatternCreatedOperations(
1847  RewriterState &state, RewriterState &newState);
1848  LogicalResult legalizePatternRootUpdates(ConversionPatternRewriter &rewriter,
1850  RewriterState &state,
1851  RewriterState &newState);
1852 
1853  //===--------------------------------------------------------------------===//
1854  // Cost Model
1855  //===--------------------------------------------------------------------===//
1856 
1857  /// Build an optimistic legalization graph given the provided patterns. This
1858  /// function populates 'anyOpLegalizerPatterns' and 'legalizerPatterns' with
1859  /// patterns for operations that are not directly legal, but may be
1860  /// transitively legal for the current target given the provided patterns.
1861  void buildLegalizationGraph(
1862  LegalizationPatterns &anyOpLegalizerPatterns,
1864 
1865  /// Compute the benefit of each node within the computed legalization graph.
1866  /// This orders the patterns within 'legalizerPatterns' based upon two
1867  /// criteria:
1868  /// 1) Prefer patterns that have the lowest legalization depth, i.e.
1869  /// represent the more direct mapping to the target.
1870  /// 2) When comparing patterns with the same legalization depth, prefer the
1871  /// pattern with the highest PatternBenefit. This allows for users to
1872  /// prefer specific legalizations over others.
1873  void computeLegalizationGraphBenefit(
1874  LegalizationPatterns &anyOpLegalizerPatterns,
1876 
1877  /// Compute the legalization depth when legalizing an operation of the given
1878  /// type.
1879  unsigned computeOpLegalizationDepth(
1880  OperationName op, DenseMap<OperationName, unsigned> &minOpPatternDepth,
1882 
1883  /// Apply the conversion cost model to the given set of patterns, and return
1884  /// the smallest legalization depth of any of the patterns. See
1885  /// `computeLegalizationGraphBenefit` for the breakdown of the cost model.
1886  unsigned applyCostModelToPatterns(
1887  LegalizationPatterns &patterns,
1888  DenseMap<OperationName, unsigned> &minOpPatternDepth,
1890 
1891  /// The current set of patterns that have been applied.
1892  SmallPtrSet<const Pattern *, 8> appliedPatterns;
1893 
1894  /// The legalization information provided by the target.
1895  const ConversionTarget &target;
1896 
1897  /// The pattern applicator to use for conversions.
1898  PatternApplicator applicator;
1899 
1900  /// Dialect conversion configuration.
1901  const ConversionConfig &config;
1902 };
1903 } // namespace
1904 
1905 OperationLegalizer::OperationLegalizer(const ConversionTarget &targetInfo,
1906  const FrozenRewritePatternSet &patterns,
1907  const ConversionConfig &config)
1908  : target(targetInfo), applicator(patterns), config(config) {
1909  // The set of patterns that can be applied to illegal operations to transform
1910  // them into legal ones.
1912  LegalizationPatterns anyOpLegalizerPatterns;
1913 
1914  buildLegalizationGraph(anyOpLegalizerPatterns, legalizerPatterns);
1915  computeLegalizationGraphBenefit(anyOpLegalizerPatterns, legalizerPatterns);
1916 }
1917 
1918 bool OperationLegalizer::isIllegal(Operation *op) const {
1919  return target.isIllegal(op);
1920 }
1921 
1922 LogicalResult
1923 OperationLegalizer::legalize(Operation *op,
1924  ConversionPatternRewriter &rewriter) {
1925 #ifndef NDEBUG
1926  const char *logLineComment =
1927  "//===-------------------------------------------===//\n";
1928 
1929  auto &logger = rewriter.getImpl().logger;
1930 #endif
1931  LLVM_DEBUG({
1932  logger.getOStream() << "\n";
1933  logger.startLine() << logLineComment;
1934  logger.startLine() << "Legalizing operation : '" << op->getName() << "'("
1935  << op << ") {\n";
1936  logger.indent();
1937 
1938  // If the operation has no regions, just print it here.
1939  if (op->getNumRegions() == 0) {
1940  op->print(logger.startLine(), OpPrintingFlags().printGenericOpForm());
1941  logger.getOStream() << "\n\n";
1942  }
1943  });
1944 
1945  // Check if this operation is legal on the target.
1946  if (auto legalityInfo = target.isLegal(op)) {
1947  LLVM_DEBUG({
1948  logSuccess(
1949  logger, "operation marked legal by the target{0}",
1950  legalityInfo->isRecursivelyLegal
1951  ? "; NOTE: operation is recursively legal; skipping internals"
1952  : "");
1953  logger.startLine() << logLineComment;
1954  });
1955 
1956  // If this operation is recursively legal, mark its children as ignored so
1957  // that we don't consider them for legalization.
1958  if (legalityInfo->isRecursivelyLegal) {
1959  op->walk([&](Operation *nested) {
1960  if (op != nested)
1961  rewriter.getImpl().ignoredOps.insert(nested);
1962  });
1963  }
1964 
1965  return success();
1966  }
1967 
1968  // Check to see if the operation is ignored and doesn't need to be converted.
1969  if (rewriter.getImpl().isOpIgnored(op)) {
1970  LLVM_DEBUG({
1971  logSuccess(logger, "operation marked 'ignored' during conversion");
1972  logger.startLine() << logLineComment;
1973  });
1974  return success();
1975  }
1976 
1977  // If the operation isn't legal, try to fold it in-place.
1978  // TODO: Should we always try to do this, even if the op is
1979  // already legal?
1980  if (succeeded(legalizeWithFold(op, rewriter))) {
1981  LLVM_DEBUG({
1982  logSuccess(logger, "operation was folded");
1983  logger.startLine() << logLineComment;
1984  });
1985  return success();
1986  }
1987 
1988  // Otherwise, we need to apply a legalization pattern to this operation.
1989  if (succeeded(legalizeWithPattern(op, rewriter))) {
1990  LLVM_DEBUG({
1991  logSuccess(logger, "");
1992  logger.startLine() << logLineComment;
1993  });
1994  return success();
1995  }
1996 
1997  LLVM_DEBUG({
1998  logFailure(logger, "no matched legalization pattern");
1999  logger.startLine() << logLineComment;
2000  });
2001  return failure();
2002 }
2003 
2004 LogicalResult
2005 OperationLegalizer::legalizeWithFold(Operation *op,
2006  ConversionPatternRewriter &rewriter) {
2007  auto &rewriterImpl = rewriter.getImpl();
2008  RewriterState curState = rewriterImpl.getCurrentState();
2009 
2010  LLVM_DEBUG({
2011  rewriterImpl.logger.startLine() << "* Fold {\n";
2012  rewriterImpl.logger.indent();
2013  });
2014 
2015  // Try to fold the operation.
2016  SmallVector<Value, 2> replacementValues;
2017  rewriter.setInsertionPoint(op);
2018  if (failed(rewriter.tryFold(op, replacementValues))) {
2019  LLVM_DEBUG(logFailure(rewriterImpl.logger, "unable to fold"));
2020  return failure();
2021  }
2022  // An empty list of replacement values indicates that the fold was in-place.
2023  // As the operation changed, a new legalization needs to be attempted.
2024  if (replacementValues.empty())
2025  return legalize(op, rewriter);
2026 
2027  // Insert a replacement for 'op' with the folded replacement values.
2028  rewriter.replaceOp(op, replacementValues);
2029 
2030  // Recursively legalize any new constant operations.
2031  for (unsigned i = curState.numRewrites, e = rewriterImpl.rewrites.size();
2032  i != e; ++i) {
2033  auto *createOp =
2034  dyn_cast<CreateOperationRewrite>(rewriterImpl.rewrites[i].get());
2035  if (!createOp)
2036  continue;
2037  if (failed(legalize(createOp->getOperation(), rewriter))) {
2038  LLVM_DEBUG(logFailure(rewriterImpl.logger,
2039  "failed to legalize generated constant '{0}'",
2040  createOp->getOperation()->getName()));
2041  rewriterImpl.resetState(curState);
2042  return failure();
2043  }
2044  }
2045 
2046  LLVM_DEBUG(logSuccess(rewriterImpl.logger, ""));
2047  return success();
2048 }
2049 
2050 LogicalResult
2051 OperationLegalizer::legalizeWithPattern(Operation *op,
2052  ConversionPatternRewriter &rewriter) {
2053  auto &rewriterImpl = rewriter.getImpl();
2054 
2055  // Functor that returns if the given pattern may be applied.
2056  auto canApply = [&](const Pattern &pattern) {
2057  bool canApply = canApplyPattern(op, pattern, rewriter);
2058  if (canApply && config.listener)
2059  config.listener->notifyPatternBegin(pattern, op);
2060  return canApply;
2061  };
2062 
2063  // Functor that cleans up the rewriter state after a pattern failed to match.
2064  RewriterState curState = rewriterImpl.getCurrentState();
2065  auto onFailure = [&](const Pattern &pattern) {
2066  assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates");
2067  LLVM_DEBUG({
2068  logFailure(rewriterImpl.logger, "pattern failed to match");
2069  if (rewriterImpl.config.notifyCallback) {
2071  diag << "Failed to apply pattern \"" << pattern.getDebugName()
2072  << "\" on op:\n"
2073  << *op;
2074  rewriterImpl.config.notifyCallback(diag);
2075  }
2076  });
2077  if (config.listener)
2078  config.listener->notifyPatternEnd(pattern, failure());
2079  rewriterImpl.resetState(curState);
2080  appliedPatterns.erase(&pattern);
2081  };
2082 
2083  // Functor that performs additional legalization when a pattern is
2084  // successfully applied.
2085  auto onSuccess = [&](const Pattern &pattern) {
2086  assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates");
2087  auto result = legalizePatternResult(op, pattern, rewriter, curState);
2088  appliedPatterns.erase(&pattern);
2089  if (failed(result))
2090  rewriterImpl.resetState(curState);
2091  if (config.listener)
2092  config.listener->notifyPatternEnd(pattern, result);
2093  return result;
2094  };
2095 
2096  // Try to match and rewrite a pattern on this operation.
2097  return applicator.matchAndRewrite(op, rewriter, canApply, onFailure,
2098  onSuccess);
2099 }
2100 
2101 bool OperationLegalizer::canApplyPattern(Operation *op, const Pattern &pattern,
2102  ConversionPatternRewriter &rewriter) {
2103  LLVM_DEBUG({
2104  auto &os = rewriter.getImpl().logger;
2105  os.getOStream() << "\n";
2106  os.startLine() << "* Pattern : '" << op->getName() << " -> (";
2107  llvm::interleaveComma(pattern.getGeneratedOps(), os.getOStream());
2108  os.getOStream() << ")' {\n";
2109  os.indent();
2110  });
2111 
2112  // Ensure that we don't cycle by not allowing the same pattern to be
2113  // applied twice in the same recursion stack if it is not known to be safe.
2114  if (!pattern.hasBoundedRewriteRecursion() &&
2115  !appliedPatterns.insert(&pattern).second) {
2116  LLVM_DEBUG(
2117  logFailure(rewriter.getImpl().logger, "pattern was already applied"));
2118  return false;
2119  }
2120  return true;
2121 }
2122 
2123 LogicalResult
2124 OperationLegalizer::legalizePatternResult(Operation *op, const Pattern &pattern,
2125  ConversionPatternRewriter &rewriter,
2126  RewriterState &curState) {
2127  auto &impl = rewriter.getImpl();
2128 
2129 #ifndef NDEBUG
2130  assert(impl.pendingRootUpdates.empty() && "dangling root updates");
2131  // Check that the root was either replaced or updated in place.
2132  auto newRewrites = llvm::drop_begin(impl.rewrites, curState.numRewrites);
2133  auto replacedRoot = [&] {
2134  return hasRewrite<ReplaceOperationRewrite>(newRewrites, op);
2135  };
2136  auto updatedRootInPlace = [&] {
2137  return hasRewrite<ModifyOperationRewrite>(newRewrites, op);
2138  };
2139  assert((replacedRoot() || updatedRootInPlace()) &&
2140  "expected pattern to replace the root operation");
2141 #endif // NDEBUG
2142 
2143  // Legalize each of the actions registered during application.
2144  RewriterState newState = impl.getCurrentState();
2145  if (failed(legalizePatternBlockRewrites(op, rewriter, impl, curState,
2146  newState)) ||
2147  failed(legalizePatternRootUpdates(rewriter, impl, curState, newState)) ||
2148  failed(legalizePatternCreatedOperations(rewriter, impl, curState,
2149  newState))) {
2150  return failure();
2151  }
2152 
2153  LLVM_DEBUG(logSuccess(impl.logger, "pattern applied successfully"));
2154  return success();
2155 }
2156 
2157 LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
2158  Operation *op, ConversionPatternRewriter &rewriter,
2159  ConversionPatternRewriterImpl &impl, RewriterState &state,
2160  RewriterState &newState) {
2161  SmallPtrSet<Operation *, 16> operationsToIgnore;
2162 
2163  // If the pattern moved or created any blocks, make sure the types of block
2164  // arguments get legalized.
2165  for (int i = state.numRewrites, e = newState.numRewrites; i != e; ++i) {
2166  BlockRewrite *rewrite = dyn_cast<BlockRewrite>(impl.rewrites[i].get());
2167  if (!rewrite)
2168  continue;
2169  Block *block = rewrite->getBlock();
2170  if (isa<BlockTypeConversionRewrite, EraseBlockRewrite,
2171  ReplaceBlockArgRewrite>(rewrite))
2172  continue;
2173  // Only check blocks outside of the current operation.
2174  Operation *parentOp = block->getParentOp();
2175  if (!parentOp || parentOp == op || block->getNumArguments() == 0)
2176  continue;
2177 
2178  // If the region of the block has a type converter, try to convert the block
2179  // directly.
2180  if (auto *converter = impl.regionToConverter.lookup(block->getParent())) {
2181  std::optional<TypeConverter::SignatureConversion> conversion =
2182  converter->convertBlockSignature(block);
2183  if (!conversion) {
2184  LLVM_DEBUG(logFailure(impl.logger, "failed to convert types of moved "
2185  "block"));
2186  return failure();
2187  }
2188  impl.applySignatureConversion(rewriter, block, converter, *conversion);
2189  continue;
2190  }
2191 
2192  // Otherwise, check that this operation isn't one generated by this pattern.
2193  // This is because we will attempt to legalize the parent operation, and
2194  // blocks in regions created by this pattern will already be legalized later
2195  // on. If we haven't built the set yet, build it now.
2196  if (operationsToIgnore.empty()) {
2197  for (unsigned i = state.numRewrites, e = impl.rewrites.size(); i != e;
2198  ++i) {
2199  auto *createOp =
2200  dyn_cast<CreateOperationRewrite>(impl.rewrites[i].get());
2201  if (!createOp)
2202  continue;
2203  operationsToIgnore.insert(createOp->getOperation());
2204  }
2205  }
2206 
2207  // If this operation should be considered for re-legalization, try it.
2208  if (operationsToIgnore.insert(parentOp).second &&
2209  failed(legalize(parentOp, rewriter))) {
2210  LLVM_DEBUG(logFailure(impl.logger,
2211  "operation '{0}'({1}) became illegal after rewrite",
2212  parentOp->getName(), parentOp));
2213  return failure();
2214  }
2215  }
2216  return success();
2217 }
2218 
2219 LogicalResult OperationLegalizer::legalizePatternCreatedOperations(
2221  RewriterState &state, RewriterState &newState) {
2222  for (int i = state.numRewrites, e = newState.numRewrites; i != e; ++i) {
2223  auto *createOp = dyn_cast<CreateOperationRewrite>(impl.rewrites[i].get());
2224  if (!createOp)
2225  continue;
2226  Operation *op = createOp->getOperation();
2227  if (failed(legalize(op, rewriter))) {
2228  LLVM_DEBUG(logFailure(impl.logger,
2229  "failed to legalize generated operation '{0}'({1})",
2230  op->getName(), op));
2231  return failure();
2232  }
2233  }
2234  return success();
2235 }
2236 
2237 LogicalResult OperationLegalizer::legalizePatternRootUpdates(
2239  RewriterState &state, RewriterState &newState) {
2240  for (int i = state.numRewrites, e = newState.numRewrites; i != e; ++i) {
2241  auto *rewrite = dyn_cast<ModifyOperationRewrite>(impl.rewrites[i].get());
2242  if (!rewrite)
2243  continue;
2244  Operation *op = rewrite->getOperation();
2245  if (failed(legalize(op, rewriter))) {
2246  LLVM_DEBUG(logFailure(
2247  impl.logger, "failed to legalize operation updated in-place '{0}'",
2248  op->getName()));
2249  return failure();
2250  }
2251  }
2252  return success();
2253 }
2254 
2255 //===----------------------------------------------------------------------===//
2256 // Cost Model
2257 
2258 void OperationLegalizer::buildLegalizationGraph(
2259  LegalizationPatterns &anyOpLegalizerPatterns,
2260  DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
2261  // A mapping between an operation and a set of operations that can be used to
2262  // generate it.
2264  // A mapping between an operation and any currently invalid patterns it has.
2266  // A worklist of patterns to consider for legality.
2267  SetVector<const Pattern *> patternWorklist;
2268 
2269  // Build the mapping from operations to the parent ops that may generate them.
2270  applicator.walkAllPatterns([&](const Pattern &pattern) {
2271  std::optional<OperationName> root = pattern.getRootKind();
2272 
2273  // If the pattern has no specific root, we can't analyze the relationship
2274  // between the root op and generated operations. Given that, add all such
2275  // patterns to the legalization set.
2276  if (!root) {
2277  anyOpLegalizerPatterns.push_back(&pattern);
2278  return;
2279  }
2280 
2281  // Skip operations that are always known to be legal.
2282  if (target.getOpAction(*root) == LegalizationAction::Legal)
2283  return;
2284 
2285  // Add this pattern to the invalid set for the root op and record this root
2286  // as a parent for any generated operations.
2287  invalidPatterns[*root].insert(&pattern);
2288  for (auto op : pattern.getGeneratedOps())
2289  parentOps[op].insert(*root);
2290 
2291  // Add this pattern to the worklist.
2292  patternWorklist.insert(&pattern);
2293  });
2294 
2295  // If there are any patterns that don't have a specific root kind, we can't
2296  // make direct assumptions about what operations will never be legalized.
2297  // Note: Technically we could, but it would require an analysis that may
2298  // recurse into itself. It would be better to perform this kind of filtering
2299  // at a higher level than here anyways.
2300  if (!anyOpLegalizerPatterns.empty()) {
2301  for (const Pattern *pattern : patternWorklist)
2302  legalizerPatterns[*pattern->getRootKind()].push_back(pattern);
2303  return;
2304  }
2305 
2306  while (!patternWorklist.empty()) {
2307  auto *pattern = patternWorklist.pop_back_val();
2308 
2309  // Check to see if any of the generated operations are invalid.
2310  if (llvm::any_of(pattern->getGeneratedOps(), [&](OperationName op) {
2311  std::optional<LegalizationAction> action = target.getOpAction(op);
2312  return !legalizerPatterns.count(op) &&
2313  (!action || action == LegalizationAction::Illegal);
2314  }))
2315  continue;
2316 
2317  // Otherwise, if all of the generated operation are valid, this op is now
2318  // legal so add all of the child patterns to the worklist.
2319  legalizerPatterns[*pattern->getRootKind()].push_back(pattern);
2320  invalidPatterns[*pattern->getRootKind()].erase(pattern);
2321 
2322  // Add any invalid patterns of the parent operations to see if they have now
2323  // become legal.
2324  for (auto op : parentOps[*pattern->getRootKind()])
2325  patternWorklist.set_union(invalidPatterns[op]);
2326  }
2327 }
2328 
2329 void OperationLegalizer::computeLegalizationGraphBenefit(
2330  LegalizationPatterns &anyOpLegalizerPatterns,
2331  DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
2332  // The smallest pattern depth, when legalizing an operation.
2333  DenseMap<OperationName, unsigned> minOpPatternDepth;
2334 
2335  // For each operation that is transitively legal, compute a cost for it.
2336  for (auto &opIt : legalizerPatterns)
2337  if (!minOpPatternDepth.count(opIt.first))
2338  computeOpLegalizationDepth(opIt.first, minOpPatternDepth,
2339  legalizerPatterns);
2340 
2341  // Apply the cost model to the patterns that can match any operation. Those
2342  // with a specific operation type are already resolved when computing the op
2343  // legalization depth.
2344  if (!anyOpLegalizerPatterns.empty())
2345  applyCostModelToPatterns(anyOpLegalizerPatterns, minOpPatternDepth,
2346  legalizerPatterns);
2347 
2348  // Apply a cost model to the pattern applicator. We order patterns first by
2349  // depth then benefit. `legalizerPatterns` contains per-op patterns by
2350  // decreasing benefit.
2351  applicator.applyCostModel([&](const Pattern &pattern) {
2352  ArrayRef<const Pattern *> orderedPatternList;
2353  if (std::optional<OperationName> rootName = pattern.getRootKind())
2354  orderedPatternList = legalizerPatterns[*rootName];
2355  else
2356  orderedPatternList = anyOpLegalizerPatterns;
2357 
2358  // If the pattern is not found, then it was removed and cannot be matched.
2359  auto *it = llvm::find(orderedPatternList, &pattern);
2360  if (it == orderedPatternList.end())
2362 
2363  // Patterns found earlier in the list have higher benefit.
2364  return PatternBenefit(std::distance(it, orderedPatternList.end()));
2365  });
2366 }
2367 
2368 unsigned OperationLegalizer::computeOpLegalizationDepth(
2369  OperationName op, DenseMap<OperationName, unsigned> &minOpPatternDepth,
2370  DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
2371  // Check for existing depth.
2372  auto depthIt = minOpPatternDepth.find(op);
2373  if (depthIt != minOpPatternDepth.end())
2374  return depthIt->second;
2375 
2376  // If a mapping for this operation does not exist, then this operation
2377  // is always legal. Return 0 as the depth for a directly legal operation.
2378  auto opPatternsIt = legalizerPatterns.find(op);
2379  if (opPatternsIt == legalizerPatterns.end() || opPatternsIt->second.empty())
2380  return 0u;
2381 
2382  // Record this initial depth in case we encounter this op again when
2383  // recursively computing the depth.
2384  minOpPatternDepth.try_emplace(op, std::numeric_limits<unsigned>::max());
2385 
2386  // Apply the cost model to the operation patterns, and update the minimum
2387  // depth.
2388  unsigned minDepth = applyCostModelToPatterns(
2389  opPatternsIt->second, minOpPatternDepth, legalizerPatterns);
2390  minOpPatternDepth[op] = minDepth;
2391  return minDepth;
2392 }
2393 
2394 unsigned OperationLegalizer::applyCostModelToPatterns(
2395  LegalizationPatterns &patterns,
2396  DenseMap<OperationName, unsigned> &minOpPatternDepth,
2397  DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
2398  unsigned minDepth = std::numeric_limits<unsigned>::max();
2399 
2400  // Compute the depth for each pattern within the set.
2401  SmallVector<std::pair<const Pattern *, unsigned>, 4> patternsByDepth;
2402  patternsByDepth.reserve(patterns.size());
2403  for (const Pattern *pattern : patterns) {
2404  unsigned depth = 1;
2405  for (auto generatedOp : pattern->getGeneratedOps()) {
2406  unsigned generatedOpDepth = computeOpLegalizationDepth(
2407  generatedOp, minOpPatternDepth, legalizerPatterns);
2408  depth = std::max(depth, generatedOpDepth + 1);
2409  }
2410  patternsByDepth.emplace_back(pattern, depth);
2411 
2412  // Update the minimum depth of the pattern list.
2413  minDepth = std::min(minDepth, depth);
2414  }
2415 
2416  // If the operation only has one legalization pattern, there is no need to
2417  // sort them.
2418  if (patternsByDepth.size() == 1)
2419  return minDepth;
2420 
2421  // Sort the patterns by those likely to be the most beneficial.
2422  std::stable_sort(patternsByDepth.begin(), patternsByDepth.end(),
2423  [](const std::pair<const Pattern *, unsigned> &lhs,
2424  const std::pair<const Pattern *, unsigned> &rhs) {
2425  // First sort by the smaller pattern legalization
2426  // depth.
2427  if (lhs.second != rhs.second)
2428  return lhs.second < rhs.second;
2429 
2430  // Then sort by the larger pattern benefit.
2431  auto lhsBenefit = lhs.first->getBenefit();
2432  auto rhsBenefit = rhs.first->getBenefit();
2433  return lhsBenefit > rhsBenefit;
2434  });
2435 
2436  // Update the legalization pattern to use the new sorted list.
2437  patterns.clear();
2438  for (auto &patternIt : patternsByDepth)
2439  patterns.push_back(patternIt.first);
2440  return minDepth;
2441 }
2442 
2443 //===----------------------------------------------------------------------===//
2444 // OperationConverter
2445 //===----------------------------------------------------------------------===//
2446 namespace {
2447 enum OpConversionMode {
2448  /// In this mode, the conversion will ignore failed conversions to allow
2449  /// illegal operations to co-exist in the IR.
2450  Partial,
2451 
2452  /// In this mode, all operations must be legal for the given target for the
2453  /// conversion to succeed.
2454  Full,
2455 
2456  /// In this mode, operations are analyzed for legality. No actual rewrites are
2457  /// applied to the operations on success.
2458  Analysis,
2459 };
2460 } // namespace
2461 
2462 namespace mlir {
2463 // This class converts operations to a given conversion target via a set of
2464 // rewrite patterns. The conversion behaves differently depending on the
2465 // conversion mode.
2467  explicit OperationConverter(const ConversionTarget &target,
2468  const FrozenRewritePatternSet &patterns,
2469  const ConversionConfig &config,
2470  OpConversionMode mode)
2471  : config(config), opLegalizer(target, patterns, this->config),
2472  mode(mode) {}
2473 
2474  /// Converts the given operations to the conversion target.
2475  LogicalResult convertOperations(ArrayRef<Operation *> ops);
2476 
2477 private:
2478  /// Converts an operation with the given rewriter.
2479  LogicalResult convert(ConversionPatternRewriter &rewriter, Operation *op);
2480 
2481  /// Dialect conversion configuration.
2482  ConversionConfig config;
2483 
2484  /// The legalizer to use when converting operations.
2485  OperationLegalizer opLegalizer;
2486 
2487  /// The conversion mode to use when legalizing operations.
2488  OpConversionMode mode;
2489 };
2490 } // namespace mlir
2491 
2492 LogicalResult OperationConverter::convert(ConversionPatternRewriter &rewriter,
2493  Operation *op) {
2494  // Legalize the given operation.
2495  if (failed(opLegalizer.legalize(op, rewriter))) {
2496  // Handle the case of a failed conversion for each of the different modes.
2497  // Full conversions expect all operations to be converted.
2498  if (mode == OpConversionMode::Full)
2499  return op->emitError()
2500  << "failed to legalize operation '" << op->getName() << "'";
2501  // Partial conversions allow conversions to fail iff the operation was not
2502  // explicitly marked as illegal. If the user provided a `unlegalizedOps`
2503  // set, non-legalizable ops are added to that set.
2504  if (mode == OpConversionMode::Partial) {
2505  if (opLegalizer.isIllegal(op))
2506  return op->emitError()
2507  << "failed to legalize operation '" << op->getName()
2508  << "' that was explicitly marked illegal";
2509  if (config.unlegalizedOps)
2510  config.unlegalizedOps->insert(op);
2511  }
2512  } else if (mode == OpConversionMode::Analysis) {
2513  // Analysis conversions don't fail if any operations fail to legalize,
2514  // they are only interested in the operations that were successfully
2515  // legalized.
2516  if (config.legalizableOps)
2517  config.legalizableOps->insert(op);
2518  }
2519  return success();
2520 }
2521 
2522 static LogicalResult
2524  UnresolvedMaterializationRewrite *rewrite) {
2525  UnrealizedConversionCastOp op = rewrite->getOperation();
2526  assert(!op.use_empty() &&
2527  "expected that dead materializations have already been DCE'd");
2528  Operation::operand_range inputOperands = op.getOperands();
2529  Type outputType = op.getResultTypes()[0];
2530 
2531  // Try to materialize the conversion.
2532  if (const TypeConverter *converter = rewrite->getConverter()) {
2533  rewriter.setInsertionPoint(op);
2534  Value newMaterialization;
2535  switch (rewrite->getMaterializationKind()) {
2537  // Try to materialize an argument conversion.
2538  newMaterialization = converter->materializeArgumentConversion(
2539  rewriter, op->getLoc(), outputType, inputOperands);
2540  if (newMaterialization)
2541  break;
2542  // If an argument materialization failed, fallback to trying a target
2543  // materialization.
2544  [[fallthrough]];
2545  case MaterializationKind::Target:
2546  newMaterialization = converter->materializeTargetConversion(
2547  rewriter, op->getLoc(), outputType, inputOperands,
2548  rewrite->getOriginalType());
2549  break;
2550  case MaterializationKind::Source:
2551  newMaterialization = converter->materializeSourceConversion(
2552  rewriter, op->getLoc(), outputType, inputOperands);
2553  break;
2554  }
2555  if (newMaterialization) {
2556  assert(newMaterialization.getType() == outputType &&
2557  "materialization callback produced value of incorrect type");
2558  rewriter.replaceOp(op, newMaterialization);
2559  return success();
2560  }
2561  }
2562 
2564  op->emitError() << "failed to legalize unresolved materialization "
2565  "from ("
2566  << inputOperands.getTypes() << ") to (" << outputType
2567  << ") that remained live after conversion";
2568  diag.attachNote(op->getUsers().begin()->getLoc())
2569  << "see existing live user here: " << *op->getUsers().begin();
2570  return failure();
2571 }
2572 
2574  if (ops.empty())
2575  return success();
2576  const ConversionTarget &target = opLegalizer.getTarget();
2577 
2578  // Compute the set of operations and blocks to convert.
2579  SmallVector<Operation *> toConvert;
2580  for (auto *op : ops) {
2582  [&](Operation *op) {
2583  toConvert.push_back(op);
2584  // Don't check this operation's children for conversion if the
2585  // operation is recursively legal.
2586  auto legalityInfo = target.isLegal(op);
2587  if (legalityInfo && legalityInfo->isRecursivelyLegal)
2588  return WalkResult::skip();
2589  return WalkResult::advance();
2590  });
2591  }
2592 
2593  // Convert each operation and discard rewrites on failure.
2594  ConversionPatternRewriter rewriter(ops.front()->getContext(), config);
2595  ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
2596 
2597  for (auto *op : toConvert)
2598  if (failed(convert(rewriter, op)))
2599  return rewriterImpl.undoRewrites(), failure();
2600 
2601  // After a successful conversion, apply rewrites.
2602  rewriterImpl.applyRewrites();
2603 
2604  // Gather all unresolved materializations.
2607  &materializations = rewriterImpl.unresolvedMaterializations;
2608  for (auto it : materializations) {
2609  if (rewriterImpl.eraseRewriter.wasErased(it.first))
2610  continue;
2611  allCastOps.push_back(it.first);
2612  }
2613 
2614  // Reconcile all UnrealizedConversionCastOps that were inserted by the
2615  // dialect conversion frameworks. (Not the one that were inserted by
2616  // patterns.)
2617  SmallVector<UnrealizedConversionCastOp> remainingCastOps;
2618  reconcileUnrealizedCasts(allCastOps, &remainingCastOps);
2619 
2620  // Try to legalize all unresolved materializations.
2621  if (config.buildMaterializations) {
2622  IRRewriter rewriter(rewriterImpl.context, config.listener);
2623  for (UnrealizedConversionCastOp castOp : remainingCastOps) {
2624  auto it = materializations.find(castOp);
2625  assert(it != materializations.end() && "inconsistent state");
2626  if (failed(legalizeUnresolvedMaterialization(rewriter, it->second)))
2627  return failure();
2628  }
2629  }
2630 
2631  return success();
2632 }
2633 
2634 //===----------------------------------------------------------------------===//
2635 // Reconcile Unrealized Casts
2636 //===----------------------------------------------------------------------===//
2637 
2640  SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
2641  SetVector<UnrealizedConversionCastOp> worklist(castOps.begin(),
2642  castOps.end());
2643  // This set is maintained only if `remainingCastOps` is provided.
2644  DenseSet<Operation *> erasedOps;
2645 
2646  // Helper function that adds all operands to the worklist that are an
2647  // unrealized_conversion_cast op result.
2648  auto enqueueOperands = [&](UnrealizedConversionCastOp castOp) {
2649  for (Value v : castOp.getInputs())
2650  if (auto inputCastOp = v.getDefiningOp<UnrealizedConversionCastOp>())
2651  worklist.insert(inputCastOp);
2652  };
2653 
2654  // Helper function that return the unrealized_conversion_cast op that
2655  // defines all inputs of the given op (in the same order). Return "nullptr"
2656  // if there is no such op.
2657  auto getInputCast =
2658  [](UnrealizedConversionCastOp castOp) -> UnrealizedConversionCastOp {
2659  if (castOp.getInputs().empty())
2660  return {};
2661  auto inputCastOp =
2662  castOp.getInputs().front().getDefiningOp<UnrealizedConversionCastOp>();
2663  if (!inputCastOp)
2664  return {};
2665  if (inputCastOp.getOutputs() != castOp.getInputs())
2666  return {};
2667  return inputCastOp;
2668  };
2669 
2670  // Process ops in the worklist bottom-to-top.
2671  while (!worklist.empty()) {
2672  UnrealizedConversionCastOp castOp = worklist.pop_back_val();
2673  if (castOp->use_empty()) {
2674  // DCE: If the op has no users, erase it. Add the operands to the
2675  // worklist to find additional DCE opportunities.
2676  enqueueOperands(castOp);
2677  if (remainingCastOps)
2678  erasedOps.insert(castOp.getOperation());
2679  castOp->erase();
2680  continue;
2681  }
2682 
2683  // Traverse the chain of input cast ops to see if an op with the same
2684  // input types can be found.
2685  UnrealizedConversionCastOp nextCast = castOp;
2686  while (nextCast) {
2687  if (nextCast.getInputs().getTypes() == castOp.getResultTypes()) {
2688  // Found a cast where the input types match the output types of the
2689  // matched op. We can directly use those inputs and the matched op can
2690  // be removed.
2691  enqueueOperands(castOp);
2692  castOp.replaceAllUsesWith(nextCast.getInputs());
2693  if (remainingCastOps)
2694  erasedOps.insert(castOp.getOperation());
2695  castOp->erase();
2696  break;
2697  }
2698  nextCast = getInputCast(nextCast);
2699  }
2700  }
2701 
2702  if (remainingCastOps)
2703  for (UnrealizedConversionCastOp op : castOps)
2704  if (!erasedOps.contains(op.getOperation()))
2705  remainingCastOps->push_back(op);
2706 }
2707 
2708 //===----------------------------------------------------------------------===//
2709 // Type Conversion
2710 //===----------------------------------------------------------------------===//
2711 
2713  ArrayRef<Type> types) {
2714  assert(!types.empty() && "expected valid types");
2715  remapInput(origInputNo, /*newInputNo=*/argTypes.size(), types.size());
2716  addInputs(types);
2717 }
2718 
2720  assert(!types.empty() &&
2721  "1->0 type remappings don't need to be added explicitly");
2722  argTypes.append(types.begin(), types.end());
2723 }
2724 
2725 void TypeConverter::SignatureConversion::remapInput(unsigned origInputNo,
2726  unsigned newInputNo,
2727  unsigned newInputCount) {
2728  assert(!remappedInputs[origInputNo] && "input has already been remapped");
2729  assert(newInputCount != 0 && "expected valid input count");
2730  remappedInputs[origInputNo] =
2731  InputMapping{newInputNo, newInputCount, /*replacementValue=*/nullptr};
2732 }
2733 
2735  Value replacementValue) {
2736  assert(!remappedInputs[origInputNo] && "input has already been remapped");
2737  remappedInputs[origInputNo] =
2738  InputMapping{origInputNo, /*size=*/0, replacementValue};
2739 }
2740 
2742  SmallVectorImpl<Type> &results) const {
2743  {
2744  std::shared_lock<decltype(cacheMutex)> cacheReadLock(cacheMutex,
2745  std::defer_lock);
2747  cacheReadLock.lock();
2748  auto existingIt = cachedDirectConversions.find(t);
2749  if (existingIt != cachedDirectConversions.end()) {
2750  if (existingIt->second)
2751  results.push_back(existingIt->second);
2752  return success(existingIt->second != nullptr);
2753  }
2754  auto multiIt = cachedMultiConversions.find(t);
2755  if (multiIt != cachedMultiConversions.end()) {
2756  results.append(multiIt->second.begin(), multiIt->second.end());
2757  return success();
2758  }
2759  }
2760  // Walk the added converters in reverse order to apply the most recently
2761  // registered first.
2762  size_t currentCount = results.size();
2763 
2764  std::unique_lock<decltype(cacheMutex)> cacheWriteLock(cacheMutex,
2765  std::defer_lock);
2766 
2767  for (const ConversionCallbackFn &converter : llvm::reverse(conversions)) {
2768  if (std::optional<LogicalResult> result = converter(t, results)) {
2770  cacheWriteLock.lock();
2771  if (!succeeded(*result)) {
2772  cachedDirectConversions.try_emplace(t, nullptr);
2773  return failure();
2774  }
2775  auto newTypes = ArrayRef<Type>(results).drop_front(currentCount);
2776  if (newTypes.size() == 1)
2777  cachedDirectConversions.try_emplace(t, newTypes.front());
2778  else
2779  cachedMultiConversions.try_emplace(t, llvm::to_vector<2>(newTypes));
2780  return success();
2781  }
2782  }
2783  return failure();
2784 }
2785 
2787  // Use the multi-type result version to convert the type.
2788  SmallVector<Type, 1> results;
2789  if (failed(convertType(t, results)))
2790  return nullptr;
2791 
2792  // Check to ensure that only one type was produced.
2793  return results.size() == 1 ? results.front() : nullptr;
2794 }
2795 
2796 LogicalResult
2798  SmallVectorImpl<Type> &results) const {
2799  for (Type type : types)
2800  if (failed(convertType(type, results)))
2801  return failure();
2802  return success();
2803 }
2804 
2805 bool TypeConverter::isLegal(Type type) const {
2806  return convertType(type) == type;
2807 }
2809  return isLegal(op->getOperandTypes()) && isLegal(op->getResultTypes());
2810 }
2811 
2812 bool TypeConverter::isLegal(Region *region) const {
2813  return llvm::all_of(*region, [this](Block &block) {
2814  return isLegal(block.getArgumentTypes());
2815  });
2816 }
2817 
2818 bool TypeConverter::isSignatureLegal(FunctionType ty) const {
2819  return isLegal(llvm::concat<const Type>(ty.getInputs(), ty.getResults()));
2820 }
2821 
2822 LogicalResult
2824  SignatureConversion &result) const {
2825  // Try to convert the given input type.
2826  SmallVector<Type, 1> convertedTypes;
2827  if (failed(convertType(type, convertedTypes)))
2828  return failure();
2829 
2830  // If this argument is being dropped, there is nothing left to do.
2831  if (convertedTypes.empty())
2832  return success();
2833 
2834  // Otherwise, add the new inputs.
2835  result.addInputs(inputNo, convertedTypes);
2836  return success();
2837 }
2838 LogicalResult
2840  SignatureConversion &result,
2841  unsigned origInputOffset) const {
2842  for (unsigned i = 0, e = types.size(); i != e; ++i)
2843  if (failed(convertSignatureArg(origInputOffset + i, types[i], result)))
2844  return failure();
2845  return success();
2846 }
2847 
2849  Location loc,
2850  Type resultType,
2851  ValueRange inputs) const {
2852  for (const MaterializationCallbackFn &fn :
2853  llvm::reverse(argumentMaterializations))
2854  if (Value result = fn(builder, resultType, inputs, loc))
2855  return result;
2856  return nullptr;
2857 }
2858 
2860  Location loc, Type resultType,
2861  ValueRange inputs) const {
2862  for (const MaterializationCallbackFn &fn :
2863  llvm::reverse(sourceMaterializations))
2864  if (Value result = fn(builder, resultType, inputs, loc))
2865  return result;
2866  return nullptr;
2867 }
2868 
2870  Location loc, Type resultType,
2871  ValueRange inputs,
2872  Type originalType) const {
2874  builder, loc, TypeRange(resultType), inputs, originalType);
2875  if (result.empty())
2876  return nullptr;
2877  assert(result.size() == 1 && "expected single result");
2878  return result.front();
2879 }
2880 
2882  OpBuilder &builder, Location loc, TypeRange resultTypes, ValueRange inputs,
2883  Type originalType) const {
2884  for (const TargetMaterializationCallbackFn &fn :
2885  llvm::reverse(targetMaterializations)) {
2886  SmallVector<Value> result =
2887  fn(builder, resultTypes, inputs, loc, originalType);
2888  if (result.empty())
2889  continue;
2890  assert(TypeRange(ValueRange(result)) == resultTypes &&
2891  "callback produced incorrect number of values or values with "
2892  "incorrect types");
2893  return result;
2894  }
2895  return {};
2896 }
2897 
2898 std::optional<TypeConverter::SignatureConversion>
2900  SignatureConversion conversion(block->getNumArguments());
2901  if (failed(convertSignatureArgs(block->getArgumentTypes(), conversion)))
2902  return std::nullopt;
2903  return conversion;
2904 }
2905 
2906 //===----------------------------------------------------------------------===//
2907 // Type attribute conversion
2908 //===----------------------------------------------------------------------===//
2911  return AttributeConversionResult(attr, resultTag);
2912 }
2913 
2916  return AttributeConversionResult(nullptr, naTag);
2917 }
2918 
2921  return AttributeConversionResult(nullptr, abortTag);
2922 }
2923 
2925  return impl.getInt() == resultTag;
2926 }
2927 
2929  return impl.getInt() == naTag;
2930 }
2931 
2933  return impl.getInt() == abortTag;
2934 }
2935 
2937  assert(hasResult() && "Cannot get result from N/A or abort");
2938  return impl.getPointer();
2939 }
2940 
2941 std::optional<Attribute>
2943  for (const TypeAttributeConversionCallbackFn &fn :
2944  llvm::reverse(typeAttributeConversions)) {
2945  AttributeConversionResult res = fn(type, attr);
2946  if (res.hasResult())
2947  return res.getResult();
2948  if (res.isAbort())
2949  return std::nullopt;
2950  }
2951  return std::nullopt;
2952 }
2953 
2954 //===----------------------------------------------------------------------===//
2955 // FunctionOpInterfaceSignatureConversion
2956 //===----------------------------------------------------------------------===//
2957 
2958 static LogicalResult convertFuncOpTypes(FunctionOpInterface funcOp,
2959  const TypeConverter &typeConverter,
2960  ConversionPatternRewriter &rewriter) {
2961  FunctionType type = dyn_cast<FunctionType>(funcOp.getFunctionType());
2962  if (!type)
2963  return failure();
2964 
2965  // Convert the original function types.
2966  TypeConverter::SignatureConversion result(type.getNumInputs());
2967  SmallVector<Type, 1> newResults;
2968  if (failed(typeConverter.convertSignatureArgs(type.getInputs(), result)) ||
2969  failed(typeConverter.convertTypes(type.getResults(), newResults)) ||
2970  failed(rewriter.convertRegionTypes(&funcOp.getFunctionBody(),
2971  typeConverter, &result)))
2972  return failure();
2973 
2974  // Update the function signature in-place.
2975  auto newType = FunctionType::get(rewriter.getContext(),
2976  result.getConvertedTypes(), newResults);
2977 
2978  rewriter.modifyOpInPlace(funcOp, [&] { funcOp.setType(newType); });
2979 
2980  return success();
2981 }
2982 
2983 /// Create a default conversion pattern that rewrites the type signature of a
2984 /// FunctionOpInterface op. This only supports ops which use FunctionType to
2985 /// represent their type.
2986 namespace {
2987 struct FunctionOpInterfaceSignatureConversion : public ConversionPattern {
2988  FunctionOpInterfaceSignatureConversion(StringRef functionLikeOpName,
2989  MLIRContext *ctx,
2990  const TypeConverter &converter)
2991  : ConversionPattern(converter, functionLikeOpName, /*benefit=*/1, ctx) {}
2992 
2993  LogicalResult
2994  matchAndRewrite(Operation *op, ArrayRef<Value> /*operands*/,
2995  ConversionPatternRewriter &rewriter) const override {
2996  FunctionOpInterface funcOp = cast<FunctionOpInterface>(op);
2997  return convertFuncOpTypes(funcOp, *typeConverter, rewriter);
2998  }
2999 };
3000 
3001 struct AnyFunctionOpInterfaceSignatureConversion
3002  : public OpInterfaceConversionPattern<FunctionOpInterface> {
3004 
3005  LogicalResult
3006  matchAndRewrite(FunctionOpInterface funcOp, ArrayRef<Value> /*operands*/,
3007  ConversionPatternRewriter &rewriter) const override {
3008  return convertFuncOpTypes(funcOp, *typeConverter, rewriter);
3009  }
3010 };
3011 } // namespace
3012 
3013 FailureOr<Operation *>
3015  const TypeConverter &converter,
3016  ConversionPatternRewriter &rewriter) {
3017  assert(op && "Invalid op");
3018  Location loc = op->getLoc();
3019  if (converter.isLegal(op))
3020  return rewriter.notifyMatchFailure(loc, "op already legal");
3021 
3022  OperationState newOp(loc, op->getName());
3023  newOp.addOperands(operands);
3024 
3025  SmallVector<Type> newResultTypes;
3026  if (failed(converter.convertTypes(op->getResultTypes(), newResultTypes)))
3027  return rewriter.notifyMatchFailure(loc, "couldn't convert return types");
3028 
3029  newOp.addTypes(newResultTypes);
3030  newOp.addAttributes(op->getAttrs());
3031  return rewriter.create(newOp);
3032 }
3033 
3035  StringRef functionLikeOpName, RewritePatternSet &patterns,
3036  const TypeConverter &converter) {
3037  patterns.add<FunctionOpInterfaceSignatureConversion>(
3038  functionLikeOpName, patterns.getContext(), converter);
3039 }
3040 
3042  RewritePatternSet &patterns, const TypeConverter &converter) {
3043  patterns.add<AnyFunctionOpInterfaceSignatureConversion>(
3044  converter, patterns.getContext());
3045 }
3046 
3047 //===----------------------------------------------------------------------===//
3048 // ConversionTarget
3049 //===----------------------------------------------------------------------===//
3050 
3052  LegalizationAction action) {
3053  legalOperations[op].action = action;
3054 }
3055 
3057  LegalizationAction action) {
3058  for (StringRef dialect : dialectNames)
3059  legalDialects[dialect] = action;
3060 }
3061 
3063  -> std::optional<LegalizationAction> {
3064  std::optional<LegalizationInfo> info = getOpInfo(op);
3065  return info ? info->action : std::optional<LegalizationAction>();
3066 }
3067 
3069  -> std::optional<LegalOpDetails> {
3070  std::optional<LegalizationInfo> info = getOpInfo(op->getName());
3071  if (!info)
3072  return std::nullopt;
3073 
3074  // Returns true if this operation instance is known to be legal.
3075  auto isOpLegal = [&] {
3076  // Handle dynamic legality either with the provided legality function.
3077  if (info->action == LegalizationAction::Dynamic) {
3078  std::optional<bool> result = info->legalityFn(op);
3079  if (result)
3080  return *result;
3081  }
3082 
3083  // Otherwise, the operation is only legal if it was marked 'Legal'.
3084  return info->action == LegalizationAction::Legal;
3085  };
3086  if (!isOpLegal())
3087  return std::nullopt;
3088 
3089  // This operation is legal, compute any additional legality information.
3090  LegalOpDetails legalityDetails;
3091  if (info->isRecursivelyLegal) {
3092  auto legalityFnIt = opRecursiveLegalityFns.find(op->getName());
3093  if (legalityFnIt != opRecursiveLegalityFns.end()) {
3094  legalityDetails.isRecursivelyLegal =
3095  legalityFnIt->second(op).value_or(true);
3096  } else {
3097  legalityDetails.isRecursivelyLegal = true;
3098  }
3099  }
3100  return legalityDetails;
3101 }
3102 
3104  std::optional<LegalizationInfo> info = getOpInfo(op->getName());
3105  if (!info)
3106  return false;
3107 
3108  if (info->action == LegalizationAction::Dynamic) {
3109  std::optional<bool> result = info->legalityFn(op);
3110  if (!result)
3111  return false;
3112 
3113  return !(*result);
3114  }
3115 
3116  return info->action == LegalizationAction::Illegal;
3117 }
3118 
3122  if (!oldCallback)
3123  return newCallback;
3124 
3125  auto chain = [oldCl = std::move(oldCallback), newCl = std::move(newCallback)](
3126  Operation *op) -> std::optional<bool> {
3127  if (std::optional<bool> result = newCl(op))
3128  return *result;
3129 
3130  return oldCl(op);
3131  };
3132  return chain;
3133 }
3134 
3135 void ConversionTarget::setLegalityCallback(
3136  OperationName name, const DynamicLegalityCallbackFn &callback) {
3137  assert(callback && "expected valid legality callback");
3138  auto *infoIt = legalOperations.find(name);
3139  assert(infoIt != legalOperations.end() &&
3140  infoIt->second.action == LegalizationAction::Dynamic &&
3141  "expected operation to already be marked as dynamically legal");
3142  infoIt->second.legalityFn =
3143  composeLegalityCallbacks(std::move(infoIt->second.legalityFn), callback);
3144 }
3145 
3147  OperationName name, const DynamicLegalityCallbackFn &callback) {
3148  auto *infoIt = legalOperations.find(name);
3149  assert(infoIt != legalOperations.end() &&
3150  infoIt->second.action != LegalizationAction::Illegal &&
3151  "expected operation to already be marked as legal");
3152  infoIt->second.isRecursivelyLegal = true;
3153  if (callback)
3154  opRecursiveLegalityFns[name] = composeLegalityCallbacks(
3155  std::move(opRecursiveLegalityFns[name]), callback);
3156  else
3157  opRecursiveLegalityFns.erase(name);
3158 }
3159 
3160 void ConversionTarget::setLegalityCallback(
3161  ArrayRef<StringRef> dialects, const DynamicLegalityCallbackFn &callback) {
3162  assert(callback && "expected valid legality callback");
3163  for (StringRef dialect : dialects)
3164  dialectLegalityFns[dialect] = composeLegalityCallbacks(
3165  std::move(dialectLegalityFns[dialect]), callback);
3166 }
3167 
3168 void ConversionTarget::setLegalityCallback(
3169  const DynamicLegalityCallbackFn &callback) {
3170  assert(callback && "expected valid legality callback");
3171  unknownLegalityFn = composeLegalityCallbacks(unknownLegalityFn, callback);
3172 }
3173 
3174 auto ConversionTarget::getOpInfo(OperationName op) const
3175  -> std::optional<LegalizationInfo> {
3176  // Check for info for this specific operation.
3177  const auto *it = legalOperations.find(op);
3178  if (it != legalOperations.end())
3179  return it->second;
3180  // Check for info for the parent dialect.
3181  auto dialectIt = legalDialects.find(op.getDialectNamespace());
3182  if (dialectIt != legalDialects.end()) {
3183  DynamicLegalityCallbackFn callback;
3184  auto dialectFn = dialectLegalityFns.find(op.getDialectNamespace());
3185  if (dialectFn != dialectLegalityFns.end())
3186  callback = dialectFn->second;
3187  return LegalizationInfo{dialectIt->second, /*isRecursivelyLegal=*/false,
3188  callback};
3189  }
3190  // Otherwise, check if we mark unknown operations as dynamic.
3191  if (unknownLegalityFn)
3192  return LegalizationInfo{LegalizationAction::Dynamic,
3193  /*isRecursivelyLegal=*/false, unknownLegalityFn};
3194  return std::nullopt;
3195 }
3196 
3197 #if MLIR_ENABLE_PDL_IN_PATTERNMATCH
3198 //===----------------------------------------------------------------------===//
3199 // PDL Configuration
3200 //===----------------------------------------------------------------------===//
3201 
3203  auto &rewriterImpl =
3204  static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
3205  rewriterImpl.currentTypeConverter = getTypeConverter();
3206 }
3207 
3209  auto &rewriterImpl =
3210  static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
3211  rewriterImpl.currentTypeConverter = nullptr;
3212 }
3213 
3214 /// Remap the given value using the rewriter and the type converter in the
3215 /// provided config.
3216 static FailureOr<SmallVector<Value>>
3218  SmallVector<Value> mappedValues;
3219  if (failed(rewriter.getRemappedValues(values, mappedValues)))
3220  return failure();
3221  return std::move(mappedValues);
3222 }
3223 
3225  patterns.getPDLPatterns().registerRewriteFunction(
3226  "convertValue",
3227  [](PatternRewriter &rewriter, Value value) -> FailureOr<Value> {
3228  auto results = pdllConvertValues(
3229  static_cast<ConversionPatternRewriter &>(rewriter), value);
3230  if (failed(results))
3231  return failure();
3232  return results->front();
3233  });
3234  patterns.getPDLPatterns().registerRewriteFunction(
3235  "convertValues", [](PatternRewriter &rewriter, ValueRange values) {
3236  return pdllConvertValues(
3237  static_cast<ConversionPatternRewriter &>(rewriter), values);
3238  });
3239  patterns.getPDLPatterns().registerRewriteFunction(
3240  "convertType",
3241  [](PatternRewriter &rewriter, Type type) -> FailureOr<Type> {
3242  auto &rewriterImpl =
3243  static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
3244  if (const TypeConverter *converter =
3245  rewriterImpl.currentTypeConverter) {
3246  if (Type newType = converter->convertType(type))
3247  return newType;
3248  return failure();
3249  }
3250  return type;
3251  });
3252  patterns.getPDLPatterns().registerRewriteFunction(
3253  "convertTypes",
3254  [](PatternRewriter &rewriter,
3255  TypeRange types) -> FailureOr<SmallVector<Type>> {
3256  auto &rewriterImpl =
3257  static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
3258  const TypeConverter *converter = rewriterImpl.currentTypeConverter;
3259  if (!converter)
3260  return SmallVector<Type>(types);
3261 
3262  SmallVector<Type> remappedTypes;
3263  if (failed(converter->convertTypes(types, remappedTypes)))
3264  return failure();
3265  return std::move(remappedTypes);
3266  });
3267 }
3268 #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
3269 
3270 //===----------------------------------------------------------------------===//
3271 // Op Conversion Entry Points
3272 //===----------------------------------------------------------------------===//
3273 
3274 //===----------------------------------------------------------------------===//
3275 // Partial Conversion
3276 
3278  ArrayRef<Operation *> ops, const ConversionTarget &target,
3279  const FrozenRewritePatternSet &patterns, ConversionConfig config) {
3280  OperationConverter opConverter(target, patterns, config,
3281  OpConversionMode::Partial);
3282  return opConverter.convertOperations(ops);
3283 }
3284 LogicalResult
3286  const FrozenRewritePatternSet &patterns,
3287  ConversionConfig config) {
3288  return applyPartialConversion(llvm::ArrayRef(op), target, patterns, config);
3289 }
3290 
3291 //===----------------------------------------------------------------------===//
3292 // Full Conversion
3293 
3295  const ConversionTarget &target,
3296  const FrozenRewritePatternSet &patterns,
3297  ConversionConfig config) {
3298  OperationConverter opConverter(target, patterns, config,
3299  OpConversionMode::Full);
3300  return opConverter.convertOperations(ops);
3301 }
3303  const ConversionTarget &target,
3304  const FrozenRewritePatternSet &patterns,
3305  ConversionConfig config) {
3306  return applyFullConversion(llvm::ArrayRef(op), target, patterns, config);
3307 }
3308 
3309 //===----------------------------------------------------------------------===//
3310 // Analysis Conversion
3311 
3312 /// Find a common IsolatedFromAbove ancestor of the given ops. If at least one
3313 /// op is a top-level module op (which is expected to be isolated from above),
3314 /// return that op.
3316  // Check if there is a top-level operation within `ops`. If so, return that
3317  // op.
3318  for (Operation *op : ops) {
3319  if (!op->getParentOp()) {
3320 #ifndef NDEBUG
3321  assert(op->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
3322  "expected top-level op to be isolated from above");
3323  for (Operation *other : ops)
3324  assert(op->isAncestor(other) &&
3325  "expected ops to have a common ancestor");
3326 #endif // NDEBUG
3327  return op;
3328  }
3329  }
3330 
3331  // No top-level op. Find a common ancestor.
3332  Operation *commonAncestor =
3333  ops.front()->getParentWithTrait<OpTrait::IsIsolatedFromAbove>();
3334  for (Operation *op : ops.drop_front()) {
3335  while (!commonAncestor->isProperAncestor(op)) {
3336  commonAncestor =
3338  assert(commonAncestor &&
3339  "expected to find a common isolated from above ancestor");
3340  }
3341  }
3342 
3343  return commonAncestor;
3344 }
3345 
3348  const FrozenRewritePatternSet &patterns, ConversionConfig config) {
3349 #ifndef NDEBUG
3350  if (config.legalizableOps)
3351  assert(config.legalizableOps->empty() && "expected empty set");
3352 #endif // NDEBUG
3353 
3354  // Clone closted common ancestor that is isolated from above.
3355  Operation *commonAncestor = findCommonAncestor(ops);
3356  IRMapping mapping;
3357  Operation *clonedAncestor = commonAncestor->clone(mapping);
3358  // Compute inverse IR mapping.
3359  DenseMap<Operation *, Operation *> inverseOperationMap;
3360  for (auto &it : mapping.getOperationMap())
3361  inverseOperationMap[it.second] = it.first;
3362 
3363  // Convert the cloned operations. The original IR will remain unchanged.
3364  SmallVector<Operation *> opsToConvert = llvm::map_to_vector(
3365  ops, [&](Operation *op) { return mapping.lookup(op); });
3366  OperationConverter opConverter(target, patterns, config,
3367  OpConversionMode::Analysis);
3368  LogicalResult status = opConverter.convertOperations(opsToConvert);
3369 
3370  // Remap `legalizableOps`, so that they point to the original ops and not the
3371  // cloned ops.
3372  if (config.legalizableOps) {
3373  DenseSet<Operation *> originalLegalizableOps;
3374  for (Operation *op : *config.legalizableOps)
3375  originalLegalizableOps.insert(inverseOperationMap[op]);
3376  *config.legalizableOps = std::move(originalLegalizableOps);
3377  }
3378 
3379  // Erase the cloned IR.
3380  clonedAncestor->erase();
3381  return status;
3382 }
3383 
3384 LogicalResult
3386  const FrozenRewritePatternSet &patterns,
3387  ConversionConfig config) {
3388  return applyAnalysisConversion(llvm::ArrayRef(op), target, patterns, config);
3389 }
static ConversionTarget::DynamicLegalityCallbackFn composeLegalityCallbacks(ConversionTarget::DynamicLegalityCallbackFn oldCallback, ConversionTarget::DynamicLegalityCallbackFn newCallback)
SmallVector< Value, 1 > ReplacementValues
A list of replacement SSA values.
static FailureOr< SmallVector< Value > > pdllConvertValues(ConversionPatternRewriter &rewriter, ValueRange values)
Remap the given value using the rewriter and the type converter in the provided config.
static LogicalResult convertFuncOpTypes(FunctionOpInterface funcOp, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args)
A utility function to log a failure result for the given reason.
static LogicalResult legalizeUnresolvedMaterialization(RewriterBase &rewriter, UnresolvedMaterializationRewrite *rewrite)
static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args)
A utility function to log a successful result for the given reason.
static Operation * findCommonAncestor(ArrayRef< Operation * > ops)
Find a common IsolatedFromAbove ancestor of the given ops.
static OpBuilder::InsertPoint computeInsertPoint(Value value)
Helper function that computes an insertion point where the given value is defined and can be used wit...
static bool hasRewrite(R &&rewrites, Operation *op)
Return "true" if there is an operation rewrite that matches the specified rewrite type and operation ...
static std::string diag(const llvm::Value &value)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void rewrite(DataFlowSolver &solver, MLIRContext *context, MutableArrayRef< Region > initialRegions)
Rewrite the given regions using the computing analysis.
Definition: SCCP.cpp:67
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:319
Block * getOwner() const
Returns the block that owns this argument.
Definition: Value.h:328
Location getLoc() const
Return the location for this argument.
Definition: Value.h:334
Block represents an ordered list of Operations.
Definition: Block.h:33
OpListType::iterator iterator
Definition: Block.h:140
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
Definition: Block.cpp:151
bool empty()
Definition: Block.h:148
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
unsigned getNumArguments()
Definition: Block.h:128
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition: Block.cpp:29
void dropAllDefinedValueUses()
This drops all uses of values defined in this block or in the blocks of nested regions wherever the u...
Definition: Block.cpp:96
OpListType & getOperations()
Definition: Block.h:137
BlockArgListType getArguments()
Definition: Block.h:87
Operation & front()
Definition: Block.h:153
iterator end()
Definition: Block.h:144
iterator begin()
Definition: Block.h:143
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:33
MLIRContext * getContext() const
Definition: Builders.h:56
Location getUnknownLoc()
Definition: Builders.cpp:27
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
LogicalResult getRemappedValues(ValueRange keys, SmallVectorImpl< Value > &results)
Return the converted values that replace 'keys' with types defined by the type converter of the curre...
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Apply a signature conversion to each block in the given region.
void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt) override
PatternRewriter hook for inlining the ops of a block into another block.
Block * applySignatureConversion(Block *block, TypeConverter::SignatureConversion &conversion, const TypeConverter *converter=nullptr)
Apply a signature conversion to given block.
void startOpModification(Operation *op) override
PatternRewriter hook for updating the given operation in-place.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
void replaceOpWithMultiple(Operation *op, ArrayRef< ValueRange > newValues)
Replace the given operation with the new value ranges.
detail::ConversionPatternRewriterImpl & getImpl()
Return a reference to the internal implementation.
void eraseBlock(Block *block) override
PatternRewriter hook for erase all operations in a block.
void cancelOpModification(Operation *op) override
PatternRewriter hook for updating the given operation in-place.
Value getRemappedValue(Value key)
Return the converted value of 'key' with a type defined by the type converter of the currently execut...
void finalizeOpModification(Operation *op) override
PatternRewriter hook for updating the given operation in-place.
void replaceUsesOfBlockArgument(BlockArgument from, Value to)
Replace all the uses of the block argument from with value to.
Base class for the conversion patterns.
virtual LogicalResult matchAndRewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const
Hook for derived classes to implement combined matching and rewriting.
This class describes a specific conversion target.
void setDialectAction(ArrayRef< StringRef > dialectNames, LegalizationAction action)
Register a legality action for the given dialects.
void setOpAction(OperationName op, LegalizationAction action)
Register a legality action for the given operation.
std::optional< LegalOpDetails > isLegal(Operation *op) const
If the given operation instance is legal on this target, a structure containing legality information ...
std::optional< LegalizationAction > getOpAction(OperationName op) const
Get the legality action for the given operation.
LegalizationAction
This enumeration corresponds to the specific action to take when considering an operation legal for t...
void markOpRecursivelyLegal(OperationName name, const DynamicLegalityCallbackFn &callback)
Mark an operation, that must have either been set as Legal or DynamicallyLegal, as being recursively ...
std::function< std::optional< bool >(Operation *)> DynamicLegalityCallbackFn
The signature of the callback used to determine if an operation is dynamically legal on the target.
bool isIllegal(Operation *op) const
Returns true is operation instance is illegal on this target.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
Definition: Diagnostics.h:155
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
const DenseMap< Operation *, Operation * > & getOperationMap() const
Return the held operation mapping.
Definition: IRMapping.h:88
auto lookup(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:72
user_range getUsers() const
Returns a range of all users.
Definition: UseDefLists.h:274
void replaceAllUsesWith(ValueT &&newValue)
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
Definition: UseDefLists.h:211
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:772
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:314
Location objects represent source locations information in MLIR.
Definition: Location.h:31
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
bool isMultithreadingEnabled()
Return true if multi-threading is enabled by the context.
This class represents a saved insertion point.
Definition: Builders.h:336
Block::iterator getPoint() const
Definition: Builders.h:349
bool isSet() const
Returns true if this insert point is set.
Definition: Builders.h:346
Block * getBlock() const
Definition: Builders.h:348
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:357
This class helps build Operations.
Definition: Builders.h:216
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:407
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Definition: Builders.h:329
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:470
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
LogicalResult tryFold(Operation *op, SmallVectorImpl< Value > &results)
Attempts to fold the given operation and places new results within results.
Definition: Builders.cpp:512
OpInterfaceConversionPattern is a wrapper around ConversionPattern that allows for matching and rewri...
OpInterfaceConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
This class represents an operand of an operation.
Definition: Value.h:267
This is a value defined by a result of an operation.
Definition: Value.h:457
This class provides the API for ops that are known to be isolated from above.
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:42
type_range getTypes() const
Definition: ValueRange.cpp:26
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
void setLoc(Location loc)
Set the source location the operation was defined or derived from.
Definition: Operation.h:226
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:745
void dropAllUses()
Drop all uses of results of this operation.
Definition: Operation.h:830
void setAttrs(DictionaryAttr newAttrs)
Set the attributes from a dictionary on this operation.
Definition: Operation.cpp:305
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Definition: Operation.cpp:386
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
Definition: Operation.cpp:717
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:793
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:669
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:507
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
Definition: Operation.h:248
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:672
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
operand_type_range getOperandTypes()
Definition: Operation.h:392
result_type_range getResultTypes()
Definition: Operation.h:423
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
void setSuccessor(Block *block, unsigned index)
Definition: Operation.cpp:605
bool isAncestor(Operation *other)
Return true if this operation is an ancestor of the other operation.
Definition: Operation.h:263
void setOperands(ValueRange operands)
Replace the current operands of this operation with the ones provided in 'operands'.
Definition: Operation.cpp:237
result_range getResults()
Definition: Operation.h:410
int getPropertiesStorageSize() const
Returns the properties storage size.
Definition: Operation.h:892
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
Definition: Operation.cpp:219
OpaqueProperties getPropertiesStorage()
Returns the properties storage.
Definition: Operation.h:896
void erase()
Remove this operation from its parent block and delete it.
Definition: Operation.cpp:539
void copyProperties(OpaqueProperties rhs)
Copy properties from an existing other properties object.
Definition: Operation.cpp:366
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
void notifyRewriteEnd(PatternRewriter &rewriter) final
void notifyRewriteBegin(PatternRewriter &rewriter) final
Hooks that are invoked at the beginning and end of a rewrite of a matched pattern.
This class manages the application of a group of rewrite patterns, with a user-provided cost model.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
static PatternBenefit impossibleToMatch()
Definition: PatternMatch.h:43
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
Definition: PatternMatch.h:73
bool hasBoundedRewriteRecursion() const
Returns true if this pattern is known to result in recursive application, i.e.
Definition: PatternMatch.h:129
std::optional< OperationName > getRootKind() const
Return the root node that this pattern matches.
Definition: PatternMatch.h:94
ArrayRef< OperationName > getGeneratedOps() const
Return a list of operations that may be generated when rewriting an operation instance with this patt...
Definition: PatternMatch.h:90
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition: Region.h:200
bool empty()
Definition: Region.h:60
iterator end()
Definition: Region.h:56
BlockListType & getBlocks()
Definition: Region.h:45
Block & front()
Definition: Region.h:65
BlockListType::iterator iterator
Definition: Region.h:52
MLIRContext * getContext() const
Definition: PatternMatch.h:829
PDLPatternModule & getPDLPatterns()
Return the PDL patterns held in this list.
Definition: PatternMatch.h:835
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:853
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:724
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:644
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void moveOpBefore(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:636
The general result of a type attribute conversion callback, allowing for early termination.
static AttributeConversionResult result(Attribute attr)
This class provides all of the information necessary to convert a type signature.
void addInputs(unsigned origInputNo, ArrayRef< Type > types)
Remap an input of the original signature with a new set of types.
std::optional< InputMapping > getInputMapping(unsigned input) const
Get the input mapping for the given argument.
ArrayRef< Type > getConvertedTypes() const
Return the argument types for the new signature.
void remapInput(unsigned origInputNo, Value replacement)
Remap an input of the original signature to another replacement value.
Type conversion class.
std::optional< Attribute > convertTypeAttribute(Type type, Attribute attr) const
Convert an attribute present attr from within the type type using the registered conversion functions...
Value materializeSourceConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs) const
bool isLegal(Type type) const
Return true if the given type is legal for this type converter, i.e.
LogicalResult convertSignatureArgs(TypeRange types, SignatureConversion &result, unsigned origInputOffset=0) const
LogicalResult convertSignatureArg(unsigned inputNo, Type type, SignatureConversion &result) const
This method allows for converting a specific argument of a signature.
Value materializeArgumentConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs) const
Materialize a conversion from a set of types into one result type by generating a cast sequence of so...
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
std::optional< SignatureConversion > convertBlockSignature(Block *block) const
This function converts the type signature of the given block, by invoking 'convertSignatureArg' for e...
LogicalResult convertTypes(TypeRange types, SmallVectorImpl< Type > &results) const
Convert the given set of types, filling 'results' as necessary.
bool isSignatureLegal(FunctionType ty) const
Return true if the inputs and outputs of the given function type are legal.
Value materializeTargetConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs, Type originalType={}) const
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Block * getParentBlock()
Return the Block in which this Value is defined.
Definition: Value.cpp:48
user_range getUsers() const
Definition: Value.h:228
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
static WalkResult skip()
Definition: Visitors.h:52
static WalkResult advance()
Definition: Visitors.h:51
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
AttrTypeReplacer.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
@ Full
Documents are synced by always sending the full content of the document.
Kind
An enumeration of the kinds of predicates.
Definition: Predicate.h:44
llvm::PointerUnion< NamedAttribute *, NamedProperty *, NamedTypeConstraint * > Argument
Definition: Argument.h:64
Include the generated interface declarations.
void populateFunctionOpInterfaceTypeConversionPattern(StringRef functionLikeOpName, RewritePatternSet &patterns, const TypeConverter &converter)
Add a pattern to the given pattern list to convert the signature of a FunctionOpInterface op with the...
void populateAnyFunctionOpInterfaceTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter)
LogicalResult applyFullConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Apply a complete conversion on the given operations, and all nested operations.
FailureOr< Operation * > convertOpResultTypes(Operation *op, ValueRange operands, const TypeConverter &converter, ConversionPatternRewriter &rewriter)
Generic utility to convert op result types according to type converter without knowing exact op type.
LogicalResult applyAnalysisConversion(ArrayRef< Operation * > ops, ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Apply an analysis conversion on the given operations, and all nested operations.
void reconcileUnrealizedCasts(ArrayRef< UnrealizedConversionCastOp > castOps, SmallVectorImpl< UnrealizedConversionCastOp > *remainingCastOps=nullptr)
Try to reconcile all given UnrealizedConversionCastOps and store the left-over ops in remainingCastOp...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void registerConversionPDLFunctions(RewritePatternSet &patterns)
Register the dialect conversion PDL functions with the given pattern set.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
Dialect conversion configuration.
RewriterBase::Listener * listener
An optional listener that is notified about all IR modifications in case dialect conversion succeeds.
function_ref< void(Diagnostic &)> notifyCallback
An optional callback used to notify about match failure diagnostics during the conversion.
DenseSet< Operation * > * legalizableOps
Analysis conversion only.
DenseSet< Operation * > * unlegalizedOps
Partial conversion only.
bool buildMaterializations
If set to "true", the dialect conversion attempts to build source/target/ argument materializations t...
A structure containing additional information describing a specific legal operation instance.
bool isRecursivelyLegal
A flag that indicates if this operation is 'recursively' legal.
This iterator enumerates elements according to their dominance relationship.
Definition: Iterators.h:48
LogicalResult convertOperations(ArrayRef< Operation * > ops)
Converts the given operations to the conversion target.
OperationConverter(const ConversionTarget &target, const FrozenRewritePatternSet &patterns, const ConversionConfig &config, OpConversionMode mode)
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addTypes(ArrayRef< Type > newTypes)
This struct represents a range of new types or a single value that remaps an existing signature input...
A rewriter that keeps track of erased ops and blocks.
void eraseOp(Operation *op) override
Erase the given op (unless it was already erased).
void notifyBlockErased(Block *block) override
Notify the listener that the specified block is about to be erased.
void notifyOperationErased(Operation *op) override
Notify the listener that the specified operation is about to be erased.
void eraseBlock(Block *block) override
Erase the given block (unless it was already erased).
void notifyOperationInserted(Operation *op, OpBuilder::InsertPoint previous) override
Notify the listener that the specified operation was inserted.
Value findOrBuildReplacementValue(Value value, const TypeConverter *converter)
Find a replacement value for the given SSA value in the conversion value mapping.
ConversionPatternRewriterImpl(MLIRContext *ctx, const ConversionConfig &config)
DenseMap< Region *, const TypeConverter * > regionToConverter
A mapping of regions to type converters that should be used when converting the arguments of blocks w...
bool wasOpReplaced(Operation *op) const
Return "true" if the given operation was replaced or erased.
void notifyBlockInserted(Block *block, Region *previous, Region::iterator previousIt) override
Notifies that a block was inserted.
DenseMap< UnrealizedConversionCastOp, UnresolvedMaterializationRewrite * > unresolvedMaterializations
A mapping of all unresolved materializations (UnrealizedConversionCastOp) to the corresponding rewrit...
void resetState(RewriterState state)
Reset the state of the rewriter to a previously saved point.
Block * applySignatureConversion(ConversionPatternRewriter &rewriter, Block *block, const TypeConverter *converter, TypeConverter::SignatureConversion &signatureConversion)
Apply the given signature conversion on the given block.
FailureOr< Block * > convertRegionTypes(ConversionPatternRewriter &rewriter, Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion)
Convert the types of block arguments within the given region.
void applyRewrites()
Apply all requested operation rewrites.
void undoRewrites(unsigned numRewritesToKeep=0)
Undo the rewrites (motions, splits) one by one in reverse order until "numRewritesToKeep" rewrites re...
RewriterState getCurrentState()
Return the current state of the rewriter.
void notifyOpReplaced(Operation *op, ArrayRef< ReplacementValues > newValues)
Notifies that an op is about to be replaced with the given values.
llvm::ScopedPrinter logger
A logger used to emit diagnostics during the conversion process.
void notifyBlockBeingInlined(Block *block, Block *srcBlock, Block::iterator before)
Notifies that a block is being inlined into another block.
void appendRewrite(Args &&...args)
Append a rewrite.
SmallPtrSet< Operation *, 1 > pendingRootUpdates
A set of operations that have pending updates.
void insertNTo1Materialization(OpBuilder::InsertPoint ip, Location loc, ValueRange replacements, Value originalValue, const TypeConverter *converter)
Build an N:1 materialization for the given original value that was replaced with the given replacemen...
LogicalResult remapValues(StringRef valueDiagTag, std::optional< Location > inputLoc, PatternRewriter &rewriter, ValueRange values, SmallVectorImpl< Value > &remapped)
Remap the given values to those with potentially different types.
void notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
Notifies that a pattern match failed for the given reason.
SingleEraseRewriter eraseRewriter
A rewriter that keeps track of ops/block that were already erased and skips duplicate op/block erasur...
bool isOpIgnored(Operation *op) const
Return "true" if the given operation is ignored, and does not need to be converted.
Value buildUnresolvedMaterialization(MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc, ValueRange inputs, Type outputType, Type originalType, const TypeConverter *converter)
Build an unresolved materialization operation given an output type and set of input operands.
SetVector< Operation * > ignoredOps
A set of operations that should no longer be considered for legalization.
SmallVector< std::unique_ptr< IRRewrite > > rewrites
Ordered list of block operations (creations, splits, motions).
const ConversionConfig & config
Dialect conversion configuration.
SetVector< Operation * > replacedOps
A set of operations that were replaced/erased.
void notifyBlockIsBeingErased(Block *block)
Notifies that a block is about to be erased.
const TypeConverter * currentTypeConverter
The current type converter, or nullptr if no type converter is currently active.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.