MLIR  20.0.0git
DialectConversion.cpp
Go to the documentation of this file.
1 //===- DialectConversion.cpp - MLIR dialect conversion generic pass -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 #include "mlir/Config/mlir-config.h"
11 #include "mlir/IR/Block.h"
12 #include "mlir/IR/Builders.h"
13 #include "mlir/IR/BuiltinOps.h"
14 #include "mlir/IR/IRMapping.h"
15 #include "mlir/IR/Iterators.h"
18 #include "llvm/ADT/ScopeExit.h"
19 #include "llvm/ADT/SetVector.h"
20 #include "llvm/ADT/SmallPtrSet.h"
21 #include "llvm/Support/Debug.h"
22 #include "llvm/Support/FormatVariadic.h"
23 #include "llvm/Support/SaveAndRestore.h"
24 #include "llvm/Support/ScopedPrinter.h"
25 #include <optional>
26 
27 using namespace mlir;
28 using namespace mlir::detail;
29 
30 #define DEBUG_TYPE "dialect-conversion"
31 
32 /// A utility function to log a successful result for the given reason.
33 template <typename... Args>
34 static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
35  LLVM_DEBUG({
36  os.unindent();
37  os.startLine() << "} -> SUCCESS";
38  if (!fmt.empty())
39  os.getOStream() << " : "
40  << llvm::formatv(fmt.data(), std::forward<Args>(args)...);
41  os.getOStream() << "\n";
42  });
43 }
44 
45 /// A utility function to log a failure result for the given reason.
46 template <typename... Args>
47 static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
48  LLVM_DEBUG({
49  os.unindent();
50  os.startLine() << "} -> FAILURE : "
51  << llvm::formatv(fmt.data(), std::forward<Args>(args)...)
52  << "\n";
53  });
54 }
55 
56 /// Helper function that computes an insertion point where the given value is
57 /// defined and can be used without a dominance violation.
59  Block *insertBlock = value.getParentBlock();
60  Block::iterator insertPt = insertBlock->begin();
61  if (OpResult inputRes = dyn_cast<OpResult>(value))
62  insertPt = ++inputRes.getOwner()->getIterator();
63  return OpBuilder::InsertPoint(insertBlock, insertPt);
64 }
65 
66 //===----------------------------------------------------------------------===//
67 // ConversionValueMapping
68 //===----------------------------------------------------------------------===//
69 
70 namespace {
71 /// This class wraps a IRMapping to provide recursive lookup
72 /// functionality, i.e. we will traverse if the mapped value also has a mapping.
73 struct ConversionValueMapping {
74  /// Return "true" if an SSA value is mapped to the given value. May return
75  /// false positives.
76  bool isMappedTo(Value value) const { return mappedTo.contains(value); }
77 
78  /// Lookup the most recently mapped value with the desired type in the
79  /// mapping.
80  ///
81  /// Special cases:
82  /// - If the desired type is "null", simply return the most recently mapped
83  /// value.
84  /// - If there is no mapping to the desired type, also return the most
85  /// recently mapped value.
86  /// - If there is no mapping for the given value at all, return the given
87  /// value.
88  Value lookupOrDefault(Value from, Type desiredType = nullptr) const;
89 
90  /// Lookup a mapped value within the map, or return null if a mapping does not
91  /// exist. If a mapping exists, this follows the same behavior of
92  /// `lookupOrDefault`.
93  Value lookupOrNull(Value from, Type desiredType = nullptr) const;
94 
95  /// Map a value to the one provided.
96  void map(Value oldVal, Value newVal) {
97  LLVM_DEBUG({
98  for (Value it = newVal; it; it = mapping.lookupOrNull(it))
99  assert(it != oldVal && "inserting cyclic mapping");
100  });
101  mapping.map(oldVal, newVal);
102  mappedTo.insert(newVal);
103  }
104 
105  /// Drop the last mapping for the given value.
106  void erase(Value value) { mapping.erase(value); }
107 
108 private:
109  /// Current value mappings.
110  IRMapping mapping;
111 
112  /// All SSA values that are mapped to. May contain false positives.
113  DenseSet<Value> mappedTo;
114 };
115 } // namespace
116 
117 Value ConversionValueMapping::lookupOrDefault(Value from,
118  Type desiredType) const {
119  // Try to find the deepest value that has the desired type. If there is no
120  // such value, simply return the deepest value.
121  Value desiredValue;
122  do {
123  if (!desiredType || from.getType() == desiredType)
124  desiredValue = from;
125 
126  Value mappedValue = mapping.lookupOrNull(from);
127  if (!mappedValue)
128  break;
129  from = mappedValue;
130  } while (true);
131 
132  // If the desired value was found use it, otherwise default to the leaf value.
133  return desiredValue ? desiredValue : from;
134 }
135 
136 Value ConversionValueMapping::lookupOrNull(Value from, Type desiredType) const {
137  Value result = lookupOrDefault(from, desiredType);
138  if (result == from || (desiredType && result.getType() != desiredType))
139  return nullptr;
140  return result;
141 }
142 
143 //===----------------------------------------------------------------------===//
144 // Rewriter and Translation State
145 //===----------------------------------------------------------------------===//
146 namespace {
147 /// This class contains a snapshot of the current conversion rewriter state.
148 /// This is useful when saving and undoing a set of rewrites.
149 struct RewriterState {
150  RewriterState(unsigned numRewrites, unsigned numIgnoredOperations,
151  unsigned numReplacedOps)
152  : numRewrites(numRewrites), numIgnoredOperations(numIgnoredOperations),
153  numReplacedOps(numReplacedOps) {}
154 
155  /// The current number of rewrites performed.
156  unsigned numRewrites;
157 
158  /// The current number of ignored operations.
159  unsigned numIgnoredOperations;
160 
161  /// The current number of replaced ops that are scheduled for erasure.
162  unsigned numReplacedOps;
163 };
164 
165 //===----------------------------------------------------------------------===//
166 // IR rewrites
167 //===----------------------------------------------------------------------===//
168 
169 /// An IR rewrite that can be committed (upon success) or rolled back (upon
170 /// failure).
171 ///
172 /// The dialect conversion keeps track of IR modifications (requested by the
173 /// user through the rewriter API) in `IRRewrite` objects. Some kind of rewrites
174 /// are directly applied to the IR as the rewriter API is used, some are applied
175 /// partially, and some are delayed until the `IRRewrite` objects are committed.
176 class IRRewrite {
177 public:
178  /// The kind of the rewrite. Rewrites can be undone if the conversion fails.
179  /// Enum values are ordered, so that they can be used in `classof`: first all
180  /// block rewrites, then all operation rewrites.
181  enum class Kind {
182  // Block rewrites
183  CreateBlock,
184  EraseBlock,
185  InlineBlock,
186  MoveBlock,
187  BlockTypeConversion,
188  ReplaceBlockArg,
189  // Operation rewrites
190  MoveOperation,
191  ModifyOperation,
192  ReplaceOperation,
193  CreateOperation,
194  UnresolvedMaterialization
195  };
196 
197  virtual ~IRRewrite() = default;
198 
199  /// Roll back the rewrite. Operations may be erased during rollback.
200  virtual void rollback() = 0;
201 
202  /// Commit the rewrite. At this point, it is certain that the dialect
203  /// conversion will succeed. All IR modifications, except for operation/block
204  /// erasure, must be performed through the given rewriter.
205  ///
206  /// Instead of erasing operations/blocks, they should merely be unlinked
207  /// commit phase and finally be erased during the cleanup phase. This is
208  /// because internal dialect conversion state (such as `mapping`) may still
209  /// be using them.
210  ///
211  /// Any IR modification that was already performed before the commit phase
212  /// (e.g., insertion of an op) must be communicated to the listener that may
213  /// be attached to the given rewriter.
214  virtual void commit(RewriterBase &rewriter) {}
215 
216  /// Cleanup operations/blocks. Cleanup is called after commit.
217  virtual void cleanup(RewriterBase &rewriter) {}
218 
219  Kind getKind() const { return kind; }
220 
221  static bool classof(const IRRewrite *rewrite) { return true; }
222 
223 protected:
224  IRRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl)
225  : kind(kind), rewriterImpl(rewriterImpl) {}
226 
227  const ConversionConfig &getConfig() const;
228 
229  const Kind kind;
230  ConversionPatternRewriterImpl &rewriterImpl;
231 };
232 
233 /// A block rewrite.
234 class BlockRewrite : public IRRewrite {
235 public:
236  /// Return the block that this rewrite operates on.
237  Block *getBlock() const { return block; }
238 
239  static bool classof(const IRRewrite *rewrite) {
240  return rewrite->getKind() >= Kind::CreateBlock &&
241  rewrite->getKind() <= Kind::ReplaceBlockArg;
242  }
243 
244 protected:
245  BlockRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl,
246  Block *block)
247  : IRRewrite(kind, rewriterImpl), block(block) {}
248 
249  // The block that this rewrite operates on.
250  Block *block;
251 };
252 
253 /// Creation of a block. Block creations are immediately reflected in the IR.
254 /// There is no extra work to commit the rewrite. During rollback, the newly
255 /// created block is erased.
256 class CreateBlockRewrite : public BlockRewrite {
257 public:
258  CreateBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block)
259  : BlockRewrite(Kind::CreateBlock, rewriterImpl, block) {}
260 
261  static bool classof(const IRRewrite *rewrite) {
262  return rewrite->getKind() == Kind::CreateBlock;
263  }
264 
265  void commit(RewriterBase &rewriter) override {
266  // The block was already created and inserted. Just inform the listener.
267  if (auto *listener = rewriter.getListener())
268  listener->notifyBlockInserted(block, /*previous=*/{}, /*previousIt=*/{});
269  }
270 
271  void rollback() override {
272  // Unlink all of the operations within this block, they will be deleted
273  // separately.
274  auto &blockOps = block->getOperations();
275  while (!blockOps.empty())
276  blockOps.remove(blockOps.begin());
277  block->dropAllUses();
278  if (block->getParent())
279  block->erase();
280  else
281  delete block;
282  }
283 };
284 
285 /// Erasure of a block. Block erasures are partially reflected in the IR. Erased
286 /// blocks are immediately unlinked, but only erased during cleanup. This makes
287 /// it easier to rollback a block erasure: the block is simply inserted into its
288 /// original location.
289 class EraseBlockRewrite : public BlockRewrite {
290 public:
291  EraseBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block)
292  : BlockRewrite(Kind::EraseBlock, rewriterImpl, block),
293  region(block->getParent()), insertBeforeBlock(block->getNextNode()) {}
294 
295  static bool classof(const IRRewrite *rewrite) {
296  return rewrite->getKind() == Kind::EraseBlock;
297  }
298 
299  ~EraseBlockRewrite() override {
300  assert(!block &&
301  "rewrite was neither rolled back nor committed/cleaned up");
302  }
303 
304  void rollback() override {
305  // The block (owned by this rewrite) was not actually erased yet. It was
306  // just unlinked. Put it back into its original position.
307  assert(block && "expected block");
308  auto &blockList = region->getBlocks();
309  Region::iterator before = insertBeforeBlock
310  ? Region::iterator(insertBeforeBlock)
311  : blockList.end();
312  blockList.insert(before, block);
313  block = nullptr;
314  }
315 
316  void commit(RewriterBase &rewriter) override {
317  // Erase the block.
318  assert(block && "expected block");
319  assert(block->empty() && "expected empty block");
320 
321  // Notify the listener that the block is about to be erased.
322  if (auto *listener =
323  dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener()))
324  listener->notifyBlockErased(block);
325  }
326 
327  void cleanup(RewriterBase &rewriter) override {
328  // Erase the block.
329  block->dropAllDefinedValueUses();
330  delete block;
331  block = nullptr;
332  }
333 
334 private:
335  // The region in which this block was previously contained.
336  Region *region;
337 
338  // The original successor of this block before it was unlinked. "nullptr" if
339  // this block was the only block in the region.
340  Block *insertBeforeBlock;
341 };
342 
343 /// Inlining of a block. This rewrite is immediately reflected in the IR.
344 /// Note: This rewrite represents only the inlining of the operations. The
345 /// erasure of the inlined block is a separate rewrite.
346 class InlineBlockRewrite : public BlockRewrite {
347 public:
348  InlineBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block,
349  Block *sourceBlock, Block::iterator before)
350  : BlockRewrite(Kind::InlineBlock, rewriterImpl, block),
351  sourceBlock(sourceBlock),
352  firstInlinedInst(sourceBlock->empty() ? nullptr
353  : &sourceBlock->front()),
354  lastInlinedInst(sourceBlock->empty() ? nullptr : &sourceBlock->back()) {
355  // If a listener is attached to the dialect conversion, ops must be moved
356  // one-by-one. When they are moved in bulk, notifications cannot be sent
357  // because the ops that used to be in the source block at the time of the
358  // inlining (before the "commit" phase) are unknown at the time when
359  // notifications are sent (which is during the "commit" phase).
360  assert(!getConfig().listener &&
361  "InlineBlockRewrite not supported if listener is attached");
362  }
363 
364  static bool classof(const IRRewrite *rewrite) {
365  return rewrite->getKind() == Kind::InlineBlock;
366  }
367 
368  void rollback() override {
369  // Put the operations from the destination block (owned by the rewrite)
370  // back into the source block.
371  if (firstInlinedInst) {
372  assert(lastInlinedInst && "expected operation");
373  sourceBlock->getOperations().splice(sourceBlock->begin(),
374  block->getOperations(),
375  Block::iterator(firstInlinedInst),
376  ++Block::iterator(lastInlinedInst));
377  }
378  }
379 
380 private:
381  // The block that originally contained the operations.
382  Block *sourceBlock;
383 
384  // The first inlined operation.
385  Operation *firstInlinedInst;
386 
387  // The last inlined operation.
388  Operation *lastInlinedInst;
389 };
390 
391 /// Moving of a block. This rewrite is immediately reflected in the IR.
392 class MoveBlockRewrite : public BlockRewrite {
393 public:
394  MoveBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block,
395  Region *region, Block *insertBeforeBlock)
396  : BlockRewrite(Kind::MoveBlock, rewriterImpl, block), region(region),
397  insertBeforeBlock(insertBeforeBlock) {}
398 
399  static bool classof(const IRRewrite *rewrite) {
400  return rewrite->getKind() == Kind::MoveBlock;
401  }
402 
403  void commit(RewriterBase &rewriter) override {
404  // The block was already moved. Just inform the listener.
405  if (auto *listener = rewriter.getListener()) {
406  // Note: `previousIt` cannot be passed because this is a delayed
407  // notification and iterators into past IR state cannot be represented.
408  listener->notifyBlockInserted(block, /*previous=*/region,
409  /*previousIt=*/{});
410  }
411  }
412 
413  void rollback() override {
414  // Move the block back to its original position.
415  Region::iterator before =
416  insertBeforeBlock ? Region::iterator(insertBeforeBlock) : region->end();
417  region->getBlocks().splice(before, block->getParent()->getBlocks(), block);
418  }
419 
420 private:
421  // The region in which this block was previously contained.
422  Region *region;
423 
424  // The original successor of this block before it was moved. "nullptr" if
425  // this block was the only block in the region.
426  Block *insertBeforeBlock;
427 };
428 
429 /// Block type conversion. This rewrite is partially reflected in the IR.
430 class BlockTypeConversionRewrite : public BlockRewrite {
431 public:
432  BlockTypeConversionRewrite(ConversionPatternRewriterImpl &rewriterImpl,
433  Block *origBlock, Block *newBlock)
434  : BlockRewrite(Kind::BlockTypeConversion, rewriterImpl, origBlock),
435  newBlock(newBlock) {}
436 
437  static bool classof(const IRRewrite *rewrite) {
438  return rewrite->getKind() == Kind::BlockTypeConversion;
439  }
440 
441  Block *getOrigBlock() const { return block; }
442 
443  Block *getNewBlock() const { return newBlock; }
444 
445  void commit(RewriterBase &rewriter) override;
446 
447  void rollback() override;
448 
449 private:
450  /// The new block that was created as part of this signature conversion.
451  Block *newBlock;
452 };
453 
454 /// Replacing a block argument. This rewrite is not immediately reflected in the
455 /// IR. An internal IR mapping is updated, but the actual replacement is delayed
456 /// until the rewrite is committed.
457 class ReplaceBlockArgRewrite : public BlockRewrite {
458 public:
459  ReplaceBlockArgRewrite(ConversionPatternRewriterImpl &rewriterImpl,
460  Block *block, BlockArgument arg,
461  const TypeConverter *converter)
462  : BlockRewrite(Kind::ReplaceBlockArg, rewriterImpl, block), arg(arg),
463  converter(converter) {}
464 
465  static bool classof(const IRRewrite *rewrite) {
466  return rewrite->getKind() == Kind::ReplaceBlockArg;
467  }
468 
469  void commit(RewriterBase &rewriter) override;
470 
471  void rollback() override;
472 
473 private:
474  BlockArgument arg;
475 
476  /// The current type converter when the block argument was replaced.
477  const TypeConverter *converter;
478 };
479 
480 /// An operation rewrite.
481 class OperationRewrite : public IRRewrite {
482 public:
483  /// Return the operation that this rewrite operates on.
484  Operation *getOperation() const { return op; }
485 
486  static bool classof(const IRRewrite *rewrite) {
487  return rewrite->getKind() >= Kind::MoveOperation &&
488  rewrite->getKind() <= Kind::UnresolvedMaterialization;
489  }
490 
491 protected:
492  OperationRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl,
493  Operation *op)
494  : IRRewrite(kind, rewriterImpl), op(op) {}
495 
496  // The operation that this rewrite operates on.
497  Operation *op;
498 };
499 
500 /// Moving of an operation. This rewrite is immediately reflected in the IR.
501 class MoveOperationRewrite : public OperationRewrite {
502 public:
503  MoveOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
504  Operation *op, Block *block, Operation *insertBeforeOp)
505  : OperationRewrite(Kind::MoveOperation, rewriterImpl, op), block(block),
506  insertBeforeOp(insertBeforeOp) {}
507 
508  static bool classof(const IRRewrite *rewrite) {
509  return rewrite->getKind() == Kind::MoveOperation;
510  }
511 
512  void commit(RewriterBase &rewriter) override {
513  // The operation was already moved. Just inform the listener.
514  if (auto *listener = rewriter.getListener()) {
515  // Note: `previousIt` cannot be passed because this is a delayed
516  // notification and iterators into past IR state cannot be represented.
517  listener->notifyOperationInserted(
518  op, /*previous=*/OpBuilder::InsertPoint(/*insertBlock=*/block,
519  /*insertPt=*/{}));
520  }
521  }
522 
523  void rollback() override {
524  // Move the operation back to its original position.
525  Block::iterator before =
526  insertBeforeOp ? Block::iterator(insertBeforeOp) : block->end();
527  block->getOperations().splice(before, op->getBlock()->getOperations(), op);
528  }
529 
530 private:
531  // The block in which this operation was previously contained.
532  Block *block;
533 
534  // The original successor of this operation before it was moved. "nullptr"
535  // if this operation was the only operation in the region.
536  Operation *insertBeforeOp;
537 };
538 
539 /// In-place modification of an op. This rewrite is immediately reflected in
540 /// the IR. The previous state of the operation is stored in this object.
541 class ModifyOperationRewrite : public OperationRewrite {
542 public:
543  ModifyOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
544  Operation *op)
545  : OperationRewrite(Kind::ModifyOperation, rewriterImpl, op),
546  name(op->getName()), loc(op->getLoc()), attrs(op->getAttrDictionary()),
547  operands(op->operand_begin(), op->operand_end()),
548  successors(op->successor_begin(), op->successor_end()) {
549  if (OpaqueProperties prop = op->getPropertiesStorage()) {
550  // Make a copy of the properties.
551  propertiesStorage = operator new(op->getPropertiesStorageSize());
552  OpaqueProperties propCopy(propertiesStorage);
553  name.initOpProperties(propCopy, /*init=*/prop);
554  }
555  }
556 
557  static bool classof(const IRRewrite *rewrite) {
558  return rewrite->getKind() == Kind::ModifyOperation;
559  }
560 
561  ~ModifyOperationRewrite() override {
562  assert(!propertiesStorage &&
563  "rewrite was neither committed nor rolled back");
564  }
565 
566  void commit(RewriterBase &rewriter) override {
567  // Notify the listener that the operation was modified in-place.
568  if (auto *listener =
569  dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener()))
570  listener->notifyOperationModified(op);
571 
572  if (propertiesStorage) {
573  OpaqueProperties propCopy(propertiesStorage);
574  // Note: The operation may have been erased in the mean time, so
575  // OperationName must be stored in this object.
576  name.destroyOpProperties(propCopy);
577  operator delete(propertiesStorage);
578  propertiesStorage = nullptr;
579  }
580  }
581 
582  void rollback() override {
583  op->setLoc(loc);
584  op->setAttrs(attrs);
585  op->setOperands(operands);
586  for (const auto &it : llvm::enumerate(successors))
587  op->setSuccessor(it.value(), it.index());
588  if (propertiesStorage) {
589  OpaqueProperties propCopy(propertiesStorage);
590  op->copyProperties(propCopy);
591  name.destroyOpProperties(propCopy);
592  operator delete(propertiesStorage);
593  propertiesStorage = nullptr;
594  }
595  }
596 
597 private:
598  OperationName name;
599  LocationAttr loc;
600  DictionaryAttr attrs;
601  SmallVector<Value, 8> operands;
602  SmallVector<Block *, 2> successors;
603  void *propertiesStorage = nullptr;
604 };
605 
606 /// Replacing an operation. Erasing an operation is treated as a special case
607 /// with "null" replacements. This rewrite is not immediately reflected in the
608 /// IR. An internal IR mapping is updated, but values are not replaced and the
609 /// original op is not erased until the rewrite is committed.
610 class ReplaceOperationRewrite : public OperationRewrite {
611 public:
612  ReplaceOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
613  Operation *op, const TypeConverter *converter)
614  : OperationRewrite(Kind::ReplaceOperation, rewriterImpl, op),
615  converter(converter) {}
616 
617  static bool classof(const IRRewrite *rewrite) {
618  return rewrite->getKind() == Kind::ReplaceOperation;
619  }
620 
621  void commit(RewriterBase &rewriter) override;
622 
623  void rollback() override;
624 
625  void cleanup(RewriterBase &rewriter) override;
626 
627 private:
628  /// An optional type converter that can be used to materialize conversions
629  /// between the new and old values if necessary.
630  const TypeConverter *converter;
631 };
632 
633 class CreateOperationRewrite : public OperationRewrite {
634 public:
635  CreateOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
636  Operation *op)
637  : OperationRewrite(Kind::CreateOperation, rewriterImpl, op) {}
638 
639  static bool classof(const IRRewrite *rewrite) {
640  return rewrite->getKind() == Kind::CreateOperation;
641  }
642 
643  void commit(RewriterBase &rewriter) override {
644  // The operation was already created and inserted. Just inform the listener.
645  if (auto *listener = rewriter.getListener())
646  listener->notifyOperationInserted(op, /*previous=*/{});
647  }
648 
649  void rollback() override;
650 };
651 
652 /// The type of materialization.
653 enum MaterializationKind {
654  /// This materialization materializes a conversion for an illegal block
655  /// argument type, to the original one.
656  Argument,
657 
658  /// This materialization materializes a conversion from an illegal type to a
659  /// legal one.
660  Target,
661 
662  /// This materialization materializes a conversion from a legal type back to
663  /// an illegal one.
664  Source
665 };
666 
667 /// An unresolved materialization, i.e., a "builtin.unrealized_conversion_cast"
668 /// op. Unresolved materializations are erased at the end of the dialect
669 /// conversion.
670 class UnresolvedMaterializationRewrite : public OperationRewrite {
671 public:
672  UnresolvedMaterializationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
673  UnrealizedConversionCastOp op,
674  const TypeConverter *converter,
675  MaterializationKind kind, Type originalType,
676  Value mappedValue);
677 
678  static bool classof(const IRRewrite *rewrite) {
679  return rewrite->getKind() == Kind::UnresolvedMaterialization;
680  }
681 
682  void rollback() override;
683 
684  UnrealizedConversionCastOp getOperation() const {
685  return cast<UnrealizedConversionCastOp>(op);
686  }
687 
688  /// Return the type converter of this materialization (which may be null).
689  const TypeConverter *getConverter() const {
690  return converterAndKind.getPointer();
691  }
692 
693  /// Return the kind of this materialization.
694  MaterializationKind getMaterializationKind() const {
695  return converterAndKind.getInt();
696  }
697 
698  /// Return the original type of the SSA value.
699  Type getOriginalType() const { return originalType; }
700 
701 private:
702  /// The corresponding type converter to use when resolving this
703  /// materialization, and the kind of this materialization.
704  llvm::PointerIntPair<const TypeConverter *, 2, MaterializationKind>
705  converterAndKind;
706 
707  /// The original type of the SSA value. Only used for target
708  /// materializations.
709  Type originalType;
710 
711  /// The value in the conversion value mapping that is being replaced by the
712  /// results of this unresolved materialization.
713  Value mappedValue;
714 };
715 } // namespace
716 
717 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
718 /// Return "true" if there is an operation rewrite that matches the specified
719 /// rewrite type and operation among the given rewrites.
720 template <typename RewriteTy, typename R>
721 static bool hasRewrite(R &&rewrites, Operation *op) {
722  return any_of(std::forward<R>(rewrites), [&](auto &rewrite) {
723  auto *rewriteTy = dyn_cast<RewriteTy>(rewrite.get());
724  return rewriteTy && rewriteTy->getOperation() == op;
725  });
726 }
727 
728 /// Return "true" if there is a block rewrite that matches the specified
729 /// rewrite type and block among the given rewrites.
730 template <typename RewriteTy, typename R>
731 static bool hasRewrite(R &&rewrites, Block *block) {
732  return any_of(std::forward<R>(rewrites), [&](auto &rewrite) {
733  auto *rewriteTy = dyn_cast<RewriteTy>(rewrite.get());
734  return rewriteTy && rewriteTy->getBlock() == block;
735  });
736 }
737 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
738 
739 //===----------------------------------------------------------------------===//
740 // ConversionPatternRewriterImpl
741 //===----------------------------------------------------------------------===//
742 namespace mlir {
743 namespace detail {
746  const ConversionConfig &config)
747  : context(ctx), eraseRewriter(ctx), config(config) {}
748 
749  //===--------------------------------------------------------------------===//
750  // State Management
751  //===--------------------------------------------------------------------===//
752 
753  /// Return the current state of the rewriter.
754  RewriterState getCurrentState();
755 
756  /// Apply all requested operation rewrites. This method is invoked when the
757  /// conversion process succeeds.
758  void applyRewrites();
759 
760  /// Reset the state of the rewriter to a previously saved point.
761  void resetState(RewriterState state);
762 
763  /// Append a rewrite. Rewrites are committed upon success and rolled back upon
764  /// failure.
765  template <typename RewriteTy, typename... Args>
766  void appendRewrite(Args &&...args) {
767  rewrites.push_back(
768  std::make_unique<RewriteTy>(*this, std::forward<Args>(args)...));
769  }
770 
771  /// Undo the rewrites (motions, splits) one by one in reverse order until
772  /// "numRewritesToKeep" rewrites remains.
773  void undoRewrites(unsigned numRewritesToKeep = 0);
774 
775  /// Remap the given values to those with potentially different types. Returns
776  /// success if the values could be remapped, failure otherwise. `valueDiagTag`
777  /// is the tag used when describing a value within a diagnostic, e.g.
778  /// "operand".
779  LogicalResult remapValues(StringRef valueDiagTag,
780  std::optional<Location> inputLoc,
781  PatternRewriter &rewriter, ValueRange values,
782  SmallVector<SmallVector<Value>> &remapped);
783 
784  /// Return "true" if the given operation is ignored, and does not need to be
785  /// converted.
786  bool isOpIgnored(Operation *op) const;
787 
788  /// Return "true" if the given operation was replaced or erased.
789  bool wasOpReplaced(Operation *op) const;
790 
791  //===--------------------------------------------------------------------===//
792  // Type Conversion
793  //===--------------------------------------------------------------------===//
794 
795  /// Convert the types of block arguments within the given region.
796  FailureOr<Block *>
797  convertRegionTypes(ConversionPatternRewriter &rewriter, Region *region,
798  const TypeConverter &converter,
799  TypeConverter::SignatureConversion *entryConversion);
800 
801  /// Apply the given signature conversion on the given block. The new block
802  /// containing the updated signature is returned. If no conversions were
803  /// necessary, e.g. if the block has no arguments, `block` is returned.
804  /// `converter` is used to generate any necessary cast operations that
805  /// translate between the origin argument types and those specified in the
806  /// signature conversion.
807  Block *applySignatureConversion(
808  ConversionPatternRewriter &rewriter, Block *block,
809  const TypeConverter *converter,
810  TypeConverter::SignatureConversion &signatureConversion);
811 
812  //===--------------------------------------------------------------------===//
813  // Materializations
814  //===--------------------------------------------------------------------===//
815 
816  /// Build an unresolved materialization operation given a range of output
817  /// types and a list of input operands. Returns the inputs if they their
818  /// types match the output types.
819  ///
820  /// If a cast op was built, it can optionally be returned with the `castOp`
821  /// output argument.
822  ///
823  /// If `valueToMap` is set to a non-null Value, then that value is mapped to
824  /// the results of the unresolved materialization in the conversion value
825  /// mapping.
826  ValueRange buildUnresolvedMaterialization(
827  MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
828  Value valueToMap, ValueRange inputs, TypeRange outputTypes,
829  Type originalType, const TypeConverter *converter,
830  UnrealizedConversionCastOp *castOp = nullptr);
832  MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
833  Value valueToMap, ValueRange inputs, Type outputType, Type originalType,
834  const TypeConverter *converter,
835  UnrealizedConversionCastOp *castOp = nullptr) {
836  return buildUnresolvedMaterialization(kind, ip, loc, valueToMap, inputs,
837  TypeRange(outputType), originalType,
838  converter, castOp)
839  .front();
840  }
841 
842  /// Build an N:1 materialization for the given original value that was
843  /// replaced with the given replacement values.
844  ///
845  /// This is a workaround around incomplete 1:N support in the dialect
846  /// conversion driver. The conversion mapping can store only 1:1 replacements
847  /// and the conversion patterns only support single Value replacements in the
848  /// adaptor, so N values must be converted back to a single value. This
849  /// function will be deleted when full 1:N support has been added.
850  ///
851  /// This function inserts an argument materialization back to the original
852  /// type.
853  void insertNTo1Materialization(OpBuilder::InsertPoint ip, Location loc,
854  ValueRange replacements, Value originalValue,
855  const TypeConverter *converter);
856 
857  /// Find a replacement value for the given SSA value in the conversion value
858  /// mapping. The replacement value must have the same type as the given SSA
859  /// value. If there is no replacement value with the correct type, find the
860  /// latest replacement value (regardless of the type) and build a source
861  /// materialization.
862  Value findOrBuildReplacementValue(Value value,
863  const TypeConverter *converter);
864 
865  /// Unpack an N:1 materialization and return the inputs of the
866  /// materialization. This function unpacks only those materializations that
867  /// were built with `insertNTo1Materialization`.
868  ///
869  /// This is a workaround around incomplete 1:N support in the dialect
870  /// conversion driver. It allows us to write 1:N conversion patterns while
871  /// 1:N support is still missing in the conversion value mapping. This
872  /// function will be deleted when full 1:N support has been added.
873  SmallVector<Value> unpackNTo1Materialization(Value value);
874 
875  //===--------------------------------------------------------------------===//
876  // Rewriter Notification Hooks
877  //===--------------------------------------------------------------------===//
878 
879  //// Notifies that an op was inserted.
880  void notifyOperationInserted(Operation *op,
881  OpBuilder::InsertPoint previous) override;
882 
883  /// Notifies that an op is about to be replaced with the given values.
884  void notifyOpReplaced(Operation *op, ArrayRef<ValueRange> newValues);
885 
886  /// Notifies that a block is about to be erased.
887  void notifyBlockIsBeingErased(Block *block);
888 
889  /// Notifies that a block was inserted.
890  void notifyBlockInserted(Block *block, Region *previous,
891  Region::iterator previousIt) override;
892 
893  /// Notifies that a block is being inlined into another block.
894  void notifyBlockBeingInlined(Block *block, Block *srcBlock,
895  Block::iterator before);
896 
897  /// Notifies that a pattern match failed for the given reason.
898  void
899  notifyMatchFailure(Location loc,
900  function_ref<void(Diagnostic &)> reasonCallback) override;
901 
902  //===--------------------------------------------------------------------===//
903  // IR Erasure
904  //===--------------------------------------------------------------------===//
905 
906  /// A rewriter that keeps track of erased ops and blocks. It ensures that no
907  /// operation or block is erased multiple times. This rewriter assumes that
908  /// no new IR is created between calls to `eraseOp`/`eraseBlock`.
910  public:
912  : RewriterBase(context, /*listener=*/this) {}
913 
914  /// Erase the given op (unless it was already erased).
915  void eraseOp(Operation *op) override {
916  if (wasErased(op))
917  return;
918  op->dropAllUses();
920  }
921 
922  /// Erase the given block (unless it was already erased).
923  void eraseBlock(Block *block) override {
924  if (wasErased(block))
925  return;
926  assert(block->empty() && "expected empty block");
927  block->dropAllDefinedValueUses();
929  }
930 
931  bool wasErased(void *ptr) const { return erased.contains(ptr); }
932 
933  void notifyOperationErased(Operation *op) override { erased.insert(op); }
934 
935  void notifyBlockErased(Block *block) override { erased.insert(block); }
936 
937  private:
938  /// Pointers to all erased operations and blocks.
939  DenseSet<void *> erased;
940  };
941 
942  //===--------------------------------------------------------------------===//
943  // State
944  //===--------------------------------------------------------------------===//
945 
946  /// MLIR context.
948 
949  /// A rewriter that keeps track of ops/block that were already erased and
950  /// skips duplicate op/block erasures. This rewriter is used during the
951  /// "cleanup" phase.
953 
954  // Mapping between replaced values that differ in type. This happens when
955  // replacing a value with one of a different type.
956  ConversionValueMapping mapping;
957 
958  /// Ordered list of block operations (creations, splits, motions).
960 
961  /// A set of operations that should no longer be considered for legalization.
962  /// E.g., ops that are recursively legal. Ops that were replaced/erased are
963  /// tracked separately.
965 
966  /// A set of operations that were replaced/erased. Such ops are not erased
967  /// immediately but only when the dialect conversion succeeds. In the mean
968  /// time, they should no longer be considered for legalization and any attempt
969  /// to modify/access them is invalid rewriter API usage.
971 
972  /// A mapping of all unresolved materializations (UnrealizedConversionCastOp)
973  /// to the corresponding rewrite objects.
976 
977  /// A set of all N:1 materializations that were added to work around
978  /// incomplete 1:N support in the dialect conversion driver.
980 
981  /// The current type converter, or nullptr if no type converter is currently
982  /// active.
983  const TypeConverter *currentTypeConverter = nullptr;
984 
985  /// A mapping of regions to type converters that should be used when
986  /// converting the arguments of blocks within that region.
988 
989  /// Dialect conversion configuration.
991 
992 #ifndef NDEBUG
993  /// A set of operations that have pending updates. This tracking isn't
994  /// strictly necessary, and is thus only active during debug builds for extra
995  /// verification.
997 
998  /// A logger used to emit diagnostics during the conversion process.
999  llvm::ScopedPrinter logger{llvm::dbgs()};
1000 #endif
1001 };
1002 } // namespace detail
1003 } // namespace mlir
1004 
1005 const ConversionConfig &IRRewrite::getConfig() const {
1006  return rewriterImpl.config;
1007 }
1008 
1009 void BlockTypeConversionRewrite::commit(RewriterBase &rewriter) {
1010  // Inform the listener about all IR modifications that have already taken
1011  // place: References to the original block have been replaced with the new
1012  // block.
1013  if (auto *listener =
1014  dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener()))
1015  for (Operation *op : getNewBlock()->getUsers())
1016  listener->notifyOperationModified(op);
1017 }
1018 
1019 void BlockTypeConversionRewrite::rollback() {
1020  getNewBlock()->replaceAllUsesWith(getOrigBlock());
1021 }
1022 
1023 void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) {
1024  Value repl = rewriterImpl.findOrBuildReplacementValue(arg, converter);
1025  if (!repl)
1026  return;
1027 
1028  if (isa<BlockArgument>(repl)) {
1029  rewriter.replaceAllUsesWith(arg, repl);
1030  return;
1031  }
1032 
1033  // If the replacement value is an operation, we check to make sure that we
1034  // don't replace uses that are within the parent operation of the
1035  // replacement value.
1036  Operation *replOp = cast<OpResult>(repl).getOwner();
1037  Block *replBlock = replOp->getBlock();
1038  rewriter.replaceUsesWithIf(arg, repl, [&](OpOperand &operand) {
1039  Operation *user = operand.getOwner();
1040  return user->getBlock() != replBlock || replOp->isBeforeInBlock(user);
1041  });
1042 }
1043 
1044 void ReplaceBlockArgRewrite::rollback() { rewriterImpl.mapping.erase(arg); }
1045 
1046 void ReplaceOperationRewrite::commit(RewriterBase &rewriter) {
1047  auto *listener =
1048  dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener());
1049 
1050  // Compute replacement values.
1051  SmallVector<Value> replacements =
1052  llvm::map_to_vector(op->getResults(), [&](OpResult result) {
1053  return rewriterImpl.findOrBuildReplacementValue(result, converter);
1054  });
1055 
1056  // Notify the listener that the operation is about to be replaced.
1057  if (listener)
1058  listener->notifyOperationReplaced(op, replacements);
1059 
1060  // Replace all uses with the new values.
1061  for (auto [result, newValue] :
1062  llvm::zip_equal(op->getResults(), replacements))
1063  if (newValue)
1064  rewriter.replaceAllUsesWith(result, newValue);
1065 
1066  // The original op will be erased, so remove it from the set of unlegalized
1067  // ops.
1068  if (getConfig().unlegalizedOps)
1069  getConfig().unlegalizedOps->erase(op);
1070 
1071  // Notify the listener that the operation (and its nested operations) was
1072  // erased.
1073  if (listener) {
1075  [&](Operation *op) { listener->notifyOperationErased(op); });
1076  }
1077 
1078  // Do not erase the operation yet. It may still be referenced in `mapping`.
1079  // Just unlink it for now and erase it during cleanup.
1080  op->getBlock()->getOperations().remove(op);
1081 }
1082 
1083 void ReplaceOperationRewrite::rollback() {
1084  for (auto result : op->getResults())
1085  rewriterImpl.mapping.erase(result);
1086 }
1087 
1088 void ReplaceOperationRewrite::cleanup(RewriterBase &rewriter) {
1089  rewriter.eraseOp(op);
1090 }
1091 
1092 void CreateOperationRewrite::rollback() {
1093  for (Region &region : op->getRegions()) {
1094  while (!region.getBlocks().empty())
1095  region.getBlocks().remove(region.getBlocks().begin());
1096  }
1097  op->dropAllUses();
1098  op->erase();
1099 }
1100 
1101 UnresolvedMaterializationRewrite::UnresolvedMaterializationRewrite(
1102  ConversionPatternRewriterImpl &rewriterImpl, UnrealizedConversionCastOp op,
1103  const TypeConverter *converter, MaterializationKind kind, Type originalType,
1104  Value mappedValue)
1105  : OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
1106  converterAndKind(converter, kind), originalType(originalType),
1107  mappedValue(mappedValue) {
1108  assert((!originalType || kind == MaterializationKind::Target) &&
1109  "original type is valid only for target materializations");
1110  rewriterImpl.unresolvedMaterializations[op] = this;
1111 }
1112 
1113 void UnresolvedMaterializationRewrite::rollback() {
1114  if (mappedValue)
1115  rewriterImpl.mapping.erase(mappedValue);
1116  rewriterImpl.unresolvedMaterializations.erase(getOperation());
1117  rewriterImpl.nTo1TempMaterializations.erase(getOperation());
1118  op->erase();
1119 }
1120 
1122  // Commit all rewrites.
1123  IRRewriter rewriter(context, config.listener);
1124  // Note: New rewrites may be added during the "commit" phase and the
1125  // `rewrites` vector may reallocate.
1126  for (size_t i = 0; i < rewrites.size(); ++i)
1127  rewrites[i]->commit(rewriter);
1128 
1129  // Clean up all rewrites.
1130  for (auto &rewrite : rewrites)
1131  rewrite->cleanup(eraseRewriter);
1132 }
1133 
1134 //===----------------------------------------------------------------------===//
1135 // State Management
1136 
1138  return RewriterState(rewrites.size(), ignoredOps.size(), replacedOps.size());
1139 }
1140 
1142  // Undo any rewrites.
1143  undoRewrites(state.numRewrites);
1144 
1145  // Pop all of the recorded ignored operations that are no longer valid.
1146  while (ignoredOps.size() != state.numIgnoredOperations)
1147  ignoredOps.pop_back();
1148 
1149  while (replacedOps.size() != state.numReplacedOps)
1150  replacedOps.pop_back();
1151 }
1152 
1153 void ConversionPatternRewriterImpl::undoRewrites(unsigned numRewritesToKeep) {
1154  for (auto &rewrite :
1155  llvm::reverse(llvm::drop_begin(rewrites, numRewritesToKeep)))
1156  rewrite->rollback();
1157  rewrites.resize(numRewritesToKeep);
1158 }
1159 
1161  StringRef valueDiagTag, std::optional<Location> inputLoc,
1162  PatternRewriter &rewriter, ValueRange values,
1163  SmallVector<SmallVector<Value>> &remapped) {
1164  remapped.reserve(llvm::size(values));
1165 
1166  for (const auto &it : llvm::enumerate(values)) {
1167  Value operand = it.value();
1168  Type origType = operand.getType();
1169  Location operandLoc = inputLoc ? *inputLoc : operand.getLoc();
1170 
1171  // Find the most recently mapped value. Unpack all temporary N:1
1172  // materializations. Such conversions are a workaround around missing
1173  // 1:N support in the ConversionValueMapping. (The conversion patterns
1174  // already support 1:N replacements.)
1175  Value repl = mapping.lookupOrDefault(operand);
1177 
1178  if (!currentTypeConverter) {
1179  // The current pattern does not have a type converter. I.e., it does not
1180  // distinguish between legal and illegal types. For each operand, simply
1181  // pass through the most recently mapped value.
1182  remapped.push_back(std::move(unpacked));
1183  continue;
1184  }
1185 
1186  // If there is no legal conversion, fail to match this pattern.
1187  SmallVector<Type, 1> legalTypes;
1188  if (failed(currentTypeConverter->convertType(origType, legalTypes))) {
1189  notifyMatchFailure(operandLoc, [=](Diagnostic &diag) {
1190  diag << "unable to convert type for " << valueDiagTag << " #"
1191  << it.index() << ", type was " << origType;
1192  });
1193  return failure();
1194  }
1195 
1196  // If a type is converted to 0 types, there is nothing to do.
1197  if (legalTypes.empty()) {
1198  remapped.push_back({});
1199  continue;
1200  }
1201 
1202  if (legalTypes.size() != 1) {
1203  // TODO: This is a 1:N conversion. The conversion value mapping does not
1204  // store such materializations yet. If the types of the most recently
1205  // mapped values do not match, build a target materialization.
1206  ValueRange unpackedRange(unpacked);
1207  if (TypeRange(unpackedRange) == legalTypes) {
1208  remapped.push_back(std::move(unpacked));
1209  continue;
1210  }
1211 
1212  // Insert a target materialization if the current pattern expects
1213  // different legalized types.
1215  MaterializationKind::Target, computeInsertPoint(repl), operandLoc,
1216  /*valueToMap=*/Value(), /*inputs=*/unpacked,
1217  /*outputType=*/legalTypes, /*originalType=*/origType,
1219  remapped.push_back(targetMat);
1220  continue;
1221  }
1222 
1223  // Handle 1->1 type conversions.
1224  Type desiredType = legalTypes.front();
1225  // Try to find a mapped value with the desired type. (Or the operand itself
1226  // if the value is not mapped at all.)
1227  Value newOperand = mapping.lookupOrDefault(operand, desiredType);
1228  if (newOperand.getType() != desiredType) {
1229  // If the looked up value's type does not have the desired type, it means
1230  // that the value was replaced with a value of different type and no
1231  // target materialization was created yet.
1233  MaterializationKind::Target, computeInsertPoint(newOperand),
1234  operandLoc, /*valueToMap=*/newOperand, /*inputs=*/unpacked,
1235  /*outputType=*/desiredType, /*originalType=*/origType,
1237  newOperand = castValue;
1238  }
1239  remapped.push_back({newOperand});
1240  }
1241  return success();
1242 }
1243 
1245  // Check to see if this operation is ignored or was replaced.
1246  return replacedOps.count(op) || ignoredOps.count(op);
1247 }
1248 
1250  // Check to see if this operation was replaced.
1251  return replacedOps.count(op);
1252 }
1253 
1254 //===----------------------------------------------------------------------===//
1255 // Type Conversion
1256 
1258  ConversionPatternRewriter &rewriter, Region *region,
1259  const TypeConverter &converter,
1260  TypeConverter::SignatureConversion *entryConversion) {
1261  regionToConverter[region] = &converter;
1262  if (region->empty())
1263  return nullptr;
1264 
1265  // Convert the arguments of each non-entry block within the region.
1266  for (Block &block :
1267  llvm::make_early_inc_range(llvm::drop_begin(*region, 1))) {
1268  // Compute the signature for the block with the provided converter.
1269  std::optional<TypeConverter::SignatureConversion> conversion =
1270  converter.convertBlockSignature(&block);
1271  if (!conversion)
1272  return failure();
1273  // Convert the block with the computed signature.
1274  applySignatureConversion(rewriter, &block, &converter, *conversion);
1275  }
1276 
1277  // Convert the entry block. If an entry signature conversion was provided,
1278  // use that one. Otherwise, compute the signature with the type converter.
1279  if (entryConversion)
1280  return applySignatureConversion(rewriter, &region->front(), &converter,
1281  *entryConversion);
1282  std::optional<TypeConverter::SignatureConversion> conversion =
1283  converter.convertBlockSignature(&region->front());
1284  if (!conversion)
1285  return failure();
1286  return applySignatureConversion(rewriter, &region->front(), &converter,
1287  *conversion);
1288 }
1289 
1291  ConversionPatternRewriter &rewriter, Block *block,
1292  const TypeConverter *converter,
1293  TypeConverter::SignatureConversion &signatureConversion) {
1294 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
1295  // A block cannot be converted multiple times.
1296  if (hasRewrite<BlockTypeConversionRewrite>(rewrites, block))
1297  llvm::report_fatal_error("block was already converted");
1298 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
1299 
1300  OpBuilder::InsertionGuard g(rewriter);
1301 
1302  // If no arguments are being changed or added, there is nothing to do.
1303  unsigned origArgCount = block->getNumArguments();
1304  auto convertedTypes = signatureConversion.getConvertedTypes();
1305  if (llvm::equal(block->getArgumentTypes(), convertedTypes))
1306  return block;
1307 
1308  // Compute the locations of all block arguments in the new block.
1309  SmallVector<Location> newLocs(convertedTypes.size(),
1310  rewriter.getUnknownLoc());
1311  for (unsigned i = 0; i < origArgCount; ++i) {
1312  auto inputMap = signatureConversion.getInputMapping(i);
1313  if (!inputMap || inputMap->replacementValue)
1314  continue;
1315  Location origLoc = block->getArgument(i).getLoc();
1316  for (unsigned j = 0; j < inputMap->size; ++j)
1317  newLocs[inputMap->inputNo + j] = origLoc;
1318  }
1319 
1320  // Insert a new block with the converted block argument types and move all ops
1321  // from the old block to the new block.
1322  Block *newBlock =
1323  rewriter.createBlock(block->getParent(), std::next(block->getIterator()),
1324  convertedTypes, newLocs);
1325 
1326  // If a listener is attached to the dialect conversion, ops cannot be moved
1327  // to the destination block in bulk ("fast path"). This is because at the time
1328  // the notifications are sent, it is unknown which ops were moved. Instead,
1329  // ops should be moved one-by-one ("slow path"), so that a separate
1330  // `MoveOperationRewrite` is enqueued for each moved op. Moving ops in bulk is
1331  // a bit more efficient, so we try to do that when possible.
1332  bool fastPath = !config.listener;
1333  if (fastPath) {
1334  appendRewrite<InlineBlockRewrite>(newBlock, block, newBlock->end());
1335  newBlock->getOperations().splice(newBlock->end(), block->getOperations());
1336  } else {
1337  while (!block->empty())
1338  rewriter.moveOpBefore(&block->front(), newBlock, newBlock->end());
1339  }
1340 
1341  // Replace all uses of the old block with the new block.
1342  block->replaceAllUsesWith(newBlock);
1343 
1344  for (unsigned i = 0; i != origArgCount; ++i) {
1345  BlockArgument origArg = block->getArgument(i);
1346  Type origArgType = origArg.getType();
1347 
1348  std::optional<TypeConverter::SignatureConversion::InputMapping> inputMap =
1349  signatureConversion.getInputMapping(i);
1350  if (!inputMap) {
1351  // This block argument was dropped and no replacement value was provided.
1352  // Materialize a replacement value "out of thin air".
1354  MaterializationKind::Source,
1355  OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(),
1356  /*valueToMap=*/origArg, /*inputs=*/ValueRange(),
1357  /*outputType=*/origArgType, /*originalType=*/Type(), converter);
1358  appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
1359  continue;
1360  }
1361 
1362  if (Value repl = inputMap->replacementValue) {
1363  // This block argument was dropped and a replacement value was provided.
1364  assert(inputMap->size == 0 &&
1365  "invalid to provide a replacement value when the argument isn't "
1366  "dropped");
1367  mapping.map(origArg, repl);
1368  appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
1369  continue;
1370  }
1371 
1372  // This is a 1->1+ mapping. 1->N mappings are not fully supported in the
1373  // dialect conversion. Therefore, we need an argument materialization to
1374  // turn the replacement block arguments into a single SSA value that can be
1375  // used as a replacement.
1376  auto replArgs =
1377  newBlock->getArguments().slice(inputMap->inputNo, inputMap->size);
1378  if (replArgs.size() == 1) {
1379  mapping.map(origArg, replArgs.front());
1380  } else {
1382  OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(),
1383  /*replacements=*/replArgs, /*outputValue=*/origArg, converter);
1384  }
1385  appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
1386  }
1387 
1388  appendRewrite<BlockTypeConversionRewrite>(/*origBlock=*/block, newBlock);
1389 
1390  // Erase the old block. (It is just unlinked for now and will be erased during
1391  // cleanup.)
1392  rewriter.eraseBlock(block);
1393 
1394  return newBlock;
1395 }
1396 
1397 //===----------------------------------------------------------------------===//
1398 // Materializations
1399 //===----------------------------------------------------------------------===//
1400 
1401 /// Build an unresolved materialization operation given an output type and set
1402 /// of input operands.
1404  MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
1405  Value valueToMap, ValueRange inputs, TypeRange outputTypes,
1406  Type originalType, const TypeConverter *converter,
1407  UnrealizedConversionCastOp *castOp) {
1408  assert((!originalType || kind == MaterializationKind::Target) &&
1409  "original type is valid only for target materializations");
1410 
1411  // Avoid materializing an unnecessary cast.
1412  if (TypeRange(inputs) == outputTypes) {
1413  if (valueToMap) {
1414  assert(inputs.size() == 1 && "1:N mapping is not supported");
1415  mapping.map(valueToMap, inputs.front());
1416  }
1417  return inputs;
1418  }
1419 
1420  // Create an unresolved materialization. We use a new OpBuilder to avoid
1421  // tracking the materialization like we do for other operations.
1422  OpBuilder builder(outputTypes.front().getContext());
1423  builder.setInsertionPoint(ip.getBlock(), ip.getPoint());
1424  auto convertOp =
1425  builder.create<UnrealizedConversionCastOp>(loc, outputTypes, inputs);
1426  if (valueToMap) {
1427  assert(outputTypes.size() == 1 && "1:N mapping is not supported");
1428  mapping.map(valueToMap, convertOp.getResult(0));
1429  }
1430  if (castOp)
1431  *castOp = convertOp;
1432  appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind,
1433  originalType, valueToMap);
1434  return convertOp.getResults();
1435 }
1436 
1438  OpBuilder::InsertPoint ip, Location loc, ValueRange replacements,
1439  Value originalValue, const TypeConverter *converter) {
1440  // Insert argument materialization back to the original type.
1441  Type originalType = originalValue.getType();
1442  UnrealizedConversionCastOp argCastOp;
1444  MaterializationKind::Argument, ip, loc, /*valueToMap=*/originalValue,
1445  /*inputs=*/replacements, originalType,
1446  /*originalType=*/Type(), converter, &argCastOp);
1447  if (argCastOp)
1448  nTo1TempMaterializations.insert(argCastOp);
1449 }
1450 
1452  Value value, const TypeConverter *converter) {
1453  // Find a replacement value with the same type.
1454  Value repl = mapping.lookupOrNull(value, value.getType());
1455  if (repl)
1456  return repl;
1457 
1458  // Check if the value is dead. No replacement value is needed in that case.
1459  // This is an approximate check that may have false negatives but does not
1460  // require computing and traversing an inverse mapping. (We may end up
1461  // building source materializations that are never used and that fold away.)
1462  if (llvm::all_of(value.getUsers(),
1463  [&](Operation *op) { return replacedOps.contains(op); }) &&
1464  !mapping.isMappedTo(value))
1465  return Value();
1466 
1467  // No replacement value was found. Get the latest replacement value
1468  // (regardless of the type) and build a source materialization to the
1469  // original type.
1470  repl = mapping.lookupOrNull(value);
1471  if (!repl) {
1472  // No replacement value is registered in the mapping. This means that the
1473  // value is dropped and no longer needed. (If the value were still needed,
1474  // a source materialization producing a replacement value "out of thin air"
1475  // would have already been created during `replaceOp` or
1476  // `applySignatureConversion`.)
1477  return Value();
1478  }
1480  MaterializationKind::Source, computeInsertPoint(repl), value.getLoc(),
1481  /*valueToMap=*/value, /*inputs=*/repl, /*outputType=*/value.getType(),
1482  /*originalType=*/Type(), converter);
1483  mapping.map(value, castValue);
1484  return castValue;
1485 }
1486 
1489  // Unpack unrealized_conversion_cast ops that were inserted as a N:1
1490  // workaround.
1491  auto castOp = value.getDefiningOp<UnrealizedConversionCastOp>();
1492  if (!castOp)
1493  return {value};
1494  if (!nTo1TempMaterializations.contains(castOp))
1495  return {value};
1496  assert(castOp->getNumResults() == 1 && "expected single result");
1497 
1498  SmallVector<Value> result;
1499  for (Value v : castOp.getOperands()) {
1500  // Keep unpacking if possible. This is needed because during block
1501  // signature conversions and 1:N op replacements, the driver may have
1502  // inserted two materializations back-to-back: first an argument
1503  // materialization, then a target materialization.
1504  llvm::append_range(result, unpackNTo1Materialization(v));
1505  }
1506  return result;
1507 }
1508 
1509 //===----------------------------------------------------------------------===//
1510 // Rewriter Notification Hooks
1511 
1513  Operation *op, OpBuilder::InsertPoint previous) {
1514  LLVM_DEBUG({
1515  logger.startLine() << "** Insert : '" << op->getName() << "'(" << op
1516  << ")\n";
1517  });
1518  assert(!wasOpReplaced(op->getParentOp()) &&
1519  "attempting to insert into a block within a replaced/erased op");
1520 
1521  if (!previous.isSet()) {
1522  // This is a newly created op.
1523  appendRewrite<CreateOperationRewrite>(op);
1524  return;
1525  }
1526  Operation *prevOp = previous.getPoint() == previous.getBlock()->end()
1527  ? nullptr
1528  : &*previous.getPoint();
1529  appendRewrite<MoveOperationRewrite>(op, previous.getBlock(), prevOp);
1530 }
1531 
1533  Operation *op, ArrayRef<ValueRange> newValues) {
1534  assert(newValues.size() == op->getNumResults());
1535  assert(!ignoredOps.contains(op) && "operation was already replaced");
1536 
1537  // Check if replaced op is an unresolved materialization, i.e., an
1538  // unrealized_conversion_cast op that was created by the conversion driver.
1539  bool isUnresolvedMaterialization = false;
1540  if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op))
1541  if (unresolvedMaterializations.contains(castOp))
1542  isUnresolvedMaterialization = true;
1543 
1544  // Create mappings for each of the new result values.
1545  for (auto [repl, result] : llvm::zip_equal(newValues, op->getResults())) {
1546  if (repl.empty()) {
1547  // This result was dropped and no replacement value was provided.
1548  if (isUnresolvedMaterialization) {
1549  // Do not create another materializations if we are erasing a
1550  // materialization.
1551  continue;
1552  }
1553 
1554  // Materialize a replacement value "out of thin air".
1556  MaterializationKind::Source, computeInsertPoint(result),
1557  result.getLoc(), /*valueToMap=*/result, /*inputs=*/ValueRange(),
1558  /*outputType=*/result.getType(), /*originalType=*/Type(),
1560  continue;
1561  } else {
1562  // Make sure that the user does not mess with unresolved materializations
1563  // that were inserted by the conversion driver. We keep track of these
1564  // ops in internal data structures. Erasing them must be allowed because
1565  // this can happen when the user is erasing an entire block (including
1566  // its body). But replacing them with another value should be forbidden
1567  // to avoid problems with the `mapping`.
1568  assert(!isUnresolvedMaterialization &&
1569  "attempting to replace an unresolved materialization");
1570  }
1571 
1572  // Remap result to replacement value.
1573  if (repl.empty())
1574  continue;
1575 
1576  if (repl.size() == 1) {
1577  // Single replacement value: replace directly.
1578  mapping.map(result, repl.front());
1579  } else {
1580  // Multiple replacement values: insert N:1 materialization.
1582  /*replacements=*/repl, /*outputValue=*/result,
1584  }
1585  }
1586 
1587  appendRewrite<ReplaceOperationRewrite>(op, currentTypeConverter);
1588  // Mark this operation and all nested ops as replaced.
1589  op->walk([&](Operation *op) { replacedOps.insert(op); });
1590 }
1591 
1593  appendRewrite<EraseBlockRewrite>(block);
1594 }
1595 
1597  Block *block, Region *previous, Region::iterator previousIt) {
1598  assert(!wasOpReplaced(block->getParentOp()) &&
1599  "attempting to insert into a region within a replaced/erased op");
1600  LLVM_DEBUG(
1601  {
1602  Operation *parent = block->getParentOp();
1603  if (parent) {
1604  logger.startLine() << "** Insert Block into : '" << parent->getName()
1605  << "'(" << parent << ")\n";
1606  } else {
1607  logger.startLine()
1608  << "** Insert Block into detached Region (nullptr parent op)'";
1609  }
1610  });
1611 
1612  if (!previous) {
1613  // This is a newly created block.
1614  appendRewrite<CreateBlockRewrite>(block);
1615  return;
1616  }
1617  Block *prevBlock = previousIt == previous->end() ? nullptr : &*previousIt;
1618  appendRewrite<MoveBlockRewrite>(block, previous, prevBlock);
1619 }
1620 
1622  Block *block, Block *srcBlock, Block::iterator before) {
1623  appendRewrite<InlineBlockRewrite>(block, srcBlock, before);
1624 }
1625 
1627  Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
1628  LLVM_DEBUG({
1630  reasonCallback(diag);
1631  logger.startLine() << "** Failure : " << diag.str() << "\n";
1632  if (config.notifyCallback)
1634  });
1635 }
1636 
1637 //===----------------------------------------------------------------------===//
1638 // ConversionPatternRewriter
1639 //===----------------------------------------------------------------------===//
1640 
1641 ConversionPatternRewriter::ConversionPatternRewriter(
1642  MLIRContext *ctx, const ConversionConfig &config)
1643  : PatternRewriter(ctx),
1644  impl(new detail::ConversionPatternRewriterImpl(ctx, config)) {
1645  setListener(impl.get());
1646 }
1647 
1649 
1651  assert(op && newOp && "expected non-null op");
1652  replaceOp(op, newOp->getResults());
1653 }
1654 
1656  assert(op->getNumResults() == newValues.size() &&
1657  "incorrect # of replacement values");
1658  LLVM_DEBUG({
1659  impl->logger.startLine()
1660  << "** Replace : '" << op->getName() << "'(" << op << ")\n";
1661  });
1662  SmallVector<ValueRange> newVals;
1663  for (size_t i = 0; i < newValues.size(); ++i)
1664  newVals.push_back(newValues.slice(i, 1));
1665  impl->notifyOpReplaced(op, newVals);
1666 }
1667 
1669  Operation *op, ArrayRef<ValueRange> newValues) {
1670  assert(op->getNumResults() == newValues.size() &&
1671  "incorrect # of replacement values");
1672  LLVM_DEBUG({
1673  impl->logger.startLine()
1674  << "** Replace : '" << op->getName() << "'(" << op << ")\n";
1675  });
1676  impl->notifyOpReplaced(op, newValues);
1677 }
1678 
1680  LLVM_DEBUG({
1681  impl->logger.startLine()
1682  << "** Erase : '" << op->getName() << "'(" << op << ")\n";
1683  });
1684  SmallVector<ValueRange> nullRepls(op->getNumResults(), {});
1685  impl->notifyOpReplaced(op, nullRepls);
1686 }
1687 
1689  assert(!impl->wasOpReplaced(block->getParentOp()) &&
1690  "attempting to erase a block within a replaced/erased op");
1691 
1692  // Mark all ops for erasure.
1693  for (Operation &op : *block)
1694  eraseOp(&op);
1695 
1696  // Unlink the block from its parent region. The block is kept in the rewrite
1697  // object and will be actually destroyed when rewrites are applied. This
1698  // allows us to keep the operations in the block live and undo the removal by
1699  // re-inserting the block.
1700  impl->notifyBlockIsBeingErased(block);
1701  block->getParent()->getBlocks().remove(block);
1702 }
1703 
1705  Block *block, TypeConverter::SignatureConversion &conversion,
1706  const TypeConverter *converter) {
1707  assert(!impl->wasOpReplaced(block->getParentOp()) &&
1708  "attempting to apply a signature conversion to a block within a "
1709  "replaced/erased op");
1710  return impl->applySignatureConversion(*this, block, converter, conversion);
1711 }
1712 
1714  Region *region, const TypeConverter &converter,
1715  TypeConverter::SignatureConversion *entryConversion) {
1716  assert(!impl->wasOpReplaced(region->getParentOp()) &&
1717  "attempting to apply a signature conversion to a block within a "
1718  "replaced/erased op");
1719  return impl->convertRegionTypes(*this, region, converter, entryConversion);
1720 }
1721 
1723  Value to) {
1724  LLVM_DEBUG({
1725  Operation *parentOp = from.getOwner()->getParentOp();
1726  impl->logger.startLine() << "** Replace Argument : '" << from
1727  << "'(in region of '" << parentOp->getName()
1728  << "'(" << from.getOwner()->getParentOp() << ")\n";
1729  });
1730  impl->appendRewrite<ReplaceBlockArgRewrite>(from.getOwner(), from,
1731  impl->currentTypeConverter);
1732  impl->mapping.map(impl->mapping.lookupOrDefault(from), to);
1733 }
1734 
1736  SmallVector<SmallVector<Value>> remappedValues;
1737  if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, key,
1738  remappedValues)))
1739  return nullptr;
1740  assert(remappedValues.front().size() == 1 && "1:N conversion not supported");
1741  return remappedValues.front().front();
1742 }
1743 
1744 LogicalResult
1746  SmallVectorImpl<Value> &results) {
1747  if (keys.empty())
1748  return success();
1750  if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, keys,
1751  remapped)))
1752  return failure();
1753  for (const auto &values : remapped) {
1754  assert(values.size() == 1 && "1:N conversion not supported");
1755  results.push_back(values.front());
1756  }
1757  return success();
1758 }
1759 
1761  Block::iterator before,
1762  ValueRange argValues) {
1763 #ifndef NDEBUG
1764  assert(argValues.size() == source->getNumArguments() &&
1765  "incorrect # of argument replacement values");
1766  assert(!impl->wasOpReplaced(source->getParentOp()) &&
1767  "attempting to inline a block from a replaced/erased op");
1768  assert(!impl->wasOpReplaced(dest->getParentOp()) &&
1769  "attempting to inline a block into a replaced/erased op");
1770  auto opIgnored = [&](Operation *op) { return impl->isOpIgnored(op); };
1771  // The source block will be deleted, so it should not have any users (i.e.,
1772  // there should be no predecessors).
1773  assert(llvm::all_of(source->getUsers(), opIgnored) &&
1774  "expected 'source' to have no predecessors");
1775 #endif // NDEBUG
1776 
1777  // If a listener is attached to the dialect conversion, ops cannot be moved
1778  // to the destination block in bulk ("fast path"). This is because at the time
1779  // the notifications are sent, it is unknown which ops were moved. Instead,
1780  // ops should be moved one-by-one ("slow path"), so that a separate
1781  // `MoveOperationRewrite` is enqueued for each moved op. Moving ops in bulk is
1782  // a bit more efficient, so we try to do that when possible.
1783  bool fastPath = !impl->config.listener;
1784 
1785  if (fastPath)
1786  impl->notifyBlockBeingInlined(dest, source, before);
1787 
1788  // Replace all uses of block arguments.
1789  for (auto it : llvm::zip(source->getArguments(), argValues))
1790  replaceUsesOfBlockArgument(std::get<0>(it), std::get<1>(it));
1791 
1792  if (fastPath) {
1793  // Move all ops at once.
1794  dest->getOperations().splice(before, source->getOperations());
1795  } else {
1796  // Move op by op.
1797  while (!source->empty())
1798  moveOpBefore(&source->front(), dest, before);
1799  }
1800 
1801  // Erase the source block.
1802  eraseBlock(source);
1803 }
1804 
1806  assert(!impl->wasOpReplaced(op) &&
1807  "attempting to modify a replaced/erased op");
1808 #ifndef NDEBUG
1809  impl->pendingRootUpdates.insert(op);
1810 #endif
1811  impl->appendRewrite<ModifyOperationRewrite>(op);
1812 }
1813 
1815  assert(!impl->wasOpReplaced(op) &&
1816  "attempting to modify a replaced/erased op");
1818  // There is nothing to do here, we only need to track the operation at the
1819  // start of the update.
1820 #ifndef NDEBUG
1821  assert(impl->pendingRootUpdates.erase(op) &&
1822  "operation did not have a pending in-place update");
1823 #endif
1824 }
1825 
1827 #ifndef NDEBUG
1828  assert(impl->pendingRootUpdates.erase(op) &&
1829  "operation did not have a pending in-place update");
1830 #endif
1831  // Erase the last update for this operation.
1832  auto it = llvm::find_if(
1833  llvm::reverse(impl->rewrites), [&](std::unique_ptr<IRRewrite> &rewrite) {
1834  auto *modifyRewrite = dyn_cast<ModifyOperationRewrite>(rewrite.get());
1835  return modifyRewrite && modifyRewrite->getOperation() == op;
1836  });
1837  assert(it != impl->rewrites.rend() && "no root update started on op");
1838  (*it)->rollback();
1839  int updateIdx = std::prev(impl->rewrites.rend()) - it;
1840  impl->rewrites.erase(impl->rewrites.begin() + updateIdx);
1841 }
1842 
1844  return *impl;
1845 }
1846 
1847 //===----------------------------------------------------------------------===//
1848 // ConversionPattern
1849 //===----------------------------------------------------------------------===//
1850 
1852  ArrayRef<ValueRange> operands) const {
1853  SmallVector<Value> oneToOneOperands;
1854  oneToOneOperands.reserve(operands.size());
1855  for (ValueRange operand : operands) {
1856  if (operand.size() != 1)
1857  llvm::report_fatal_error("pattern '" + getDebugName() +
1858  "' does not support 1:N conversion");
1859  oneToOneOperands.push_back(operand.front());
1860  }
1861  return oneToOneOperands;
1862 }
1863 
1864 LogicalResult
1866  PatternRewriter &rewriter) const {
1867  auto &dialectRewriter = static_cast<ConversionPatternRewriter &>(rewriter);
1868  auto &rewriterImpl = dialectRewriter.getImpl();
1869 
1870  // Track the current conversion pattern type converter in the rewriter.
1871  llvm::SaveAndRestore currentConverterGuard(rewriterImpl.currentTypeConverter,
1872  getTypeConverter());
1873 
1874  // Remap the operands of the operation.
1876  if (failed(rewriterImpl.remapValues("operand", op->getLoc(), rewriter,
1877  op->getOperands(), remapped))) {
1878  return failure();
1879  }
1880  SmallVector<ValueRange> remappedAsRange =
1881  llvm::to_vector_of<ValueRange>(remapped);
1882  return matchAndRewrite(op, remappedAsRange, dialectRewriter);
1883 }
1884 
1885 //===----------------------------------------------------------------------===//
1886 // OperationLegalizer
1887 //===----------------------------------------------------------------------===//
1888 
1889 namespace {
1890 /// A set of rewrite patterns that can be used to legalize a given operation.
1891 using LegalizationPatterns = SmallVector<const Pattern *, 1>;
1892 
1893 /// This class defines a recursive operation legalizer.
1894 class OperationLegalizer {
1895 public:
1896  using LegalizationAction = ConversionTarget::LegalizationAction;
1897 
1898  OperationLegalizer(const ConversionTarget &targetInfo,
1900  const ConversionConfig &config);
1901 
1902  /// Returns true if the given operation is known to be illegal on the target.
1903  bool isIllegal(Operation *op) const;
1904 
1905  /// Attempt to legalize the given operation. Returns success if the operation
1906  /// was legalized, failure otherwise.
1907  LogicalResult legalize(Operation *op, ConversionPatternRewriter &rewriter);
1908 
1909  /// Returns the conversion target in use by the legalizer.
1910  const ConversionTarget &getTarget() { return target; }
1911 
1912 private:
1913  /// Attempt to legalize the given operation by folding it.
1914  LogicalResult legalizeWithFold(Operation *op,
1915  ConversionPatternRewriter &rewriter);
1916 
1917  /// Attempt to legalize the given operation by applying a pattern. Returns
1918  /// success if the operation was legalized, failure otherwise.
1919  LogicalResult legalizeWithPattern(Operation *op,
1920  ConversionPatternRewriter &rewriter);
1921 
1922  /// Return true if the given pattern may be applied to the given operation,
1923  /// false otherwise.
1924  bool canApplyPattern(Operation *op, const Pattern &pattern,
1925  ConversionPatternRewriter &rewriter);
1926 
1927  /// Legalize the resultant IR after successfully applying the given pattern.
1928  LogicalResult legalizePatternResult(Operation *op, const Pattern &pattern,
1929  ConversionPatternRewriter &rewriter,
1930  RewriterState &curState);
1931 
1932  /// Legalizes the actions registered during the execution of a pattern.
1933  LogicalResult
1934  legalizePatternBlockRewrites(Operation *op,
1935  ConversionPatternRewriter &rewriter,
1937  RewriterState &state, RewriterState &newState);
1938  LogicalResult legalizePatternCreatedOperations(
1940  RewriterState &state, RewriterState &newState);
1941  LogicalResult legalizePatternRootUpdates(ConversionPatternRewriter &rewriter,
1943  RewriterState &state,
1944  RewriterState &newState);
1945 
1946  //===--------------------------------------------------------------------===//
1947  // Cost Model
1948  //===--------------------------------------------------------------------===//
1949 
1950  /// Build an optimistic legalization graph given the provided patterns. This
1951  /// function populates 'anyOpLegalizerPatterns' and 'legalizerPatterns' with
1952  /// patterns for operations that are not directly legal, but may be
1953  /// transitively legal for the current target given the provided patterns.
1954  void buildLegalizationGraph(
1955  LegalizationPatterns &anyOpLegalizerPatterns,
1957 
1958  /// Compute the benefit of each node within the computed legalization graph.
1959  /// This orders the patterns within 'legalizerPatterns' based upon two
1960  /// criteria:
1961  /// 1) Prefer patterns that have the lowest legalization depth, i.e.
1962  /// represent the more direct mapping to the target.
1963  /// 2) When comparing patterns with the same legalization depth, prefer the
1964  /// pattern with the highest PatternBenefit. This allows for users to
1965  /// prefer specific legalizations over others.
1966  void computeLegalizationGraphBenefit(
1967  LegalizationPatterns &anyOpLegalizerPatterns,
1969 
1970  /// Compute the legalization depth when legalizing an operation of the given
1971  /// type.
1972  unsigned computeOpLegalizationDepth(
1973  OperationName op, DenseMap<OperationName, unsigned> &minOpPatternDepth,
1975 
1976  /// Apply the conversion cost model to the given set of patterns, and return
1977  /// the smallest legalization depth of any of the patterns. See
1978  /// `computeLegalizationGraphBenefit` for the breakdown of the cost model.
1979  unsigned applyCostModelToPatterns(
1980  LegalizationPatterns &patterns,
1981  DenseMap<OperationName, unsigned> &minOpPatternDepth,
1983 
1984  /// The current set of patterns that have been applied.
1985  SmallPtrSet<const Pattern *, 8> appliedPatterns;
1986 
1987  /// The legalization information provided by the target.
1988  const ConversionTarget &target;
1989 
1990  /// The pattern applicator to use for conversions.
1991  PatternApplicator applicator;
1992 
1993  /// Dialect conversion configuration.
1994  const ConversionConfig &config;
1995 };
1996 } // namespace
1997 
1998 OperationLegalizer::OperationLegalizer(const ConversionTarget &targetInfo,
2000  const ConversionConfig &config)
2001  : target(targetInfo), applicator(patterns), config(config) {
2002  // The set of patterns that can be applied to illegal operations to transform
2003  // them into legal ones.
2005  LegalizationPatterns anyOpLegalizerPatterns;
2006 
2007  buildLegalizationGraph(anyOpLegalizerPatterns, legalizerPatterns);
2008  computeLegalizationGraphBenefit(anyOpLegalizerPatterns, legalizerPatterns);
2009 }
2010 
2011 bool OperationLegalizer::isIllegal(Operation *op) const {
2012  return target.isIllegal(op);
2013 }
2014 
2015 LogicalResult
2016 OperationLegalizer::legalize(Operation *op,
2017  ConversionPatternRewriter &rewriter) {
2018 #ifndef NDEBUG
2019  const char *logLineComment =
2020  "//===-------------------------------------------===//\n";
2021 
2022  auto &logger = rewriter.getImpl().logger;
2023 #endif
2024  LLVM_DEBUG({
2025  logger.getOStream() << "\n";
2026  logger.startLine() << logLineComment;
2027  logger.startLine() << "Legalizing operation : '" << op->getName() << "'("
2028  << op << ") {\n";
2029  logger.indent();
2030 
2031  // If the operation has no regions, just print it here.
2032  if (op->getNumRegions() == 0) {
2033  op->print(logger.startLine(), OpPrintingFlags().printGenericOpForm());
2034  logger.getOStream() << "\n\n";
2035  }
2036  });
2037 
2038  // Check if this operation is legal on the target.
2039  if (auto legalityInfo = target.isLegal(op)) {
2040  LLVM_DEBUG({
2041  logSuccess(
2042  logger, "operation marked legal by the target{0}",
2043  legalityInfo->isRecursivelyLegal
2044  ? "; NOTE: operation is recursively legal; skipping internals"
2045  : "");
2046  logger.startLine() << logLineComment;
2047  });
2048 
2049  // If this operation is recursively legal, mark its children as ignored so
2050  // that we don't consider them for legalization.
2051  if (legalityInfo->isRecursivelyLegal) {
2052  op->walk([&](Operation *nested) {
2053  if (op != nested)
2054  rewriter.getImpl().ignoredOps.insert(nested);
2055  });
2056  }
2057 
2058  return success();
2059  }
2060 
2061  // Check to see if the operation is ignored and doesn't need to be converted.
2062  if (rewriter.getImpl().isOpIgnored(op)) {
2063  LLVM_DEBUG({
2064  logSuccess(logger, "operation marked 'ignored' during conversion");
2065  logger.startLine() << logLineComment;
2066  });
2067  return success();
2068  }
2069 
2070  // If the operation isn't legal, try to fold it in-place.
2071  // TODO: Should we always try to do this, even if the op is
2072  // already legal?
2073  if (succeeded(legalizeWithFold(op, rewriter))) {
2074  LLVM_DEBUG({
2075  logSuccess(logger, "operation was folded");
2076  logger.startLine() << logLineComment;
2077  });
2078  return success();
2079  }
2080 
2081  // Otherwise, we need to apply a legalization pattern to this operation.
2082  if (succeeded(legalizeWithPattern(op, rewriter))) {
2083  LLVM_DEBUG({
2084  logSuccess(logger, "");
2085  logger.startLine() << logLineComment;
2086  });
2087  return success();
2088  }
2089 
2090  LLVM_DEBUG({
2091  logFailure(logger, "no matched legalization pattern");
2092  logger.startLine() << logLineComment;
2093  });
2094  return failure();
2095 }
2096 
2097 LogicalResult
2098 OperationLegalizer::legalizeWithFold(Operation *op,
2099  ConversionPatternRewriter &rewriter) {
2100  auto &rewriterImpl = rewriter.getImpl();
2101  RewriterState curState = rewriterImpl.getCurrentState();
2102 
2103  LLVM_DEBUG({
2104  rewriterImpl.logger.startLine() << "* Fold {\n";
2105  rewriterImpl.logger.indent();
2106  });
2107 
2108  // Try to fold the operation.
2109  SmallVector<Value, 2> replacementValues;
2110  rewriter.setInsertionPoint(op);
2111  if (failed(rewriter.tryFold(op, replacementValues))) {
2112  LLVM_DEBUG(logFailure(rewriterImpl.logger, "unable to fold"));
2113  return failure();
2114  }
2115  // An empty list of replacement values indicates that the fold was in-place.
2116  // As the operation changed, a new legalization needs to be attempted.
2117  if (replacementValues.empty())
2118  return legalize(op, rewriter);
2119 
2120  // Insert a replacement for 'op' with the folded replacement values.
2121  rewriter.replaceOp(op, replacementValues);
2122 
2123  // Recursively legalize any new constant operations.
2124  for (unsigned i = curState.numRewrites, e = rewriterImpl.rewrites.size();
2125  i != e; ++i) {
2126  auto *createOp =
2127  dyn_cast<CreateOperationRewrite>(rewriterImpl.rewrites[i].get());
2128  if (!createOp)
2129  continue;
2130  if (failed(legalize(createOp->getOperation(), rewriter))) {
2131  LLVM_DEBUG(logFailure(rewriterImpl.logger,
2132  "failed to legalize generated constant '{0}'",
2133  createOp->getOperation()->getName()));
2134  rewriterImpl.resetState(curState);
2135  return failure();
2136  }
2137  }
2138 
2139  LLVM_DEBUG(logSuccess(rewriterImpl.logger, ""));
2140  return success();
2141 }
2142 
2143 LogicalResult
2144 OperationLegalizer::legalizeWithPattern(Operation *op,
2145  ConversionPatternRewriter &rewriter) {
2146  auto &rewriterImpl = rewriter.getImpl();
2147 
2148  // Functor that returns if the given pattern may be applied.
2149  auto canApply = [&](const Pattern &pattern) {
2150  bool canApply = canApplyPattern(op, pattern, rewriter);
2151  if (canApply && config.listener)
2152  config.listener->notifyPatternBegin(pattern, op);
2153  return canApply;
2154  };
2155 
2156  // Functor that cleans up the rewriter state after a pattern failed to match.
2157  RewriterState curState = rewriterImpl.getCurrentState();
2158  auto onFailure = [&](const Pattern &pattern) {
2159  assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates");
2160  LLVM_DEBUG({
2161  logFailure(rewriterImpl.logger, "pattern failed to match");
2162  if (rewriterImpl.config.notifyCallback) {
2164  diag << "Failed to apply pattern \"" << pattern.getDebugName()
2165  << "\" on op:\n"
2166  << *op;
2167  rewriterImpl.config.notifyCallback(diag);
2168  }
2169  });
2170  if (config.listener)
2171  config.listener->notifyPatternEnd(pattern, failure());
2172  rewriterImpl.resetState(curState);
2173  appliedPatterns.erase(&pattern);
2174  };
2175 
2176  // Functor that performs additional legalization when a pattern is
2177  // successfully applied.
2178  auto onSuccess = [&](const Pattern &pattern) {
2179  assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates");
2180  auto result = legalizePatternResult(op, pattern, rewriter, curState);
2181  appliedPatterns.erase(&pattern);
2182  if (failed(result))
2183  rewriterImpl.resetState(curState);
2184  if (config.listener)
2185  config.listener->notifyPatternEnd(pattern, result);
2186  return result;
2187  };
2188 
2189  // Try to match and rewrite a pattern on this operation.
2190  return applicator.matchAndRewrite(op, rewriter, canApply, onFailure,
2191  onSuccess);
2192 }
2193 
2194 bool OperationLegalizer::canApplyPattern(Operation *op, const Pattern &pattern,
2195  ConversionPatternRewriter &rewriter) {
2196  LLVM_DEBUG({
2197  auto &os = rewriter.getImpl().logger;
2198  os.getOStream() << "\n";
2199  os.startLine() << "* Pattern : '" << op->getName() << " -> (";
2200  llvm::interleaveComma(pattern.getGeneratedOps(), os.getOStream());
2201  os.getOStream() << ")' {\n";
2202  os.indent();
2203  });
2204 
2205  // Ensure that we don't cycle by not allowing the same pattern to be
2206  // applied twice in the same recursion stack if it is not known to be safe.
2207  if (!pattern.hasBoundedRewriteRecursion() &&
2208  !appliedPatterns.insert(&pattern).second) {
2209  LLVM_DEBUG(
2210  logFailure(rewriter.getImpl().logger, "pattern was already applied"));
2211  return false;
2212  }
2213  return true;
2214 }
2215 
2216 LogicalResult
2217 OperationLegalizer::legalizePatternResult(Operation *op, const Pattern &pattern,
2218  ConversionPatternRewriter &rewriter,
2219  RewriterState &curState) {
2220  auto &impl = rewriter.getImpl();
2221  assert(impl.pendingRootUpdates.empty() && "dangling root updates");
2222 
2223 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2224  // Check that the root was either replaced or updated in place.
2225  auto newRewrites = llvm::drop_begin(impl.rewrites, curState.numRewrites);
2226  auto replacedRoot = [&] {
2227  return hasRewrite<ReplaceOperationRewrite>(newRewrites, op);
2228  };
2229  auto updatedRootInPlace = [&] {
2230  return hasRewrite<ModifyOperationRewrite>(newRewrites, op);
2231  };
2232  if (!replacedRoot() && !updatedRootInPlace())
2233  llvm::report_fatal_error("expected pattern to replace the root operation");
2234 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2235 
2236  // Legalize each of the actions registered during application.
2237  RewriterState newState = impl.getCurrentState();
2238  if (failed(legalizePatternBlockRewrites(op, rewriter, impl, curState,
2239  newState)) ||
2240  failed(legalizePatternRootUpdates(rewriter, impl, curState, newState)) ||
2241  failed(legalizePatternCreatedOperations(rewriter, impl, curState,
2242  newState))) {
2243  return failure();
2244  }
2245 
2246  LLVM_DEBUG(logSuccess(impl.logger, "pattern applied successfully"));
2247  return success();
2248 }
2249 
2250 LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
2251  Operation *op, ConversionPatternRewriter &rewriter,
2252  ConversionPatternRewriterImpl &impl, RewriterState &state,
2253  RewriterState &newState) {
2254  SmallPtrSet<Operation *, 16> operationsToIgnore;
2255 
2256  // If the pattern moved or created any blocks, make sure the types of block
2257  // arguments get legalized.
2258  for (int i = state.numRewrites, e = newState.numRewrites; i != e; ++i) {
2259  BlockRewrite *rewrite = dyn_cast<BlockRewrite>(impl.rewrites[i].get());
2260  if (!rewrite)
2261  continue;
2262  Block *block = rewrite->getBlock();
2263  if (isa<BlockTypeConversionRewrite, EraseBlockRewrite,
2264  ReplaceBlockArgRewrite>(rewrite))
2265  continue;
2266  // Only check blocks outside of the current operation.
2267  Operation *parentOp = block->getParentOp();
2268  if (!parentOp || parentOp == op || block->getNumArguments() == 0)
2269  continue;
2270 
2271  // If the region of the block has a type converter, try to convert the block
2272  // directly.
2273  if (auto *converter = impl.regionToConverter.lookup(block->getParent())) {
2274  std::optional<TypeConverter::SignatureConversion> conversion =
2275  converter->convertBlockSignature(block);
2276  if (!conversion) {
2277  LLVM_DEBUG(logFailure(impl.logger, "failed to convert types of moved "
2278  "block"));
2279  return failure();
2280  }
2281  impl.applySignatureConversion(rewriter, block, converter, *conversion);
2282  continue;
2283  }
2284 
2285  // Otherwise, check that this operation isn't one generated by this pattern.
2286  // This is because we will attempt to legalize the parent operation, and
2287  // blocks in regions created by this pattern will already be legalized later
2288  // on. If we haven't built the set yet, build it now.
2289  if (operationsToIgnore.empty()) {
2290  for (unsigned i = state.numRewrites, e = impl.rewrites.size(); i != e;
2291  ++i) {
2292  auto *createOp =
2293  dyn_cast<CreateOperationRewrite>(impl.rewrites[i].get());
2294  if (!createOp)
2295  continue;
2296  operationsToIgnore.insert(createOp->getOperation());
2297  }
2298  }
2299 
2300  // If this operation should be considered for re-legalization, try it.
2301  if (operationsToIgnore.insert(parentOp).second &&
2302  failed(legalize(parentOp, rewriter))) {
2303  LLVM_DEBUG(logFailure(impl.logger,
2304  "operation '{0}'({1}) became illegal after rewrite",
2305  parentOp->getName(), parentOp));
2306  return failure();
2307  }
2308  }
2309  return success();
2310 }
2311 
2312 LogicalResult OperationLegalizer::legalizePatternCreatedOperations(
2314  RewriterState &state, RewriterState &newState) {
2315  for (int i = state.numRewrites, e = newState.numRewrites; i != e; ++i) {
2316  auto *createOp = dyn_cast<CreateOperationRewrite>(impl.rewrites[i].get());
2317  if (!createOp)
2318  continue;
2319  Operation *op = createOp->getOperation();
2320  if (failed(legalize(op, rewriter))) {
2321  LLVM_DEBUG(logFailure(impl.logger,
2322  "failed to legalize generated operation '{0}'({1})",
2323  op->getName(), op));
2324  return failure();
2325  }
2326  }
2327  return success();
2328 }
2329 
2330 LogicalResult OperationLegalizer::legalizePatternRootUpdates(
2332  RewriterState &state, RewriterState &newState) {
2333  for (int i = state.numRewrites, e = newState.numRewrites; i != e; ++i) {
2334  auto *rewrite = dyn_cast<ModifyOperationRewrite>(impl.rewrites[i].get());
2335  if (!rewrite)
2336  continue;
2337  Operation *op = rewrite->getOperation();
2338  if (failed(legalize(op, rewriter))) {
2339  LLVM_DEBUG(logFailure(
2340  impl.logger, "failed to legalize operation updated in-place '{0}'",
2341  op->getName()));
2342  return failure();
2343  }
2344  }
2345  return success();
2346 }
2347 
2348 //===----------------------------------------------------------------------===//
2349 // Cost Model
2350 
2351 void OperationLegalizer::buildLegalizationGraph(
2352  LegalizationPatterns &anyOpLegalizerPatterns,
2353  DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
2354  // A mapping between an operation and a set of operations that can be used to
2355  // generate it.
2357  // A mapping between an operation and any currently invalid patterns it has.
2359  // A worklist of patterns to consider for legality.
2360  SetVector<const Pattern *> patternWorklist;
2361 
2362  // Build the mapping from operations to the parent ops that may generate them.
2363  applicator.walkAllPatterns([&](const Pattern &pattern) {
2364  std::optional<OperationName> root = pattern.getRootKind();
2365 
2366  // If the pattern has no specific root, we can't analyze the relationship
2367  // between the root op and generated operations. Given that, add all such
2368  // patterns to the legalization set.
2369  if (!root) {
2370  anyOpLegalizerPatterns.push_back(&pattern);
2371  return;
2372  }
2373 
2374  // Skip operations that are always known to be legal.
2375  if (target.getOpAction(*root) == LegalizationAction::Legal)
2376  return;
2377 
2378  // Add this pattern to the invalid set for the root op and record this root
2379  // as a parent for any generated operations.
2380  invalidPatterns[*root].insert(&pattern);
2381  for (auto op : pattern.getGeneratedOps())
2382  parentOps[op].insert(*root);
2383 
2384  // Add this pattern to the worklist.
2385  patternWorklist.insert(&pattern);
2386  });
2387 
2388  // If there are any patterns that don't have a specific root kind, we can't
2389  // make direct assumptions about what operations will never be legalized.
2390  // Note: Technically we could, but it would require an analysis that may
2391  // recurse into itself. It would be better to perform this kind of filtering
2392  // at a higher level than here anyways.
2393  if (!anyOpLegalizerPatterns.empty()) {
2394  for (const Pattern *pattern : patternWorklist)
2395  legalizerPatterns[*pattern->getRootKind()].push_back(pattern);
2396  return;
2397  }
2398 
2399  while (!patternWorklist.empty()) {
2400  auto *pattern = patternWorklist.pop_back_val();
2401 
2402  // Check to see if any of the generated operations are invalid.
2403  if (llvm::any_of(pattern->getGeneratedOps(), [&](OperationName op) {
2404  std::optional<LegalizationAction> action = target.getOpAction(op);
2405  return !legalizerPatterns.count(op) &&
2406  (!action || action == LegalizationAction::Illegal);
2407  }))
2408  continue;
2409 
2410  // Otherwise, if all of the generated operation are valid, this op is now
2411  // legal so add all of the child patterns to the worklist.
2412  legalizerPatterns[*pattern->getRootKind()].push_back(pattern);
2413  invalidPatterns[*pattern->getRootKind()].erase(pattern);
2414 
2415  // Add any invalid patterns of the parent operations to see if they have now
2416  // become legal.
2417  for (auto op : parentOps[*pattern->getRootKind()])
2418  patternWorklist.set_union(invalidPatterns[op]);
2419  }
2420 }
2421 
2422 void OperationLegalizer::computeLegalizationGraphBenefit(
2423  LegalizationPatterns &anyOpLegalizerPatterns,
2424  DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
2425  // The smallest pattern depth, when legalizing an operation.
2426  DenseMap<OperationName, unsigned> minOpPatternDepth;
2427 
2428  // For each operation that is transitively legal, compute a cost for it.
2429  for (auto &opIt : legalizerPatterns)
2430  if (!minOpPatternDepth.count(opIt.first))
2431  computeOpLegalizationDepth(opIt.first, minOpPatternDepth,
2432  legalizerPatterns);
2433 
2434  // Apply the cost model to the patterns that can match any operation. Those
2435  // with a specific operation type are already resolved when computing the op
2436  // legalization depth.
2437  if (!anyOpLegalizerPatterns.empty())
2438  applyCostModelToPatterns(anyOpLegalizerPatterns, minOpPatternDepth,
2439  legalizerPatterns);
2440 
2441  // Apply a cost model to the pattern applicator. We order patterns first by
2442  // depth then benefit. `legalizerPatterns` contains per-op patterns by
2443  // decreasing benefit.
2444  applicator.applyCostModel([&](const Pattern &pattern) {
2445  ArrayRef<const Pattern *> orderedPatternList;
2446  if (std::optional<OperationName> rootName = pattern.getRootKind())
2447  orderedPatternList = legalizerPatterns[*rootName];
2448  else
2449  orderedPatternList = anyOpLegalizerPatterns;
2450 
2451  // If the pattern is not found, then it was removed and cannot be matched.
2452  auto *it = llvm::find(orderedPatternList, &pattern);
2453  if (it == orderedPatternList.end())
2455 
2456  // Patterns found earlier in the list have higher benefit.
2457  return PatternBenefit(std::distance(it, orderedPatternList.end()));
2458  });
2459 }
2460 
2461 unsigned OperationLegalizer::computeOpLegalizationDepth(
2462  OperationName op, DenseMap<OperationName, unsigned> &minOpPatternDepth,
2463  DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
2464  // Check for existing depth.
2465  auto depthIt = minOpPatternDepth.find(op);
2466  if (depthIt != minOpPatternDepth.end())
2467  return depthIt->second;
2468 
2469  // If a mapping for this operation does not exist, then this operation
2470  // is always legal. Return 0 as the depth for a directly legal operation.
2471  auto opPatternsIt = legalizerPatterns.find(op);
2472  if (opPatternsIt == legalizerPatterns.end() || opPatternsIt->second.empty())
2473  return 0u;
2474 
2475  // Record this initial depth in case we encounter this op again when
2476  // recursively computing the depth.
2477  minOpPatternDepth.try_emplace(op, std::numeric_limits<unsigned>::max());
2478 
2479  // Apply the cost model to the operation patterns, and update the minimum
2480  // depth.
2481  unsigned minDepth = applyCostModelToPatterns(
2482  opPatternsIt->second, minOpPatternDepth, legalizerPatterns);
2483  minOpPatternDepth[op] = minDepth;
2484  return minDepth;
2485 }
2486 
2487 unsigned OperationLegalizer::applyCostModelToPatterns(
2488  LegalizationPatterns &patterns,
2489  DenseMap<OperationName, unsigned> &minOpPatternDepth,
2490  DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
2491  unsigned minDepth = std::numeric_limits<unsigned>::max();
2492 
2493  // Compute the depth for each pattern within the set.
2495  patternsByDepth.reserve(patterns.size());
2496  for (const Pattern *pattern : patterns) {
2497  unsigned depth = 1;
2498  for (auto generatedOp : pattern->getGeneratedOps()) {
2499  unsigned generatedOpDepth = computeOpLegalizationDepth(
2500  generatedOp, minOpPatternDepth, legalizerPatterns);
2501  depth = std::max(depth, generatedOpDepth + 1);
2502  }
2503  patternsByDepth.emplace_back(pattern, depth);
2504 
2505  // Update the minimum depth of the pattern list.
2506  minDepth = std::min(minDepth, depth);
2507  }
2508 
2509  // If the operation only has one legalization pattern, there is no need to
2510  // sort them.
2511  if (patternsByDepth.size() == 1)
2512  return minDepth;
2513 
2514  // Sort the patterns by those likely to be the most beneficial.
2515  std::stable_sort(patternsByDepth.begin(), patternsByDepth.end(),
2516  [](const std::pair<const Pattern *, unsigned> &lhs,
2517  const std::pair<const Pattern *, unsigned> &rhs) {
2518  // First sort by the smaller pattern legalization
2519  // depth.
2520  if (lhs.second != rhs.second)
2521  return lhs.second < rhs.second;
2522 
2523  // Then sort by the larger pattern benefit.
2524  auto lhsBenefit = lhs.first->getBenefit();
2525  auto rhsBenefit = rhs.first->getBenefit();
2526  return lhsBenefit > rhsBenefit;
2527  });
2528 
2529  // Update the legalization pattern to use the new sorted list.
2530  patterns.clear();
2531  for (auto &patternIt : patternsByDepth)
2532  patterns.push_back(patternIt.first);
2533  return minDepth;
2534 }
2535 
2536 //===----------------------------------------------------------------------===//
2537 // OperationConverter
2538 //===----------------------------------------------------------------------===//
2539 namespace {
2540 enum OpConversionMode {
2541  /// In this mode, the conversion will ignore failed conversions to allow
2542  /// illegal operations to co-exist in the IR.
2543  Partial,
2544 
2545  /// In this mode, all operations must be legal for the given target for the
2546  /// conversion to succeed.
2547  Full,
2548 
2549  /// In this mode, operations are analyzed for legality. No actual rewrites are
2550  /// applied to the operations on success.
2551  Analysis,
2552 };
2553 } // namespace
2554 
2555 namespace mlir {
2556 // This class converts operations to a given conversion target via a set of
2557 // rewrite patterns. The conversion behaves differently depending on the
2558 // conversion mode.
2560  explicit OperationConverter(const ConversionTarget &target,
2562  const ConversionConfig &config,
2563  OpConversionMode mode)
2564  : config(config), opLegalizer(target, patterns, this->config),
2565  mode(mode) {}
2566 
2567  /// Converts the given operations to the conversion target.
2568  LogicalResult convertOperations(ArrayRef<Operation *> ops);
2569 
2570 private:
2571  /// Converts an operation with the given rewriter.
2572  LogicalResult convert(ConversionPatternRewriter &rewriter, Operation *op);
2573 
2574  /// Dialect conversion configuration.
2575  ConversionConfig config;
2576 
2577  /// The legalizer to use when converting operations.
2578  OperationLegalizer opLegalizer;
2579 
2580  /// The conversion mode to use when legalizing operations.
2581  OpConversionMode mode;
2582 };
2583 } // namespace mlir
2584 
2585 LogicalResult OperationConverter::convert(ConversionPatternRewriter &rewriter,
2586  Operation *op) {
2587  // Legalize the given operation.
2588  if (failed(opLegalizer.legalize(op, rewriter))) {
2589  // Handle the case of a failed conversion for each of the different modes.
2590  // Full conversions expect all operations to be converted.
2591  if (mode == OpConversionMode::Full)
2592  return op->emitError()
2593  << "failed to legalize operation '" << op->getName() << "'";
2594  // Partial conversions allow conversions to fail iff the operation was not
2595  // explicitly marked as illegal. If the user provided a `unlegalizedOps`
2596  // set, non-legalizable ops are added to that set.
2597  if (mode == OpConversionMode::Partial) {
2598  if (opLegalizer.isIllegal(op))
2599  return op->emitError()
2600  << "failed to legalize operation '" << op->getName()
2601  << "' that was explicitly marked illegal";
2602  if (config.unlegalizedOps)
2603  config.unlegalizedOps->insert(op);
2604  }
2605  } else if (mode == OpConversionMode::Analysis) {
2606  // Analysis conversions don't fail if any operations fail to legalize,
2607  // they are only interested in the operations that were successfully
2608  // legalized.
2609  if (config.legalizableOps)
2610  config.legalizableOps->insert(op);
2611  }
2612  return success();
2613 }
2614 
2615 static LogicalResult
2617  UnresolvedMaterializationRewrite *rewrite) {
2618  UnrealizedConversionCastOp op = rewrite->getOperation();
2619  assert(!op.use_empty() &&
2620  "expected that dead materializations have already been DCE'd");
2621  Operation::operand_range inputOperands = op.getOperands();
2622 
2623  // Try to materialize the conversion.
2624  if (const TypeConverter *converter = rewrite->getConverter()) {
2625  rewriter.setInsertionPoint(op);
2626  SmallVector<Value> newMaterialization;
2627  switch (rewrite->getMaterializationKind()) {
2629  // Try to materialize an argument conversion.
2630  assert(op->getNumResults() == 1 && "expected single result");
2631  Value argMat = converter->materializeArgumentConversion(
2632  rewriter, op->getLoc(), op.getResultTypes().front(), inputOperands);
2633  if (argMat) {
2634  newMaterialization.push_back(argMat);
2635  break;
2636  }
2637  }
2638  // If an argument materialization failed, fallback to trying a target
2639  // materialization.
2640  [[fallthrough]];
2641  case MaterializationKind::Target:
2642  newMaterialization = converter->materializeTargetConversion(
2643  rewriter, op->getLoc(), op.getResultTypes(), inputOperands,
2644  rewrite->getOriginalType());
2645  break;
2646  case MaterializationKind::Source:
2647  assert(op->getNumResults() == 1 && "expected single result");
2648  Value sourceMat = converter->materializeSourceConversion(
2649  rewriter, op->getLoc(), op.getResultTypes().front(), inputOperands);
2650  if (sourceMat)
2651  newMaterialization.push_back(sourceMat);
2652  break;
2653  }
2654  if (!newMaterialization.empty()) {
2655 #ifndef NDEBUG
2656  ValueRange newMaterializationRange(newMaterialization);
2657  assert(TypeRange(newMaterializationRange) == op.getResultTypes() &&
2658  "materialization callback produced value of incorrect type");
2659 #endif // NDEBUG
2660  rewriter.replaceOp(op, newMaterialization);
2661  return success();
2662  }
2663  }
2664 
2665  InFlightDiagnostic diag = op->emitError()
2666  << "failed to legalize unresolved materialization "
2667  "from ("
2668  << inputOperands.getTypes() << ") to ("
2669  << op.getResultTypes()
2670  << ") that remained live after conversion";
2671  diag.attachNote(op->getUsers().begin()->getLoc())
2672  << "see existing live user here: " << *op->getUsers().begin();
2673  return failure();
2674 }
2675 
2677  if (ops.empty())
2678  return success();
2679  const ConversionTarget &target = opLegalizer.getTarget();
2680 
2681  // Compute the set of operations and blocks to convert.
2682  SmallVector<Operation *> toConvert;
2683  for (auto *op : ops) {
2685  [&](Operation *op) {
2686  toConvert.push_back(op);
2687  // Don't check this operation's children for conversion if the
2688  // operation is recursively legal.
2689  auto legalityInfo = target.isLegal(op);
2690  if (legalityInfo && legalityInfo->isRecursivelyLegal)
2691  return WalkResult::skip();
2692  return WalkResult::advance();
2693  });
2694  }
2695 
2696  // Convert each operation and discard rewrites on failure.
2697  ConversionPatternRewriter rewriter(ops.front()->getContext(), config);
2698  ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
2699 
2700  for (auto *op : toConvert)
2701  if (failed(convert(rewriter, op)))
2702  return rewriterImpl.undoRewrites(), failure();
2703 
2704  // After a successful conversion, apply rewrites.
2705  rewriterImpl.applyRewrites();
2706 
2707  // Gather all unresolved materializations.
2710  &materializations = rewriterImpl.unresolvedMaterializations;
2711  for (auto it : materializations) {
2712  if (rewriterImpl.eraseRewriter.wasErased(it.first))
2713  continue;
2714  allCastOps.push_back(it.first);
2715  }
2716 
2717  // Reconcile all UnrealizedConversionCastOps that were inserted by the
2718  // dialect conversion frameworks. (Not the one that were inserted by
2719  // patterns.)
2720  SmallVector<UnrealizedConversionCastOp> remainingCastOps;
2721  reconcileUnrealizedCasts(allCastOps, &remainingCastOps);
2722 
2723  // Try to legalize all unresolved materializations.
2724  if (config.buildMaterializations) {
2725  IRRewriter rewriter(rewriterImpl.context, config.listener);
2726  for (UnrealizedConversionCastOp castOp : remainingCastOps) {
2727  auto it = materializations.find(castOp);
2728  assert(it != materializations.end() && "inconsistent state");
2729  if (failed(legalizeUnresolvedMaterialization(rewriter, it->second)))
2730  return failure();
2731  }
2732  }
2733 
2734  return success();
2735 }
2736 
2737 //===----------------------------------------------------------------------===//
2738 // Reconcile Unrealized Casts
2739 //===----------------------------------------------------------------------===//
2740 
2743  SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
2744  SetVector<UnrealizedConversionCastOp> worklist(castOps.begin(),
2745  castOps.end());
2746  // This set is maintained only if `remainingCastOps` is provided.
2747  DenseSet<Operation *> erasedOps;
2748 
2749  // Helper function that adds all operands to the worklist that are an
2750  // unrealized_conversion_cast op result.
2751  auto enqueueOperands = [&](UnrealizedConversionCastOp castOp) {
2752  for (Value v : castOp.getInputs())
2753  if (auto inputCastOp = v.getDefiningOp<UnrealizedConversionCastOp>())
2754  worklist.insert(inputCastOp);
2755  };
2756 
2757  // Helper function that return the unrealized_conversion_cast op that
2758  // defines all inputs of the given op (in the same order). Return "nullptr"
2759  // if there is no such op.
2760  auto getInputCast =
2761  [](UnrealizedConversionCastOp castOp) -> UnrealizedConversionCastOp {
2762  if (castOp.getInputs().empty())
2763  return {};
2764  auto inputCastOp =
2765  castOp.getInputs().front().getDefiningOp<UnrealizedConversionCastOp>();
2766  if (!inputCastOp)
2767  return {};
2768  if (inputCastOp.getOutputs() != castOp.getInputs())
2769  return {};
2770  return inputCastOp;
2771  };
2772 
2773  // Process ops in the worklist bottom-to-top.
2774  while (!worklist.empty()) {
2775  UnrealizedConversionCastOp castOp = worklist.pop_back_val();
2776  if (castOp->use_empty()) {
2777  // DCE: If the op has no users, erase it. Add the operands to the
2778  // worklist to find additional DCE opportunities.
2779  enqueueOperands(castOp);
2780  if (remainingCastOps)
2781  erasedOps.insert(castOp.getOperation());
2782  castOp->erase();
2783  continue;
2784  }
2785 
2786  // Traverse the chain of input cast ops to see if an op with the same
2787  // input types can be found.
2788  UnrealizedConversionCastOp nextCast = castOp;
2789  while (nextCast) {
2790  if (nextCast.getInputs().getTypes() == castOp.getResultTypes()) {
2791  // Found a cast where the input types match the output types of the
2792  // matched op. We can directly use those inputs and the matched op can
2793  // be removed.
2794  enqueueOperands(castOp);
2795  castOp.replaceAllUsesWith(nextCast.getInputs());
2796  if (remainingCastOps)
2797  erasedOps.insert(castOp.getOperation());
2798  castOp->erase();
2799  break;
2800  }
2801  nextCast = getInputCast(nextCast);
2802  }
2803  }
2804 
2805  if (remainingCastOps)
2806  for (UnrealizedConversionCastOp op : castOps)
2807  if (!erasedOps.contains(op.getOperation()))
2808  remainingCastOps->push_back(op);
2809 }
2810 
2811 //===----------------------------------------------------------------------===//
2812 // Type Conversion
2813 //===----------------------------------------------------------------------===//
2814 
2816  ArrayRef<Type> types) {
2817  assert(!types.empty() && "expected valid types");
2818  remapInput(origInputNo, /*newInputNo=*/argTypes.size(), types.size());
2819  addInputs(types);
2820 }
2821 
2823  assert(!types.empty() &&
2824  "1->0 type remappings don't need to be added explicitly");
2825  argTypes.append(types.begin(), types.end());
2826 }
2827 
2828 void TypeConverter::SignatureConversion::remapInput(unsigned origInputNo,
2829  unsigned newInputNo,
2830  unsigned newInputCount) {
2831  assert(!remappedInputs[origInputNo] && "input has already been remapped");
2832  assert(newInputCount != 0 && "expected valid input count");
2833  remappedInputs[origInputNo] =
2834  InputMapping{newInputNo, newInputCount, /*replacementValue=*/nullptr};
2835 }
2836 
2838  Value replacementValue) {
2839  assert(!remappedInputs[origInputNo] && "input has already been remapped");
2840  remappedInputs[origInputNo] =
2841  InputMapping{origInputNo, /*size=*/0, replacementValue};
2842 }
2843 
2845  SmallVectorImpl<Type> &results) const {
2846  assert(t && "expected non-null type");
2847 
2848  {
2849  std::shared_lock<decltype(cacheMutex)> cacheReadLock(cacheMutex,
2850  std::defer_lock);
2852  cacheReadLock.lock();
2853  auto existingIt = cachedDirectConversions.find(t);
2854  if (existingIt != cachedDirectConversions.end()) {
2855  if (existingIt->second)
2856  results.push_back(existingIt->second);
2857  return success(existingIt->second != nullptr);
2858  }
2859  auto multiIt = cachedMultiConversions.find(t);
2860  if (multiIt != cachedMultiConversions.end()) {
2861  results.append(multiIt->second.begin(), multiIt->second.end());
2862  return success();
2863  }
2864  }
2865  // Walk the added converters in reverse order to apply the most recently
2866  // registered first.
2867  size_t currentCount = results.size();
2868 
2869  std::unique_lock<decltype(cacheMutex)> cacheWriteLock(cacheMutex,
2870  std::defer_lock);
2871 
2872  for (const ConversionCallbackFn &converter : llvm::reverse(conversions)) {
2873  if (std::optional<LogicalResult> result = converter(t, results)) {
2875  cacheWriteLock.lock();
2876  if (!succeeded(*result)) {
2877  cachedDirectConversions.try_emplace(t, nullptr);
2878  return failure();
2879  }
2880  auto newTypes = ArrayRef<Type>(results).drop_front(currentCount);
2881  if (newTypes.size() == 1)
2882  cachedDirectConversions.try_emplace(t, newTypes.front());
2883  else
2884  cachedMultiConversions.try_emplace(t, llvm::to_vector<2>(newTypes));
2885  return success();
2886  }
2887  }
2888  return failure();
2889 }
2890 
2892  // Use the multi-type result version to convert the type.
2893  SmallVector<Type, 1> results;
2894  if (failed(convertType(t, results)))
2895  return nullptr;
2896 
2897  // Check to ensure that only one type was produced.
2898  return results.size() == 1 ? results.front() : nullptr;
2899 }
2900 
2901 LogicalResult
2903  SmallVectorImpl<Type> &results) const {
2904  for (Type type : types)
2905  if (failed(convertType(type, results)))
2906  return failure();
2907  return success();
2908 }
2909 
2910 bool TypeConverter::isLegal(Type type) const {
2911  return convertType(type) == type;
2912 }
2914  return isLegal(op->getOperandTypes()) && isLegal(op->getResultTypes());
2915 }
2916 
2917 bool TypeConverter::isLegal(Region *region) const {
2918  return llvm::all_of(*region, [this](Block &block) {
2919  return isLegal(block.getArgumentTypes());
2920  });
2921 }
2922 
2923 bool TypeConverter::isSignatureLegal(FunctionType ty) const {
2924  return isLegal(llvm::concat<const Type>(ty.getInputs(), ty.getResults()));
2925 }
2926 
2927 LogicalResult
2929  SignatureConversion &result) const {
2930  // Try to convert the given input type.
2931  SmallVector<Type, 1> convertedTypes;
2932  if (failed(convertType(type, convertedTypes)))
2933  return failure();
2934 
2935  // If this argument is being dropped, there is nothing left to do.
2936  if (convertedTypes.empty())
2937  return success();
2938 
2939  // Otherwise, add the new inputs.
2940  result.addInputs(inputNo, convertedTypes);
2941  return success();
2942 }
2943 LogicalResult
2945  SignatureConversion &result,
2946  unsigned origInputOffset) const {
2947  for (unsigned i = 0, e = types.size(); i != e; ++i)
2948  if (failed(convertSignatureArg(origInputOffset + i, types[i], result)))
2949  return failure();
2950  return success();
2951 }
2952 
2954  Location loc,
2955  Type resultType,
2956  ValueRange inputs) const {
2957  for (const MaterializationCallbackFn &fn :
2958  llvm::reverse(argumentMaterializations))
2959  if (Value result = fn(builder, resultType, inputs, loc))
2960  return result;
2961  return nullptr;
2962 }
2963 
2965  Location loc, Type resultType,
2966  ValueRange inputs) const {
2967  for (const MaterializationCallbackFn &fn :
2968  llvm::reverse(sourceMaterializations))
2969  if (Value result = fn(builder, resultType, inputs, loc))
2970  return result;
2971  return nullptr;
2972 }
2973 
2975  Location loc, Type resultType,
2976  ValueRange inputs,
2977  Type originalType) const {
2979  builder, loc, TypeRange(resultType), inputs, originalType);
2980  if (result.empty())
2981  return nullptr;
2982  assert(result.size() == 1 && "expected single result");
2983  return result.front();
2984 }
2985 
2987  OpBuilder &builder, Location loc, TypeRange resultTypes, ValueRange inputs,
2988  Type originalType) const {
2989  for (const TargetMaterializationCallbackFn &fn :
2990  llvm::reverse(targetMaterializations)) {
2991  SmallVector<Value> result =
2992  fn(builder, resultTypes, inputs, loc, originalType);
2993  if (result.empty())
2994  continue;
2995  assert(TypeRange(ValueRange(result)) == resultTypes &&
2996  "callback produced incorrect number of values or values with "
2997  "incorrect types");
2998  return result;
2999  }
3000  return {};
3001 }
3002 
3003 std::optional<TypeConverter::SignatureConversion>
3005  SignatureConversion conversion(block->getNumArguments());
3006  if (failed(convertSignatureArgs(block->getArgumentTypes(), conversion)))
3007  return std::nullopt;
3008  return conversion;
3009 }
3010 
3011 //===----------------------------------------------------------------------===//
3012 // Type attribute conversion
3013 //===----------------------------------------------------------------------===//
3016  return AttributeConversionResult(attr, resultTag);
3017 }
3018 
3021  return AttributeConversionResult(nullptr, naTag);
3022 }
3023 
3026  return AttributeConversionResult(nullptr, abortTag);
3027 }
3028 
3030  return impl.getInt() == resultTag;
3031 }
3032 
3034  return impl.getInt() == naTag;
3035 }
3036 
3038  return impl.getInt() == abortTag;
3039 }
3040 
3042  assert(hasResult() && "Cannot get result from N/A or abort");
3043  return impl.getPointer();
3044 }
3045 
3046 std::optional<Attribute>
3048  for (const TypeAttributeConversionCallbackFn &fn :
3049  llvm::reverse(typeAttributeConversions)) {
3050  AttributeConversionResult res = fn(type, attr);
3051  if (res.hasResult())
3052  return res.getResult();
3053  if (res.isAbort())
3054  return std::nullopt;
3055  }
3056  return std::nullopt;
3057 }
3058 
3059 //===----------------------------------------------------------------------===//
3060 // FunctionOpInterfaceSignatureConversion
3061 //===----------------------------------------------------------------------===//
3062 
3063 static LogicalResult convertFuncOpTypes(FunctionOpInterface funcOp,
3064  const TypeConverter &typeConverter,
3065  ConversionPatternRewriter &rewriter) {
3066  FunctionType type = dyn_cast<FunctionType>(funcOp.getFunctionType());
3067  if (!type)
3068  return failure();
3069 
3070  // Convert the original function types.
3071  TypeConverter::SignatureConversion result(type.getNumInputs());
3072  SmallVector<Type, 1> newResults;
3073  if (failed(typeConverter.convertSignatureArgs(type.getInputs(), result)) ||
3074  failed(typeConverter.convertTypes(type.getResults(), newResults)) ||
3075  failed(rewriter.convertRegionTypes(&funcOp.getFunctionBody(),
3076  typeConverter, &result)))
3077  return failure();
3078 
3079  // Update the function signature in-place.
3080  auto newType = FunctionType::get(rewriter.getContext(),
3081  result.getConvertedTypes(), newResults);
3082 
3083  rewriter.modifyOpInPlace(funcOp, [&] { funcOp.setType(newType); });
3084 
3085  return success();
3086 }
3087 
3088 /// Create a default conversion pattern that rewrites the type signature of a
3089 /// FunctionOpInterface op. This only supports ops which use FunctionType to
3090 /// represent their type.
3091 namespace {
3092 struct FunctionOpInterfaceSignatureConversion : public ConversionPattern {
3093  FunctionOpInterfaceSignatureConversion(StringRef functionLikeOpName,
3094  MLIRContext *ctx,
3095  const TypeConverter &converter)
3096  : ConversionPattern(converter, functionLikeOpName, /*benefit=*/1, ctx) {}
3097 
3098  LogicalResult
3099  matchAndRewrite(Operation *op, ArrayRef<Value> /*operands*/,
3100  ConversionPatternRewriter &rewriter) const override {
3101  FunctionOpInterface funcOp = cast<FunctionOpInterface>(op);
3102  return convertFuncOpTypes(funcOp, *typeConverter, rewriter);
3103  }
3104 };
3105 
3106 struct AnyFunctionOpInterfaceSignatureConversion
3107  : public OpInterfaceConversionPattern<FunctionOpInterface> {
3109 
3110  LogicalResult
3111  matchAndRewrite(FunctionOpInterface funcOp, ArrayRef<Value> /*operands*/,
3112  ConversionPatternRewriter &rewriter) const override {
3113  return convertFuncOpTypes(funcOp, *typeConverter, rewriter);
3114  }
3115 };
3116 } // namespace
3117 
3118 FailureOr<Operation *>
3120  const TypeConverter &converter,
3121  ConversionPatternRewriter &rewriter) {
3122  assert(op && "Invalid op");
3123  Location loc = op->getLoc();
3124  if (converter.isLegal(op))
3125  return rewriter.notifyMatchFailure(loc, "op already legal");
3126 
3127  OperationState newOp(loc, op->getName());
3128  newOp.addOperands(operands);
3129 
3130  SmallVector<Type> newResultTypes;
3131  if (failed(converter.convertTypes(op->getResultTypes(), newResultTypes)))
3132  return rewriter.notifyMatchFailure(loc, "couldn't convert return types");
3133 
3134  newOp.addTypes(newResultTypes);
3135  newOp.addAttributes(op->getAttrs());
3136  return rewriter.create(newOp);
3137 }
3138 
3140  StringRef functionLikeOpName, RewritePatternSet &patterns,
3141  const TypeConverter &converter) {
3142  patterns.add<FunctionOpInterfaceSignatureConversion>(
3143  functionLikeOpName, patterns.getContext(), converter);
3144 }
3145 
3147  RewritePatternSet &patterns, const TypeConverter &converter) {
3148  patterns.add<AnyFunctionOpInterfaceSignatureConversion>(
3149  converter, patterns.getContext());
3150 }
3151 
3152 //===----------------------------------------------------------------------===//
3153 // ConversionTarget
3154 //===----------------------------------------------------------------------===//
3155 
3157  LegalizationAction action) {
3158  legalOperations[op].action = action;
3159 }
3160 
3162  LegalizationAction action) {
3163  for (StringRef dialect : dialectNames)
3164  legalDialects[dialect] = action;
3165 }
3166 
3168  -> std::optional<LegalizationAction> {
3169  std::optional<LegalizationInfo> info = getOpInfo(op);
3170  return info ? info->action : std::optional<LegalizationAction>();
3171 }
3172 
3174  -> std::optional<LegalOpDetails> {
3175  std::optional<LegalizationInfo> info = getOpInfo(op->getName());
3176  if (!info)
3177  return std::nullopt;
3178 
3179  // Returns true if this operation instance is known to be legal.
3180  auto isOpLegal = [&] {
3181  // Handle dynamic legality either with the provided legality function.
3182  if (info->action == LegalizationAction::Dynamic) {
3183  std::optional<bool> result = info->legalityFn(op);
3184  if (result)
3185  return *result;
3186  }
3187 
3188  // Otherwise, the operation is only legal if it was marked 'Legal'.
3189  return info->action == LegalizationAction::Legal;
3190  };
3191  if (!isOpLegal())
3192  return std::nullopt;
3193 
3194  // This operation is legal, compute any additional legality information.
3195  LegalOpDetails legalityDetails;
3196  if (info->isRecursivelyLegal) {
3197  auto legalityFnIt = opRecursiveLegalityFns.find(op->getName());
3198  if (legalityFnIt != opRecursiveLegalityFns.end()) {
3199  legalityDetails.isRecursivelyLegal =
3200  legalityFnIt->second(op).value_or(true);
3201  } else {
3202  legalityDetails.isRecursivelyLegal = true;
3203  }
3204  }
3205  return legalityDetails;
3206 }
3207 
3209  std::optional<LegalizationInfo> info = getOpInfo(op->getName());
3210  if (!info)
3211  return false;
3212 
3213  if (info->action == LegalizationAction::Dynamic) {
3214  std::optional<bool> result = info->legalityFn(op);
3215  if (!result)
3216  return false;
3217 
3218  return !(*result);
3219  }
3220 
3221  return info->action == LegalizationAction::Illegal;
3222 }
3223 
3227  if (!oldCallback)
3228  return newCallback;
3229 
3230  auto chain = [oldCl = std::move(oldCallback), newCl = std::move(newCallback)](
3231  Operation *op) -> std::optional<bool> {
3232  if (std::optional<bool> result = newCl(op))
3233  return *result;
3234 
3235  return oldCl(op);
3236  };
3237  return chain;
3238 }
3239 
3240 void ConversionTarget::setLegalityCallback(
3241  OperationName name, const DynamicLegalityCallbackFn &callback) {
3242  assert(callback && "expected valid legality callback");
3243  auto *infoIt = legalOperations.find(name);
3244  assert(infoIt != legalOperations.end() &&
3245  infoIt->second.action == LegalizationAction::Dynamic &&
3246  "expected operation to already be marked as dynamically legal");
3247  infoIt->second.legalityFn =
3248  composeLegalityCallbacks(std::move(infoIt->second.legalityFn), callback);
3249 }
3250 
3252  OperationName name, const DynamicLegalityCallbackFn &callback) {
3253  auto *infoIt = legalOperations.find(name);
3254  assert(infoIt != legalOperations.end() &&
3255  infoIt->second.action != LegalizationAction::Illegal &&
3256  "expected operation to already be marked as legal");
3257  infoIt->second.isRecursivelyLegal = true;
3258  if (callback)
3259  opRecursiveLegalityFns[name] = composeLegalityCallbacks(
3260  std::move(opRecursiveLegalityFns[name]), callback);
3261  else
3262  opRecursiveLegalityFns.erase(name);
3263 }
3264 
3265 void ConversionTarget::setLegalityCallback(
3266  ArrayRef<StringRef> dialects, const DynamicLegalityCallbackFn &callback) {
3267  assert(callback && "expected valid legality callback");
3268  for (StringRef dialect : dialects)
3269  dialectLegalityFns[dialect] = composeLegalityCallbacks(
3270  std::move(dialectLegalityFns[dialect]), callback);
3271 }
3272 
3273 void ConversionTarget::setLegalityCallback(
3274  const DynamicLegalityCallbackFn &callback) {
3275  assert(callback && "expected valid legality callback");
3276  unknownLegalityFn = composeLegalityCallbacks(unknownLegalityFn, callback);
3277 }
3278 
3279 auto ConversionTarget::getOpInfo(OperationName op) const
3280  -> std::optional<LegalizationInfo> {
3281  // Check for info for this specific operation.
3282  const auto *it = legalOperations.find(op);
3283  if (it != legalOperations.end())
3284  return it->second;
3285  // Check for info for the parent dialect.
3286  auto dialectIt = legalDialects.find(op.getDialectNamespace());
3287  if (dialectIt != legalDialects.end()) {
3288  DynamicLegalityCallbackFn callback;
3289  auto dialectFn = dialectLegalityFns.find(op.getDialectNamespace());
3290  if (dialectFn != dialectLegalityFns.end())
3291  callback = dialectFn->second;
3292  return LegalizationInfo{dialectIt->second, /*isRecursivelyLegal=*/false,
3293  callback};
3294  }
3295  // Otherwise, check if we mark unknown operations as dynamic.
3296  if (unknownLegalityFn)
3297  return LegalizationInfo{LegalizationAction::Dynamic,
3298  /*isRecursivelyLegal=*/false, unknownLegalityFn};
3299  return std::nullopt;
3300 }
3301 
3302 #if MLIR_ENABLE_PDL_IN_PATTERNMATCH
3303 //===----------------------------------------------------------------------===//
3304 // PDL Configuration
3305 //===----------------------------------------------------------------------===//
3306 
3308  auto &rewriterImpl =
3309  static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
3310  rewriterImpl.currentTypeConverter = getTypeConverter();
3311 }
3312 
3314  auto &rewriterImpl =
3315  static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
3316  rewriterImpl.currentTypeConverter = nullptr;
3317 }
3318 
3319 /// Remap the given value using the rewriter and the type converter in the
3320 /// provided config.
3321 static FailureOr<SmallVector<Value>>
3323  SmallVector<Value> mappedValues;
3324  if (failed(rewriter.getRemappedValues(values, mappedValues)))
3325  return failure();
3326  return std::move(mappedValues);
3327 }
3328 
3330  patterns.getPDLPatterns().registerRewriteFunction(
3331  "convertValue",
3332  [](PatternRewriter &rewriter, Value value) -> FailureOr<Value> {
3333  auto results = pdllConvertValues(
3334  static_cast<ConversionPatternRewriter &>(rewriter), value);
3335  if (failed(results))
3336  return failure();
3337  return results->front();
3338  });
3339  patterns.getPDLPatterns().registerRewriteFunction(
3340  "convertValues", [](PatternRewriter &rewriter, ValueRange values) {
3341  return pdllConvertValues(
3342  static_cast<ConversionPatternRewriter &>(rewriter), values);
3343  });
3344  patterns.getPDLPatterns().registerRewriteFunction(
3345  "convertType",
3346  [](PatternRewriter &rewriter, Type type) -> FailureOr<Type> {
3347  auto &rewriterImpl =
3348  static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
3349  if (const TypeConverter *converter =
3350  rewriterImpl.currentTypeConverter) {
3351  if (Type newType = converter->convertType(type))
3352  return newType;
3353  return failure();
3354  }
3355  return type;
3356  });
3357  patterns.getPDLPatterns().registerRewriteFunction(
3358  "convertTypes",
3359  [](PatternRewriter &rewriter,
3360  TypeRange types) -> FailureOr<SmallVector<Type>> {
3361  auto &rewriterImpl =
3362  static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
3363  const TypeConverter *converter = rewriterImpl.currentTypeConverter;
3364  if (!converter)
3365  return SmallVector<Type>(types);
3366 
3367  SmallVector<Type> remappedTypes;
3368  if (failed(converter->convertTypes(types, remappedTypes)))
3369  return failure();
3370  return std::move(remappedTypes);
3371  });
3372 }
3373 #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
3374 
3375 //===----------------------------------------------------------------------===//
3376 // Op Conversion Entry Points
3377 //===----------------------------------------------------------------------===//
3378 
3379 //===----------------------------------------------------------------------===//
3380 // Partial Conversion
3381 
3383  ArrayRef<Operation *> ops, const ConversionTarget &target,
3385  OperationConverter opConverter(target, patterns, config,
3386  OpConversionMode::Partial);
3387  return opConverter.convertOperations(ops);
3388 }
3389 LogicalResult
3393  return applyPartialConversion(llvm::ArrayRef(op), target, patterns, config);
3394 }
3395 
3396 //===----------------------------------------------------------------------===//
3397 // Full Conversion
3398 
3400  const ConversionTarget &target,
3403  OperationConverter opConverter(target, patterns, config,
3404  OpConversionMode::Full);
3405  return opConverter.convertOperations(ops);
3406 }
3408  const ConversionTarget &target,
3411  return applyFullConversion(llvm::ArrayRef(op), target, patterns, config);
3412 }
3413 
3414 //===----------------------------------------------------------------------===//
3415 // Analysis Conversion
3416 
3417 /// Find a common IsolatedFromAbove ancestor of the given ops. If at least one
3418 /// op is a top-level module op (which is expected to be isolated from above),
3419 /// return that op.
3421  // Check if there is a top-level operation within `ops`. If so, return that
3422  // op.
3423  for (Operation *op : ops) {
3424  if (!op->getParentOp()) {
3425 #ifndef NDEBUG
3426  assert(op->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
3427  "expected top-level op to be isolated from above");
3428  for (Operation *other : ops)
3429  assert(op->isAncestor(other) &&
3430  "expected ops to have a common ancestor");
3431 #endif // NDEBUG
3432  return op;
3433  }
3434  }
3435 
3436  // No top-level op. Find a common ancestor.
3437  Operation *commonAncestor =
3438  ops.front()->getParentWithTrait<OpTrait::IsIsolatedFromAbove>();
3439  for (Operation *op : ops.drop_front()) {
3440  while (!commonAncestor->isProperAncestor(op)) {
3441  commonAncestor =
3443  assert(commonAncestor &&
3444  "expected to find a common isolated from above ancestor");
3445  }
3446  }
3447 
3448  return commonAncestor;
3449 }
3450 
3454 #ifndef NDEBUG
3455  if (config.legalizableOps)
3456  assert(config.legalizableOps->empty() && "expected empty set");
3457 #endif // NDEBUG
3458 
3459  // Clone closted common ancestor that is isolated from above.
3460  Operation *commonAncestor = findCommonAncestor(ops);
3461  IRMapping mapping;
3462  Operation *clonedAncestor = commonAncestor->clone(mapping);
3463  // Compute inverse IR mapping.
3464  DenseMap<Operation *, Operation *> inverseOperationMap;
3465  for (auto &it : mapping.getOperationMap())
3466  inverseOperationMap[it.second] = it.first;
3467 
3468  // Convert the cloned operations. The original IR will remain unchanged.
3469  SmallVector<Operation *> opsToConvert = llvm::map_to_vector(
3470  ops, [&](Operation *op) { return mapping.lookup(op); });
3471  OperationConverter opConverter(target, patterns, config,
3472  OpConversionMode::Analysis);
3473  LogicalResult status = opConverter.convertOperations(opsToConvert);
3474 
3475  // Remap `legalizableOps`, so that they point to the original ops and not the
3476  // cloned ops.
3477  if (config.legalizableOps) {
3478  DenseSet<Operation *> originalLegalizableOps;
3479  for (Operation *op : *config.legalizableOps)
3480  originalLegalizableOps.insert(inverseOperationMap[op]);
3481  *config.legalizableOps = std::move(originalLegalizableOps);
3482  }
3483 
3484  // Erase the cloned IR.
3485  clonedAncestor->erase();
3486  return status;
3487 }
3488 
3489 LogicalResult
3493  return applyAnalysisConversion(llvm::ArrayRef(op), target, patterns, config);
3494 }
static ConversionTarget::DynamicLegalityCallbackFn composeLegalityCallbacks(ConversionTarget::DynamicLegalityCallbackFn oldCallback, ConversionTarget::DynamicLegalityCallbackFn newCallback)
static FailureOr< SmallVector< Value > > pdllConvertValues(ConversionPatternRewriter &rewriter, ValueRange values)
Remap the given value using the rewriter and the type converter in the provided config.
static LogicalResult convertFuncOpTypes(FunctionOpInterface funcOp, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args)
A utility function to log a failure result for the given reason.
static LogicalResult legalizeUnresolvedMaterialization(RewriterBase &rewriter, UnresolvedMaterializationRewrite *rewrite)
static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args)
A utility function to log a successful result for the given reason.
static Operation * findCommonAncestor(ArrayRef< Operation * > ops)
Find a common IsolatedFromAbove ancestor of the given ops.
static OpBuilder::InsertPoint computeInsertPoint(Value value)
Helper function that computes an insertion point where the given value is defined and can be used wit...
static std::string diag(const llvm::Value &value)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void rewrite(DataFlowSolver &solver, MLIRContext *context, MutableArrayRef< Region > initialRegions)
Rewrite the given regions using the computing analysis.
Definition: SCCP.cpp:67
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:319
Block * getOwner() const
Returns the block that owns this argument.
Definition: Value.h:328
Location getLoc() const
Return the location for this argument.
Definition: Value.h:334
Block represents an ordered list of Operations.
Definition: Block.h:33
OpListType::iterator iterator
Definition: Block.h:140
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
Definition: Block.cpp:151
bool empty()
Definition: Block.h:148
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
unsigned getNumArguments()
Definition: Block.h:128
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition: Block.cpp:29
void dropAllDefinedValueUses()
This drops all uses of values defined in this block or in the blocks of nested regions wherever the u...
Definition: Block.cpp:96
OpListType & getOperations()
Definition: Block.h:137
BlockArgListType getArguments()
Definition: Block.h:87
Operation & front()
Definition: Block.h:153
iterator end()
Definition: Block.h:144
iterator begin()
Definition: Block.h:143
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:33
MLIRContext * getContext() const
Definition: Builders.h:56
Location getUnknownLoc()
Definition: Builders.cpp:27
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
LogicalResult getRemappedValues(ValueRange keys, SmallVectorImpl< Value > &results)
Return the converted values that replace 'keys' with types defined by the type converter of the curre...
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Apply a signature conversion to each block in the given region.
void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt) override
PatternRewriter hook for inlining the ops of a block into another block.
Block * applySignatureConversion(Block *block, TypeConverter::SignatureConversion &conversion, const TypeConverter *converter=nullptr)
Apply a signature conversion to given block.
void startOpModification(Operation *op) override
PatternRewriter hook for updating the given operation in-place.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
void replaceOpWithMultiple(Operation *op, ArrayRef< ValueRange > newValues)
Replace the given operation with the new value ranges.
detail::ConversionPatternRewriterImpl & getImpl()
Return a reference to the internal implementation.
void eraseBlock(Block *block) override
PatternRewriter hook for erase all operations in a block.
void cancelOpModification(Operation *op) override
PatternRewriter hook for updating the given operation in-place.
Value getRemappedValue(Value key)
Return the converted value of 'key' with a type defined by the type converter of the currently execut...
void finalizeOpModification(Operation *op) override
PatternRewriter hook for updating the given operation in-place.
void replaceUsesOfBlockArgument(BlockArgument from, Value to)
Replace all the uses of the block argument from with value to.
Base class for the conversion patterns.
SmallVector< Value > getOneToOneAdaptorOperands(ArrayRef< ValueRange > operands) const
Given an array of value ranges, which are the inputs to a 1:N adaptor, try to extract the single valu...
virtual LogicalResult matchAndRewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const
Hook for derived classes to implement combined matching and rewriting.
This class describes a specific conversion target.
void setDialectAction(ArrayRef< StringRef > dialectNames, LegalizationAction action)
Register a legality action for the given dialects.
void setOpAction(OperationName op, LegalizationAction action)
Register a legality action for the given operation.
std::optional< LegalOpDetails > isLegal(Operation *op) const
If the given operation instance is legal on this target, a structure containing legality information ...
std::optional< LegalizationAction > getOpAction(OperationName op) const
Get the legality action for the given operation.
LegalizationAction
This enumeration corresponds to the specific action to take when considering an operation legal for t...
void markOpRecursivelyLegal(OperationName name, const DynamicLegalityCallbackFn &callback)
Mark an operation, that must have either been set as Legal or DynamicallyLegal, as being recursively ...
std::function< std::optional< bool >(Operation *)> DynamicLegalityCallbackFn
The signature of the callback used to determine if an operation is dynamically legal on the target.
bool isIllegal(Operation *op) const
Returns true is operation instance is illegal on this target.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
Definition: Diagnostics.h:155
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
const DenseMap< Operation *, Operation * > & getOperationMap() const
Return the held operation mapping.
Definition: IRMapping.h:88
auto lookup(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:72
user_range getUsers() const
Returns a range of all users.
Definition: UseDefLists.h:274
void replaceAllUsesWith(ValueT &&newValue)
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
Definition: UseDefLists.h:211
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:772
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:314
Location objects represent source locations information in MLIR.
Definition: Location.h:31
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
bool isMultithreadingEnabled()
Return true if multi-threading is enabled by the context.
This class represents a saved insertion point.
Definition: Builders.h:336
Block::iterator getPoint() const
Definition: Builders.h:349
bool isSet() const
Returns true if this insert point is set.
Definition: Builders.h:346
Block * getBlock() const
Definition: Builders.h:348
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:357
This class helps build Operations.
Definition: Builders.h:216
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:407
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Definition: Builders.h:329
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:470
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
LogicalResult tryFold(Operation *op, SmallVectorImpl< Value > &results)
Attempts to fold the given operation and places new results within results.
Definition: Builders.cpp:512
OpInterfaceConversionPattern is a wrapper around ConversionPattern that allows for matching and rewri...
OpInterfaceConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
This class represents an operand of an operation.
Definition: Value.h:267
This is a value defined by a result of an operation.
Definition: Value.h:457
This class provides the API for ops that are known to be isolated from above.
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:42
type_range getTypes() const
Definition: ValueRange.cpp:26
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
void setLoc(Location loc)
Set the source location the operation was defined or derived from.
Definition: Operation.h:226
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:750
void dropAllUses()
Drop all uses of results of this operation.
Definition: Operation.h:835
void setAttrs(DictionaryAttr newAttrs)
Set the attributes from a dictionary on this operation.
Definition: Operation.cpp:305
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Definition: Operation.cpp:386
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
Definition: Operation.cpp:717
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:798
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:674
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:512
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
Definition: Operation.h:248
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:677
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
operand_type_range getOperandTypes()
Definition: Operation.h:397
result_type_range getResultTypes()
Definition: Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
void setSuccessor(Block *block, unsigned index)
Definition: Operation.cpp:605
bool isAncestor(Operation *other)
Return true if this operation is an ancestor of the other operation.
Definition: Operation.h:263
void setOperands(ValueRange operands)
Replace the current operands of this operation with the ones provided in 'operands'.
Definition: Operation.cpp:237
result_range getResults()
Definition: Operation.h:415
int getPropertiesStorageSize() const
Returns the properties storage size.
Definition: Operation.h:897
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
Definition: Operation.cpp:219
OpaqueProperties getPropertiesStorage()
Returns the properties storage.
Definition: Operation.h:901
void erase()
Remove this operation from its parent block and delete it.
Definition: Operation.cpp:539
void copyProperties(OpaqueProperties rhs)
Copy properties from an existing other properties object.
Definition: Operation.cpp:366
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
void notifyRewriteEnd(PatternRewriter &rewriter) final
void notifyRewriteBegin(PatternRewriter &rewriter) final
Hooks that are invoked at the beginning and end of a rewrite of a matched pattern.
This class manages the application of a group of rewrite patterns, with a user-provided cost model.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
static PatternBenefit impossibleToMatch()
Definition: PatternMatch.h:43
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
Definition: PatternMatch.h:73
bool hasBoundedRewriteRecursion() const
Returns true if this pattern is known to result in recursive application, i.e.
Definition: PatternMatch.h:129
std::optional< OperationName > getRootKind() const
Return the root node that this pattern matches.
Definition: PatternMatch.h:94
ArrayRef< OperationName > getGeneratedOps() const
Return a list of operations that may be generated when rewriting an operation instance with this patt...
Definition: PatternMatch.h:90
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition: Region.h:200
bool empty()
Definition: Region.h:60
iterator end()
Definition: Region.h:56
BlockListType & getBlocks()
Definition: Region.h:45
Block & front()
Definition: Region.h:65
BlockListType::iterator iterator
Definition: Region.h:52
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:724
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:644
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void moveOpBefore(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:636
The general result of a type attribute conversion callback, allowing for early termination.
static AttributeConversionResult result(Attribute attr)
This class provides all of the information necessary to convert a type signature.
void addInputs(unsigned origInputNo, ArrayRef< Type > types)
Remap an input of the original signature with a new set of types.
std::optional< InputMapping > getInputMapping(unsigned input) const
Get the input mapping for the given argument.
ArrayRef< Type > getConvertedTypes() const
Return the argument types for the new signature.
void remapInput(unsigned origInputNo, Value replacement)
Remap an input of the original signature to another replacement value.
Type conversion class.
std::optional< Attribute > convertTypeAttribute(Type type, Attribute attr) const
Convert an attribute present attr from within the type type using the registered conversion functions...
Value materializeSourceConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs) const
bool isLegal(Type type) const
Return true if the given type is legal for this type converter, i.e.
LogicalResult convertSignatureArgs(TypeRange types, SignatureConversion &result, unsigned origInputOffset=0) const
LogicalResult convertSignatureArg(unsigned inputNo, Type type, SignatureConversion &result) const
This method allows for converting a specific argument of a signature.
Value materializeArgumentConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs) const
Materialize a conversion from a set of types into one result type by generating a cast sequence of so...
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
std::optional< SignatureConversion > convertBlockSignature(Block *block) const
This function converts the type signature of the given block, by invoking 'convertSignatureArg' for e...
LogicalResult convertTypes(TypeRange types, SmallVectorImpl< Type > &results) const
Convert the given set of types, filling 'results' as necessary.
bool isSignatureLegal(FunctionType ty) const
Return true if the inputs and outputs of the given function type are legal.
Value materializeTargetConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs, Type originalType={}) const
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Block * getParentBlock()
Return the Block in which this Value is defined.
Definition: Value.cpp:48
user_range getUsers() const
Definition: Value.h:228
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
static WalkResult skip()
Definition: Visitors.h:52
static WalkResult advance()
Definition: Visitors.h:51
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
AttrTypeReplacer.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
@ Full
Documents are synced by always sending the full content of the document.
Kind
An enumeration of the kinds of predicates.
Definition: Predicate.h:44
llvm::PointerUnion< NamedAttribute *, NamedProperty *, NamedTypeConstraint * > Argument
Definition: Argument.h:64
Include the generated interface declarations.
void populateFunctionOpInterfaceTypeConversionPattern(StringRef functionLikeOpName, RewritePatternSet &patterns, const TypeConverter &converter)
Add a pattern to the given pattern list to convert the signature of a FunctionOpInterface op with the...
void populateAnyFunctionOpInterfaceTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter)
const FrozenRewritePatternSet GreedyRewriteConfig config
LogicalResult applyFullConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Apply a complete conversion on the given operations, and all nested operations.
FailureOr< Operation * > convertOpResultTypes(Operation *op, ValueRange operands, const TypeConverter &converter, ConversionPatternRewriter &rewriter)
Generic utility to convert op result types according to type converter without knowing exact op type.
const FrozenRewritePatternSet & patterns
LogicalResult applyAnalysisConversion(ArrayRef< Operation * > ops, ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Apply an analysis conversion on the given operations, and all nested operations.
void reconcileUnrealizedCasts(ArrayRef< UnrealizedConversionCastOp > castOps, SmallVectorImpl< UnrealizedConversionCastOp > *remainingCastOps=nullptr)
Try to reconcile all given UnrealizedConversionCastOps and store the left-over ops in remainingCastOp...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void registerConversionPDLFunctions(RewritePatternSet &patterns)
Register the dialect conversion PDL functions with the given pattern set.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
Dialect conversion configuration.
RewriterBase::Listener * listener
An optional listener that is notified about all IR modifications in case dialect conversion succeeds.
function_ref< void(Diagnostic &)> notifyCallback
An optional callback used to notify about match failure diagnostics during the conversion.
DenseSet< Operation * > * legalizableOps
Analysis conversion only.
DenseSet< Operation * > * unlegalizedOps
Partial conversion only.
bool buildMaterializations
If set to "true", the dialect conversion attempts to build source/target/ argument materializations t...
A structure containing additional information describing a specific legal operation instance.
bool isRecursivelyLegal
A flag that indicates if this operation is 'recursively' legal.
This iterator enumerates elements according to their dominance relationship.
Definition: Iterators.h:48
LogicalResult convertOperations(ArrayRef< Operation * > ops)
Converts the given operations to the conversion target.
OperationConverter(const ConversionTarget &target, const FrozenRewritePatternSet &patterns, const ConversionConfig &config, OpConversionMode mode)
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addTypes(ArrayRef< Type > newTypes)
This struct represents a range of new types or a single value that remaps an existing signature input...
A rewriter that keeps track of erased ops and blocks.
void eraseOp(Operation *op) override
Erase the given op (unless it was already erased).
void notifyBlockErased(Block *block) override
Notify the listener that the specified block is about to be erased.
void notifyOperationErased(Operation *op) override
Notify the listener that the specified operation is about to be erased.
void eraseBlock(Block *block) override
Erase the given block (unless it was already erased).
void notifyOperationInserted(Operation *op, OpBuilder::InsertPoint previous) override
Notify the listener that the specified operation was inserted.
Value findOrBuildReplacementValue(Value value, const TypeConverter *converter)
Find a replacement value for the given SSA value in the conversion value mapping.
ConversionPatternRewriterImpl(MLIRContext *ctx, const ConversionConfig &config)
DenseMap< Region *, const TypeConverter * > regionToConverter
A mapping of regions to type converters that should be used when converting the arguments of blocks w...
bool wasOpReplaced(Operation *op) const
Return "true" if the given operation was replaced or erased.
SmallVector< Value > unpackNTo1Materialization(Value value)
Unpack an N:1 materialization and return the inputs of the materialization.
void notifyBlockInserted(Block *block, Region *previous, Region::iterator previousIt) override
Notifies that a block was inserted.
DenseMap< UnrealizedConversionCastOp, UnresolvedMaterializationRewrite * > unresolvedMaterializations
A mapping of all unresolved materializations (UnrealizedConversionCastOp) to the corresponding rewrit...
Value buildUnresolvedMaterialization(MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc, Value valueToMap, ValueRange inputs, Type outputType, Type originalType, const TypeConverter *converter, UnrealizedConversionCastOp *castOp=nullptr)
void resetState(RewriterState state)
Reset the state of the rewriter to a previously saved point.
Block * applySignatureConversion(ConversionPatternRewriter &rewriter, Block *block, const TypeConverter *converter, TypeConverter::SignatureConversion &signatureConversion)
Apply the given signature conversion on the given block.
FailureOr< Block * > convertRegionTypes(ConversionPatternRewriter &rewriter, Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion)
Convert the types of block arguments within the given region.
void applyRewrites()
Apply all requested operation rewrites.
void undoRewrites(unsigned numRewritesToKeep=0)
Undo the rewrites (motions, splits) one by one in reverse order until "numRewritesToKeep" rewrites re...
RewriterState getCurrentState()
Return the current state of the rewriter.
llvm::ScopedPrinter logger
A logger used to emit diagnostics during the conversion process.
void notifyBlockBeingInlined(Block *block, Block *srcBlock, Block::iterator before)
Notifies that a block is being inlined into another block.
void appendRewrite(Args &&...args)
Append a rewrite.
SmallPtrSet< Operation *, 1 > pendingRootUpdates
A set of operations that have pending updates.
void insertNTo1Materialization(OpBuilder::InsertPoint ip, Location loc, ValueRange replacements, Value originalValue, const TypeConverter *converter)
Build an N:1 materialization for the given original value that was replaced with the given replacemen...
void notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
Notifies that a pattern match failed for the given reason.
SingleEraseRewriter eraseRewriter
A rewriter that keeps track of ops/block that were already erased and skips duplicate op/block erasur...
bool isOpIgnored(Operation *op) const
Return "true" if the given operation is ignored, and does not need to be converted.
DenseSet< UnrealizedConversionCastOp > nTo1TempMaterializations
A set of all N:1 materializations that were added to work around incomplete 1:N support in the dialec...
ValueRange buildUnresolvedMaterialization(MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc, Value valueToMap, ValueRange inputs, TypeRange outputTypes, Type originalType, const TypeConverter *converter, UnrealizedConversionCastOp *castOp=nullptr)
Build an unresolved materialization operation given a range of output types and a list of input opera...
SetVector< Operation * > ignoredOps
A set of operations that should no longer be considered for legalization.
SmallVector< std::unique_ptr< IRRewrite > > rewrites
Ordered list of block operations (creations, splits, motions).
void notifyOpReplaced(Operation *op, ArrayRef< ValueRange > newValues)
Notifies that an op is about to be replaced with the given values.
const ConversionConfig & config
Dialect conversion configuration.
LogicalResult remapValues(StringRef valueDiagTag, std::optional< Location > inputLoc, PatternRewriter &rewriter, ValueRange values, SmallVector< SmallVector< Value >> &remapped)
Remap the given values to those with potentially different types.
SetVector< Operation * > replacedOps
A set of operations that were replaced/erased.
void notifyBlockIsBeingErased(Block *block)
Notifies that a block is about to be erased.
const TypeConverter * currentTypeConverter
The current type converter, or nullptr if no type converter is currently active.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.