MLIR  21.0.0git
DialectConversion.cpp
Go to the documentation of this file.
1 //===- DialectConversion.cpp - MLIR dialect conversion generic pass -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 #include "mlir/Config/mlir-config.h"
11 #include "mlir/IR/Block.h"
12 #include "mlir/IR/Builders.h"
13 #include "mlir/IR/BuiltinOps.h"
14 #include "mlir/IR/Dominance.h"
15 #include "mlir/IR/IRMapping.h"
16 #include "mlir/IR/Iterators.h"
19 #include "llvm/ADT/ScopeExit.h"
20 #include "llvm/ADT/SetVector.h"
21 #include "llvm/ADT/SmallPtrSet.h"
22 #include "llvm/Support/Debug.h"
23 #include "llvm/Support/FormatVariadic.h"
24 #include "llvm/Support/SaveAndRestore.h"
25 #include "llvm/Support/ScopedPrinter.h"
26 #include <optional>
27 
28 using namespace mlir;
29 using namespace mlir::detail;
30 
31 #define DEBUG_TYPE "dialect-conversion"
32 
33 /// A utility function to log a successful result for the given reason.
34 template <typename... Args>
35 static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
36  LLVM_DEBUG({
37  os.unindent();
38  os.startLine() << "} -> SUCCESS";
39  if (!fmt.empty())
40  os.getOStream() << " : "
41  << llvm::formatv(fmt.data(), std::forward<Args>(args)...);
42  os.getOStream() << "\n";
43  });
44 }
45 
46 /// A utility function to log a failure result for the given reason.
47 template <typename... Args>
48 static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
49  LLVM_DEBUG({
50  os.unindent();
51  os.startLine() << "} -> FAILURE : "
52  << llvm::formatv(fmt.data(), std::forward<Args>(args)...)
53  << "\n";
54  });
55 }
56 
57 /// Helper function that computes an insertion point where the given value is
58 /// defined and can be used without a dominance violation.
60  Block *insertBlock = value.getParentBlock();
61  Block::iterator insertPt = insertBlock->begin();
62  if (OpResult inputRes = dyn_cast<OpResult>(value))
63  insertPt = ++inputRes.getOwner()->getIterator();
64  return OpBuilder::InsertPoint(insertBlock, insertPt);
65 }
66 
67 /// Helper function that computes an insertion point where the given values are
68 /// defined and can be used without a dominance violation.
70  assert(!vals.empty() && "expected at least one value");
71  DominanceInfo domInfo;
72  OpBuilder::InsertPoint pt = computeInsertPoint(vals.front());
73  for (Value v : vals.drop_front()) {
74  // Choose the "later" insertion point.
76  if (domInfo.dominates(pt.getBlock(), pt.getPoint(), nextPt.getBlock(),
77  nextPt.getPoint())) {
78  // pt is before nextPt => choose nextPt.
79  pt = nextPt;
80  } else {
81 #ifndef NDEBUG
82  // nextPt should be before pt => choose pt.
83  // If pt, nextPt are no dominance relationship, then there is no valid
84  // insertion point at which all given values are defined.
85  bool dom = domInfo.dominates(nextPt.getBlock(), nextPt.getPoint(),
86  pt.getBlock(), pt.getPoint());
87  assert(dom && "unable to find valid insertion point");
88 #endif // NDEBUG
89  }
90  }
91  return pt;
92 }
93 
94 //===----------------------------------------------------------------------===//
95 // ConversionValueMapping
96 //===----------------------------------------------------------------------===//
97 
98 /// A vector of SSA values, optimized for the most common case of a single
99 /// value.
100 using ValueVector = SmallVector<Value, 1>;
101 
102 namespace {
103 
104 /// Helper class to make it possible to use `ValueVector` as a key in DenseMap.
105 struct ValueVectorMapInfo {
106  static ValueVector getEmptyKey() { return ValueVector{Value()}; }
107  static ValueVector getTombstoneKey() { return ValueVector{Value(), Value()}; }
108  static ::llvm::hash_code getHashValue(const ValueVector &val) {
109  return ::llvm::hash_combine_range(val.begin(), val.end());
110  }
111  static bool isEqual(const ValueVector &LHS, const ValueVector &RHS) {
112  return LHS == RHS;
113  }
114 };
115 
116 /// This class wraps a IRMapping to provide recursive lookup
117 /// functionality, i.e. we will traverse if the mapped value also has a mapping.
118 struct ConversionValueMapping {
119  /// Return "true" if an SSA value is mapped to the given value. May return
120  /// false positives.
121  bool isMappedTo(Value value) const { return mappedTo.contains(value); }
122 
123  /// Lookup the most recently mapped values with the desired types in the
124  /// mapping.
125  ///
126  /// Special cases:
127  /// - If the desired type range is empty, simply return the most recently
128  /// mapped values.
129  /// - If there is no mapping to the desired types, also return the most
130  /// recently mapped values.
131  /// - If there is no mapping for the given values at all, return the given
132  /// value.
133  ValueVector lookupOrDefault(Value from, TypeRange desiredTypes = {}) const;
134 
135  /// Lookup the given value within the map, or return an empty vector if the
136  /// value is not mapped. If it is mapped, this follows the same behavior
137  /// as `lookupOrDefault`.
138  ValueVector lookupOrNull(Value from, TypeRange desiredTypes = {}) const;
139 
140  template <typename T>
141  struct IsValueVector : std::is_same<std::decay_t<T>, ValueVector> {};
142 
143  /// Map a value vector to the one provided.
144  template <typename OldVal, typename NewVal>
145  std::enable_if_t<IsValueVector<OldVal>::value && IsValueVector<NewVal>::value>
146  map(OldVal &&oldVal, NewVal &&newVal) {
147  LLVM_DEBUG({
148  ValueVector next(newVal);
149  while (true) {
150  assert(next != oldVal && "inserting cyclic mapping");
151  auto it = mapping.find(next);
152  if (it == mapping.end())
153  break;
154  next = it->second;
155  }
156  });
157  mappedTo.insert_range(newVal);
158 
159  mapping[std::forward<OldVal>(oldVal)] = std::forward<NewVal>(newVal);
160  }
161 
162  /// Map a value vector or single value to the one provided.
163  template <typename OldVal, typename NewVal>
164  std::enable_if_t<!IsValueVector<OldVal>::value ||
165  !IsValueVector<NewVal>::value>
166  map(OldVal &&oldVal, NewVal &&newVal) {
167  if constexpr (IsValueVector<OldVal>{}) {
168  map(std::forward<OldVal>(oldVal), ValueVector{newVal});
169  } else if constexpr (IsValueVector<NewVal>{}) {
170  map(ValueVector{oldVal}, std::forward<NewVal>(newVal));
171  } else {
172  map(ValueVector{oldVal}, ValueVector{newVal});
173  }
174  }
175 
176  void map(Value oldVal, SmallVector<Value> &&newVal) {
177  map(ValueVector{oldVal}, ValueVector(std::move(newVal)));
178  }
179 
180  /// Drop the last mapping for the given values.
181  void erase(const ValueVector &value) { mapping.erase(value); }
182 
183 private:
184  /// Current value mappings.
186 
187  /// All SSA values that are mapped to. May contain false positives.
188  DenseSet<Value> mappedTo;
189 };
190 } // namespace
191 
193 ConversionValueMapping::lookupOrDefault(Value from,
194  TypeRange desiredTypes) const {
195  // Try to find the deepest values that have the desired types. If there is no
196  // such mapping, simply return the deepest values.
197  ValueVector desiredValue;
198  ValueVector current{from};
199  do {
200  // Store the current value if the types match.
201  if (TypeRange(ValueRange(current)) == desiredTypes)
202  desiredValue = current;
203 
204  // If possible, Replace each value with (one or multiple) mapped values.
205  ValueVector next;
206  for (Value v : current) {
207  auto it = mapping.find({v});
208  if (it != mapping.end()) {
209  llvm::append_range(next, it->second);
210  } else {
211  next.push_back(v);
212  }
213  }
214  if (next != current) {
215  // If at least one value was replaced, continue the lookup from there.
216  current = std::move(next);
217  continue;
218  }
219 
220  // Otherwise: Check if there is a mapping for the entire vector. Such
221  // mappings are materializations. (N:M mapping are not supported for value
222  // replacements.)
223  //
224  // Note: From a correctness point of view, materializations do not have to
225  // be stored (and looked up) in the mapping. But for performance reasons,
226  // we choose to reuse existing IR (when possible) instead of creating it
227  // multiple times.
228  auto it = mapping.find(current);
229  if (it == mapping.end()) {
230  // No mapping found: The lookup stops here.
231  break;
232  }
233  current = it->second;
234  } while (true);
235 
236  // If the desired values were found use them, otherwise default to the leaf
237  // values.
238  // Note: If `desiredTypes` is empty, this function always returns `current`.
239  return !desiredValue.empty() ? std::move(desiredValue) : std::move(current);
240 }
241 
242 ValueVector ConversionValueMapping::lookupOrNull(Value from,
243  TypeRange desiredTypes) const {
244  ValueVector result = lookupOrDefault(from, desiredTypes);
245  if (result == ValueVector{from} ||
246  (!desiredTypes.empty() && TypeRange(ValueRange(result)) != desiredTypes))
247  return {};
248  return result;
249 }
250 
251 //===----------------------------------------------------------------------===//
252 // Rewriter and Translation State
253 //===----------------------------------------------------------------------===//
254 namespace {
255 /// This class contains a snapshot of the current conversion rewriter state.
256 /// This is useful when saving and undoing a set of rewrites.
257 struct RewriterState {
258  RewriterState(unsigned numRewrites, unsigned numIgnoredOperations,
259  unsigned numReplacedOps)
260  : numRewrites(numRewrites), numIgnoredOperations(numIgnoredOperations),
261  numReplacedOps(numReplacedOps) {}
262 
263  /// The current number of rewrites performed.
264  unsigned numRewrites;
265 
266  /// The current number of ignored operations.
267  unsigned numIgnoredOperations;
268 
269  /// The current number of replaced ops that are scheduled for erasure.
270  unsigned numReplacedOps;
271 };
272 
273 //===----------------------------------------------------------------------===//
274 // IR rewrites
275 //===----------------------------------------------------------------------===//
276 
277 /// An IR rewrite that can be committed (upon success) or rolled back (upon
278 /// failure).
279 ///
280 /// The dialect conversion keeps track of IR modifications (requested by the
281 /// user through the rewriter API) in `IRRewrite` objects. Some kind of rewrites
282 /// are directly applied to the IR as the rewriter API is used, some are applied
283 /// partially, and some are delayed until the `IRRewrite` objects are committed.
284 class IRRewrite {
285 public:
286  /// The kind of the rewrite. Rewrites can be undone if the conversion fails.
287  /// Enum values are ordered, so that they can be used in `classof`: first all
288  /// block rewrites, then all operation rewrites.
289  enum class Kind {
290  // Block rewrites
291  CreateBlock,
292  EraseBlock,
293  InlineBlock,
294  MoveBlock,
295  BlockTypeConversion,
296  ReplaceBlockArg,
297  // Operation rewrites
298  MoveOperation,
299  ModifyOperation,
300  ReplaceOperation,
301  CreateOperation,
302  UnresolvedMaterialization
303  };
304 
305  virtual ~IRRewrite() = default;
306 
307  /// Roll back the rewrite. Operations may be erased during rollback.
308  virtual void rollback() = 0;
309 
310  /// Commit the rewrite. At this point, it is certain that the dialect
311  /// conversion will succeed. All IR modifications, except for operation/block
312  /// erasure, must be performed through the given rewriter.
313  ///
314  /// Instead of erasing operations/blocks, they should merely be unlinked
315  /// commit phase and finally be erased during the cleanup phase. This is
316  /// because internal dialect conversion state (such as `mapping`) may still
317  /// be using them.
318  ///
319  /// Any IR modification that was already performed before the commit phase
320  /// (e.g., insertion of an op) must be communicated to the listener that may
321  /// be attached to the given rewriter.
322  virtual void commit(RewriterBase &rewriter) {}
323 
324  /// Cleanup operations/blocks. Cleanup is called after commit.
325  virtual void cleanup(RewriterBase &rewriter) {}
326 
327  Kind getKind() const { return kind; }
328 
329  static bool classof(const IRRewrite *rewrite) { return true; }
330 
331 protected:
332  IRRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl)
333  : kind(kind), rewriterImpl(rewriterImpl) {}
334 
335  const ConversionConfig &getConfig() const;
336 
337  const Kind kind;
338  ConversionPatternRewriterImpl &rewriterImpl;
339 };
340 
341 /// A block rewrite.
342 class BlockRewrite : public IRRewrite {
343 public:
344  /// Return the block that this rewrite operates on.
345  Block *getBlock() const { return block; }
346 
347  static bool classof(const IRRewrite *rewrite) {
348  return rewrite->getKind() >= Kind::CreateBlock &&
349  rewrite->getKind() <= Kind::ReplaceBlockArg;
350  }
351 
352 protected:
353  BlockRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl,
354  Block *block)
355  : IRRewrite(kind, rewriterImpl), block(block) {}
356 
357  // The block that this rewrite operates on.
358  Block *block;
359 };
360 
361 /// Creation of a block. Block creations are immediately reflected in the IR.
362 /// There is no extra work to commit the rewrite. During rollback, the newly
363 /// created block is erased.
364 class CreateBlockRewrite : public BlockRewrite {
365 public:
366  CreateBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block)
367  : BlockRewrite(Kind::CreateBlock, rewriterImpl, block) {}
368 
369  static bool classof(const IRRewrite *rewrite) {
370  return rewrite->getKind() == Kind::CreateBlock;
371  }
372 
373  void commit(RewriterBase &rewriter) override {
374  // The block was already created and inserted. Just inform the listener.
375  if (auto *listener = rewriter.getListener())
376  listener->notifyBlockInserted(block, /*previous=*/{}, /*previousIt=*/{});
377  }
378 
379  void rollback() override {
380  // Unlink all of the operations within this block, they will be deleted
381  // separately.
382  auto &blockOps = block->getOperations();
383  while (!blockOps.empty())
384  blockOps.remove(blockOps.begin());
385  block->dropAllUses();
386  if (block->getParent())
387  block->erase();
388  else
389  delete block;
390  }
391 };
392 
393 /// Erasure of a block. Block erasures are partially reflected in the IR. Erased
394 /// blocks are immediately unlinked, but only erased during cleanup. This makes
395 /// it easier to rollback a block erasure: the block is simply inserted into its
396 /// original location.
397 class EraseBlockRewrite : public BlockRewrite {
398 public:
399  EraseBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block)
400  : BlockRewrite(Kind::EraseBlock, rewriterImpl, block),
401  region(block->getParent()), insertBeforeBlock(block->getNextNode()) {}
402 
403  static bool classof(const IRRewrite *rewrite) {
404  return rewrite->getKind() == Kind::EraseBlock;
405  }
406 
407  ~EraseBlockRewrite() override {
408  assert(!block &&
409  "rewrite was neither rolled back nor committed/cleaned up");
410  }
411 
412  void rollback() override {
413  // The block (owned by this rewrite) was not actually erased yet. It was
414  // just unlinked. Put it back into its original position.
415  assert(block && "expected block");
416  auto &blockList = region->getBlocks();
417  Region::iterator before = insertBeforeBlock
418  ? Region::iterator(insertBeforeBlock)
419  : blockList.end();
420  blockList.insert(before, block);
421  block = nullptr;
422  }
423 
424  void commit(RewriterBase &rewriter) override {
425  // Erase the block.
426  assert(block && "expected block");
427  assert(block->empty() && "expected empty block");
428 
429  // Notify the listener that the block is about to be erased.
430  if (auto *listener =
431  dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener()))
432  listener->notifyBlockErased(block);
433  }
434 
435  void cleanup(RewriterBase &rewriter) override {
436  // Erase the block.
437  block->dropAllDefinedValueUses();
438  delete block;
439  block = nullptr;
440  }
441 
442 private:
443  // The region in which this block was previously contained.
444  Region *region;
445 
446  // The original successor of this block before it was unlinked. "nullptr" if
447  // this block was the only block in the region.
448  Block *insertBeforeBlock;
449 };
450 
451 /// Inlining of a block. This rewrite is immediately reflected in the IR.
452 /// Note: This rewrite represents only the inlining of the operations. The
453 /// erasure of the inlined block is a separate rewrite.
454 class InlineBlockRewrite : public BlockRewrite {
455 public:
456  InlineBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block,
457  Block *sourceBlock, Block::iterator before)
458  : BlockRewrite(Kind::InlineBlock, rewriterImpl, block),
459  sourceBlock(sourceBlock),
460  firstInlinedInst(sourceBlock->empty() ? nullptr
461  : &sourceBlock->front()),
462  lastInlinedInst(sourceBlock->empty() ? nullptr : &sourceBlock->back()) {
463  // If a listener is attached to the dialect conversion, ops must be moved
464  // one-by-one. When they are moved in bulk, notifications cannot be sent
465  // because the ops that used to be in the source block at the time of the
466  // inlining (before the "commit" phase) are unknown at the time when
467  // notifications are sent (which is during the "commit" phase).
468  assert(!getConfig().listener &&
469  "InlineBlockRewrite not supported if listener is attached");
470  }
471 
472  static bool classof(const IRRewrite *rewrite) {
473  return rewrite->getKind() == Kind::InlineBlock;
474  }
475 
476  void rollback() override {
477  // Put the operations from the destination block (owned by the rewrite)
478  // back into the source block.
479  if (firstInlinedInst) {
480  assert(lastInlinedInst && "expected operation");
481  sourceBlock->getOperations().splice(sourceBlock->begin(),
482  block->getOperations(),
483  Block::iterator(firstInlinedInst),
484  ++Block::iterator(lastInlinedInst));
485  }
486  }
487 
488 private:
489  // The block that originally contained the operations.
490  Block *sourceBlock;
491 
492  // The first inlined operation.
493  Operation *firstInlinedInst;
494 
495  // The last inlined operation.
496  Operation *lastInlinedInst;
497 };
498 
499 /// Moving of a block. This rewrite is immediately reflected in the IR.
500 class MoveBlockRewrite : public BlockRewrite {
501 public:
502  MoveBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block,
503  Region *region, Block *insertBeforeBlock)
504  : BlockRewrite(Kind::MoveBlock, rewriterImpl, block), region(region),
505  insertBeforeBlock(insertBeforeBlock) {}
506 
507  static bool classof(const IRRewrite *rewrite) {
508  return rewrite->getKind() == Kind::MoveBlock;
509  }
510 
511  void commit(RewriterBase &rewriter) override {
512  // The block was already moved. Just inform the listener.
513  if (auto *listener = rewriter.getListener()) {
514  // Note: `previousIt` cannot be passed because this is a delayed
515  // notification and iterators into past IR state cannot be represented.
516  listener->notifyBlockInserted(block, /*previous=*/region,
517  /*previousIt=*/{});
518  }
519  }
520 
521  void rollback() override {
522  // Move the block back to its original position.
523  Region::iterator before =
524  insertBeforeBlock ? Region::iterator(insertBeforeBlock) : region->end();
525  region->getBlocks().splice(before, block->getParent()->getBlocks(), block);
526  }
527 
528 private:
529  // The region in which this block was previously contained.
530  Region *region;
531 
532  // The original successor of this block before it was moved. "nullptr" if
533  // this block was the only block in the region.
534  Block *insertBeforeBlock;
535 };
536 
537 /// Block type conversion. This rewrite is partially reflected in the IR.
538 class BlockTypeConversionRewrite : public BlockRewrite {
539 public:
540  BlockTypeConversionRewrite(ConversionPatternRewriterImpl &rewriterImpl,
541  Block *origBlock, Block *newBlock)
542  : BlockRewrite(Kind::BlockTypeConversion, rewriterImpl, origBlock),
543  newBlock(newBlock) {}
544 
545  static bool classof(const IRRewrite *rewrite) {
546  return rewrite->getKind() == Kind::BlockTypeConversion;
547  }
548 
549  Block *getOrigBlock() const { return block; }
550 
551  Block *getNewBlock() const { return newBlock; }
552 
553  void commit(RewriterBase &rewriter) override;
554 
555  void rollback() override;
556 
557 private:
558  /// The new block that was created as part of this signature conversion.
559  Block *newBlock;
560 };
561 
562 /// Replacing a block argument. This rewrite is not immediately reflected in the
563 /// IR. An internal IR mapping is updated, but the actual replacement is delayed
564 /// until the rewrite is committed.
565 class ReplaceBlockArgRewrite : public BlockRewrite {
566 public:
567  ReplaceBlockArgRewrite(ConversionPatternRewriterImpl &rewriterImpl,
568  Block *block, BlockArgument arg,
569  const TypeConverter *converter)
570  : BlockRewrite(Kind::ReplaceBlockArg, rewriterImpl, block), arg(arg),
571  converter(converter) {}
572 
573  static bool classof(const IRRewrite *rewrite) {
574  return rewrite->getKind() == Kind::ReplaceBlockArg;
575  }
576 
577  void commit(RewriterBase &rewriter) override;
578 
579  void rollback() override;
580 
581 private:
582  BlockArgument arg;
583 
584  /// The current type converter when the block argument was replaced.
585  const TypeConverter *converter;
586 };
587 
588 /// An operation rewrite.
589 class OperationRewrite : public IRRewrite {
590 public:
591  /// Return the operation that this rewrite operates on.
592  Operation *getOperation() const { return op; }
593 
594  static bool classof(const IRRewrite *rewrite) {
595  return rewrite->getKind() >= Kind::MoveOperation &&
596  rewrite->getKind() <= Kind::UnresolvedMaterialization;
597  }
598 
599 protected:
600  OperationRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl,
601  Operation *op)
602  : IRRewrite(kind, rewriterImpl), op(op) {}
603 
604  // The operation that this rewrite operates on.
605  Operation *op;
606 };
607 
608 /// Moving of an operation. This rewrite is immediately reflected in the IR.
609 class MoveOperationRewrite : public OperationRewrite {
610 public:
611  MoveOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
612  Operation *op, Block *block, Operation *insertBeforeOp)
613  : OperationRewrite(Kind::MoveOperation, rewriterImpl, op), block(block),
614  insertBeforeOp(insertBeforeOp) {}
615 
616  static bool classof(const IRRewrite *rewrite) {
617  return rewrite->getKind() == Kind::MoveOperation;
618  }
619 
620  void commit(RewriterBase &rewriter) override {
621  // The operation was already moved. Just inform the listener.
622  if (auto *listener = rewriter.getListener()) {
623  // Note: `previousIt` cannot be passed because this is a delayed
624  // notification and iterators into past IR state cannot be represented.
625  listener->notifyOperationInserted(
626  op, /*previous=*/OpBuilder::InsertPoint(/*insertBlock=*/block,
627  /*insertPt=*/{}));
628  }
629  }
630 
631  void rollback() override {
632  // Move the operation back to its original position.
633  Block::iterator before =
634  insertBeforeOp ? Block::iterator(insertBeforeOp) : block->end();
635  block->getOperations().splice(before, op->getBlock()->getOperations(), op);
636  }
637 
638 private:
639  // The block in which this operation was previously contained.
640  Block *block;
641 
642  // The original successor of this operation before it was moved. "nullptr"
643  // if this operation was the only operation in the region.
644  Operation *insertBeforeOp;
645 };
646 
647 /// In-place modification of an op. This rewrite is immediately reflected in
648 /// the IR. The previous state of the operation is stored in this object.
649 class ModifyOperationRewrite : public OperationRewrite {
650 public:
651  ModifyOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
652  Operation *op)
653  : OperationRewrite(Kind::ModifyOperation, rewriterImpl, op),
654  name(op->getName()), loc(op->getLoc()), attrs(op->getAttrDictionary()),
655  operands(op->operand_begin(), op->operand_end()),
656  successors(op->successor_begin(), op->successor_end()) {
657  if (OpaqueProperties prop = op->getPropertiesStorage()) {
658  // Make a copy of the properties.
659  propertiesStorage = operator new(op->getPropertiesStorageSize());
660  OpaqueProperties propCopy(propertiesStorage);
661  name.initOpProperties(propCopy, /*init=*/prop);
662  }
663  }
664 
665  static bool classof(const IRRewrite *rewrite) {
666  return rewrite->getKind() == Kind::ModifyOperation;
667  }
668 
669  ~ModifyOperationRewrite() override {
670  assert(!propertiesStorage &&
671  "rewrite was neither committed nor rolled back");
672  }
673 
674  void commit(RewriterBase &rewriter) override {
675  // Notify the listener that the operation was modified in-place.
676  if (auto *listener =
677  dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener()))
678  listener->notifyOperationModified(op);
679 
680  if (propertiesStorage) {
681  OpaqueProperties propCopy(propertiesStorage);
682  // Note: The operation may have been erased in the mean time, so
683  // OperationName must be stored in this object.
684  name.destroyOpProperties(propCopy);
685  operator delete(propertiesStorage);
686  propertiesStorage = nullptr;
687  }
688  }
689 
690  void rollback() override {
691  op->setLoc(loc);
692  op->setAttrs(attrs);
693  op->setOperands(operands);
694  for (const auto &it : llvm::enumerate(successors))
695  op->setSuccessor(it.value(), it.index());
696  if (propertiesStorage) {
697  OpaqueProperties propCopy(propertiesStorage);
698  op->copyProperties(propCopy);
699  name.destroyOpProperties(propCopy);
700  operator delete(propertiesStorage);
701  propertiesStorage = nullptr;
702  }
703  }
704 
705 private:
706  OperationName name;
707  LocationAttr loc;
708  DictionaryAttr attrs;
709  SmallVector<Value, 8> operands;
710  SmallVector<Block *, 2> successors;
711  void *propertiesStorage = nullptr;
712 };
713 
714 /// Replacing an operation. Erasing an operation is treated as a special case
715 /// with "null" replacements. This rewrite is not immediately reflected in the
716 /// IR. An internal IR mapping is updated, but values are not replaced and the
717 /// original op is not erased until the rewrite is committed.
718 class ReplaceOperationRewrite : public OperationRewrite {
719 public:
720  ReplaceOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
721  Operation *op, const TypeConverter *converter)
722  : OperationRewrite(Kind::ReplaceOperation, rewriterImpl, op),
723  converter(converter) {}
724 
725  static bool classof(const IRRewrite *rewrite) {
726  return rewrite->getKind() == Kind::ReplaceOperation;
727  }
728 
729  void commit(RewriterBase &rewriter) override;
730 
731  void rollback() override;
732 
733  void cleanup(RewriterBase &rewriter) override;
734 
735 private:
736  /// An optional type converter that can be used to materialize conversions
737  /// between the new and old values if necessary.
738  const TypeConverter *converter;
739 };
740 
741 class CreateOperationRewrite : public OperationRewrite {
742 public:
743  CreateOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
744  Operation *op)
745  : OperationRewrite(Kind::CreateOperation, rewriterImpl, op) {}
746 
747  static bool classof(const IRRewrite *rewrite) {
748  return rewrite->getKind() == Kind::CreateOperation;
749  }
750 
751  void commit(RewriterBase &rewriter) override {
752  // The operation was already created and inserted. Just inform the listener.
753  if (auto *listener = rewriter.getListener())
754  listener->notifyOperationInserted(op, /*previous=*/{});
755  }
756 
757  void rollback() override;
758 };
759 
760 /// The type of materialization.
761 enum MaterializationKind {
762  /// This materialization materializes a conversion from an illegal type to a
763  /// legal one.
764  Target,
765 
766  /// This materialization materializes a conversion from a legal type back to
767  /// an illegal one.
768  Source
769 };
770 
771 /// An unresolved materialization, i.e., a "builtin.unrealized_conversion_cast"
772 /// op. Unresolved materializations are erased at the end of the dialect
773 /// conversion.
774 class UnresolvedMaterializationRewrite : public OperationRewrite {
775 public:
776  UnresolvedMaterializationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
777  UnrealizedConversionCastOp op,
778  const TypeConverter *converter,
779  MaterializationKind kind, Type originalType,
780  ValueVector mappedValues);
781 
782  static bool classof(const IRRewrite *rewrite) {
783  return rewrite->getKind() == Kind::UnresolvedMaterialization;
784  }
785 
786  void rollback() override;
787 
788  UnrealizedConversionCastOp getOperation() const {
789  return cast<UnrealizedConversionCastOp>(op);
790  }
791 
792  /// Return the type converter of this materialization (which may be null).
793  const TypeConverter *getConverter() const {
794  return converterAndKind.getPointer();
795  }
796 
797  /// Return the kind of this materialization.
798  MaterializationKind getMaterializationKind() const {
799  return converterAndKind.getInt();
800  }
801 
802  /// Return the original type of the SSA value.
803  Type getOriginalType() const { return originalType; }
804 
805 private:
806  /// The corresponding type converter to use when resolving this
807  /// materialization, and the kind of this materialization.
808  llvm::PointerIntPair<const TypeConverter *, 2, MaterializationKind>
809  converterAndKind;
810 
811  /// The original type of the SSA value. Only used for target
812  /// materializations.
813  Type originalType;
814 
815  /// The values in the conversion value mapping that are being replaced by the
816  /// results of this unresolved materialization.
817  ValueVector mappedValues;
818 };
819 } // namespace
820 
821 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
822 /// Return "true" if there is an operation rewrite that matches the specified
823 /// rewrite type and operation among the given rewrites.
824 template <typename RewriteTy, typename R>
825 static bool hasRewrite(R &&rewrites, Operation *op) {
826  return any_of(std::forward<R>(rewrites), [&](auto &rewrite) {
827  auto *rewriteTy = dyn_cast<RewriteTy>(rewrite.get());
828  return rewriteTy && rewriteTy->getOperation() == op;
829  });
830 }
831 
832 /// Return "true" if there is a block rewrite that matches the specified
833 /// rewrite type and block among the given rewrites.
834 template <typename RewriteTy, typename R>
835 static bool hasRewrite(R &&rewrites, Block *block) {
836  return any_of(std::forward<R>(rewrites), [&](auto &rewrite) {
837  auto *rewriteTy = dyn_cast<RewriteTy>(rewrite.get());
838  return rewriteTy && rewriteTy->getBlock() == block;
839  });
840 }
841 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
842 
843 //===----------------------------------------------------------------------===//
844 // ConversionPatternRewriterImpl
845 //===----------------------------------------------------------------------===//
846 namespace mlir {
847 namespace detail {
850  const ConversionConfig &config)
851  : context(ctx), eraseRewriter(ctx), config(config) {}
852 
853  //===--------------------------------------------------------------------===//
854  // State Management
855  //===--------------------------------------------------------------------===//
856 
857  /// Return the current state of the rewriter.
858  RewriterState getCurrentState();
859 
860  /// Apply all requested operation rewrites. This method is invoked when the
861  /// conversion process succeeds.
862  void applyRewrites();
863 
864  /// Reset the state of the rewriter to a previously saved point.
865  void resetState(RewriterState state);
866 
867  /// Append a rewrite. Rewrites are committed upon success and rolled back upon
868  /// failure.
869  template <typename RewriteTy, typename... Args>
870  void appendRewrite(Args &&...args) {
871  rewrites.push_back(
872  std::make_unique<RewriteTy>(*this, std::forward<Args>(args)...));
873  }
874 
875  /// Undo the rewrites (motions, splits) one by one in reverse order until
876  /// "numRewritesToKeep" rewrites remains.
877  void undoRewrites(unsigned numRewritesToKeep = 0);
878 
879  /// Remap the given values to those with potentially different types. Returns
880  /// success if the values could be remapped, failure otherwise. `valueDiagTag`
881  /// is the tag used when describing a value within a diagnostic, e.g.
882  /// "operand".
883  LogicalResult remapValues(StringRef valueDiagTag,
884  std::optional<Location> inputLoc,
885  PatternRewriter &rewriter, ValueRange values,
886  SmallVector<ValueVector> &remapped);
887 
888  /// Return "true" if the given operation is ignored, and does not need to be
889  /// converted.
890  bool isOpIgnored(Operation *op) const;
891 
892  /// Return "true" if the given operation was replaced or erased.
893  bool wasOpReplaced(Operation *op) const;
894 
895  //===--------------------------------------------------------------------===//
896  // Type Conversion
897  //===--------------------------------------------------------------------===//
898 
899  /// Convert the types of block arguments within the given region.
900  FailureOr<Block *>
901  convertRegionTypes(ConversionPatternRewriter &rewriter, Region *region,
902  const TypeConverter &converter,
903  TypeConverter::SignatureConversion *entryConversion);
904 
905  /// Apply the given signature conversion on the given block. The new block
906  /// containing the updated signature is returned. If no conversions were
907  /// necessary, e.g. if the block has no arguments, `block` is returned.
908  /// `converter` is used to generate any necessary cast operations that
909  /// translate between the origin argument types and those specified in the
910  /// signature conversion.
911  Block *applySignatureConversion(
912  ConversionPatternRewriter &rewriter, Block *block,
913  const TypeConverter *converter,
914  TypeConverter::SignatureConversion &signatureConversion);
915 
916  //===--------------------------------------------------------------------===//
917  // Materializations
918  //===--------------------------------------------------------------------===//
919 
920  /// Build an unresolved materialization operation given a range of output
921  /// types and a list of input operands. Returns the inputs if they their
922  /// types match the output types.
923  ///
924  /// If a cast op was built, it can optionally be returned with the `castOp`
925  /// output argument.
926  ///
927  /// If `valuesToMap` is set to a non-null Value, then that value is mapped to
928  /// the results of the unresolved materialization in the conversion value
929  /// mapping.
930  ValueRange buildUnresolvedMaterialization(
931  MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
932  ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes,
933  Type originalType, const TypeConverter *converter,
934  UnrealizedConversionCastOp *castOp = nullptr);
935 
936  /// Find a replacement value for the given SSA value in the conversion value
937  /// mapping. The replacement value must have the same type as the given SSA
938  /// value. If there is no replacement value with the correct type, find the
939  /// latest replacement value (regardless of the type) and build a source
940  /// materialization.
941  Value findOrBuildReplacementValue(Value value,
942  const TypeConverter *converter);
943 
944  //===--------------------------------------------------------------------===//
945  // Rewriter Notification Hooks
946  //===--------------------------------------------------------------------===//
947 
948  //// Notifies that an op was inserted.
949  void notifyOperationInserted(Operation *op,
950  OpBuilder::InsertPoint previous) override;
951 
952  /// Notifies that an op is about to be replaced with the given values.
953  void notifyOpReplaced(Operation *op,
954  SmallVector<SmallVector<Value>> &&newValues);
955 
956  /// Notifies that a block is about to be erased.
957  void notifyBlockIsBeingErased(Block *block);
958 
959  /// Notifies that a block was inserted.
960  void notifyBlockInserted(Block *block, Region *previous,
961  Region::iterator previousIt) override;
962 
963  /// Notifies that a block is being inlined into another block.
964  void notifyBlockBeingInlined(Block *block, Block *srcBlock,
965  Block::iterator before);
966 
967  /// Notifies that a pattern match failed for the given reason.
968  void
969  notifyMatchFailure(Location loc,
970  function_ref<void(Diagnostic &)> reasonCallback) override;
971 
972  //===--------------------------------------------------------------------===//
973  // IR Erasure
974  //===--------------------------------------------------------------------===//
975 
976  /// A rewriter that keeps track of erased ops and blocks. It ensures that no
977  /// operation or block is erased multiple times. This rewriter assumes that
978  /// no new IR is created between calls to `eraseOp`/`eraseBlock`.
980  public:
982  : RewriterBase(context, /*listener=*/this) {}
983 
984  /// Erase the given op (unless it was already erased).
985  void eraseOp(Operation *op) override {
986  if (wasErased(op))
987  return;
988  op->dropAllUses();
990  }
991 
992  /// Erase the given block (unless it was already erased).
993  void eraseBlock(Block *block) override {
994  if (wasErased(block))
995  return;
996  assert(block->empty() && "expected empty block");
997  block->dropAllDefinedValueUses();
999  }
1000 
1001  bool wasErased(void *ptr) const { return erased.contains(ptr); }
1002 
1003  void notifyOperationErased(Operation *op) override { erased.insert(op); }
1004 
1005  void notifyBlockErased(Block *block) override { erased.insert(block); }
1006 
1007  private:
1008  /// Pointers to all erased operations and blocks.
1009  DenseSet<void *> erased;
1010  };
1011 
1012  //===--------------------------------------------------------------------===//
1013  // State
1014  //===--------------------------------------------------------------------===//
1015 
1016  /// MLIR context.
1018 
1019  /// A rewriter that keeps track of ops/block that were already erased and
1020  /// skips duplicate op/block erasures. This rewriter is used during the
1021  /// "cleanup" phase.
1023 
1024  // Mapping between replaced values that differ in type. This happens when
1025  // replacing a value with one of a different type.
1026  ConversionValueMapping mapping;
1027 
1028  /// Ordered list of block operations (creations, splits, motions).
1030 
1031  /// A set of operations that should no longer be considered for legalization.
1032  /// E.g., ops that are recursively legal. Ops that were replaced/erased are
1033  /// tracked separately.
1035 
1036  /// A set of operations that were replaced/erased. Such ops are not erased
1037  /// immediately but only when the dialect conversion succeeds. In the mean
1038  /// time, they should no longer be considered for legalization and any attempt
1039  /// to modify/access them is invalid rewriter API usage.
1041 
1042  /// A mapping of all unresolved materializations (UnrealizedConversionCastOp)
1043  /// to the corresponding rewrite objects.
1046 
1047  /// The current type converter, or nullptr if no type converter is currently
1048  /// active.
1049  const TypeConverter *currentTypeConverter = nullptr;
1050 
1051  /// A mapping of regions to type converters that should be used when
1052  /// converting the arguments of blocks within that region.
1054 
1055  /// Dialect conversion configuration.
1057 
1058 #ifndef NDEBUG
1059  /// A set of operations that have pending updates. This tracking isn't
1060  /// strictly necessary, and is thus only active during debug builds for extra
1061  /// verification.
1063 
1064  /// A logger used to emit diagnostics during the conversion process.
1065  llvm::ScopedPrinter logger{llvm::dbgs()};
1066 #endif
1067 };
1068 } // namespace detail
1069 } // namespace mlir
1070 
1071 const ConversionConfig &IRRewrite::getConfig() const {
1072  return rewriterImpl.config;
1073 }
1074 
1075 void BlockTypeConversionRewrite::commit(RewriterBase &rewriter) {
1076  // Inform the listener about all IR modifications that have already taken
1077  // place: References to the original block have been replaced with the new
1078  // block.
1079  if (auto *listener =
1080  dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener()))
1081  for (Operation *op : getNewBlock()->getUsers())
1082  listener->notifyOperationModified(op);
1083 }
1084 
1085 void BlockTypeConversionRewrite::rollback() {
1086  getNewBlock()->replaceAllUsesWith(getOrigBlock());
1087 }
1088 
1089 void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) {
1090  Value repl = rewriterImpl.findOrBuildReplacementValue(arg, converter);
1091  if (!repl)
1092  return;
1093 
1094  if (isa<BlockArgument>(repl)) {
1095  rewriter.replaceAllUsesWith(arg, repl);
1096  return;
1097  }
1098 
1099  // If the replacement value is an operation, we check to make sure that we
1100  // don't replace uses that are within the parent operation of the
1101  // replacement value.
1102  Operation *replOp = cast<OpResult>(repl).getOwner();
1103  Block *replBlock = replOp->getBlock();
1104  rewriter.replaceUsesWithIf(arg, repl, [&](OpOperand &operand) {
1105  Operation *user = operand.getOwner();
1106  return user->getBlock() != replBlock || replOp->isBeforeInBlock(user);
1107  });
1108 }
1109 
1110 void ReplaceBlockArgRewrite::rollback() { rewriterImpl.mapping.erase({arg}); }
1111 
1112 void ReplaceOperationRewrite::commit(RewriterBase &rewriter) {
1113  auto *listener =
1114  dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener());
1115 
1116  // Compute replacement values.
1117  SmallVector<Value> replacements =
1118  llvm::map_to_vector(op->getResults(), [&](OpResult result) {
1119  return rewriterImpl.findOrBuildReplacementValue(result, converter);
1120  });
1121 
1122  // Notify the listener that the operation is about to be replaced.
1123  if (listener)
1124  listener->notifyOperationReplaced(op, replacements);
1125 
1126  // Replace all uses with the new values.
1127  for (auto [result, newValue] :
1128  llvm::zip_equal(op->getResults(), replacements))
1129  if (newValue)
1130  rewriter.replaceAllUsesWith(result, newValue);
1131 
1132  // The original op will be erased, so remove it from the set of unlegalized
1133  // ops.
1134  if (getConfig().unlegalizedOps)
1135  getConfig().unlegalizedOps->erase(op);
1136 
1137  // Notify the listener that the operation (and its nested operations) was
1138  // erased.
1139  if (listener) {
1141  [&](Operation *op) { listener->notifyOperationErased(op); });
1142  }
1143 
1144  // Do not erase the operation yet. It may still be referenced in `mapping`.
1145  // Just unlink it for now and erase it during cleanup.
1146  op->getBlock()->getOperations().remove(op);
1147 }
1148 
1149 void ReplaceOperationRewrite::rollback() {
1150  for (auto result : op->getResults())
1151  rewriterImpl.mapping.erase({result});
1152 }
1153 
1154 void ReplaceOperationRewrite::cleanup(RewriterBase &rewriter) {
1155  rewriter.eraseOp(op);
1156 }
1157 
1158 void CreateOperationRewrite::rollback() {
1159  for (Region &region : op->getRegions()) {
1160  while (!region.getBlocks().empty())
1161  region.getBlocks().remove(region.getBlocks().begin());
1162  }
1163  op->dropAllUses();
1164  op->erase();
1165 }
1166 
1167 UnresolvedMaterializationRewrite::UnresolvedMaterializationRewrite(
1168  ConversionPatternRewriterImpl &rewriterImpl, UnrealizedConversionCastOp op,
1169  const TypeConverter *converter, MaterializationKind kind, Type originalType,
1170  ValueVector mappedValues)
1171  : OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
1172  converterAndKind(converter, kind), originalType(originalType),
1173  mappedValues(std::move(mappedValues)) {
1174  assert((!originalType || kind == MaterializationKind::Target) &&
1175  "original type is valid only for target materializations");
1176  rewriterImpl.unresolvedMaterializations[op] = this;
1177 }
1178 
1179 void UnresolvedMaterializationRewrite::rollback() {
1180  if (!mappedValues.empty())
1181  rewriterImpl.mapping.erase(mappedValues);
1182  rewriterImpl.unresolvedMaterializations.erase(getOperation());
1183  op->erase();
1184 }
1185 
1187  // Commit all rewrites.
1188  IRRewriter rewriter(context, config.listener);
1189  // Note: New rewrites may be added during the "commit" phase and the
1190  // `rewrites` vector may reallocate.
1191  for (size_t i = 0; i < rewrites.size(); ++i)
1192  rewrites[i]->commit(rewriter);
1193 
1194  // Clean up all rewrites.
1195  for (auto &rewrite : rewrites)
1196  rewrite->cleanup(eraseRewriter);
1197 }
1198 
1199 //===----------------------------------------------------------------------===//
1200 // State Management
1201 //===----------------------------------------------------------------------===//
1202 
1204  return RewriterState(rewrites.size(), ignoredOps.size(), replacedOps.size());
1205 }
1206 
1208  // Undo any rewrites.
1209  undoRewrites(state.numRewrites);
1210 
1211  // Pop all of the recorded ignored operations that are no longer valid.
1212  while (ignoredOps.size() != state.numIgnoredOperations)
1213  ignoredOps.pop_back();
1214 
1215  while (replacedOps.size() != state.numReplacedOps)
1216  replacedOps.pop_back();
1217 }
1218 
1219 void ConversionPatternRewriterImpl::undoRewrites(unsigned numRewritesToKeep) {
1220  for (auto &rewrite :
1221  llvm::reverse(llvm::drop_begin(rewrites, numRewritesToKeep)))
1222  rewrite->rollback();
1223  rewrites.resize(numRewritesToKeep);
1224 }
1225 
1227  StringRef valueDiagTag, std::optional<Location> inputLoc,
1228  PatternRewriter &rewriter, ValueRange values,
1229  SmallVector<ValueVector> &remapped) {
1230  remapped.reserve(llvm::size(values));
1231 
1232  for (const auto &it : llvm::enumerate(values)) {
1233  Value operand = it.value();
1234  Type origType = operand.getType();
1235  Location operandLoc = inputLoc ? *inputLoc : operand.getLoc();
1236 
1237  if (!currentTypeConverter) {
1238  // The current pattern does not have a type converter. I.e., it does not
1239  // distinguish between legal and illegal types. For each operand, simply
1240  // pass through the most recently mapped values.
1241  remapped.push_back(mapping.lookupOrDefault(operand));
1242  continue;
1243  }
1244 
1245  // If there is no legal conversion, fail to match this pattern.
1246  SmallVector<Type, 1> legalTypes;
1247  if (failed(currentTypeConverter->convertType(origType, legalTypes))) {
1248  notifyMatchFailure(operandLoc, [=](Diagnostic &diag) {
1249  diag << "unable to convert type for " << valueDiagTag << " #"
1250  << it.index() << ", type was " << origType;
1251  });
1252  return failure();
1253  }
1254  // If a type is converted to 0 types, there is nothing to do.
1255  if (legalTypes.empty()) {
1256  remapped.push_back({});
1257  continue;
1258  }
1259 
1260  ValueVector repl = mapping.lookupOrDefault(operand, legalTypes);
1261  if (!repl.empty() && TypeRange(ValueRange(repl)) == legalTypes) {
1262  // Mapped values have the correct type or there is an existing
1263  // materialization. Or the operand is not mapped at all and has the
1264  // correct type.
1265  remapped.push_back(std::move(repl));
1266  continue;
1267  }
1268 
1269  // Create a materialization for the most recently mapped values.
1270  repl = mapping.lookupOrDefault(operand);
1272  MaterializationKind::Target, computeInsertPoint(repl), operandLoc,
1273  /*valuesToMap=*/repl, /*inputs=*/repl, /*outputTypes=*/legalTypes,
1274  /*originalType=*/origType, currentTypeConverter);
1275  remapped.push_back(castValues);
1276  }
1277  return success();
1278 }
1279 
1281  // Check to see if this operation is ignored or was replaced.
1282  return replacedOps.count(op) || ignoredOps.count(op);
1283 }
1284 
1286  // Check to see if this operation was replaced.
1287  return replacedOps.count(op);
1288 }
1289 
1290 //===----------------------------------------------------------------------===//
1291 // Type Conversion
1292 //===----------------------------------------------------------------------===//
1293 
1295  ConversionPatternRewriter &rewriter, Region *region,
1296  const TypeConverter &converter,
1297  TypeConverter::SignatureConversion *entryConversion) {
1298  regionToConverter[region] = &converter;
1299  if (region->empty())
1300  return nullptr;
1301 
1302  // Convert the arguments of each non-entry block within the region.
1303  for (Block &block :
1304  llvm::make_early_inc_range(llvm::drop_begin(*region, 1))) {
1305  // Compute the signature for the block with the provided converter.
1306  std::optional<TypeConverter::SignatureConversion> conversion =
1307  converter.convertBlockSignature(&block);
1308  if (!conversion)
1309  return failure();
1310  // Convert the block with the computed signature.
1311  applySignatureConversion(rewriter, &block, &converter, *conversion);
1312  }
1313 
1314  // Convert the entry block. If an entry signature conversion was provided,
1315  // use that one. Otherwise, compute the signature with the type converter.
1316  if (entryConversion)
1317  return applySignatureConversion(rewriter, &region->front(), &converter,
1318  *entryConversion);
1319  std::optional<TypeConverter::SignatureConversion> conversion =
1320  converter.convertBlockSignature(&region->front());
1321  if (!conversion)
1322  return failure();
1323  return applySignatureConversion(rewriter, &region->front(), &converter,
1324  *conversion);
1325 }
1326 
1328  ConversionPatternRewriter &rewriter, Block *block,
1329  const TypeConverter *converter,
1330  TypeConverter::SignatureConversion &signatureConversion) {
1331 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
1332  // A block cannot be converted multiple times.
1333  if (hasRewrite<BlockTypeConversionRewrite>(rewrites, block))
1334  llvm::report_fatal_error("block was already converted");
1335 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
1336 
1337  OpBuilder::InsertionGuard g(rewriter);
1338 
1339  // If no arguments are being changed or added, there is nothing to do.
1340  unsigned origArgCount = block->getNumArguments();
1341  auto convertedTypes = signatureConversion.getConvertedTypes();
1342  if (llvm::equal(block->getArgumentTypes(), convertedTypes))
1343  return block;
1344 
1345  // Compute the locations of all block arguments in the new block.
1346  SmallVector<Location> newLocs(convertedTypes.size(),
1347  rewriter.getUnknownLoc());
1348  for (unsigned i = 0; i < origArgCount; ++i) {
1349  auto inputMap = signatureConversion.getInputMapping(i);
1350  if (!inputMap || inputMap->replacedWithValues())
1351  continue;
1352  Location origLoc = block->getArgument(i).getLoc();
1353  for (unsigned j = 0; j < inputMap->size; ++j)
1354  newLocs[inputMap->inputNo + j] = origLoc;
1355  }
1356 
1357  // Insert a new block with the converted block argument types and move all ops
1358  // from the old block to the new block.
1359  Block *newBlock =
1360  rewriter.createBlock(block->getParent(), std::next(block->getIterator()),
1361  convertedTypes, newLocs);
1362 
1363  // If a listener is attached to the dialect conversion, ops cannot be moved
1364  // to the destination block in bulk ("fast path"). This is because at the time
1365  // the notifications are sent, it is unknown which ops were moved. Instead,
1366  // ops should be moved one-by-one ("slow path"), so that a separate
1367  // `MoveOperationRewrite` is enqueued for each moved op. Moving ops in bulk is
1368  // a bit more efficient, so we try to do that when possible.
1369  bool fastPath = !config.listener;
1370  if (fastPath) {
1371  appendRewrite<InlineBlockRewrite>(newBlock, block, newBlock->end());
1372  newBlock->getOperations().splice(newBlock->end(), block->getOperations());
1373  } else {
1374  while (!block->empty())
1375  rewriter.moveOpBefore(&block->front(), newBlock, newBlock->end());
1376  }
1377 
1378  // Replace all uses of the old block with the new block.
1379  block->replaceAllUsesWith(newBlock);
1380 
1381  for (unsigned i = 0; i != origArgCount; ++i) {
1382  BlockArgument origArg = block->getArgument(i);
1383  Type origArgType = origArg.getType();
1384 
1385  std::optional<TypeConverter::SignatureConversion::InputMapping> inputMap =
1386  signatureConversion.getInputMapping(i);
1387  if (!inputMap) {
1388  // This block argument was dropped and no replacement value was provided.
1389  // Materialize a replacement value "out of thin air".
1391  MaterializationKind::Source,
1392  OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(),
1393  /*valuesToMap=*/{origArg}, /*inputs=*/ValueRange(),
1394  /*outputTypes=*/origArgType, /*originalType=*/Type(), converter);
1395  appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
1396  continue;
1397  }
1398 
1399  if (inputMap->replacedWithValues()) {
1400  // This block argument was dropped and replacement values were provided.
1401  assert(inputMap->size == 0 &&
1402  "invalid to provide a replacement value when the argument isn't "
1403  "dropped");
1404  mapping.map(origArg, inputMap->replacementValues);
1405  appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
1406  continue;
1407  }
1408 
1409  // This is a 1->1+ mapping.
1410  auto replArgs =
1411  newBlock->getArguments().slice(inputMap->inputNo, inputMap->size);
1412  ValueVector replArgVals = llvm::to_vector_of<Value, 1>(replArgs);
1413  mapping.map(origArg, std::move(replArgVals));
1414  appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
1415  }
1416 
1417  appendRewrite<BlockTypeConversionRewrite>(/*origBlock=*/block, newBlock);
1418 
1419  // Erase the old block. (It is just unlinked for now and will be erased during
1420  // cleanup.)
1421  rewriter.eraseBlock(block);
1422 
1423  return newBlock;
1424 }
1425 
1426 //===----------------------------------------------------------------------===//
1427 // Materializations
1428 //===----------------------------------------------------------------------===//
1429 
1430 /// Build an unresolved materialization operation given an output type and set
1431 /// of input operands.
1433  MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
1434  ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes,
1435  Type originalType, const TypeConverter *converter,
1436  UnrealizedConversionCastOp *castOp) {
1437  assert((!originalType || kind == MaterializationKind::Target) &&
1438  "original type is valid only for target materializations");
1439  assert(TypeRange(inputs) != outputTypes &&
1440  "materialization is not necessary");
1441 
1442  // Create an unresolved materialization. We use a new OpBuilder to avoid
1443  // tracking the materialization like we do for other operations.
1444  OpBuilder builder(outputTypes.front().getContext());
1445  builder.setInsertionPoint(ip.getBlock(), ip.getPoint());
1446  auto convertOp =
1447  builder.create<UnrealizedConversionCastOp>(loc, outputTypes, inputs);
1448  if (!valuesToMap.empty())
1449  mapping.map(valuesToMap, convertOp.getResults());
1450  if (castOp)
1451  *castOp = convertOp;
1452  appendRewrite<UnresolvedMaterializationRewrite>(
1453  convertOp, converter, kind, originalType, std::move(valuesToMap));
1454  return convertOp.getResults();
1455 }
1456 
1458  Value value, const TypeConverter *converter) {
1459  // Try to find a replacement value with the same type in the conversion value
1460  // mapping. This includes cached materializations. We try to reuse those
1461  // instead of generating duplicate IR.
1462  ValueVector repl = mapping.lookupOrNull(value, value.getType());
1463  if (!repl.empty())
1464  return repl.front();
1465 
1466  // Check if the value is dead. No replacement value is needed in that case.
1467  // This is an approximate check that may have false negatives but does not
1468  // require computing and traversing an inverse mapping. (We may end up
1469  // building source materializations that are never used and that fold away.)
1470  if (llvm::all_of(value.getUsers(),
1471  [&](Operation *op) { return replacedOps.contains(op); }) &&
1472  !mapping.isMappedTo(value))
1473  return Value();
1474 
1475  // No replacement value was found. Get the latest replacement value
1476  // (regardless of the type) and build a source materialization to the
1477  // original type.
1478  repl = mapping.lookupOrNull(value);
1479  if (repl.empty()) {
1480  // No replacement value is registered in the mapping. This means that the
1481  // value is dropped and no longer needed. (If the value were still needed,
1482  // a source materialization producing a replacement value "out of thin air"
1483  // would have already been created during `replaceOp` or
1484  // `applySignatureConversion`.)
1485  return Value();
1486  }
1487 
1488  // Note: `computeInsertPoint` computes the "earliest" insertion point at
1489  // which all values in `repl` are defined. It is important to emit the
1490  // materialization at that location because the same materialization may be
1491  // reused in a different context. (That's because materializations are cached
1492  // in the conversion value mapping.) The insertion point of the
1493  // materialization must be valid for all future users that may be created
1494  // later in the conversion process.
1495  Value castValue =
1496  buildUnresolvedMaterialization(MaterializationKind::Source,
1497  computeInsertPoint(repl), value.getLoc(),
1498  /*valuesToMap=*/repl, /*inputs=*/repl,
1499  /*outputTypes=*/value.getType(),
1500  /*originalType=*/Type(), converter)
1501  .front();
1502  return castValue;
1503 }
1504 
1505 //===----------------------------------------------------------------------===//
1506 // Rewriter Notification Hooks
1507 //===----------------------------------------------------------------------===//
1508 
1510  Operation *op, OpBuilder::InsertPoint previous) {
1511  LLVM_DEBUG({
1512  logger.startLine() << "** Insert : '" << op->getName() << "'(" << op
1513  << ")\n";
1514  });
1515  assert(!wasOpReplaced(op->getParentOp()) &&
1516  "attempting to insert into a block within a replaced/erased op");
1517 
1518  if (!previous.isSet()) {
1519  // This is a newly created op.
1520  appendRewrite<CreateOperationRewrite>(op);
1521  return;
1522  }
1523  Operation *prevOp = previous.getPoint() == previous.getBlock()->end()
1524  ? nullptr
1525  : &*previous.getPoint();
1526  appendRewrite<MoveOperationRewrite>(op, previous.getBlock(), prevOp);
1527 }
1528 
1530  Operation *op, SmallVector<SmallVector<Value>> &&newValues) {
1531  assert(newValues.size() == op->getNumResults());
1532  assert(!ignoredOps.contains(op) && "operation was already replaced");
1533 
1534  // Check if replaced op is an unresolved materialization, i.e., an
1535  // unrealized_conversion_cast op that was created by the conversion driver.
1536  bool isUnresolvedMaterialization = false;
1537  if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op))
1538  if (unresolvedMaterializations.contains(castOp))
1539  isUnresolvedMaterialization = true;
1540 
1541  // Create mappings for each of the new result values.
1542  for (auto [repl, result] : llvm::zip_equal(newValues, op->getResults())) {
1543  if (repl.empty()) {
1544  // This result was dropped and no replacement value was provided.
1545  if (isUnresolvedMaterialization) {
1546  // Do not create another materializations if we are erasing a
1547  // materialization.
1548  continue;
1549  }
1550 
1551  // Materialize a replacement value "out of thin air".
1553  MaterializationKind::Source, computeInsertPoint(result),
1554  result.getLoc(), /*valuesToMap=*/{result}, /*inputs=*/ValueRange(),
1555  /*outputTypes=*/result.getType(), /*originalType=*/Type(),
1557  continue;
1558  } else {
1559  // Make sure that the user does not mess with unresolved materializations
1560  // that were inserted by the conversion driver. We keep track of these
1561  // ops in internal data structures. Erasing them must be allowed because
1562  // this can happen when the user is erasing an entire block (including
1563  // its body). But replacing them with another value should be forbidden
1564  // to avoid problems with the `mapping`.
1565  assert(!isUnresolvedMaterialization &&
1566  "attempting to replace an unresolved materialization");
1567  }
1568 
1569  // Remap result to replacement value.
1570  if (repl.empty())
1571  continue;
1572  mapping.map(static_cast<Value>(result), std::move(repl));
1573  }
1574 
1575  appendRewrite<ReplaceOperationRewrite>(op, currentTypeConverter);
1576  // Mark this operation and all nested ops as replaced.
1577  op->walk([&](Operation *op) { replacedOps.insert(op); });
1578 }
1579 
1581  appendRewrite<EraseBlockRewrite>(block);
1582 }
1583 
1585  Block *block, Region *previous, Region::iterator previousIt) {
1586  assert(!wasOpReplaced(block->getParentOp()) &&
1587  "attempting to insert into a region within a replaced/erased op");
1588  LLVM_DEBUG(
1589  {
1590  Operation *parent = block->getParentOp();
1591  if (parent) {
1592  logger.startLine() << "** Insert Block into : '" << parent->getName()
1593  << "'(" << parent << ")\n";
1594  } else {
1595  logger.startLine()
1596  << "** Insert Block into detached Region (nullptr parent op)'";
1597  }
1598  });
1599 
1600  if (!previous) {
1601  // This is a newly created block.
1602  appendRewrite<CreateBlockRewrite>(block);
1603  return;
1604  }
1605  Block *prevBlock = previousIt == previous->end() ? nullptr : &*previousIt;
1606  appendRewrite<MoveBlockRewrite>(block, previous, prevBlock);
1607 }
1608 
1610  Block *block, Block *srcBlock, Block::iterator before) {
1611  appendRewrite<InlineBlockRewrite>(block, srcBlock, before);
1612 }
1613 
1615  Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
1616  LLVM_DEBUG({
1618  reasonCallback(diag);
1619  logger.startLine() << "** Failure : " << diag.str() << "\n";
1620  if (config.notifyCallback)
1622  });
1623 }
1624 
1625 //===----------------------------------------------------------------------===//
1626 // ConversionPatternRewriter
1627 //===----------------------------------------------------------------------===//
1628 
1629 ConversionPatternRewriter::ConversionPatternRewriter(
1630  MLIRContext *ctx, const ConversionConfig &config)
1631  : PatternRewriter(ctx),
1632  impl(new detail::ConversionPatternRewriterImpl(ctx, config)) {
1633  setListener(impl.get());
1634 }
1635 
1637 
1639  assert(op && newOp && "expected non-null op");
1640  replaceOp(op, newOp->getResults());
1641 }
1642 
1644  assert(op->getNumResults() == newValues.size() &&
1645  "incorrect # of replacement values");
1646  LLVM_DEBUG({
1647  impl->logger.startLine()
1648  << "** Replace : '" << op->getName() << "'(" << op << ")\n";
1649  });
1651  llvm::map_to_vector(newValues, [](Value v) -> SmallVector<Value> {
1652  return v ? SmallVector<Value>{v} : SmallVector<Value>();
1653  });
1654  impl->notifyOpReplaced(op, std::move(newVals));
1655 }
1656 
1658  Operation *op, SmallVector<SmallVector<Value>> &&newValues) {
1659  assert(op->getNumResults() == newValues.size() &&
1660  "incorrect # of replacement values");
1661  LLVM_DEBUG({
1662  impl->logger.startLine()
1663  << "** Replace : '" << op->getName() << "'(" << op << ")\n";
1664  });
1665  impl->notifyOpReplaced(op, std::move(newValues));
1666 }
1667 
1669  LLVM_DEBUG({
1670  impl->logger.startLine()
1671  << "** Erase : '" << op->getName() << "'(" << op << ")\n";
1672  });
1673  SmallVector<SmallVector<Value>> nullRepls(op->getNumResults(), {});
1674  impl->notifyOpReplaced(op, std::move(nullRepls));
1675 }
1676 
1678  assert(!impl->wasOpReplaced(block->getParentOp()) &&
1679  "attempting to erase a block within a replaced/erased op");
1680 
1681  // Mark all ops for erasure.
1682  for (Operation &op : *block)
1683  eraseOp(&op);
1684 
1685  // Unlink the block from its parent region. The block is kept in the rewrite
1686  // object and will be actually destroyed when rewrites are applied. This
1687  // allows us to keep the operations in the block live and undo the removal by
1688  // re-inserting the block.
1689  impl->notifyBlockIsBeingErased(block);
1690  block->getParent()->getBlocks().remove(block);
1691 }
1692 
1694  Block *block, TypeConverter::SignatureConversion &conversion,
1695  const TypeConverter *converter) {
1696  assert(!impl->wasOpReplaced(block->getParentOp()) &&
1697  "attempting to apply a signature conversion to a block within a "
1698  "replaced/erased op");
1699  return impl->applySignatureConversion(*this, block, converter, conversion);
1700 }
1701 
1703  Region *region, const TypeConverter &converter,
1704  TypeConverter::SignatureConversion *entryConversion) {
1705  assert(!impl->wasOpReplaced(region->getParentOp()) &&
1706  "attempting to apply a signature conversion to a block within a "
1707  "replaced/erased op");
1708  return impl->convertRegionTypes(*this, region, converter, entryConversion);
1709 }
1710 
1712  Value to) {
1713  LLVM_DEBUG({
1714  impl->logger.startLine() << "** Replace Argument : '" << from << "'";
1715  if (Operation *parentOp = from.getOwner()->getParentOp()) {
1716  impl->logger.getOStream() << " (in region of '" << parentOp->getName()
1717  << "' (" << parentOp << ")\n";
1718  } else {
1719  impl->logger.getOStream() << " (unlinked block)\n";
1720  }
1721  });
1722  impl->appendRewrite<ReplaceBlockArgRewrite>(from.getOwner(), from,
1723  impl->currentTypeConverter);
1724  impl->mapping.map(impl->mapping.lookupOrDefault(from), to);
1725 }
1726 
1728  SmallVector<ValueVector> remappedValues;
1729  if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, key,
1730  remappedValues)))
1731  return nullptr;
1732  assert(remappedValues.front().size() == 1 && "1:N conversion not supported");
1733  return remappedValues.front().front();
1734 }
1735 
1736 LogicalResult
1738  SmallVectorImpl<Value> &results) {
1739  if (keys.empty())
1740  return success();
1741  SmallVector<ValueVector> remapped;
1742  if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, keys,
1743  remapped)))
1744  return failure();
1745  for (const auto &values : remapped) {
1746  assert(values.size() == 1 && "1:N conversion not supported");
1747  results.push_back(values.front());
1748  }
1749  return success();
1750 }
1751 
1753  Block::iterator before,
1754  ValueRange argValues) {
1755 #ifndef NDEBUG
1756  assert(argValues.size() == source->getNumArguments() &&
1757  "incorrect # of argument replacement values");
1758  assert(!impl->wasOpReplaced(source->getParentOp()) &&
1759  "attempting to inline a block from a replaced/erased op");
1760  assert(!impl->wasOpReplaced(dest->getParentOp()) &&
1761  "attempting to inline a block into a replaced/erased op");
1762  auto opIgnored = [&](Operation *op) { return impl->isOpIgnored(op); };
1763  // The source block will be deleted, so it should not have any users (i.e.,
1764  // there should be no predecessors).
1765  assert(llvm::all_of(source->getUsers(), opIgnored) &&
1766  "expected 'source' to have no predecessors");
1767 #endif // NDEBUG
1768 
1769  // If a listener is attached to the dialect conversion, ops cannot be moved
1770  // to the destination block in bulk ("fast path"). This is because at the time
1771  // the notifications are sent, it is unknown which ops were moved. Instead,
1772  // ops should be moved one-by-one ("slow path"), so that a separate
1773  // `MoveOperationRewrite` is enqueued for each moved op. Moving ops in bulk is
1774  // a bit more efficient, so we try to do that when possible.
1775  bool fastPath = !impl->config.listener;
1776 
1777  if (fastPath)
1778  impl->notifyBlockBeingInlined(dest, source, before);
1779 
1780  // Replace all uses of block arguments.
1781  for (auto it : llvm::zip(source->getArguments(), argValues))
1782  replaceUsesOfBlockArgument(std::get<0>(it), std::get<1>(it));
1783 
1784  if (fastPath) {
1785  // Move all ops at once.
1786  dest->getOperations().splice(before, source->getOperations());
1787  } else {
1788  // Move op by op.
1789  while (!source->empty())
1790  moveOpBefore(&source->front(), dest, before);
1791  }
1792 
1793  // Erase the source block.
1794  eraseBlock(source);
1795 }
1796 
1798  assert(!impl->wasOpReplaced(op) &&
1799  "attempting to modify a replaced/erased op");
1800 #ifndef NDEBUG
1801  impl->pendingRootUpdates.insert(op);
1802 #endif
1803  impl->appendRewrite<ModifyOperationRewrite>(op);
1804 }
1805 
1807  assert(!impl->wasOpReplaced(op) &&
1808  "attempting to modify a replaced/erased op");
1810  // There is nothing to do here, we only need to track the operation at the
1811  // start of the update.
1812 #ifndef NDEBUG
1813  assert(impl->pendingRootUpdates.erase(op) &&
1814  "operation did not have a pending in-place update");
1815 #endif
1816 }
1817 
1819 #ifndef NDEBUG
1820  assert(impl->pendingRootUpdates.erase(op) &&
1821  "operation did not have a pending in-place update");
1822 #endif
1823  // Erase the last update for this operation.
1824  auto it = llvm::find_if(
1825  llvm::reverse(impl->rewrites), [&](std::unique_ptr<IRRewrite> &rewrite) {
1826  auto *modifyRewrite = dyn_cast<ModifyOperationRewrite>(rewrite.get());
1827  return modifyRewrite && modifyRewrite->getOperation() == op;
1828  });
1829  assert(it != impl->rewrites.rend() && "no root update started on op");
1830  (*it)->rollback();
1831  int updateIdx = std::prev(impl->rewrites.rend()) - it;
1832  impl->rewrites.erase(impl->rewrites.begin() + updateIdx);
1833 }
1834 
1836  return *impl;
1837 }
1838 
1839 //===----------------------------------------------------------------------===//
1840 // ConversionPattern
1841 //===----------------------------------------------------------------------===//
1842 
1844  ArrayRef<ValueRange> operands) const {
1845  SmallVector<Value> oneToOneOperands;
1846  oneToOneOperands.reserve(operands.size());
1847  for (ValueRange operand : operands) {
1848  if (operand.size() != 1)
1849  llvm::report_fatal_error("pattern '" + getDebugName() +
1850  "' does not support 1:N conversion");
1851  oneToOneOperands.push_back(operand.front());
1852  }
1853  return oneToOneOperands;
1854 }
1855 
1856 LogicalResult
1858  PatternRewriter &rewriter) const {
1859  auto &dialectRewriter = static_cast<ConversionPatternRewriter &>(rewriter);
1860  auto &rewriterImpl = dialectRewriter.getImpl();
1861 
1862  // Track the current conversion pattern type converter in the rewriter.
1863  llvm::SaveAndRestore currentConverterGuard(rewriterImpl.currentTypeConverter,
1864  getTypeConverter());
1865 
1866  // Remap the operands of the operation.
1867  SmallVector<ValueVector> remapped;
1868  if (failed(rewriterImpl.remapValues("operand", op->getLoc(), rewriter,
1869  op->getOperands(), remapped))) {
1870  return failure();
1871  }
1872  SmallVector<ValueRange> remappedAsRange =
1873  llvm::to_vector_of<ValueRange>(remapped);
1874  return matchAndRewrite(op, remappedAsRange, dialectRewriter);
1875 }
1876 
1877 //===----------------------------------------------------------------------===//
1878 // OperationLegalizer
1879 //===----------------------------------------------------------------------===//
1880 
1881 namespace {
1882 /// A set of rewrite patterns that can be used to legalize a given operation.
1883 using LegalizationPatterns = SmallVector<const Pattern *, 1>;
1884 
1885 /// This class defines a recursive operation legalizer.
1886 class OperationLegalizer {
1887 public:
1888  using LegalizationAction = ConversionTarget::LegalizationAction;
1889 
1890  OperationLegalizer(const ConversionTarget &targetInfo,
1892  const ConversionConfig &config);
1893 
1894  /// Returns true if the given operation is known to be illegal on the target.
1895  bool isIllegal(Operation *op) const;
1896 
1897  /// Attempt to legalize the given operation. Returns success if the operation
1898  /// was legalized, failure otherwise.
1899  LogicalResult legalize(Operation *op, ConversionPatternRewriter &rewriter);
1900 
1901  /// Returns the conversion target in use by the legalizer.
1902  const ConversionTarget &getTarget() { return target; }
1903 
1904 private:
1905  /// Attempt to legalize the given operation by folding it.
1906  LogicalResult legalizeWithFold(Operation *op,
1907  ConversionPatternRewriter &rewriter);
1908 
1909  /// Attempt to legalize the given operation by applying a pattern. Returns
1910  /// success if the operation was legalized, failure otherwise.
1911  LogicalResult legalizeWithPattern(Operation *op,
1912  ConversionPatternRewriter &rewriter);
1913 
1914  /// Return true if the given pattern may be applied to the given operation,
1915  /// false otherwise.
1916  bool canApplyPattern(Operation *op, const Pattern &pattern,
1917  ConversionPatternRewriter &rewriter);
1918 
1919  /// Legalize the resultant IR after successfully applying the given pattern.
1920  LogicalResult legalizePatternResult(Operation *op, const Pattern &pattern,
1921  ConversionPatternRewriter &rewriter,
1922  RewriterState &curState);
1923 
1924  /// Legalizes the actions registered during the execution of a pattern.
1925  LogicalResult
1926  legalizePatternBlockRewrites(Operation *op,
1927  ConversionPatternRewriter &rewriter,
1929  RewriterState &state, RewriterState &newState);
1930  LogicalResult legalizePatternCreatedOperations(
1932  RewriterState &state, RewriterState &newState);
1933  LogicalResult legalizePatternRootUpdates(ConversionPatternRewriter &rewriter,
1935  RewriterState &state,
1936  RewriterState &newState);
1937 
1938  //===--------------------------------------------------------------------===//
1939  // Cost Model
1940  //===--------------------------------------------------------------------===//
1941 
1942  /// Build an optimistic legalization graph given the provided patterns. This
1943  /// function populates 'anyOpLegalizerPatterns' and 'legalizerPatterns' with
1944  /// patterns for operations that are not directly legal, but may be
1945  /// transitively legal for the current target given the provided patterns.
1946  void buildLegalizationGraph(
1947  LegalizationPatterns &anyOpLegalizerPatterns,
1949 
1950  /// Compute the benefit of each node within the computed legalization graph.
1951  /// This orders the patterns within 'legalizerPatterns' based upon two
1952  /// criteria:
1953  /// 1) Prefer patterns that have the lowest legalization depth, i.e.
1954  /// represent the more direct mapping to the target.
1955  /// 2) When comparing patterns with the same legalization depth, prefer the
1956  /// pattern with the highest PatternBenefit. This allows for users to
1957  /// prefer specific legalizations over others.
1958  void computeLegalizationGraphBenefit(
1959  LegalizationPatterns &anyOpLegalizerPatterns,
1961 
1962  /// Compute the legalization depth when legalizing an operation of the given
1963  /// type.
1964  unsigned computeOpLegalizationDepth(
1965  OperationName op, DenseMap<OperationName, unsigned> &minOpPatternDepth,
1967 
1968  /// Apply the conversion cost model to the given set of patterns, and return
1969  /// the smallest legalization depth of any of the patterns. See
1970  /// `computeLegalizationGraphBenefit` for the breakdown of the cost model.
1971  unsigned applyCostModelToPatterns(
1972  LegalizationPatterns &patterns,
1973  DenseMap<OperationName, unsigned> &minOpPatternDepth,
1975 
1976  /// The current set of patterns that have been applied.
1977  SmallPtrSet<const Pattern *, 8> appliedPatterns;
1978 
1979  /// The legalization information provided by the target.
1980  const ConversionTarget &target;
1981 
1982  /// The pattern applicator to use for conversions.
1983  PatternApplicator applicator;
1984 
1985  /// Dialect conversion configuration.
1986  const ConversionConfig &config;
1987 };
1988 } // namespace
1989 
1990 OperationLegalizer::OperationLegalizer(const ConversionTarget &targetInfo,
1992  const ConversionConfig &config)
1993  : target(targetInfo), applicator(patterns), config(config) {
1994  // The set of patterns that can be applied to illegal operations to transform
1995  // them into legal ones.
1997  LegalizationPatterns anyOpLegalizerPatterns;
1998 
1999  buildLegalizationGraph(anyOpLegalizerPatterns, legalizerPatterns);
2000  computeLegalizationGraphBenefit(anyOpLegalizerPatterns, legalizerPatterns);
2001 }
2002 
2003 bool OperationLegalizer::isIllegal(Operation *op) const {
2004  return target.isIllegal(op);
2005 }
2006 
2007 LogicalResult
2008 OperationLegalizer::legalize(Operation *op,
2009  ConversionPatternRewriter &rewriter) {
2010 #ifndef NDEBUG
2011  const char *logLineComment =
2012  "//===-------------------------------------------===//\n";
2013 
2014  auto &logger = rewriter.getImpl().logger;
2015 #endif
2016  LLVM_DEBUG({
2017  logger.getOStream() << "\n";
2018  logger.startLine() << logLineComment;
2019  logger.startLine() << "Legalizing operation : '" << op->getName() << "'("
2020  << op << ") {\n";
2021  logger.indent();
2022 
2023  // If the operation has no regions, just print it here.
2024  if (op->getNumRegions() == 0) {
2025  op->print(logger.startLine(), OpPrintingFlags().printGenericOpForm());
2026  logger.getOStream() << "\n\n";
2027  }
2028  });
2029 
2030  // Check if this operation is legal on the target.
2031  if (auto legalityInfo = target.isLegal(op)) {
2032  LLVM_DEBUG({
2033  logSuccess(
2034  logger, "operation marked legal by the target{0}",
2035  legalityInfo->isRecursivelyLegal
2036  ? "; NOTE: operation is recursively legal; skipping internals"
2037  : "");
2038  logger.startLine() << logLineComment;
2039  });
2040 
2041  // If this operation is recursively legal, mark its children as ignored so
2042  // that we don't consider them for legalization.
2043  if (legalityInfo->isRecursivelyLegal) {
2044  op->walk([&](Operation *nested) {
2045  if (op != nested)
2046  rewriter.getImpl().ignoredOps.insert(nested);
2047  });
2048  }
2049 
2050  return success();
2051  }
2052 
2053  // Check to see if the operation is ignored and doesn't need to be converted.
2054  if (rewriter.getImpl().isOpIgnored(op)) {
2055  LLVM_DEBUG({
2056  logSuccess(logger, "operation marked 'ignored' during conversion");
2057  logger.startLine() << logLineComment;
2058  });
2059  return success();
2060  }
2061 
2062  // If the operation isn't legal, try to fold it in-place.
2063  // TODO: Should we always try to do this, even if the op is
2064  // already legal?
2065  if (succeeded(legalizeWithFold(op, rewriter))) {
2066  LLVM_DEBUG({
2067  logSuccess(logger, "operation was folded");
2068  logger.startLine() << logLineComment;
2069  });
2070  return success();
2071  }
2072 
2073  // Otherwise, we need to apply a legalization pattern to this operation.
2074  if (succeeded(legalizeWithPattern(op, rewriter))) {
2075  LLVM_DEBUG({
2076  logSuccess(logger, "");
2077  logger.startLine() << logLineComment;
2078  });
2079  return success();
2080  }
2081 
2082  LLVM_DEBUG({
2083  logFailure(logger, "no matched legalization pattern");
2084  logger.startLine() << logLineComment;
2085  });
2086  return failure();
2087 }
2088 
2089 LogicalResult
2090 OperationLegalizer::legalizeWithFold(Operation *op,
2091  ConversionPatternRewriter &rewriter) {
2092  auto &rewriterImpl = rewriter.getImpl();
2093  RewriterState curState = rewriterImpl.getCurrentState();
2094 
2095  LLVM_DEBUG({
2096  rewriterImpl.logger.startLine() << "* Fold {\n";
2097  rewriterImpl.logger.indent();
2098  });
2099 
2100  // Try to fold the operation.
2101  SmallVector<Value, 2> replacementValues;
2102  rewriter.setInsertionPoint(op);
2103  if (failed(rewriter.tryFold(op, replacementValues))) {
2104  LLVM_DEBUG(logFailure(rewriterImpl.logger, "unable to fold"));
2105  return failure();
2106  }
2107  // An empty list of replacement values indicates that the fold was in-place.
2108  // As the operation changed, a new legalization needs to be attempted.
2109  if (replacementValues.empty())
2110  return legalize(op, rewriter);
2111 
2112  // Recursively legalize any new constant operations.
2113  for (unsigned i = curState.numRewrites, e = rewriterImpl.rewrites.size();
2114  i != e; ++i) {
2115  auto *createOp =
2116  dyn_cast<CreateOperationRewrite>(rewriterImpl.rewrites[i].get());
2117  if (!createOp)
2118  continue;
2119  if (failed(legalize(createOp->getOperation(), rewriter))) {
2120  LLVM_DEBUG(logFailure(rewriterImpl.logger,
2121  "failed to legalize generated constant '{0}'",
2122  createOp->getOperation()->getName()));
2123  rewriterImpl.resetState(curState);
2124  return failure();
2125  }
2126  }
2127 
2128  // Insert a replacement for 'op' with the folded replacement values.
2129  rewriter.replaceOp(op, replacementValues);
2130 
2131  LLVM_DEBUG(logSuccess(rewriterImpl.logger, ""));
2132  return success();
2133 }
2134 
2135 LogicalResult
2136 OperationLegalizer::legalizeWithPattern(Operation *op,
2137  ConversionPatternRewriter &rewriter) {
2138  auto &rewriterImpl = rewriter.getImpl();
2139 
2140  // Functor that returns if the given pattern may be applied.
2141  auto canApply = [&](const Pattern &pattern) {
2142  bool canApply = canApplyPattern(op, pattern, rewriter);
2143  if (canApply && config.listener)
2144  config.listener->notifyPatternBegin(pattern, op);
2145  return canApply;
2146  };
2147 
2148  // Functor that cleans up the rewriter state after a pattern failed to match.
2149  RewriterState curState = rewriterImpl.getCurrentState();
2150  auto onFailure = [&](const Pattern &pattern) {
2151  assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates");
2152  LLVM_DEBUG({
2153  logFailure(rewriterImpl.logger, "pattern failed to match");
2154  if (rewriterImpl.config.notifyCallback) {
2156  diag << "Failed to apply pattern \"" << pattern.getDebugName()
2157  << "\" on op:\n"
2158  << *op;
2159  rewriterImpl.config.notifyCallback(diag);
2160  }
2161  });
2162  if (config.listener)
2163  config.listener->notifyPatternEnd(pattern, failure());
2164  rewriterImpl.resetState(curState);
2165  appliedPatterns.erase(&pattern);
2166  };
2167 
2168  // Functor that performs additional legalization when a pattern is
2169  // successfully applied.
2170  auto onSuccess = [&](const Pattern &pattern) {
2171  assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates");
2172  auto result = legalizePatternResult(op, pattern, rewriter, curState);
2173  appliedPatterns.erase(&pattern);
2174  if (failed(result))
2175  rewriterImpl.resetState(curState);
2176  if (config.listener)
2177  config.listener->notifyPatternEnd(pattern, result);
2178  return result;
2179  };
2180 
2181  // Try to match and rewrite a pattern on this operation.
2182  return applicator.matchAndRewrite(op, rewriter, canApply, onFailure,
2183  onSuccess);
2184 }
2185 
2186 bool OperationLegalizer::canApplyPattern(Operation *op, const Pattern &pattern,
2187  ConversionPatternRewriter &rewriter) {
2188  LLVM_DEBUG({
2189  auto &os = rewriter.getImpl().logger;
2190  os.getOStream() << "\n";
2191  os.startLine() << "* Pattern : '" << op->getName() << " -> (";
2192  llvm::interleaveComma(pattern.getGeneratedOps(), os.getOStream());
2193  os.getOStream() << ")' {\n";
2194  os.indent();
2195  });
2196 
2197  // Ensure that we don't cycle by not allowing the same pattern to be
2198  // applied twice in the same recursion stack if it is not known to be safe.
2199  if (!pattern.hasBoundedRewriteRecursion() &&
2200  !appliedPatterns.insert(&pattern).second) {
2201  LLVM_DEBUG(
2202  logFailure(rewriter.getImpl().logger, "pattern was already applied"));
2203  return false;
2204  }
2205  return true;
2206 }
2207 
2208 LogicalResult
2209 OperationLegalizer::legalizePatternResult(Operation *op, const Pattern &pattern,
2210  ConversionPatternRewriter &rewriter,
2211  RewriterState &curState) {
2212  auto &impl = rewriter.getImpl();
2213  assert(impl.pendingRootUpdates.empty() && "dangling root updates");
2214 
2215 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2216  // Check that the root was either replaced or updated in place.
2217  auto newRewrites = llvm::drop_begin(impl.rewrites, curState.numRewrites);
2218  auto replacedRoot = [&] {
2219  return hasRewrite<ReplaceOperationRewrite>(newRewrites, op);
2220  };
2221  auto updatedRootInPlace = [&] {
2222  return hasRewrite<ModifyOperationRewrite>(newRewrites, op);
2223  };
2224  if (!replacedRoot() && !updatedRootInPlace())
2225  llvm::report_fatal_error("expected pattern to replace the root operation");
2226 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2227 
2228  // Legalize each of the actions registered during application.
2229  RewriterState newState = impl.getCurrentState();
2230  if (failed(legalizePatternBlockRewrites(op, rewriter, impl, curState,
2231  newState)) ||
2232  failed(legalizePatternRootUpdates(rewriter, impl, curState, newState)) ||
2233  failed(legalizePatternCreatedOperations(rewriter, impl, curState,
2234  newState))) {
2235  return failure();
2236  }
2237 
2238  LLVM_DEBUG(logSuccess(impl.logger, "pattern applied successfully"));
2239  return success();
2240 }
2241 
2242 LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
2243  Operation *op, ConversionPatternRewriter &rewriter,
2244  ConversionPatternRewriterImpl &impl, RewriterState &state,
2245  RewriterState &newState) {
2246  SmallPtrSet<Operation *, 16> operationsToIgnore;
2247 
2248  // If the pattern moved or created any blocks, make sure the types of block
2249  // arguments get legalized.
2250  for (int i = state.numRewrites, e = newState.numRewrites; i != e; ++i) {
2251  BlockRewrite *rewrite = dyn_cast<BlockRewrite>(impl.rewrites[i].get());
2252  if (!rewrite)
2253  continue;
2254  Block *block = rewrite->getBlock();
2255  if (isa<BlockTypeConversionRewrite, EraseBlockRewrite,
2256  ReplaceBlockArgRewrite>(rewrite))
2257  continue;
2258  // Only check blocks outside of the current operation.
2259  Operation *parentOp = block->getParentOp();
2260  if (!parentOp || parentOp == op || block->getNumArguments() == 0)
2261  continue;
2262 
2263  // If the region of the block has a type converter, try to convert the block
2264  // directly.
2265  if (auto *converter = impl.regionToConverter.lookup(block->getParent())) {
2266  std::optional<TypeConverter::SignatureConversion> conversion =
2267  converter->convertBlockSignature(block);
2268  if (!conversion) {
2269  LLVM_DEBUG(logFailure(impl.logger, "failed to convert types of moved "
2270  "block"));
2271  return failure();
2272  }
2273  impl.applySignatureConversion(rewriter, block, converter, *conversion);
2274  continue;
2275  }
2276 
2277  // Otherwise, check that this operation isn't one generated by this pattern.
2278  // This is because we will attempt to legalize the parent operation, and
2279  // blocks in regions created by this pattern will already be legalized later
2280  // on. If we haven't built the set yet, build it now.
2281  if (operationsToIgnore.empty()) {
2282  for (unsigned i = state.numRewrites, e = impl.rewrites.size(); i != e;
2283  ++i) {
2284  auto *createOp =
2285  dyn_cast<CreateOperationRewrite>(impl.rewrites[i].get());
2286  if (!createOp)
2287  continue;
2288  operationsToIgnore.insert(createOp->getOperation());
2289  }
2290  }
2291 
2292  // If this operation should be considered for re-legalization, try it.
2293  if (operationsToIgnore.insert(parentOp).second &&
2294  failed(legalize(parentOp, rewriter))) {
2295  LLVM_DEBUG(logFailure(impl.logger,
2296  "operation '{0}'({1}) became illegal after rewrite",
2297  parentOp->getName(), parentOp));
2298  return failure();
2299  }
2300  }
2301  return success();
2302 }
2303 
2304 LogicalResult OperationLegalizer::legalizePatternCreatedOperations(
2306  RewriterState &state, RewriterState &newState) {
2307  for (int i = state.numRewrites, e = newState.numRewrites; i != e; ++i) {
2308  auto *createOp = dyn_cast<CreateOperationRewrite>(impl.rewrites[i].get());
2309  if (!createOp)
2310  continue;
2311  Operation *op = createOp->getOperation();
2312  if (failed(legalize(op, rewriter))) {
2313  LLVM_DEBUG(logFailure(impl.logger,
2314  "failed to legalize generated operation '{0}'({1})",
2315  op->getName(), op));
2316  return failure();
2317  }
2318  }
2319  return success();
2320 }
2321 
2322 LogicalResult OperationLegalizer::legalizePatternRootUpdates(
2324  RewriterState &state, RewriterState &newState) {
2325  for (int i = state.numRewrites, e = newState.numRewrites; i != e; ++i) {
2326  auto *rewrite = dyn_cast<ModifyOperationRewrite>(impl.rewrites[i].get());
2327  if (!rewrite)
2328  continue;
2329  Operation *op = rewrite->getOperation();
2330  if (failed(legalize(op, rewriter))) {
2331  LLVM_DEBUG(logFailure(
2332  impl.logger, "failed to legalize operation updated in-place '{0}'",
2333  op->getName()));
2334  return failure();
2335  }
2336  }
2337  return success();
2338 }
2339 
2340 //===----------------------------------------------------------------------===//
2341 // Cost Model
2342 //===----------------------------------------------------------------------===//
2343 
2344 void OperationLegalizer::buildLegalizationGraph(
2345  LegalizationPatterns &anyOpLegalizerPatterns,
2346  DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
2347  // A mapping between an operation and a set of operations that can be used to
2348  // generate it.
2350  // A mapping between an operation and any currently invalid patterns it has.
2352  // A worklist of patterns to consider for legality.
2353  SetVector<const Pattern *> patternWorklist;
2354 
2355  // Build the mapping from operations to the parent ops that may generate them.
2356  applicator.walkAllPatterns([&](const Pattern &pattern) {
2357  std::optional<OperationName> root = pattern.getRootKind();
2358 
2359  // If the pattern has no specific root, we can't analyze the relationship
2360  // between the root op and generated operations. Given that, add all such
2361  // patterns to the legalization set.
2362  if (!root) {
2363  anyOpLegalizerPatterns.push_back(&pattern);
2364  return;
2365  }
2366 
2367  // Skip operations that are always known to be legal.
2368  if (target.getOpAction(*root) == LegalizationAction::Legal)
2369  return;
2370 
2371  // Add this pattern to the invalid set for the root op and record this root
2372  // as a parent for any generated operations.
2373  invalidPatterns[*root].insert(&pattern);
2374  for (auto op : pattern.getGeneratedOps())
2375  parentOps[op].insert(*root);
2376 
2377  // Add this pattern to the worklist.
2378  patternWorklist.insert(&pattern);
2379  });
2380 
2381  // If there are any patterns that don't have a specific root kind, we can't
2382  // make direct assumptions about what operations will never be legalized.
2383  // Note: Technically we could, but it would require an analysis that may
2384  // recurse into itself. It would be better to perform this kind of filtering
2385  // at a higher level than here anyways.
2386  if (!anyOpLegalizerPatterns.empty()) {
2387  for (const Pattern *pattern : patternWorklist)
2388  legalizerPatterns[*pattern->getRootKind()].push_back(pattern);
2389  return;
2390  }
2391 
2392  while (!patternWorklist.empty()) {
2393  auto *pattern = patternWorklist.pop_back_val();
2394 
2395  // Check to see if any of the generated operations are invalid.
2396  if (llvm::any_of(pattern->getGeneratedOps(), [&](OperationName op) {
2397  std::optional<LegalizationAction> action = target.getOpAction(op);
2398  return !legalizerPatterns.count(op) &&
2399  (!action || action == LegalizationAction::Illegal);
2400  }))
2401  continue;
2402 
2403  // Otherwise, if all of the generated operation are valid, this op is now
2404  // legal so add all of the child patterns to the worklist.
2405  legalizerPatterns[*pattern->getRootKind()].push_back(pattern);
2406  invalidPatterns[*pattern->getRootKind()].erase(pattern);
2407 
2408  // Add any invalid patterns of the parent operations to see if they have now
2409  // become legal.
2410  for (auto op : parentOps[*pattern->getRootKind()])
2411  patternWorklist.set_union(invalidPatterns[op]);
2412  }
2413 }
2414 
2415 void OperationLegalizer::computeLegalizationGraphBenefit(
2416  LegalizationPatterns &anyOpLegalizerPatterns,
2417  DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
2418  // The smallest pattern depth, when legalizing an operation.
2419  DenseMap<OperationName, unsigned> minOpPatternDepth;
2420 
2421  // For each operation that is transitively legal, compute a cost for it.
2422  for (auto &opIt : legalizerPatterns)
2423  if (!minOpPatternDepth.count(opIt.first))
2424  computeOpLegalizationDepth(opIt.first, minOpPatternDepth,
2425  legalizerPatterns);
2426 
2427  // Apply the cost model to the patterns that can match any operation. Those
2428  // with a specific operation type are already resolved when computing the op
2429  // legalization depth.
2430  if (!anyOpLegalizerPatterns.empty())
2431  applyCostModelToPatterns(anyOpLegalizerPatterns, minOpPatternDepth,
2432  legalizerPatterns);
2433 
2434  // Apply a cost model to the pattern applicator. We order patterns first by
2435  // depth then benefit. `legalizerPatterns` contains per-op patterns by
2436  // decreasing benefit.
2437  applicator.applyCostModel([&](const Pattern &pattern) {
2438  ArrayRef<const Pattern *> orderedPatternList;
2439  if (std::optional<OperationName> rootName = pattern.getRootKind())
2440  orderedPatternList = legalizerPatterns[*rootName];
2441  else
2442  orderedPatternList = anyOpLegalizerPatterns;
2443 
2444  // If the pattern is not found, then it was removed and cannot be matched.
2445  auto *it = llvm::find(orderedPatternList, &pattern);
2446  if (it == orderedPatternList.end())
2448 
2449  // Patterns found earlier in the list have higher benefit.
2450  return PatternBenefit(std::distance(it, orderedPatternList.end()));
2451  });
2452 }
2453 
2454 unsigned OperationLegalizer::computeOpLegalizationDepth(
2455  OperationName op, DenseMap<OperationName, unsigned> &minOpPatternDepth,
2456  DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
2457  // Check for existing depth.
2458  auto depthIt = minOpPatternDepth.find(op);
2459  if (depthIt != minOpPatternDepth.end())
2460  return depthIt->second;
2461 
2462  // If a mapping for this operation does not exist, then this operation
2463  // is always legal. Return 0 as the depth for a directly legal operation.
2464  auto opPatternsIt = legalizerPatterns.find(op);
2465  if (opPatternsIt == legalizerPatterns.end() || opPatternsIt->second.empty())
2466  return 0u;
2467 
2468  // Record this initial depth in case we encounter this op again when
2469  // recursively computing the depth.
2470  minOpPatternDepth.try_emplace(op, std::numeric_limits<unsigned>::max());
2471 
2472  // Apply the cost model to the operation patterns, and update the minimum
2473  // depth.
2474  unsigned minDepth = applyCostModelToPatterns(
2475  opPatternsIt->second, minOpPatternDepth, legalizerPatterns);
2476  minOpPatternDepth[op] = minDepth;
2477  return minDepth;
2478 }
2479 
2480 unsigned OperationLegalizer::applyCostModelToPatterns(
2481  LegalizationPatterns &patterns,
2482  DenseMap<OperationName, unsigned> &minOpPatternDepth,
2483  DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
2484  unsigned minDepth = std::numeric_limits<unsigned>::max();
2485 
2486  // Compute the depth for each pattern within the set.
2487  SmallVector<std::pair<const Pattern *, unsigned>, 4> patternsByDepth;
2488  patternsByDepth.reserve(patterns.size());
2489  for (const Pattern *pattern : patterns) {
2490  unsigned depth = 1;
2491  for (auto generatedOp : pattern->getGeneratedOps()) {
2492  unsigned generatedOpDepth = computeOpLegalizationDepth(
2493  generatedOp, minOpPatternDepth, legalizerPatterns);
2494  depth = std::max(depth, generatedOpDepth + 1);
2495  }
2496  patternsByDepth.emplace_back(pattern, depth);
2497 
2498  // Update the minimum depth of the pattern list.
2499  minDepth = std::min(minDepth, depth);
2500  }
2501 
2502  // If the operation only has one legalization pattern, there is no need to
2503  // sort them.
2504  if (patternsByDepth.size() == 1)
2505  return minDepth;
2506 
2507  // Sort the patterns by those likely to be the most beneficial.
2508  std::stable_sort(patternsByDepth.begin(), patternsByDepth.end(),
2509  [](const std::pair<const Pattern *, unsigned> &lhs,
2510  const std::pair<const Pattern *, unsigned> &rhs) {
2511  // First sort by the smaller pattern legalization
2512  // depth.
2513  if (lhs.second != rhs.second)
2514  return lhs.second < rhs.second;
2515 
2516  // Then sort by the larger pattern benefit.
2517  auto lhsBenefit = lhs.first->getBenefit();
2518  auto rhsBenefit = rhs.first->getBenefit();
2519  return lhsBenefit > rhsBenefit;
2520  });
2521 
2522  // Update the legalization pattern to use the new sorted list.
2523  patterns.clear();
2524  for (auto &patternIt : patternsByDepth)
2525  patterns.push_back(patternIt.first);
2526  return minDepth;
2527 }
2528 
2529 //===----------------------------------------------------------------------===//
2530 // OperationConverter
2531 //===----------------------------------------------------------------------===//
2532 namespace {
2533 enum OpConversionMode {
2534  /// In this mode, the conversion will ignore failed conversions to allow
2535  /// illegal operations to co-exist in the IR.
2536  Partial,
2537 
2538  /// In this mode, all operations must be legal for the given target for the
2539  /// conversion to succeed.
2540  Full,
2541 
2542  /// In this mode, operations are analyzed for legality. No actual rewrites are
2543  /// applied to the operations on success.
2544  Analysis,
2545 };
2546 } // namespace
2547 
2548 namespace mlir {
2549 // This class converts operations to a given conversion target via a set of
2550 // rewrite patterns. The conversion behaves differently depending on the
2551 // conversion mode.
2553  explicit OperationConverter(const ConversionTarget &target,
2555  const ConversionConfig &config,
2556  OpConversionMode mode)
2557  : config(config), opLegalizer(target, patterns, this->config),
2558  mode(mode) {}
2559 
2560  /// Converts the given operations to the conversion target.
2561  LogicalResult convertOperations(ArrayRef<Operation *> ops);
2562 
2563 private:
2564  /// Converts an operation with the given rewriter.
2565  LogicalResult convert(ConversionPatternRewriter &rewriter, Operation *op);
2566 
2567  /// Dialect conversion configuration.
2568  ConversionConfig config;
2569 
2570  /// The legalizer to use when converting operations.
2571  OperationLegalizer opLegalizer;
2572 
2573  /// The conversion mode to use when legalizing operations.
2574  OpConversionMode mode;
2575 };
2576 } // namespace mlir
2577 
2578 LogicalResult OperationConverter::convert(ConversionPatternRewriter &rewriter,
2579  Operation *op) {
2580  // Legalize the given operation.
2581  if (failed(opLegalizer.legalize(op, rewriter))) {
2582  // Handle the case of a failed conversion for each of the different modes.
2583  // Full conversions expect all operations to be converted.
2584  if (mode == OpConversionMode::Full)
2585  return op->emitError()
2586  << "failed to legalize operation '" << op->getName() << "'";
2587  // Partial conversions allow conversions to fail iff the operation was not
2588  // explicitly marked as illegal. If the user provided a `unlegalizedOps`
2589  // set, non-legalizable ops are added to that set.
2590  if (mode == OpConversionMode::Partial) {
2591  if (opLegalizer.isIllegal(op))
2592  return op->emitError()
2593  << "failed to legalize operation '" << op->getName()
2594  << "' that was explicitly marked illegal";
2595  if (config.unlegalizedOps)
2596  config.unlegalizedOps->insert(op);
2597  }
2598  } else if (mode == OpConversionMode::Analysis) {
2599  // Analysis conversions don't fail if any operations fail to legalize,
2600  // they are only interested in the operations that were successfully
2601  // legalized.
2602  if (config.legalizableOps)
2603  config.legalizableOps->insert(op);
2604  }
2605  return success();
2606 }
2607 
2608 static LogicalResult
2610  UnresolvedMaterializationRewrite *rewrite) {
2611  UnrealizedConversionCastOp op = rewrite->getOperation();
2612  assert(!op.use_empty() &&
2613  "expected that dead materializations have already been DCE'd");
2614  Operation::operand_range inputOperands = op.getOperands();
2615 
2616  // Try to materialize the conversion.
2617  if (const TypeConverter *converter = rewrite->getConverter()) {
2618  rewriter.setInsertionPoint(op);
2619  SmallVector<Value> newMaterialization;
2620  switch (rewrite->getMaterializationKind()) {
2621  case MaterializationKind::Target:
2622  newMaterialization = converter->materializeTargetConversion(
2623  rewriter, op->getLoc(), op.getResultTypes(), inputOperands,
2624  rewrite->getOriginalType());
2625  break;
2626  case MaterializationKind::Source:
2627  assert(op->getNumResults() == 1 && "expected single result");
2628  Value sourceMat = converter->materializeSourceConversion(
2629  rewriter, op->getLoc(), op.getResultTypes().front(), inputOperands);
2630  if (sourceMat)
2631  newMaterialization.push_back(sourceMat);
2632  break;
2633  }
2634  if (!newMaterialization.empty()) {
2635 #ifndef NDEBUG
2636  ValueRange newMaterializationRange(newMaterialization);
2637  assert(TypeRange(newMaterializationRange) == op.getResultTypes() &&
2638  "materialization callback produced value of incorrect type");
2639 #endif // NDEBUG
2640  rewriter.replaceOp(op, newMaterialization);
2641  return success();
2642  }
2643  }
2644 
2645  InFlightDiagnostic diag = op->emitError()
2646  << "failed to legalize unresolved materialization "
2647  "from ("
2648  << inputOperands.getTypes() << ") to ("
2649  << op.getResultTypes()
2650  << ") that remained live after conversion";
2651  diag.attachNote(op->getUsers().begin()->getLoc())
2652  << "see existing live user here: " << *op->getUsers().begin();
2653  return failure();
2654 }
2655 
2657  if (ops.empty())
2658  return success();
2659  const ConversionTarget &target = opLegalizer.getTarget();
2660 
2661  // Compute the set of operations and blocks to convert.
2662  SmallVector<Operation *> toConvert;
2663  for (auto *op : ops) {
2665  [&](Operation *op) {
2666  toConvert.push_back(op);
2667  // Don't check this operation's children for conversion if the
2668  // operation is recursively legal.
2669  auto legalityInfo = target.isLegal(op);
2670  if (legalityInfo && legalityInfo->isRecursivelyLegal)
2671  return WalkResult::skip();
2672  return WalkResult::advance();
2673  });
2674  }
2675 
2676  // Convert each operation and discard rewrites on failure.
2677  ConversionPatternRewriter rewriter(ops.front()->getContext(), config);
2678  ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
2679 
2680  for (auto *op : toConvert)
2681  if (failed(convert(rewriter, op)))
2682  return rewriterImpl.undoRewrites(), failure();
2683 
2684  // After a successful conversion, apply rewrites.
2685  rewriterImpl.applyRewrites();
2686 
2687  // Gather all unresolved materializations.
2690  &materializations = rewriterImpl.unresolvedMaterializations;
2691  for (auto it : materializations) {
2692  if (rewriterImpl.eraseRewriter.wasErased(it.first))
2693  continue;
2694  allCastOps.push_back(it.first);
2695  }
2696 
2697  // Reconcile all UnrealizedConversionCastOps that were inserted by the
2698  // dialect conversion frameworks. (Not the one that were inserted by
2699  // patterns.)
2700  SmallVector<UnrealizedConversionCastOp> remainingCastOps;
2701  reconcileUnrealizedCasts(allCastOps, &remainingCastOps);
2702 
2703  // Try to legalize all unresolved materializations.
2704  if (config.buildMaterializations) {
2705  IRRewriter rewriter(rewriterImpl.context, config.listener);
2706  for (UnrealizedConversionCastOp castOp : remainingCastOps) {
2707  auto it = materializations.find(castOp);
2708  assert(it != materializations.end() && "inconsistent state");
2709  if (failed(legalizeUnresolvedMaterialization(rewriter, it->second)))
2710  return failure();
2711  }
2712  }
2713 
2714  return success();
2715 }
2716 
2717 //===----------------------------------------------------------------------===//
2718 // Reconcile Unrealized Casts
2719 //===----------------------------------------------------------------------===//
2720 
2723  SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
2724  SetVector<UnrealizedConversionCastOp> worklist(castOps.begin(),
2725  castOps.end());
2726  // This set is maintained only if `remainingCastOps` is provided.
2727  DenseSet<Operation *> erasedOps;
2728 
2729  // Helper function that adds all operands to the worklist that are an
2730  // unrealized_conversion_cast op result.
2731  auto enqueueOperands = [&](UnrealizedConversionCastOp castOp) {
2732  for (Value v : castOp.getInputs())
2733  if (auto inputCastOp = v.getDefiningOp<UnrealizedConversionCastOp>())
2734  worklist.insert(inputCastOp);
2735  };
2736 
2737  // Helper function that return the unrealized_conversion_cast op that
2738  // defines all inputs of the given op (in the same order). Return "nullptr"
2739  // if there is no such op.
2740  auto getInputCast =
2741  [](UnrealizedConversionCastOp castOp) -> UnrealizedConversionCastOp {
2742  if (castOp.getInputs().empty())
2743  return {};
2744  auto inputCastOp =
2745  castOp.getInputs().front().getDefiningOp<UnrealizedConversionCastOp>();
2746  if (!inputCastOp)
2747  return {};
2748  if (inputCastOp.getOutputs() != castOp.getInputs())
2749  return {};
2750  return inputCastOp;
2751  };
2752 
2753  // Process ops in the worklist bottom-to-top.
2754  while (!worklist.empty()) {
2755  UnrealizedConversionCastOp castOp = worklist.pop_back_val();
2756  if (castOp->use_empty()) {
2757  // DCE: If the op has no users, erase it. Add the operands to the
2758  // worklist to find additional DCE opportunities.
2759  enqueueOperands(castOp);
2760  if (remainingCastOps)
2761  erasedOps.insert(castOp.getOperation());
2762  castOp->erase();
2763  continue;
2764  }
2765 
2766  // Traverse the chain of input cast ops to see if an op with the same
2767  // input types can be found.
2768  UnrealizedConversionCastOp nextCast = castOp;
2769  while (nextCast) {
2770  if (nextCast.getInputs().getTypes() == castOp.getResultTypes()) {
2771  // Found a cast where the input types match the output types of the
2772  // matched op. We can directly use those inputs and the matched op can
2773  // be removed.
2774  enqueueOperands(castOp);
2775  castOp.replaceAllUsesWith(nextCast.getInputs());
2776  if (remainingCastOps)
2777  erasedOps.insert(castOp.getOperation());
2778  castOp->erase();
2779  break;
2780  }
2781  nextCast = getInputCast(nextCast);
2782  }
2783  }
2784 
2785  if (remainingCastOps)
2786  for (UnrealizedConversionCastOp op : castOps)
2787  if (!erasedOps.contains(op.getOperation()))
2788  remainingCastOps->push_back(op);
2789 }
2790 
2791 //===----------------------------------------------------------------------===//
2792 // Type Conversion
2793 //===----------------------------------------------------------------------===//
2794 
2796  ArrayRef<Type> types) {
2797  assert(!types.empty() && "expected valid types");
2798  remapInput(origInputNo, /*newInputNo=*/argTypes.size(), types.size());
2799  addInputs(types);
2800 }
2801 
2803  assert(!types.empty() &&
2804  "1->0 type remappings don't need to be added explicitly");
2805  argTypes.append(types.begin(), types.end());
2806 }
2807 
2808 void TypeConverter::SignatureConversion::remapInput(unsigned origInputNo,
2809  unsigned newInputNo,
2810  unsigned newInputCount) {
2811  assert(!remappedInputs[origInputNo] && "input has already been remapped");
2812  assert(newInputCount != 0 && "expected valid input count");
2813  remappedInputs[origInputNo] =
2814  InputMapping{newInputNo, newInputCount, /*replacementValues=*/{}};
2815 }
2816 
2818  unsigned origInputNo, ArrayRef<Value> replacements) {
2819  assert(!remappedInputs[origInputNo] && "input has already been remapped");
2820  remappedInputs[origInputNo] = InputMapping{
2821  origInputNo, /*size=*/0,
2822  SmallVector<Value, 1>(replacements.begin(), replacements.end())};
2823 }
2824 
2826  SmallVectorImpl<Type> &results) const {
2827  assert(t && "expected non-null type");
2828 
2829  {
2830  std::shared_lock<decltype(cacheMutex)> cacheReadLock(cacheMutex,
2831  std::defer_lock);
2833  cacheReadLock.lock();
2834  auto existingIt = cachedDirectConversions.find(t);
2835  if (existingIt != cachedDirectConversions.end()) {
2836  if (existingIt->second)
2837  results.push_back(existingIt->second);
2838  return success(existingIt->second != nullptr);
2839  }
2840  auto multiIt = cachedMultiConversions.find(t);
2841  if (multiIt != cachedMultiConversions.end()) {
2842  results.append(multiIt->second.begin(), multiIt->second.end());
2843  return success();
2844  }
2845  }
2846  // Walk the added converters in reverse order to apply the most recently
2847  // registered first.
2848  size_t currentCount = results.size();
2849 
2850  std::unique_lock<decltype(cacheMutex)> cacheWriteLock(cacheMutex,
2851  std::defer_lock);
2852 
2853  for (const ConversionCallbackFn &converter : llvm::reverse(conversions)) {
2854  if (std::optional<LogicalResult> result = converter(t, results)) {
2856  cacheWriteLock.lock();
2857  if (!succeeded(*result)) {
2858  cachedDirectConversions.try_emplace(t, nullptr);
2859  return failure();
2860  }
2861  auto newTypes = ArrayRef<Type>(results).drop_front(currentCount);
2862  if (newTypes.size() == 1)
2863  cachedDirectConversions.try_emplace(t, newTypes.front());
2864  else
2865  cachedMultiConversions.try_emplace(t, llvm::to_vector<2>(newTypes));
2866  return success();
2867  }
2868  }
2869  return failure();
2870 }
2871 
2873  // Use the multi-type result version to convert the type.
2874  SmallVector<Type, 1> results;
2875  if (failed(convertType(t, results)))
2876  return nullptr;
2877 
2878  // Check to ensure that only one type was produced.
2879  return results.size() == 1 ? results.front() : nullptr;
2880 }
2881 
2882 LogicalResult
2884  SmallVectorImpl<Type> &results) const {
2885  for (Type type : types)
2886  if (failed(convertType(type, results)))
2887  return failure();
2888  return success();
2889 }
2890 
2891 bool TypeConverter::isLegal(Type type) const {
2892  return convertType(type) == type;
2893 }
2895  return isLegal(op->getOperandTypes()) && isLegal(op->getResultTypes());
2896 }
2897 
2898 bool TypeConverter::isLegal(Region *region) const {
2899  return llvm::all_of(*region, [this](Block &block) {
2900  return isLegal(block.getArgumentTypes());
2901  });
2902 }
2903 
2904 bool TypeConverter::isSignatureLegal(FunctionType ty) const {
2905  return isLegal(llvm::concat<const Type>(ty.getInputs(), ty.getResults()));
2906 }
2907 
2908 LogicalResult
2910  SignatureConversion &result) const {
2911  // Try to convert the given input type.
2912  SmallVector<Type, 1> convertedTypes;
2913  if (failed(convertType(type, convertedTypes)))
2914  return failure();
2915 
2916  // If this argument is being dropped, there is nothing left to do.
2917  if (convertedTypes.empty())
2918  return success();
2919 
2920  // Otherwise, add the new inputs.
2921  result.addInputs(inputNo, convertedTypes);
2922  return success();
2923 }
2924 LogicalResult
2926  SignatureConversion &result,
2927  unsigned origInputOffset) const {
2928  for (unsigned i = 0, e = types.size(); i != e; ++i)
2929  if (failed(convertSignatureArg(origInputOffset + i, types[i], result)))
2930  return failure();
2931  return success();
2932 }
2933 
2935  Location loc,
2936  Type resultType,
2937  ValueRange inputs) const {
2938  for (const MaterializationCallbackFn &fn :
2939  llvm::reverse(argumentMaterializations))
2940  if (Value result = fn(builder, resultType, inputs, loc))
2941  return result;
2942  return nullptr;
2943 }
2944 
2946  Location loc, Type resultType,
2947  ValueRange inputs) const {
2948  for (const MaterializationCallbackFn &fn :
2949  llvm::reverse(sourceMaterializations))
2950  if (Value result = fn(builder, resultType, inputs, loc))
2951  return result;
2952  return nullptr;
2953 }
2954 
2956  Location loc, Type resultType,
2957  ValueRange inputs,
2958  Type originalType) const {
2960  builder, loc, TypeRange(resultType), inputs, originalType);
2961  if (result.empty())
2962  return nullptr;
2963  assert(result.size() == 1 && "expected single result");
2964  return result.front();
2965 }
2966 
2968  OpBuilder &builder, Location loc, TypeRange resultTypes, ValueRange inputs,
2969  Type originalType) const {
2970  for (const TargetMaterializationCallbackFn &fn :
2971  llvm::reverse(targetMaterializations)) {
2972  SmallVector<Value> result =
2973  fn(builder, resultTypes, inputs, loc, originalType);
2974  if (result.empty())
2975  continue;
2976  assert(TypeRange(ValueRange(result)) == resultTypes &&
2977  "callback produced incorrect number of values or values with "
2978  "incorrect types");
2979  return result;
2980  }
2981  return {};
2982 }
2983 
2984 std::optional<TypeConverter::SignatureConversion>
2986  SignatureConversion conversion(block->getNumArguments());
2987  if (failed(convertSignatureArgs(block->getArgumentTypes(), conversion)))
2988  return std::nullopt;
2989  return conversion;
2990 }
2991 
2992 //===----------------------------------------------------------------------===//
2993 // Type attribute conversion
2994 //===----------------------------------------------------------------------===//
2997  return AttributeConversionResult(attr, resultTag);
2998 }
2999 
3002  return AttributeConversionResult(nullptr, naTag);
3003 }
3004 
3007  return AttributeConversionResult(nullptr, abortTag);
3008 }
3009 
3011  return impl.getInt() == resultTag;
3012 }
3013 
3015  return impl.getInt() == naTag;
3016 }
3017 
3019  return impl.getInt() == abortTag;
3020 }
3021 
3023  assert(hasResult() && "Cannot get result from N/A or abort");
3024  return impl.getPointer();
3025 }
3026 
3027 std::optional<Attribute>
3029  for (const TypeAttributeConversionCallbackFn &fn :
3030  llvm::reverse(typeAttributeConversions)) {
3031  AttributeConversionResult res = fn(type, attr);
3032  if (res.hasResult())
3033  return res.getResult();
3034  if (res.isAbort())
3035  return std::nullopt;
3036  }
3037  return std::nullopt;
3038 }
3039 
3040 //===----------------------------------------------------------------------===//
3041 // FunctionOpInterfaceSignatureConversion
3042 //===----------------------------------------------------------------------===//
3043 
3044 static LogicalResult convertFuncOpTypes(FunctionOpInterface funcOp,
3045  const TypeConverter &typeConverter,
3046  ConversionPatternRewriter &rewriter) {
3047  FunctionType type = dyn_cast<FunctionType>(funcOp.getFunctionType());
3048  if (!type)
3049  return failure();
3050 
3051  // Convert the original function types.
3052  TypeConverter::SignatureConversion result(type.getNumInputs());
3053  SmallVector<Type, 1> newResults;
3054  if (failed(typeConverter.convertSignatureArgs(type.getInputs(), result)) ||
3055  failed(typeConverter.convertTypes(type.getResults(), newResults)) ||
3056  failed(rewriter.convertRegionTypes(&funcOp.getFunctionBody(),
3057  typeConverter, &result)))
3058  return failure();
3059 
3060  // Update the function signature in-place.
3061  auto newType = FunctionType::get(rewriter.getContext(),
3062  result.getConvertedTypes(), newResults);
3063 
3064  rewriter.modifyOpInPlace(funcOp, [&] { funcOp.setType(newType); });
3065 
3066  return success();
3067 }
3068 
3069 /// Create a default conversion pattern that rewrites the type signature of a
3070 /// FunctionOpInterface op. This only supports ops which use FunctionType to
3071 /// represent their type.
3072 namespace {
3073 struct FunctionOpInterfaceSignatureConversion : public ConversionPattern {
3074  FunctionOpInterfaceSignatureConversion(StringRef functionLikeOpName,
3075  MLIRContext *ctx,
3076  const TypeConverter &converter)
3077  : ConversionPattern(converter, functionLikeOpName, /*benefit=*/1, ctx) {}
3078 
3079  LogicalResult
3080  matchAndRewrite(Operation *op, ArrayRef<Value> /*operands*/,
3081  ConversionPatternRewriter &rewriter) const override {
3082  FunctionOpInterface funcOp = cast<FunctionOpInterface>(op);
3083  return convertFuncOpTypes(funcOp, *typeConverter, rewriter);
3084  }
3085 };
3086 
3087 struct AnyFunctionOpInterfaceSignatureConversion
3088  : public OpInterfaceConversionPattern<FunctionOpInterface> {
3090 
3091  LogicalResult
3092  matchAndRewrite(FunctionOpInterface funcOp, ArrayRef<Value> /*operands*/,
3093  ConversionPatternRewriter &rewriter) const override {
3094  return convertFuncOpTypes(funcOp, *typeConverter, rewriter);
3095  }
3096 };
3097 } // namespace
3098 
3099 FailureOr<Operation *>
3101  const TypeConverter &converter,
3102  ConversionPatternRewriter &rewriter) {
3103  assert(op && "Invalid op");
3104  Location loc = op->getLoc();
3105  if (converter.isLegal(op))
3106  return rewriter.notifyMatchFailure(loc, "op already legal");
3107 
3108  OperationState newOp(loc, op->getName());
3109  newOp.addOperands(operands);
3110 
3111  SmallVector<Type> newResultTypes;
3112  if (failed(converter.convertTypes(op->getResultTypes(), newResultTypes)))
3113  return rewriter.notifyMatchFailure(loc, "couldn't convert return types");
3114 
3115  newOp.addTypes(newResultTypes);
3116  newOp.addAttributes(op->getAttrs());
3117  return rewriter.create(newOp);
3118 }
3119 
3121  StringRef functionLikeOpName, RewritePatternSet &patterns,
3122  const TypeConverter &converter) {
3123  patterns.add<FunctionOpInterfaceSignatureConversion>(
3124  functionLikeOpName, patterns.getContext(), converter);
3125 }
3126 
3128  RewritePatternSet &patterns, const TypeConverter &converter) {
3129  patterns.add<AnyFunctionOpInterfaceSignatureConversion>(
3130  converter, patterns.getContext());
3131 }
3132 
3133 //===----------------------------------------------------------------------===//
3134 // ConversionTarget
3135 //===----------------------------------------------------------------------===//
3136 
3138  LegalizationAction action) {
3139  legalOperations[op].action = action;
3140 }
3141 
3143  LegalizationAction action) {
3144  for (StringRef dialect : dialectNames)
3145  legalDialects[dialect] = action;
3146 }
3147 
3149  -> std::optional<LegalizationAction> {
3150  std::optional<LegalizationInfo> info = getOpInfo(op);
3151  return info ? info->action : std::optional<LegalizationAction>();
3152 }
3153 
3155  -> std::optional<LegalOpDetails> {
3156  std::optional<LegalizationInfo> info = getOpInfo(op->getName());
3157  if (!info)
3158  return std::nullopt;
3159 
3160  // Returns true if this operation instance is known to be legal.
3161  auto isOpLegal = [&] {
3162  // Handle dynamic legality either with the provided legality function.
3163  if (info->action == LegalizationAction::Dynamic) {
3164  std::optional<bool> result = info->legalityFn(op);
3165  if (result)
3166  return *result;
3167  }
3168 
3169  // Otherwise, the operation is only legal if it was marked 'Legal'.
3170  return info->action == LegalizationAction::Legal;
3171  };
3172  if (!isOpLegal())
3173  return std::nullopt;
3174 
3175  // This operation is legal, compute any additional legality information.
3176  LegalOpDetails legalityDetails;
3177  if (info->isRecursivelyLegal) {
3178  auto legalityFnIt = opRecursiveLegalityFns.find(op->getName());
3179  if (legalityFnIt != opRecursiveLegalityFns.end()) {
3180  legalityDetails.isRecursivelyLegal =
3181  legalityFnIt->second(op).value_or(true);
3182  } else {
3183  legalityDetails.isRecursivelyLegal = true;
3184  }
3185  }
3186  return legalityDetails;
3187 }
3188 
3190  std::optional<LegalizationInfo> info = getOpInfo(op->getName());
3191  if (!info)
3192  return false;
3193 
3194  if (info->action == LegalizationAction::Dynamic) {
3195  std::optional<bool> result = info->legalityFn(op);
3196  if (!result)
3197  return false;
3198 
3199  return !(*result);
3200  }
3201 
3202  return info->action == LegalizationAction::Illegal;
3203 }
3204 
3208  if (!oldCallback)
3209  return newCallback;
3210 
3211  auto chain = [oldCl = std::move(oldCallback), newCl = std::move(newCallback)](
3212  Operation *op) -> std::optional<bool> {
3213  if (std::optional<bool> result = newCl(op))
3214  return *result;
3215 
3216  return oldCl(op);
3217  };
3218  return chain;
3219 }
3220 
3221 void ConversionTarget::setLegalityCallback(
3222  OperationName name, const DynamicLegalityCallbackFn &callback) {
3223  assert(callback && "expected valid legality callback");
3224  auto *infoIt = legalOperations.find(name);
3225  assert(infoIt != legalOperations.end() &&
3226  infoIt->second.action == LegalizationAction::Dynamic &&
3227  "expected operation to already be marked as dynamically legal");
3228  infoIt->second.legalityFn =
3229  composeLegalityCallbacks(std::move(infoIt->second.legalityFn), callback);
3230 }
3231 
3233  OperationName name, const DynamicLegalityCallbackFn &callback) {
3234  auto *infoIt = legalOperations.find(name);
3235  assert(infoIt != legalOperations.end() &&
3236  infoIt->second.action != LegalizationAction::Illegal &&
3237  "expected operation to already be marked as legal");
3238  infoIt->second.isRecursivelyLegal = true;
3239  if (callback)
3240  opRecursiveLegalityFns[name] = composeLegalityCallbacks(
3241  std::move(opRecursiveLegalityFns[name]), callback);
3242  else
3243  opRecursiveLegalityFns.erase(name);
3244 }
3245 
3246 void ConversionTarget::setLegalityCallback(
3247  ArrayRef<StringRef> dialects, const DynamicLegalityCallbackFn &callback) {
3248  assert(callback && "expected valid legality callback");
3249  for (StringRef dialect : dialects)
3250  dialectLegalityFns[dialect] = composeLegalityCallbacks(
3251  std::move(dialectLegalityFns[dialect]), callback);
3252 }
3253 
3254 void ConversionTarget::setLegalityCallback(
3255  const DynamicLegalityCallbackFn &callback) {
3256  assert(callback && "expected valid legality callback");
3257  unknownLegalityFn = composeLegalityCallbacks(unknownLegalityFn, callback);
3258 }
3259 
3260 auto ConversionTarget::getOpInfo(OperationName op) const
3261  -> std::optional<LegalizationInfo> {
3262  // Check for info for this specific operation.
3263  const auto *it = legalOperations.find(op);
3264  if (it != legalOperations.end())
3265  return it->second;
3266  // Check for info for the parent dialect.
3267  auto dialectIt = legalDialects.find(op.getDialectNamespace());
3268  if (dialectIt != legalDialects.end()) {
3269  DynamicLegalityCallbackFn callback;
3270  auto dialectFn = dialectLegalityFns.find(op.getDialectNamespace());
3271  if (dialectFn != dialectLegalityFns.end())
3272  callback = dialectFn->second;
3273  return LegalizationInfo{dialectIt->second, /*isRecursivelyLegal=*/false,
3274  callback};
3275  }
3276  // Otherwise, check if we mark unknown operations as dynamic.
3277  if (unknownLegalityFn)
3278  return LegalizationInfo{LegalizationAction::Dynamic,
3279  /*isRecursivelyLegal=*/false, unknownLegalityFn};
3280  return std::nullopt;
3281 }
3282 
3283 #if MLIR_ENABLE_PDL_IN_PATTERNMATCH
3284 //===----------------------------------------------------------------------===//
3285 // PDL Configuration
3286 //===----------------------------------------------------------------------===//
3287 
3289  auto &rewriterImpl =
3290  static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
3291  rewriterImpl.currentTypeConverter = getTypeConverter();
3292 }
3293 
3295  auto &rewriterImpl =
3296  static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
3297  rewriterImpl.currentTypeConverter = nullptr;
3298 }
3299 
3300 /// Remap the given value using the rewriter and the type converter in the
3301 /// provided config.
3302 static FailureOr<SmallVector<Value>>
3304  SmallVector<Value> mappedValues;
3305  if (failed(rewriter.getRemappedValues(values, mappedValues)))
3306  return failure();
3307  return std::move(mappedValues);
3308 }
3309 
3311  patterns.getPDLPatterns().registerRewriteFunction(
3312  "convertValue",
3313  [](PatternRewriter &rewriter, Value value) -> FailureOr<Value> {
3314  auto results = pdllConvertValues(
3315  static_cast<ConversionPatternRewriter &>(rewriter), value);
3316  if (failed(results))
3317  return failure();
3318  return results->front();
3319  });
3320  patterns.getPDLPatterns().registerRewriteFunction(
3321  "convertValues", [](PatternRewriter &rewriter, ValueRange values) {
3322  return pdllConvertValues(
3323  static_cast<ConversionPatternRewriter &>(rewriter), values);
3324  });
3325  patterns.getPDLPatterns().registerRewriteFunction(
3326  "convertType",
3327  [](PatternRewriter &rewriter, Type type) -> FailureOr<Type> {
3328  auto &rewriterImpl =
3329  static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
3330  if (const TypeConverter *converter =
3331  rewriterImpl.currentTypeConverter) {
3332  if (Type newType = converter->convertType(type))
3333  return newType;
3334  return failure();
3335  }
3336  return type;
3337  });
3338  patterns.getPDLPatterns().registerRewriteFunction(
3339  "convertTypes",
3340  [](PatternRewriter &rewriter,
3341  TypeRange types) -> FailureOr<SmallVector<Type>> {
3342  auto &rewriterImpl =
3343  static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
3344  const TypeConverter *converter = rewriterImpl.currentTypeConverter;
3345  if (!converter)
3346  return SmallVector<Type>(types);
3347 
3348  SmallVector<Type> remappedTypes;
3349  if (failed(converter->convertTypes(types, remappedTypes)))
3350  return failure();
3351  return std::move(remappedTypes);
3352  });
3353 }
3354 #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
3355 
3356 //===----------------------------------------------------------------------===//
3357 // Op Conversion Entry Points
3358 //===----------------------------------------------------------------------===//
3359 
3360 //===----------------------------------------------------------------------===//
3361 // Partial Conversion
3362 //===----------------------------------------------------------------------===//
3363 
3365  ArrayRef<Operation *> ops, const ConversionTarget &target,
3367  OperationConverter opConverter(target, patterns, config,
3368  OpConversionMode::Partial);
3369  return opConverter.convertOperations(ops);
3370 }
3371 LogicalResult
3375  return applyPartialConversion(llvm::ArrayRef(op), target, patterns, config);
3376 }
3377 
3378 //===----------------------------------------------------------------------===//
3379 // Full Conversion
3380 //===----------------------------------------------------------------------===//
3381 
3383  const ConversionTarget &target,
3386  OperationConverter opConverter(target, patterns, config,
3387  OpConversionMode::Full);
3388  return opConverter.convertOperations(ops);
3389 }
3391  const ConversionTarget &target,
3394  return applyFullConversion(llvm::ArrayRef(op), target, patterns, config);
3395 }
3396 
3397 //===----------------------------------------------------------------------===//
3398 // Analysis Conversion
3399 //===----------------------------------------------------------------------===//
3400 
3401 /// Find a common IsolatedFromAbove ancestor of the given ops. If at least one
3402 /// op is a top-level module op (which is expected to be isolated from above),
3403 /// return that op.
3405  // Check if there is a top-level operation within `ops`. If so, return that
3406  // op.
3407  for (Operation *op : ops) {
3408  if (!op->getParentOp()) {
3409 #ifndef NDEBUG
3410  assert(op->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
3411  "expected top-level op to be isolated from above");
3412  for (Operation *other : ops)
3413  assert(op->isAncestor(other) &&
3414  "expected ops to have a common ancestor");
3415 #endif // NDEBUG
3416  return op;
3417  }
3418  }
3419 
3420  // No top-level op. Find a common ancestor.
3421  Operation *commonAncestor =
3422  ops.front()->getParentWithTrait<OpTrait::IsIsolatedFromAbove>();
3423  for (Operation *op : ops.drop_front()) {
3424  while (!commonAncestor->isProperAncestor(op)) {
3425  commonAncestor =
3427  assert(commonAncestor &&
3428  "expected to find a common isolated from above ancestor");
3429  }
3430  }
3431 
3432  return commonAncestor;
3433 }
3434 
3438 #ifndef NDEBUG
3439  if (config.legalizableOps)
3440  assert(config.legalizableOps->empty() && "expected empty set");
3441 #endif // NDEBUG
3442 
3443  // Clone closted common ancestor that is isolated from above.
3444  Operation *commonAncestor = findCommonAncestor(ops);
3445  IRMapping mapping;
3446  Operation *clonedAncestor = commonAncestor->clone(mapping);
3447  // Compute inverse IR mapping.
3448  DenseMap<Operation *, Operation *> inverseOperationMap;
3449  for (auto &it : mapping.getOperationMap())
3450  inverseOperationMap[it.second] = it.first;
3451 
3452  // Convert the cloned operations. The original IR will remain unchanged.
3453  SmallVector<Operation *> opsToConvert = llvm::map_to_vector(
3454  ops, [&](Operation *op) { return mapping.lookup(op); });
3455  OperationConverter opConverter(target, patterns, config,
3456  OpConversionMode::Analysis);
3457  LogicalResult status = opConverter.convertOperations(opsToConvert);
3458 
3459  // Remap `legalizableOps`, so that they point to the original ops and not the
3460  // cloned ops.
3461  if (config.legalizableOps) {
3462  DenseSet<Operation *> originalLegalizableOps;
3463  for (Operation *op : *config.legalizableOps)
3464  originalLegalizableOps.insert(inverseOperationMap[op]);
3465  *config.legalizableOps = std::move(originalLegalizableOps);
3466  }
3467 
3468  // Erase the cloned IR.
3469  clonedAncestor->erase();
3470  return status;
3471 }
3472 
3473 LogicalResult
3477  return applyAnalysisConversion(llvm::ArrayRef(op), target, patterns, config);
3478 }
static ConversionTarget::DynamicLegalityCallbackFn composeLegalityCallbacks(ConversionTarget::DynamicLegalityCallbackFn oldCallback, ConversionTarget::DynamicLegalityCallbackFn newCallback)
static FailureOr< SmallVector< Value > > pdllConvertValues(ConversionPatternRewriter &rewriter, ValueRange values)
Remap the given value using the rewriter and the type converter in the provided config.
static LogicalResult convertFuncOpTypes(FunctionOpInterface funcOp, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args)
A utility function to log a failure result for the given reason.
static LogicalResult legalizeUnresolvedMaterialization(RewriterBase &rewriter, UnresolvedMaterializationRewrite *rewrite)
static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args)
A utility function to log a successful result for the given reason.
SmallVector< Value, 1 > ValueVector
A vector of SSA values, optimized for the most common case of a single value.
static Operation * findCommonAncestor(ArrayRef< Operation * > ops)
Find a common IsolatedFromAbove ancestor of the given ops.
static OpBuilder::InsertPoint computeInsertPoint(Value value)
Helper function that computes an insertion point where the given value is defined and can be used wit...
union mlir::linalg::@1182::ArityGroupAndKind::Kind kind
static std::string diag(const llvm::Value &value)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void rewrite(DataFlowSolver &solver, MLIRContext *context, MutableArrayRef< Region > initialRegions)
Rewrite the given regions using the computing analysis.
Definition: SCCP.cpp:67
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:319
Block * getOwner() const
Returns the block that owns this argument.
Definition: Value.h:328
Location getLoc() const
Return the location for this argument.
Definition: Value.h:334
Block represents an ordered list of Operations.
Definition: Block.h:33
OpListType::iterator iterator
Definition: Block.h:140
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
Definition: Block.cpp:151
bool empty()
Definition: Block.h:148
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
unsigned getNumArguments()
Definition: Block.h:128
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition: Block.cpp:29
void dropAllDefinedValueUses()
This drops all uses of values defined in this block or in the blocks of nested regions wherever the u...
Definition: Block.cpp:96
OpListType & getOperations()
Definition: Block.h:137
BlockArgListType getArguments()
Definition: Block.h:87
Operation & front()
Definition: Block.h:153
iterator end()
Definition: Block.h:144
iterator begin()
Definition: Block.h:143
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:33
MLIRContext * getContext() const
Definition: Builders.h:56
Location getUnknownLoc()
Definition: Builders.cpp:27
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
LogicalResult getRemappedValues(ValueRange keys, SmallVectorImpl< Value > &results)
Return the converted values that replace 'keys' with types defined by the type converter of the curre...
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Apply a signature conversion to each block in the given region.
void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt) override
PatternRewriter hook for inlining the ops of a block into another block.
void replaceOpWithMultiple(Operation *op, SmallVector< SmallVector< Value >> &&newValues)
Replace the given operation with the new value ranges.
Block * applySignatureConversion(Block *block, TypeConverter::SignatureConversion &conversion, const TypeConverter *converter=nullptr)
Apply a signature conversion to given block.
void startOpModification(Operation *op) override
PatternRewriter hook for updating the given operation in-place.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
detail::ConversionPatternRewriterImpl & getImpl()
Return a reference to the internal implementation.
void eraseBlock(Block *block) override
PatternRewriter hook for erase all operations in a block.
void cancelOpModification(Operation *op) override
PatternRewriter hook for updating the given operation in-place.
Value getRemappedValue(Value key)
Return the converted value of 'key' with a type defined by the type converter of the currently execut...
void finalizeOpModification(Operation *op) override
PatternRewriter hook for updating the given operation in-place.
void replaceUsesOfBlockArgument(BlockArgument from, Value to)
Replace all the uses of the block argument from with value to.
Base class for the conversion patterns.
SmallVector< Value > getOneToOneAdaptorOperands(ArrayRef< ValueRange > operands) const
Given an array of value ranges, which are the inputs to a 1:N adaptor, try to extract the single valu...
virtual LogicalResult matchAndRewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const
Hook for derived classes to implement combined matching and rewriting.
This class describes a specific conversion target.
void setDialectAction(ArrayRef< StringRef > dialectNames, LegalizationAction action)
Register a legality action for the given dialects.
void setOpAction(OperationName op, LegalizationAction action)
Register a legality action for the given operation.
std::optional< LegalOpDetails > isLegal(Operation *op) const
If the given operation instance is legal on this target, a structure containing legality information ...
std::optional< LegalizationAction > getOpAction(OperationName op) const
Get the legality action for the given operation.
LegalizationAction
This enumeration corresponds to the specific action to take when considering an operation legal for t...
void markOpRecursivelyLegal(OperationName name, const DynamicLegalityCallbackFn &callback)
Mark an operation, that must have either been set as Legal or DynamicallyLegal, as being recursively ...
std::function< std::optional< bool >(Operation *)> DynamicLegalityCallbackFn
The signature of the callback used to determine if an operation is dynamically legal on the target.
bool isIllegal(Operation *op) const
Returns true is operation instance is illegal on this target.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
Definition: Diagnostics.h:155
A class for computing basic dominance information.
Definition: Dominance.h:140
bool dominates(Operation *a, Operation *b) const
Return true if operation A dominates operation B, i.e.
Definition: Dominance.h:158
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
const DenseMap< Operation *, Operation * > & getOperationMap() const
Return the held operation mapping.
Definition: IRMapping.h:88
auto lookup(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:72
user_range getUsers() const
Returns a range of all users.
Definition: UseDefLists.h:274
void replaceAllUsesWith(ValueT &&newValue)
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
Definition: UseDefLists.h:211
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:734
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:314
Location objects represent source locations information in MLIR.
Definition: Location.h:31
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
bool isMultithreadingEnabled()
Return true if multi-threading is enabled by the context.
This class represents a saved insertion point.
Definition: Builders.h:325
Block::iterator getPoint() const
Definition: Builders.h:338
bool isSet() const
Returns true if this insert point is set.
Definition: Builders.h:335
Block * getBlock() const
Definition: Builders.h:337
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class helps build Operations.
Definition: Builders.h:205
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:396
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Definition: Builders.h:318
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:426
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
LogicalResult tryFold(Operation *op, SmallVectorImpl< Value > &results)
Attempts to fold the given operation and places new results within results.
Definition: Builders.cpp:468
OpInterfaceConversionPattern is a wrapper around ConversionPattern that allows for matching and rewri...
OpInterfaceConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
This class represents an operand of an operation.
Definition: Value.h:267
This is a value defined by a result of an operation.
Definition: Value.h:457
This class provides the API for ops that are known to be isolated from above.
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:43
type_range getTypes() const
Definition: ValueRange.cpp:28
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
void setLoc(Location loc)
Set the source location the operation was defined or derived from.
Definition: Operation.h:226
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:750
void dropAllUses()
Drop all uses of results of this operation.
Definition: Operation.h:835
void setAttrs(DictionaryAttr newAttrs)
Set the attributes from a dictionary on this operation.
Definition: Operation.cpp:305
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Definition: Operation.cpp:386
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
Definition: Operation.cpp:717
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:798
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:674
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:512
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
Definition: Operation.h:248
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:677
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
operand_type_range getOperandTypes()
Definition: Operation.h:397
result_type_range getResultTypes()
Definition: Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
void setSuccessor(Block *block, unsigned index)
Definition: Operation.cpp:605
bool isAncestor(Operation *other)
Return true if this operation is an ancestor of the other operation.
Definition: Operation.h:263
void setOperands(ValueRange operands)
Replace the current operands of this operation with the ones provided in 'operands'.
Definition: Operation.cpp:237
result_range getResults()
Definition: Operation.h:415
int getPropertiesStorageSize() const
Returns the properties storage size.
Definition: Operation.h:897
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
Definition: Operation.cpp:219
OpaqueProperties getPropertiesStorage()
Returns the properties storage.
Definition: Operation.h:901
void erase()
Remove this operation from its parent block and delete it.
Definition: Operation.cpp:539
void copyProperties(OpaqueProperties rhs)
Copy properties from an existing other properties object.
Definition: Operation.cpp:366
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
void notifyRewriteEnd(PatternRewriter &rewriter) final
void notifyRewriteBegin(PatternRewriter &rewriter) final
Hooks that are invoked at the beginning and end of a rewrite of a matched pattern.
This class manages the application of a group of rewrite patterns, with a user-provided cost model.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
static PatternBenefit impossibleToMatch()
Definition: PatternMatch.h:43
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:753
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
Definition: PatternMatch.h:73
bool hasBoundedRewriteRecursion() const
Returns true if this pattern is known to result in recursive application, i.e.
Definition: PatternMatch.h:129
std::optional< OperationName > getRootKind() const
Return the root node that this pattern matches.
Definition: PatternMatch.h:94
ArrayRef< OperationName > getGeneratedOps() const
Return a list of operations that may be generated when rewriting an operation instance with this patt...
Definition: PatternMatch.h:90
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition: Region.h:200
bool empty()
Definition: Region.h:60
iterator end()
Definition: Region.h:56
BlockListType & getBlocks()
Definition: Region.h:45
Block & front()
Definition: Region.h:65
BlockListType::iterator iterator
Definition: Region.h:52
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:362
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:686
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:606
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void moveOpBefore(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:598
The general result of a type attribute conversion callback, allowing for early termination.
static AttributeConversionResult result(Attribute attr)
This class provides all of the information necessary to convert a type signature.
void addInputs(unsigned origInputNo, ArrayRef< Type > types)
Remap an input of the original signature with a new set of types.
std::optional< InputMapping > getInputMapping(unsigned input) const
Get the input mapping for the given argument.
ArrayRef< Type > getConvertedTypes() const
Return the argument types for the new signature.
void remapInput(unsigned origInputNo, ArrayRef< Value > replacements)
Remap an input of the original signature to replacements values.
Type conversion class.
std::optional< Attribute > convertTypeAttribute(Type type, Attribute attr) const
Convert an attribute present attr from within the type type using the registered conversion functions...
Value materializeSourceConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs) const
bool isLegal(Type type) const
Return true if the given type is legal for this type converter, i.e.
LogicalResult convertSignatureArgs(TypeRange types, SignatureConversion &result, unsigned origInputOffset=0) const
LogicalResult convertSignatureArg(unsigned inputNo, Type type, SignatureConversion &result) const
This method allows for converting a specific argument of a signature.
Value materializeArgumentConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs) const
Materialize a conversion from a set of types into one result type by generating a cast sequence of so...
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
std::optional< SignatureConversion > convertBlockSignature(Block *block) const
This function converts the type signature of the given block, by invoking 'convertSignatureArg' for e...
LogicalResult convertTypes(TypeRange types, SmallVectorImpl< Type > &results) const
Convert the given set of types, filling 'results' as necessary.
bool isSignatureLegal(FunctionType ty) const
Return true if the inputs and outputs of the given function type are legal.
Value materializeTargetConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs, Type originalType={}) const
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Block * getParentBlock()
Return the Block in which this Value is defined.
Definition: Value.cpp:48
user_range getUsers() const
Definition: Value.h:228
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
static WalkResult skip()
Definition: Visitors.h:52
static WalkResult advance()
Definition: Visitors.h:51
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
AttrTypeReplacer.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
@ Full
Documents are synced by always sending the full content of the document.
Kind
An enumeration of the kinds of predicates.
Definition: Predicate.h:44
Include the generated interface declarations.
void populateFunctionOpInterfaceTypeConversionPattern(StringRef functionLikeOpName, RewritePatternSet &patterns, const TypeConverter &converter)
Add a pattern to the given pattern list to convert the signature of a FunctionOpInterface op with the...
void populateAnyFunctionOpInterfaceTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter)
const FrozenRewritePatternSet GreedyRewriteConfig config
LogicalResult applyFullConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Apply a complete conversion on the given operations, and all nested operations.
FailureOr< Operation * > convertOpResultTypes(Operation *op, ValueRange operands, const TypeConverter &converter, ConversionPatternRewriter &rewriter)
Generic utility to convert op result types according to type converter without knowing exact op type.
const FrozenRewritePatternSet & patterns
LogicalResult applyAnalysisConversion(ArrayRef< Operation * > ops, ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Apply an analysis conversion on the given operations, and all nested operations.
void reconcileUnrealizedCasts(ArrayRef< UnrealizedConversionCastOp > castOps, SmallVectorImpl< UnrealizedConversionCastOp > *remainingCastOps=nullptr)
Try to reconcile all given UnrealizedConversionCastOps and store the left-over ops in remainingCastOp...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void registerConversionPDLFunctions(RewritePatternSet &patterns)
Register the dialect conversion PDL functions with the given pattern set.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
Dialect conversion configuration.
RewriterBase::Listener * listener
An optional listener that is notified about all IR modifications in case dialect conversion succeeds.
function_ref< void(Diagnostic &)> notifyCallback
An optional callback used to notify about match failure diagnostics during the conversion.
DenseSet< Operation * > * legalizableOps
Analysis conversion only.
DenseSet< Operation * > * unlegalizedOps
Partial conversion only.
bool buildMaterializations
If set to "true", the dialect conversion attempts to build source/target materializations through the...
A structure containing additional information describing a specific legal operation instance.
bool isRecursivelyLegal
A flag that indicates if this operation is 'recursively' legal.
This iterator enumerates elements according to their dominance relationship.
Definition: Iterators.h:48
LogicalResult convertOperations(ArrayRef< Operation * > ops)
Converts the given operations to the conversion target.
OperationConverter(const ConversionTarget &target, const FrozenRewritePatternSet &patterns, const ConversionConfig &config, OpConversionMode mode)
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addTypes(ArrayRef< Type > newTypes)
This struct represents a range of new types or a range of values that remaps an existing signature in...
A rewriter that keeps track of erased ops and blocks.
void eraseOp(Operation *op) override
Erase the given op (unless it was already erased).
void notifyBlockErased(Block *block) override
Notify the listener that the specified block is about to be erased.
void notifyOperationErased(Operation *op) override
Notify the listener that the specified operation is about to be erased.
void eraseBlock(Block *block) override
Erase the given block (unless it was already erased).
void notifyOperationInserted(Operation *op, OpBuilder::InsertPoint previous) override
Notify the listener that the specified operation was inserted.
Value findOrBuildReplacementValue(Value value, const TypeConverter *converter)
Find a replacement value for the given SSA value in the conversion value mapping.
ConversionPatternRewriterImpl(MLIRContext *ctx, const ConversionConfig &config)
DenseMap< Region *, const TypeConverter * > regionToConverter
A mapping of regions to type converters that should be used when converting the arguments of blocks w...
bool wasOpReplaced(Operation *op) const
Return "true" if the given operation was replaced or erased.
void notifyBlockInserted(Block *block, Region *previous, Region::iterator previousIt) override
Notifies that a block was inserted.
DenseMap< UnrealizedConversionCastOp, UnresolvedMaterializationRewrite * > unresolvedMaterializations
A mapping of all unresolved materializations (UnrealizedConversionCastOp) to the corresponding rewrit...
LogicalResult remapValues(StringRef valueDiagTag, std::optional< Location > inputLoc, PatternRewriter &rewriter, ValueRange values, SmallVector< ValueVector > &remapped)
Remap the given values to those with potentially different types.
void resetState(RewriterState state)
Reset the state of the rewriter to a previously saved point.
Block * applySignatureConversion(ConversionPatternRewriter &rewriter, Block *block, const TypeConverter *converter, TypeConverter::SignatureConversion &signatureConversion)
Apply the given signature conversion on the given block.
FailureOr< Block * > convertRegionTypes(ConversionPatternRewriter &rewriter, Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion)
Convert the types of block arguments within the given region.
void applyRewrites()
Apply all requested operation rewrites.
void undoRewrites(unsigned numRewritesToKeep=0)
Undo the rewrites (motions, splits) one by one in reverse order until "numRewritesToKeep" rewrites re...
RewriterState getCurrentState()
Return the current state of the rewriter.
llvm::ScopedPrinter logger
A logger used to emit diagnostics during the conversion process.
void notifyBlockBeingInlined(Block *block, Block *srcBlock, Block::iterator before)
Notifies that a block is being inlined into another block.
void appendRewrite(Args &&...args)
Append a rewrite.
SmallPtrSet< Operation *, 1 > pendingRootUpdates
A set of operations that have pending updates.
void notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
Notifies that a pattern match failed for the given reason.
SingleEraseRewriter eraseRewriter
A rewriter that keeps track of ops/block that were already erased and skips duplicate op/block erasur...
bool isOpIgnored(Operation *op) const
Return "true" if the given operation is ignored, and does not need to be converted.
ValueRange buildUnresolvedMaterialization(MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc, ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes, Type originalType, const TypeConverter *converter, UnrealizedConversionCastOp *castOp=nullptr)
Build an unresolved materialization operation given a range of output types and a list of input opera...
SetVector< Operation * > ignoredOps
A set of operations that should no longer be considered for legalization.
SmallVector< std::unique_ptr< IRRewrite > > rewrites
Ordered list of block operations (creations, splits, motions).
const ConversionConfig & config
Dialect conversion configuration.
void notifyOpReplaced(Operation *op, SmallVector< SmallVector< Value >> &&newValues)
Notifies that an op is about to be replaced with the given values.
SetVector< Operation * > replacedOps
A set of operations that were replaced/erased.
void notifyBlockIsBeingErased(Block *block)
Notifies that a block is about to be erased.
const TypeConverter * currentTypeConverter
The current type converter, or nullptr if no type converter is currently active.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.