MLIR  21.0.0git
DialectConversion.cpp
Go to the documentation of this file.
1 //===- DialectConversion.cpp - MLIR dialect conversion generic pass -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 #include "mlir/Config/mlir-config.h"
11 #include "mlir/IR/Block.h"
12 #include "mlir/IR/Builders.h"
13 #include "mlir/IR/BuiltinOps.h"
14 #include "mlir/IR/Dominance.h"
15 #include "mlir/IR/IRMapping.h"
16 #include "mlir/IR/Iterators.h"
19 #include "llvm/ADT/SmallPtrSet.h"
20 #include "llvm/Support/Debug.h"
21 #include "llvm/Support/FormatVariadic.h"
22 #include "llvm/Support/SaveAndRestore.h"
23 #include "llvm/Support/ScopedPrinter.h"
24 #include <optional>
25 
26 using namespace mlir;
27 using namespace mlir::detail;
28 
29 #define DEBUG_TYPE "dialect-conversion"
30 
31 /// A utility function to log a successful result for the given reason.
32 template <typename... Args>
33 static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
34  LLVM_DEBUG({
35  os.unindent();
36  os.startLine() << "} -> SUCCESS";
37  if (!fmt.empty())
38  os.getOStream() << " : "
39  << llvm::formatv(fmt.data(), std::forward<Args>(args)...);
40  os.getOStream() << "\n";
41  });
42 }
43 
44 /// A utility function to log a failure result for the given reason.
45 template <typename... Args>
46 static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
47  LLVM_DEBUG({
48  os.unindent();
49  os.startLine() << "} -> FAILURE : "
50  << llvm::formatv(fmt.data(), std::forward<Args>(args)...)
51  << "\n";
52  });
53 }
54 
55 /// Helper function that computes an insertion point where the given value is
56 /// defined and can be used without a dominance violation.
58  Block *insertBlock = value.getParentBlock();
59  Block::iterator insertPt = insertBlock->begin();
60  if (OpResult inputRes = dyn_cast<OpResult>(value))
61  insertPt = ++inputRes.getOwner()->getIterator();
62  return OpBuilder::InsertPoint(insertBlock, insertPt);
63 }
64 
65 /// Helper function that computes an insertion point where the given values are
66 /// defined and can be used without a dominance violation.
68  assert(!vals.empty() && "expected at least one value");
69  DominanceInfo domInfo;
70  OpBuilder::InsertPoint pt = computeInsertPoint(vals.front());
71  for (Value v : vals.drop_front()) {
72  // Choose the "later" insertion point.
74  if (domInfo.dominates(pt.getBlock(), pt.getPoint(), nextPt.getBlock(),
75  nextPt.getPoint())) {
76  // pt is before nextPt => choose nextPt.
77  pt = nextPt;
78  } else {
79 #ifndef NDEBUG
80  // nextPt should be before pt => choose pt.
81  // If pt, nextPt are no dominance relationship, then there is no valid
82  // insertion point at which all given values are defined.
83  bool dom = domInfo.dominates(nextPt.getBlock(), nextPt.getPoint(),
84  pt.getBlock(), pt.getPoint());
85  assert(dom && "unable to find valid insertion point");
86 #endif // NDEBUG
87  }
88  }
89  return pt;
90 }
91 
92 //===----------------------------------------------------------------------===//
93 // ConversionValueMapping
94 //===----------------------------------------------------------------------===//
95 
96 /// A vector of SSA values, optimized for the most common case of a single
97 /// value.
98 using ValueVector = SmallVector<Value, 1>;
99 
100 namespace {
101 
102 /// Helper class to make it possible to use `ValueVector` as a key in DenseMap.
103 struct ValueVectorMapInfo {
104  static ValueVector getEmptyKey() { return ValueVector{Value()}; }
105  static ValueVector getTombstoneKey() { return ValueVector{Value(), Value()}; }
106  static ::llvm::hash_code getHashValue(const ValueVector &val) {
107  return ::llvm::hash_combine_range(val);
108  }
109  static bool isEqual(const ValueVector &LHS, const ValueVector &RHS) {
110  return LHS == RHS;
111  }
112 };
113 
114 /// This class wraps a IRMapping to provide recursive lookup
115 /// functionality, i.e. we will traverse if the mapped value also has a mapping.
116 struct ConversionValueMapping {
117  /// Return "true" if an SSA value is mapped to the given value. May return
118  /// false positives.
119  bool isMappedTo(Value value) const { return mappedTo.contains(value); }
120 
121  /// Lookup the most recently mapped values with the desired types in the
122  /// mapping.
123  ///
124  /// Special cases:
125  /// - If the desired type range is empty, simply return the most recently
126  /// mapped values.
127  /// - If there is no mapping to the desired types, also return the most
128  /// recently mapped values.
129  /// - If there is no mapping for the given values at all, return the given
130  /// value.
131  ValueVector lookupOrDefault(Value from, TypeRange desiredTypes = {}) const;
132 
133  /// Lookup the given value within the map, or return an empty vector if the
134  /// value is not mapped. If it is mapped, this follows the same behavior
135  /// as `lookupOrDefault`.
136  ValueVector lookupOrNull(Value from, TypeRange desiredTypes = {}) const;
137 
138  template <typename T>
139  struct IsValueVector : std::is_same<std::decay_t<T>, ValueVector> {};
140 
141  /// Map a value vector to the one provided.
142  template <typename OldVal, typename NewVal>
143  std::enable_if_t<IsValueVector<OldVal>::value && IsValueVector<NewVal>::value>
144  map(OldVal &&oldVal, NewVal &&newVal) {
145  LLVM_DEBUG({
146  ValueVector next(newVal);
147  while (true) {
148  assert(next != oldVal && "inserting cyclic mapping");
149  auto it = mapping.find(next);
150  if (it == mapping.end())
151  break;
152  next = it->second;
153  }
154  });
155  mappedTo.insert_range(newVal);
156 
157  mapping[std::forward<OldVal>(oldVal)] = std::forward<NewVal>(newVal);
158  }
159 
160  /// Map a value vector or single value to the one provided.
161  template <typename OldVal, typename NewVal>
162  std::enable_if_t<!IsValueVector<OldVal>::value ||
163  !IsValueVector<NewVal>::value>
164  map(OldVal &&oldVal, NewVal &&newVal) {
165  if constexpr (IsValueVector<OldVal>{}) {
166  map(std::forward<OldVal>(oldVal), ValueVector{newVal});
167  } else if constexpr (IsValueVector<NewVal>{}) {
168  map(ValueVector{oldVal}, std::forward<NewVal>(newVal));
169  } else {
170  map(ValueVector{oldVal}, ValueVector{newVal});
171  }
172  }
173 
174  void map(Value oldVal, SmallVector<Value> &&newVal) {
175  map(ValueVector{oldVal}, ValueVector(std::move(newVal)));
176  }
177 
178  /// Drop the last mapping for the given values.
179  void erase(const ValueVector &value) { mapping.erase(value); }
180 
181 private:
182  /// Current value mappings.
184 
185  /// All SSA values that are mapped to. May contain false positives.
186  DenseSet<Value> mappedTo;
187 };
188 } // namespace
189 
191 ConversionValueMapping::lookupOrDefault(Value from,
192  TypeRange desiredTypes) const {
193  // Try to find the deepest values that have the desired types. If there is no
194  // such mapping, simply return the deepest values.
195  ValueVector desiredValue;
196  ValueVector current{from};
197  do {
198  // Store the current value if the types match.
199  if (TypeRange(ValueRange(current)) == desiredTypes)
200  desiredValue = current;
201 
202  // If possible, Replace each value with (one or multiple) mapped values.
203  ValueVector next;
204  for (Value v : current) {
205  auto it = mapping.find({v});
206  if (it != mapping.end()) {
207  llvm::append_range(next, it->second);
208  } else {
209  next.push_back(v);
210  }
211  }
212  if (next != current) {
213  // If at least one value was replaced, continue the lookup from there.
214  current = std::move(next);
215  continue;
216  }
217 
218  // Otherwise: Check if there is a mapping for the entire vector. Such
219  // mappings are materializations. (N:M mapping are not supported for value
220  // replacements.)
221  //
222  // Note: From a correctness point of view, materializations do not have to
223  // be stored (and looked up) in the mapping. But for performance reasons,
224  // we choose to reuse existing IR (when possible) instead of creating it
225  // multiple times.
226  auto it = mapping.find(current);
227  if (it == mapping.end()) {
228  // No mapping found: The lookup stops here.
229  break;
230  }
231  current = it->second;
232  } while (true);
233 
234  // If the desired values were found use them, otherwise default to the leaf
235  // values.
236  // Note: If `desiredTypes` is empty, this function always returns `current`.
237  return !desiredValue.empty() ? std::move(desiredValue) : std::move(current);
238 }
239 
240 ValueVector ConversionValueMapping::lookupOrNull(Value from,
241  TypeRange desiredTypes) const {
242  ValueVector result = lookupOrDefault(from, desiredTypes);
243  if (result == ValueVector{from} ||
244  (!desiredTypes.empty() && TypeRange(ValueRange(result)) != desiredTypes))
245  return {};
246  return result;
247 }
248 
249 //===----------------------------------------------------------------------===//
250 // Rewriter and Translation State
251 //===----------------------------------------------------------------------===//
252 namespace {
253 /// This class contains a snapshot of the current conversion rewriter state.
254 /// This is useful when saving and undoing a set of rewrites.
255 struct RewriterState {
256  RewriterState(unsigned numRewrites, unsigned numIgnoredOperations,
257  unsigned numReplacedOps)
258  : numRewrites(numRewrites), numIgnoredOperations(numIgnoredOperations),
259  numReplacedOps(numReplacedOps) {}
260 
261  /// The current number of rewrites performed.
262  unsigned numRewrites;
263 
264  /// The current number of ignored operations.
265  unsigned numIgnoredOperations;
266 
267  /// The current number of replaced ops that are scheduled for erasure.
268  unsigned numReplacedOps;
269 };
270 
271 //===----------------------------------------------------------------------===//
272 // IR rewrites
273 //===----------------------------------------------------------------------===//
274 
275 static void notifyIRErased(RewriterBase::Listener *listener, Operation &op);
276 
277 /// Notify the listener that the given block and its contents are being erased.
278 static void notifyIRErased(RewriterBase::Listener *listener, Block &b) {
279  for (Operation &op : b)
280  notifyIRErased(listener, op);
281  listener->notifyBlockErased(&b);
282 }
283 
284 /// Notify the listener that the given operation and its contents are being
285 /// erased.
286 static void notifyIRErased(RewriterBase::Listener *listener, Operation &op) {
287  for (Region &r : op.getRegions()) {
288  for (Block &b : r) {
289  notifyIRErased(listener, b);
290  }
291  }
292  listener->notifyOperationErased(&op);
293 }
294 
295 /// An IR rewrite that can be committed (upon success) or rolled back (upon
296 /// failure).
297 ///
298 /// The dialect conversion keeps track of IR modifications (requested by the
299 /// user through the rewriter API) in `IRRewrite` objects. Some kind of rewrites
300 /// are directly applied to the IR as the rewriter API is used, some are applied
301 /// partially, and some are delayed until the `IRRewrite` objects are committed.
302 class IRRewrite {
303 public:
304  /// The kind of the rewrite. Rewrites can be undone if the conversion fails.
305  /// Enum values are ordered, so that they can be used in `classof`: first all
306  /// block rewrites, then all operation rewrites.
307  enum class Kind {
308  // Block rewrites
309  CreateBlock,
310  EraseBlock,
311  InlineBlock,
312  MoveBlock,
313  BlockTypeConversion,
314  ReplaceBlockArg,
315  // Operation rewrites
316  MoveOperation,
317  ModifyOperation,
318  ReplaceOperation,
319  CreateOperation,
320  UnresolvedMaterialization
321  };
322 
323  virtual ~IRRewrite() = default;
324 
325  /// Roll back the rewrite. Operations may be erased during rollback.
326  virtual void rollback() = 0;
327 
328  /// Commit the rewrite. At this point, it is certain that the dialect
329  /// conversion will succeed. All IR modifications, except for operation/block
330  /// erasure, must be performed through the given rewriter.
331  ///
332  /// Instead of erasing operations/blocks, they should merely be unlinked
333  /// commit phase and finally be erased during the cleanup phase. This is
334  /// because internal dialect conversion state (such as `mapping`) may still
335  /// be using them.
336  ///
337  /// Any IR modification that was already performed before the commit phase
338  /// (e.g., insertion of an op) must be communicated to the listener that may
339  /// be attached to the given rewriter.
340  virtual void commit(RewriterBase &rewriter) {}
341 
342  /// Cleanup operations/blocks. Cleanup is called after commit.
343  virtual void cleanup(RewriterBase &rewriter) {}
344 
345  Kind getKind() const { return kind; }
346 
347  static bool classof(const IRRewrite *rewrite) { return true; }
348 
349 protected:
350  IRRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl)
351  : kind(kind), rewriterImpl(rewriterImpl) {}
352 
353  const ConversionConfig &getConfig() const;
354 
355  const Kind kind;
356  ConversionPatternRewriterImpl &rewriterImpl;
357 };
358 
359 /// A block rewrite.
360 class BlockRewrite : public IRRewrite {
361 public:
362  /// Return the block that this rewrite operates on.
363  Block *getBlock() const { return block; }
364 
365  static bool classof(const IRRewrite *rewrite) {
366  return rewrite->getKind() >= Kind::CreateBlock &&
367  rewrite->getKind() <= Kind::ReplaceBlockArg;
368  }
369 
370 protected:
371  BlockRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl,
372  Block *block)
373  : IRRewrite(kind, rewriterImpl), block(block) {}
374 
375  // The block that this rewrite operates on.
376  Block *block;
377 };
378 
379 /// Creation of a block. Block creations are immediately reflected in the IR.
380 /// There is no extra work to commit the rewrite. During rollback, the newly
381 /// created block is erased.
382 class CreateBlockRewrite : public BlockRewrite {
383 public:
384  CreateBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block)
385  : BlockRewrite(Kind::CreateBlock, rewriterImpl, block) {}
386 
387  static bool classof(const IRRewrite *rewrite) {
388  return rewrite->getKind() == Kind::CreateBlock;
389  }
390 
391  void commit(RewriterBase &rewriter) override {
392  // The block was already created and inserted. Just inform the listener.
393  if (auto *listener = rewriter.getListener())
394  listener->notifyBlockInserted(block, /*previous=*/{}, /*previousIt=*/{});
395  }
396 
397  void rollback() override {
398  // Unlink all of the operations within this block, they will be deleted
399  // separately.
400  auto &blockOps = block->getOperations();
401  while (!blockOps.empty())
402  blockOps.remove(blockOps.begin());
403  block->dropAllUses();
404  if (block->getParent())
405  block->erase();
406  else
407  delete block;
408  }
409 };
410 
411 /// Erasure of a block. Block erasures are partially reflected in the IR. Erased
412 /// blocks are immediately unlinked, but only erased during cleanup. This makes
413 /// it easier to rollback a block erasure: the block is simply inserted into its
414 /// original location.
415 class EraseBlockRewrite : public BlockRewrite {
416 public:
417  EraseBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block)
418  : BlockRewrite(Kind::EraseBlock, rewriterImpl, block),
419  region(block->getParent()), insertBeforeBlock(block->getNextNode()) {}
420 
421  static bool classof(const IRRewrite *rewrite) {
422  return rewrite->getKind() == Kind::EraseBlock;
423  }
424 
425  ~EraseBlockRewrite() override {
426  assert(!block &&
427  "rewrite was neither rolled back nor committed/cleaned up");
428  }
429 
430  void rollback() override {
431  // The block (owned by this rewrite) was not actually erased yet. It was
432  // just unlinked. Put it back into its original position.
433  assert(block && "expected block");
434  auto &blockList = region->getBlocks();
435  Region::iterator before = insertBeforeBlock
436  ? Region::iterator(insertBeforeBlock)
437  : blockList.end();
438  blockList.insert(before, block);
439  block = nullptr;
440  }
441 
442  void commit(RewriterBase &rewriter) override {
443  assert(block && "expected block");
444 
445  // Notify the listener that the block and its contents are being erased.
446  if (auto *listener =
447  dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener()))
448  notifyIRErased(listener, *block);
449  }
450 
451  void cleanup(RewriterBase &rewriter) override {
452  // Erase the contents of the block.
453  for (auto &op : llvm::make_early_inc_range(llvm::reverse(*block)))
454  rewriter.eraseOp(&op);
455  assert(block->empty() && "expected empty block");
456 
457  // Erase the block.
458  block->dropAllDefinedValueUses();
459  delete block;
460  block = nullptr;
461  }
462 
463 private:
464  // The region in which this block was previously contained.
465  Region *region;
466 
467  // The original successor of this block before it was unlinked. "nullptr" if
468  // this block was the only block in the region.
469  Block *insertBeforeBlock;
470 };
471 
472 /// Inlining of a block. This rewrite is immediately reflected in the IR.
473 /// Note: This rewrite represents only the inlining of the operations. The
474 /// erasure of the inlined block is a separate rewrite.
475 class InlineBlockRewrite : public BlockRewrite {
476 public:
477  InlineBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block,
478  Block *sourceBlock, Block::iterator before)
479  : BlockRewrite(Kind::InlineBlock, rewriterImpl, block),
480  sourceBlock(sourceBlock),
481  firstInlinedInst(sourceBlock->empty() ? nullptr
482  : &sourceBlock->front()),
483  lastInlinedInst(sourceBlock->empty() ? nullptr : &sourceBlock->back()) {
484  // If a listener is attached to the dialect conversion, ops must be moved
485  // one-by-one. When they are moved in bulk, notifications cannot be sent
486  // because the ops that used to be in the source block at the time of the
487  // inlining (before the "commit" phase) are unknown at the time when
488  // notifications are sent (which is during the "commit" phase).
489  assert(!getConfig().listener &&
490  "InlineBlockRewrite not supported if listener is attached");
491  }
492 
493  static bool classof(const IRRewrite *rewrite) {
494  return rewrite->getKind() == Kind::InlineBlock;
495  }
496 
497  void rollback() override {
498  // Put the operations from the destination block (owned by the rewrite)
499  // back into the source block.
500  if (firstInlinedInst) {
501  assert(lastInlinedInst && "expected operation");
502  sourceBlock->getOperations().splice(sourceBlock->begin(),
503  block->getOperations(),
504  Block::iterator(firstInlinedInst),
505  ++Block::iterator(lastInlinedInst));
506  }
507  }
508 
509 private:
510  // The block that originally contained the operations.
511  Block *sourceBlock;
512 
513  // The first inlined operation.
514  Operation *firstInlinedInst;
515 
516  // The last inlined operation.
517  Operation *lastInlinedInst;
518 };
519 
520 /// Moving of a block. This rewrite is immediately reflected in the IR.
521 class MoveBlockRewrite : public BlockRewrite {
522 public:
523  MoveBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block,
524  Region *region, Block *insertBeforeBlock)
525  : BlockRewrite(Kind::MoveBlock, rewriterImpl, block), region(region),
526  insertBeforeBlock(insertBeforeBlock) {}
527 
528  static bool classof(const IRRewrite *rewrite) {
529  return rewrite->getKind() == Kind::MoveBlock;
530  }
531 
532  void commit(RewriterBase &rewriter) override {
533  // The block was already moved. Just inform the listener.
534  if (auto *listener = rewriter.getListener()) {
535  // Note: `previousIt` cannot be passed because this is a delayed
536  // notification and iterators into past IR state cannot be represented.
537  listener->notifyBlockInserted(block, /*previous=*/region,
538  /*previousIt=*/{});
539  }
540  }
541 
542  void rollback() override {
543  // Move the block back to its original position.
544  Region::iterator before =
545  insertBeforeBlock ? Region::iterator(insertBeforeBlock) : region->end();
546  region->getBlocks().splice(before, block->getParent()->getBlocks(), block);
547  }
548 
549 private:
550  // The region in which this block was previously contained.
551  Region *region;
552 
553  // The original successor of this block before it was moved. "nullptr" if
554  // this block was the only block in the region.
555  Block *insertBeforeBlock;
556 };
557 
558 /// Block type conversion. This rewrite is partially reflected in the IR.
559 class BlockTypeConversionRewrite : public BlockRewrite {
560 public:
561  BlockTypeConversionRewrite(ConversionPatternRewriterImpl &rewriterImpl,
562  Block *origBlock, Block *newBlock)
563  : BlockRewrite(Kind::BlockTypeConversion, rewriterImpl, origBlock),
564  newBlock(newBlock) {}
565 
566  static bool classof(const IRRewrite *rewrite) {
567  return rewrite->getKind() == Kind::BlockTypeConversion;
568  }
569 
570  Block *getOrigBlock() const { return block; }
571 
572  Block *getNewBlock() const { return newBlock; }
573 
574  void commit(RewriterBase &rewriter) override;
575 
576  void rollback() override;
577 
578 private:
579  /// The new block that was created as part of this signature conversion.
580  Block *newBlock;
581 };
582 
583 /// Replacing a block argument. This rewrite is not immediately reflected in the
584 /// IR. An internal IR mapping is updated, but the actual replacement is delayed
585 /// until the rewrite is committed.
586 class ReplaceBlockArgRewrite : public BlockRewrite {
587 public:
588  ReplaceBlockArgRewrite(ConversionPatternRewriterImpl &rewriterImpl,
589  Block *block, BlockArgument arg,
590  const TypeConverter *converter)
591  : BlockRewrite(Kind::ReplaceBlockArg, rewriterImpl, block), arg(arg),
592  converter(converter) {}
593 
594  static bool classof(const IRRewrite *rewrite) {
595  return rewrite->getKind() == Kind::ReplaceBlockArg;
596  }
597 
598  void commit(RewriterBase &rewriter) override;
599 
600  void rollback() override;
601 
602 private:
603  BlockArgument arg;
604 
605  /// The current type converter when the block argument was replaced.
606  const TypeConverter *converter;
607 };
608 
609 /// An operation rewrite.
610 class OperationRewrite : public IRRewrite {
611 public:
612  /// Return the operation that this rewrite operates on.
613  Operation *getOperation() const { return op; }
614 
615  static bool classof(const IRRewrite *rewrite) {
616  return rewrite->getKind() >= Kind::MoveOperation &&
617  rewrite->getKind() <= Kind::UnresolvedMaterialization;
618  }
619 
620 protected:
621  OperationRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl,
622  Operation *op)
623  : IRRewrite(kind, rewriterImpl), op(op) {}
624 
625  // The operation that this rewrite operates on.
626  Operation *op;
627 };
628 
629 /// Moving of an operation. This rewrite is immediately reflected in the IR.
630 class MoveOperationRewrite : public OperationRewrite {
631 public:
632  MoveOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
633  Operation *op, Block *block, Operation *insertBeforeOp)
634  : OperationRewrite(Kind::MoveOperation, rewriterImpl, op), block(block),
635  insertBeforeOp(insertBeforeOp) {}
636 
637  static bool classof(const IRRewrite *rewrite) {
638  return rewrite->getKind() == Kind::MoveOperation;
639  }
640 
641  void commit(RewriterBase &rewriter) override {
642  // The operation was already moved. Just inform the listener.
643  if (auto *listener = rewriter.getListener()) {
644  // Note: `previousIt` cannot be passed because this is a delayed
645  // notification and iterators into past IR state cannot be represented.
646  listener->notifyOperationInserted(
647  op, /*previous=*/OpBuilder::InsertPoint(/*insertBlock=*/block,
648  /*insertPt=*/{}));
649  }
650  }
651 
652  void rollback() override {
653  // Move the operation back to its original position.
654  Block::iterator before =
655  insertBeforeOp ? Block::iterator(insertBeforeOp) : block->end();
656  block->getOperations().splice(before, op->getBlock()->getOperations(), op);
657  }
658 
659 private:
660  // The block in which this operation was previously contained.
661  Block *block;
662 
663  // The original successor of this operation before it was moved. "nullptr"
664  // if this operation was the only operation in the region.
665  Operation *insertBeforeOp;
666 };
667 
668 /// In-place modification of an op. This rewrite is immediately reflected in
669 /// the IR. The previous state of the operation is stored in this object.
670 class ModifyOperationRewrite : public OperationRewrite {
671 public:
672  ModifyOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
673  Operation *op)
674  : OperationRewrite(Kind::ModifyOperation, rewriterImpl, op),
675  name(op->getName()), loc(op->getLoc()), attrs(op->getAttrDictionary()),
676  operands(op->operand_begin(), op->operand_end()),
677  successors(op->successor_begin(), op->successor_end()) {
678  if (OpaqueProperties prop = op->getPropertiesStorage()) {
679  // Make a copy of the properties.
680  propertiesStorage = operator new(op->getPropertiesStorageSize());
681  OpaqueProperties propCopy(propertiesStorage);
682  name.initOpProperties(propCopy, /*init=*/prop);
683  }
684  }
685 
686  static bool classof(const IRRewrite *rewrite) {
687  return rewrite->getKind() == Kind::ModifyOperation;
688  }
689 
690  ~ModifyOperationRewrite() override {
691  assert(!propertiesStorage &&
692  "rewrite was neither committed nor rolled back");
693  }
694 
695  void commit(RewriterBase &rewriter) override {
696  // Notify the listener that the operation was modified in-place.
697  if (auto *listener =
698  dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener()))
699  listener->notifyOperationModified(op);
700 
701  if (propertiesStorage) {
702  OpaqueProperties propCopy(propertiesStorage);
703  // Note: The operation may have been erased in the mean time, so
704  // OperationName must be stored in this object.
705  name.destroyOpProperties(propCopy);
706  operator delete(propertiesStorage);
707  propertiesStorage = nullptr;
708  }
709  }
710 
711  void rollback() override {
712  op->setLoc(loc);
713  op->setAttrs(attrs);
714  op->setOperands(operands);
715  for (const auto &it : llvm::enumerate(successors))
716  op->setSuccessor(it.value(), it.index());
717  if (propertiesStorage) {
718  OpaqueProperties propCopy(propertiesStorage);
719  op->copyProperties(propCopy);
720  name.destroyOpProperties(propCopy);
721  operator delete(propertiesStorage);
722  propertiesStorage = nullptr;
723  }
724  }
725 
726 private:
727  OperationName name;
728  LocationAttr loc;
729  DictionaryAttr attrs;
730  SmallVector<Value, 8> operands;
731  SmallVector<Block *, 2> successors;
732  void *propertiesStorage = nullptr;
733 };
734 
735 /// Replacing an operation. Erasing an operation is treated as a special case
736 /// with "null" replacements. This rewrite is not immediately reflected in the
737 /// IR. An internal IR mapping is updated, but values are not replaced and the
738 /// original op is not erased until the rewrite is committed.
739 class ReplaceOperationRewrite : public OperationRewrite {
740 public:
741  ReplaceOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
742  Operation *op, const TypeConverter *converter)
743  : OperationRewrite(Kind::ReplaceOperation, rewriterImpl, op),
744  converter(converter) {}
745 
746  static bool classof(const IRRewrite *rewrite) {
747  return rewrite->getKind() == Kind::ReplaceOperation;
748  }
749 
750  void commit(RewriterBase &rewriter) override;
751 
752  void rollback() override;
753 
754  void cleanup(RewriterBase &rewriter) override;
755 
756 private:
757  /// An optional type converter that can be used to materialize conversions
758  /// between the new and old values if necessary.
759  const TypeConverter *converter;
760 };
761 
762 class CreateOperationRewrite : public OperationRewrite {
763 public:
764  CreateOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
765  Operation *op)
766  : OperationRewrite(Kind::CreateOperation, rewriterImpl, op) {}
767 
768  static bool classof(const IRRewrite *rewrite) {
769  return rewrite->getKind() == Kind::CreateOperation;
770  }
771 
772  void commit(RewriterBase &rewriter) override {
773  // The operation was already created and inserted. Just inform the listener.
774  if (auto *listener = rewriter.getListener())
775  listener->notifyOperationInserted(op, /*previous=*/{});
776  }
777 
778  void rollback() override;
779 };
780 
781 /// The type of materialization.
782 enum MaterializationKind {
783  /// This materialization materializes a conversion from an illegal type to a
784  /// legal one.
785  Target,
786 
787  /// This materialization materializes a conversion from a legal type back to
788  /// an illegal one.
789  Source
790 };
791 
792 /// An unresolved materialization, i.e., a "builtin.unrealized_conversion_cast"
793 /// op. Unresolved materializations are erased at the end of the dialect
794 /// conversion.
795 class UnresolvedMaterializationRewrite : public OperationRewrite {
796 public:
797  UnresolvedMaterializationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
798  UnrealizedConversionCastOp op,
799  const TypeConverter *converter,
800  MaterializationKind kind, Type originalType,
801  ValueVector mappedValues);
802 
803  static bool classof(const IRRewrite *rewrite) {
804  return rewrite->getKind() == Kind::UnresolvedMaterialization;
805  }
806 
807  void rollback() override;
808 
809  UnrealizedConversionCastOp getOperation() const {
810  return cast<UnrealizedConversionCastOp>(op);
811  }
812 
813  /// Return the type converter of this materialization (which may be null).
814  const TypeConverter *getConverter() const {
815  return converterAndKind.getPointer();
816  }
817 
818  /// Return the kind of this materialization.
819  MaterializationKind getMaterializationKind() const {
820  return converterAndKind.getInt();
821  }
822 
823  /// Return the original type of the SSA value.
824  Type getOriginalType() const { return originalType; }
825 
826 private:
827  /// The corresponding type converter to use when resolving this
828  /// materialization, and the kind of this materialization.
829  llvm::PointerIntPair<const TypeConverter *, 2, MaterializationKind>
830  converterAndKind;
831 
832  /// The original type of the SSA value. Only used for target
833  /// materializations.
834  Type originalType;
835 
836  /// The values in the conversion value mapping that are being replaced by the
837  /// results of this unresolved materialization.
838  ValueVector mappedValues;
839 };
840 } // namespace
841 
842 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
843 /// Return "true" if there is an operation rewrite that matches the specified
844 /// rewrite type and operation among the given rewrites.
845 template <typename RewriteTy, typename R>
846 static bool hasRewrite(R &&rewrites, Operation *op) {
847  return any_of(std::forward<R>(rewrites), [&](auto &rewrite) {
848  auto *rewriteTy = dyn_cast<RewriteTy>(rewrite.get());
849  return rewriteTy && rewriteTy->getOperation() == op;
850  });
851 }
852 
853 /// Return "true" if there is a block rewrite that matches the specified
854 /// rewrite type and block among the given rewrites.
855 template <typename RewriteTy, typename R>
856 static bool hasRewrite(R &&rewrites, Block *block) {
857  return any_of(std::forward<R>(rewrites), [&](auto &rewrite) {
858  auto *rewriteTy = dyn_cast<RewriteTy>(rewrite.get());
859  return rewriteTy && rewriteTy->getBlock() == block;
860  });
861 }
862 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
863 
864 //===----------------------------------------------------------------------===//
865 // ConversionPatternRewriterImpl
866 //===----------------------------------------------------------------------===//
867 namespace mlir {
868 namespace detail {
871  const ConversionConfig &config)
872  : context(ctx), config(config) {}
873 
874  //===--------------------------------------------------------------------===//
875  // State Management
876  //===--------------------------------------------------------------------===//
877 
878  /// Return the current state of the rewriter.
879  RewriterState getCurrentState();
880 
881  /// Apply all requested operation rewrites. This method is invoked when the
882  /// conversion process succeeds.
883  void applyRewrites();
884 
885  /// Reset the state of the rewriter to a previously saved point. Optionally,
886  /// the name of the pattern that triggered the rollback can specified for
887  /// debugging purposes.
888  void resetState(RewriterState state, StringRef patternName = "");
889 
890  /// Append a rewrite. Rewrites are committed upon success and rolled back upon
891  /// failure.
892  template <typename RewriteTy, typename... Args>
893  void appendRewrite(Args &&...args) {
894  rewrites.push_back(
895  std::make_unique<RewriteTy>(*this, std::forward<Args>(args)...));
896  }
897 
898  /// Undo the rewrites (motions, splits) one by one in reverse order until
899  /// "numRewritesToKeep" rewrites remains. Optionally, the name of the pattern
900  /// that triggered the rollback can specified for debugging purposes.
901  void undoRewrites(unsigned numRewritesToKeep = 0, StringRef patternName = "");
902 
903  /// Remap the given values to those with potentially different types. Returns
904  /// success if the values could be remapped, failure otherwise. `valueDiagTag`
905  /// is the tag used when describing a value within a diagnostic, e.g.
906  /// "operand".
907  LogicalResult remapValues(StringRef valueDiagTag,
908  std::optional<Location> inputLoc,
909  PatternRewriter &rewriter, ValueRange values,
910  SmallVector<ValueVector> &remapped);
911 
912  /// Return "true" if the given operation is ignored, and does not need to be
913  /// converted.
914  bool isOpIgnored(Operation *op) const;
915 
916  /// Return "true" if the given operation was replaced or erased.
917  bool wasOpReplaced(Operation *op) const;
918 
919  //===--------------------------------------------------------------------===//
920  // IR Rewrites / Type Conversion
921  //===--------------------------------------------------------------------===//
922 
923  /// Convert the types of block arguments within the given region.
924  FailureOr<Block *>
925  convertRegionTypes(ConversionPatternRewriter &rewriter, Region *region,
926  const TypeConverter &converter,
927  TypeConverter::SignatureConversion *entryConversion);
928 
929  /// Apply the given signature conversion on the given block. The new block
930  /// containing the updated signature is returned. If no conversions were
931  /// necessary, e.g. if the block has no arguments, `block` is returned.
932  /// `converter` is used to generate any necessary cast operations that
933  /// translate between the origin argument types and those specified in the
934  /// signature conversion.
935  Block *applySignatureConversion(
936  ConversionPatternRewriter &rewriter, Block *block,
937  const TypeConverter *converter,
938  TypeConverter::SignatureConversion &signatureConversion);
939 
940  /// Replace the results of the given operation with the given values and
941  /// erase the operation.
942  ///
943  /// There can be multiple replacement values for each result (1:N
944  /// replacement). If the replacement values are empty, the respective result
945  /// is dropped and a source materialization is built if the result still has
946  /// uses.
947  void replaceOp(Operation *op, SmallVector<SmallVector<Value>> &&newValues);
948 
949  /// Replace the given block argument with the given values. The specified
950  /// converter is used to build materializations (if necessary).
951  void replaceUsesOfBlockArgument(BlockArgument from, ValueRange to,
952  const TypeConverter *converter);
953 
954  /// Erase the given block and its contents.
955  void eraseBlock(Block *block);
956 
957  /// Inline the source block into the destination block before the given
958  /// iterator.
959  void inlineBlockBefore(Block *source, Block *dest, Block::iterator before);
960 
961  //===--------------------------------------------------------------------===//
962  // Materializations
963  //===--------------------------------------------------------------------===//
964 
965  /// Build an unresolved materialization operation given a range of output
966  /// types and a list of input operands. Returns the inputs if they their
967  /// types match the output types.
968  ///
969  /// If a cast op was built, it can optionally be returned with the `castOp`
970  /// output argument.
971  ///
972  /// If `valuesToMap` is set to a non-null Value, then that value is mapped to
973  /// the results of the unresolved materialization in the conversion value
974  /// mapping.
975  ValueRange buildUnresolvedMaterialization(
976  MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
977  ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes,
978  Type originalType, const TypeConverter *converter,
979  UnrealizedConversionCastOp *castOp = nullptr);
980 
981  /// Find a replacement value for the given SSA value in the conversion value
982  /// mapping. The replacement value must have the same type as the given SSA
983  /// value. If there is no replacement value with the correct type, find the
984  /// latest replacement value (regardless of the type) and build a source
985  /// materialization.
986  Value findOrBuildReplacementValue(Value value,
987  const TypeConverter *converter);
988 
989  //===--------------------------------------------------------------------===//
990  // Rewriter Notification Hooks
991  //===--------------------------------------------------------------------===//
992 
993  //// Notifies that an op was inserted.
994  void notifyOperationInserted(Operation *op,
995  OpBuilder::InsertPoint previous) override;
996 
997  /// Notifies that a block was inserted.
998  void notifyBlockInserted(Block *block, Region *previous,
999  Region::iterator previousIt) override;
1000 
1001  /// Notifies that a pattern match failed for the given reason.
1002  void
1003  notifyMatchFailure(Location loc,
1004  function_ref<void(Diagnostic &)> reasonCallback) override;
1005 
1006  //===--------------------------------------------------------------------===//
1007  // IR Erasure
1008  //===--------------------------------------------------------------------===//
1009 
1010  /// A rewriter that keeps track of erased ops and blocks. It ensures that no
1011  /// operation or block is erased multiple times. This rewriter assumes that
1012  /// no new IR is created between calls to `eraseOp`/`eraseBlock`.
1014  public:
1016  MLIRContext *context,
1017  std::function<void(Operation *)> opErasedCallback = nullptr)
1018  : RewriterBase(context, /*listener=*/this),
1019  opErasedCallback(opErasedCallback) {}
1020 
1021  /// Erase the given op (unless it was already erased).
1022  void eraseOp(Operation *op) override {
1023  if (wasErased(op))
1024  return;
1025  op->dropAllUses();
1027  }
1028 
1029  /// Erase the given block (unless it was already erased).
1030  void eraseBlock(Block *block) override {
1031  if (wasErased(block))
1032  return;
1033  assert(block->empty() && "expected empty block");
1034  block->dropAllDefinedValueUses();
1035  RewriterBase::eraseBlock(block);
1036  }
1037 
1038  bool wasErased(void *ptr) const { return erased.contains(ptr); }
1039 
1040  void notifyOperationErased(Operation *op) override {
1041  erased.insert(op);
1042  if (opErasedCallback)
1043  opErasedCallback(op);
1044  }
1045 
1046  void notifyBlockErased(Block *block) override { erased.insert(block); }
1047 
1048  private:
1049  /// Pointers to all erased operations and blocks.
1050  DenseSet<void *> erased;
1051 
1052  /// A callback that is invoked when an operation is erased.
1053  std::function<void(Operation *)> opErasedCallback;
1054  };
1055 
1056  //===--------------------------------------------------------------------===//
1057  // State
1058  //===--------------------------------------------------------------------===//
1059 
1060  /// MLIR context.
1062 
1063  // Mapping between replaced values that differ in type. This happens when
1064  // replacing a value with one of a different type.
1065  ConversionValueMapping mapping;
1066 
1067  /// Ordered list of block operations (creations, splits, motions).
1069 
1070  /// A set of operations that should no longer be considered for legalization.
1071  /// E.g., ops that are recursively legal. Ops that were replaced/erased are
1072  /// tracked separately.
1074 
1075  /// A set of operations that were replaced/erased. Such ops are not erased
1076  /// immediately but only when the dialect conversion succeeds. In the mean
1077  /// time, they should no longer be considered for legalization and any attempt
1078  /// to modify/access them is invalid rewriter API usage.
1080 
1081  /// A set of operations that were created by the current pattern.
1083 
1084  /// A set of operations that were modified by the current pattern.
1086 
1087  /// A set of blocks that were inserted (newly-created blocks or moved blocks)
1088  /// by the current pattern.
1090 
1091  /// A mapping of all unresolved materializations (UnrealizedConversionCastOp)
1092  /// to the corresponding rewrite objects.
1095 
1096  /// The current type converter, or nullptr if no type converter is currently
1097  /// active.
1098  const TypeConverter *currentTypeConverter = nullptr;
1099 
1100  /// A mapping of regions to type converters that should be used when
1101  /// converting the arguments of blocks within that region.
1103 
1104  /// Dialect conversion configuration.
1106 
1107 #ifndef NDEBUG
1108  /// A set of operations that have pending updates. This tracking isn't
1109  /// strictly necessary, and is thus only active during debug builds for extra
1110  /// verification.
1112 
1113  /// A logger used to emit diagnostics during the conversion process.
1114  llvm::ScopedPrinter logger{llvm::dbgs()};
1115 #endif
1116 };
1117 } // namespace detail
1118 } // namespace mlir
1119 
1120 const ConversionConfig &IRRewrite::getConfig() const {
1121  return rewriterImpl.config;
1122 }
1123 
1124 void BlockTypeConversionRewrite::commit(RewriterBase &rewriter) {
1125  // Inform the listener about all IR modifications that have already taken
1126  // place: References to the original block have been replaced with the new
1127  // block.
1128  if (auto *listener =
1129  dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener()))
1130  for (Operation *op : getNewBlock()->getUsers())
1131  listener->notifyOperationModified(op);
1132 }
1133 
1134 void BlockTypeConversionRewrite::rollback() {
1135  getNewBlock()->replaceAllUsesWith(getOrigBlock());
1136 }
1137 
1138 void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) {
1139  Value repl = rewriterImpl.findOrBuildReplacementValue(arg, converter);
1140  if (!repl)
1141  return;
1142 
1143  if (isa<BlockArgument>(repl)) {
1144  rewriter.replaceAllUsesWith(arg, repl);
1145  return;
1146  }
1147 
1148  // If the replacement value is an operation, we check to make sure that we
1149  // don't replace uses that are within the parent operation of the
1150  // replacement value.
1151  Operation *replOp = cast<OpResult>(repl).getOwner();
1152  Block *replBlock = replOp->getBlock();
1153  rewriter.replaceUsesWithIf(arg, repl, [&](OpOperand &operand) {
1154  Operation *user = operand.getOwner();
1155  return user->getBlock() != replBlock || replOp->isBeforeInBlock(user);
1156  });
1157 }
1158 
1159 void ReplaceBlockArgRewrite::rollback() { rewriterImpl.mapping.erase({arg}); }
1160 
1161 void ReplaceOperationRewrite::commit(RewriterBase &rewriter) {
1162  auto *listener =
1163  dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener());
1164 
1165  // Compute replacement values.
1166  SmallVector<Value> replacements =
1167  llvm::map_to_vector(op->getResults(), [&](OpResult result) {
1168  return rewriterImpl.findOrBuildReplacementValue(result, converter);
1169  });
1170 
1171  // Notify the listener that the operation is about to be replaced.
1172  if (listener)
1173  listener->notifyOperationReplaced(op, replacements);
1174 
1175  // Replace all uses with the new values.
1176  for (auto [result, newValue] :
1177  llvm::zip_equal(op->getResults(), replacements))
1178  if (newValue)
1179  rewriter.replaceAllUsesWith(result, newValue);
1180 
1181  // The original op will be erased, so remove it from the set of unlegalized
1182  // ops.
1183  if (getConfig().unlegalizedOps)
1184  getConfig().unlegalizedOps->erase(op);
1185 
1186  // Notify the listener that the operation and its contents are being erased.
1187  if (listener)
1188  notifyIRErased(listener, *op);
1189 
1190  // Do not erase the operation yet. It may still be referenced in `mapping`.
1191  // Just unlink it for now and erase it during cleanup.
1192  op->getBlock()->getOperations().remove(op);
1193 }
1194 
1195 void ReplaceOperationRewrite::rollback() {
1196  for (auto result : op->getResults())
1197  rewriterImpl.mapping.erase({result});
1198 }
1199 
1200 void ReplaceOperationRewrite::cleanup(RewriterBase &rewriter) {
1201  rewriter.eraseOp(op);
1202 }
1203 
1204 void CreateOperationRewrite::rollback() {
1205  for (Region &region : op->getRegions()) {
1206  while (!region.getBlocks().empty())
1207  region.getBlocks().remove(region.getBlocks().begin());
1208  }
1209  op->dropAllUses();
1210  op->erase();
1211 }
1212 
1213 UnresolvedMaterializationRewrite::UnresolvedMaterializationRewrite(
1214  ConversionPatternRewriterImpl &rewriterImpl, UnrealizedConversionCastOp op,
1215  const TypeConverter *converter, MaterializationKind kind, Type originalType,
1216  ValueVector mappedValues)
1217  : OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
1218  converterAndKind(converter, kind), originalType(originalType),
1219  mappedValues(std::move(mappedValues)) {
1220  assert((!originalType || kind == MaterializationKind::Target) &&
1221  "original type is valid only for target materializations");
1222  rewriterImpl.unresolvedMaterializations[op] = this;
1223 }
1224 
1225 void UnresolvedMaterializationRewrite::rollback() {
1226  if (!mappedValues.empty())
1227  rewriterImpl.mapping.erase(mappedValues);
1228  rewriterImpl.unresolvedMaterializations.erase(getOperation());
1229  op->erase();
1230 }
1231 
1233  // Commit all rewrites.
1234  IRRewriter rewriter(context, config.listener);
1235  // Note: New rewrites may be added during the "commit" phase and the
1236  // `rewrites` vector may reallocate.
1237  for (size_t i = 0; i < rewrites.size(); ++i)
1238  rewrites[i]->commit(rewriter);
1239 
1240  // Clean up all rewrites.
1241  SingleEraseRewriter eraseRewriter(
1242  context, /*opErasedCallback=*/[&](Operation *op) {
1243  if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op))
1244  unresolvedMaterializations.erase(castOp);
1245  });
1246  for (auto &rewrite : rewrites)
1247  rewrite->cleanup(eraseRewriter);
1248 }
1249 
1250 //===----------------------------------------------------------------------===//
1251 // State Management
1252 //===----------------------------------------------------------------------===//
1253 
1255  return RewriterState(rewrites.size(), ignoredOps.size(), replacedOps.size());
1256 }
1257 
1259  StringRef patternName) {
1260  // Undo any rewrites.
1261  undoRewrites(state.numRewrites, patternName);
1262 
1263  // Pop all of the recorded ignored operations that are no longer valid.
1264  while (ignoredOps.size() != state.numIgnoredOperations)
1265  ignoredOps.pop_back();
1266 
1267  while (replacedOps.size() != state.numReplacedOps)
1268  replacedOps.pop_back();
1269 }
1270 
1271 void ConversionPatternRewriterImpl::undoRewrites(unsigned numRewritesToKeep,
1272  StringRef patternName) {
1273  for (auto &rewrite :
1274  llvm::reverse(llvm::drop_begin(rewrites, numRewritesToKeep))) {
1276  !isa<UnresolvedMaterializationRewrite>(rewrite)) {
1277  // Unresolved materializations can always be rolled back (erased).
1278  llvm::report_fatal_error("pattern '" + patternName +
1279  "' rollback of IR modifications requested");
1280  }
1281  rewrite->rollback();
1282  }
1283  rewrites.resize(numRewritesToKeep);
1284 }
1285 
1287  StringRef valueDiagTag, std::optional<Location> inputLoc,
1288  PatternRewriter &rewriter, ValueRange values,
1289  SmallVector<ValueVector> &remapped) {
1290  remapped.reserve(llvm::size(values));
1291 
1292  for (const auto &it : llvm::enumerate(values)) {
1293  Value operand = it.value();
1294  Type origType = operand.getType();
1295  Location operandLoc = inputLoc ? *inputLoc : operand.getLoc();
1296 
1297  if (!currentTypeConverter) {
1298  // The current pattern does not have a type converter. I.e., it does not
1299  // distinguish between legal and illegal types. For each operand, simply
1300  // pass through the most recently mapped values.
1301  remapped.push_back(mapping.lookupOrDefault(operand));
1302  continue;
1303  }
1304 
1305  // If there is no legal conversion, fail to match this pattern.
1306  SmallVector<Type, 1> legalTypes;
1307  if (failed(currentTypeConverter->convertType(origType, legalTypes))) {
1308  notifyMatchFailure(operandLoc, [=](Diagnostic &diag) {
1309  diag << "unable to convert type for " << valueDiagTag << " #"
1310  << it.index() << ", type was " << origType;
1311  });
1312  return failure();
1313  }
1314  // If a type is converted to 0 types, there is nothing to do.
1315  if (legalTypes.empty()) {
1316  remapped.push_back({});
1317  continue;
1318  }
1319 
1320  ValueVector repl = mapping.lookupOrDefault(operand, legalTypes);
1321  if (!repl.empty() && TypeRange(ValueRange(repl)) == legalTypes) {
1322  // Mapped values have the correct type or there is an existing
1323  // materialization. Or the operand is not mapped at all and has the
1324  // correct type.
1325  remapped.push_back(std::move(repl));
1326  continue;
1327  }
1328 
1329  // Create a materialization for the most recently mapped values.
1330  repl = mapping.lookupOrDefault(operand);
1332  MaterializationKind::Target, computeInsertPoint(repl), operandLoc,
1333  /*valuesToMap=*/repl, /*inputs=*/repl, /*outputTypes=*/legalTypes,
1334  /*originalType=*/origType, currentTypeConverter);
1335  remapped.push_back(castValues);
1336  }
1337  return success();
1338 }
1339 
1341  // Check to see if this operation is ignored or was replaced.
1342  return replacedOps.count(op) || ignoredOps.count(op);
1343 }
1344 
1346  // Check to see if this operation was replaced.
1347  return replacedOps.count(op);
1348 }
1349 
1350 //===----------------------------------------------------------------------===//
1351 // Type Conversion
1352 //===----------------------------------------------------------------------===//
1353 
1355  ConversionPatternRewriter &rewriter, Region *region,
1356  const TypeConverter &converter,
1357  TypeConverter::SignatureConversion *entryConversion) {
1358  regionToConverter[region] = &converter;
1359  if (region->empty())
1360  return nullptr;
1361 
1362  // Convert the arguments of each non-entry block within the region.
1363  for (Block &block :
1364  llvm::make_early_inc_range(llvm::drop_begin(*region, 1))) {
1365  // Compute the signature for the block with the provided converter.
1366  std::optional<TypeConverter::SignatureConversion> conversion =
1367  converter.convertBlockSignature(&block);
1368  if (!conversion)
1369  return failure();
1370  // Convert the block with the computed signature.
1371  applySignatureConversion(rewriter, &block, &converter, *conversion);
1372  }
1373 
1374  // Convert the entry block. If an entry signature conversion was provided,
1375  // use that one. Otherwise, compute the signature with the type converter.
1376  if (entryConversion)
1377  return applySignatureConversion(rewriter, &region->front(), &converter,
1378  *entryConversion);
1379  std::optional<TypeConverter::SignatureConversion> conversion =
1380  converter.convertBlockSignature(&region->front());
1381  if (!conversion)
1382  return failure();
1383  return applySignatureConversion(rewriter, &region->front(), &converter,
1384  *conversion);
1385 }
1386 
1388  ConversionPatternRewriter &rewriter, Block *block,
1389  const TypeConverter *converter,
1390  TypeConverter::SignatureConversion &signatureConversion) {
1391 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
1392  // A block cannot be converted multiple times.
1393  if (hasRewrite<BlockTypeConversionRewrite>(rewrites, block))
1394  llvm::report_fatal_error("block was already converted");
1395 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
1396 
1397  OpBuilder::InsertionGuard g(rewriter);
1398 
1399  // If no arguments are being changed or added, there is nothing to do.
1400  unsigned origArgCount = block->getNumArguments();
1401  auto convertedTypes = signatureConversion.getConvertedTypes();
1402  if (llvm::equal(block->getArgumentTypes(), convertedTypes))
1403  return block;
1404 
1405  // Compute the locations of all block arguments in the new block.
1406  SmallVector<Location> newLocs(convertedTypes.size(),
1407  rewriter.getUnknownLoc());
1408  for (unsigned i = 0; i < origArgCount; ++i) {
1409  auto inputMap = signatureConversion.getInputMapping(i);
1410  if (!inputMap || inputMap->replacedWithValues())
1411  continue;
1412  Location origLoc = block->getArgument(i).getLoc();
1413  for (unsigned j = 0; j < inputMap->size; ++j)
1414  newLocs[inputMap->inputNo + j] = origLoc;
1415  }
1416 
1417  // Insert a new block with the converted block argument types and move all ops
1418  // from the old block to the new block.
1419  Block *newBlock =
1420  rewriter.createBlock(block->getParent(), std::next(block->getIterator()),
1421  convertedTypes, newLocs);
1422 
1423  // If a listener is attached to the dialect conversion, ops cannot be moved
1424  // to the destination block in bulk ("fast path"). This is because at the time
1425  // the notifications are sent, it is unknown which ops were moved. Instead,
1426  // ops should be moved one-by-one ("slow path"), so that a separate
1427  // `MoveOperationRewrite` is enqueued for each moved op. Moving ops in bulk is
1428  // a bit more efficient, so we try to do that when possible.
1429  bool fastPath = !config.listener;
1430  if (fastPath) {
1431  appendRewrite<InlineBlockRewrite>(newBlock, block, newBlock->end());
1432  newBlock->getOperations().splice(newBlock->end(), block->getOperations());
1433  } else {
1434  while (!block->empty())
1435  rewriter.moveOpBefore(&block->front(), newBlock, newBlock->end());
1436  }
1437 
1438  // Replace all uses of the old block with the new block.
1439  block->replaceAllUsesWith(newBlock);
1440 
1441  for (unsigned i = 0; i != origArgCount; ++i) {
1442  BlockArgument origArg = block->getArgument(i);
1443  Type origArgType = origArg.getType();
1444 
1445  std::optional<TypeConverter::SignatureConversion::InputMapping> inputMap =
1446  signatureConversion.getInputMapping(i);
1447  if (!inputMap) {
1448  // This block argument was dropped and no replacement value was provided.
1449  // Materialize a replacement value "out of thin air".
1450  Value mat =
1452  MaterializationKind::Source,
1453  OpBuilder::InsertPoint(newBlock, newBlock->begin()),
1454  origArg.getLoc(),
1455  /*valuesToMap=*/{}, /*inputs=*/ValueRange(),
1456  /*outputTypes=*/origArgType, /*originalType=*/Type(), converter)
1457  .front();
1458  replaceUsesOfBlockArgument(origArg, mat, converter);
1459  continue;
1460  }
1461 
1462  if (inputMap->replacedWithValues()) {
1463  // This block argument was dropped and replacement values were provided.
1464  assert(inputMap->size == 0 &&
1465  "invalid to provide a replacement value when the argument isn't "
1466  "dropped");
1467  replaceUsesOfBlockArgument(origArg, inputMap->replacementValues,
1468  converter);
1469  continue;
1470  }
1471 
1472  // This is a 1->1+ mapping.
1473  auto replArgs =
1474  newBlock->getArguments().slice(inputMap->inputNo, inputMap->size);
1475  replaceUsesOfBlockArgument(origArg, replArgs, converter);
1476  }
1477 
1478  appendRewrite<BlockTypeConversionRewrite>(/*origBlock=*/block, newBlock);
1479 
1480  // Erase the old block. (It is just unlinked for now and will be erased during
1481  // cleanup.)
1482  rewriter.eraseBlock(block);
1483 
1484  return newBlock;
1485 }
1486 
1487 //===----------------------------------------------------------------------===//
1488 // Materializations
1489 //===----------------------------------------------------------------------===//
1490 
1491 /// Build an unresolved materialization operation given an output type and set
1492 /// of input operands.
1494  MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
1495  ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes,
1496  Type originalType, const TypeConverter *converter,
1497  UnrealizedConversionCastOp *castOp) {
1498  assert((!originalType || kind == MaterializationKind::Target) &&
1499  "original type is valid only for target materializations");
1500  assert(TypeRange(inputs) != outputTypes &&
1501  "materialization is not necessary");
1502 
1503  // Create an unresolved materialization. We use a new OpBuilder to avoid
1504  // tracking the materialization like we do for other operations.
1505  OpBuilder builder(outputTypes.front().getContext());
1506  builder.setInsertionPoint(ip.getBlock(), ip.getPoint());
1507  auto convertOp =
1508  builder.create<UnrealizedConversionCastOp>(loc, outputTypes, inputs);
1509  if (!valuesToMap.empty())
1510  mapping.map(valuesToMap, convertOp.getResults());
1511  if (castOp)
1512  *castOp = convertOp;
1513  appendRewrite<UnresolvedMaterializationRewrite>(
1514  convertOp, converter, kind, originalType, std::move(valuesToMap));
1515  return convertOp.getResults();
1516 }
1517 
1519  Value value, const TypeConverter *converter) {
1520  // Try to find a replacement value with the same type in the conversion value
1521  // mapping. This includes cached materializations. We try to reuse those
1522  // instead of generating duplicate IR.
1523  ValueVector repl = mapping.lookupOrNull(value, value.getType());
1524  if (!repl.empty())
1525  return repl.front();
1526 
1527  // Check if the value is dead. No replacement value is needed in that case.
1528  // This is an approximate check that may have false negatives but does not
1529  // require computing and traversing an inverse mapping. (We may end up
1530  // building source materializations that are never used and that fold away.)
1531  if (llvm::all_of(value.getUsers(),
1532  [&](Operation *op) { return replacedOps.contains(op); }) &&
1533  !mapping.isMappedTo(value))
1534  return Value();
1535 
1536  // No replacement value was found. Get the latest replacement value
1537  // (regardless of the type) and build a source materialization to the
1538  // original type.
1539  repl = mapping.lookupOrNull(value);
1540  if (repl.empty()) {
1541  // No replacement value is registered in the mapping. This means that the
1542  // value is dropped and no longer needed. (If the value were still needed,
1543  // a source materialization producing a replacement value "out of thin air"
1544  // would have already been created during `replaceOp` or
1545  // `applySignatureConversion`.)
1546  return Value();
1547  }
1548 
1549  // Note: `computeInsertPoint` computes the "earliest" insertion point at
1550  // which all values in `repl` are defined. It is important to emit the
1551  // materialization at that location because the same materialization may be
1552  // reused in a different context. (That's because materializations are cached
1553  // in the conversion value mapping.) The insertion point of the
1554  // materialization must be valid for all future users that may be created
1555  // later in the conversion process.
1556  Value castValue =
1557  buildUnresolvedMaterialization(MaterializationKind::Source,
1558  computeInsertPoint(repl), value.getLoc(),
1559  /*valuesToMap=*/repl, /*inputs=*/repl,
1560  /*outputTypes=*/value.getType(),
1561  /*originalType=*/Type(), converter)
1562  .front();
1563  return castValue;
1564 }
1565 
1566 //===----------------------------------------------------------------------===//
1567 // Rewriter Notification Hooks
1568 //===----------------------------------------------------------------------===//
1569 
1571  Operation *op, OpBuilder::InsertPoint previous) {
1572  LLVM_DEBUG({
1573  logger.startLine() << "** Insert : '" << op->getName() << "'(" << op
1574  << ")\n";
1575  });
1576  assert(!wasOpReplaced(op->getParentOp()) &&
1577  "attempting to insert into a block within a replaced/erased op");
1578 
1579  if (!previous.isSet()) {
1580  // This is a newly created op.
1581  appendRewrite<CreateOperationRewrite>(op);
1582  patternNewOps.insert(op);
1583  return;
1584  }
1585  Operation *prevOp = previous.getPoint() == previous.getBlock()->end()
1586  ? nullptr
1587  : &*previous.getPoint();
1588  appendRewrite<MoveOperationRewrite>(op, previous.getBlock(), prevOp);
1589 }
1590 
1592  Operation *op, SmallVector<SmallVector<Value>> &&newValues) {
1593  assert(newValues.size() == op->getNumResults());
1594  assert(!ignoredOps.contains(op) && "operation was already replaced");
1595 
1596  // Check if replaced op is an unresolved materialization, i.e., an
1597  // unrealized_conversion_cast op that was created by the conversion driver.
1598  if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op)) {
1599  // Make sure that the user does not mess with unresolved materializations
1600  // that were inserted by the conversion driver. We keep track of these
1601  // ops in internal data structures.
1602  assert(!unresolvedMaterializations.contains(castOp) &&
1603  "attempting to replace/erase an unresolved materialization");
1604  }
1605 
1606  // Create mappings for each of the new result values.
1607  for (auto [repl, result] : llvm::zip_equal(newValues, op->getResults())) {
1608  if (repl.empty()) {
1609  // This result was dropped and no replacement value was provided.
1610  // Materialize a replacement value "out of thin air".
1612  MaterializationKind::Source, computeInsertPoint(result),
1613  result.getLoc(), /*valuesToMap=*/{result}, /*inputs=*/ValueRange(),
1614  /*outputTypes=*/result.getType(), /*originalType=*/Type(),
1616  continue;
1617  }
1618 
1619  // Remap result to replacement value.
1620  if (repl.empty())
1621  continue;
1622  mapping.map(static_cast<Value>(result), std::move(repl));
1623  }
1624 
1625  appendRewrite<ReplaceOperationRewrite>(op, currentTypeConverter);
1626  // Mark this operation and all nested ops as replaced.
1627  op->walk([&](Operation *op) { replacedOps.insert(op); });
1628 }
1629 
1631  BlockArgument from, ValueRange to, const TypeConverter *converter) {
1632  appendRewrite<ReplaceBlockArgRewrite>(from.getOwner(), from, converter);
1633  mapping.map(from, to);
1634 }
1635 
1637  assert(!wasOpReplaced(block->getParentOp()) &&
1638  "attempting to erase a block within a replaced/erased op");
1639  appendRewrite<EraseBlockRewrite>(block);
1640 
1641  // Unlink the block from its parent region. The block is kept in the rewrite
1642  // object and will be actually destroyed when rewrites are applied. This
1643  // allows us to keep the operations in the block live and undo the removal by
1644  // re-inserting the block.
1645  block->getParent()->getBlocks().remove(block);
1646 
1647  // Mark all nested ops as erased.
1648  block->walk([&](Operation *op) { replacedOps.insert(op); });
1649 }
1650 
1652  Block *block, Region *previous, Region::iterator previousIt) {
1653  assert(!wasOpReplaced(block->getParentOp()) &&
1654  "attempting to insert into a region within a replaced/erased op");
1655  LLVM_DEBUG(
1656  {
1657  Operation *parent = block->getParentOp();
1658  if (parent) {
1659  logger.startLine() << "** Insert Block into : '" << parent->getName()
1660  << "'(" << parent << ")\n";
1661  } else {
1662  logger.startLine()
1663  << "** Insert Block into detached Region (nullptr parent op)'\n";
1664  }
1665  });
1666 
1667  patternInsertedBlocks.insert(block);
1668 
1669  if (!previous) {
1670  // This is a newly created block.
1671  appendRewrite<CreateBlockRewrite>(block);
1672  return;
1673  }
1674  Block *prevBlock = previousIt == previous->end() ? nullptr : &*previousIt;
1675  appendRewrite<MoveBlockRewrite>(block, previous, prevBlock);
1676 }
1677 
1679  Block *dest,
1680  Block::iterator before) {
1681  appendRewrite<InlineBlockRewrite>(dest, source, before);
1682 }
1683 
1685  Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
1686  LLVM_DEBUG({
1688  reasonCallback(diag);
1689  logger.startLine() << "** Failure : " << diag.str() << "\n";
1690  if (config.notifyCallback)
1692  });
1693 }
1694 
1695 //===----------------------------------------------------------------------===//
1696 // ConversionPatternRewriter
1697 //===----------------------------------------------------------------------===//
1698 
1699 ConversionPatternRewriter::ConversionPatternRewriter(
1700  MLIRContext *ctx, const ConversionConfig &config)
1701  : PatternRewriter(ctx),
1702  impl(new detail::ConversionPatternRewriterImpl(ctx, config)) {
1703  setListener(impl.get());
1704 }
1705 
1707 
1709  assert(op && newOp && "expected non-null op");
1710  replaceOp(op, newOp->getResults());
1711 }
1712 
1714  assert(op->getNumResults() == newValues.size() &&
1715  "incorrect # of replacement values");
1716  LLVM_DEBUG({
1717  impl->logger.startLine()
1718  << "** Replace : '" << op->getName() << "'(" << op << ")\n";
1719  });
1721  llvm::map_to_vector(newValues, [](Value v) -> SmallVector<Value> {
1722  return v ? SmallVector<Value>{v} : SmallVector<Value>();
1723  });
1724  impl->replaceOp(op, std::move(newVals));
1725 }
1726 
1728  Operation *op, SmallVector<SmallVector<Value>> &&newValues) {
1729  assert(op->getNumResults() == newValues.size() &&
1730  "incorrect # of replacement values");
1731  LLVM_DEBUG({
1732  impl->logger.startLine()
1733  << "** Replace : '" << op->getName() << "'(" << op << ")\n";
1734  });
1735  impl->replaceOp(op, std::move(newValues));
1736 }
1737 
1739  LLVM_DEBUG({
1740  impl->logger.startLine()
1741  << "** Erase : '" << op->getName() << "'(" << op << ")\n";
1742  });
1743  SmallVector<SmallVector<Value>> nullRepls(op->getNumResults(), {});
1744  impl->replaceOp(op, std::move(nullRepls));
1745 }
1746 
1748  impl->eraseBlock(block);
1749 }
1750 
1752  Block *block, TypeConverter::SignatureConversion &conversion,
1753  const TypeConverter *converter) {
1754  assert(!impl->wasOpReplaced(block->getParentOp()) &&
1755  "attempting to apply a signature conversion to a block within a "
1756  "replaced/erased op");
1757  return impl->applySignatureConversion(*this, block, converter, conversion);
1758 }
1759 
1761  Region *region, const TypeConverter &converter,
1762  TypeConverter::SignatureConversion *entryConversion) {
1763  assert(!impl->wasOpReplaced(region->getParentOp()) &&
1764  "attempting to apply a signature conversion to a block within a "
1765  "replaced/erased op");
1766  return impl->convertRegionTypes(*this, region, converter, entryConversion);
1767 }
1768 
1770  ValueRange to) {
1771  LLVM_DEBUG({
1772  impl->logger.startLine() << "** Replace Argument : '" << from << "'";
1773  if (Operation *parentOp = from.getOwner()->getParentOp()) {
1774  impl->logger.getOStream() << " (in region of '" << parentOp->getName()
1775  << "' (" << parentOp << ")\n";
1776  } else {
1777  impl->logger.getOStream() << " (unlinked block)\n";
1778  }
1779  });
1780  impl->replaceUsesOfBlockArgument(from, to, impl->currentTypeConverter);
1781 }
1782 
1784  SmallVector<ValueVector> remappedValues;
1785  if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, key,
1786  remappedValues)))
1787  return nullptr;
1788  assert(remappedValues.front().size() == 1 && "1:N conversion not supported");
1789  return remappedValues.front().front();
1790 }
1791 
1792 LogicalResult
1794  SmallVectorImpl<Value> &results) {
1795  if (keys.empty())
1796  return success();
1797  SmallVector<ValueVector> remapped;
1798  if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, keys,
1799  remapped)))
1800  return failure();
1801  for (const auto &values : remapped) {
1802  assert(values.size() == 1 && "1:N conversion not supported");
1803  results.push_back(values.front());
1804  }
1805  return success();
1806 }
1807 
1809  Block::iterator before,
1810  ValueRange argValues) {
1811 #ifndef NDEBUG
1812  assert(argValues.size() == source->getNumArguments() &&
1813  "incorrect # of argument replacement values");
1814  assert(!impl->wasOpReplaced(source->getParentOp()) &&
1815  "attempting to inline a block from a replaced/erased op");
1816  assert(!impl->wasOpReplaced(dest->getParentOp()) &&
1817  "attempting to inline a block into a replaced/erased op");
1818  auto opIgnored = [&](Operation *op) { return impl->isOpIgnored(op); };
1819  // The source block will be deleted, so it should not have any users (i.e.,
1820  // there should be no predecessors).
1821  assert(llvm::all_of(source->getUsers(), opIgnored) &&
1822  "expected 'source' to have no predecessors");
1823 #endif // NDEBUG
1824 
1825  // If a listener is attached to the dialect conversion, ops cannot be moved
1826  // to the destination block in bulk ("fast path"). This is because at the time
1827  // the notifications are sent, it is unknown which ops were moved. Instead,
1828  // ops should be moved one-by-one ("slow path"), so that a separate
1829  // `MoveOperationRewrite` is enqueued for each moved op. Moving ops in bulk is
1830  // a bit more efficient, so we try to do that when possible.
1831  bool fastPath = !impl->config.listener;
1832 
1833  if (fastPath)
1834  impl->inlineBlockBefore(source, dest, before);
1835 
1836  // Replace all uses of block arguments.
1837  for (auto it : llvm::zip(source->getArguments(), argValues))
1838  replaceUsesOfBlockArgument(std::get<0>(it), std::get<1>(it));
1839 
1840  if (fastPath) {
1841  // Move all ops at once.
1842  dest->getOperations().splice(before, source->getOperations());
1843  } else {
1844  // Move op by op.
1845  while (!source->empty())
1846  moveOpBefore(&source->front(), dest, before);
1847  }
1848 
1849  // Erase the source block.
1850  eraseBlock(source);
1851 }
1852 
1854  assert(!impl->wasOpReplaced(op) &&
1855  "attempting to modify a replaced/erased op");
1856 #ifndef NDEBUG
1857  impl->pendingRootUpdates.insert(op);
1858 #endif
1859  impl->appendRewrite<ModifyOperationRewrite>(op);
1860 }
1861 
1863  assert(!impl->wasOpReplaced(op) &&
1864  "attempting to modify a replaced/erased op");
1866  impl->patternModifiedOps.insert(op);
1867 
1868  // There is nothing to do here, we only need to track the operation at the
1869  // start of the update.
1870 #ifndef NDEBUG
1871  assert(impl->pendingRootUpdates.erase(op) &&
1872  "operation did not have a pending in-place update");
1873 #endif
1874 }
1875 
1877 #ifndef NDEBUG
1878  assert(impl->pendingRootUpdates.erase(op) &&
1879  "operation did not have a pending in-place update");
1880 #endif
1881  // Erase the last update for this operation.
1882  auto it = llvm::find_if(
1883  llvm::reverse(impl->rewrites), [&](std::unique_ptr<IRRewrite> &rewrite) {
1884  auto *modifyRewrite = dyn_cast<ModifyOperationRewrite>(rewrite.get());
1885  return modifyRewrite && modifyRewrite->getOperation() == op;
1886  });
1887  assert(it != impl->rewrites.rend() && "no root update started on op");
1888  (*it)->rollback();
1889  int updateIdx = std::prev(impl->rewrites.rend()) - it;
1890  impl->rewrites.erase(impl->rewrites.begin() + updateIdx);
1891 }
1892 
1894  return *impl;
1895 }
1896 
1897 //===----------------------------------------------------------------------===//
1898 // ConversionPattern
1899 //===----------------------------------------------------------------------===//
1900 
1902  ArrayRef<ValueRange> operands) const {
1903  SmallVector<Value> oneToOneOperands;
1904  oneToOneOperands.reserve(operands.size());
1905  for (ValueRange operand : operands) {
1906  if (operand.size() != 1)
1907  llvm::report_fatal_error("pattern '" + getDebugName() +
1908  "' does not support 1:N conversion");
1909  oneToOneOperands.push_back(operand.front());
1910  }
1911  return oneToOneOperands;
1912 }
1913 
1914 LogicalResult
1916  PatternRewriter &rewriter) const {
1917  auto &dialectRewriter = static_cast<ConversionPatternRewriter &>(rewriter);
1918  auto &rewriterImpl = dialectRewriter.getImpl();
1919 
1920  // Track the current conversion pattern type converter in the rewriter.
1921  llvm::SaveAndRestore currentConverterGuard(rewriterImpl.currentTypeConverter,
1922  getTypeConverter());
1923 
1924  // Remap the operands of the operation.
1925  SmallVector<ValueVector> remapped;
1926  if (failed(rewriterImpl.remapValues("operand", op->getLoc(), rewriter,
1927  op->getOperands(), remapped))) {
1928  return failure();
1929  }
1930  SmallVector<ValueRange> remappedAsRange =
1931  llvm::to_vector_of<ValueRange>(remapped);
1932  return matchAndRewrite(op, remappedAsRange, dialectRewriter);
1933 }
1934 
1935 //===----------------------------------------------------------------------===//
1936 // OperationLegalizer
1937 //===----------------------------------------------------------------------===//
1938 
1939 namespace {
1940 /// A set of rewrite patterns that can be used to legalize a given operation.
1941 using LegalizationPatterns = SmallVector<const Pattern *, 1>;
1942 
1943 /// This class defines a recursive operation legalizer.
1944 class OperationLegalizer {
1945 public:
1946  using LegalizationAction = ConversionTarget::LegalizationAction;
1947 
1948  OperationLegalizer(const ConversionTarget &targetInfo,
1950  const ConversionConfig &config);
1951 
1952  /// Returns true if the given operation is known to be illegal on the target.
1953  bool isIllegal(Operation *op) const;
1954 
1955  /// Attempt to legalize the given operation. Returns success if the operation
1956  /// was legalized, failure otherwise.
1957  LogicalResult legalize(Operation *op, ConversionPatternRewriter &rewriter);
1958 
1959  /// Returns the conversion target in use by the legalizer.
1960  const ConversionTarget &getTarget() { return target; }
1961 
1962 private:
1963  /// Attempt to legalize the given operation by folding it.
1964  LogicalResult legalizeWithFold(Operation *op,
1965  ConversionPatternRewriter &rewriter);
1966 
1967  /// Attempt to legalize the given operation by applying a pattern. Returns
1968  /// success if the operation was legalized, failure otherwise.
1969  LogicalResult legalizeWithPattern(Operation *op,
1970  ConversionPatternRewriter &rewriter);
1971 
1972  /// Return true if the given pattern may be applied to the given operation,
1973  /// false otherwise.
1974  bool canApplyPattern(Operation *op, const Pattern &pattern,
1975  ConversionPatternRewriter &rewriter);
1976 
1977  /// Legalize the resultant IR after successfully applying the given pattern.
1978  LogicalResult legalizePatternResult(Operation *op, const Pattern &pattern,
1979  ConversionPatternRewriter &rewriter,
1980  const SetVector<Operation *> &newOps,
1981  const SetVector<Operation *> &modifiedOps,
1982  const SetVector<Block *> &insertedBlocks);
1983 
1984  /// Legalizes the actions registered during the execution of a pattern.
1985  LogicalResult
1986  legalizePatternBlockRewrites(Operation *op,
1987  ConversionPatternRewriter &rewriter,
1989  const SetVector<Block *> &insertedBlocks,
1990  const SetVector<Operation *> &newOps);
1991  LogicalResult
1992  legalizePatternCreatedOperations(ConversionPatternRewriter &rewriter,
1994  const SetVector<Operation *> &newOps);
1995  LogicalResult
1996  legalizePatternRootUpdates(ConversionPatternRewriter &rewriter,
1998  const SetVector<Operation *> &modifiedOps);
1999 
2000  //===--------------------------------------------------------------------===//
2001  // Cost Model
2002  //===--------------------------------------------------------------------===//
2003 
2004  /// Build an optimistic legalization graph given the provided patterns. This
2005  /// function populates 'anyOpLegalizerPatterns' and 'legalizerPatterns' with
2006  /// patterns for operations that are not directly legal, but may be
2007  /// transitively legal for the current target given the provided patterns.
2008  void buildLegalizationGraph(
2009  LegalizationPatterns &anyOpLegalizerPatterns,
2011 
2012  /// Compute the benefit of each node within the computed legalization graph.
2013  /// This orders the patterns within 'legalizerPatterns' based upon two
2014  /// criteria:
2015  /// 1) Prefer patterns that have the lowest legalization depth, i.e.
2016  /// represent the more direct mapping to the target.
2017  /// 2) When comparing patterns with the same legalization depth, prefer the
2018  /// pattern with the highest PatternBenefit. This allows for users to
2019  /// prefer specific legalizations over others.
2020  void computeLegalizationGraphBenefit(
2021  LegalizationPatterns &anyOpLegalizerPatterns,
2023 
2024  /// Compute the legalization depth when legalizing an operation of the given
2025  /// type.
2026  unsigned computeOpLegalizationDepth(
2027  OperationName op, DenseMap<OperationName, unsigned> &minOpPatternDepth,
2029 
2030  /// Apply the conversion cost model to the given set of patterns, and return
2031  /// the smallest legalization depth of any of the patterns. See
2032  /// `computeLegalizationGraphBenefit` for the breakdown of the cost model.
2033  unsigned applyCostModelToPatterns(
2034  LegalizationPatterns &patterns,
2035  DenseMap<OperationName, unsigned> &minOpPatternDepth,
2037 
2038  /// The current set of patterns that have been applied.
2039  SmallPtrSet<const Pattern *, 8> appliedPatterns;
2040 
2041  /// The legalization information provided by the target.
2042  const ConversionTarget &target;
2043 
2044  /// The pattern applicator to use for conversions.
2045  PatternApplicator applicator;
2046 
2047  /// Dialect conversion configuration.
2048  const ConversionConfig &config;
2049 };
2050 } // namespace
2051 
2052 OperationLegalizer::OperationLegalizer(const ConversionTarget &targetInfo,
2054  const ConversionConfig &config)
2055  : target(targetInfo), applicator(patterns), config(config) {
2056  // The set of patterns that can be applied to illegal operations to transform
2057  // them into legal ones.
2059  LegalizationPatterns anyOpLegalizerPatterns;
2060 
2061  buildLegalizationGraph(anyOpLegalizerPatterns, legalizerPatterns);
2062  computeLegalizationGraphBenefit(anyOpLegalizerPatterns, legalizerPatterns);
2063 }
2064 
2065 bool OperationLegalizer::isIllegal(Operation *op) const {
2066  return target.isIllegal(op);
2067 }
2068 
2069 LogicalResult
2070 OperationLegalizer::legalize(Operation *op,
2071  ConversionPatternRewriter &rewriter) {
2072 #ifndef NDEBUG
2073  const char *logLineComment =
2074  "//===-------------------------------------------===//\n";
2075 
2076  auto &logger = rewriter.getImpl().logger;
2077 #endif
2078  LLVM_DEBUG({
2079  logger.getOStream() << "\n";
2080  logger.startLine() << logLineComment;
2081  logger.startLine() << "Legalizing operation : '" << op->getName() << "'("
2082  << op << ") {\n";
2083  logger.indent();
2084 
2085  // If the operation has no regions, just print it here.
2086  if (op->getNumRegions() == 0) {
2087  op->print(logger.startLine(), OpPrintingFlags().printGenericOpForm());
2088  logger.getOStream() << "\n\n";
2089  }
2090  });
2091 
2092  // Check if this operation is legal on the target.
2093  if (auto legalityInfo = target.isLegal(op)) {
2094  LLVM_DEBUG({
2095  logSuccess(
2096  logger, "operation marked legal by the target{0}",
2097  legalityInfo->isRecursivelyLegal
2098  ? "; NOTE: operation is recursively legal; skipping internals"
2099  : "");
2100  logger.startLine() << logLineComment;
2101  });
2102 
2103  // If this operation is recursively legal, mark its children as ignored so
2104  // that we don't consider them for legalization.
2105  if (legalityInfo->isRecursivelyLegal) {
2106  op->walk([&](Operation *nested) {
2107  if (op != nested)
2108  rewriter.getImpl().ignoredOps.insert(nested);
2109  });
2110  }
2111 
2112  return success();
2113  }
2114 
2115  // Check to see if the operation is ignored and doesn't need to be converted.
2116  if (rewriter.getImpl().isOpIgnored(op)) {
2117  LLVM_DEBUG({
2118  logSuccess(logger, "operation marked 'ignored' during conversion");
2119  logger.startLine() << logLineComment;
2120  });
2121  return success();
2122  }
2123 
2124  // If the operation isn't legal, try to fold it in-place.
2125  // TODO: Should we always try to do this, even if the op is
2126  // already legal?
2127  if (succeeded(legalizeWithFold(op, rewriter))) {
2128  LLVM_DEBUG({
2129  logSuccess(logger, "operation was folded");
2130  logger.startLine() << logLineComment;
2131  });
2132  return success();
2133  }
2134 
2135  // Otherwise, we need to apply a legalization pattern to this operation.
2136  if (succeeded(legalizeWithPattern(op, rewriter))) {
2137  LLVM_DEBUG({
2138  logSuccess(logger, "");
2139  logger.startLine() << logLineComment;
2140  });
2141  return success();
2142  }
2143 
2144  LLVM_DEBUG({
2145  logFailure(logger, "no matched legalization pattern");
2146  logger.startLine() << logLineComment;
2147  });
2148  return failure();
2149 }
2150 
2151 /// Helper function that moves and returns the given object. Also resets the
2152 /// original object, so that it is in a valid, empty state again.
2153 template <typename T>
2154 static T moveAndReset(T &obj) {
2155  T result = std::move(obj);
2156  obj = T();
2157  return result;
2158 }
2159 
2160 LogicalResult
2161 OperationLegalizer::legalizeWithFold(Operation *op,
2162  ConversionPatternRewriter &rewriter) {
2163  auto &rewriterImpl = rewriter.getImpl();
2164  LLVM_DEBUG({
2165  rewriterImpl.logger.startLine() << "* Fold {\n";
2166  rewriterImpl.logger.indent();
2167  });
2168  (void)rewriterImpl;
2169 
2170  // Try to fold the operation.
2171  SmallVector<Value, 2> replacementValues;
2172  SmallVector<Operation *, 2> newOps;
2173  rewriter.setInsertionPoint(op);
2174  if (failed(rewriter.tryFold(op, replacementValues, &newOps))) {
2175  LLVM_DEBUG(logFailure(rewriterImpl.logger, "unable to fold"));
2176  return failure();
2177  }
2178 
2179  // An empty list of replacement values indicates that the fold was in-place.
2180  // As the operation changed, a new legalization needs to be attempted.
2181  if (replacementValues.empty())
2182  return legalize(op, rewriter);
2183 
2184  // Recursively legalize any new constant operations.
2185  for (Operation *newOp : newOps) {
2186  if (failed(legalize(newOp, rewriter))) {
2187  LLVM_DEBUG(logFailure(rewriterImpl.logger,
2188  "failed to legalize generated constant '{0}'",
2189  newOp->getName()));
2190  // Legalization failed: erase all materialized constants.
2191  for (Operation *op : newOps)
2192  rewriter.eraseOp(op);
2193  return failure();
2194  }
2195  }
2196 
2197  // Insert a replacement for 'op' with the folded replacement values.
2198  rewriter.replaceOp(op, replacementValues);
2199 
2200  LLVM_DEBUG(logSuccess(rewriterImpl.logger, ""));
2201  return success();
2202 }
2203 
2204 LogicalResult
2205 OperationLegalizer::legalizeWithPattern(Operation *op,
2206  ConversionPatternRewriter &rewriter) {
2207  auto &rewriterImpl = rewriter.getImpl();
2208 
2209  // Functor that returns if the given pattern may be applied.
2210  auto canApply = [&](const Pattern &pattern) {
2211  bool canApply = canApplyPattern(op, pattern, rewriter);
2212  if (canApply && config.listener)
2213  config.listener->notifyPatternBegin(pattern, op);
2214  return canApply;
2215  };
2216 
2217  // Functor that cleans up the rewriter state after a pattern failed to match.
2218  RewriterState curState = rewriterImpl.getCurrentState();
2219  auto onFailure = [&](const Pattern &pattern) {
2220  assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates");
2221  rewriterImpl.patternNewOps.clear();
2222  rewriterImpl.patternModifiedOps.clear();
2223  rewriterImpl.patternInsertedBlocks.clear();
2224  LLVM_DEBUG({
2225  logFailure(rewriterImpl.logger, "pattern failed to match");
2226  if (rewriterImpl.config.notifyCallback) {
2228  diag << "Failed to apply pattern \"" << pattern.getDebugName()
2229  << "\" on op:\n"
2230  << *op;
2231  rewriterImpl.config.notifyCallback(diag);
2232  }
2233  });
2234  if (config.listener)
2235  config.listener->notifyPatternEnd(pattern, failure());
2236  rewriterImpl.resetState(curState, pattern.getDebugName());
2237  appliedPatterns.erase(&pattern);
2238  };
2239 
2240  // Functor that performs additional legalization when a pattern is
2241  // successfully applied.
2242  auto onSuccess = [&](const Pattern &pattern) {
2243  assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates");
2244  SetVector<Operation *> newOps = moveAndReset(rewriterImpl.patternNewOps);
2245  SetVector<Operation *> modifiedOps =
2246  moveAndReset(rewriterImpl.patternModifiedOps);
2247  SetVector<Block *> insertedBlocks =
2248  moveAndReset(rewriterImpl.patternInsertedBlocks);
2249  auto result = legalizePatternResult(op, pattern, rewriter, newOps,
2250  modifiedOps, insertedBlocks);
2251  appliedPatterns.erase(&pattern);
2252  if (failed(result)) {
2253  if (!rewriterImpl.config.allowPatternRollback)
2254  op->emitError("pattern '")
2255  << pattern.getDebugName()
2256  << "' produced IR that could not be legalized";
2257  rewriterImpl.resetState(curState, pattern.getDebugName());
2258  }
2259  if (config.listener)
2260  config.listener->notifyPatternEnd(pattern, result);
2261  return result;
2262  };
2263 
2264  // Try to match and rewrite a pattern on this operation.
2265  return applicator.matchAndRewrite(op, rewriter, canApply, onFailure,
2266  onSuccess);
2267 }
2268 
2269 bool OperationLegalizer::canApplyPattern(Operation *op, const Pattern &pattern,
2270  ConversionPatternRewriter &rewriter) {
2271  LLVM_DEBUG({
2272  auto &os = rewriter.getImpl().logger;
2273  os.getOStream() << "\n";
2274  os.startLine() << "* Pattern : '" << op->getName() << " -> (";
2275  llvm::interleaveComma(pattern.getGeneratedOps(), os.getOStream());
2276  os.getOStream() << ")' {\n";
2277  os.indent();
2278  });
2279 
2280  // Ensure that we don't cycle by not allowing the same pattern to be
2281  // applied twice in the same recursion stack if it is not known to be safe.
2282  if (!pattern.hasBoundedRewriteRecursion() &&
2283  !appliedPatterns.insert(&pattern).second) {
2284  LLVM_DEBUG(
2285  logFailure(rewriter.getImpl().logger, "pattern was already applied"));
2286  return false;
2287  }
2288  return true;
2289 }
2290 
2291 LogicalResult OperationLegalizer::legalizePatternResult(
2292  Operation *op, const Pattern &pattern, ConversionPatternRewriter &rewriter,
2293  const SetVector<Operation *> &newOps,
2294  const SetVector<Operation *> &modifiedOps,
2295  const SetVector<Block *> &insertedBlocks) {
2296  auto &impl = rewriter.getImpl();
2297  assert(impl.pendingRootUpdates.empty() && "dangling root updates");
2298 
2299 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2300  // Check that the root was either replaced or updated in place.
2301  auto newRewrites = llvm::drop_begin(impl.rewrites, curState.numRewrites);
2302  auto replacedRoot = [&] {
2303  return hasRewrite<ReplaceOperationRewrite>(newRewrites, op);
2304  };
2305  auto updatedRootInPlace = [&] {
2306  return hasRewrite<ModifyOperationRewrite>(newRewrites, op);
2307  };
2308  if (!replacedRoot() && !updatedRootInPlace())
2309  llvm::report_fatal_error("expected pattern to replace the root operation");
2310 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2311 
2312  // Legalize each of the actions registered during application.
2313  if (failed(legalizePatternBlockRewrites(op, rewriter, impl, insertedBlocks,
2314  newOps)) ||
2315  failed(legalizePatternRootUpdates(rewriter, impl, modifiedOps)) ||
2316  failed(legalizePatternCreatedOperations(rewriter, impl, newOps))) {
2317  return failure();
2318  }
2319 
2320  LLVM_DEBUG(logSuccess(impl.logger, "pattern applied successfully"));
2321  return success();
2322 }
2323 
2324 LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
2325  Operation *op, ConversionPatternRewriter &rewriter,
2327  const SetVector<Block *> &insertedBlocks,
2328  const SetVector<Operation *> &newOps) {
2329  SmallPtrSet<Operation *, 16> alreadyLegalized;
2330 
2331  // If the pattern moved or created any blocks, make sure the types of block
2332  // arguments get legalized.
2333  for (Block *block : insertedBlocks) {
2334  // Only check blocks outside of the current operation.
2335  Operation *parentOp = block->getParentOp();
2336  if (!parentOp || parentOp == op || block->getNumArguments() == 0)
2337  continue;
2338 
2339  // If the region of the block has a type converter, try to convert the block
2340  // directly.
2341  if (auto *converter = impl.regionToConverter.lookup(block->getParent())) {
2342  std::optional<TypeConverter::SignatureConversion> conversion =
2343  converter->convertBlockSignature(block);
2344  if (!conversion) {
2345  LLVM_DEBUG(logFailure(impl.logger, "failed to convert types of moved "
2346  "block"));
2347  return failure();
2348  }
2349  impl.applySignatureConversion(rewriter, block, converter, *conversion);
2350  continue;
2351  }
2352 
2353  // Otherwise, try to legalize the parent operation if it was not generated
2354  // by this pattern. This is because we will attempt to legalize the parent
2355  // operation, and blocks in regions created by this pattern will already be
2356  // legalized later on.
2357  if (!newOps.count(parentOp) && alreadyLegalized.insert(parentOp).second) {
2358  if (failed(legalize(parentOp, rewriter))) {
2359  LLVM_DEBUG(logFailure(
2360  impl.logger, "operation '{0}'({1}) became illegal after rewrite",
2361  parentOp->getName(), parentOp));
2362  return failure();
2363  }
2364  }
2365  }
2366  return success();
2367 }
2368 
2369 LogicalResult OperationLegalizer::legalizePatternCreatedOperations(
2371  const SetVector<Operation *> &newOps) {
2372  for (Operation *op : newOps) {
2373  if (failed(legalize(op, rewriter))) {
2374  LLVM_DEBUG(logFailure(impl.logger,
2375  "failed to legalize generated operation '{0}'({1})",
2376  op->getName(), op));
2377  return failure();
2378  }
2379  }
2380  return success();
2381 }
2382 
2383 LogicalResult OperationLegalizer::legalizePatternRootUpdates(
2385  const SetVector<Operation *> &modifiedOps) {
2386  for (Operation *op : modifiedOps) {
2387  if (failed(legalize(op, rewriter))) {
2388  LLVM_DEBUG(logFailure(
2389  impl.logger, "failed to legalize operation updated in-place '{0}'",
2390  op->getName()));
2391  return failure();
2392  }
2393  }
2394  return success();
2395 }
2396 
2397 //===----------------------------------------------------------------------===//
2398 // Cost Model
2399 //===----------------------------------------------------------------------===//
2400 
2401 void OperationLegalizer::buildLegalizationGraph(
2402  LegalizationPatterns &anyOpLegalizerPatterns,
2403  DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
2404  // A mapping between an operation and a set of operations that can be used to
2405  // generate it.
2407  // A mapping between an operation and any currently invalid patterns it has.
2409  // A worklist of patterns to consider for legality.
2410  SetVector<const Pattern *> patternWorklist;
2411 
2412  // Build the mapping from operations to the parent ops that may generate them.
2413  applicator.walkAllPatterns([&](const Pattern &pattern) {
2414  std::optional<OperationName> root = pattern.getRootKind();
2415 
2416  // If the pattern has no specific root, we can't analyze the relationship
2417  // between the root op and generated operations. Given that, add all such
2418  // patterns to the legalization set.
2419  if (!root) {
2420  anyOpLegalizerPatterns.push_back(&pattern);
2421  return;
2422  }
2423 
2424  // Skip operations that are always known to be legal.
2425  if (target.getOpAction(*root) == LegalizationAction::Legal)
2426  return;
2427 
2428  // Add this pattern to the invalid set for the root op and record this root
2429  // as a parent for any generated operations.
2430  invalidPatterns[*root].insert(&pattern);
2431  for (auto op : pattern.getGeneratedOps())
2432  parentOps[op].insert(*root);
2433 
2434  // Add this pattern to the worklist.
2435  patternWorklist.insert(&pattern);
2436  });
2437 
2438  // If there are any patterns that don't have a specific root kind, we can't
2439  // make direct assumptions about what operations will never be legalized.
2440  // Note: Technically we could, but it would require an analysis that may
2441  // recurse into itself. It would be better to perform this kind of filtering
2442  // at a higher level than here anyways.
2443  if (!anyOpLegalizerPatterns.empty()) {
2444  for (const Pattern *pattern : patternWorklist)
2445  legalizerPatterns[*pattern->getRootKind()].push_back(pattern);
2446  return;
2447  }
2448 
2449  while (!patternWorklist.empty()) {
2450  auto *pattern = patternWorklist.pop_back_val();
2451 
2452  // Check to see if any of the generated operations are invalid.
2453  if (llvm::any_of(pattern->getGeneratedOps(), [&](OperationName op) {
2454  std::optional<LegalizationAction> action = target.getOpAction(op);
2455  return !legalizerPatterns.count(op) &&
2456  (!action || action == LegalizationAction::Illegal);
2457  }))
2458  continue;
2459 
2460  // Otherwise, if all of the generated operation are valid, this op is now
2461  // legal so add all of the child patterns to the worklist.
2462  legalizerPatterns[*pattern->getRootKind()].push_back(pattern);
2463  invalidPatterns[*pattern->getRootKind()].erase(pattern);
2464 
2465  // Add any invalid patterns of the parent operations to see if they have now
2466  // become legal.
2467  for (auto op : parentOps[*pattern->getRootKind()])
2468  patternWorklist.set_union(invalidPatterns[op]);
2469  }
2470 }
2471 
2472 void OperationLegalizer::computeLegalizationGraphBenefit(
2473  LegalizationPatterns &anyOpLegalizerPatterns,
2474  DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
2475  // The smallest pattern depth, when legalizing an operation.
2476  DenseMap<OperationName, unsigned> minOpPatternDepth;
2477 
2478  // For each operation that is transitively legal, compute a cost for it.
2479  for (auto &opIt : legalizerPatterns)
2480  if (!minOpPatternDepth.count(opIt.first))
2481  computeOpLegalizationDepth(opIt.first, minOpPatternDepth,
2482  legalizerPatterns);
2483 
2484  // Apply the cost model to the patterns that can match any operation. Those
2485  // with a specific operation type are already resolved when computing the op
2486  // legalization depth.
2487  if (!anyOpLegalizerPatterns.empty())
2488  applyCostModelToPatterns(anyOpLegalizerPatterns, minOpPatternDepth,
2489  legalizerPatterns);
2490 
2491  // Apply a cost model to the pattern applicator. We order patterns first by
2492  // depth then benefit. `legalizerPatterns` contains per-op patterns by
2493  // decreasing benefit.
2494  applicator.applyCostModel([&](const Pattern &pattern) {
2495  ArrayRef<const Pattern *> orderedPatternList;
2496  if (std::optional<OperationName> rootName = pattern.getRootKind())
2497  orderedPatternList = legalizerPatterns[*rootName];
2498  else
2499  orderedPatternList = anyOpLegalizerPatterns;
2500 
2501  // If the pattern is not found, then it was removed and cannot be matched.
2502  auto *it = llvm::find(orderedPatternList, &pattern);
2503  if (it == orderedPatternList.end())
2505 
2506  // Patterns found earlier in the list have higher benefit.
2507  return PatternBenefit(std::distance(it, orderedPatternList.end()));
2508  });
2509 }
2510 
2511 unsigned OperationLegalizer::computeOpLegalizationDepth(
2512  OperationName op, DenseMap<OperationName, unsigned> &minOpPatternDepth,
2513  DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
2514  // Check for existing depth.
2515  auto depthIt = minOpPatternDepth.find(op);
2516  if (depthIt != minOpPatternDepth.end())
2517  return depthIt->second;
2518 
2519  // If a mapping for this operation does not exist, then this operation
2520  // is always legal. Return 0 as the depth for a directly legal operation.
2521  auto opPatternsIt = legalizerPatterns.find(op);
2522  if (opPatternsIt == legalizerPatterns.end() || opPatternsIt->second.empty())
2523  return 0u;
2524 
2525  // Record this initial depth in case we encounter this op again when
2526  // recursively computing the depth.
2527  minOpPatternDepth.try_emplace(op, std::numeric_limits<unsigned>::max());
2528 
2529  // Apply the cost model to the operation patterns, and update the minimum
2530  // depth.
2531  unsigned minDepth = applyCostModelToPatterns(
2532  opPatternsIt->second, minOpPatternDepth, legalizerPatterns);
2533  minOpPatternDepth[op] = minDepth;
2534  return minDepth;
2535 }
2536 
2537 unsigned OperationLegalizer::applyCostModelToPatterns(
2538  LegalizationPatterns &patterns,
2539  DenseMap<OperationName, unsigned> &minOpPatternDepth,
2540  DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
2541  unsigned minDepth = std::numeric_limits<unsigned>::max();
2542 
2543  // Compute the depth for each pattern within the set.
2544  SmallVector<std::pair<const Pattern *, unsigned>, 4> patternsByDepth;
2545  patternsByDepth.reserve(patterns.size());
2546  for (const Pattern *pattern : patterns) {
2547  unsigned depth = 1;
2548  for (auto generatedOp : pattern->getGeneratedOps()) {
2549  unsigned generatedOpDepth = computeOpLegalizationDepth(
2550  generatedOp, minOpPatternDepth, legalizerPatterns);
2551  depth = std::max(depth, generatedOpDepth + 1);
2552  }
2553  patternsByDepth.emplace_back(pattern, depth);
2554 
2555  // Update the minimum depth of the pattern list.
2556  minDepth = std::min(minDepth, depth);
2557  }
2558 
2559  // If the operation only has one legalization pattern, there is no need to
2560  // sort them.
2561  if (patternsByDepth.size() == 1)
2562  return minDepth;
2563 
2564  // Sort the patterns by those likely to be the most beneficial.
2565  llvm::stable_sort(patternsByDepth,
2566  [](const std::pair<const Pattern *, unsigned> &lhs,
2567  const std::pair<const Pattern *, unsigned> &rhs) {
2568  // First sort by the smaller pattern legalization
2569  // depth.
2570  if (lhs.second != rhs.second)
2571  return lhs.second < rhs.second;
2572 
2573  // Then sort by the larger pattern benefit.
2574  auto lhsBenefit = lhs.first->getBenefit();
2575  auto rhsBenefit = rhs.first->getBenefit();
2576  return lhsBenefit > rhsBenefit;
2577  });
2578 
2579  // Update the legalization pattern to use the new sorted list.
2580  patterns.clear();
2581  for (auto &patternIt : patternsByDepth)
2582  patterns.push_back(patternIt.first);
2583  return minDepth;
2584 }
2585 
2586 //===----------------------------------------------------------------------===//
2587 // OperationConverter
2588 //===----------------------------------------------------------------------===//
2589 namespace {
2590 enum OpConversionMode {
2591  /// In this mode, the conversion will ignore failed conversions to allow
2592  /// illegal operations to co-exist in the IR.
2593  Partial,
2594 
2595  /// In this mode, all operations must be legal for the given target for the
2596  /// conversion to succeed.
2597  Full,
2598 
2599  /// In this mode, operations are analyzed for legality. No actual rewrites are
2600  /// applied to the operations on success.
2601  Analysis,
2602 };
2603 } // namespace
2604 
2605 namespace mlir {
2606 // This class converts operations to a given conversion target via a set of
2607 // rewrite patterns. The conversion behaves differently depending on the
2608 // conversion mode.
2610  explicit OperationConverter(const ConversionTarget &target,
2612  const ConversionConfig &config,
2613  OpConversionMode mode)
2614  : config(config), opLegalizer(target, patterns, this->config),
2615  mode(mode) {}
2616 
2617  /// Converts the given operations to the conversion target.
2618  LogicalResult convertOperations(ArrayRef<Operation *> ops);
2619 
2620 private:
2621  /// Converts an operation with the given rewriter.
2622  LogicalResult convert(ConversionPatternRewriter &rewriter, Operation *op);
2623 
2624  /// Dialect conversion configuration.
2625  ConversionConfig config;
2626 
2627  /// The legalizer to use when converting operations.
2628  OperationLegalizer opLegalizer;
2629 
2630  /// The conversion mode to use when legalizing operations.
2631  OpConversionMode mode;
2632 };
2633 } // namespace mlir
2634 
2635 LogicalResult OperationConverter::convert(ConversionPatternRewriter &rewriter,
2636  Operation *op) {
2637  // Legalize the given operation.
2638  if (failed(opLegalizer.legalize(op, rewriter))) {
2639  // Handle the case of a failed conversion for each of the different modes.
2640  // Full conversions expect all operations to be converted.
2641  if (mode == OpConversionMode::Full)
2642  return op->emitError()
2643  << "failed to legalize operation '" << op->getName() << "'";
2644  // Partial conversions allow conversions to fail iff the operation was not
2645  // explicitly marked as illegal. If the user provided a `unlegalizedOps`
2646  // set, non-legalizable ops are added to that set.
2647  if (mode == OpConversionMode::Partial) {
2648  if (opLegalizer.isIllegal(op))
2649  return op->emitError()
2650  << "failed to legalize operation '" << op->getName()
2651  << "' that was explicitly marked illegal";
2652  if (config.unlegalizedOps)
2653  config.unlegalizedOps->insert(op);
2654  }
2655  } else if (mode == OpConversionMode::Analysis) {
2656  // Analysis conversions don't fail if any operations fail to legalize,
2657  // they are only interested in the operations that were successfully
2658  // legalized.
2659  if (config.legalizableOps)
2660  config.legalizableOps->insert(op);
2661  }
2662  return success();
2663 }
2664 
2665 static LogicalResult
2667  UnresolvedMaterializationRewrite *rewrite) {
2668  UnrealizedConversionCastOp op = rewrite->getOperation();
2669  assert(!op.use_empty() &&
2670  "expected that dead materializations have already been DCE'd");
2671  Operation::operand_range inputOperands = op.getOperands();
2672 
2673  // Try to materialize the conversion.
2674  if (const TypeConverter *converter = rewrite->getConverter()) {
2675  rewriter.setInsertionPoint(op);
2676  SmallVector<Value> newMaterialization;
2677  switch (rewrite->getMaterializationKind()) {
2678  case MaterializationKind::Target:
2679  newMaterialization = converter->materializeTargetConversion(
2680  rewriter, op->getLoc(), op.getResultTypes(), inputOperands,
2681  rewrite->getOriginalType());
2682  break;
2683  case MaterializationKind::Source:
2684  assert(op->getNumResults() == 1 && "expected single result");
2685  Value sourceMat = converter->materializeSourceConversion(
2686  rewriter, op->getLoc(), op.getResultTypes().front(), inputOperands);
2687  if (sourceMat)
2688  newMaterialization.push_back(sourceMat);
2689  break;
2690  }
2691  if (!newMaterialization.empty()) {
2692 #ifndef NDEBUG
2693  ValueRange newMaterializationRange(newMaterialization);
2694  assert(TypeRange(newMaterializationRange) == op.getResultTypes() &&
2695  "materialization callback produced value of incorrect type");
2696 #endif // NDEBUG
2697  rewriter.replaceOp(op, newMaterialization);
2698  return success();
2699  }
2700  }
2701 
2702  InFlightDiagnostic diag = op->emitError()
2703  << "failed to legalize unresolved materialization "
2704  "from ("
2705  << inputOperands.getTypes() << ") to ("
2706  << op.getResultTypes()
2707  << ") that remained live after conversion";
2708  diag.attachNote(op->getUsers().begin()->getLoc())
2709  << "see existing live user here: " << *op->getUsers().begin();
2710  return failure();
2711 }
2712 
2714  if (ops.empty())
2715  return success();
2716  const ConversionTarget &target = opLegalizer.getTarget();
2717 
2718  // Compute the set of operations and blocks to convert.
2719  SmallVector<Operation *> toConvert;
2720  for (auto *op : ops) {
2722  [&](Operation *op) {
2723  toConvert.push_back(op);
2724  // Don't check this operation's children for conversion if the
2725  // operation is recursively legal.
2726  auto legalityInfo = target.isLegal(op);
2727  if (legalityInfo && legalityInfo->isRecursivelyLegal)
2728  return WalkResult::skip();
2729  return WalkResult::advance();
2730  });
2731  }
2732 
2733  // Convert each operation and discard rewrites on failure.
2734  ConversionPatternRewriter rewriter(ops.front()->getContext(), config);
2735  ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
2736 
2737  for (auto *op : toConvert) {
2738  if (failed(convert(rewriter, op))) {
2739  // Dialect conversion failed.
2740  if (rewriterImpl.config.allowPatternRollback) {
2741  // Rollback is allowed: restore the original IR.
2742  rewriterImpl.undoRewrites();
2743  } else {
2744  // Rollback is not allowed: apply all modifications that have been
2745  // performed so far.
2746  rewriterImpl.applyRewrites();
2747  }
2748  return failure();
2749  }
2750  }
2751 
2752  // After a successful conversion, apply rewrites.
2753  rewriterImpl.applyRewrites();
2754 
2755  // Gather all unresolved materializations.
2758  &materializations = rewriterImpl.unresolvedMaterializations;
2759  for (auto it : materializations)
2760  allCastOps.push_back(it.first);
2761 
2762  // Reconcile all UnrealizedConversionCastOps that were inserted by the
2763  // dialect conversion frameworks. (Not the one that were inserted by
2764  // patterns.)
2765  SmallVector<UnrealizedConversionCastOp> remainingCastOps;
2766  reconcileUnrealizedCasts(allCastOps, &remainingCastOps);
2767 
2768  // Try to legalize all unresolved materializations.
2769  if (config.buildMaterializations) {
2770  IRRewriter rewriter(rewriterImpl.context, config.listener);
2771  for (UnrealizedConversionCastOp castOp : remainingCastOps) {
2772  auto it = materializations.find(castOp);
2773  assert(it != materializations.end() && "inconsistent state");
2774  if (failed(legalizeUnresolvedMaterialization(rewriter, it->second)))
2775  return failure();
2776  }
2777  }
2778 
2779  return success();
2780 }
2781 
2782 //===----------------------------------------------------------------------===//
2783 // Reconcile Unrealized Casts
2784 //===----------------------------------------------------------------------===//
2785 
2788  SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
2789  SetVector<UnrealizedConversionCastOp> worklist(llvm::from_range, castOps);
2790  // This set is maintained only if `remainingCastOps` is provided.
2791  DenseSet<Operation *> erasedOps;
2792 
2793  // Helper function that adds all operands to the worklist that are an
2794  // unrealized_conversion_cast op result.
2795  auto enqueueOperands = [&](UnrealizedConversionCastOp castOp) {
2796  for (Value v : castOp.getInputs())
2797  if (auto inputCastOp = v.getDefiningOp<UnrealizedConversionCastOp>())
2798  worklist.insert(inputCastOp);
2799  };
2800 
2801  // Helper function that return the unrealized_conversion_cast op that
2802  // defines all inputs of the given op (in the same order). Return "nullptr"
2803  // if there is no such op.
2804  auto getInputCast =
2805  [](UnrealizedConversionCastOp castOp) -> UnrealizedConversionCastOp {
2806  if (castOp.getInputs().empty())
2807  return {};
2808  auto inputCastOp =
2809  castOp.getInputs().front().getDefiningOp<UnrealizedConversionCastOp>();
2810  if (!inputCastOp)
2811  return {};
2812  if (inputCastOp.getOutputs() != castOp.getInputs())
2813  return {};
2814  return inputCastOp;
2815  };
2816 
2817  // Process ops in the worklist bottom-to-top.
2818  while (!worklist.empty()) {
2819  UnrealizedConversionCastOp castOp = worklist.pop_back_val();
2820  if (castOp->use_empty()) {
2821  // DCE: If the op has no users, erase it. Add the operands to the
2822  // worklist to find additional DCE opportunities.
2823  enqueueOperands(castOp);
2824  if (remainingCastOps)
2825  erasedOps.insert(castOp.getOperation());
2826  castOp->erase();
2827  continue;
2828  }
2829 
2830  // Traverse the chain of input cast ops to see if an op with the same
2831  // input types can be found.
2832  UnrealizedConversionCastOp nextCast = castOp;
2833  while (nextCast) {
2834  if (nextCast.getInputs().getTypes() == castOp.getResultTypes()) {
2835  // Found a cast where the input types match the output types of the
2836  // matched op. We can directly use those inputs and the matched op can
2837  // be removed.
2838  enqueueOperands(castOp);
2839  castOp.replaceAllUsesWith(nextCast.getInputs());
2840  if (remainingCastOps)
2841  erasedOps.insert(castOp.getOperation());
2842  castOp->erase();
2843  break;
2844  }
2845  nextCast = getInputCast(nextCast);
2846  }
2847  }
2848 
2849  if (remainingCastOps)
2850  for (UnrealizedConversionCastOp op : castOps)
2851  if (!erasedOps.contains(op.getOperation()))
2852  remainingCastOps->push_back(op);
2853 }
2854 
2855 //===----------------------------------------------------------------------===//
2856 // Type Conversion
2857 //===----------------------------------------------------------------------===//
2858 
2860  ArrayRef<Type> types) {
2861  assert(!types.empty() && "expected valid types");
2862  remapInput(origInputNo, /*newInputNo=*/argTypes.size(), types.size());
2863  addInputs(types);
2864 }
2865 
2867  assert(!types.empty() &&
2868  "1->0 type remappings don't need to be added explicitly");
2869  argTypes.append(types.begin(), types.end());
2870 }
2871 
2872 void TypeConverter::SignatureConversion::remapInput(unsigned origInputNo,
2873  unsigned newInputNo,
2874  unsigned newInputCount) {
2875  assert(!remappedInputs[origInputNo] && "input has already been remapped");
2876  assert(newInputCount != 0 && "expected valid input count");
2877  remappedInputs[origInputNo] =
2878  InputMapping{newInputNo, newInputCount, /*replacementValues=*/{}};
2879 }
2880 
2882  unsigned origInputNo, ArrayRef<Value> replacements) {
2883  assert(!remappedInputs[origInputNo] && "input has already been remapped");
2884  remappedInputs[origInputNo] = InputMapping{
2885  origInputNo, /*size=*/0,
2886  SmallVector<Value, 1>(replacements.begin(), replacements.end())};
2887 }
2888 
2890  SmallVectorImpl<Type> &results) const {
2891  assert(t && "expected non-null type");
2892 
2893  {
2894  std::shared_lock<decltype(cacheMutex)> cacheReadLock(cacheMutex,
2895  std::defer_lock);
2897  cacheReadLock.lock();
2898  auto existingIt = cachedDirectConversions.find(t);
2899  if (existingIt != cachedDirectConversions.end()) {
2900  if (existingIt->second)
2901  results.push_back(existingIt->second);
2902  return success(existingIt->second != nullptr);
2903  }
2904  auto multiIt = cachedMultiConversions.find(t);
2905  if (multiIt != cachedMultiConversions.end()) {
2906  results.append(multiIt->second.begin(), multiIt->second.end());
2907  return success();
2908  }
2909  }
2910  // Walk the added converters in reverse order to apply the most recently
2911  // registered first.
2912  size_t currentCount = results.size();
2913 
2914  std::unique_lock<decltype(cacheMutex)> cacheWriteLock(cacheMutex,
2915  std::defer_lock);
2916 
2917  for (const ConversionCallbackFn &converter : llvm::reverse(conversions)) {
2918  if (std::optional<LogicalResult> result = converter(t, results)) {
2920  cacheWriteLock.lock();
2921  if (!succeeded(*result)) {
2922  assert(results.size() == currentCount &&
2923  "failed type conversion should not change results");
2924  cachedDirectConversions.try_emplace(t, nullptr);
2925  return failure();
2926  }
2927  auto newTypes = ArrayRef<Type>(results).drop_front(currentCount);
2928  if (newTypes.size() == 1)
2929  cachedDirectConversions.try_emplace(t, newTypes.front());
2930  else
2931  cachedMultiConversions.try_emplace(t, llvm::to_vector<2>(newTypes));
2932  return success();
2933  } else {
2934  assert(results.size() == currentCount &&
2935  "failed type conversion should not change results");
2936  }
2937  }
2938  return failure();
2939 }
2940 
2942  // Use the multi-type result version to convert the type.
2943  SmallVector<Type, 1> results;
2944  if (failed(convertType(t, results)))
2945  return nullptr;
2946 
2947  // Check to ensure that only one type was produced.
2948  return results.size() == 1 ? results.front() : nullptr;
2949 }
2950 
2951 LogicalResult
2953  SmallVectorImpl<Type> &results) const {
2954  for (Type type : types)
2955  if (failed(convertType(type, results)))
2956  return failure();
2957  return success();
2958 }
2959 
2960 bool TypeConverter::isLegal(Type type) const {
2961  return convertType(type) == type;
2962 }
2964  return isLegal(op->getOperandTypes()) && isLegal(op->getResultTypes());
2965 }
2966 
2967 bool TypeConverter::isLegal(Region *region) const {
2968  return llvm::all_of(*region, [this](Block &block) {
2969  return isLegal(block.getArgumentTypes());
2970  });
2971 }
2972 
2973 bool TypeConverter::isSignatureLegal(FunctionType ty) const {
2974  return isLegal(llvm::concat<const Type>(ty.getInputs(), ty.getResults()));
2975 }
2976 
2977 LogicalResult
2979  SignatureConversion &result) const {
2980  // Try to convert the given input type.
2981  SmallVector<Type, 1> convertedTypes;
2982  if (failed(convertType(type, convertedTypes)))
2983  return failure();
2984 
2985  // If this argument is being dropped, there is nothing left to do.
2986  if (convertedTypes.empty())
2987  return success();
2988 
2989  // Otherwise, add the new inputs.
2990  result.addInputs(inputNo, convertedTypes);
2991  return success();
2992 }
2993 LogicalResult
2995  SignatureConversion &result,
2996  unsigned origInputOffset) const {
2997  for (unsigned i = 0, e = types.size(); i != e; ++i)
2998  if (failed(convertSignatureArg(origInputOffset + i, types[i], result)))
2999  return failure();
3000  return success();
3001 }
3002 
3004  Location loc, Type resultType,
3005  ValueRange inputs) const {
3006  for (const SourceMaterializationCallbackFn &fn :
3007  llvm::reverse(sourceMaterializations))
3008  if (Value result = fn(builder, resultType, inputs, loc))
3009  return result;
3010  return nullptr;
3011 }
3012 
3014  Location loc, Type resultType,
3015  ValueRange inputs,
3016  Type originalType) const {
3018  builder, loc, TypeRange(resultType), inputs, originalType);
3019  if (result.empty())
3020  return nullptr;
3021  assert(result.size() == 1 && "expected single result");
3022  return result.front();
3023 }
3024 
3026  OpBuilder &builder, Location loc, TypeRange resultTypes, ValueRange inputs,
3027  Type originalType) const {
3028  for (const TargetMaterializationCallbackFn &fn :
3029  llvm::reverse(targetMaterializations)) {
3030  SmallVector<Value> result =
3031  fn(builder, resultTypes, inputs, loc, originalType);
3032  if (result.empty())
3033  continue;
3034  assert(TypeRange(ValueRange(result)) == resultTypes &&
3035  "callback produced incorrect number of values or values with "
3036  "incorrect types");
3037  return result;
3038  }
3039  return {};
3040 }
3041 
3042 std::optional<TypeConverter::SignatureConversion>
3044  SignatureConversion conversion(block->getNumArguments());
3045  if (failed(convertSignatureArgs(block->getArgumentTypes(), conversion)))
3046  return std::nullopt;
3047  return conversion;
3048 }
3049 
3050 //===----------------------------------------------------------------------===//
3051 // Type attribute conversion
3052 //===----------------------------------------------------------------------===//
3055  return AttributeConversionResult(attr, resultTag);
3056 }
3057 
3060  return AttributeConversionResult(nullptr, naTag);
3061 }
3062 
3065  return AttributeConversionResult(nullptr, abortTag);
3066 }
3067 
3069  return impl.getInt() == resultTag;
3070 }
3071 
3073  return impl.getInt() == naTag;
3074 }
3075 
3077  return impl.getInt() == abortTag;
3078 }
3079 
3081  assert(hasResult() && "Cannot get result from N/A or abort");
3082  return impl.getPointer();
3083 }
3084 
3085 std::optional<Attribute>
3087  for (const TypeAttributeConversionCallbackFn &fn :
3088  llvm::reverse(typeAttributeConversions)) {
3089  AttributeConversionResult res = fn(type, attr);
3090  if (res.hasResult())
3091  return res.getResult();
3092  if (res.isAbort())
3093  return std::nullopt;
3094  }
3095  return std::nullopt;
3096 }
3097 
3098 //===----------------------------------------------------------------------===//
3099 // FunctionOpInterfaceSignatureConversion
3100 //===----------------------------------------------------------------------===//
3101 
3102 static LogicalResult convertFuncOpTypes(FunctionOpInterface funcOp,
3103  const TypeConverter &typeConverter,
3104  ConversionPatternRewriter &rewriter) {
3105  FunctionType type = dyn_cast<FunctionType>(funcOp.getFunctionType());
3106  if (!type)
3107  return failure();
3108 
3109  // Convert the original function types.
3110  TypeConverter::SignatureConversion result(type.getNumInputs());
3111  SmallVector<Type, 1> newResults;
3112  if (failed(typeConverter.convertSignatureArgs(type.getInputs(), result)) ||
3113  failed(typeConverter.convertTypes(type.getResults(), newResults)) ||
3114  failed(rewriter.convertRegionTypes(&funcOp.getFunctionBody(),
3115  typeConverter, &result)))
3116  return failure();
3117 
3118  // Update the function signature in-place.
3119  auto newType = FunctionType::get(rewriter.getContext(),
3120  result.getConvertedTypes(), newResults);
3121 
3122  rewriter.modifyOpInPlace(funcOp, [&] { funcOp.setType(newType); });
3123 
3124  return success();
3125 }
3126 
3127 /// Create a default conversion pattern that rewrites the type signature of a
3128 /// FunctionOpInterface op. This only supports ops which use FunctionType to
3129 /// represent their type.
3130 namespace {
3131 struct FunctionOpInterfaceSignatureConversion : public ConversionPattern {
3132  FunctionOpInterfaceSignatureConversion(StringRef functionLikeOpName,
3133  MLIRContext *ctx,
3134  const TypeConverter &converter)
3135  : ConversionPattern(converter, functionLikeOpName, /*benefit=*/1, ctx) {}
3136 
3137  LogicalResult
3138  matchAndRewrite(Operation *op, ArrayRef<Value> /*operands*/,
3139  ConversionPatternRewriter &rewriter) const override {
3140  FunctionOpInterface funcOp = cast<FunctionOpInterface>(op);
3141  return convertFuncOpTypes(funcOp, *typeConverter, rewriter);
3142  }
3143 };
3144 
3145 struct AnyFunctionOpInterfaceSignatureConversion
3146  : public OpInterfaceConversionPattern<FunctionOpInterface> {
3148 
3149  LogicalResult
3150  matchAndRewrite(FunctionOpInterface funcOp, ArrayRef<Value> /*operands*/,
3151  ConversionPatternRewriter &rewriter) const override {
3152  return convertFuncOpTypes(funcOp, *typeConverter, rewriter);
3153  }
3154 };
3155 } // namespace
3156 
3157 FailureOr<Operation *>
3159  const TypeConverter &converter,
3160  ConversionPatternRewriter &rewriter) {
3161  assert(op && "Invalid op");
3162  Location loc = op->getLoc();
3163  if (converter.isLegal(op))
3164  return rewriter.notifyMatchFailure(loc, "op already legal");
3165 
3166  OperationState newOp(loc, op->getName());
3167  newOp.addOperands(operands);
3168 
3169  SmallVector<Type> newResultTypes;
3170  if (failed(converter.convertTypes(op->getResultTypes(), newResultTypes)))
3171  return rewriter.notifyMatchFailure(loc, "couldn't convert return types");
3172 
3173  newOp.addTypes(newResultTypes);
3174  newOp.addAttributes(op->getAttrs());
3175  return rewriter.create(newOp);
3176 }
3177 
3179  StringRef functionLikeOpName, RewritePatternSet &patterns,
3180  const TypeConverter &converter) {
3181  patterns.add<FunctionOpInterfaceSignatureConversion>(
3182  functionLikeOpName, patterns.getContext(), converter);
3183 }
3184 
3186  RewritePatternSet &patterns, const TypeConverter &converter) {
3187  patterns.add<AnyFunctionOpInterfaceSignatureConversion>(
3188  converter, patterns.getContext());
3189 }
3190 
3191 //===----------------------------------------------------------------------===//
3192 // ConversionTarget
3193 //===----------------------------------------------------------------------===//
3194 
3196  LegalizationAction action) {
3197  legalOperations[op].action = action;
3198 }
3199 
3201  LegalizationAction action) {
3202  for (StringRef dialect : dialectNames)
3203  legalDialects[dialect] = action;
3204 }
3205 
3207  -> std::optional<LegalizationAction> {
3208  std::optional<LegalizationInfo> info = getOpInfo(op);
3209  return info ? info->action : std::optional<LegalizationAction>();
3210 }
3211 
3213  -> std::optional<LegalOpDetails> {
3214  std::optional<LegalizationInfo> info = getOpInfo(op->getName());
3215  if (!info)
3216  return std::nullopt;
3217 
3218  // Returns true if this operation instance is known to be legal.
3219  auto isOpLegal = [&] {
3220  // Handle dynamic legality either with the provided legality function.
3221  if (info->action == LegalizationAction::Dynamic) {
3222  std::optional<bool> result = info->legalityFn(op);
3223  if (result)
3224  return *result;
3225  }
3226 
3227  // Otherwise, the operation is only legal if it was marked 'Legal'.
3228  return info->action == LegalizationAction::Legal;
3229  };
3230  if (!isOpLegal())
3231  return std::nullopt;
3232 
3233  // This operation is legal, compute any additional legality information.
3234  LegalOpDetails legalityDetails;
3235  if (info->isRecursivelyLegal) {
3236  auto legalityFnIt = opRecursiveLegalityFns.find(op->getName());
3237  if (legalityFnIt != opRecursiveLegalityFns.end()) {
3238  legalityDetails.isRecursivelyLegal =
3239  legalityFnIt->second(op).value_or(true);
3240  } else {
3241  legalityDetails.isRecursivelyLegal = true;
3242  }
3243  }
3244  return legalityDetails;
3245 }
3246 
3248  std::optional<LegalizationInfo> info = getOpInfo(op->getName());
3249  if (!info)
3250  return false;
3251 
3252  if (info->action == LegalizationAction::Dynamic) {
3253  std::optional<bool> result = info->legalityFn(op);
3254  if (!result)
3255  return false;
3256 
3257  return !(*result);
3258  }
3259 
3260  return info->action == LegalizationAction::Illegal;
3261 }
3262 
3266  if (!oldCallback)
3267  return newCallback;
3268 
3269  auto chain = [oldCl = std::move(oldCallback), newCl = std::move(newCallback)](
3270  Operation *op) -> std::optional<bool> {
3271  if (std::optional<bool> result = newCl(op))
3272  return *result;
3273 
3274  return oldCl(op);
3275  };
3276  return chain;
3277 }
3278 
3279 void ConversionTarget::setLegalityCallback(
3280  OperationName name, const DynamicLegalityCallbackFn &callback) {
3281  assert(callback && "expected valid legality callback");
3282  auto *infoIt = legalOperations.find(name);
3283  assert(infoIt != legalOperations.end() &&
3284  infoIt->second.action == LegalizationAction::Dynamic &&
3285  "expected operation to already be marked as dynamically legal");
3286  infoIt->second.legalityFn =
3287  composeLegalityCallbacks(std::move(infoIt->second.legalityFn), callback);
3288 }
3289 
3291  OperationName name, const DynamicLegalityCallbackFn &callback) {
3292  auto *infoIt = legalOperations.find(name);
3293  assert(infoIt != legalOperations.end() &&
3294  infoIt->second.action != LegalizationAction::Illegal &&
3295  "expected operation to already be marked as legal");
3296  infoIt->second.isRecursivelyLegal = true;
3297  if (callback)
3298  opRecursiveLegalityFns[name] = composeLegalityCallbacks(
3299  std::move(opRecursiveLegalityFns[name]), callback);
3300  else
3301  opRecursiveLegalityFns.erase(name);
3302 }
3303 
3304 void ConversionTarget::setLegalityCallback(
3305  ArrayRef<StringRef> dialects, const DynamicLegalityCallbackFn &callback) {
3306  assert(callback && "expected valid legality callback");
3307  for (StringRef dialect : dialects)
3308  dialectLegalityFns[dialect] = composeLegalityCallbacks(
3309  std::move(dialectLegalityFns[dialect]), callback);
3310 }
3311 
3312 void ConversionTarget::setLegalityCallback(
3313  const DynamicLegalityCallbackFn &callback) {
3314  assert(callback && "expected valid legality callback");
3315  unknownLegalityFn = composeLegalityCallbacks(unknownLegalityFn, callback);
3316 }
3317 
3318 auto ConversionTarget::getOpInfo(OperationName op) const
3319  -> std::optional<LegalizationInfo> {
3320  // Check for info for this specific operation.
3321  const auto *it = legalOperations.find(op);
3322  if (it != legalOperations.end())
3323  return it->second;
3324  // Check for info for the parent dialect.
3325  auto dialectIt = legalDialects.find(op.getDialectNamespace());
3326  if (dialectIt != legalDialects.end()) {
3327  DynamicLegalityCallbackFn callback;
3328  auto dialectFn = dialectLegalityFns.find(op.getDialectNamespace());
3329  if (dialectFn != dialectLegalityFns.end())
3330  callback = dialectFn->second;
3331  return LegalizationInfo{dialectIt->second, /*isRecursivelyLegal=*/false,
3332  callback};
3333  }
3334  // Otherwise, check if we mark unknown operations as dynamic.
3335  if (unknownLegalityFn)
3336  return LegalizationInfo{LegalizationAction::Dynamic,
3337  /*isRecursivelyLegal=*/false, unknownLegalityFn};
3338  return std::nullopt;
3339 }
3340 
3341 #if MLIR_ENABLE_PDL_IN_PATTERNMATCH
3342 //===----------------------------------------------------------------------===//
3343 // PDL Configuration
3344 //===----------------------------------------------------------------------===//
3345 
3347  auto &rewriterImpl =
3348  static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
3349  rewriterImpl.currentTypeConverter = getTypeConverter();
3350 }
3351 
3353  auto &rewriterImpl =
3354  static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
3355  rewriterImpl.currentTypeConverter = nullptr;
3356 }
3357 
3358 /// Remap the given value using the rewriter and the type converter in the
3359 /// provided config.
3360 static FailureOr<SmallVector<Value>>
3362  SmallVector<Value> mappedValues;
3363  if (failed(rewriter.getRemappedValues(values, mappedValues)))
3364  return failure();
3365  return std::move(mappedValues);
3366 }
3367 
3369  patterns.getPDLPatterns().registerRewriteFunction(
3370  "convertValue",
3371  [](PatternRewriter &rewriter, Value value) -> FailureOr<Value> {
3372  auto results = pdllConvertValues(
3373  static_cast<ConversionPatternRewriter &>(rewriter), value);
3374  if (failed(results))
3375  return failure();
3376  return results->front();
3377  });
3378  patterns.getPDLPatterns().registerRewriteFunction(
3379  "convertValues", [](PatternRewriter &rewriter, ValueRange values) {
3380  return pdllConvertValues(
3381  static_cast<ConversionPatternRewriter &>(rewriter), values);
3382  });
3383  patterns.getPDLPatterns().registerRewriteFunction(
3384  "convertType",
3385  [](PatternRewriter &rewriter, Type type) -> FailureOr<Type> {
3386  auto &rewriterImpl =
3387  static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
3388  if (const TypeConverter *converter =
3389  rewriterImpl.currentTypeConverter) {
3390  if (Type newType = converter->convertType(type))
3391  return newType;
3392  return failure();
3393  }
3394  return type;
3395  });
3396  patterns.getPDLPatterns().registerRewriteFunction(
3397  "convertTypes",
3398  [](PatternRewriter &rewriter,
3399  TypeRange types) -> FailureOr<SmallVector<Type>> {
3400  auto &rewriterImpl =
3401  static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
3402  const TypeConverter *converter = rewriterImpl.currentTypeConverter;
3403  if (!converter)
3404  return SmallVector<Type>(types);
3405 
3406  SmallVector<Type> remappedTypes;
3407  if (failed(converter->convertTypes(types, remappedTypes)))
3408  return failure();
3409  return std::move(remappedTypes);
3410  });
3411 }
3412 #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
3413 
3414 //===----------------------------------------------------------------------===//
3415 // Op Conversion Entry Points
3416 //===----------------------------------------------------------------------===//
3417 
3418 //===----------------------------------------------------------------------===//
3419 // Partial Conversion
3420 //===----------------------------------------------------------------------===//
3421 
3423  ArrayRef<Operation *> ops, const ConversionTarget &target,
3425  OperationConverter opConverter(target, patterns, config,
3426  OpConversionMode::Partial);
3427  return opConverter.convertOperations(ops);
3428 }
3429 LogicalResult
3433  return applyPartialConversion(llvm::ArrayRef(op), target, patterns, config);
3434 }
3435 
3436 //===----------------------------------------------------------------------===//
3437 // Full Conversion
3438 //===----------------------------------------------------------------------===//
3439 
3441  const ConversionTarget &target,
3444  OperationConverter opConverter(target, patterns, config,
3445  OpConversionMode::Full);
3446  return opConverter.convertOperations(ops);
3447 }
3449  const ConversionTarget &target,
3452  return applyFullConversion(llvm::ArrayRef(op), target, patterns, config);
3453 }
3454 
3455 //===----------------------------------------------------------------------===//
3456 // Analysis Conversion
3457 //===----------------------------------------------------------------------===//
3458 
3459 /// Find a common IsolatedFromAbove ancestor of the given ops. If at least one
3460 /// op is a top-level module op (which is expected to be isolated from above),
3461 /// return that op.
3463  // Check if there is a top-level operation within `ops`. If so, return that
3464  // op.
3465  for (Operation *op : ops) {
3466  if (!op->getParentOp()) {
3467 #ifndef NDEBUG
3468  assert(op->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
3469  "expected top-level op to be isolated from above");
3470  for (Operation *other : ops)
3471  assert(op->isAncestor(other) &&
3472  "expected ops to have a common ancestor");
3473 #endif // NDEBUG
3474  return op;
3475  }
3476  }
3477 
3478  // No top-level op. Find a common ancestor.
3479  Operation *commonAncestor =
3480  ops.front()->getParentWithTrait<OpTrait::IsIsolatedFromAbove>();
3481  for (Operation *op : ops.drop_front()) {
3482  while (!commonAncestor->isProperAncestor(op)) {
3483  commonAncestor =
3485  assert(commonAncestor &&
3486  "expected to find a common isolated from above ancestor");
3487  }
3488  }
3489 
3490  return commonAncestor;
3491 }
3492 
3496 #ifndef NDEBUG
3497  if (config.legalizableOps)
3498  assert(config.legalizableOps->empty() && "expected empty set");
3499 #endif // NDEBUG
3500 
3501  // Clone closted common ancestor that is isolated from above.
3502  Operation *commonAncestor = findCommonAncestor(ops);
3503  IRMapping mapping;
3504  Operation *clonedAncestor = commonAncestor->clone(mapping);
3505  // Compute inverse IR mapping.
3506  DenseMap<Operation *, Operation *> inverseOperationMap;
3507  for (auto &it : mapping.getOperationMap())
3508  inverseOperationMap[it.second] = it.first;
3509 
3510  // Convert the cloned operations. The original IR will remain unchanged.
3511  SmallVector<Operation *> opsToConvert = llvm::map_to_vector(
3512  ops, [&](Operation *op) { return mapping.lookup(op); });
3513  OperationConverter opConverter(target, patterns, config,
3514  OpConversionMode::Analysis);
3515  LogicalResult status = opConverter.convertOperations(opsToConvert);
3516 
3517  // Remap `legalizableOps`, so that they point to the original ops and not the
3518  // cloned ops.
3519  if (config.legalizableOps) {
3520  DenseSet<Operation *> originalLegalizableOps;
3521  for (Operation *op : *config.legalizableOps)
3522  originalLegalizableOps.insert(inverseOperationMap[op]);
3523  *config.legalizableOps = std::move(originalLegalizableOps);
3524  }
3525 
3526  // Erase the cloned IR.
3527  clonedAncestor->erase();
3528  return status;
3529 }
3530 
3531 LogicalResult
3535  return applyAnalysisConversion(llvm::ArrayRef(op), target, patterns, config);
3536 }
static T moveAndReset(T &obj)
Helper function that moves and returns the given object.
static ConversionTarget::DynamicLegalityCallbackFn composeLegalityCallbacks(ConversionTarget::DynamicLegalityCallbackFn oldCallback, ConversionTarget::DynamicLegalityCallbackFn newCallback)
static FailureOr< SmallVector< Value > > pdllConvertValues(ConversionPatternRewriter &rewriter, ValueRange values)
Remap the given value using the rewriter and the type converter in the provided config.
static LogicalResult convertFuncOpTypes(FunctionOpInterface funcOp, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args)
A utility function to log a failure result for the given reason.
static LogicalResult legalizeUnresolvedMaterialization(RewriterBase &rewriter, UnresolvedMaterializationRewrite *rewrite)
static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args)
A utility function to log a successful result for the given reason.
SmallVector< Value, 1 > ValueVector
A vector of SSA values, optimized for the most common case of a single value.
static Operation * findCommonAncestor(ArrayRef< Operation * > ops)
Find a common IsolatedFromAbove ancestor of the given ops.
static OpBuilder::InsertPoint computeInsertPoint(Value value)
Helper function that computes an insertion point where the given value is defined and can be used wit...
union mlir::linalg::@1215::ArityGroupAndKind::Kind kind
static std::string diag(const llvm::Value &value)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void rewrite(DataFlowSolver &solver, MLIRContext *context, MutableArrayRef< Region > initialRegions)
Rewrite the given regions using the computing analysis.
Definition: SCCP.cpp:67
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:309
Block * getOwner() const
Returns the block that owns this argument.
Definition: Value.h:318
Location getLoc() const
Return the location for this argument.
Definition: Value.h:324
Block represents an ordered list of Operations.
Definition: Block.h:33
OpListType::iterator iterator
Definition: Block.h:140
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
Definition: Block.cpp:149
bool empty()
Definition: Block.h:148
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
unsigned getNumArguments()
Definition: Block.h:128
Operation & back()
Definition: Block.h:152
void erase()
Unlink this Block from its parent region and delete it.
Definition: Block.cpp:66
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition: Block.cpp:27
void dropAllDefinedValueUses()
This drops all uses of values defined in this block or in the blocks of nested regions wherever the u...
Definition: Block.cpp:94
RetT walk(FnT &&callback)
Walk all nested operations, blocks (including this block) or regions, depending on the type of callba...
Definition: Block.h:305
OpListType & getOperations()
Definition: Block.h:137
BlockArgListType getArguments()
Definition: Block.h:87
Operation & front()
Definition: Block.h:153
iterator end()
Definition: Block.h:144
iterator begin()
Definition: Block.h:143
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:31
MLIRContext * getContext() const
Definition: Builders.h:55
Location getUnknownLoc()
Definition: Builders.cpp:24
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
LogicalResult getRemappedValues(ValueRange keys, SmallVectorImpl< Value > &results)
Return the converted values that replace 'keys' with types defined by the type converter of the curre...
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Apply a signature conversion to each block in the given region.
void replaceOpWithMultiple(Operation *op, SmallVector< SmallVector< Value >> &&newValues)
Replace the given operation with the new value ranges.
Block * applySignatureConversion(Block *block, TypeConverter::SignatureConversion &conversion, const TypeConverter *converter=nullptr)
Apply a signature conversion to given block.
void startOpModification(Operation *op) override
PatternRewriter hook for updating the given operation in-place.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
detail::ConversionPatternRewriterImpl & getImpl()
Return a reference to the internal implementation.
void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={}) override
PatternRewriter hook for inlining the ops of a block into another block.
void eraseBlock(Block *block) override
PatternRewriter hook for erase all operations in a block.
void cancelOpModification(Operation *op) override
PatternRewriter hook for updating the given operation in-place.
Value getRemappedValue(Value key)
Return the converted value of 'key' with a type defined by the type converter of the currently execut...
void finalizeOpModification(Operation *op) override
PatternRewriter hook for updating the given operation in-place.
void replaceUsesOfBlockArgument(BlockArgument from, ValueRange to)
Replace all the uses of the block argument from with to.
Base class for the conversion patterns.
SmallVector< Value > getOneToOneAdaptorOperands(ArrayRef< ValueRange > operands) const
Given an array of value ranges, which are the inputs to a 1:N adaptor, try to extract the single valu...
virtual LogicalResult matchAndRewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const
Hook for derived classes to implement combined matching and rewriting.
This class describes a specific conversion target.
void setDialectAction(ArrayRef< StringRef > dialectNames, LegalizationAction action)
Register a legality action for the given dialects.
void setOpAction(OperationName op, LegalizationAction action)
Register a legality action for the given operation.
std::optional< LegalOpDetails > isLegal(Operation *op) const
If the given operation instance is legal on this target, a structure containing legality information ...
std::optional< LegalizationAction > getOpAction(OperationName op) const
Get the legality action for the given operation.
LegalizationAction
This enumeration corresponds to the specific action to take when considering an operation legal for t...
void markOpRecursivelyLegal(OperationName name, const DynamicLegalityCallbackFn &callback)
Mark an operation, that must have either been set as Legal or DynamicallyLegal, as being recursively ...
std::function< std::optional< bool >(Operation *)> DynamicLegalityCallbackFn
The signature of the callback used to determine if an operation is dynamically legal on the target.
bool isIllegal(Operation *op) const
Returns true is operation instance is illegal on this target.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
Definition: Diagnostics.h:155
A class for computing basic dominance information.
Definition: Dominance.h:140
bool dominates(Operation *a, Operation *b) const
Return true if operation A dominates operation B, i.e.
Definition: Dominance.h:158
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
const DenseMap< Operation *, Operation * > & getOperationMap() const
Return the held operation mapping.
Definition: IRMapping.h:88
auto lookup(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:72
user_range getUsers() const
Returns a range of all users.
Definition: UseDefLists.h:274
void dropAllUses()
Drop all uses of this object from their respective owners.
Definition: UseDefLists.h:202
void replaceAllUsesWith(ValueT &&newValue)
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
Definition: UseDefLists.h:211
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:729
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:314
Location objects represent source locations information in MLIR.
Definition: Location.h:32
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
bool isMultithreadingEnabled()
Return true if multi-threading is enabled by the context.
This class represents a saved insertion point.
Definition: Builders.h:325
Block::iterator getPoint() const
Definition: Builders.h:338
bool isSet() const
Returns true if this insert point is set.
Definition: Builders.h:335
Block * getBlock() const
Definition: Builders.h:337
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class helps build Operations.
Definition: Builders.h:205
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:425
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:396
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Definition: Builders.h:318
LogicalResult tryFold(Operation *op, SmallVectorImpl< Value > &results, SmallVectorImpl< Operation * > *materializedConstants=nullptr)
Attempts to fold the given operation and places new results within results.
Definition: Builders.cpp:468
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:452
OpInterfaceConversionPattern is a wrapper around ConversionPattern that allows for matching and rewri...
OpInterfaceConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
This class represents an operand of an operation.
Definition: Value.h:257
This is a value defined by a result of an operation.
Definition: Value.h:447
This class provides the API for ops that are known to be isolated from above.
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:43
type_range getTypes() const
Definition: ValueRange.cpp:28
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
void setLoc(Location loc)
Set the source location the operation was defined or derived from.
Definition: Operation.h:226
DictionaryAttr getAttrDictionary()
Return all of the attributes on this operation as a DictionaryAttr.
Definition: Operation.cpp:295
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:749
void dropAllUses()
Drop all uses of results of this operation.
Definition: Operation.h:834
void setAttrs(DictionaryAttr newAttrs)
Set the attributes from a dictionary on this operation.
Definition: Operation.cpp:304
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Definition: Operation.cpp:385
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
Definition: Operation.cpp:718
operand_iterator operand_begin()
Definition: Operation.h:374
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:797
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:674
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:512
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:267
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
operand_iterator operand_end()
Definition: Operation.h:375
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
Definition: Operation.h:248
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:677
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
operand_type_range getOperandTypes()
Definition: Operation.h:397
result_type_range getResultTypes()
Definition: Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
void setSuccessor(Block *block, unsigned index)
Definition: Operation.cpp:606
bool isAncestor(Operation *other)
Return true if this operation is an ancestor of the other operation.
Definition: Operation.h:263
void setOperands(ValueRange operands)
Replace the current operands of this operation with the ones provided in 'operands'.
Definition: Operation.cpp:236
result_range getResults()
Definition: Operation.h:415
int getPropertiesStorageSize() const
Returns the properties storage size.
Definition: Operation.h:896
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
Definition: Operation.cpp:218
succ_iterator successor_end()
Definition: Operation.h:702
OpaqueProperties getPropertiesStorage()
Returns the properties storage.
Definition: Operation.h:900
void erase()
Remove this operation from its parent block and delete it.
Definition: Operation.cpp:538
void copyProperties(OpaqueProperties rhs)
Copy properties from an existing other properties object.
Definition: Operation.cpp:365
succ_iterator successor_begin()
Definition: Operation.h:701
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
void notifyRewriteEnd(PatternRewriter &rewriter) final
void notifyRewriteBegin(PatternRewriter &rewriter) final
Hooks that are invoked at the beginning and end of a rewrite of a matched pattern.
This class manages the application of a group of rewrite patterns, with a user-provided cost model.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
static PatternBenefit impossibleToMatch()
Definition: PatternMatch.h:43
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:748
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
Definition: PatternMatch.h:73
bool hasBoundedRewriteRecursion() const
Returns true if this pattern is known to result in recursive application, i.e.
Definition: PatternMatch.h:129
std::optional< OperationName > getRootKind() const
Return the root node that this pattern matches.
Definition: PatternMatch.h:94
ArrayRef< OperationName > getGeneratedOps() const
Return a list of operations that may be generated when rewriting an operation instance with this patt...
Definition: PatternMatch.h:90
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition: Region.h:200
bool empty()
Definition: Region.h:60
iterator end()
Definition: Region.h:56
BlockListType & getBlocks()
Definition: Region.h:45
Block & front()
Definition: Region.h:65
BlockListType::iterator iterator
Definition: Region.h:52
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:358
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:681
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:601
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void moveOpBefore(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:593
The general result of a type attribute conversion callback, allowing for early termination.
static AttributeConversionResult result(Attribute attr)
This class provides all of the information necessary to convert a type signature.
void addInputs(unsigned origInputNo, ArrayRef< Type > types)
Remap an input of the original signature with a new set of types.
std::optional< InputMapping > getInputMapping(unsigned input) const
Get the input mapping for the given argument.
ArrayRef< Type > getConvertedTypes() const
Return the argument types for the new signature.
void remapInput(unsigned origInputNo, ArrayRef< Value > replacements)
Remap an input of the original signature to replacements values.
Type conversion class.
std::optional< Attribute > convertTypeAttribute(Type type, Attribute attr) const
Convert an attribute present attr from within the type type using the registered conversion functions...
Value materializeSourceConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs) const
Materialize a conversion from a set of types into one result type by generating a cast sequence of so...
bool isLegal(Type type) const
Return true if the given type is legal for this type converter, i.e.
LogicalResult convertSignatureArgs(TypeRange types, SignatureConversion &result, unsigned origInputOffset=0) const
LogicalResult convertSignatureArg(unsigned inputNo, Type type, SignatureConversion &result) const
This method allows for converting a specific argument of a signature.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
std::optional< SignatureConversion > convertBlockSignature(Block *block) const
This function converts the type signature of the given block, by invoking 'convertSignatureArg' for e...
LogicalResult convertTypes(TypeRange types, SmallVectorImpl< Type > &results) const
Convert the given set of types, filling 'results' as necessary.
bool isSignatureLegal(FunctionType ty) const
Return true if the inputs and outputs of the given function type are legal.
Value materializeTargetConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs, Type originalType={}) const
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Block * getParentBlock()
Return the Block in which this Value is defined.
Definition: Value.cpp:48
user_range getUsers() const
Definition: Value.h:218
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
static WalkResult skip()
Definition: WalkResult.h:48
static WalkResult advance()
Definition: WalkResult.h:47
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
AttrTypeReplacer.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
@ Full
Documents are synced by always sending the full content of the document.
Kind
An enumeration of the kinds of predicates.
Definition: Predicate.h:44
Include the generated interface declarations.
void populateFunctionOpInterfaceTypeConversionPattern(StringRef functionLikeOpName, RewritePatternSet &patterns, const TypeConverter &converter)
Add a pattern to the given pattern list to convert the signature of a FunctionOpInterface op with the...
void populateAnyFunctionOpInterfaceTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter)
const FrozenRewritePatternSet GreedyRewriteConfig config
LogicalResult applyFullConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Apply a complete conversion on the given operations, and all nested operations.
FailureOr< Operation * > convertOpResultTypes(Operation *op, ValueRange operands, const TypeConverter &converter, ConversionPatternRewriter &rewriter)
Generic utility to convert op result types according to type converter without knowing exact op type.
const FrozenRewritePatternSet & patterns
LogicalResult applyAnalysisConversion(ArrayRef< Operation * > ops, ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Apply an analysis conversion on the given operations, and all nested operations.
void reconcileUnrealizedCasts(ArrayRef< UnrealizedConversionCastOp > castOps, SmallVectorImpl< UnrealizedConversionCastOp > *remainingCastOps=nullptr)
Try to reconcile all given UnrealizedConversionCastOps and store the left-over ops in remainingCastOp...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void registerConversionPDLFunctions(RewritePatternSet &patterns)
Register the dialect conversion PDL functions with the given pattern set.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
Dialect conversion configuration.
bool allowPatternRollback
If set to "true", pattern rollback is allowed.
RewriterBase::Listener * listener
An optional listener that is notified about all IR modifications in case dialect conversion succeeds.
function_ref< void(Diagnostic &)> notifyCallback
An optional callback used to notify about match failure diagnostics during the conversion.
DenseSet< Operation * > * legalizableOps
Analysis conversion only.
DenseSet< Operation * > * unlegalizedOps
Partial conversion only.
bool buildMaterializations
If set to "true", the dialect conversion attempts to build source/target materializations through the...
A structure containing additional information describing a specific legal operation instance.
bool isRecursivelyLegal
A flag that indicates if this operation is 'recursively' legal.
This iterator enumerates elements according to their dominance relationship.
Definition: Iterators.h:48
virtual void notifyBlockInserted(Block *block, Region *previous, Region::iterator previousIt)
Notify the listener that the specified block was inserted.
Definition: Builders.h:306
virtual void notifyOperationInserted(Operation *op, InsertPoint previous)
Notify the listener that the specified operation was inserted.
Definition: Builders.h:296
LogicalResult convertOperations(ArrayRef< Operation * > ops)
Converts the given operations to the conversion target.
OperationConverter(const ConversionTarget &target, const FrozenRewritePatternSet &patterns, const ConversionConfig &config, OpConversionMode mode)
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addTypes(ArrayRef< Type > newTypes)
virtual void notifyOperationModified(Operation *op)
Notify the listener that the specified operation was modified in-place.
Definition: PatternMatch.h:369
virtual void notifyOperationErased(Operation *op)
Notify the listener that the specified operation is about to be erased.
Definition: PatternMatch.h:393
virtual void notifyBlockErased(Block *block)
Notify the listener that the specified block is about to be erased.
Definition: PatternMatch.h:366
This struct represents a range of new types or a range of values that remaps an existing signature in...
A rewriter that keeps track of erased ops and blocks.
SingleEraseRewriter(MLIRContext *context, std::function< void(Operation *)> opErasedCallback=nullptr)
void eraseOp(Operation *op) override
Erase the given op (unless it was already erased).
void notifyBlockErased(Block *block) override
Notify the listener that the specified block is about to be erased.
void notifyOperationErased(Operation *op) override
Notify the listener that the specified operation is about to be erased.
void eraseBlock(Block *block) override
Erase the given block (unless it was already erased).
void notifyOperationInserted(Operation *op, OpBuilder::InsertPoint previous) override
Notify the listener that the specified operation was inserted.
Value findOrBuildReplacementValue(Value value, const TypeConverter *converter)
Find a replacement value for the given SSA value in the conversion value mapping.
void replaceUsesOfBlockArgument(BlockArgument from, ValueRange to, const TypeConverter *converter)
Replace the given block argument with the given values.
SetVector< Operation * > patternNewOps
A set of operations that were created by the current pattern.
ConversionPatternRewriterImpl(MLIRContext *ctx, const ConversionConfig &config)
DenseMap< Region *, const TypeConverter * > regionToConverter
A mapping of regions to type converters that should be used when converting the arguments of blocks w...
bool wasOpReplaced(Operation *op) const
Return "true" if the given operation was replaced or erased.
void notifyBlockInserted(Block *block, Region *previous, Region::iterator previousIt) override
Notifies that a block was inserted.
void undoRewrites(unsigned numRewritesToKeep=0, StringRef patternName="")
Undo the rewrites (motions, splits) one by one in reverse order until "numRewritesToKeep" rewrites re...
DenseMap< UnrealizedConversionCastOp, UnresolvedMaterializationRewrite * > unresolvedMaterializations
A mapping of all unresolved materializations (UnrealizedConversionCastOp) to the corresponding rewrit...
LogicalResult remapValues(StringRef valueDiagTag, std::optional< Location > inputLoc, PatternRewriter &rewriter, ValueRange values, SmallVector< ValueVector > &remapped)
Remap the given values to those with potentially different types.
Block * applySignatureConversion(ConversionPatternRewriter &rewriter, Block *block, const TypeConverter *converter, TypeConverter::SignatureConversion &signatureConversion)
Apply the given signature conversion on the given block.
FailureOr< Block * > convertRegionTypes(ConversionPatternRewriter &rewriter, Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion)
Convert the types of block arguments within the given region.
void resetState(RewriterState state, StringRef patternName="")
Reset the state of the rewriter to a previously saved point.
void applyRewrites()
Apply all requested operation rewrites.
void inlineBlockBefore(Block *source, Block *dest, Block::iterator before)
Inline the source block into the destination block before the given iterator.
RewriterState getCurrentState()
Return the current state of the rewriter.
llvm::ScopedPrinter logger
A logger used to emit diagnostics during the conversion process.
void appendRewrite(Args &&...args)
Append a rewrite.
SmallPtrSet< Operation *, 1 > pendingRootUpdates
A set of operations that have pending updates.
void notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
Notifies that a pattern match failed for the given reason.
bool isOpIgnored(Operation *op) const
Return "true" if the given operation is ignored, and does not need to be converted.
void eraseBlock(Block *block)
Erase the given block and its contents.
SetVector< Block * > patternInsertedBlocks
A set of blocks that were inserted (newly-created blocks or moved blocks) by the current pattern.
ValueRange buildUnresolvedMaterialization(MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc, ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes, Type originalType, const TypeConverter *converter, UnrealizedConversionCastOp *castOp=nullptr)
Build an unresolved materialization operation given a range of output types and a list of input opera...
SetVector< Operation * > ignoredOps
A set of operations that should no longer be considered for legalization.
SmallVector< std::unique_ptr< IRRewrite > > rewrites
Ordered list of block operations (creations, splits, motions).
SetVector< Operation * > patternModifiedOps
A set of operations that were modified by the current pattern.
const ConversionConfig & config
Dialect conversion configuration.
SetVector< Operation * > replacedOps
A set of operations that were replaced/erased.
const TypeConverter * currentTypeConverter
The current type converter, or nullptr if no type converter is currently active.
void replaceOp(Operation *op, SmallVector< SmallVector< Value >> &&newValues)
Replace the results of the given operation with the given values and erase the operation.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.