MLIR  22.0.0git
DialectConversion.cpp
Go to the documentation of this file.
1 //===- DialectConversion.cpp - MLIR dialect conversion generic pass -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 #include "mlir/Config/mlir-config.h"
11 #include "mlir/IR/Block.h"
12 #include "mlir/IR/Builders.h"
13 #include "mlir/IR/BuiltinOps.h"
14 #include "mlir/IR/Dominance.h"
15 #include "mlir/IR/IRMapping.h"
16 #include "mlir/IR/Iterators.h"
19 #include "llvm/ADT/SmallPtrSet.h"
20 #include "llvm/Support/Debug.h"
21 #include "llvm/Support/FormatVariadic.h"
22 #include "llvm/Support/SaveAndRestore.h"
23 #include "llvm/Support/ScopedPrinter.h"
24 #include <optional>
25 
26 using namespace mlir;
27 using namespace mlir::detail;
28 
29 #define DEBUG_TYPE "dialect-conversion"
30 
31 /// A utility function to log a successful result for the given reason.
32 template <typename... Args>
33 static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
34  LLVM_DEBUG({
35  os.unindent();
36  os.startLine() << "} -> SUCCESS";
37  if (!fmt.empty())
38  os.getOStream() << " : "
39  << llvm::formatv(fmt.data(), std::forward<Args>(args)...);
40  os.getOStream() << "\n";
41  });
42 }
43 
44 /// A utility function to log a failure result for the given reason.
45 template <typename... Args>
46 static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
47  LLVM_DEBUG({
48  os.unindent();
49  os.startLine() << "} -> FAILURE : "
50  << llvm::formatv(fmt.data(), std::forward<Args>(args)...)
51  << "\n";
52  });
53 }
54 
55 /// Helper function that computes an insertion point where the given value is
56 /// defined and can be used without a dominance violation.
58  Block *insertBlock = value.getParentBlock();
59  Block::iterator insertPt = insertBlock->begin();
60  if (OpResult inputRes = dyn_cast<OpResult>(value))
61  insertPt = ++inputRes.getOwner()->getIterator();
62  return OpBuilder::InsertPoint(insertBlock, insertPt);
63 }
64 
65 /// Helper function that computes an insertion point where the given values are
66 /// defined and can be used without a dominance violation.
68  assert(!vals.empty() && "expected at least one value");
69  DominanceInfo domInfo;
70  OpBuilder::InsertPoint pt = computeInsertPoint(vals.front());
71  for (Value v : vals.drop_front()) {
72  // Choose the "later" insertion point.
74  if (domInfo.dominates(pt.getBlock(), pt.getPoint(), nextPt.getBlock(),
75  nextPt.getPoint())) {
76  // pt is before nextPt => choose nextPt.
77  pt = nextPt;
78  } else {
79 #ifndef NDEBUG
80  // nextPt should be before pt => choose pt.
81  // If pt, nextPt are no dominance relationship, then there is no valid
82  // insertion point at which all given values are defined.
83  bool dom = domInfo.dominates(nextPt.getBlock(), nextPt.getPoint(),
84  pt.getBlock(), pt.getPoint());
85  assert(dom && "unable to find valid insertion point");
86 #endif // NDEBUG
87  }
88  }
89  return pt;
90 }
91 
92 //===----------------------------------------------------------------------===//
93 // ConversionValueMapping
94 //===----------------------------------------------------------------------===//
95 
96 /// A vector of SSA values, optimized for the most common case of a single
97 /// value.
98 using ValueVector = SmallVector<Value, 1>;
99 
100 namespace {
101 
102 /// Helper class to make it possible to use `ValueVector` as a key in DenseMap.
103 struct ValueVectorMapInfo {
104  static ValueVector getEmptyKey() { return ValueVector{Value()}; }
105  static ValueVector getTombstoneKey() { return ValueVector{Value(), Value()}; }
106  static ::llvm::hash_code getHashValue(const ValueVector &val) {
107  return ::llvm::hash_combine_range(val);
108  }
109  static bool isEqual(const ValueVector &LHS, const ValueVector &RHS) {
110  return LHS == RHS;
111  }
112 };
113 
114 /// This class wraps a IRMapping to provide recursive lookup
115 /// functionality, i.e. we will traverse if the mapped value also has a mapping.
116 struct ConversionValueMapping {
117  /// Return "true" if an SSA value is mapped to the given value. May return
118  /// false positives.
119  bool isMappedTo(Value value) const { return mappedTo.contains(value); }
120 
121  /// Lookup the most recently mapped values with the desired types in the
122  /// mapping.
123  ///
124  /// Special cases:
125  /// - If the desired type range is empty, simply return the most recently
126  /// mapped values.
127  /// - If there is no mapping to the desired types, also return the most
128  /// recently mapped values.
129  /// - If there is no mapping for the given values at all, return the given
130  /// value.
131  ValueVector lookupOrDefault(Value from, TypeRange desiredTypes = {}) const;
132 
133  /// Lookup the given value within the map, or return an empty vector if the
134  /// value is not mapped. If it is mapped, this follows the same behavior
135  /// as `lookupOrDefault`.
136  ValueVector lookupOrNull(Value from, TypeRange desiredTypes = {}) const;
137 
138  template <typename T>
139  struct IsValueVector : std::is_same<std::decay_t<T>, ValueVector> {};
140 
141  /// Map a value vector to the one provided.
142  template <typename OldVal, typename NewVal>
143  std::enable_if_t<IsValueVector<OldVal>::value && IsValueVector<NewVal>::value>
144  map(OldVal &&oldVal, NewVal &&newVal) {
145  LLVM_DEBUG({
146  ValueVector next(newVal);
147  while (true) {
148  assert(next != oldVal && "inserting cyclic mapping");
149  auto it = mapping.find(next);
150  if (it == mapping.end())
151  break;
152  next = it->second;
153  }
154  });
155  mappedTo.insert_range(newVal);
156 
157  mapping[std::forward<OldVal>(oldVal)] = std::forward<NewVal>(newVal);
158  }
159 
160  /// Map a value vector or single value to the one provided.
161  template <typename OldVal, typename NewVal>
162  std::enable_if_t<!IsValueVector<OldVal>::value ||
163  !IsValueVector<NewVal>::value>
164  map(OldVal &&oldVal, NewVal &&newVal) {
165  if constexpr (IsValueVector<OldVal>{}) {
166  map(std::forward<OldVal>(oldVal), ValueVector{newVal});
167  } else if constexpr (IsValueVector<NewVal>{}) {
168  map(ValueVector{oldVal}, std::forward<NewVal>(newVal));
169  } else {
170  map(ValueVector{oldVal}, ValueVector{newVal});
171  }
172  }
173 
174  void map(Value oldVal, SmallVector<Value> &&newVal) {
175  map(ValueVector{oldVal}, ValueVector(std::move(newVal)));
176  }
177 
178  /// Drop the last mapping for the given values.
179  void erase(const ValueVector &value) { mapping.erase(value); }
180 
181 private:
182  /// Current value mappings.
184 
185  /// All SSA values that are mapped to. May contain false positives.
186  DenseSet<Value> mappedTo;
187 };
188 } // namespace
189 
191 ConversionValueMapping::lookupOrDefault(Value from,
192  TypeRange desiredTypes) const {
193  // Try to find the deepest values that have the desired types. If there is no
194  // such mapping, simply return the deepest values.
195  ValueVector desiredValue;
196  ValueVector current{from};
197  do {
198  // Store the current value if the types match.
199  if (TypeRange(ValueRange(current)) == desiredTypes)
200  desiredValue = current;
201 
202  // If possible, Replace each value with (one or multiple) mapped values.
203  ValueVector next;
204  for (Value v : current) {
205  auto it = mapping.find({v});
206  if (it != mapping.end()) {
207  llvm::append_range(next, it->second);
208  } else {
209  next.push_back(v);
210  }
211  }
212  if (next != current) {
213  // If at least one value was replaced, continue the lookup from there.
214  current = std::move(next);
215  continue;
216  }
217 
218  // Otherwise: Check if there is a mapping for the entire vector. Such
219  // mappings are materializations. (N:M mapping are not supported for value
220  // replacements.)
221  //
222  // Note: From a correctness point of view, materializations do not have to
223  // be stored (and looked up) in the mapping. But for performance reasons,
224  // we choose to reuse existing IR (when possible) instead of creating it
225  // multiple times.
226  auto it = mapping.find(current);
227  if (it == mapping.end()) {
228  // No mapping found: The lookup stops here.
229  break;
230  }
231  current = it->second;
232  } while (true);
233 
234  // If the desired values were found use them, otherwise default to the leaf
235  // values.
236  // Note: If `desiredTypes` is empty, this function always returns `current`.
237  return !desiredValue.empty() ? std::move(desiredValue) : std::move(current);
238 }
239 
240 ValueVector ConversionValueMapping::lookupOrNull(Value from,
241  TypeRange desiredTypes) const {
242  ValueVector result = lookupOrDefault(from, desiredTypes);
243  if (result == ValueVector{from} ||
244  (!desiredTypes.empty() && TypeRange(ValueRange(result)) != desiredTypes))
245  return {};
246  return result;
247 }
248 
249 //===----------------------------------------------------------------------===//
250 // Rewriter and Translation State
251 //===----------------------------------------------------------------------===//
252 namespace {
253 /// This class contains a snapshot of the current conversion rewriter state.
254 /// This is useful when saving and undoing a set of rewrites.
255 struct RewriterState {
256  RewriterState(unsigned numRewrites, unsigned numIgnoredOperations,
257  unsigned numReplacedOps)
258  : numRewrites(numRewrites), numIgnoredOperations(numIgnoredOperations),
259  numReplacedOps(numReplacedOps) {}
260 
261  /// The current number of rewrites performed.
262  unsigned numRewrites;
263 
264  /// The current number of ignored operations.
265  unsigned numIgnoredOperations;
266 
267  /// The current number of replaced ops that are scheduled for erasure.
268  unsigned numReplacedOps;
269 };
270 
271 //===----------------------------------------------------------------------===//
272 // IR rewrites
273 //===----------------------------------------------------------------------===//
274 
275 static void notifyIRErased(RewriterBase::Listener *listener, Operation &op);
276 
277 /// Notify the listener that the given block and its contents are being erased.
278 static void notifyIRErased(RewriterBase::Listener *listener, Block &b) {
279  for (Operation &op : b)
280  notifyIRErased(listener, op);
281  listener->notifyBlockErased(&b);
282 }
283 
284 /// Notify the listener that the given operation and its contents are being
285 /// erased.
286 static void notifyIRErased(RewriterBase::Listener *listener, Operation &op) {
287  for (Region &r : op.getRegions()) {
288  for (Block &b : r) {
289  notifyIRErased(listener, b);
290  }
291  }
292  listener->notifyOperationErased(&op);
293 }
294 
295 /// An IR rewrite that can be committed (upon success) or rolled back (upon
296 /// failure).
297 ///
298 /// The dialect conversion keeps track of IR modifications (requested by the
299 /// user through the rewriter API) in `IRRewrite` objects. Some kind of rewrites
300 /// are directly applied to the IR as the rewriter API is used, some are applied
301 /// partially, and some are delayed until the `IRRewrite` objects are committed.
302 class IRRewrite {
303 public:
304  /// The kind of the rewrite. Rewrites can be undone if the conversion fails.
305  /// Enum values are ordered, so that they can be used in `classof`: first all
306  /// block rewrites, then all operation rewrites.
307  enum class Kind {
308  // Block rewrites
309  CreateBlock,
310  EraseBlock,
311  InlineBlock,
312  MoveBlock,
313  BlockTypeConversion,
314  ReplaceBlockArg,
315  // Operation rewrites
316  MoveOperation,
317  ModifyOperation,
318  ReplaceOperation,
319  CreateOperation,
320  UnresolvedMaterialization
321  };
322 
323  virtual ~IRRewrite() = default;
324 
325  /// Roll back the rewrite. Operations may be erased during rollback.
326  virtual void rollback() = 0;
327 
328  /// Commit the rewrite. At this point, it is certain that the dialect
329  /// conversion will succeed. All IR modifications, except for operation/block
330  /// erasure, must be performed through the given rewriter.
331  ///
332  /// Instead of erasing operations/blocks, they should merely be unlinked
333  /// commit phase and finally be erased during the cleanup phase. This is
334  /// because internal dialect conversion state (such as `mapping`) may still
335  /// be using them.
336  ///
337  /// Any IR modification that was already performed before the commit phase
338  /// (e.g., insertion of an op) must be communicated to the listener that may
339  /// be attached to the given rewriter.
340  virtual void commit(RewriterBase &rewriter) {}
341 
342  /// Cleanup operations/blocks. Cleanup is called after commit.
343  virtual void cleanup(RewriterBase &rewriter) {}
344 
345  Kind getKind() const { return kind; }
346 
347  static bool classof(const IRRewrite *rewrite) { return true; }
348 
349 protected:
350  IRRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl)
351  : kind(kind), rewriterImpl(rewriterImpl) {}
352 
353  const ConversionConfig &getConfig() const;
354 
355  const Kind kind;
356  ConversionPatternRewriterImpl &rewriterImpl;
357 };
358 
359 /// A block rewrite.
360 class BlockRewrite : public IRRewrite {
361 public:
362  /// Return the block that this rewrite operates on.
363  Block *getBlock() const { return block; }
364 
365  static bool classof(const IRRewrite *rewrite) {
366  return rewrite->getKind() >= Kind::CreateBlock &&
367  rewrite->getKind() <= Kind::ReplaceBlockArg;
368  }
369 
370 protected:
371  BlockRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl,
372  Block *block)
373  : IRRewrite(kind, rewriterImpl), block(block) {}
374 
375  // The block that this rewrite operates on.
376  Block *block;
377 };
378 
379 /// Creation of a block. Block creations are immediately reflected in the IR.
380 /// There is no extra work to commit the rewrite. During rollback, the newly
381 /// created block is erased.
382 class CreateBlockRewrite : public BlockRewrite {
383 public:
384  CreateBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block)
385  : BlockRewrite(Kind::CreateBlock, rewriterImpl, block) {}
386 
387  static bool classof(const IRRewrite *rewrite) {
388  return rewrite->getKind() == Kind::CreateBlock;
389  }
390 
391  void commit(RewriterBase &rewriter) override {
392  // The block was already created and inserted. Just inform the listener.
393  if (auto *listener = rewriter.getListener())
394  listener->notifyBlockInserted(block, /*previous=*/{}, /*previousIt=*/{});
395  }
396 
397  void rollback() override {
398  // Unlink all of the operations within this block, they will be deleted
399  // separately.
400  auto &blockOps = block->getOperations();
401  while (!blockOps.empty())
402  blockOps.remove(blockOps.begin());
403  block->dropAllUses();
404  if (block->getParent())
405  block->erase();
406  else
407  delete block;
408  }
409 };
410 
411 /// Erasure of a block. Block erasures are partially reflected in the IR. Erased
412 /// blocks are immediately unlinked, but only erased during cleanup. This makes
413 /// it easier to rollback a block erasure: the block is simply inserted into its
414 /// original location.
415 class EraseBlockRewrite : public BlockRewrite {
416 public:
417  EraseBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block)
418  : BlockRewrite(Kind::EraseBlock, rewriterImpl, block),
419  region(block->getParent()), insertBeforeBlock(block->getNextNode()) {}
420 
421  static bool classof(const IRRewrite *rewrite) {
422  return rewrite->getKind() == Kind::EraseBlock;
423  }
424 
425  ~EraseBlockRewrite() override {
426  assert(!block &&
427  "rewrite was neither rolled back nor committed/cleaned up");
428  }
429 
430  void rollback() override {
431  // The block (owned by this rewrite) was not actually erased yet. It was
432  // just unlinked. Put it back into its original position.
433  assert(block && "expected block");
434  auto &blockList = region->getBlocks();
435  Region::iterator before = insertBeforeBlock
436  ? Region::iterator(insertBeforeBlock)
437  : blockList.end();
438  blockList.insert(before, block);
439  block = nullptr;
440  }
441 
442  void commit(RewriterBase &rewriter) override {
443  assert(block && "expected block");
444 
445  // Notify the listener that the block and its contents are being erased.
446  if (auto *listener =
447  dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener()))
448  notifyIRErased(listener, *block);
449  }
450 
451  void cleanup(RewriterBase &rewriter) override {
452  // Erase the contents of the block.
453  for (auto &op : llvm::make_early_inc_range(llvm::reverse(*block)))
454  rewriter.eraseOp(&op);
455  assert(block->empty() && "expected empty block");
456 
457  // Erase the block.
458  block->dropAllDefinedValueUses();
459  delete block;
460  block = nullptr;
461  }
462 
463 private:
464  // The region in which this block was previously contained.
465  Region *region;
466 
467  // The original successor of this block before it was unlinked. "nullptr" if
468  // this block was the only block in the region.
469  Block *insertBeforeBlock;
470 };
471 
472 /// Inlining of a block. This rewrite is immediately reflected in the IR.
473 /// Note: This rewrite represents only the inlining of the operations. The
474 /// erasure of the inlined block is a separate rewrite.
475 class InlineBlockRewrite : public BlockRewrite {
476 public:
477  InlineBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block,
478  Block *sourceBlock, Block::iterator before)
479  : BlockRewrite(Kind::InlineBlock, rewriterImpl, block),
480  sourceBlock(sourceBlock),
481  firstInlinedInst(sourceBlock->empty() ? nullptr
482  : &sourceBlock->front()),
483  lastInlinedInst(sourceBlock->empty() ? nullptr : &sourceBlock->back()) {
484  // If a listener is attached to the dialect conversion, ops must be moved
485  // one-by-one. When they are moved in bulk, notifications cannot be sent
486  // because the ops that used to be in the source block at the time of the
487  // inlining (before the "commit" phase) are unknown at the time when
488  // notifications are sent (which is during the "commit" phase).
489  assert(!getConfig().listener &&
490  "InlineBlockRewrite not supported if listener is attached");
491  }
492 
493  static bool classof(const IRRewrite *rewrite) {
494  return rewrite->getKind() == Kind::InlineBlock;
495  }
496 
497  void rollback() override {
498  // Put the operations from the destination block (owned by the rewrite)
499  // back into the source block.
500  if (firstInlinedInst) {
501  assert(lastInlinedInst && "expected operation");
502  sourceBlock->getOperations().splice(sourceBlock->begin(),
503  block->getOperations(),
504  Block::iterator(firstInlinedInst),
505  ++Block::iterator(lastInlinedInst));
506  }
507  }
508 
509 private:
510  // The block that originally contained the operations.
511  Block *sourceBlock;
512 
513  // The first inlined operation.
514  Operation *firstInlinedInst;
515 
516  // The last inlined operation.
517  Operation *lastInlinedInst;
518 };
519 
520 /// Moving of a block. This rewrite is immediately reflected in the IR.
521 class MoveBlockRewrite : public BlockRewrite {
522 public:
523  MoveBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block,
524  Region *region, Block *insertBeforeBlock)
525  : BlockRewrite(Kind::MoveBlock, rewriterImpl, block), region(region),
526  insertBeforeBlock(insertBeforeBlock) {}
527 
528  static bool classof(const IRRewrite *rewrite) {
529  return rewrite->getKind() == Kind::MoveBlock;
530  }
531 
532  void commit(RewriterBase &rewriter) override {
533  // The block was already moved. Just inform the listener.
534  if (auto *listener = rewriter.getListener()) {
535  // Note: `previousIt` cannot be passed because this is a delayed
536  // notification and iterators into past IR state cannot be represented.
537  listener->notifyBlockInserted(block, /*previous=*/region,
538  /*previousIt=*/{});
539  }
540  }
541 
542  void rollback() override {
543  // Move the block back to its original position.
544  Region::iterator before =
545  insertBeforeBlock ? Region::iterator(insertBeforeBlock) : region->end();
546  region->getBlocks().splice(before, block->getParent()->getBlocks(), block);
547  }
548 
549 private:
550  // The region in which this block was previously contained.
551  Region *region;
552 
553  // The original successor of this block before it was moved. "nullptr" if
554  // this block was the only block in the region.
555  Block *insertBeforeBlock;
556 };
557 
558 /// Block type conversion. This rewrite is partially reflected in the IR.
559 class BlockTypeConversionRewrite : public BlockRewrite {
560 public:
561  BlockTypeConversionRewrite(ConversionPatternRewriterImpl &rewriterImpl,
562  Block *origBlock, Block *newBlock)
563  : BlockRewrite(Kind::BlockTypeConversion, rewriterImpl, origBlock),
564  newBlock(newBlock) {}
565 
566  static bool classof(const IRRewrite *rewrite) {
567  return rewrite->getKind() == Kind::BlockTypeConversion;
568  }
569 
570  Block *getOrigBlock() const { return block; }
571 
572  Block *getNewBlock() const { return newBlock; }
573 
574  void commit(RewriterBase &rewriter) override;
575 
576  void rollback() override;
577 
578 private:
579  /// The new block that was created as part of this signature conversion.
580  Block *newBlock;
581 };
582 
583 /// Replacing a block argument. This rewrite is not immediately reflected in the
584 /// IR. An internal IR mapping is updated, but the actual replacement is delayed
585 /// until the rewrite is committed.
586 class ReplaceBlockArgRewrite : public BlockRewrite {
587 public:
588  ReplaceBlockArgRewrite(ConversionPatternRewriterImpl &rewriterImpl,
589  Block *block, BlockArgument arg,
590  const TypeConverter *converter)
591  : BlockRewrite(Kind::ReplaceBlockArg, rewriterImpl, block), arg(arg),
592  converter(converter) {}
593 
594  static bool classof(const IRRewrite *rewrite) {
595  return rewrite->getKind() == Kind::ReplaceBlockArg;
596  }
597 
598  void commit(RewriterBase &rewriter) override;
599 
600  void rollback() override;
601 
602 private:
603  BlockArgument arg;
604 
605  /// The current type converter when the block argument was replaced.
606  const TypeConverter *converter;
607 };
608 
609 /// An operation rewrite.
610 class OperationRewrite : public IRRewrite {
611 public:
612  /// Return the operation that this rewrite operates on.
613  Operation *getOperation() const { return op; }
614 
615  static bool classof(const IRRewrite *rewrite) {
616  return rewrite->getKind() >= Kind::MoveOperation &&
617  rewrite->getKind() <= Kind::UnresolvedMaterialization;
618  }
619 
620 protected:
621  OperationRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl,
622  Operation *op)
623  : IRRewrite(kind, rewriterImpl), op(op) {}
624 
625  // The operation that this rewrite operates on.
626  Operation *op;
627 };
628 
629 /// Moving of an operation. This rewrite is immediately reflected in the IR.
630 class MoveOperationRewrite : public OperationRewrite {
631 public:
632  MoveOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
633  Operation *op, Block *block, Operation *insertBeforeOp)
634  : OperationRewrite(Kind::MoveOperation, rewriterImpl, op), block(block),
635  insertBeforeOp(insertBeforeOp) {}
636 
637  static bool classof(const IRRewrite *rewrite) {
638  return rewrite->getKind() == Kind::MoveOperation;
639  }
640 
641  void commit(RewriterBase &rewriter) override {
642  // The operation was already moved. Just inform the listener.
643  if (auto *listener = rewriter.getListener()) {
644  // Note: `previousIt` cannot be passed because this is a delayed
645  // notification and iterators into past IR state cannot be represented.
646  listener->notifyOperationInserted(
647  op, /*previous=*/OpBuilder::InsertPoint(/*insertBlock=*/block,
648  /*insertPt=*/{}));
649  }
650  }
651 
652  void rollback() override {
653  // Move the operation back to its original position.
654  Block::iterator before =
655  insertBeforeOp ? Block::iterator(insertBeforeOp) : block->end();
656  block->getOperations().splice(before, op->getBlock()->getOperations(), op);
657  }
658 
659 private:
660  // The block in which this operation was previously contained.
661  Block *block;
662 
663  // The original successor of this operation before it was moved. "nullptr"
664  // if this operation was the only operation in the region.
665  Operation *insertBeforeOp;
666 };
667 
668 /// In-place modification of an op. This rewrite is immediately reflected in
669 /// the IR. The previous state of the operation is stored in this object.
670 class ModifyOperationRewrite : public OperationRewrite {
671 public:
672  ModifyOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
673  Operation *op)
674  : OperationRewrite(Kind::ModifyOperation, rewriterImpl, op),
675  name(op->getName()), loc(op->getLoc()), attrs(op->getAttrDictionary()),
676  operands(op->operand_begin(), op->operand_end()),
677  successors(op->successor_begin(), op->successor_end()) {
678  if (OpaqueProperties prop = op->getPropertiesStorage()) {
679  // Make a copy of the properties.
680  propertiesStorage = operator new(op->getPropertiesStorageSize());
681  OpaqueProperties propCopy(propertiesStorage);
682  name.initOpProperties(propCopy, /*init=*/prop);
683  }
684  }
685 
686  static bool classof(const IRRewrite *rewrite) {
687  return rewrite->getKind() == Kind::ModifyOperation;
688  }
689 
690  ~ModifyOperationRewrite() override {
691  assert(!propertiesStorage &&
692  "rewrite was neither committed nor rolled back");
693  }
694 
695  void commit(RewriterBase &rewriter) override {
696  // Notify the listener that the operation was modified in-place.
697  if (auto *listener =
698  dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener()))
699  listener->notifyOperationModified(op);
700 
701  if (propertiesStorage) {
702  OpaqueProperties propCopy(propertiesStorage);
703  // Note: The operation may have been erased in the mean time, so
704  // OperationName must be stored in this object.
705  name.destroyOpProperties(propCopy);
706  operator delete(propertiesStorage);
707  propertiesStorage = nullptr;
708  }
709  }
710 
711  void rollback() override {
712  op->setLoc(loc);
713  op->setAttrs(attrs);
714  op->setOperands(operands);
715  for (const auto &it : llvm::enumerate(successors))
716  op->setSuccessor(it.value(), it.index());
717  if (propertiesStorage) {
718  OpaqueProperties propCopy(propertiesStorage);
719  op->copyProperties(propCopy);
720  name.destroyOpProperties(propCopy);
721  operator delete(propertiesStorage);
722  propertiesStorage = nullptr;
723  }
724  }
725 
726 private:
727  OperationName name;
728  LocationAttr loc;
729  DictionaryAttr attrs;
730  SmallVector<Value, 8> operands;
731  SmallVector<Block *, 2> successors;
732  void *propertiesStorage = nullptr;
733 };
734 
735 /// Replacing an operation. Erasing an operation is treated as a special case
736 /// with "null" replacements. This rewrite is not immediately reflected in the
737 /// IR. An internal IR mapping is updated, but values are not replaced and the
738 /// original op is not erased until the rewrite is committed.
739 class ReplaceOperationRewrite : public OperationRewrite {
740 public:
741  ReplaceOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
742  Operation *op, const TypeConverter *converter)
743  : OperationRewrite(Kind::ReplaceOperation, rewriterImpl, op),
744  converter(converter) {}
745 
746  static bool classof(const IRRewrite *rewrite) {
747  return rewrite->getKind() == Kind::ReplaceOperation;
748  }
749 
750  void commit(RewriterBase &rewriter) override;
751 
752  void rollback() override;
753 
754  void cleanup(RewriterBase &rewriter) override;
755 
756 private:
757  /// An optional type converter that can be used to materialize conversions
758  /// between the new and old values if necessary.
759  const TypeConverter *converter;
760 };
761 
762 class CreateOperationRewrite : public OperationRewrite {
763 public:
764  CreateOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
765  Operation *op)
766  : OperationRewrite(Kind::CreateOperation, rewriterImpl, op) {}
767 
768  static bool classof(const IRRewrite *rewrite) {
769  return rewrite->getKind() == Kind::CreateOperation;
770  }
771 
772  void commit(RewriterBase &rewriter) override {
773  // The operation was already created and inserted. Just inform the listener.
774  if (auto *listener = rewriter.getListener())
775  listener->notifyOperationInserted(op, /*previous=*/{});
776  }
777 
778  void rollback() override;
779 };
780 
781 /// The type of materialization.
782 enum MaterializationKind {
783  /// This materialization materializes a conversion from an illegal type to a
784  /// legal one.
785  Target,
786 
787  /// This materialization materializes a conversion from a legal type back to
788  /// an illegal one.
789  Source
790 };
791 
792 /// Helper class that stores metadata about an unresolved materialization.
793 class UnresolvedMaterializationInfo {
794 public:
795  UnresolvedMaterializationInfo() = default;
796  UnresolvedMaterializationInfo(const TypeConverter *converter,
797  MaterializationKind kind, Type originalType)
798  : converterAndKind(converter, kind), originalType(originalType) {}
799 
800  /// Return the type converter of this materialization (which may be null).
801  const TypeConverter *getConverter() const {
802  return converterAndKind.getPointer();
803  }
804 
805  /// Return the kind of this materialization.
806  MaterializationKind getMaterializationKind() const {
807  return converterAndKind.getInt();
808  }
809 
810  /// Return the original type of the SSA value.
811  Type getOriginalType() const { return originalType; }
812 
813 private:
814  /// The corresponding type converter to use when resolving this
815  /// materialization, and the kind of this materialization.
816  llvm::PointerIntPair<const TypeConverter *, 2, MaterializationKind>
817  converterAndKind;
818 
819  /// The original type of the SSA value. Only used for target
820  /// materializations.
821  Type originalType;
822 };
823 
824 /// An unresolved materialization, i.e., a "builtin.unrealized_conversion_cast"
825 /// op. Unresolved materializations fold away or are replaced with
826 /// source/target materializations at the end of the dialect conversion.
827 class UnresolvedMaterializationRewrite : public OperationRewrite {
828 public:
829  UnresolvedMaterializationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
830  UnrealizedConversionCastOp op,
831  ValueVector mappedValues)
832  : OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
833  mappedValues(std::move(mappedValues)) {}
834 
835  static bool classof(const IRRewrite *rewrite) {
836  return rewrite->getKind() == Kind::UnresolvedMaterialization;
837  }
838 
839  void rollback() override;
840 
841  UnrealizedConversionCastOp getOperation() const {
842  return cast<UnrealizedConversionCastOp>(op);
843  }
844 
845 private:
846  /// The values in the conversion value mapping that are being replaced by the
847  /// results of this unresolved materialization.
848  ValueVector mappedValues;
849 };
850 } // namespace
851 
852 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
853 /// Return "true" if there is an operation rewrite that matches the specified
854 /// rewrite type and operation among the given rewrites.
855 template <typename RewriteTy, typename R>
856 static bool hasRewrite(R &&rewrites, Operation *op) {
857  return any_of(std::forward<R>(rewrites), [&](auto &rewrite) {
858  auto *rewriteTy = dyn_cast<RewriteTy>(rewrite.get());
859  return rewriteTy && rewriteTy->getOperation() == op;
860  });
861 }
862 
863 /// Return "true" if there is a block rewrite that matches the specified
864 /// rewrite type and block among the given rewrites.
865 template <typename RewriteTy, typename R>
866 static bool hasRewrite(R &&rewrites, Block *block) {
867  return any_of(std::forward<R>(rewrites), [&](auto &rewrite) {
868  auto *rewriteTy = dyn_cast<RewriteTy>(rewrite.get());
869  return rewriteTy && rewriteTy->getBlock() == block;
870  });
871 }
872 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
873 
874 //===----------------------------------------------------------------------===//
875 // ConversionPatternRewriterImpl
876 //===----------------------------------------------------------------------===//
877 namespace mlir {
878 namespace detail {
881  const ConversionConfig &config)
882  : context(ctx), config(config) {}
883 
884  //===--------------------------------------------------------------------===//
885  // State Management
886  //===--------------------------------------------------------------------===//
887 
888  /// Return the current state of the rewriter.
889  RewriterState getCurrentState();
890 
891  /// Apply all requested operation rewrites. This method is invoked when the
892  /// conversion process succeeds.
893  void applyRewrites();
894 
895  /// Reset the state of the rewriter to a previously saved point. Optionally,
896  /// the name of the pattern that triggered the rollback can specified for
897  /// debugging purposes.
898  void resetState(RewriterState state, StringRef patternName = "");
899 
900  /// Append a rewrite. Rewrites are committed upon success and rolled back upon
901  /// failure.
902  template <typename RewriteTy, typename... Args>
903  void appendRewrite(Args &&...args) {
904  rewrites.push_back(
905  std::make_unique<RewriteTy>(*this, std::forward<Args>(args)...));
906  }
907 
908  /// Undo the rewrites (motions, splits) one by one in reverse order until
909  /// "numRewritesToKeep" rewrites remains. Optionally, the name of the pattern
910  /// that triggered the rollback can specified for debugging purposes.
911  void undoRewrites(unsigned numRewritesToKeep = 0, StringRef patternName = "");
912 
913  /// Remap the given values to those with potentially different types. Returns
914  /// success if the values could be remapped, failure otherwise. `valueDiagTag`
915  /// is the tag used when describing a value within a diagnostic, e.g.
916  /// "operand".
917  LogicalResult remapValues(StringRef valueDiagTag,
918  std::optional<Location> inputLoc,
919  PatternRewriter &rewriter, ValueRange values,
920  SmallVector<ValueVector> &remapped);
921 
922  /// Return "true" if the given operation is ignored, and does not need to be
923  /// converted.
924  bool isOpIgnored(Operation *op) const;
925 
926  /// Return "true" if the given operation was replaced or erased.
927  bool wasOpReplaced(Operation *op) const;
928 
929  //===--------------------------------------------------------------------===//
930  // IR Rewrites / Type Conversion
931  //===--------------------------------------------------------------------===//
932 
933  /// Convert the types of block arguments within the given region.
934  FailureOr<Block *>
935  convertRegionTypes(ConversionPatternRewriter &rewriter, Region *region,
936  const TypeConverter &converter,
937  TypeConverter::SignatureConversion *entryConversion);
938 
939  /// Apply the given signature conversion on the given block. The new block
940  /// containing the updated signature is returned. If no conversions were
941  /// necessary, e.g. if the block has no arguments, `block` is returned.
942  /// `converter` is used to generate any necessary cast operations that
943  /// translate between the origin argument types and those specified in the
944  /// signature conversion.
945  Block *applySignatureConversion(
946  ConversionPatternRewriter &rewriter, Block *block,
947  const TypeConverter *converter,
948  TypeConverter::SignatureConversion &signatureConversion);
949 
950  /// Replace the results of the given operation with the given values and
951  /// erase the operation.
952  ///
953  /// There can be multiple replacement values for each result (1:N
954  /// replacement). If the replacement values are empty, the respective result
955  /// is dropped and a source materialization is built if the result still has
956  /// uses.
957  void replaceOp(Operation *op, SmallVector<SmallVector<Value>> &&newValues);
958 
959  /// Replace the given block argument with the given values. The specified
960  /// converter is used to build materializations (if necessary).
961  void replaceUsesOfBlockArgument(BlockArgument from, ValueRange to,
962  const TypeConverter *converter);
963 
964  /// Erase the given block and its contents.
965  void eraseBlock(Block *block);
966 
967  /// Inline the source block into the destination block before the given
968  /// iterator.
969  void inlineBlockBefore(Block *source, Block *dest, Block::iterator before);
970 
971  //===--------------------------------------------------------------------===//
972  // Materializations
973  //===--------------------------------------------------------------------===//
974 
975  /// Build an unresolved materialization operation given a range of output
976  /// types and a list of input operands. Returns the inputs if they their
977  /// types match the output types.
978  ///
979  /// If a cast op was built, it can optionally be returned with the `castOp`
980  /// output argument.
981  ///
982  /// If `valuesToMap` is set to a non-null Value, then that value is mapped to
983  /// the results of the unresolved materialization in the conversion value
984  /// mapping.
985  ValueRange buildUnresolvedMaterialization(
986  MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
987  ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes,
988  Type originalType, const TypeConverter *converter,
989  UnrealizedConversionCastOp *castOp = nullptr);
990 
991  /// Find a replacement value for the given SSA value in the conversion value
992  /// mapping. The replacement value must have the same type as the given SSA
993  /// value. If there is no replacement value with the correct type, find the
994  /// latest replacement value (regardless of the type) and build a source
995  /// materialization.
996  Value findOrBuildReplacementValue(Value value,
997  const TypeConverter *converter);
998 
999  //===--------------------------------------------------------------------===//
1000  // Rewriter Notification Hooks
1001  //===--------------------------------------------------------------------===//
1002 
1003  //// Notifies that an op was inserted.
1004  void notifyOperationInserted(Operation *op,
1005  OpBuilder::InsertPoint previous) override;
1006 
1007  /// Notifies that a block was inserted.
1008  void notifyBlockInserted(Block *block, Region *previous,
1009  Region::iterator previousIt) override;
1010 
1011  /// Notifies that a pattern match failed for the given reason.
1012  void
1013  notifyMatchFailure(Location loc,
1014  function_ref<void(Diagnostic &)> reasonCallback) override;
1015 
1016  //===--------------------------------------------------------------------===//
1017  // IR Erasure
1018  //===--------------------------------------------------------------------===//
1019 
1020  /// A rewriter that keeps track of erased ops and blocks. It ensures that no
1021  /// operation or block is erased multiple times. This rewriter assumes that
1022  /// no new IR is created between calls to `eraseOp`/`eraseBlock`.
1024  public:
1026  MLIRContext *context,
1027  std::function<void(Operation *)> opErasedCallback = nullptr)
1028  : RewriterBase(context, /*listener=*/this),
1029  opErasedCallback(opErasedCallback) {}
1030 
1031  /// Erase the given op (unless it was already erased).
1032  void eraseOp(Operation *op) override {
1033  if (wasErased(op))
1034  return;
1035  op->dropAllUses();
1037  }
1038 
1039  /// Erase the given block (unless it was already erased).
1040  void eraseBlock(Block *block) override {
1041  if (wasErased(block))
1042  return;
1043  assert(block->empty() && "expected empty block");
1044  block->dropAllDefinedValueUses();
1045  RewriterBase::eraseBlock(block);
1046  }
1047 
1048  bool wasErased(void *ptr) const { return erased.contains(ptr); }
1049 
1050  void notifyOperationErased(Operation *op) override {
1051  erased.insert(op);
1052  if (opErasedCallback)
1053  opErasedCallback(op);
1054  }
1055 
1056  void notifyBlockErased(Block *block) override { erased.insert(block); }
1057 
1058  private:
1059  /// Pointers to all erased operations and blocks.
1060  DenseSet<void *> erased;
1061 
1062  /// A callback that is invoked when an operation is erased.
1063  std::function<void(Operation *)> opErasedCallback;
1064  };
1065 
1066  //===--------------------------------------------------------------------===//
1067  // State
1068  //===--------------------------------------------------------------------===//
1069 
1070  /// MLIR context.
1072 
1073  // Mapping between replaced values that differ in type. This happens when
1074  // replacing a value with one of a different type.
1075  ConversionValueMapping mapping;
1076 
1077  /// Ordered list of block operations (creations, splits, motions).
1079 
1080  /// A set of operations that should no longer be considered for legalization.
1081  /// E.g., ops that are recursively legal. Ops that were replaced/erased are
1082  /// tracked separately.
1084 
1085  /// A set of operations that were replaced/erased. Such ops are not erased
1086  /// immediately but only when the dialect conversion succeeds. In the mean
1087  /// time, they should no longer be considered for legalization and any attempt
1088  /// to modify/access them is invalid rewriter API usage.
1090 
1091  /// A set of operations that were created by the current pattern.
1093 
1094  /// A set of operations that were modified by the current pattern.
1096 
1097  /// A set of blocks that were inserted (newly-created blocks or moved blocks)
1098  /// by the current pattern.
1100 
1101  /// A mapping for looking up metadata of unresolved materializations.
1104 
1105  /// The current type converter, or nullptr if no type converter is currently
1106  /// active.
1107  const TypeConverter *currentTypeConverter = nullptr;
1108 
1109  /// A mapping of regions to type converters that should be used when
1110  /// converting the arguments of blocks within that region.
1112 
1113  /// Dialect conversion configuration.
1115 
1116 #ifndef NDEBUG
1117  /// A set of operations that have pending updates. This tracking isn't
1118  /// strictly necessary, and is thus only active during debug builds for extra
1119  /// verification.
1121 
1122  /// A logger used to emit diagnostics during the conversion process.
1123  llvm::ScopedPrinter logger{llvm::dbgs()};
1124 #endif
1125 };
1126 } // namespace detail
1127 } // namespace mlir
1128 
1129 const ConversionConfig &IRRewrite::getConfig() const {
1130  return rewriterImpl.config;
1131 }
1132 
1133 void BlockTypeConversionRewrite::commit(RewriterBase &rewriter) {
1134  // Inform the listener about all IR modifications that have already taken
1135  // place: References to the original block have been replaced with the new
1136  // block.
1137  if (auto *listener =
1138  dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener()))
1139  for (Operation *op : getNewBlock()->getUsers())
1140  listener->notifyOperationModified(op);
1141 }
1142 
1143 void BlockTypeConversionRewrite::rollback() {
1144  getNewBlock()->replaceAllUsesWith(getOrigBlock());
1145 }
1146 
1147 void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) {
1148  Value repl = rewriterImpl.findOrBuildReplacementValue(arg, converter);
1149  if (!repl)
1150  return;
1151 
1152  if (isa<BlockArgument>(repl)) {
1153  rewriter.replaceAllUsesWith(arg, repl);
1154  return;
1155  }
1156 
1157  // If the replacement value is an operation, we check to make sure that we
1158  // don't replace uses that are within the parent operation of the
1159  // replacement value.
1160  Operation *replOp = cast<OpResult>(repl).getOwner();
1161  Block *replBlock = replOp->getBlock();
1162  rewriter.replaceUsesWithIf(arg, repl, [&](OpOperand &operand) {
1163  Operation *user = operand.getOwner();
1164  return user->getBlock() != replBlock || replOp->isBeforeInBlock(user);
1165  });
1166 }
1167 
1168 void ReplaceBlockArgRewrite::rollback() { rewriterImpl.mapping.erase({arg}); }
1169 
1170 void ReplaceOperationRewrite::commit(RewriterBase &rewriter) {
1171  auto *listener =
1172  dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener());
1173 
1174  // Compute replacement values.
1175  SmallVector<Value> replacements =
1176  llvm::map_to_vector(op->getResults(), [&](OpResult result) {
1177  return rewriterImpl.findOrBuildReplacementValue(result, converter);
1178  });
1179 
1180  // Notify the listener that the operation is about to be replaced.
1181  if (listener)
1182  listener->notifyOperationReplaced(op, replacements);
1183 
1184  // Replace all uses with the new values.
1185  for (auto [result, newValue] :
1186  llvm::zip_equal(op->getResults(), replacements))
1187  if (newValue)
1188  rewriter.replaceAllUsesWith(result, newValue);
1189 
1190  // The original op will be erased, so remove it from the set of unlegalized
1191  // ops.
1192  if (getConfig().unlegalizedOps)
1193  getConfig().unlegalizedOps->erase(op);
1194 
1195  // Notify the listener that the operation and its contents are being erased.
1196  if (listener)
1197  notifyIRErased(listener, *op);
1198 
1199  // Do not erase the operation yet. It may still be referenced in `mapping`.
1200  // Just unlink it for now and erase it during cleanup.
1201  op->getBlock()->getOperations().remove(op);
1202 }
1203 
1204 void ReplaceOperationRewrite::rollback() {
1205  for (auto result : op->getResults())
1206  rewriterImpl.mapping.erase({result});
1207 }
1208 
1209 void ReplaceOperationRewrite::cleanup(RewriterBase &rewriter) {
1210  rewriter.eraseOp(op);
1211 }
1212 
1213 void CreateOperationRewrite::rollback() {
1214  for (Region &region : op->getRegions()) {
1215  while (!region.getBlocks().empty())
1216  region.getBlocks().remove(region.getBlocks().begin());
1217  }
1218  op->dropAllUses();
1219  op->erase();
1220 }
1221 
1222 void UnresolvedMaterializationRewrite::rollback() {
1223  if (!mappedValues.empty())
1224  rewriterImpl.mapping.erase(mappedValues);
1225  rewriterImpl.unresolvedMaterializations.erase(getOperation());
1226  op->erase();
1227 }
1228 
1230  // Commit all rewrites.
1231  IRRewriter rewriter(context, config.listener);
1232  // Note: New rewrites may be added during the "commit" phase and the
1233  // `rewrites` vector may reallocate.
1234  for (size_t i = 0; i < rewrites.size(); ++i)
1235  rewrites[i]->commit(rewriter);
1236 
1237  // Clean up all rewrites.
1238  SingleEraseRewriter eraseRewriter(
1239  context, /*opErasedCallback=*/[&](Operation *op) {
1240  if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op))
1241  unresolvedMaterializations.erase(castOp);
1242  });
1243  for (auto &rewrite : rewrites)
1244  rewrite->cleanup(eraseRewriter);
1245 }
1246 
1247 //===----------------------------------------------------------------------===//
1248 // State Management
1249 //===----------------------------------------------------------------------===//
1250 
1252  return RewriterState(rewrites.size(), ignoredOps.size(), replacedOps.size());
1253 }
1254 
1256  StringRef patternName) {
1257  // Undo any rewrites.
1258  undoRewrites(state.numRewrites, patternName);
1259 
1260  // Pop all of the recorded ignored operations that are no longer valid.
1261  while (ignoredOps.size() != state.numIgnoredOperations)
1262  ignoredOps.pop_back();
1263 
1264  while (replacedOps.size() != state.numReplacedOps)
1265  replacedOps.pop_back();
1266 }
1267 
1268 void ConversionPatternRewriterImpl::undoRewrites(unsigned numRewritesToKeep,
1269  StringRef patternName) {
1270  for (auto &rewrite :
1271  llvm::reverse(llvm::drop_begin(rewrites, numRewritesToKeep))) {
1273  !isa<UnresolvedMaterializationRewrite>(rewrite)) {
1274  // Unresolved materializations can always be rolled back (erased).
1275  llvm::report_fatal_error("pattern '" + patternName +
1276  "' rollback of IR modifications requested");
1277  }
1278  rewrite->rollback();
1279  }
1280  rewrites.resize(numRewritesToKeep);
1281 }
1282 
1284  StringRef valueDiagTag, std::optional<Location> inputLoc,
1285  PatternRewriter &rewriter, ValueRange values,
1286  SmallVector<ValueVector> &remapped) {
1287  remapped.reserve(llvm::size(values));
1288 
1289  for (const auto &it : llvm::enumerate(values)) {
1290  Value operand = it.value();
1291  Type origType = operand.getType();
1292  Location operandLoc = inputLoc ? *inputLoc : operand.getLoc();
1293 
1294  if (!currentTypeConverter) {
1295  // The current pattern does not have a type converter. I.e., it does not
1296  // distinguish between legal and illegal types. For each operand, simply
1297  // pass through the most recently mapped values.
1298  remapped.push_back(mapping.lookupOrDefault(operand));
1299  continue;
1300  }
1301 
1302  // If there is no legal conversion, fail to match this pattern.
1303  SmallVector<Type, 1> legalTypes;
1304  if (failed(currentTypeConverter->convertType(origType, legalTypes))) {
1305  notifyMatchFailure(operandLoc, [=](Diagnostic &diag) {
1306  diag << "unable to convert type for " << valueDiagTag << " #"
1307  << it.index() << ", type was " << origType;
1308  });
1309  return failure();
1310  }
1311  // If a type is converted to 0 types, there is nothing to do.
1312  if (legalTypes.empty()) {
1313  remapped.push_back({});
1314  continue;
1315  }
1316 
1317  ValueVector repl = mapping.lookupOrDefault(operand, legalTypes);
1318  if (!repl.empty() && TypeRange(ValueRange(repl)) == legalTypes) {
1319  // Mapped values have the correct type or there is an existing
1320  // materialization. Or the operand is not mapped at all and has the
1321  // correct type.
1322  remapped.push_back(std::move(repl));
1323  continue;
1324  }
1325 
1326  // Create a materialization for the most recently mapped values.
1327  repl = mapping.lookupOrDefault(operand);
1329  MaterializationKind::Target, computeInsertPoint(repl), operandLoc,
1330  /*valuesToMap=*/repl, /*inputs=*/repl, /*outputTypes=*/legalTypes,
1331  /*originalType=*/origType, currentTypeConverter);
1332  remapped.push_back(castValues);
1333  }
1334  return success();
1335 }
1336 
1338  // Check to see if this operation is ignored or was replaced.
1339  return replacedOps.count(op) || ignoredOps.count(op);
1340 }
1341 
1343  // Check to see if this operation was replaced.
1344  return replacedOps.count(op);
1345 }
1346 
1347 //===----------------------------------------------------------------------===//
1348 // Type Conversion
1349 //===----------------------------------------------------------------------===//
1350 
1352  ConversionPatternRewriter &rewriter, Region *region,
1353  const TypeConverter &converter,
1354  TypeConverter::SignatureConversion *entryConversion) {
1355  regionToConverter[region] = &converter;
1356  if (region->empty())
1357  return nullptr;
1358 
1359  // Convert the arguments of each non-entry block within the region.
1360  for (Block &block :
1361  llvm::make_early_inc_range(llvm::drop_begin(*region, 1))) {
1362  // Compute the signature for the block with the provided converter.
1363  std::optional<TypeConverter::SignatureConversion> conversion =
1364  converter.convertBlockSignature(&block);
1365  if (!conversion)
1366  return failure();
1367  // Convert the block with the computed signature.
1368  applySignatureConversion(rewriter, &block, &converter, *conversion);
1369  }
1370 
1371  // Convert the entry block. If an entry signature conversion was provided,
1372  // use that one. Otherwise, compute the signature with the type converter.
1373  if (entryConversion)
1374  return applySignatureConversion(rewriter, &region->front(), &converter,
1375  *entryConversion);
1376  std::optional<TypeConverter::SignatureConversion> conversion =
1377  converter.convertBlockSignature(&region->front());
1378  if (!conversion)
1379  return failure();
1380  return applySignatureConversion(rewriter, &region->front(), &converter,
1381  *conversion);
1382 }
1383 
1385  ConversionPatternRewriter &rewriter, Block *block,
1386  const TypeConverter *converter,
1387  TypeConverter::SignatureConversion &signatureConversion) {
1388 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
1389  // A block cannot be converted multiple times.
1390  if (hasRewrite<BlockTypeConversionRewrite>(rewrites, block))
1391  llvm::report_fatal_error("block was already converted");
1392 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
1393 
1394  OpBuilder::InsertionGuard g(rewriter);
1395 
1396  // If no arguments are being changed or added, there is nothing to do.
1397  unsigned origArgCount = block->getNumArguments();
1398  auto convertedTypes = signatureConversion.getConvertedTypes();
1399  if (llvm::equal(block->getArgumentTypes(), convertedTypes))
1400  return block;
1401 
1402  // Compute the locations of all block arguments in the new block.
1403  SmallVector<Location> newLocs(convertedTypes.size(),
1404  rewriter.getUnknownLoc());
1405  for (unsigned i = 0; i < origArgCount; ++i) {
1406  auto inputMap = signatureConversion.getInputMapping(i);
1407  if (!inputMap || inputMap->replacedWithValues())
1408  continue;
1409  Location origLoc = block->getArgument(i).getLoc();
1410  for (unsigned j = 0; j < inputMap->size; ++j)
1411  newLocs[inputMap->inputNo + j] = origLoc;
1412  }
1413 
1414  // Insert a new block with the converted block argument types and move all ops
1415  // from the old block to the new block.
1416  Block *newBlock =
1417  rewriter.createBlock(block->getParent(), std::next(block->getIterator()),
1418  convertedTypes, newLocs);
1419 
1420  // If a listener is attached to the dialect conversion, ops cannot be moved
1421  // to the destination block in bulk ("fast path"). This is because at the time
1422  // the notifications are sent, it is unknown which ops were moved. Instead,
1423  // ops should be moved one-by-one ("slow path"), so that a separate
1424  // `MoveOperationRewrite` is enqueued for each moved op. Moving ops in bulk is
1425  // a bit more efficient, so we try to do that when possible.
1426  bool fastPath = !config.listener;
1427  if (fastPath) {
1428  appendRewrite<InlineBlockRewrite>(newBlock, block, newBlock->end());
1429  newBlock->getOperations().splice(newBlock->end(), block->getOperations());
1430  } else {
1431  while (!block->empty())
1432  rewriter.moveOpBefore(&block->front(), newBlock, newBlock->end());
1433  }
1434 
1435  // Replace all uses of the old block with the new block.
1436  block->replaceAllUsesWith(newBlock);
1437 
1438  for (unsigned i = 0; i != origArgCount; ++i) {
1439  BlockArgument origArg = block->getArgument(i);
1440  Type origArgType = origArg.getType();
1441 
1442  std::optional<TypeConverter::SignatureConversion::InputMapping> inputMap =
1443  signatureConversion.getInputMapping(i);
1444  if (!inputMap) {
1445  // This block argument was dropped and no replacement value was provided.
1446  // Materialize a replacement value "out of thin air".
1447  Value mat =
1449  MaterializationKind::Source,
1450  OpBuilder::InsertPoint(newBlock, newBlock->begin()),
1451  origArg.getLoc(),
1452  /*valuesToMap=*/{}, /*inputs=*/ValueRange(),
1453  /*outputTypes=*/origArgType, /*originalType=*/Type(), converter)
1454  .front();
1455  replaceUsesOfBlockArgument(origArg, mat, converter);
1456  continue;
1457  }
1458 
1459  if (inputMap->replacedWithValues()) {
1460  // This block argument was dropped and replacement values were provided.
1461  assert(inputMap->size == 0 &&
1462  "invalid to provide a replacement value when the argument isn't "
1463  "dropped");
1464  replaceUsesOfBlockArgument(origArg, inputMap->replacementValues,
1465  converter);
1466  continue;
1467  }
1468 
1469  // This is a 1->1+ mapping.
1470  auto replArgs =
1471  newBlock->getArguments().slice(inputMap->inputNo, inputMap->size);
1472  replaceUsesOfBlockArgument(origArg, replArgs, converter);
1473  }
1474 
1475  appendRewrite<BlockTypeConversionRewrite>(/*origBlock=*/block, newBlock);
1476 
1477  // Erase the old block. (It is just unlinked for now and will be erased during
1478  // cleanup.)
1479  rewriter.eraseBlock(block);
1480 
1481  return newBlock;
1482 }
1483 
1484 //===----------------------------------------------------------------------===//
1485 // Materializations
1486 //===----------------------------------------------------------------------===//
1487 
1488 /// Build an unresolved materialization operation given an output type and set
1489 /// of input operands.
1491  MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
1492  ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes,
1493  Type originalType, const TypeConverter *converter,
1494  UnrealizedConversionCastOp *castOp) {
1495  assert((!originalType || kind == MaterializationKind::Target) &&
1496  "original type is valid only for target materializations");
1497  assert(TypeRange(inputs) != outputTypes &&
1498  "materialization is not necessary");
1499 
1500  // Create an unresolved materialization. We use a new OpBuilder to avoid
1501  // tracking the materialization like we do for other operations.
1502  OpBuilder builder(outputTypes.front().getContext());
1503  builder.setInsertionPoint(ip.getBlock(), ip.getPoint());
1504  auto convertOp =
1505  UnrealizedConversionCastOp::create(builder, loc, outputTypes, inputs);
1506  if (!valuesToMap.empty())
1507  mapping.map(valuesToMap, convertOp.getResults());
1508  if (castOp)
1509  *castOp = convertOp;
1510  unresolvedMaterializations[convertOp] =
1511  UnresolvedMaterializationInfo(converter, kind, originalType);
1512  appendRewrite<UnresolvedMaterializationRewrite>(convertOp,
1513  std::move(valuesToMap));
1514  return convertOp.getResults();
1515 }
1516 
1518  Value value, const TypeConverter *converter) {
1519  // Try to find a replacement value with the same type in the conversion value
1520  // mapping. This includes cached materializations. We try to reuse those
1521  // instead of generating duplicate IR.
1522  ValueVector repl = mapping.lookupOrNull(value, value.getType());
1523  if (!repl.empty())
1524  return repl.front();
1525 
1526  // Check if the value is dead. No replacement value is needed in that case.
1527  // This is an approximate check that may have false negatives but does not
1528  // require computing and traversing an inverse mapping. (We may end up
1529  // building source materializations that are never used and that fold away.)
1530  if (llvm::all_of(value.getUsers(),
1531  [&](Operation *op) { return replacedOps.contains(op); }) &&
1532  !mapping.isMappedTo(value))
1533  return Value();
1534 
1535  // No replacement value was found. Get the latest replacement value
1536  // (regardless of the type) and build a source materialization to the
1537  // original type.
1538  repl = mapping.lookupOrNull(value);
1539  if (repl.empty()) {
1540  // No replacement value is registered in the mapping. This means that the
1541  // value is dropped and no longer needed. (If the value were still needed,
1542  // a source materialization producing a replacement value "out of thin air"
1543  // would have already been created during `replaceOp` or
1544  // `applySignatureConversion`.)
1545  return Value();
1546  }
1547 
1548  // Note: `computeInsertPoint` computes the "earliest" insertion point at
1549  // which all values in `repl` are defined. It is important to emit the
1550  // materialization at that location because the same materialization may be
1551  // reused in a different context. (That's because materializations are cached
1552  // in the conversion value mapping.) The insertion point of the
1553  // materialization must be valid for all future users that may be created
1554  // later in the conversion process.
1555  Value castValue =
1556  buildUnresolvedMaterialization(MaterializationKind::Source,
1557  computeInsertPoint(repl), value.getLoc(),
1558  /*valuesToMap=*/repl, /*inputs=*/repl,
1559  /*outputTypes=*/value.getType(),
1560  /*originalType=*/Type(), converter)
1561  .front();
1562  return castValue;
1563 }
1564 
1565 //===----------------------------------------------------------------------===//
1566 // Rewriter Notification Hooks
1567 //===----------------------------------------------------------------------===//
1568 
1570  Operation *op, OpBuilder::InsertPoint previous) {
1571  LLVM_DEBUG({
1572  logger.startLine() << "** Insert : '" << op->getName() << "'(" << op
1573  << ")\n";
1574  });
1575  assert(!wasOpReplaced(op->getParentOp()) &&
1576  "attempting to insert into a block within a replaced/erased op");
1577 
1578  if (!previous.isSet()) {
1579  // This is a newly created op.
1580  appendRewrite<CreateOperationRewrite>(op);
1581  patternNewOps.insert(op);
1582  return;
1583  }
1584  Operation *prevOp = previous.getPoint() == previous.getBlock()->end()
1585  ? nullptr
1586  : &*previous.getPoint();
1587  appendRewrite<MoveOperationRewrite>(op, previous.getBlock(), prevOp);
1588 }
1589 
1591  Operation *op, SmallVector<SmallVector<Value>> &&newValues) {
1592  assert(newValues.size() == op->getNumResults());
1593  assert(!ignoredOps.contains(op) && "operation was already replaced");
1594 
1595  // Check if replaced op is an unresolved materialization, i.e., an
1596  // unrealized_conversion_cast op that was created by the conversion driver.
1597  if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op)) {
1598  // Make sure that the user does not mess with unresolved materializations
1599  // that were inserted by the conversion driver. We keep track of these
1600  // ops in internal data structures.
1601  assert(!unresolvedMaterializations.contains(castOp) &&
1602  "attempting to replace/erase an unresolved materialization");
1603  }
1604 
1605  // Create mappings for each of the new result values.
1606  for (auto [repl, result] : llvm::zip_equal(newValues, op->getResults())) {
1607  if (repl.empty()) {
1608  // This result was dropped and no replacement value was provided.
1609  // Materialize a replacement value "out of thin air".
1611  MaterializationKind::Source, computeInsertPoint(result),
1612  result.getLoc(), /*valuesToMap=*/{result}, /*inputs=*/ValueRange(),
1613  /*outputTypes=*/result.getType(), /*originalType=*/Type(),
1615  continue;
1616  }
1617 
1618  // Remap result to replacement value.
1619  if (repl.empty())
1620  continue;
1621  mapping.map(static_cast<Value>(result), std::move(repl));
1622  }
1623 
1624  appendRewrite<ReplaceOperationRewrite>(op, currentTypeConverter);
1625  // Mark this operation and all nested ops as replaced.
1626  op->walk([&](Operation *op) { replacedOps.insert(op); });
1627 }
1628 
1630  BlockArgument from, ValueRange to, const TypeConverter *converter) {
1631  appendRewrite<ReplaceBlockArgRewrite>(from.getOwner(), from, converter);
1632  mapping.map(from, to);
1633 }
1634 
1636  assert(!wasOpReplaced(block->getParentOp()) &&
1637  "attempting to erase a block within a replaced/erased op");
1638  appendRewrite<EraseBlockRewrite>(block);
1639 
1640  // Unlink the block from its parent region. The block is kept in the rewrite
1641  // object and will be actually destroyed when rewrites are applied. This
1642  // allows us to keep the operations in the block live and undo the removal by
1643  // re-inserting the block.
1644  block->getParent()->getBlocks().remove(block);
1645 
1646  // Mark all nested ops as erased.
1647  block->walk([&](Operation *op) { replacedOps.insert(op); });
1648 }
1649 
1651  Block *block, Region *previous, Region::iterator previousIt) {
1652  assert(!wasOpReplaced(block->getParentOp()) &&
1653  "attempting to insert into a region within a replaced/erased op");
1654  LLVM_DEBUG(
1655  {
1656  Operation *parent = block->getParentOp();
1657  if (parent) {
1658  logger.startLine() << "** Insert Block into : '" << parent->getName()
1659  << "'(" << parent << ")\n";
1660  } else {
1661  logger.startLine()
1662  << "** Insert Block into detached Region (nullptr parent op)'\n";
1663  }
1664  });
1665 
1666  patternInsertedBlocks.insert(block);
1667 
1668  if (!previous) {
1669  // This is a newly created block.
1670  appendRewrite<CreateBlockRewrite>(block);
1671  return;
1672  }
1673  Block *prevBlock = previousIt == previous->end() ? nullptr : &*previousIt;
1674  appendRewrite<MoveBlockRewrite>(block, previous, prevBlock);
1675 }
1676 
1678  Block *dest,
1679  Block::iterator before) {
1680  appendRewrite<InlineBlockRewrite>(dest, source, before);
1681 }
1682 
1684  Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
1685  LLVM_DEBUG({
1687  reasonCallback(diag);
1688  logger.startLine() << "** Failure : " << diag.str() << "\n";
1689  if (config.notifyCallback)
1691  });
1692 }
1693 
1694 //===----------------------------------------------------------------------===//
1695 // ConversionPatternRewriter
1696 //===----------------------------------------------------------------------===//
1697 
1698 ConversionPatternRewriter::ConversionPatternRewriter(
1699  MLIRContext *ctx, const ConversionConfig &config)
1700  : PatternRewriter(ctx),
1701  impl(new detail::ConversionPatternRewriterImpl(ctx, config)) {
1702  setListener(impl.get());
1703 }
1704 
1706 
1708  assert(op && newOp && "expected non-null op");
1709  replaceOp(op, newOp->getResults());
1710 }
1711 
1713  assert(op->getNumResults() == newValues.size() &&
1714  "incorrect # of replacement values");
1715  LLVM_DEBUG({
1716  impl->logger.startLine()
1717  << "** Replace : '" << op->getName() << "'(" << op << ")\n";
1718  });
1720  llvm::map_to_vector(newValues, [](Value v) -> SmallVector<Value> {
1721  return v ? SmallVector<Value>{v} : SmallVector<Value>();
1722  });
1723  impl->replaceOp(op, std::move(newVals));
1724 }
1725 
1727  Operation *op, SmallVector<SmallVector<Value>> &&newValues) {
1728  assert(op->getNumResults() == newValues.size() &&
1729  "incorrect # of replacement values");
1730  LLVM_DEBUG({
1731  impl->logger.startLine()
1732  << "** Replace : '" << op->getName() << "'(" << op << ")\n";
1733  });
1734  impl->replaceOp(op, std::move(newValues));
1735 }
1736 
1738  LLVM_DEBUG({
1739  impl->logger.startLine()
1740  << "** Erase : '" << op->getName() << "'(" << op << ")\n";
1741  });
1742  SmallVector<SmallVector<Value>> nullRepls(op->getNumResults(), {});
1743  impl->replaceOp(op, std::move(nullRepls));
1744 }
1745 
1747  impl->eraseBlock(block);
1748 }
1749 
1751  Block *block, TypeConverter::SignatureConversion &conversion,
1752  const TypeConverter *converter) {
1753  assert(!impl->wasOpReplaced(block->getParentOp()) &&
1754  "attempting to apply a signature conversion to a block within a "
1755  "replaced/erased op");
1756  return impl->applySignatureConversion(*this, block, converter, conversion);
1757 }
1758 
1760  Region *region, const TypeConverter &converter,
1761  TypeConverter::SignatureConversion *entryConversion) {
1762  assert(!impl->wasOpReplaced(region->getParentOp()) &&
1763  "attempting to apply a signature conversion to a block within a "
1764  "replaced/erased op");
1765  return impl->convertRegionTypes(*this, region, converter, entryConversion);
1766 }
1767 
1769  ValueRange to) {
1770  LLVM_DEBUG({
1771  impl->logger.startLine() << "** Replace Argument : '" << from << "'";
1772  if (Operation *parentOp = from.getOwner()->getParentOp()) {
1773  impl->logger.getOStream() << " (in region of '" << parentOp->getName()
1774  << "' (" << parentOp << ")\n";
1775  } else {
1776  impl->logger.getOStream() << " (unlinked block)\n";
1777  }
1778  });
1779  impl->replaceUsesOfBlockArgument(from, to, impl->currentTypeConverter);
1780 }
1781 
1783  SmallVector<ValueVector> remappedValues;
1784  if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, key,
1785  remappedValues)))
1786  return nullptr;
1787  assert(remappedValues.front().size() == 1 && "1:N conversion not supported");
1788  return remappedValues.front().front();
1789 }
1790 
1791 LogicalResult
1793  SmallVectorImpl<Value> &results) {
1794  if (keys.empty())
1795  return success();
1796  SmallVector<ValueVector> remapped;
1797  if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, keys,
1798  remapped)))
1799  return failure();
1800  for (const auto &values : remapped) {
1801  assert(values.size() == 1 && "1:N conversion not supported");
1802  results.push_back(values.front());
1803  }
1804  return success();
1805 }
1806 
1808  Block::iterator before,
1809  ValueRange argValues) {
1810 #ifndef NDEBUG
1811  assert(argValues.size() == source->getNumArguments() &&
1812  "incorrect # of argument replacement values");
1813  assert(!impl->wasOpReplaced(source->getParentOp()) &&
1814  "attempting to inline a block from a replaced/erased op");
1815  assert(!impl->wasOpReplaced(dest->getParentOp()) &&
1816  "attempting to inline a block into a replaced/erased op");
1817  auto opIgnored = [&](Operation *op) { return impl->isOpIgnored(op); };
1818  // The source block will be deleted, so it should not have any users (i.e.,
1819  // there should be no predecessors).
1820  assert(llvm::all_of(source->getUsers(), opIgnored) &&
1821  "expected 'source' to have no predecessors");
1822 #endif // NDEBUG
1823 
1824  // If a listener is attached to the dialect conversion, ops cannot be moved
1825  // to the destination block in bulk ("fast path"). This is because at the time
1826  // the notifications are sent, it is unknown which ops were moved. Instead,
1827  // ops should be moved one-by-one ("slow path"), so that a separate
1828  // `MoveOperationRewrite` is enqueued for each moved op. Moving ops in bulk is
1829  // a bit more efficient, so we try to do that when possible.
1830  bool fastPath = !impl->config.listener;
1831 
1832  if (fastPath)
1833  impl->inlineBlockBefore(source, dest, before);
1834 
1835  // Replace all uses of block arguments.
1836  for (auto it : llvm::zip(source->getArguments(), argValues))
1837  replaceUsesOfBlockArgument(std::get<0>(it), std::get<1>(it));
1838 
1839  if (fastPath) {
1840  // Move all ops at once.
1841  dest->getOperations().splice(before, source->getOperations());
1842  } else {
1843  // Move op by op.
1844  while (!source->empty())
1845  moveOpBefore(&source->front(), dest, before);
1846  }
1847 
1848  // Erase the source block.
1849  eraseBlock(source);
1850 }
1851 
1853  assert(!impl->wasOpReplaced(op) &&
1854  "attempting to modify a replaced/erased op");
1855 #ifndef NDEBUG
1856  impl->pendingRootUpdates.insert(op);
1857 #endif
1858  impl->appendRewrite<ModifyOperationRewrite>(op);
1859 }
1860 
1862  assert(!impl->wasOpReplaced(op) &&
1863  "attempting to modify a replaced/erased op");
1865  impl->patternModifiedOps.insert(op);
1866 
1867  // There is nothing to do here, we only need to track the operation at the
1868  // start of the update.
1869 #ifndef NDEBUG
1870  assert(impl->pendingRootUpdates.erase(op) &&
1871  "operation did not have a pending in-place update");
1872 #endif
1873 }
1874 
1876 #ifndef NDEBUG
1877  assert(impl->pendingRootUpdates.erase(op) &&
1878  "operation did not have a pending in-place update");
1879 #endif
1880  // Erase the last update for this operation.
1881  auto it = llvm::find_if(
1882  llvm::reverse(impl->rewrites), [&](std::unique_ptr<IRRewrite> &rewrite) {
1883  auto *modifyRewrite = dyn_cast<ModifyOperationRewrite>(rewrite.get());
1884  return modifyRewrite && modifyRewrite->getOperation() == op;
1885  });
1886  assert(it != impl->rewrites.rend() && "no root update started on op");
1887  (*it)->rollback();
1888  int updateIdx = std::prev(impl->rewrites.rend()) - it;
1889  impl->rewrites.erase(impl->rewrites.begin() + updateIdx);
1890 }
1891 
1893  return *impl;
1894 }
1895 
1896 //===----------------------------------------------------------------------===//
1897 // ConversionPattern
1898 //===----------------------------------------------------------------------===//
1899 
1901  ArrayRef<ValueRange> operands) const {
1902  SmallVector<Value> oneToOneOperands;
1903  oneToOneOperands.reserve(operands.size());
1904  for (ValueRange operand : operands) {
1905  if (operand.size() != 1)
1906  llvm::report_fatal_error("pattern '" + getDebugName() +
1907  "' does not support 1:N conversion");
1908  oneToOneOperands.push_back(operand.front());
1909  }
1910  return oneToOneOperands;
1911 }
1912 
1913 LogicalResult
1915  PatternRewriter &rewriter) const {
1916  auto &dialectRewriter = static_cast<ConversionPatternRewriter &>(rewriter);
1917  auto &rewriterImpl = dialectRewriter.getImpl();
1918 
1919  // Track the current conversion pattern type converter in the rewriter.
1920  llvm::SaveAndRestore currentConverterGuard(rewriterImpl.currentTypeConverter,
1921  getTypeConverter());
1922 
1923  // Remap the operands of the operation.
1924  SmallVector<ValueVector> remapped;
1925  if (failed(rewriterImpl.remapValues("operand", op->getLoc(), rewriter,
1926  op->getOperands(), remapped))) {
1927  return failure();
1928  }
1929  SmallVector<ValueRange> remappedAsRange =
1930  llvm::to_vector_of<ValueRange>(remapped);
1931  return matchAndRewrite(op, remappedAsRange, dialectRewriter);
1932 }
1933 
1934 //===----------------------------------------------------------------------===//
1935 // OperationLegalizer
1936 //===----------------------------------------------------------------------===//
1937 
1938 namespace {
1939 /// A set of rewrite patterns that can be used to legalize a given operation.
1940 using LegalizationPatterns = SmallVector<const Pattern *, 1>;
1941 
1942 /// This class defines a recursive operation legalizer.
1943 class OperationLegalizer {
1944 public:
1945  using LegalizationAction = ConversionTarget::LegalizationAction;
1946 
1947  OperationLegalizer(const ConversionTarget &targetInfo,
1949  const ConversionConfig &config);
1950 
1951  /// Returns true if the given operation is known to be illegal on the target.
1952  bool isIllegal(Operation *op) const;
1953 
1954  /// Attempt to legalize the given operation. Returns success if the operation
1955  /// was legalized, failure otherwise.
1956  LogicalResult legalize(Operation *op, ConversionPatternRewriter &rewriter);
1957 
1958  /// Returns the conversion target in use by the legalizer.
1959  const ConversionTarget &getTarget() { return target; }
1960 
1961 private:
1962  /// Attempt to legalize the given operation by folding it.
1963  LogicalResult legalizeWithFold(Operation *op,
1964  ConversionPatternRewriter &rewriter);
1965 
1966  /// Attempt to legalize the given operation by applying a pattern. Returns
1967  /// success if the operation was legalized, failure otherwise.
1968  LogicalResult legalizeWithPattern(Operation *op,
1969  ConversionPatternRewriter &rewriter);
1970 
1971  /// Return true if the given pattern may be applied to the given operation,
1972  /// false otherwise.
1973  bool canApplyPattern(Operation *op, const Pattern &pattern,
1974  ConversionPatternRewriter &rewriter);
1975 
1976  /// Legalize the resultant IR after successfully applying the given pattern.
1977  LogicalResult legalizePatternResult(Operation *op, const Pattern &pattern,
1978  ConversionPatternRewriter &rewriter,
1979  const SetVector<Operation *> &newOps,
1980  const SetVector<Operation *> &modifiedOps,
1981  const SetVector<Block *> &insertedBlocks);
1982 
1983  /// Legalizes the actions registered during the execution of a pattern.
1984  LogicalResult
1985  legalizePatternBlockRewrites(Operation *op,
1986  ConversionPatternRewriter &rewriter,
1988  const SetVector<Block *> &insertedBlocks,
1989  const SetVector<Operation *> &newOps);
1990  LogicalResult
1991  legalizePatternCreatedOperations(ConversionPatternRewriter &rewriter,
1993  const SetVector<Operation *> &newOps);
1994  LogicalResult
1995  legalizePatternRootUpdates(ConversionPatternRewriter &rewriter,
1997  const SetVector<Operation *> &modifiedOps);
1998 
1999  //===--------------------------------------------------------------------===//
2000  // Cost Model
2001  //===--------------------------------------------------------------------===//
2002 
2003  /// Build an optimistic legalization graph given the provided patterns. This
2004  /// function populates 'anyOpLegalizerPatterns' and 'legalizerPatterns' with
2005  /// patterns for operations that are not directly legal, but may be
2006  /// transitively legal for the current target given the provided patterns.
2007  void buildLegalizationGraph(
2008  LegalizationPatterns &anyOpLegalizerPatterns,
2010 
2011  /// Compute the benefit of each node within the computed legalization graph.
2012  /// This orders the patterns within 'legalizerPatterns' based upon two
2013  /// criteria:
2014  /// 1) Prefer patterns that have the lowest legalization depth, i.e.
2015  /// represent the more direct mapping to the target.
2016  /// 2) When comparing patterns with the same legalization depth, prefer the
2017  /// pattern with the highest PatternBenefit. This allows for users to
2018  /// prefer specific legalizations over others.
2019  void computeLegalizationGraphBenefit(
2020  LegalizationPatterns &anyOpLegalizerPatterns,
2022 
2023  /// Compute the legalization depth when legalizing an operation of the given
2024  /// type.
2025  unsigned computeOpLegalizationDepth(
2026  OperationName op, DenseMap<OperationName, unsigned> &minOpPatternDepth,
2028 
2029  /// Apply the conversion cost model to the given set of patterns, and return
2030  /// the smallest legalization depth of any of the patterns. See
2031  /// `computeLegalizationGraphBenefit` for the breakdown of the cost model.
2032  unsigned applyCostModelToPatterns(
2033  LegalizationPatterns &patterns,
2034  DenseMap<OperationName, unsigned> &minOpPatternDepth,
2036 
2037  /// The current set of patterns that have been applied.
2038  SmallPtrSet<const Pattern *, 8> appliedPatterns;
2039 
2040  /// The legalization information provided by the target.
2041  const ConversionTarget &target;
2042 
2043  /// The pattern applicator to use for conversions.
2044  PatternApplicator applicator;
2045 
2046  /// Dialect conversion configuration.
2047  const ConversionConfig &config;
2048 };
2049 } // namespace
2050 
2051 OperationLegalizer::OperationLegalizer(const ConversionTarget &targetInfo,
2053  const ConversionConfig &config)
2054  : target(targetInfo), applicator(patterns), config(config) {
2055  // The set of patterns that can be applied to illegal operations to transform
2056  // them into legal ones.
2058  LegalizationPatterns anyOpLegalizerPatterns;
2059 
2060  buildLegalizationGraph(anyOpLegalizerPatterns, legalizerPatterns);
2061  computeLegalizationGraphBenefit(anyOpLegalizerPatterns, legalizerPatterns);
2062 }
2063 
2064 bool OperationLegalizer::isIllegal(Operation *op) const {
2065  return target.isIllegal(op);
2066 }
2067 
2068 LogicalResult
2069 OperationLegalizer::legalize(Operation *op,
2070  ConversionPatternRewriter &rewriter) {
2071 #ifndef NDEBUG
2072  const char *logLineComment =
2073  "//===-------------------------------------------===//\n";
2074 
2075  auto &logger = rewriter.getImpl().logger;
2076 #endif
2077 
2078  // Check to see if the operation is ignored and doesn't need to be converted.
2079  bool isIgnored = rewriter.getImpl().isOpIgnored(op);
2080 
2081  LLVM_DEBUG({
2082  logger.getOStream() << "\n";
2083  logger.startLine() << logLineComment;
2084  logger.startLine() << "Legalizing operation : ";
2085  // Do not print the operation name if the operation is ignored. Ignored ops
2086  // may have been erased and should not be accessed. The pointer can be
2087  // printed safely.
2088  if (!isIgnored)
2089  logger.getOStream() << "'" << op->getName() << "' ";
2090  logger.getOStream() << "(" << op << ") {\n";
2091  logger.indent();
2092 
2093  // If the operation has no regions, just print it here.
2094  if (!isIgnored && op->getNumRegions() == 0) {
2095  op->print(logger.startLine(), OpPrintingFlags().printGenericOpForm());
2096  logger.getOStream() << "\n\n";
2097  }
2098  });
2099 
2100  if (isIgnored) {
2101  LLVM_DEBUG({
2102  logSuccess(logger, "operation marked 'ignored' during conversion");
2103  logger.startLine() << logLineComment;
2104  });
2105  return success();
2106  }
2107 
2108  // Check if this operation is legal on the target.
2109  if (auto legalityInfo = target.isLegal(op)) {
2110  LLVM_DEBUG({
2111  logSuccess(
2112  logger, "operation marked legal by the target{0}",
2113  legalityInfo->isRecursivelyLegal
2114  ? "; NOTE: operation is recursively legal; skipping internals"
2115  : "");
2116  logger.startLine() << logLineComment;
2117  });
2118 
2119  // If this operation is recursively legal, mark its children as ignored so
2120  // that we don't consider them for legalization.
2121  if (legalityInfo->isRecursivelyLegal) {
2122  op->walk([&](Operation *nested) {
2123  if (op != nested)
2124  rewriter.getImpl().ignoredOps.insert(nested);
2125  });
2126  }
2127 
2128  return success();
2129  }
2130 
2131  // If the operation isn't legal, try to fold it in-place.
2132  // TODO: Should we always try to do this, even if the op is
2133  // already legal?
2134  if (succeeded(legalizeWithFold(op, rewriter))) {
2135  LLVM_DEBUG({
2136  logSuccess(logger, "operation was folded");
2137  logger.startLine() << logLineComment;
2138  });
2139  return success();
2140  }
2141 
2142  // Otherwise, we need to apply a legalization pattern to this operation.
2143  if (succeeded(legalizeWithPattern(op, rewriter))) {
2144  LLVM_DEBUG({
2145  logSuccess(logger, "");
2146  logger.startLine() << logLineComment;
2147  });
2148  return success();
2149  }
2150 
2151  LLVM_DEBUG({
2152  logFailure(logger, "no matched legalization pattern");
2153  logger.startLine() << logLineComment;
2154  });
2155  return failure();
2156 }
2157 
2158 /// Helper function that moves and returns the given object. Also resets the
2159 /// original object, so that it is in a valid, empty state again.
2160 template <typename T>
2161 static T moveAndReset(T &obj) {
2162  T result = std::move(obj);
2163  obj = T();
2164  return result;
2165 }
2166 
2167 LogicalResult
2168 OperationLegalizer::legalizeWithFold(Operation *op,
2169  ConversionPatternRewriter &rewriter) {
2170  auto &rewriterImpl = rewriter.getImpl();
2171  LLVM_DEBUG({
2172  rewriterImpl.logger.startLine() << "* Fold {\n";
2173  rewriterImpl.logger.indent();
2174  });
2175  (void)rewriterImpl;
2176 
2177  // Try to fold the operation.
2178  StringRef opName = op->getName().getStringRef();
2179  SmallVector<Value, 2> replacementValues;
2180  SmallVector<Operation *, 2> newOps;
2181  rewriter.setInsertionPoint(op);
2182  if (failed(rewriter.tryFold(op, replacementValues, &newOps))) {
2183  LLVM_DEBUG(logFailure(rewriterImpl.logger, "unable to fold"));
2184  return failure();
2185  }
2186 
2187  // An empty list of replacement values indicates that the fold was in-place.
2188  // As the operation changed, a new legalization needs to be attempted.
2189  if (replacementValues.empty())
2190  return legalize(op, rewriter);
2191 
2192  // Recursively legalize any new constant operations.
2193  for (Operation *newOp : newOps) {
2194  if (failed(legalize(newOp, rewriter))) {
2195  LLVM_DEBUG(logFailure(rewriterImpl.logger,
2196  "failed to legalize generated constant '{0}'",
2197  newOp->getName()));
2198  if (!config.allowPatternRollback) {
2199  // Rolling back a folder is like rolling back a pattern.
2200  llvm::report_fatal_error(
2201  "op '" + opName +
2202  "' folder rollback of IR modifications requested");
2203  }
2204  // Legalization failed: erase all materialized constants.
2205  for (Operation *op : newOps)
2206  rewriter.eraseOp(op);
2207  return failure();
2208  }
2209  }
2210 
2211  // Insert a replacement for 'op' with the folded replacement values.
2212  rewriter.replaceOp(op, replacementValues);
2213 
2214  LLVM_DEBUG(logSuccess(rewriterImpl.logger, ""));
2215  return success();
2216 }
2217 
2218 LogicalResult
2219 OperationLegalizer::legalizeWithPattern(Operation *op,
2220  ConversionPatternRewriter &rewriter) {
2221  auto &rewriterImpl = rewriter.getImpl();
2222 
2223  // Functor that returns if the given pattern may be applied.
2224  auto canApply = [&](const Pattern &pattern) {
2225  bool canApply = canApplyPattern(op, pattern, rewriter);
2226  if (canApply && config.listener)
2227  config.listener->notifyPatternBegin(pattern, op);
2228  return canApply;
2229  };
2230 
2231  // Functor that cleans up the rewriter state after a pattern failed to match.
2232  RewriterState curState = rewriterImpl.getCurrentState();
2233  auto onFailure = [&](const Pattern &pattern) {
2234  assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates");
2235  rewriterImpl.patternNewOps.clear();
2236  rewriterImpl.patternModifiedOps.clear();
2237  rewriterImpl.patternInsertedBlocks.clear();
2238  LLVM_DEBUG({
2239  logFailure(rewriterImpl.logger, "pattern failed to match");
2240  if (rewriterImpl.config.notifyCallback) {
2242  diag << "Failed to apply pattern \"" << pattern.getDebugName()
2243  << "\" on op:\n"
2244  << *op;
2245  rewriterImpl.config.notifyCallback(diag);
2246  }
2247  });
2248  if (config.listener)
2249  config.listener->notifyPatternEnd(pattern, failure());
2250  rewriterImpl.resetState(curState, pattern.getDebugName());
2251  appliedPatterns.erase(&pattern);
2252  };
2253 
2254  // Functor that performs additional legalization when a pattern is
2255  // successfully applied.
2256  auto onSuccess = [&](const Pattern &pattern) {
2257  assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates");
2258  SetVector<Operation *> newOps = moveAndReset(rewriterImpl.patternNewOps);
2259  SetVector<Operation *> modifiedOps =
2260  moveAndReset(rewriterImpl.patternModifiedOps);
2261  SetVector<Block *> insertedBlocks =
2262  moveAndReset(rewriterImpl.patternInsertedBlocks);
2263  auto result = legalizePatternResult(op, pattern, rewriter, newOps,
2264  modifiedOps, insertedBlocks);
2265  appliedPatterns.erase(&pattern);
2266  if (failed(result)) {
2267  if (!rewriterImpl.config.allowPatternRollback)
2268  llvm::report_fatal_error("pattern '" + pattern.getDebugName() +
2269  "' produced IR that could not be legalized");
2270  rewriterImpl.resetState(curState, pattern.getDebugName());
2271  }
2272  if (config.listener)
2273  config.listener->notifyPatternEnd(pattern, result);
2274  return result;
2275  };
2276 
2277  // Try to match and rewrite a pattern on this operation.
2278  return applicator.matchAndRewrite(op, rewriter, canApply, onFailure,
2279  onSuccess);
2280 }
2281 
2282 bool OperationLegalizer::canApplyPattern(Operation *op, const Pattern &pattern,
2283  ConversionPatternRewriter &rewriter) {
2284  LLVM_DEBUG({
2285  auto &os = rewriter.getImpl().logger;
2286  os.getOStream() << "\n";
2287  os.startLine() << "* Pattern : '" << op->getName() << " -> (";
2288  llvm::interleaveComma(pattern.getGeneratedOps(), os.getOStream());
2289  os.getOStream() << ")' {\n";
2290  os.indent();
2291  });
2292 
2293  // Ensure that we don't cycle by not allowing the same pattern to be
2294  // applied twice in the same recursion stack if it is not known to be safe.
2295  if (!pattern.hasBoundedRewriteRecursion() &&
2296  !appliedPatterns.insert(&pattern).second) {
2297  LLVM_DEBUG(
2298  logFailure(rewriter.getImpl().logger, "pattern was already applied"));
2299  return false;
2300  }
2301  return true;
2302 }
2303 
2304 LogicalResult OperationLegalizer::legalizePatternResult(
2305  Operation *op, const Pattern &pattern, ConversionPatternRewriter &rewriter,
2306  const SetVector<Operation *> &newOps,
2307  const SetVector<Operation *> &modifiedOps,
2308  const SetVector<Block *> &insertedBlocks) {
2309  auto &impl = rewriter.getImpl();
2310  assert(impl.pendingRootUpdates.empty() && "dangling root updates");
2311 
2312 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2313  // Check that the root was either replaced or updated in place.
2314  auto newRewrites = llvm::drop_begin(impl.rewrites, curState.numRewrites);
2315  auto replacedRoot = [&] {
2316  return hasRewrite<ReplaceOperationRewrite>(newRewrites, op);
2317  };
2318  auto updatedRootInPlace = [&] {
2319  return hasRewrite<ModifyOperationRewrite>(newRewrites, op);
2320  };
2321  if (!replacedRoot() && !updatedRootInPlace())
2322  llvm::report_fatal_error("expected pattern to replace the root operation");
2323 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2324 
2325  // Legalize each of the actions registered during application.
2326  if (failed(legalizePatternBlockRewrites(op, rewriter, impl, insertedBlocks,
2327  newOps)) ||
2328  failed(legalizePatternRootUpdates(rewriter, impl, modifiedOps)) ||
2329  failed(legalizePatternCreatedOperations(rewriter, impl, newOps))) {
2330  return failure();
2331  }
2332 
2333  LLVM_DEBUG(logSuccess(impl.logger, "pattern applied successfully"));
2334  return success();
2335 }
2336 
2337 LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
2338  Operation *op, ConversionPatternRewriter &rewriter,
2340  const SetVector<Block *> &insertedBlocks,
2341  const SetVector<Operation *> &newOps) {
2342  SmallPtrSet<Operation *, 16> alreadyLegalized;
2343 
2344  // If the pattern moved or created any blocks, make sure the types of block
2345  // arguments get legalized.
2346  for (Block *block : insertedBlocks) {
2347  // Only check blocks outside of the current operation.
2348  Operation *parentOp = block->getParentOp();
2349  if (!parentOp || parentOp == op || block->getNumArguments() == 0)
2350  continue;
2351 
2352  // If the region of the block has a type converter, try to convert the block
2353  // directly.
2354  if (auto *converter = impl.regionToConverter.lookup(block->getParent())) {
2355  std::optional<TypeConverter::SignatureConversion> conversion =
2356  converter->convertBlockSignature(block);
2357  if (!conversion) {
2358  LLVM_DEBUG(logFailure(impl.logger, "failed to convert types of moved "
2359  "block"));
2360  return failure();
2361  }
2362  impl.applySignatureConversion(rewriter, block, converter, *conversion);
2363  continue;
2364  }
2365 
2366  // Otherwise, try to legalize the parent operation if it was not generated
2367  // by this pattern. This is because we will attempt to legalize the parent
2368  // operation, and blocks in regions created by this pattern will already be
2369  // legalized later on.
2370  if (!newOps.count(parentOp) && alreadyLegalized.insert(parentOp).second) {
2371  if (failed(legalize(parentOp, rewriter))) {
2372  LLVM_DEBUG(logFailure(
2373  impl.logger, "operation '{0}'({1}) became illegal after rewrite",
2374  parentOp->getName(), parentOp));
2375  return failure();
2376  }
2377  }
2378  }
2379  return success();
2380 }
2381 
2382 LogicalResult OperationLegalizer::legalizePatternCreatedOperations(
2384  const SetVector<Operation *> &newOps) {
2385  for (Operation *op : newOps) {
2386  if (failed(legalize(op, rewriter))) {
2387  LLVM_DEBUG(logFailure(impl.logger,
2388  "failed to legalize generated operation '{0}'({1})",
2389  op->getName(), op));
2390  return failure();
2391  }
2392  }
2393  return success();
2394 }
2395 
2396 LogicalResult OperationLegalizer::legalizePatternRootUpdates(
2398  const SetVector<Operation *> &modifiedOps) {
2399  for (Operation *op : modifiedOps) {
2400  if (failed(legalize(op, rewriter))) {
2401  LLVM_DEBUG(logFailure(
2402  impl.logger, "failed to legalize operation updated in-place '{0}'",
2403  op->getName()));
2404  return failure();
2405  }
2406  }
2407  return success();
2408 }
2409 
2410 //===----------------------------------------------------------------------===//
2411 // Cost Model
2412 //===----------------------------------------------------------------------===//
2413 
2414 void OperationLegalizer::buildLegalizationGraph(
2415  LegalizationPatterns &anyOpLegalizerPatterns,
2416  DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
2417  // A mapping between an operation and a set of operations that can be used to
2418  // generate it.
2420  // A mapping between an operation and any currently invalid patterns it has.
2422  // A worklist of patterns to consider for legality.
2423  SetVector<const Pattern *> patternWorklist;
2424 
2425  // Build the mapping from operations to the parent ops that may generate them.
2426  applicator.walkAllPatterns([&](const Pattern &pattern) {
2427  std::optional<OperationName> root = pattern.getRootKind();
2428 
2429  // If the pattern has no specific root, we can't analyze the relationship
2430  // between the root op and generated operations. Given that, add all such
2431  // patterns to the legalization set.
2432  if (!root) {
2433  anyOpLegalizerPatterns.push_back(&pattern);
2434  return;
2435  }
2436 
2437  // Skip operations that are always known to be legal.
2438  if (target.getOpAction(*root) == LegalizationAction::Legal)
2439  return;
2440 
2441  // Add this pattern to the invalid set for the root op and record this root
2442  // as a parent for any generated operations.
2443  invalidPatterns[*root].insert(&pattern);
2444  for (auto op : pattern.getGeneratedOps())
2445  parentOps[op].insert(*root);
2446 
2447  // Add this pattern to the worklist.
2448  patternWorklist.insert(&pattern);
2449  });
2450 
2451  // If there are any patterns that don't have a specific root kind, we can't
2452  // make direct assumptions about what operations will never be legalized.
2453  // Note: Technically we could, but it would require an analysis that may
2454  // recurse into itself. It would be better to perform this kind of filtering
2455  // at a higher level than here anyways.
2456  if (!anyOpLegalizerPatterns.empty()) {
2457  for (const Pattern *pattern : patternWorklist)
2458  legalizerPatterns[*pattern->getRootKind()].push_back(pattern);
2459  return;
2460  }
2461 
2462  while (!patternWorklist.empty()) {
2463  auto *pattern = patternWorklist.pop_back_val();
2464 
2465  // Check to see if any of the generated operations are invalid.
2466  if (llvm::any_of(pattern->getGeneratedOps(), [&](OperationName op) {
2467  std::optional<LegalizationAction> action = target.getOpAction(op);
2468  return !legalizerPatterns.count(op) &&
2469  (!action || action == LegalizationAction::Illegal);
2470  }))
2471  continue;
2472 
2473  // Otherwise, if all of the generated operation are valid, this op is now
2474  // legal so add all of the child patterns to the worklist.
2475  legalizerPatterns[*pattern->getRootKind()].push_back(pattern);
2476  invalidPatterns[*pattern->getRootKind()].erase(pattern);
2477 
2478  // Add any invalid patterns of the parent operations to see if they have now
2479  // become legal.
2480  for (auto op : parentOps[*pattern->getRootKind()])
2481  patternWorklist.set_union(invalidPatterns[op]);
2482  }
2483 }
2484 
2485 void OperationLegalizer::computeLegalizationGraphBenefit(
2486  LegalizationPatterns &anyOpLegalizerPatterns,
2487  DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
2488  // The smallest pattern depth, when legalizing an operation.
2489  DenseMap<OperationName, unsigned> minOpPatternDepth;
2490 
2491  // For each operation that is transitively legal, compute a cost for it.
2492  for (auto &opIt : legalizerPatterns)
2493  if (!minOpPatternDepth.count(opIt.first))
2494  computeOpLegalizationDepth(opIt.first, minOpPatternDepth,
2495  legalizerPatterns);
2496 
2497  // Apply the cost model to the patterns that can match any operation. Those
2498  // with a specific operation type are already resolved when computing the op
2499  // legalization depth.
2500  if (!anyOpLegalizerPatterns.empty())
2501  applyCostModelToPatterns(anyOpLegalizerPatterns, minOpPatternDepth,
2502  legalizerPatterns);
2503 
2504  // Apply a cost model to the pattern applicator. We order patterns first by
2505  // depth then benefit. `legalizerPatterns` contains per-op patterns by
2506  // decreasing benefit.
2507  applicator.applyCostModel([&](const Pattern &pattern) {
2508  ArrayRef<const Pattern *> orderedPatternList;
2509  if (std::optional<OperationName> rootName = pattern.getRootKind())
2510  orderedPatternList = legalizerPatterns[*rootName];
2511  else
2512  orderedPatternList = anyOpLegalizerPatterns;
2513 
2514  // If the pattern is not found, then it was removed and cannot be matched.
2515  auto *it = llvm::find(orderedPatternList, &pattern);
2516  if (it == orderedPatternList.end())
2518 
2519  // Patterns found earlier in the list have higher benefit.
2520  return PatternBenefit(std::distance(it, orderedPatternList.end()));
2521  });
2522 }
2523 
2524 unsigned OperationLegalizer::computeOpLegalizationDepth(
2525  OperationName op, DenseMap<OperationName, unsigned> &minOpPatternDepth,
2526  DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
2527  // Check for existing depth.
2528  auto depthIt = minOpPatternDepth.find(op);
2529  if (depthIt != minOpPatternDepth.end())
2530  return depthIt->second;
2531 
2532  // If a mapping for this operation does not exist, then this operation
2533  // is always legal. Return 0 as the depth for a directly legal operation.
2534  auto opPatternsIt = legalizerPatterns.find(op);
2535  if (opPatternsIt == legalizerPatterns.end() || opPatternsIt->second.empty())
2536  return 0u;
2537 
2538  // Record this initial depth in case we encounter this op again when
2539  // recursively computing the depth.
2540  minOpPatternDepth.try_emplace(op, std::numeric_limits<unsigned>::max());
2541 
2542  // Apply the cost model to the operation patterns, and update the minimum
2543  // depth.
2544  unsigned minDepth = applyCostModelToPatterns(
2545  opPatternsIt->second, minOpPatternDepth, legalizerPatterns);
2546  minOpPatternDepth[op] = minDepth;
2547  return minDepth;
2548 }
2549 
2550 unsigned OperationLegalizer::applyCostModelToPatterns(
2551  LegalizationPatterns &patterns,
2552  DenseMap<OperationName, unsigned> &minOpPatternDepth,
2553  DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
2554  unsigned minDepth = std::numeric_limits<unsigned>::max();
2555 
2556  // Compute the depth for each pattern within the set.
2557  SmallVector<std::pair<const Pattern *, unsigned>, 4> patternsByDepth;
2558  patternsByDepth.reserve(patterns.size());
2559  for (const Pattern *pattern : patterns) {
2560  unsigned depth = 1;
2561  for (auto generatedOp : pattern->getGeneratedOps()) {
2562  unsigned generatedOpDepth = computeOpLegalizationDepth(
2563  generatedOp, minOpPatternDepth, legalizerPatterns);
2564  depth = std::max(depth, generatedOpDepth + 1);
2565  }
2566  patternsByDepth.emplace_back(pattern, depth);
2567 
2568  // Update the minimum depth of the pattern list.
2569  minDepth = std::min(minDepth, depth);
2570  }
2571 
2572  // If the operation only has one legalization pattern, there is no need to
2573  // sort them.
2574  if (patternsByDepth.size() == 1)
2575  return minDepth;
2576 
2577  // Sort the patterns by those likely to be the most beneficial.
2578  llvm::stable_sort(patternsByDepth,
2579  [](const std::pair<const Pattern *, unsigned> &lhs,
2580  const std::pair<const Pattern *, unsigned> &rhs) {
2581  // First sort by the smaller pattern legalization
2582  // depth.
2583  if (lhs.second != rhs.second)
2584  return lhs.second < rhs.second;
2585 
2586  // Then sort by the larger pattern benefit.
2587  auto lhsBenefit = lhs.first->getBenefit();
2588  auto rhsBenefit = rhs.first->getBenefit();
2589  return lhsBenefit > rhsBenefit;
2590  });
2591 
2592  // Update the legalization pattern to use the new sorted list.
2593  patterns.clear();
2594  for (auto &patternIt : patternsByDepth)
2595  patterns.push_back(patternIt.first);
2596  return minDepth;
2597 }
2598 
2599 //===----------------------------------------------------------------------===//
2600 // OperationConverter
2601 //===----------------------------------------------------------------------===//
2602 namespace {
2603 enum OpConversionMode {
2604  /// In this mode, the conversion will ignore failed conversions to allow
2605  /// illegal operations to co-exist in the IR.
2606  Partial,
2607 
2608  /// In this mode, all operations must be legal for the given target for the
2609  /// conversion to succeed.
2610  Full,
2611 
2612  /// In this mode, operations are analyzed for legality. No actual rewrites are
2613  /// applied to the operations on success.
2614  Analysis,
2615 };
2616 } // namespace
2617 
2618 namespace mlir {
2619 // This class converts operations to a given conversion target via a set of
2620 // rewrite patterns. The conversion behaves differently depending on the
2621 // conversion mode.
2623  explicit OperationConverter(const ConversionTarget &target,
2625  const ConversionConfig &config,
2626  OpConversionMode mode)
2627  : config(config), opLegalizer(target, patterns, this->config),
2628  mode(mode) {}
2629 
2630  /// Converts the given operations to the conversion target.
2631  LogicalResult convertOperations(ArrayRef<Operation *> ops);
2632 
2633 private:
2634  /// Converts an operation with the given rewriter.
2635  LogicalResult convert(ConversionPatternRewriter &rewriter, Operation *op);
2636 
2637  /// Dialect conversion configuration.
2638  ConversionConfig config;
2639 
2640  /// The legalizer to use when converting operations.
2641  OperationLegalizer opLegalizer;
2642 
2643  /// The conversion mode to use when legalizing operations.
2644  OpConversionMode mode;
2645 };
2646 } // namespace mlir
2647 
2648 LogicalResult OperationConverter::convert(ConversionPatternRewriter &rewriter,
2649  Operation *op) {
2650  // Legalize the given operation.
2651  if (failed(opLegalizer.legalize(op, rewriter))) {
2652  // Handle the case of a failed conversion for each of the different modes.
2653  // Full conversions expect all operations to be converted.
2654  if (mode == OpConversionMode::Full)
2655  return op->emitError()
2656  << "failed to legalize operation '" << op->getName() << "'";
2657  // Partial conversions allow conversions to fail iff the operation was not
2658  // explicitly marked as illegal. If the user provided a `unlegalizedOps`
2659  // set, non-legalizable ops are added to that set.
2660  if (mode == OpConversionMode::Partial) {
2661  if (opLegalizer.isIllegal(op))
2662  return op->emitError()
2663  << "failed to legalize operation '" << op->getName()
2664  << "' that was explicitly marked illegal";
2665  if (config.unlegalizedOps)
2666  config.unlegalizedOps->insert(op);
2667  }
2668  } else if (mode == OpConversionMode::Analysis) {
2669  // Analysis conversions don't fail if any operations fail to legalize,
2670  // they are only interested in the operations that were successfully
2671  // legalized.
2672  if (config.legalizableOps)
2673  config.legalizableOps->insert(op);
2674  }
2675  return success();
2676 }
2677 
2678 static LogicalResult
2680  UnrealizedConversionCastOp op,
2681  const UnresolvedMaterializationInfo &info) {
2682  assert(!op.use_empty() &&
2683  "expected that dead materializations have already been DCE'd");
2684  Operation::operand_range inputOperands = op.getOperands();
2685 
2686  // Try to materialize the conversion.
2687  if (const TypeConverter *converter = info.getConverter()) {
2688  rewriter.setInsertionPoint(op);
2689  SmallVector<Value> newMaterialization;
2690  switch (info.getMaterializationKind()) {
2691  case MaterializationKind::Target:
2692  newMaterialization = converter->materializeTargetConversion(
2693  rewriter, op->getLoc(), op.getResultTypes(), inputOperands,
2694  info.getOriginalType());
2695  break;
2696  case MaterializationKind::Source:
2697  assert(op->getNumResults() == 1 && "expected single result");
2698  Value sourceMat = converter->materializeSourceConversion(
2699  rewriter, op->getLoc(), op.getResultTypes().front(), inputOperands);
2700  if (sourceMat)
2701  newMaterialization.push_back(sourceMat);
2702  break;
2703  }
2704  if (!newMaterialization.empty()) {
2705 #ifndef NDEBUG
2706  ValueRange newMaterializationRange(newMaterialization);
2707  assert(TypeRange(newMaterializationRange) == op.getResultTypes() &&
2708  "materialization callback produced value of incorrect type");
2709 #endif // NDEBUG
2710  rewriter.replaceOp(op, newMaterialization);
2711  return success();
2712  }
2713  }
2714 
2715  InFlightDiagnostic diag = op->emitError()
2716  << "failed to legalize unresolved materialization "
2717  "from ("
2718  << inputOperands.getTypes() << ") to ("
2719  << op.getResultTypes()
2720  << ") that remained live after conversion";
2721  diag.attachNote(op->getUsers().begin()->getLoc())
2722  << "see existing live user here: " << *op->getUsers().begin();
2723  return failure();
2724 }
2725 
2727  assert(!ops.empty() && "expected at least one operation");
2728  const ConversionTarget &target = opLegalizer.getTarget();
2729 
2730  // Compute the set of operations and blocks to convert.
2731  SmallVector<Operation *> toConvert;
2732  for (auto *op : ops) {
2734  [&](Operation *op) {
2735  toConvert.push_back(op);
2736  // Don't check this operation's children for conversion if the
2737  // operation is recursively legal.
2738  auto legalityInfo = target.isLegal(op);
2739  if (legalityInfo && legalityInfo->isRecursivelyLegal)
2740  return WalkResult::skip();
2741  return WalkResult::advance();
2742  });
2743  }
2744 
2745  // Convert each operation and discard rewrites on failure.
2746  ConversionPatternRewriter rewriter(ops.front()->getContext(), config);
2747  ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
2748 
2749  for (auto *op : toConvert) {
2750  if (failed(convert(rewriter, op))) {
2751  // Dialect conversion failed.
2752  if (rewriterImpl.config.allowPatternRollback) {
2753  // Rollback is allowed: restore the original IR.
2754  rewriterImpl.undoRewrites();
2755  } else {
2756  // Rollback is not allowed: apply all modifications that have been
2757  // performed so far.
2758  rewriterImpl.applyRewrites();
2759  }
2760  return failure();
2761  }
2762  }
2763 
2764  // After a successful conversion, apply rewrites.
2765  rewriterImpl.applyRewrites();
2766 
2767  // Gather all unresolved materializations.
2770  &materializations = rewriterImpl.unresolvedMaterializations;
2771  for (auto it : materializations)
2772  allCastOps.push_back(it.first);
2773 
2774  // Reconcile all UnrealizedConversionCastOps that were inserted by the
2775  // dialect conversion frameworks. (Not the one that were inserted by
2776  // patterns.)
2777  SmallVector<UnrealizedConversionCastOp> remainingCastOps;
2778  reconcileUnrealizedCasts(allCastOps, &remainingCastOps);
2779 
2780  // Try to legalize all unresolved materializations.
2781  if (config.buildMaterializations) {
2782  IRRewriter rewriter(rewriterImpl.context, config.listener);
2783  for (UnrealizedConversionCastOp castOp : remainingCastOps) {
2784  auto it = materializations.find(castOp);
2785  assert(it != materializations.end() && "inconsistent state");
2786  if (failed(
2787  legalizeUnresolvedMaterialization(rewriter, castOp, it->second)))
2788  return failure();
2789  }
2790  }
2791 
2792  return success();
2793 }
2794 
2795 //===----------------------------------------------------------------------===//
2796 // Reconcile Unrealized Casts
2797 //===----------------------------------------------------------------------===//
2798 
2801  SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
2802  SetVector<UnrealizedConversionCastOp> worklist(llvm::from_range, castOps);
2803  // This set is maintained only if `remainingCastOps` is provided.
2804  DenseSet<Operation *> erasedOps;
2805 
2806  // Helper function that adds all operands to the worklist that are an
2807  // unrealized_conversion_cast op result.
2808  auto enqueueOperands = [&](UnrealizedConversionCastOp castOp) {
2809  for (Value v : castOp.getInputs())
2810  if (auto inputCastOp = v.getDefiningOp<UnrealizedConversionCastOp>())
2811  worklist.insert(inputCastOp);
2812  };
2813 
2814  // Helper function that return the unrealized_conversion_cast op that
2815  // defines all inputs of the given op (in the same order). Return "nullptr"
2816  // if there is no such op.
2817  auto getInputCast =
2818  [](UnrealizedConversionCastOp castOp) -> UnrealizedConversionCastOp {
2819  if (castOp.getInputs().empty())
2820  return {};
2821  auto inputCastOp =
2822  castOp.getInputs().front().getDefiningOp<UnrealizedConversionCastOp>();
2823  if (!inputCastOp)
2824  return {};
2825  if (inputCastOp.getOutputs() != castOp.getInputs())
2826  return {};
2827  return inputCastOp;
2828  };
2829 
2830  // Process ops in the worklist bottom-to-top.
2831  while (!worklist.empty()) {
2832  UnrealizedConversionCastOp castOp = worklist.pop_back_val();
2833  if (castOp->use_empty()) {
2834  // DCE: If the op has no users, erase it. Add the operands to the
2835  // worklist to find additional DCE opportunities.
2836  enqueueOperands(castOp);
2837  if (remainingCastOps)
2838  erasedOps.insert(castOp.getOperation());
2839  castOp->erase();
2840  continue;
2841  }
2842 
2843  // Traverse the chain of input cast ops to see if an op with the same
2844  // input types can be found.
2845  UnrealizedConversionCastOp nextCast = castOp;
2846  while (nextCast) {
2847  if (nextCast.getInputs().getTypes() == castOp.getResultTypes()) {
2848  // Found a cast where the input types match the output types of the
2849  // matched op. We can directly use those inputs and the matched op can
2850  // be removed.
2851  enqueueOperands(castOp);
2852  castOp.replaceAllUsesWith(nextCast.getInputs());
2853  if (remainingCastOps)
2854  erasedOps.insert(castOp.getOperation());
2855  castOp->erase();
2856  break;
2857  }
2858  nextCast = getInputCast(nextCast);
2859  }
2860  }
2861 
2862  if (remainingCastOps)
2863  for (UnrealizedConversionCastOp op : castOps)
2864  if (!erasedOps.contains(op.getOperation()))
2865  remainingCastOps->push_back(op);
2866 }
2867 
2868 //===----------------------------------------------------------------------===//
2869 // Type Conversion
2870 //===----------------------------------------------------------------------===//
2871 
2873  ArrayRef<Type> types) {
2874  assert(!types.empty() && "expected valid types");
2875  remapInput(origInputNo, /*newInputNo=*/argTypes.size(), types.size());
2876  addInputs(types);
2877 }
2878 
2880  assert(!types.empty() &&
2881  "1->0 type remappings don't need to be added explicitly");
2882  argTypes.append(types.begin(), types.end());
2883 }
2884 
2885 void TypeConverter::SignatureConversion::remapInput(unsigned origInputNo,
2886  unsigned newInputNo,
2887  unsigned newInputCount) {
2888  assert(!remappedInputs[origInputNo] && "input has already been remapped");
2889  assert(newInputCount != 0 && "expected valid input count");
2890  remappedInputs[origInputNo] =
2891  InputMapping{newInputNo, newInputCount, /*replacementValues=*/{}};
2892 }
2893 
2895  unsigned origInputNo, ArrayRef<Value> replacements) {
2896  assert(!remappedInputs[origInputNo] && "input has already been remapped");
2897  remappedInputs[origInputNo] = InputMapping{
2898  origInputNo, /*size=*/0,
2899  SmallVector<Value, 1>(replacements.begin(), replacements.end())};
2900 }
2901 
2903  SmallVectorImpl<Type> &results) const {
2904  assert(t && "expected non-null type");
2905 
2906  {
2907  std::shared_lock<decltype(cacheMutex)> cacheReadLock(cacheMutex,
2908  std::defer_lock);
2910  cacheReadLock.lock();
2911  auto existingIt = cachedDirectConversions.find(t);
2912  if (existingIt != cachedDirectConversions.end()) {
2913  if (existingIt->second)
2914  results.push_back(existingIt->second);
2915  return success(existingIt->second != nullptr);
2916  }
2917  auto multiIt = cachedMultiConversions.find(t);
2918  if (multiIt != cachedMultiConversions.end()) {
2919  results.append(multiIt->second.begin(), multiIt->second.end());
2920  return success();
2921  }
2922  }
2923  // Walk the added converters in reverse order to apply the most recently
2924  // registered first.
2925  size_t currentCount = results.size();
2926 
2927  std::unique_lock<decltype(cacheMutex)> cacheWriteLock(cacheMutex,
2928  std::defer_lock);
2929 
2930  for (const ConversionCallbackFn &converter : llvm::reverse(conversions)) {
2931  if (std::optional<LogicalResult> result = converter(t, results)) {
2933  cacheWriteLock.lock();
2934  if (!succeeded(*result)) {
2935  assert(results.size() == currentCount &&
2936  "failed type conversion should not change results");
2937  cachedDirectConversions.try_emplace(t, nullptr);
2938  return failure();
2939  }
2940  auto newTypes = ArrayRef<Type>(results).drop_front(currentCount);
2941  if (newTypes.size() == 1)
2942  cachedDirectConversions.try_emplace(t, newTypes.front());
2943  else
2944  cachedMultiConversions.try_emplace(t, llvm::to_vector<2>(newTypes));
2945  return success();
2946  } else {
2947  assert(results.size() == currentCount &&
2948  "failed type conversion should not change results");
2949  }
2950  }
2951  return failure();
2952 }
2953 
2955  // Use the multi-type result version to convert the type.
2956  SmallVector<Type, 1> results;
2957  if (failed(convertType(t, results)))
2958  return nullptr;
2959 
2960  // Check to ensure that only one type was produced.
2961  return results.size() == 1 ? results.front() : nullptr;
2962 }
2963 
2964 LogicalResult
2966  SmallVectorImpl<Type> &results) const {
2967  for (Type type : types)
2968  if (failed(convertType(type, results)))
2969  return failure();
2970  return success();
2971 }
2972 
2973 bool TypeConverter::isLegal(Type type) const {
2974  return convertType(type) == type;
2975 }
2977  return isLegal(op->getOperandTypes()) && isLegal(op->getResultTypes());
2978 }
2979 
2980 bool TypeConverter::isLegal(Region *region) const {
2981  return llvm::all_of(*region, [this](Block &block) {
2982  return isLegal(block.getArgumentTypes());
2983  });
2984 }
2985 
2986 bool TypeConverter::isSignatureLegal(FunctionType ty) const {
2987  return isLegal(llvm::concat<const Type>(ty.getInputs(), ty.getResults()));
2988 }
2989 
2990 LogicalResult
2992  SignatureConversion &result) const {
2993  // Try to convert the given input type.
2994  SmallVector<Type, 1> convertedTypes;
2995  if (failed(convertType(type, convertedTypes)))
2996  return failure();
2997 
2998  // If this argument is being dropped, there is nothing left to do.
2999  if (convertedTypes.empty())
3000  return success();
3001 
3002  // Otherwise, add the new inputs.
3003  result.addInputs(inputNo, convertedTypes);
3004  return success();
3005 }
3006 LogicalResult
3008  SignatureConversion &result,
3009  unsigned origInputOffset) const {
3010  for (unsigned i = 0, e = types.size(); i != e; ++i)
3011  if (failed(convertSignatureArg(origInputOffset + i, types[i], result)))
3012  return failure();
3013  return success();
3014 }
3015 
3017  Location loc, Type resultType,
3018  ValueRange inputs) const {
3019  for (const SourceMaterializationCallbackFn &fn :
3020  llvm::reverse(sourceMaterializations))
3021  if (Value result = fn(builder, resultType, inputs, loc))
3022  return result;
3023  return nullptr;
3024 }
3025 
3027  Location loc, Type resultType,
3028  ValueRange inputs,
3029  Type originalType) const {
3031  builder, loc, TypeRange(resultType), inputs, originalType);
3032  if (result.empty())
3033  return nullptr;
3034  assert(result.size() == 1 && "expected single result");
3035  return result.front();
3036 }
3037 
3039  OpBuilder &builder, Location loc, TypeRange resultTypes, ValueRange inputs,
3040  Type originalType) const {
3041  for (const TargetMaterializationCallbackFn &fn :
3042  llvm::reverse(targetMaterializations)) {
3043  SmallVector<Value> result =
3044  fn(builder, resultTypes, inputs, loc, originalType);
3045  if (result.empty())
3046  continue;
3047  assert(TypeRange(ValueRange(result)) == resultTypes &&
3048  "callback produced incorrect number of values or values with "
3049  "incorrect types");
3050  return result;
3051  }
3052  return {};
3053 }
3054 
3055 std::optional<TypeConverter::SignatureConversion>
3057  SignatureConversion conversion(block->getNumArguments());
3058  if (failed(convertSignatureArgs(block->getArgumentTypes(), conversion)))
3059  return std::nullopt;
3060  return conversion;
3061 }
3062 
3063 //===----------------------------------------------------------------------===//
3064 // Type attribute conversion
3065 //===----------------------------------------------------------------------===//
3068  return AttributeConversionResult(attr, resultTag);
3069 }
3070 
3073  return AttributeConversionResult(nullptr, naTag);
3074 }
3075 
3078  return AttributeConversionResult(nullptr, abortTag);
3079 }
3080 
3082  return impl.getInt() == resultTag;
3083 }
3084 
3086  return impl.getInt() == naTag;
3087 }
3088 
3090  return impl.getInt() == abortTag;
3091 }
3092 
3094  assert(hasResult() && "Cannot get result from N/A or abort");
3095  return impl.getPointer();
3096 }
3097 
3098 std::optional<Attribute>
3100  for (const TypeAttributeConversionCallbackFn &fn :
3101  llvm::reverse(typeAttributeConversions)) {
3102  AttributeConversionResult res = fn(type, attr);
3103  if (res.hasResult())
3104  return res.getResult();
3105  if (res.isAbort())
3106  return std::nullopt;
3107  }
3108  return std::nullopt;
3109 }
3110 
3111 //===----------------------------------------------------------------------===//
3112 // FunctionOpInterfaceSignatureConversion
3113 //===----------------------------------------------------------------------===//
3114 
3115 static LogicalResult convertFuncOpTypes(FunctionOpInterface funcOp,
3116  const TypeConverter &typeConverter,
3117  ConversionPatternRewriter &rewriter) {
3118  FunctionType type = dyn_cast<FunctionType>(funcOp.getFunctionType());
3119  if (!type)
3120  return failure();
3121 
3122  // Convert the original function types.
3123  TypeConverter::SignatureConversion result(type.getNumInputs());
3124  SmallVector<Type, 1> newResults;
3125  if (failed(typeConverter.convertSignatureArgs(type.getInputs(), result)) ||
3126  failed(typeConverter.convertTypes(type.getResults(), newResults)) ||
3127  failed(rewriter.convertRegionTypes(&funcOp.getFunctionBody(),
3128  typeConverter, &result)))
3129  return failure();
3130 
3131  // Update the function signature in-place.
3132  auto newType = FunctionType::get(rewriter.getContext(),
3133  result.getConvertedTypes(), newResults);
3134 
3135  rewriter.modifyOpInPlace(funcOp, [&] { funcOp.setType(newType); });
3136 
3137  return success();
3138 }
3139 
3140 /// Create a default conversion pattern that rewrites the type signature of a
3141 /// FunctionOpInterface op. This only supports ops which use FunctionType to
3142 /// represent their type.
3143 namespace {
3144 struct FunctionOpInterfaceSignatureConversion : public ConversionPattern {
3145  FunctionOpInterfaceSignatureConversion(StringRef functionLikeOpName,
3146  MLIRContext *ctx,
3147  const TypeConverter &converter)
3148  : ConversionPattern(converter, functionLikeOpName, /*benefit=*/1, ctx) {}
3149 
3150  LogicalResult
3151  matchAndRewrite(Operation *op, ArrayRef<Value> /*operands*/,
3152  ConversionPatternRewriter &rewriter) const override {
3153  FunctionOpInterface funcOp = cast<FunctionOpInterface>(op);
3154  return convertFuncOpTypes(funcOp, *typeConverter, rewriter);
3155  }
3156 };
3157 
3158 struct AnyFunctionOpInterfaceSignatureConversion
3159  : public OpInterfaceConversionPattern<FunctionOpInterface> {
3161 
3162  LogicalResult
3163  matchAndRewrite(FunctionOpInterface funcOp, ArrayRef<Value> /*operands*/,
3164  ConversionPatternRewriter &rewriter) const override {
3165  return convertFuncOpTypes(funcOp, *typeConverter, rewriter);
3166  }
3167 };
3168 } // namespace
3169 
3170 FailureOr<Operation *>
3172  const TypeConverter &converter,
3173  ConversionPatternRewriter &rewriter) {
3174  assert(op && "Invalid op");
3175  Location loc = op->getLoc();
3176  if (converter.isLegal(op))
3177  return rewriter.notifyMatchFailure(loc, "op already legal");
3178 
3179  OperationState newOp(loc, op->getName());
3180  newOp.addOperands(operands);
3181 
3182  SmallVector<Type> newResultTypes;
3183  if (failed(converter.convertTypes(op->getResultTypes(), newResultTypes)))
3184  return rewriter.notifyMatchFailure(loc, "couldn't convert return types");
3185 
3186  newOp.addTypes(newResultTypes);
3187  newOp.addAttributes(op->getAttrs());
3188  return rewriter.create(newOp);
3189 }
3190 
3192  StringRef functionLikeOpName, RewritePatternSet &patterns,
3193  const TypeConverter &converter) {
3194  patterns.add<FunctionOpInterfaceSignatureConversion>(
3195  functionLikeOpName, patterns.getContext(), converter);
3196 }
3197 
3199  RewritePatternSet &patterns, const TypeConverter &converter) {
3200  patterns.add<AnyFunctionOpInterfaceSignatureConversion>(
3201  converter, patterns.getContext());
3202 }
3203 
3204 //===----------------------------------------------------------------------===//
3205 // ConversionTarget
3206 //===----------------------------------------------------------------------===//
3207 
3209  LegalizationAction action) {
3210  legalOperations[op].action = action;
3211 }
3212 
3214  LegalizationAction action) {
3215  for (StringRef dialect : dialectNames)
3216  legalDialects[dialect] = action;
3217 }
3218 
3220  -> std::optional<LegalizationAction> {
3221  std::optional<LegalizationInfo> info = getOpInfo(op);
3222  return info ? info->action : std::optional<LegalizationAction>();
3223 }
3224 
3226  -> std::optional<LegalOpDetails> {
3227  std::optional<LegalizationInfo> info = getOpInfo(op->getName());
3228  if (!info)
3229  return std::nullopt;
3230 
3231  // Returns true if this operation instance is known to be legal.
3232  auto isOpLegal = [&] {
3233  // Handle dynamic legality either with the provided legality function.
3234  if (info->action == LegalizationAction::Dynamic) {
3235  std::optional<bool> result = info->legalityFn(op);
3236  if (result)
3237  return *result;
3238  }
3239 
3240  // Otherwise, the operation is only legal if it was marked 'Legal'.
3241  return info->action == LegalizationAction::Legal;
3242  };
3243  if (!isOpLegal())
3244  return std::nullopt;
3245 
3246  // This operation is legal, compute any additional legality information.
3247  LegalOpDetails legalityDetails;
3248  if (info->isRecursivelyLegal) {
3249  auto legalityFnIt = opRecursiveLegalityFns.find(op->getName());
3250  if (legalityFnIt != opRecursiveLegalityFns.end()) {
3251  legalityDetails.isRecursivelyLegal =
3252  legalityFnIt->second(op).value_or(true);
3253  } else {
3254  legalityDetails.isRecursivelyLegal = true;
3255  }
3256  }
3257  return legalityDetails;
3258 }
3259 
3261  std::optional<LegalizationInfo> info = getOpInfo(op->getName());
3262  if (!info)
3263  return false;
3264 
3265  if (info->action == LegalizationAction::Dynamic) {
3266  std::optional<bool> result = info->legalityFn(op);
3267  if (!result)
3268  return false;
3269 
3270  return !(*result);
3271  }
3272 
3273  return info->action == LegalizationAction::Illegal;
3274 }
3275 
3279  if (!oldCallback)
3280  return newCallback;
3281 
3282  auto chain = [oldCl = std::move(oldCallback), newCl = std::move(newCallback)](
3283  Operation *op) -> std::optional<bool> {
3284  if (std::optional<bool> result = newCl(op))
3285  return *result;
3286 
3287  return oldCl(op);
3288  };
3289  return chain;
3290 }
3291 
3292 void ConversionTarget::setLegalityCallback(
3293  OperationName name, const DynamicLegalityCallbackFn &callback) {
3294  assert(callback && "expected valid legality callback");
3295  auto *infoIt = legalOperations.find(name);
3296  assert(infoIt != legalOperations.end() &&
3297  infoIt->second.action == LegalizationAction::Dynamic &&
3298  "expected operation to already be marked as dynamically legal");
3299  infoIt->second.legalityFn =
3300  composeLegalityCallbacks(std::move(infoIt->second.legalityFn), callback);
3301 }
3302 
3304  OperationName name, const DynamicLegalityCallbackFn &callback) {
3305  auto *infoIt = legalOperations.find(name);
3306  assert(infoIt != legalOperations.end() &&
3307  infoIt->second.action != LegalizationAction::Illegal &&
3308  "expected operation to already be marked as legal");
3309  infoIt->second.isRecursivelyLegal = true;
3310  if (callback)
3311  opRecursiveLegalityFns[name] = composeLegalityCallbacks(
3312  std::move(opRecursiveLegalityFns[name]), callback);
3313  else
3314  opRecursiveLegalityFns.erase(name);
3315 }
3316 
3317 void ConversionTarget::setLegalityCallback(
3318  ArrayRef<StringRef> dialects, const DynamicLegalityCallbackFn &callback) {
3319  assert(callback && "expected valid legality callback");
3320  for (StringRef dialect : dialects)
3321  dialectLegalityFns[dialect] = composeLegalityCallbacks(
3322  std::move(dialectLegalityFns[dialect]), callback);
3323 }
3324 
3325 void ConversionTarget::setLegalityCallback(
3326  const DynamicLegalityCallbackFn &callback) {
3327  assert(callback && "expected valid legality callback");
3328  unknownLegalityFn = composeLegalityCallbacks(unknownLegalityFn, callback);
3329 }
3330 
3331 auto ConversionTarget::getOpInfo(OperationName op) const
3332  -> std::optional<LegalizationInfo> {
3333  // Check for info for this specific operation.
3334  const auto *it = legalOperations.find(op);
3335  if (it != legalOperations.end())
3336  return it->second;
3337  // Check for info for the parent dialect.
3338  auto dialectIt = legalDialects.find(op.getDialectNamespace());
3339  if (dialectIt != legalDialects.end()) {
3340  DynamicLegalityCallbackFn callback;
3341  auto dialectFn = dialectLegalityFns.find(op.getDialectNamespace());
3342  if (dialectFn != dialectLegalityFns.end())
3343  callback = dialectFn->second;
3344  return LegalizationInfo{dialectIt->second, /*isRecursivelyLegal=*/false,
3345  callback};
3346  }
3347  // Otherwise, check if we mark unknown operations as dynamic.
3348  if (unknownLegalityFn)
3349  return LegalizationInfo{LegalizationAction::Dynamic,
3350  /*isRecursivelyLegal=*/false, unknownLegalityFn};
3351  return std::nullopt;
3352 }
3353 
3354 #if MLIR_ENABLE_PDL_IN_PATTERNMATCH
3355 //===----------------------------------------------------------------------===//
3356 // PDL Configuration
3357 //===----------------------------------------------------------------------===//
3358 
3360  auto &rewriterImpl =
3361  static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
3362  rewriterImpl.currentTypeConverter = getTypeConverter();
3363 }
3364 
3366  auto &rewriterImpl =
3367  static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
3368  rewriterImpl.currentTypeConverter = nullptr;
3369 }
3370 
3371 /// Remap the given value using the rewriter and the type converter in the
3372 /// provided config.
3373 static FailureOr<SmallVector<Value>>
3375  SmallVector<Value> mappedValues;
3376  if (failed(rewriter.getRemappedValues(values, mappedValues)))
3377  return failure();
3378  return std::move(mappedValues);
3379 }
3380 
3382  patterns.getPDLPatterns().registerRewriteFunction(
3383  "convertValue",
3384  [](PatternRewriter &rewriter, Value value) -> FailureOr<Value> {
3385  auto results = pdllConvertValues(
3386  static_cast<ConversionPatternRewriter &>(rewriter), value);
3387  if (failed(results))
3388  return failure();
3389  return results->front();
3390  });
3391  patterns.getPDLPatterns().registerRewriteFunction(
3392  "convertValues", [](PatternRewriter &rewriter, ValueRange values) {
3393  return pdllConvertValues(
3394  static_cast<ConversionPatternRewriter &>(rewriter), values);
3395  });
3396  patterns.getPDLPatterns().registerRewriteFunction(
3397  "convertType",
3398  [](PatternRewriter &rewriter, Type type) -> FailureOr<Type> {
3399  auto &rewriterImpl =
3400  static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
3401  if (const TypeConverter *converter =
3402  rewriterImpl.currentTypeConverter) {
3403  if (Type newType = converter->convertType(type))
3404  return newType;
3405  return failure();
3406  }
3407  return type;
3408  });
3409  patterns.getPDLPatterns().registerRewriteFunction(
3410  "convertTypes",
3411  [](PatternRewriter &rewriter,
3412  TypeRange types) -> FailureOr<SmallVector<Type>> {
3413  auto &rewriterImpl =
3414  static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
3415  const TypeConverter *converter = rewriterImpl.currentTypeConverter;
3416  if (!converter)
3417  return SmallVector<Type>(types);
3418 
3419  SmallVector<Type> remappedTypes;
3420  if (failed(converter->convertTypes(types, remappedTypes)))
3421  return failure();
3422  return std::move(remappedTypes);
3423  });
3424 }
3425 #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
3426 
3427 //===----------------------------------------------------------------------===//
3428 // Op Conversion Entry Points
3429 //===----------------------------------------------------------------------===//
3430 
3431 /// This is the type of Action that is dispatched when a conversion is applied.
3433  : public tracing::ActionImpl<ApplyConversionAction> {
3434 public:
3437  static constexpr StringLiteral tag = "apply-conversion";
3438  static constexpr StringLiteral desc =
3439  "Encapsulate the application of a dialect conversion";
3440 
3441  void print(raw_ostream &os) const override { os << tag; }
3442 };
3443 
3444 static LogicalResult applyConversion(ArrayRef<Operation *> ops,
3445  const ConversionTarget &target,
3448  OpConversionMode mode) {
3449  if (ops.empty())
3450  return success();
3451  MLIRContext *ctx = ops.front()->getContext();
3452  LogicalResult status = success();
3453  SmallVector<IRUnit> irUnits(ops.begin(), ops.end());
3455  [&] {
3456  OperationConverter opConverter(target, patterns, config, mode);
3457  status = opConverter.convertOperations(ops);
3458  },
3459  irUnits);
3460  return status;
3461 }
3462 
3463 //===----------------------------------------------------------------------===//
3464 // Partial Conversion
3465 //===----------------------------------------------------------------------===//
3466 
3468  ArrayRef<Operation *> ops, const ConversionTarget &target,
3470  return applyConversion(ops, target, patterns, config,
3471  OpConversionMode::Partial);
3472 }
3473 LogicalResult
3477  return applyPartialConversion(llvm::ArrayRef(op), target, patterns, config);
3478 }
3479 
3480 //===----------------------------------------------------------------------===//
3481 // Full Conversion
3482 //===----------------------------------------------------------------------===//
3483 
3485  const ConversionTarget &target,
3488  return applyConversion(ops, target, patterns, config, OpConversionMode::Full);
3489 }
3491  const ConversionTarget &target,
3494  return applyFullConversion(llvm::ArrayRef(op), target, patterns, config);
3495 }
3496 
3497 //===----------------------------------------------------------------------===//
3498 // Analysis Conversion
3499 //===----------------------------------------------------------------------===//
3500 
3501 /// Find a common IsolatedFromAbove ancestor of the given ops. If at least one
3502 /// op is a top-level module op (which is expected to be isolated from above),
3503 /// return that op.
3505  // Check if there is a top-level operation within `ops`. If so, return that
3506  // op.
3507  for (Operation *op : ops) {
3508  if (!op->getParentOp()) {
3509 #ifndef NDEBUG
3510  assert(op->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
3511  "expected top-level op to be isolated from above");
3512  for (Operation *other : ops)
3513  assert(op->isAncestor(other) &&
3514  "expected ops to have a common ancestor");
3515 #endif // NDEBUG
3516  return op;
3517  }
3518  }
3519 
3520  // No top-level op. Find a common ancestor.
3521  Operation *commonAncestor =
3522  ops.front()->getParentWithTrait<OpTrait::IsIsolatedFromAbove>();
3523  for (Operation *op : ops.drop_front()) {
3524  while (!commonAncestor->isProperAncestor(op)) {
3525  commonAncestor =
3527  assert(commonAncestor &&
3528  "expected to find a common isolated from above ancestor");
3529  }
3530  }
3531 
3532  return commonAncestor;
3533 }
3534 
3538 #ifndef NDEBUG
3539  if (config.legalizableOps)
3540  assert(config.legalizableOps->empty() && "expected empty set");
3541 #endif // NDEBUG
3542 
3543  // Clone closted common ancestor that is isolated from above.
3544  Operation *commonAncestor = findCommonAncestor(ops);
3545  IRMapping mapping;
3546  Operation *clonedAncestor = commonAncestor->clone(mapping);
3547  // Compute inverse IR mapping.
3548  DenseMap<Operation *, Operation *> inverseOperationMap;
3549  for (auto &it : mapping.getOperationMap())
3550  inverseOperationMap[it.second] = it.first;
3551 
3552  // Convert the cloned operations. The original IR will remain unchanged.
3553  SmallVector<Operation *> opsToConvert = llvm::map_to_vector(
3554  ops, [&](Operation *op) { return mapping.lookup(op); });
3555  LogicalResult status = applyConversion(opsToConvert, target, patterns, config,
3556  OpConversionMode::Analysis);
3557 
3558  // Remap `legalizableOps`, so that they point to the original ops and not the
3559  // cloned ops.
3560  if (config.legalizableOps) {
3561  DenseSet<Operation *> originalLegalizableOps;
3562  for (Operation *op : *config.legalizableOps)
3563  originalLegalizableOps.insert(inverseOperationMap[op]);
3564  *config.legalizableOps = std::move(originalLegalizableOps);
3565  }
3566 
3567  // Erase the cloned IR.
3568  clonedAncestor->erase();
3569  return status;
3570 }
3571 
3572 LogicalResult
3576  return applyAnalysisConversion(llvm::ArrayRef(op), target, patterns, config);
3577 }
static LogicalResult applyConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config, OpConversionMode mode)
static T moveAndReset(T &obj)
Helper function that moves and returns the given object.
static ConversionTarget::DynamicLegalityCallbackFn composeLegalityCallbacks(ConversionTarget::DynamicLegalityCallbackFn oldCallback, ConversionTarget::DynamicLegalityCallbackFn newCallback)
static FailureOr< SmallVector< Value > > pdllConvertValues(ConversionPatternRewriter &rewriter, ValueRange values)
Remap the given value using the rewriter and the type converter in the provided config.
static LogicalResult legalizeUnresolvedMaterialization(RewriterBase &rewriter, UnrealizedConversionCastOp op, const UnresolvedMaterializationInfo &info)
static LogicalResult convertFuncOpTypes(FunctionOpInterface funcOp, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args)
A utility function to log a failure result for the given reason.
static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args)
A utility function to log a successful result for the given reason.
SmallVector< Value, 1 > ValueVector
A vector of SSA values, optimized for the most common case of a single value.
static Operation * findCommonAncestor(ArrayRef< Operation * > ops)
Find a common IsolatedFromAbove ancestor of the given ops.
static OpBuilder::InsertPoint computeInsertPoint(Value value)
Helper function that computes an insertion point where the given value is defined and can be used wit...
union mlir::linalg::@1223::ArityGroupAndKind::Kind kind
static std::string diag(const llvm::Value &value)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void rewrite(DataFlowSolver &solver, MLIRContext *context, MutableArrayRef< Region > initialRegions)
Rewrite the given regions using the computing analysis.
Definition: SCCP.cpp:67
This is the type of Action that is dispatched when a conversion is applied.
ApplyConversionAction(ArrayRef< IRUnit > irUnits)
void print(raw_ostream &os) const override
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:309
Block * getOwner() const
Returns the block that owns this argument.
Definition: Value.h:318
Location getLoc() const
Return the location for this argument.
Definition: Value.h:324
Block represents an ordered list of Operations.
Definition: Block.h:33
OpListType::iterator iterator
Definition: Block.h:140
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
Definition: Block.cpp:149
bool empty()
Definition: Block.h:148
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
unsigned getNumArguments()
Definition: Block.h:128
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition: Block.cpp:27
void dropAllDefinedValueUses()
This drops all uses of values defined in this block or in the blocks of nested regions wherever the u...
Definition: Block.cpp:94
RetT walk(FnT &&callback)
Walk all nested operations, blocks (including this block) or regions, depending on the type of callba...
Definition: Block.h:305
OpListType & getOperations()
Definition: Block.h:137
BlockArgListType getArguments()
Definition: Block.h:87
Operation & front()
Definition: Block.h:153
iterator end()
Definition: Block.h:144
iterator begin()
Definition: Block.h:143
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:31
MLIRContext * getContext() const
Definition: Builders.h:55
Location getUnknownLoc()
Definition: Builders.cpp:24
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
LogicalResult getRemappedValues(ValueRange keys, SmallVectorImpl< Value > &results)
Return the converted values that replace 'keys' with types defined by the type converter of the curre...
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Apply a signature conversion to each block in the given region.
void replaceOpWithMultiple(Operation *op, SmallVector< SmallVector< Value >> &&newValues)
Replace the given operation with the new value ranges.
Block * applySignatureConversion(Block *block, TypeConverter::SignatureConversion &conversion, const TypeConverter *converter=nullptr)
Apply a signature conversion to given block.
void startOpModification(Operation *op) override
PatternRewriter hook for updating the given operation in-place.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
detail::ConversionPatternRewriterImpl & getImpl()
Return a reference to the internal implementation.
void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={}) override
PatternRewriter hook for inlining the ops of a block into another block.
void eraseBlock(Block *block) override
PatternRewriter hook for erase all operations in a block.
void cancelOpModification(Operation *op) override
PatternRewriter hook for updating the given operation in-place.
Value getRemappedValue(Value key)
Return the converted value of 'key' with a type defined by the type converter of the currently execut...
void finalizeOpModification(Operation *op) override
PatternRewriter hook for updating the given operation in-place.
void replaceUsesOfBlockArgument(BlockArgument from, ValueRange to)
Replace all the uses of the block argument from with to.
Base class for the conversion patterns.
SmallVector< Value > getOneToOneAdaptorOperands(ArrayRef< ValueRange > operands) const
Given an array of value ranges, which are the inputs to a 1:N adaptor, try to extract the single valu...
virtual LogicalResult matchAndRewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const
Hook for derived classes to implement combined matching and rewriting.
This class describes a specific conversion target.
void setDialectAction(ArrayRef< StringRef > dialectNames, LegalizationAction action)
Register a legality action for the given dialects.
void setOpAction(OperationName op, LegalizationAction action)
Register a legality action for the given operation.
std::optional< LegalOpDetails > isLegal(Operation *op) const
If the given operation instance is legal on this target, a structure containing legality information ...
std::optional< LegalizationAction > getOpAction(OperationName op) const
Get the legality action for the given operation.
LegalizationAction
This enumeration corresponds to the specific action to take when considering an operation legal for t...
void markOpRecursivelyLegal(OperationName name, const DynamicLegalityCallbackFn &callback)
Mark an operation, that must have either been set as Legal or DynamicallyLegal, as being recursively ...
std::function< std::optional< bool >(Operation *)> DynamicLegalityCallbackFn
The signature of the callback used to determine if an operation is dynamically legal on the target.
bool isIllegal(Operation *op) const
Returns true is operation instance is illegal on this target.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
Definition: Diagnostics.h:155
A class for computing basic dominance information.
Definition: Dominance.h:140
bool dominates(Operation *a, Operation *b) const
Return true if operation A dominates operation B, i.e.
Definition: Dominance.h:158
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
const DenseMap< Operation *, Operation * > & getOperationMap() const
Return the held operation mapping.
Definition: IRMapping.h:88
auto lookup(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:72
user_range getUsers() const
Returns a range of all users.
Definition: UseDefLists.h:274
void replaceAllUsesWith(ValueT &&newValue)
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
Definition: UseDefLists.h:211
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:750
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:314
Location objects represent source locations information in MLIR.
Definition: Location.h:32
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
void executeAction(function_ref< void()> actionFn, const tracing::Action &action)
Dispatch the provided action to the handler if any, or just execute it.
Definition: MLIRContext.h:264
bool isMultithreadingEnabled()
Return true if multi-threading is enabled by the context.
This class represents a saved insertion point.
Definition: Builders.h:325
Block::iterator getPoint() const
Definition: Builders.h:338
bool isSet() const
Returns true if this insert point is set.
Definition: Builders.h:335
Block * getBlock() const
Definition: Builders.h:337
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class helps build Operations.
Definition: Builders.h:205
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:425
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:396
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Definition: Builders.h:318
LogicalResult tryFold(Operation *op, SmallVectorImpl< Value > &results, SmallVectorImpl< Operation * > *materializedConstants=nullptr)
Attempts to fold the given operation and places new results within results.
Definition: Builders.cpp:468
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:452
OpInterfaceConversionPattern is a wrapper around ConversionPattern that allows for matching and rewri...
OpInterfaceConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
This class represents an operand of an operation.
Definition: Value.h:257
This is a value defined by a result of an operation.
Definition: Value.h:447
This class provides the API for ops that are known to be isolated from above.
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:43
type_range getTypes() const
Definition: ValueRange.cpp:28
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
void setLoc(Location loc)
Set the source location the operation was defined or derived from.
Definition: Operation.h:226
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:749
void dropAllUses()
Drop all uses of results of this operation.
Definition: Operation.h:834
void setAttrs(DictionaryAttr newAttrs)
Set the attributes from a dictionary on this operation.
Definition: Operation.cpp:304
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Definition: Operation.cpp:385
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
Definition: Operation.cpp:718
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:797
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:674
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:512
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:267
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
Definition: Operation.h:248
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:677
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
operand_type_range getOperandTypes()
Definition: Operation.h:397
result_type_range getResultTypes()
Definition: Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
void setSuccessor(Block *block, unsigned index)
Definition: Operation.cpp:606
bool isAncestor(Operation *other)
Return true if this operation is an ancestor of the other operation.
Definition: Operation.h:263
void setOperands(ValueRange operands)
Replace the current operands of this operation with the ones provided in 'operands'.
Definition: Operation.cpp:236
result_range getResults()
Definition: Operation.h:415
int getPropertiesStorageSize() const
Returns the properties storage size.
Definition: Operation.h:896
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
Definition: Operation.cpp:218
OpaqueProperties getPropertiesStorage()
Returns the properties storage.
Definition: Operation.h:900
void erase()
Remove this operation from its parent block and delete it.
Definition: Operation.cpp:538
void copyProperties(OpaqueProperties rhs)
Copy properties from an existing other properties object.
Definition: Operation.cpp:365
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
void notifyRewriteEnd(PatternRewriter &rewriter) final
void notifyRewriteBegin(PatternRewriter &rewriter) final
Hooks that are invoked at the beginning and end of a rewrite of a matched pattern.
This class manages the application of a group of rewrite patterns, with a user-provided cost model.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
static PatternBenefit impossibleToMatch()
Definition: PatternMatch.h:43
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:769
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
Definition: PatternMatch.h:73
bool hasBoundedRewriteRecursion() const
Returns true if this pattern is known to result in recursive application, i.e.
Definition: PatternMatch.h:129
std::optional< OperationName > getRootKind() const
Return the root node that this pattern matches.
Definition: PatternMatch.h:94
ArrayRef< OperationName > getGeneratedOps() const
Return a list of operations that may be generated when rewriting an operation instance with this patt...
Definition: PatternMatch.h:90
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition: Region.h:200
bool empty()
Definition: Region.h:60
iterator end()
Definition: Region.h:56
BlockListType & getBlocks()
Definition: Region.h:45
Block & front()
Definition: Region.h:65
BlockListType::iterator iterator
Definition: Region.h:52
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:358
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:702
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:622
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void moveOpBefore(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:614
The general result of a type attribute conversion callback, allowing for early termination.
static AttributeConversionResult result(Attribute attr)
This class provides all of the information necessary to convert a type signature.
void addInputs(unsigned origInputNo, ArrayRef< Type > types)
Remap an input of the original signature with a new set of types.
std::optional< InputMapping > getInputMapping(unsigned input) const
Get the input mapping for the given argument.
ArrayRef< Type > getConvertedTypes() const
Return the argument types for the new signature.
void remapInput(unsigned origInputNo, ArrayRef< Value > replacements)
Remap an input of the original signature to replacements values.
Type conversion class.
std::optional< Attribute > convertTypeAttribute(Type type, Attribute attr) const
Convert an attribute present attr from within the type type using the registered conversion functions...
Value materializeSourceConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs) const
Materialize a conversion from a set of types into one result type by generating a cast sequence of so...
bool isLegal(Type type) const
Return true if the given type is legal for this type converter, i.e.
LogicalResult convertSignatureArgs(TypeRange types, SignatureConversion &result, unsigned origInputOffset=0) const
LogicalResult convertSignatureArg(unsigned inputNo, Type type, SignatureConversion &result) const
This method allows for converting a specific argument of a signature.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
std::optional< SignatureConversion > convertBlockSignature(Block *block) const
This function converts the type signature of the given block, by invoking 'convertSignatureArg' for e...
LogicalResult convertTypes(TypeRange types, SmallVectorImpl< Type > &results) const
Convert the given set of types, filling 'results' as necessary.
bool isSignatureLegal(FunctionType ty) const
Return true if the inputs and outputs of the given function type are legal.
Value materializeTargetConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs, Type originalType={}) const
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Block * getParentBlock()
Return the Block in which this Value is defined.
Definition: Value.cpp:48
user_range getUsers() const
Definition: Value.h:218
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
static WalkResult skip()
Definition: WalkResult.h:48
static WalkResult advance()
Definition: WalkResult.h:47
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
CRTP Implementation of an action.
Definition: Action.h:76
AttrTypeReplacer.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
@ Full
Documents are synced by always sending the full content of the document.
Kind
An enumeration of the kinds of predicates.
Definition: Predicate.h:44
Include the generated interface declarations.
void populateFunctionOpInterfaceTypeConversionPattern(StringRef functionLikeOpName, RewritePatternSet &patterns, const TypeConverter &converter)
Add a pattern to the given pattern list to convert the signature of a FunctionOpInterface op with the...
void populateAnyFunctionOpInterfaceTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter)
const FrozenRewritePatternSet GreedyRewriteConfig config
LogicalResult applyFullConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Apply a complete conversion on the given operations, and all nested operations.
FailureOr< Operation * > convertOpResultTypes(Operation *op, ValueRange operands, const TypeConverter &converter, ConversionPatternRewriter &rewriter)
Generic utility to convert op result types according to type converter without knowing exact op type.
const FrozenRewritePatternSet & patterns
LogicalResult applyAnalysisConversion(ArrayRef< Operation * > ops, ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Apply an analysis conversion on the given operations, and all nested operations.
void reconcileUnrealizedCasts(ArrayRef< UnrealizedConversionCastOp > castOps, SmallVectorImpl< UnrealizedConversionCastOp > *remainingCastOps=nullptr)
Try to reconcile all given UnrealizedConversionCastOps and store the left-over ops in remainingCastOp...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void registerConversionPDLFunctions(RewritePatternSet &patterns)
Register the dialect conversion PDL functions with the given pattern set.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
Dialect conversion configuration.
bool allowPatternRollback
If set to "true", pattern rollback is allowed.
RewriterBase::Listener * listener
An optional listener that is notified about all IR modifications in case dialect conversion succeeds.
function_ref< void(Diagnostic &)> notifyCallback
An optional callback used to notify about match failure diagnostics during the conversion.
DenseSet< Operation * > * legalizableOps
Analysis conversion only.
DenseSet< Operation * > * unlegalizedOps
Partial conversion only.
bool buildMaterializations
If set to "true", the dialect conversion attempts to build source/target materializations through the...
A structure containing additional information describing a specific legal operation instance.
bool isRecursivelyLegal
A flag that indicates if this operation is 'recursively' legal.
This iterator enumerates elements according to their dominance relationship.
Definition: Iterators.h:48
virtual void notifyBlockInserted(Block *block, Region *previous, Region::iterator previousIt)
Notify the listener that the specified block was inserted.
Definition: Builders.h:306
virtual void notifyOperationInserted(Operation *op, InsertPoint previous)
Notify the listener that the specified operation was inserted.
Definition: Builders.h:296
LogicalResult convertOperations(ArrayRef< Operation * > ops)
Converts the given operations to the conversion target.
OperationConverter(const ConversionTarget &target, const FrozenRewritePatternSet &patterns, const ConversionConfig &config, OpConversionMode mode)
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addTypes(ArrayRef< Type > newTypes)
virtual void notifyOperationModified(Operation *op)
Notify the listener that the specified operation was modified in-place.
Definition: PatternMatch.h:369
virtual void notifyOperationErased(Operation *op)
Notify the listener that the specified operation is about to be erased.
Definition: PatternMatch.h:393
virtual void notifyOperationReplaced(Operation *op, Operation *replacement)
Notify the listener that all uses of the specified operation's results are about to be replaced with ...
Definition: PatternMatch.h:377
virtual void notifyBlockErased(Block *block)
Notify the listener that the specified block is about to be erased.
Definition: PatternMatch.h:366
This struct represents a range of new types or a range of values that remaps an existing signature in...
A rewriter that keeps track of erased ops and blocks.
SingleEraseRewriter(MLIRContext *context, std::function< void(Operation *)> opErasedCallback=nullptr)
void eraseOp(Operation *op) override
Erase the given op (unless it was already erased).
void notifyBlockErased(Block *block) override
Notify the listener that the specified block is about to be erased.
void notifyOperationErased(Operation *op) override
Notify the listener that the specified operation is about to be erased.
void eraseBlock(Block *block) override
Erase the given block (unless it was already erased).
void notifyOperationInserted(Operation *op, OpBuilder::InsertPoint previous) override
Notify the listener that the specified operation was inserted.
DenseMap< UnrealizedConversionCastOp, UnresolvedMaterializationInfo > unresolvedMaterializations
A mapping for looking up metadata of unresolved materializations.
Value findOrBuildReplacementValue(Value value, const TypeConverter *converter)
Find a replacement value for the given SSA value in the conversion value mapping.
void replaceUsesOfBlockArgument(BlockArgument from, ValueRange to, const TypeConverter *converter)
Replace the given block argument with the given values.
SetVector< Operation * > patternNewOps
A set of operations that were created by the current pattern.
ConversionPatternRewriterImpl(MLIRContext *ctx, const ConversionConfig &config)
DenseMap< Region *, const TypeConverter * > regionToConverter
A mapping of regions to type converters that should be used when converting the arguments of blocks w...
bool wasOpReplaced(Operation *op) const
Return "true" if the given operation was replaced or erased.
void notifyBlockInserted(Block *block, Region *previous, Region::iterator previousIt) override
Notifies that a block was inserted.
void undoRewrites(unsigned numRewritesToKeep=0, StringRef patternName="")
Undo the rewrites (motions, splits) one by one in reverse order until "numRewritesToKeep" rewrites re...
LogicalResult remapValues(StringRef valueDiagTag, std::optional< Location > inputLoc, PatternRewriter &rewriter, ValueRange values, SmallVector< ValueVector > &remapped)
Remap the given values to those with potentially different types.
Block * applySignatureConversion(ConversionPatternRewriter &rewriter, Block *block, const TypeConverter *converter, TypeConverter::SignatureConversion &signatureConversion)
Apply the given signature conversion on the given block.
FailureOr< Block * > convertRegionTypes(ConversionPatternRewriter &rewriter, Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion)
Convert the types of block arguments within the given region.
void resetState(RewriterState state, StringRef patternName="")
Reset the state of the rewriter to a previously saved point.
void applyRewrites()
Apply all requested operation rewrites.
void inlineBlockBefore(Block *source, Block *dest, Block::iterator before)
Inline the source block into the destination block before the given iterator.
RewriterState getCurrentState()
Return the current state of the rewriter.
llvm::ScopedPrinter logger
A logger used to emit diagnostics during the conversion process.
void appendRewrite(Args &&...args)
Append a rewrite.
SmallPtrSet< Operation *, 1 > pendingRootUpdates
A set of operations that have pending updates.
void notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
Notifies that a pattern match failed for the given reason.
bool isOpIgnored(Operation *op) const
Return "true" if the given operation is ignored, and does not need to be converted.
void eraseBlock(Block *block)
Erase the given block and its contents.
SetVector< Block * > patternInsertedBlocks
A set of blocks that were inserted (newly-created blocks or moved blocks) by the current pattern.
ValueRange buildUnresolvedMaterialization(MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc, ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes, Type originalType, const TypeConverter *converter, UnrealizedConversionCastOp *castOp=nullptr)
Build an unresolved materialization operation given a range of output types and a list of input opera...
SetVector< Operation * > ignoredOps
A set of operations that should no longer be considered for legalization.
SmallVector< std::unique_ptr< IRRewrite > > rewrites
Ordered list of block operations (creations, splits, motions).
SetVector< Operation * > patternModifiedOps
A set of operations that were modified by the current pattern.
const ConversionConfig & config
Dialect conversion configuration.
SetVector< Operation * > replacedOps
A set of operations that were replaced/erased.
const TypeConverter * currentTypeConverter
The current type converter, or nullptr if no type converter is currently active.
void replaceOp(Operation *op, SmallVector< SmallVector< Value >> &&newValues)
Replace the results of the given operation with the given values and erase the operation.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.