MLIR  21.0.0git
DialectConversion.cpp
Go to the documentation of this file.
1 //===- DialectConversion.cpp - MLIR dialect conversion generic pass -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 #include "mlir/Config/mlir-config.h"
11 #include "mlir/IR/Block.h"
12 #include "mlir/IR/Builders.h"
13 #include "mlir/IR/BuiltinOps.h"
14 #include "mlir/IR/Dominance.h"
15 #include "mlir/IR/IRMapping.h"
16 #include "mlir/IR/Iterators.h"
19 #include "llvm/ADT/ScopeExit.h"
20 #include "llvm/ADT/SetVector.h"
21 #include "llvm/ADT/SmallPtrSet.h"
22 #include "llvm/Support/Debug.h"
23 #include "llvm/Support/FormatVariadic.h"
24 #include "llvm/Support/SaveAndRestore.h"
25 #include "llvm/Support/ScopedPrinter.h"
26 #include <optional>
27 
28 using namespace mlir;
29 using namespace mlir::detail;
30 
31 #define DEBUG_TYPE "dialect-conversion"
32 
33 /// A utility function to log a successful result for the given reason.
34 template <typename... Args>
35 static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
36  LLVM_DEBUG({
37  os.unindent();
38  os.startLine() << "} -> SUCCESS";
39  if (!fmt.empty())
40  os.getOStream() << " : "
41  << llvm::formatv(fmt.data(), std::forward<Args>(args)...);
42  os.getOStream() << "\n";
43  });
44 }
45 
46 /// A utility function to log a failure result for the given reason.
47 template <typename... Args>
48 static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
49  LLVM_DEBUG({
50  os.unindent();
51  os.startLine() << "} -> FAILURE : "
52  << llvm::formatv(fmt.data(), std::forward<Args>(args)...)
53  << "\n";
54  });
55 }
56 
57 /// Helper function that computes an insertion point where the given value is
58 /// defined and can be used without a dominance violation.
60  Block *insertBlock = value.getParentBlock();
61  Block::iterator insertPt = insertBlock->begin();
62  if (OpResult inputRes = dyn_cast<OpResult>(value))
63  insertPt = ++inputRes.getOwner()->getIterator();
64  return OpBuilder::InsertPoint(insertBlock, insertPt);
65 }
66 
67 /// Helper function that computes an insertion point where the given values are
68 /// defined and can be used without a dominance violation.
70  assert(!vals.empty() && "expected at least one value");
71  DominanceInfo domInfo;
72  OpBuilder::InsertPoint pt = computeInsertPoint(vals.front());
73  for (Value v : vals.drop_front()) {
74  // Choose the "later" insertion point.
76  if (domInfo.dominates(pt.getBlock(), pt.getPoint(), nextPt.getBlock(),
77  nextPt.getPoint())) {
78  // pt is before nextPt => choose nextPt.
79  pt = nextPt;
80  } else {
81 #ifndef NDEBUG
82  // nextPt should be before pt => choose pt.
83  // If pt, nextPt are no dominance relationship, then there is no valid
84  // insertion point at which all given values are defined.
85  bool dom = domInfo.dominates(nextPt.getBlock(), nextPt.getPoint(),
86  pt.getBlock(), pt.getPoint());
87  assert(dom && "unable to find valid insertion point");
88 #endif // NDEBUG
89  }
90  }
91  return pt;
92 }
93 
94 //===----------------------------------------------------------------------===//
95 // ConversionValueMapping
96 //===----------------------------------------------------------------------===//
97 
98 /// A vector of SSA values, optimized for the most common case of a single
99 /// value.
100 using ValueVector = SmallVector<Value, 1>;
101 
102 namespace {
103 
104 /// Helper class to make it possible to use `ValueVector` as a key in DenseMap.
105 struct ValueVectorMapInfo {
106  static ValueVector getEmptyKey() { return ValueVector{Value()}; }
107  static ValueVector getTombstoneKey() { return ValueVector{Value(), Value()}; }
108  static ::llvm::hash_code getHashValue(const ValueVector &val) {
109  return ::llvm::hash_combine_range(val.begin(), val.end());
110  }
111  static bool isEqual(const ValueVector &LHS, const ValueVector &RHS) {
112  return LHS == RHS;
113  }
114 };
115 
116 /// This class wraps a IRMapping to provide recursive lookup
117 /// functionality, i.e. we will traverse if the mapped value also has a mapping.
118 struct ConversionValueMapping {
119  /// Return "true" if an SSA value is mapped to the given value. May return
120  /// false positives.
121  bool isMappedTo(Value value) const { return mappedTo.contains(value); }
122 
123  /// Lookup the most recently mapped values with the desired types in the
124  /// mapping.
125  ///
126  /// Special cases:
127  /// - If the desired type range is empty, simply return the most recently
128  /// mapped values.
129  /// - If there is no mapping to the desired types, also return the most
130  /// recently mapped values.
131  /// - If there is no mapping for the given values at all, return the given
132  /// value.
133  ValueVector lookupOrDefault(Value from, TypeRange desiredTypes = {}) const;
134 
135  /// Lookup the given value within the map, or return an empty vector if the
136  /// value is not mapped. If it is mapped, this follows the same behavior
137  /// as `lookupOrDefault`.
138  ValueVector lookupOrNull(Value from, TypeRange desiredTypes = {}) const;
139 
140  template <typename T>
141  struct IsValueVector : std::is_same<std::decay_t<T>, ValueVector> {};
142 
143  /// Map a value vector to the one provided.
144  template <typename OldVal, typename NewVal>
145  std::enable_if_t<IsValueVector<OldVal>::value && IsValueVector<NewVal>::value>
146  map(OldVal &&oldVal, NewVal &&newVal) {
147  LLVM_DEBUG({
148  ValueVector next(newVal);
149  while (true) {
150  assert(next != oldVal && "inserting cyclic mapping");
151  auto it = mapping.find(next);
152  if (it == mapping.end())
153  break;
154  next = it->second;
155  }
156  });
157  for (Value v : newVal)
158  mappedTo.insert(v);
159 
160  mapping[std::forward<OldVal>(oldVal)] = std::forward<NewVal>(newVal);
161  }
162 
163  /// Map a value vector or single value to the one provided.
164  template <typename OldVal, typename NewVal>
165  std::enable_if_t<!IsValueVector<OldVal>::value ||
166  !IsValueVector<NewVal>::value>
167  map(OldVal &&oldVal, NewVal &&newVal) {
168  if constexpr (IsValueVector<OldVal>{}) {
169  map(std::forward<OldVal>(oldVal), ValueVector{newVal});
170  } else if constexpr (IsValueVector<NewVal>{}) {
171  map(ValueVector{oldVal}, std::forward<NewVal>(newVal));
172  } else {
173  map(ValueVector{oldVal}, ValueVector{newVal});
174  }
175  }
176 
177  /// Drop the last mapping for the given values.
178  void erase(const ValueVector &value) { mapping.erase(value); }
179 
180 private:
181  /// Current value mappings.
183 
184  /// All SSA values that are mapped to. May contain false positives.
185  DenseSet<Value> mappedTo;
186 };
187 } // namespace
188 
190 ConversionValueMapping::lookupOrDefault(Value from,
191  TypeRange desiredTypes) const {
192  // Try to find the deepest values that have the desired types. If there is no
193  // such mapping, simply return the deepest values.
194  ValueVector desiredValue;
195  ValueVector current{from};
196  do {
197  // Store the current value if the types match.
198  if (TypeRange(ValueRange(current)) == desiredTypes)
199  desiredValue = current;
200 
201  // If possible, Replace each value with (one or multiple) mapped values.
202  ValueVector next;
203  for (Value v : current) {
204  auto it = mapping.find({v});
205  if (it != mapping.end()) {
206  llvm::append_range(next, it->second);
207  } else {
208  next.push_back(v);
209  }
210  }
211  if (next != current) {
212  // If at least one value was replaced, continue the lookup from there.
213  current = std::move(next);
214  continue;
215  }
216 
217  // Otherwise: Check if there is a mapping for the entire vector. Such
218  // mappings are materializations. (N:M mapping are not supported for value
219  // replacements.)
220  //
221  // Note: From a correctness point of view, materializations do not have to
222  // be stored (and looked up) in the mapping. But for performance reasons,
223  // we choose to reuse existing IR (when possible) instead of creating it
224  // multiple times.
225  auto it = mapping.find(current);
226  if (it == mapping.end()) {
227  // No mapping found: The lookup stops here.
228  break;
229  }
230  current = it->second;
231  } while (true);
232 
233  // If the desired values were found use them, otherwise default to the leaf
234  // values.
235  // Note: If `desiredTypes` is empty, this function always returns `current`.
236  return !desiredValue.empty() ? std::move(desiredValue) : std::move(current);
237 }
238 
239 ValueVector ConversionValueMapping::lookupOrNull(Value from,
240  TypeRange desiredTypes) const {
241  ValueVector result = lookupOrDefault(from, desiredTypes);
242  if (result == ValueVector{from} ||
243  (!desiredTypes.empty() && TypeRange(ValueRange(result)) != desiredTypes))
244  return {};
245  return result;
246 }
247 
248 //===----------------------------------------------------------------------===//
249 // Rewriter and Translation State
250 //===----------------------------------------------------------------------===//
251 namespace {
252 /// This class contains a snapshot of the current conversion rewriter state.
253 /// This is useful when saving and undoing a set of rewrites.
254 struct RewriterState {
255  RewriterState(unsigned numRewrites, unsigned numIgnoredOperations,
256  unsigned numReplacedOps)
257  : numRewrites(numRewrites), numIgnoredOperations(numIgnoredOperations),
258  numReplacedOps(numReplacedOps) {}
259 
260  /// The current number of rewrites performed.
261  unsigned numRewrites;
262 
263  /// The current number of ignored operations.
264  unsigned numIgnoredOperations;
265 
266  /// The current number of replaced ops that are scheduled for erasure.
267  unsigned numReplacedOps;
268 };
269 
270 //===----------------------------------------------------------------------===//
271 // IR rewrites
272 //===----------------------------------------------------------------------===//
273 
274 /// An IR rewrite that can be committed (upon success) or rolled back (upon
275 /// failure).
276 ///
277 /// The dialect conversion keeps track of IR modifications (requested by the
278 /// user through the rewriter API) in `IRRewrite` objects. Some kind of rewrites
279 /// are directly applied to the IR as the rewriter API is used, some are applied
280 /// partially, and some are delayed until the `IRRewrite` objects are committed.
281 class IRRewrite {
282 public:
283  /// The kind of the rewrite. Rewrites can be undone if the conversion fails.
284  /// Enum values are ordered, so that they can be used in `classof`: first all
285  /// block rewrites, then all operation rewrites.
286  enum class Kind {
287  // Block rewrites
288  CreateBlock,
289  EraseBlock,
290  InlineBlock,
291  MoveBlock,
292  BlockTypeConversion,
293  ReplaceBlockArg,
294  // Operation rewrites
295  MoveOperation,
296  ModifyOperation,
297  ReplaceOperation,
298  CreateOperation,
299  UnresolvedMaterialization
300  };
301 
302  virtual ~IRRewrite() = default;
303 
304  /// Roll back the rewrite. Operations may be erased during rollback.
305  virtual void rollback() = 0;
306 
307  /// Commit the rewrite. At this point, it is certain that the dialect
308  /// conversion will succeed. All IR modifications, except for operation/block
309  /// erasure, must be performed through the given rewriter.
310  ///
311  /// Instead of erasing operations/blocks, they should merely be unlinked
312  /// commit phase and finally be erased during the cleanup phase. This is
313  /// because internal dialect conversion state (such as `mapping`) may still
314  /// be using them.
315  ///
316  /// Any IR modification that was already performed before the commit phase
317  /// (e.g., insertion of an op) must be communicated to the listener that may
318  /// be attached to the given rewriter.
319  virtual void commit(RewriterBase &rewriter) {}
320 
321  /// Cleanup operations/blocks. Cleanup is called after commit.
322  virtual void cleanup(RewriterBase &rewriter) {}
323 
324  Kind getKind() const { return kind; }
325 
326  static bool classof(const IRRewrite *rewrite) { return true; }
327 
328 protected:
329  IRRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl)
330  : kind(kind), rewriterImpl(rewriterImpl) {}
331 
332  const ConversionConfig &getConfig() const;
333 
334  const Kind kind;
335  ConversionPatternRewriterImpl &rewriterImpl;
336 };
337 
338 /// A block rewrite.
339 class BlockRewrite : public IRRewrite {
340 public:
341  /// Return the block that this rewrite operates on.
342  Block *getBlock() const { return block; }
343 
344  static bool classof(const IRRewrite *rewrite) {
345  return rewrite->getKind() >= Kind::CreateBlock &&
346  rewrite->getKind() <= Kind::ReplaceBlockArg;
347  }
348 
349 protected:
350  BlockRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl,
351  Block *block)
352  : IRRewrite(kind, rewriterImpl), block(block) {}
353 
354  // The block that this rewrite operates on.
355  Block *block;
356 };
357 
358 /// Creation of a block. Block creations are immediately reflected in the IR.
359 /// There is no extra work to commit the rewrite. During rollback, the newly
360 /// created block is erased.
361 class CreateBlockRewrite : public BlockRewrite {
362 public:
363  CreateBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block)
364  : BlockRewrite(Kind::CreateBlock, rewriterImpl, block) {}
365 
366  static bool classof(const IRRewrite *rewrite) {
367  return rewrite->getKind() == Kind::CreateBlock;
368  }
369 
370  void commit(RewriterBase &rewriter) override {
371  // The block was already created and inserted. Just inform the listener.
372  if (auto *listener = rewriter.getListener())
373  listener->notifyBlockInserted(block, /*previous=*/{}, /*previousIt=*/{});
374  }
375 
376  void rollback() override {
377  // Unlink all of the operations within this block, they will be deleted
378  // separately.
379  auto &blockOps = block->getOperations();
380  while (!blockOps.empty())
381  blockOps.remove(blockOps.begin());
382  block->dropAllUses();
383  if (block->getParent())
384  block->erase();
385  else
386  delete block;
387  }
388 };
389 
390 /// Erasure of a block. Block erasures are partially reflected in the IR. Erased
391 /// blocks are immediately unlinked, but only erased during cleanup. This makes
392 /// it easier to rollback a block erasure: the block is simply inserted into its
393 /// original location.
394 class EraseBlockRewrite : public BlockRewrite {
395 public:
396  EraseBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block)
397  : BlockRewrite(Kind::EraseBlock, rewriterImpl, block),
398  region(block->getParent()), insertBeforeBlock(block->getNextNode()) {}
399 
400  static bool classof(const IRRewrite *rewrite) {
401  return rewrite->getKind() == Kind::EraseBlock;
402  }
403 
404  ~EraseBlockRewrite() override {
405  assert(!block &&
406  "rewrite was neither rolled back nor committed/cleaned up");
407  }
408 
409  void rollback() override {
410  // The block (owned by this rewrite) was not actually erased yet. It was
411  // just unlinked. Put it back into its original position.
412  assert(block && "expected block");
413  auto &blockList = region->getBlocks();
414  Region::iterator before = insertBeforeBlock
415  ? Region::iterator(insertBeforeBlock)
416  : blockList.end();
417  blockList.insert(before, block);
418  block = nullptr;
419  }
420 
421  void commit(RewriterBase &rewriter) override {
422  // Erase the block.
423  assert(block && "expected block");
424  assert(block->empty() && "expected empty block");
425 
426  // Notify the listener that the block is about to be erased.
427  if (auto *listener =
428  dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener()))
429  listener->notifyBlockErased(block);
430  }
431 
432  void cleanup(RewriterBase &rewriter) override {
433  // Erase the block.
434  block->dropAllDefinedValueUses();
435  delete block;
436  block = nullptr;
437  }
438 
439 private:
440  // The region in which this block was previously contained.
441  Region *region;
442 
443  // The original successor of this block before it was unlinked. "nullptr" if
444  // this block was the only block in the region.
445  Block *insertBeforeBlock;
446 };
447 
448 /// Inlining of a block. This rewrite is immediately reflected in the IR.
449 /// Note: This rewrite represents only the inlining of the operations. The
450 /// erasure of the inlined block is a separate rewrite.
451 class InlineBlockRewrite : public BlockRewrite {
452 public:
453  InlineBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block,
454  Block *sourceBlock, Block::iterator before)
455  : BlockRewrite(Kind::InlineBlock, rewriterImpl, block),
456  sourceBlock(sourceBlock),
457  firstInlinedInst(sourceBlock->empty() ? nullptr
458  : &sourceBlock->front()),
459  lastInlinedInst(sourceBlock->empty() ? nullptr : &sourceBlock->back()) {
460  // If a listener is attached to the dialect conversion, ops must be moved
461  // one-by-one. When they are moved in bulk, notifications cannot be sent
462  // because the ops that used to be in the source block at the time of the
463  // inlining (before the "commit" phase) are unknown at the time when
464  // notifications are sent (which is during the "commit" phase).
465  assert(!getConfig().listener &&
466  "InlineBlockRewrite not supported if listener is attached");
467  }
468 
469  static bool classof(const IRRewrite *rewrite) {
470  return rewrite->getKind() == Kind::InlineBlock;
471  }
472 
473  void rollback() override {
474  // Put the operations from the destination block (owned by the rewrite)
475  // back into the source block.
476  if (firstInlinedInst) {
477  assert(lastInlinedInst && "expected operation");
478  sourceBlock->getOperations().splice(sourceBlock->begin(),
479  block->getOperations(),
480  Block::iterator(firstInlinedInst),
481  ++Block::iterator(lastInlinedInst));
482  }
483  }
484 
485 private:
486  // The block that originally contained the operations.
487  Block *sourceBlock;
488 
489  // The first inlined operation.
490  Operation *firstInlinedInst;
491 
492  // The last inlined operation.
493  Operation *lastInlinedInst;
494 };
495 
496 /// Moving of a block. This rewrite is immediately reflected in the IR.
497 class MoveBlockRewrite : public BlockRewrite {
498 public:
499  MoveBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block,
500  Region *region, Block *insertBeforeBlock)
501  : BlockRewrite(Kind::MoveBlock, rewriterImpl, block), region(region),
502  insertBeforeBlock(insertBeforeBlock) {}
503 
504  static bool classof(const IRRewrite *rewrite) {
505  return rewrite->getKind() == Kind::MoveBlock;
506  }
507 
508  void commit(RewriterBase &rewriter) override {
509  // The block was already moved. Just inform the listener.
510  if (auto *listener = rewriter.getListener()) {
511  // Note: `previousIt` cannot be passed because this is a delayed
512  // notification and iterators into past IR state cannot be represented.
513  listener->notifyBlockInserted(block, /*previous=*/region,
514  /*previousIt=*/{});
515  }
516  }
517 
518  void rollback() override {
519  // Move the block back to its original position.
520  Region::iterator before =
521  insertBeforeBlock ? Region::iterator(insertBeforeBlock) : region->end();
522  region->getBlocks().splice(before, block->getParent()->getBlocks(), block);
523  }
524 
525 private:
526  // The region in which this block was previously contained.
527  Region *region;
528 
529  // The original successor of this block before it was moved. "nullptr" if
530  // this block was the only block in the region.
531  Block *insertBeforeBlock;
532 };
533 
534 /// Block type conversion. This rewrite is partially reflected in the IR.
535 class BlockTypeConversionRewrite : public BlockRewrite {
536 public:
537  BlockTypeConversionRewrite(ConversionPatternRewriterImpl &rewriterImpl,
538  Block *origBlock, Block *newBlock)
539  : BlockRewrite(Kind::BlockTypeConversion, rewriterImpl, origBlock),
540  newBlock(newBlock) {}
541 
542  static bool classof(const IRRewrite *rewrite) {
543  return rewrite->getKind() == Kind::BlockTypeConversion;
544  }
545 
546  Block *getOrigBlock() const { return block; }
547 
548  Block *getNewBlock() const { return newBlock; }
549 
550  void commit(RewriterBase &rewriter) override;
551 
552  void rollback() override;
553 
554 private:
555  /// The new block that was created as part of this signature conversion.
556  Block *newBlock;
557 };
558 
559 /// Replacing a block argument. This rewrite is not immediately reflected in the
560 /// IR. An internal IR mapping is updated, but the actual replacement is delayed
561 /// until the rewrite is committed.
562 class ReplaceBlockArgRewrite : public BlockRewrite {
563 public:
564  ReplaceBlockArgRewrite(ConversionPatternRewriterImpl &rewriterImpl,
565  Block *block, BlockArgument arg,
566  const TypeConverter *converter)
567  : BlockRewrite(Kind::ReplaceBlockArg, rewriterImpl, block), arg(arg),
568  converter(converter) {}
569 
570  static bool classof(const IRRewrite *rewrite) {
571  return rewrite->getKind() == Kind::ReplaceBlockArg;
572  }
573 
574  void commit(RewriterBase &rewriter) override;
575 
576  void rollback() override;
577 
578 private:
579  BlockArgument arg;
580 
581  /// The current type converter when the block argument was replaced.
582  const TypeConverter *converter;
583 };
584 
585 /// An operation rewrite.
586 class OperationRewrite : public IRRewrite {
587 public:
588  /// Return the operation that this rewrite operates on.
589  Operation *getOperation() const { return op; }
590 
591  static bool classof(const IRRewrite *rewrite) {
592  return rewrite->getKind() >= Kind::MoveOperation &&
593  rewrite->getKind() <= Kind::UnresolvedMaterialization;
594  }
595 
596 protected:
597  OperationRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl,
598  Operation *op)
599  : IRRewrite(kind, rewriterImpl), op(op) {}
600 
601  // The operation that this rewrite operates on.
602  Operation *op;
603 };
604 
605 /// Moving of an operation. This rewrite is immediately reflected in the IR.
606 class MoveOperationRewrite : public OperationRewrite {
607 public:
608  MoveOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
609  Operation *op, Block *block, Operation *insertBeforeOp)
610  : OperationRewrite(Kind::MoveOperation, rewriterImpl, op), block(block),
611  insertBeforeOp(insertBeforeOp) {}
612 
613  static bool classof(const IRRewrite *rewrite) {
614  return rewrite->getKind() == Kind::MoveOperation;
615  }
616 
617  void commit(RewriterBase &rewriter) override {
618  // The operation was already moved. Just inform the listener.
619  if (auto *listener = rewriter.getListener()) {
620  // Note: `previousIt` cannot be passed because this is a delayed
621  // notification and iterators into past IR state cannot be represented.
622  listener->notifyOperationInserted(
623  op, /*previous=*/OpBuilder::InsertPoint(/*insertBlock=*/block,
624  /*insertPt=*/{}));
625  }
626  }
627 
628  void rollback() override {
629  // Move the operation back to its original position.
630  Block::iterator before =
631  insertBeforeOp ? Block::iterator(insertBeforeOp) : block->end();
632  block->getOperations().splice(before, op->getBlock()->getOperations(), op);
633  }
634 
635 private:
636  // The block in which this operation was previously contained.
637  Block *block;
638 
639  // The original successor of this operation before it was moved. "nullptr"
640  // if this operation was the only operation in the region.
641  Operation *insertBeforeOp;
642 };
643 
644 /// In-place modification of an op. This rewrite is immediately reflected in
645 /// the IR. The previous state of the operation is stored in this object.
646 class ModifyOperationRewrite : public OperationRewrite {
647 public:
648  ModifyOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
649  Operation *op)
650  : OperationRewrite(Kind::ModifyOperation, rewriterImpl, op),
651  name(op->getName()), loc(op->getLoc()), attrs(op->getAttrDictionary()),
652  operands(op->operand_begin(), op->operand_end()),
653  successors(op->successor_begin(), op->successor_end()) {
654  if (OpaqueProperties prop = op->getPropertiesStorage()) {
655  // Make a copy of the properties.
656  propertiesStorage = operator new(op->getPropertiesStorageSize());
657  OpaqueProperties propCopy(propertiesStorage);
658  name.initOpProperties(propCopy, /*init=*/prop);
659  }
660  }
661 
662  static bool classof(const IRRewrite *rewrite) {
663  return rewrite->getKind() == Kind::ModifyOperation;
664  }
665 
666  ~ModifyOperationRewrite() override {
667  assert(!propertiesStorage &&
668  "rewrite was neither committed nor rolled back");
669  }
670 
671  void commit(RewriterBase &rewriter) override {
672  // Notify the listener that the operation was modified in-place.
673  if (auto *listener =
674  dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener()))
675  listener->notifyOperationModified(op);
676 
677  if (propertiesStorage) {
678  OpaqueProperties propCopy(propertiesStorage);
679  // Note: The operation may have been erased in the mean time, so
680  // OperationName must be stored in this object.
681  name.destroyOpProperties(propCopy);
682  operator delete(propertiesStorage);
683  propertiesStorage = nullptr;
684  }
685  }
686 
687  void rollback() override {
688  op->setLoc(loc);
689  op->setAttrs(attrs);
690  op->setOperands(operands);
691  for (const auto &it : llvm::enumerate(successors))
692  op->setSuccessor(it.value(), it.index());
693  if (propertiesStorage) {
694  OpaqueProperties propCopy(propertiesStorage);
695  op->copyProperties(propCopy);
696  name.destroyOpProperties(propCopy);
697  operator delete(propertiesStorage);
698  propertiesStorage = nullptr;
699  }
700  }
701 
702 private:
703  OperationName name;
704  LocationAttr loc;
705  DictionaryAttr attrs;
706  SmallVector<Value, 8> operands;
707  SmallVector<Block *, 2> successors;
708  void *propertiesStorage = nullptr;
709 };
710 
711 /// Replacing an operation. Erasing an operation is treated as a special case
712 /// with "null" replacements. This rewrite is not immediately reflected in the
713 /// IR. An internal IR mapping is updated, but values are not replaced and the
714 /// original op is not erased until the rewrite is committed.
715 class ReplaceOperationRewrite : public OperationRewrite {
716 public:
717  ReplaceOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
718  Operation *op, const TypeConverter *converter)
719  : OperationRewrite(Kind::ReplaceOperation, rewriterImpl, op),
720  converter(converter) {}
721 
722  static bool classof(const IRRewrite *rewrite) {
723  return rewrite->getKind() == Kind::ReplaceOperation;
724  }
725 
726  void commit(RewriterBase &rewriter) override;
727 
728  void rollback() override;
729 
730  void cleanup(RewriterBase &rewriter) override;
731 
732 private:
733  /// An optional type converter that can be used to materialize conversions
734  /// between the new and old values if necessary.
735  const TypeConverter *converter;
736 };
737 
738 class CreateOperationRewrite : public OperationRewrite {
739 public:
740  CreateOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
741  Operation *op)
742  : OperationRewrite(Kind::CreateOperation, rewriterImpl, op) {}
743 
744  static bool classof(const IRRewrite *rewrite) {
745  return rewrite->getKind() == Kind::CreateOperation;
746  }
747 
748  void commit(RewriterBase &rewriter) override {
749  // The operation was already created and inserted. Just inform the listener.
750  if (auto *listener = rewriter.getListener())
751  listener->notifyOperationInserted(op, /*previous=*/{});
752  }
753 
754  void rollback() override;
755 };
756 
757 /// The type of materialization.
758 enum MaterializationKind {
759  /// This materialization materializes a conversion from an illegal type to a
760  /// legal one.
761  Target,
762 
763  /// This materialization materializes a conversion from a legal type back to
764  /// an illegal one.
765  Source
766 };
767 
768 /// An unresolved materialization, i.e., a "builtin.unrealized_conversion_cast"
769 /// op. Unresolved materializations are erased at the end of the dialect
770 /// conversion.
771 class UnresolvedMaterializationRewrite : public OperationRewrite {
772 public:
773  UnresolvedMaterializationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
774  UnrealizedConversionCastOp op,
775  const TypeConverter *converter,
776  MaterializationKind kind, Type originalType,
777  ValueVector mappedValues);
778 
779  static bool classof(const IRRewrite *rewrite) {
780  return rewrite->getKind() == Kind::UnresolvedMaterialization;
781  }
782 
783  void rollback() override;
784 
785  UnrealizedConversionCastOp getOperation() const {
786  return cast<UnrealizedConversionCastOp>(op);
787  }
788 
789  /// Return the type converter of this materialization (which may be null).
790  const TypeConverter *getConverter() const {
791  return converterAndKind.getPointer();
792  }
793 
794  /// Return the kind of this materialization.
795  MaterializationKind getMaterializationKind() const {
796  return converterAndKind.getInt();
797  }
798 
799  /// Return the original type of the SSA value.
800  Type getOriginalType() const { return originalType; }
801 
802 private:
803  /// The corresponding type converter to use when resolving this
804  /// materialization, and the kind of this materialization.
805  llvm::PointerIntPair<const TypeConverter *, 2, MaterializationKind>
806  converterAndKind;
807 
808  /// The original type of the SSA value. Only used for target
809  /// materializations.
810  Type originalType;
811 
812  /// The values in the conversion value mapping that are being replaced by the
813  /// results of this unresolved materialization.
814  ValueVector mappedValues;
815 };
816 } // namespace
817 
818 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
819 /// Return "true" if there is an operation rewrite that matches the specified
820 /// rewrite type and operation among the given rewrites.
821 template <typename RewriteTy, typename R>
822 static bool hasRewrite(R &&rewrites, Operation *op) {
823  return any_of(std::forward<R>(rewrites), [&](auto &rewrite) {
824  auto *rewriteTy = dyn_cast<RewriteTy>(rewrite.get());
825  return rewriteTy && rewriteTy->getOperation() == op;
826  });
827 }
828 
829 /// Return "true" if there is a block rewrite that matches the specified
830 /// rewrite type and block among the given rewrites.
831 template <typename RewriteTy, typename R>
832 static bool hasRewrite(R &&rewrites, Block *block) {
833  return any_of(std::forward<R>(rewrites), [&](auto &rewrite) {
834  auto *rewriteTy = dyn_cast<RewriteTy>(rewrite.get());
835  return rewriteTy && rewriteTy->getBlock() == block;
836  });
837 }
838 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
839 
840 //===----------------------------------------------------------------------===//
841 // ConversionPatternRewriterImpl
842 //===----------------------------------------------------------------------===//
843 namespace mlir {
844 namespace detail {
847  const ConversionConfig &config)
848  : context(ctx), eraseRewriter(ctx), config(config) {}
849 
850  //===--------------------------------------------------------------------===//
851  // State Management
852  //===--------------------------------------------------------------------===//
853 
854  /// Return the current state of the rewriter.
855  RewriterState getCurrentState();
856 
857  /// Apply all requested operation rewrites. This method is invoked when the
858  /// conversion process succeeds.
859  void applyRewrites();
860 
861  /// Reset the state of the rewriter to a previously saved point.
862  void resetState(RewriterState state);
863 
864  /// Append a rewrite. Rewrites are committed upon success and rolled back upon
865  /// failure.
866  template <typename RewriteTy, typename... Args>
867  void appendRewrite(Args &&...args) {
868  rewrites.push_back(
869  std::make_unique<RewriteTy>(*this, std::forward<Args>(args)...));
870  }
871 
872  /// Undo the rewrites (motions, splits) one by one in reverse order until
873  /// "numRewritesToKeep" rewrites remains.
874  void undoRewrites(unsigned numRewritesToKeep = 0);
875 
876  /// Remap the given values to those with potentially different types. Returns
877  /// success if the values could be remapped, failure otherwise. `valueDiagTag`
878  /// is the tag used when describing a value within a diagnostic, e.g.
879  /// "operand".
880  LogicalResult remapValues(StringRef valueDiagTag,
881  std::optional<Location> inputLoc,
882  PatternRewriter &rewriter, ValueRange values,
883  SmallVector<ValueVector> &remapped);
884 
885  /// Return "true" if the given operation is ignored, and does not need to be
886  /// converted.
887  bool isOpIgnored(Operation *op) const;
888 
889  /// Return "true" if the given operation was replaced or erased.
890  bool wasOpReplaced(Operation *op) const;
891 
892  //===--------------------------------------------------------------------===//
893  // Type Conversion
894  //===--------------------------------------------------------------------===//
895 
896  /// Convert the types of block arguments within the given region.
897  FailureOr<Block *>
898  convertRegionTypes(ConversionPatternRewriter &rewriter, Region *region,
899  const TypeConverter &converter,
900  TypeConverter::SignatureConversion *entryConversion);
901 
902  /// Apply the given signature conversion on the given block. The new block
903  /// containing the updated signature is returned. If no conversions were
904  /// necessary, e.g. if the block has no arguments, `block` is returned.
905  /// `converter` is used to generate any necessary cast operations that
906  /// translate between the origin argument types and those specified in the
907  /// signature conversion.
908  Block *applySignatureConversion(
909  ConversionPatternRewriter &rewriter, Block *block,
910  const TypeConverter *converter,
911  TypeConverter::SignatureConversion &signatureConversion);
912 
913  //===--------------------------------------------------------------------===//
914  // Materializations
915  //===--------------------------------------------------------------------===//
916 
917  /// Build an unresolved materialization operation given a range of output
918  /// types and a list of input operands. Returns the inputs if they their
919  /// types match the output types.
920  ///
921  /// If a cast op was built, it can optionally be returned with the `castOp`
922  /// output argument.
923  ///
924  /// If `valuesToMap` is set to a non-null Value, then that value is mapped to
925  /// the results of the unresolved materialization in the conversion value
926  /// mapping.
927  ValueRange buildUnresolvedMaterialization(
928  MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
929  ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes,
930  Type originalType, const TypeConverter *converter,
931  UnrealizedConversionCastOp *castOp = nullptr);
932 
933  /// Find a replacement value for the given SSA value in the conversion value
934  /// mapping. The replacement value must have the same type as the given SSA
935  /// value. If there is no replacement value with the correct type, find the
936  /// latest replacement value (regardless of the type) and build a source
937  /// materialization.
938  Value findOrBuildReplacementValue(Value value,
939  const TypeConverter *converter);
940 
941  //===--------------------------------------------------------------------===//
942  // Rewriter Notification Hooks
943  //===--------------------------------------------------------------------===//
944 
945  //// Notifies that an op was inserted.
946  void notifyOperationInserted(Operation *op,
947  OpBuilder::InsertPoint previous) override;
948 
949  /// Notifies that an op is about to be replaced with the given values.
950  void notifyOpReplaced(Operation *op, ArrayRef<ValueRange> newValues);
951 
952  /// Notifies that a block is about to be erased.
953  void notifyBlockIsBeingErased(Block *block);
954 
955  /// Notifies that a block was inserted.
956  void notifyBlockInserted(Block *block, Region *previous,
957  Region::iterator previousIt) override;
958 
959  /// Notifies that a block is being inlined into another block.
960  void notifyBlockBeingInlined(Block *block, Block *srcBlock,
961  Block::iterator before);
962 
963  /// Notifies that a pattern match failed for the given reason.
964  void
965  notifyMatchFailure(Location loc,
966  function_ref<void(Diagnostic &)> reasonCallback) override;
967 
968  //===--------------------------------------------------------------------===//
969  // IR Erasure
970  //===--------------------------------------------------------------------===//
971 
972  /// A rewriter that keeps track of erased ops and blocks. It ensures that no
973  /// operation or block is erased multiple times. This rewriter assumes that
974  /// no new IR is created between calls to `eraseOp`/`eraseBlock`.
976  public:
978  : RewriterBase(context, /*listener=*/this) {}
979 
980  /// Erase the given op (unless it was already erased).
981  void eraseOp(Operation *op) override {
982  if (wasErased(op))
983  return;
984  op->dropAllUses();
986  }
987 
988  /// Erase the given block (unless it was already erased).
989  void eraseBlock(Block *block) override {
990  if (wasErased(block))
991  return;
992  assert(block->empty() && "expected empty block");
993  block->dropAllDefinedValueUses();
995  }
996 
997  bool wasErased(void *ptr) const { return erased.contains(ptr); }
998 
999  void notifyOperationErased(Operation *op) override { erased.insert(op); }
1000 
1001  void notifyBlockErased(Block *block) override { erased.insert(block); }
1002 
1003  private:
1004  /// Pointers to all erased operations and blocks.
1005  DenseSet<void *> erased;
1006  };
1007 
1008  //===--------------------------------------------------------------------===//
1009  // State
1010  //===--------------------------------------------------------------------===//
1011 
1012  /// MLIR context.
1014 
1015  /// A rewriter that keeps track of ops/block that were already erased and
1016  /// skips duplicate op/block erasures. This rewriter is used during the
1017  /// "cleanup" phase.
1019 
1020  // Mapping between replaced values that differ in type. This happens when
1021  // replacing a value with one of a different type.
1022  ConversionValueMapping mapping;
1023 
1024  /// Ordered list of block operations (creations, splits, motions).
1026 
1027  /// A set of operations that should no longer be considered for legalization.
1028  /// E.g., ops that are recursively legal. Ops that were replaced/erased are
1029  /// tracked separately.
1031 
1032  /// A set of operations that were replaced/erased. Such ops are not erased
1033  /// immediately but only when the dialect conversion succeeds. In the mean
1034  /// time, they should no longer be considered for legalization and any attempt
1035  /// to modify/access them is invalid rewriter API usage.
1037 
1038  /// A mapping of all unresolved materializations (UnrealizedConversionCastOp)
1039  /// to the corresponding rewrite objects.
1042 
1043  /// The current type converter, or nullptr if no type converter is currently
1044  /// active.
1045  const TypeConverter *currentTypeConverter = nullptr;
1046 
1047  /// A mapping of regions to type converters that should be used when
1048  /// converting the arguments of blocks within that region.
1050 
1051  /// Dialect conversion configuration.
1053 
1054 #ifndef NDEBUG
1055  /// A set of operations that have pending updates. This tracking isn't
1056  /// strictly necessary, and is thus only active during debug builds for extra
1057  /// verification.
1059 
1060  /// A logger used to emit diagnostics during the conversion process.
1061  llvm::ScopedPrinter logger{llvm::dbgs()};
1062 #endif
1063 };
1064 } // namespace detail
1065 } // namespace mlir
1066 
1067 const ConversionConfig &IRRewrite::getConfig() const {
1068  return rewriterImpl.config;
1069 }
1070 
1071 void BlockTypeConversionRewrite::commit(RewriterBase &rewriter) {
1072  // Inform the listener about all IR modifications that have already taken
1073  // place: References to the original block have been replaced with the new
1074  // block.
1075  if (auto *listener =
1076  dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener()))
1077  for (Operation *op : getNewBlock()->getUsers())
1078  listener->notifyOperationModified(op);
1079 }
1080 
1081 void BlockTypeConversionRewrite::rollback() {
1082  getNewBlock()->replaceAllUsesWith(getOrigBlock());
1083 }
1084 
1085 void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) {
1086  Value repl = rewriterImpl.findOrBuildReplacementValue(arg, converter);
1087  if (!repl)
1088  return;
1089 
1090  if (isa<BlockArgument>(repl)) {
1091  rewriter.replaceAllUsesWith(arg, repl);
1092  return;
1093  }
1094 
1095  // If the replacement value is an operation, we check to make sure that we
1096  // don't replace uses that are within the parent operation of the
1097  // replacement value.
1098  Operation *replOp = cast<OpResult>(repl).getOwner();
1099  Block *replBlock = replOp->getBlock();
1100  rewriter.replaceUsesWithIf(arg, repl, [&](OpOperand &operand) {
1101  Operation *user = operand.getOwner();
1102  return user->getBlock() != replBlock || replOp->isBeforeInBlock(user);
1103  });
1104 }
1105 
1106 void ReplaceBlockArgRewrite::rollback() { rewriterImpl.mapping.erase({arg}); }
1107 
1108 void ReplaceOperationRewrite::commit(RewriterBase &rewriter) {
1109  auto *listener =
1110  dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener());
1111 
1112  // Compute replacement values.
1113  SmallVector<Value> replacements =
1114  llvm::map_to_vector(op->getResults(), [&](OpResult result) {
1115  return rewriterImpl.findOrBuildReplacementValue(result, converter);
1116  });
1117 
1118  // Notify the listener that the operation is about to be replaced.
1119  if (listener)
1120  listener->notifyOperationReplaced(op, replacements);
1121 
1122  // Replace all uses with the new values.
1123  for (auto [result, newValue] :
1124  llvm::zip_equal(op->getResults(), replacements))
1125  if (newValue)
1126  rewriter.replaceAllUsesWith(result, newValue);
1127 
1128  // The original op will be erased, so remove it from the set of unlegalized
1129  // ops.
1130  if (getConfig().unlegalizedOps)
1131  getConfig().unlegalizedOps->erase(op);
1132 
1133  // Notify the listener that the operation (and its nested operations) was
1134  // erased.
1135  if (listener) {
1137  [&](Operation *op) { listener->notifyOperationErased(op); });
1138  }
1139 
1140  // Do not erase the operation yet. It may still be referenced in `mapping`.
1141  // Just unlink it for now and erase it during cleanup.
1142  op->getBlock()->getOperations().remove(op);
1143 }
1144 
1145 void ReplaceOperationRewrite::rollback() {
1146  for (auto result : op->getResults())
1147  rewriterImpl.mapping.erase({result});
1148 }
1149 
1150 void ReplaceOperationRewrite::cleanup(RewriterBase &rewriter) {
1151  rewriter.eraseOp(op);
1152 }
1153 
1154 void CreateOperationRewrite::rollback() {
1155  for (Region &region : op->getRegions()) {
1156  while (!region.getBlocks().empty())
1157  region.getBlocks().remove(region.getBlocks().begin());
1158  }
1159  op->dropAllUses();
1160  op->erase();
1161 }
1162 
1163 UnresolvedMaterializationRewrite::UnresolvedMaterializationRewrite(
1164  ConversionPatternRewriterImpl &rewriterImpl, UnrealizedConversionCastOp op,
1165  const TypeConverter *converter, MaterializationKind kind, Type originalType,
1166  ValueVector mappedValues)
1167  : OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
1168  converterAndKind(converter, kind), originalType(originalType),
1169  mappedValues(std::move(mappedValues)) {
1170  assert((!originalType || kind == MaterializationKind::Target) &&
1171  "original type is valid only for target materializations");
1172  rewriterImpl.unresolvedMaterializations[op] = this;
1173 }
1174 
1175 void UnresolvedMaterializationRewrite::rollback() {
1176  if (!mappedValues.empty())
1177  rewriterImpl.mapping.erase(mappedValues);
1178  rewriterImpl.unresolvedMaterializations.erase(getOperation());
1179  op->erase();
1180 }
1181 
1183  // Commit all rewrites.
1184  IRRewriter rewriter(context, config.listener);
1185  // Note: New rewrites may be added during the "commit" phase and the
1186  // `rewrites` vector may reallocate.
1187  for (size_t i = 0; i < rewrites.size(); ++i)
1188  rewrites[i]->commit(rewriter);
1189 
1190  // Clean up all rewrites.
1191  for (auto &rewrite : rewrites)
1192  rewrite->cleanup(eraseRewriter);
1193 }
1194 
1195 //===----------------------------------------------------------------------===//
1196 // State Management
1197 
1199  return RewriterState(rewrites.size(), ignoredOps.size(), replacedOps.size());
1200 }
1201 
1203  // Undo any rewrites.
1204  undoRewrites(state.numRewrites);
1205 
1206  // Pop all of the recorded ignored operations that are no longer valid.
1207  while (ignoredOps.size() != state.numIgnoredOperations)
1208  ignoredOps.pop_back();
1209 
1210  while (replacedOps.size() != state.numReplacedOps)
1211  replacedOps.pop_back();
1212 }
1213 
1214 void ConversionPatternRewriterImpl::undoRewrites(unsigned numRewritesToKeep) {
1215  for (auto &rewrite :
1216  llvm::reverse(llvm::drop_begin(rewrites, numRewritesToKeep)))
1217  rewrite->rollback();
1218  rewrites.resize(numRewritesToKeep);
1219 }
1220 
1222  StringRef valueDiagTag, std::optional<Location> inputLoc,
1223  PatternRewriter &rewriter, ValueRange values,
1224  SmallVector<ValueVector> &remapped) {
1225  remapped.reserve(llvm::size(values));
1226 
1227  for (const auto &it : llvm::enumerate(values)) {
1228  Value operand = it.value();
1229  Type origType = operand.getType();
1230  Location operandLoc = inputLoc ? *inputLoc : operand.getLoc();
1231 
1232  if (!currentTypeConverter) {
1233  // The current pattern does not have a type converter. I.e., it does not
1234  // distinguish between legal and illegal types. For each operand, simply
1235  // pass through the most recently mapped values.
1236  remapped.push_back(mapping.lookupOrDefault(operand));
1237  continue;
1238  }
1239 
1240  // If there is no legal conversion, fail to match this pattern.
1241  SmallVector<Type, 1> legalTypes;
1242  if (failed(currentTypeConverter->convertType(origType, legalTypes))) {
1243  notifyMatchFailure(operandLoc, [=](Diagnostic &diag) {
1244  diag << "unable to convert type for " << valueDiagTag << " #"
1245  << it.index() << ", type was " << origType;
1246  });
1247  return failure();
1248  }
1249  // If a type is converted to 0 types, there is nothing to do.
1250  if (legalTypes.empty()) {
1251  remapped.push_back({});
1252  continue;
1253  }
1254 
1255  ValueVector repl = mapping.lookupOrDefault(operand, legalTypes);
1256  if (!repl.empty() && TypeRange(ValueRange(repl)) == legalTypes) {
1257  // Mapped values have the correct type or there is an existing
1258  // materialization. Or the operand is not mapped at all and has the
1259  // correct type.
1260  remapped.push_back(std::move(repl));
1261  continue;
1262  }
1263 
1264  // Create a materialization for the most recently mapped values.
1265  repl = mapping.lookupOrDefault(operand);
1267  MaterializationKind::Target, computeInsertPoint(repl), operandLoc,
1268  /*valuesToMap=*/repl, /*inputs=*/repl, /*outputTypes=*/legalTypes,
1269  /*originalType=*/origType, currentTypeConverter);
1270  remapped.push_back(castValues);
1271  }
1272  return success();
1273 }
1274 
1276  // Check to see if this operation is ignored or was replaced.
1277  return replacedOps.count(op) || ignoredOps.count(op);
1278 }
1279 
1281  // Check to see if this operation was replaced.
1282  return replacedOps.count(op);
1283 }
1284 
1285 //===----------------------------------------------------------------------===//
1286 // Type Conversion
1287 
1289  ConversionPatternRewriter &rewriter, Region *region,
1290  const TypeConverter &converter,
1291  TypeConverter::SignatureConversion *entryConversion) {
1292  regionToConverter[region] = &converter;
1293  if (region->empty())
1294  return nullptr;
1295 
1296  // Convert the arguments of each non-entry block within the region.
1297  for (Block &block :
1298  llvm::make_early_inc_range(llvm::drop_begin(*region, 1))) {
1299  // Compute the signature for the block with the provided converter.
1300  std::optional<TypeConverter::SignatureConversion> conversion =
1301  converter.convertBlockSignature(&block);
1302  if (!conversion)
1303  return failure();
1304  // Convert the block with the computed signature.
1305  applySignatureConversion(rewriter, &block, &converter, *conversion);
1306  }
1307 
1308  // Convert the entry block. If an entry signature conversion was provided,
1309  // use that one. Otherwise, compute the signature with the type converter.
1310  if (entryConversion)
1311  return applySignatureConversion(rewriter, &region->front(), &converter,
1312  *entryConversion);
1313  std::optional<TypeConverter::SignatureConversion> conversion =
1314  converter.convertBlockSignature(&region->front());
1315  if (!conversion)
1316  return failure();
1317  return applySignatureConversion(rewriter, &region->front(), &converter,
1318  *conversion);
1319 }
1320 
1322  ConversionPatternRewriter &rewriter, Block *block,
1323  const TypeConverter *converter,
1324  TypeConverter::SignatureConversion &signatureConversion) {
1325 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
1326  // A block cannot be converted multiple times.
1327  if (hasRewrite<BlockTypeConversionRewrite>(rewrites, block))
1328  llvm::report_fatal_error("block was already converted");
1329 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
1330 
1331  OpBuilder::InsertionGuard g(rewriter);
1332 
1333  // If no arguments are being changed or added, there is nothing to do.
1334  unsigned origArgCount = block->getNumArguments();
1335  auto convertedTypes = signatureConversion.getConvertedTypes();
1336  if (llvm::equal(block->getArgumentTypes(), convertedTypes))
1337  return block;
1338 
1339  // Compute the locations of all block arguments in the new block.
1340  SmallVector<Location> newLocs(convertedTypes.size(),
1341  rewriter.getUnknownLoc());
1342  for (unsigned i = 0; i < origArgCount; ++i) {
1343  auto inputMap = signatureConversion.getInputMapping(i);
1344  if (!inputMap || inputMap->replacedWithValues())
1345  continue;
1346  Location origLoc = block->getArgument(i).getLoc();
1347  for (unsigned j = 0; j < inputMap->size; ++j)
1348  newLocs[inputMap->inputNo + j] = origLoc;
1349  }
1350 
1351  // Insert a new block with the converted block argument types and move all ops
1352  // from the old block to the new block.
1353  Block *newBlock =
1354  rewriter.createBlock(block->getParent(), std::next(block->getIterator()),
1355  convertedTypes, newLocs);
1356 
1357  // If a listener is attached to the dialect conversion, ops cannot be moved
1358  // to the destination block in bulk ("fast path"). This is because at the time
1359  // the notifications are sent, it is unknown which ops were moved. Instead,
1360  // ops should be moved one-by-one ("slow path"), so that a separate
1361  // `MoveOperationRewrite` is enqueued for each moved op. Moving ops in bulk is
1362  // a bit more efficient, so we try to do that when possible.
1363  bool fastPath = !config.listener;
1364  if (fastPath) {
1365  appendRewrite<InlineBlockRewrite>(newBlock, block, newBlock->end());
1366  newBlock->getOperations().splice(newBlock->end(), block->getOperations());
1367  } else {
1368  while (!block->empty())
1369  rewriter.moveOpBefore(&block->front(), newBlock, newBlock->end());
1370  }
1371 
1372  // Replace all uses of the old block with the new block.
1373  block->replaceAllUsesWith(newBlock);
1374 
1375  for (unsigned i = 0; i != origArgCount; ++i) {
1376  BlockArgument origArg = block->getArgument(i);
1377  Type origArgType = origArg.getType();
1378 
1379  std::optional<TypeConverter::SignatureConversion::InputMapping> inputMap =
1380  signatureConversion.getInputMapping(i);
1381  if (!inputMap) {
1382  // This block argument was dropped and no replacement value was provided.
1383  // Materialize a replacement value "out of thin air".
1385  MaterializationKind::Source,
1386  OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(),
1387  /*valuesToMap=*/{origArg}, /*inputs=*/ValueRange(),
1388  /*outputType=*/origArgType, /*originalType=*/Type(), converter);
1389  appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
1390  continue;
1391  }
1392 
1393  if (inputMap->replacedWithValues()) {
1394  // This block argument was dropped and replacement values were provided.
1395  assert(inputMap->size == 0 &&
1396  "invalid to provide a replacement value when the argument isn't "
1397  "dropped");
1398  mapping.map(origArg, inputMap->replacementValues);
1399  appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
1400  continue;
1401  }
1402 
1403  // This is a 1->1+ mapping.
1404  auto replArgs =
1405  newBlock->getArguments().slice(inputMap->inputNo, inputMap->size);
1406  ValueVector replArgVals = llvm::to_vector_of<Value, 1>(replArgs);
1407  mapping.map(origArg, std::move(replArgVals));
1408  appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
1409  }
1410 
1411  appendRewrite<BlockTypeConversionRewrite>(/*origBlock=*/block, newBlock);
1412 
1413  // Erase the old block. (It is just unlinked for now and will be erased during
1414  // cleanup.)
1415  rewriter.eraseBlock(block);
1416 
1417  return newBlock;
1418 }
1419 
1420 //===----------------------------------------------------------------------===//
1421 // Materializations
1422 //===----------------------------------------------------------------------===//
1423 
1424 /// Build an unresolved materialization operation given an output type and set
1425 /// of input operands.
1427  MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
1428  ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes,
1429  Type originalType, const TypeConverter *converter,
1430  UnrealizedConversionCastOp *castOp) {
1431  assert((!originalType || kind == MaterializationKind::Target) &&
1432  "original type is valid only for target materializations");
1433  assert(TypeRange(inputs) != outputTypes &&
1434  "materialization is not necessary");
1435 
1436  // Create an unresolved materialization. We use a new OpBuilder to avoid
1437  // tracking the materialization like we do for other operations.
1438  OpBuilder builder(outputTypes.front().getContext());
1439  builder.setInsertionPoint(ip.getBlock(), ip.getPoint());
1440  auto convertOp =
1441  builder.create<UnrealizedConversionCastOp>(loc, outputTypes, inputs);
1442  if (!valuesToMap.empty())
1443  mapping.map(valuesToMap, convertOp.getResults());
1444  if (castOp)
1445  *castOp = convertOp;
1446  appendRewrite<UnresolvedMaterializationRewrite>(
1447  convertOp, converter, kind, originalType, std::move(valuesToMap));
1448  return convertOp.getResults();
1449 }
1450 
1452  Value value, const TypeConverter *converter) {
1453  // Try to find a replacement value with the same type in the conversion value
1454  // mapping. This includes cached materializations. We try to reuse those
1455  // instead of generating duplicate IR.
1456  ValueVector repl = mapping.lookupOrNull(value, value.getType());
1457  if (!repl.empty())
1458  return repl.front();
1459 
1460  // Check if the value is dead. No replacement value is needed in that case.
1461  // This is an approximate check that may have false negatives but does not
1462  // require computing and traversing an inverse mapping. (We may end up
1463  // building source materializations that are never used and that fold away.)
1464  if (llvm::all_of(value.getUsers(),
1465  [&](Operation *op) { return replacedOps.contains(op); }) &&
1466  !mapping.isMappedTo(value))
1467  return Value();
1468 
1469  // No replacement value was found. Get the latest replacement value
1470  // (regardless of the type) and build a source materialization to the
1471  // original type.
1472  repl = mapping.lookupOrNull(value);
1473  if (repl.empty()) {
1474  // No replacement value is registered in the mapping. This means that the
1475  // value is dropped and no longer needed. (If the value were still needed,
1476  // a source materialization producing a replacement value "out of thin air"
1477  // would have already been created during `replaceOp` or
1478  // `applySignatureConversion`.)
1479  return Value();
1480  }
1481 
1482  // Note: `computeInsertPoint` computes the "earliest" insertion point at
1483  // which all values in `repl` are defined. It is important to emit the
1484  // materialization at that location because the same materialization may be
1485  // reused in a different context. (That's because materializations are cached
1486  // in the conversion value mapping.) The insertion point of the
1487  // materialization must be valid for all future users that may be created
1488  // later in the conversion process.
1489  Value castValue =
1490  buildUnresolvedMaterialization(MaterializationKind::Source,
1491  computeInsertPoint(repl), value.getLoc(),
1492  /*valuesToMap=*/repl, /*inputs=*/repl,
1493  /*outputType=*/value.getType(),
1494  /*originalType=*/Type(), converter)
1495  .front();
1496  return castValue;
1497 }
1498 
1499 //===----------------------------------------------------------------------===//
1500 // Rewriter Notification Hooks
1501 
1503  Operation *op, OpBuilder::InsertPoint previous) {
1504  LLVM_DEBUG({
1505  logger.startLine() << "** Insert : '" << op->getName() << "'(" << op
1506  << ")\n";
1507  });
1508  assert(!wasOpReplaced(op->getParentOp()) &&
1509  "attempting to insert into a block within a replaced/erased op");
1510 
1511  if (!previous.isSet()) {
1512  // This is a newly created op.
1513  appendRewrite<CreateOperationRewrite>(op);
1514  return;
1515  }
1516  Operation *prevOp = previous.getPoint() == previous.getBlock()->end()
1517  ? nullptr
1518  : &*previous.getPoint();
1519  appendRewrite<MoveOperationRewrite>(op, previous.getBlock(), prevOp);
1520 }
1521 
1523  Operation *op, ArrayRef<ValueRange> newValues) {
1524  assert(newValues.size() == op->getNumResults());
1525  assert(!ignoredOps.contains(op) && "operation was already replaced");
1526 
1527  // Check if replaced op is an unresolved materialization, i.e., an
1528  // unrealized_conversion_cast op that was created by the conversion driver.
1529  bool isUnresolvedMaterialization = false;
1530  if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op))
1531  if (unresolvedMaterializations.contains(castOp))
1532  isUnresolvedMaterialization = true;
1533 
1534  // Create mappings for each of the new result values.
1535  for (auto [repl, result] : llvm::zip_equal(newValues, op->getResults())) {
1536  if (repl.empty()) {
1537  // This result was dropped and no replacement value was provided.
1538  if (isUnresolvedMaterialization) {
1539  // Do not create another materializations if we are erasing a
1540  // materialization.
1541  continue;
1542  }
1543 
1544  // Materialize a replacement value "out of thin air".
1546  MaterializationKind::Source, computeInsertPoint(result),
1547  result.getLoc(), /*valuesToMap=*/{result}, /*inputs=*/ValueRange(),
1548  /*outputType=*/result.getType(), /*originalType=*/Type(),
1550  continue;
1551  } else {
1552  // Make sure that the user does not mess with unresolved materializations
1553  // that were inserted by the conversion driver. We keep track of these
1554  // ops in internal data structures. Erasing them must be allowed because
1555  // this can happen when the user is erasing an entire block (including
1556  // its body). But replacing them with another value should be forbidden
1557  // to avoid problems with the `mapping`.
1558  assert(!isUnresolvedMaterialization &&
1559  "attempting to replace an unresolved materialization");
1560  }
1561 
1562  // Remap result to replacement value.
1563  if (repl.empty())
1564  continue;
1565  mapping.map(result, repl);
1566  }
1567 
1568  appendRewrite<ReplaceOperationRewrite>(op, currentTypeConverter);
1569  // Mark this operation and all nested ops as replaced.
1570  op->walk([&](Operation *op) { replacedOps.insert(op); });
1571 }
1572 
1574  appendRewrite<EraseBlockRewrite>(block);
1575 }
1576 
1578  Block *block, Region *previous, Region::iterator previousIt) {
1579  assert(!wasOpReplaced(block->getParentOp()) &&
1580  "attempting to insert into a region within a replaced/erased op");
1581  LLVM_DEBUG(
1582  {
1583  Operation *parent = block->getParentOp();
1584  if (parent) {
1585  logger.startLine() << "** Insert Block into : '" << parent->getName()
1586  << "'(" << parent << ")\n";
1587  } else {
1588  logger.startLine()
1589  << "** Insert Block into detached Region (nullptr parent op)'";
1590  }
1591  });
1592 
1593  if (!previous) {
1594  // This is a newly created block.
1595  appendRewrite<CreateBlockRewrite>(block);
1596  return;
1597  }
1598  Block *prevBlock = previousIt == previous->end() ? nullptr : &*previousIt;
1599  appendRewrite<MoveBlockRewrite>(block, previous, prevBlock);
1600 }
1601 
1603  Block *block, Block *srcBlock, Block::iterator before) {
1604  appendRewrite<InlineBlockRewrite>(block, srcBlock, before);
1605 }
1606 
1608  Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
1609  LLVM_DEBUG({
1611  reasonCallback(diag);
1612  logger.startLine() << "** Failure : " << diag.str() << "\n";
1613  if (config.notifyCallback)
1615  });
1616 }
1617 
1618 //===----------------------------------------------------------------------===//
1619 // ConversionPatternRewriter
1620 //===----------------------------------------------------------------------===//
1621 
1622 ConversionPatternRewriter::ConversionPatternRewriter(
1623  MLIRContext *ctx, const ConversionConfig &config)
1624  : PatternRewriter(ctx),
1625  impl(new detail::ConversionPatternRewriterImpl(ctx, config)) {
1626  setListener(impl.get());
1627 }
1628 
1630 
1632  assert(op && newOp && "expected non-null op");
1633  replaceOp(op, newOp->getResults());
1634 }
1635 
1637  assert(op->getNumResults() == newValues.size() &&
1638  "incorrect # of replacement values");
1639  LLVM_DEBUG({
1640  impl->logger.startLine()
1641  << "** Replace : '" << op->getName() << "'(" << op << ")\n";
1642  });
1643  SmallVector<ValueRange> newVals;
1644  for (size_t i = 0; i < newValues.size(); ++i) {
1645  if (newValues[i]) {
1646  newVals.push_back(newValues.slice(i, 1));
1647  } else {
1648  newVals.push_back(ValueRange());
1649  }
1650  }
1651  impl->notifyOpReplaced(op, newVals);
1652 }
1653 
1655  Operation *op, ArrayRef<ValueRange> newValues) {
1656  assert(op->getNumResults() == newValues.size() &&
1657  "incorrect # of replacement values");
1658  LLVM_DEBUG({
1659  impl->logger.startLine()
1660  << "** Replace : '" << op->getName() << "'(" << op << ")\n";
1661  });
1662  impl->notifyOpReplaced(op, newValues);
1663 }
1664 
1666  LLVM_DEBUG({
1667  impl->logger.startLine()
1668  << "** Erase : '" << op->getName() << "'(" << op << ")\n";
1669  });
1670  SmallVector<ValueRange> nullRepls(op->getNumResults(), {});
1671  impl->notifyOpReplaced(op, nullRepls);
1672 }
1673 
1675  assert(!impl->wasOpReplaced(block->getParentOp()) &&
1676  "attempting to erase a block within a replaced/erased op");
1677 
1678  // Mark all ops for erasure.
1679  for (Operation &op : *block)
1680  eraseOp(&op);
1681 
1682  // Unlink the block from its parent region. The block is kept in the rewrite
1683  // object and will be actually destroyed when rewrites are applied. This
1684  // allows us to keep the operations in the block live and undo the removal by
1685  // re-inserting the block.
1686  impl->notifyBlockIsBeingErased(block);
1687  block->getParent()->getBlocks().remove(block);
1688 }
1689 
1691  Block *block, TypeConverter::SignatureConversion &conversion,
1692  const TypeConverter *converter) {
1693  assert(!impl->wasOpReplaced(block->getParentOp()) &&
1694  "attempting to apply a signature conversion to a block within a "
1695  "replaced/erased op");
1696  return impl->applySignatureConversion(*this, block, converter, conversion);
1697 }
1698 
1700  Region *region, const TypeConverter &converter,
1701  TypeConverter::SignatureConversion *entryConversion) {
1702  assert(!impl->wasOpReplaced(region->getParentOp()) &&
1703  "attempting to apply a signature conversion to a block within a "
1704  "replaced/erased op");
1705  return impl->convertRegionTypes(*this, region, converter, entryConversion);
1706 }
1707 
1709  Value to) {
1710  LLVM_DEBUG({
1711  impl->logger.startLine() << "** Replace Argument : '" << from << "'";
1712  if (Operation *parentOp = from.getOwner()->getParentOp()) {
1713  impl->logger.getOStream() << " (in region of '" << parentOp->getName()
1714  << "' (" << parentOp << ")\n";
1715  } else {
1716  impl->logger.getOStream() << " (unlinked block)\n";
1717  }
1718  });
1719  impl->appendRewrite<ReplaceBlockArgRewrite>(from.getOwner(), from,
1720  impl->currentTypeConverter);
1721  impl->mapping.map(impl->mapping.lookupOrDefault(from), to);
1722 }
1723 
1725  SmallVector<ValueVector> remappedValues;
1726  if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, key,
1727  remappedValues)))
1728  return nullptr;
1729  assert(remappedValues.front().size() == 1 && "1:N conversion not supported");
1730  return remappedValues.front().front();
1731 }
1732 
1733 LogicalResult
1735  SmallVectorImpl<Value> &results) {
1736  if (keys.empty())
1737  return success();
1738  SmallVector<ValueVector> remapped;
1739  if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, keys,
1740  remapped)))
1741  return failure();
1742  for (const auto &values : remapped) {
1743  assert(values.size() == 1 && "1:N conversion not supported");
1744  results.push_back(values.front());
1745  }
1746  return success();
1747 }
1748 
1750  Block::iterator before,
1751  ValueRange argValues) {
1752 #ifndef NDEBUG
1753  assert(argValues.size() == source->getNumArguments() &&
1754  "incorrect # of argument replacement values");
1755  assert(!impl->wasOpReplaced(source->getParentOp()) &&
1756  "attempting to inline a block from a replaced/erased op");
1757  assert(!impl->wasOpReplaced(dest->getParentOp()) &&
1758  "attempting to inline a block into a replaced/erased op");
1759  auto opIgnored = [&](Operation *op) { return impl->isOpIgnored(op); };
1760  // The source block will be deleted, so it should not have any users (i.e.,
1761  // there should be no predecessors).
1762  assert(llvm::all_of(source->getUsers(), opIgnored) &&
1763  "expected 'source' to have no predecessors");
1764 #endif // NDEBUG
1765 
1766  // If a listener is attached to the dialect conversion, ops cannot be moved
1767  // to the destination block in bulk ("fast path"). This is because at the time
1768  // the notifications are sent, it is unknown which ops were moved. Instead,
1769  // ops should be moved one-by-one ("slow path"), so that a separate
1770  // `MoveOperationRewrite` is enqueued for each moved op. Moving ops in bulk is
1771  // a bit more efficient, so we try to do that when possible.
1772  bool fastPath = !impl->config.listener;
1773 
1774  if (fastPath)
1775  impl->notifyBlockBeingInlined(dest, source, before);
1776 
1777  // Replace all uses of block arguments.
1778  for (auto it : llvm::zip(source->getArguments(), argValues))
1779  replaceUsesOfBlockArgument(std::get<0>(it), std::get<1>(it));
1780 
1781  if (fastPath) {
1782  // Move all ops at once.
1783  dest->getOperations().splice(before, source->getOperations());
1784  } else {
1785  // Move op by op.
1786  while (!source->empty())
1787  moveOpBefore(&source->front(), dest, before);
1788  }
1789 
1790  // Erase the source block.
1791  eraseBlock(source);
1792 }
1793 
1795  assert(!impl->wasOpReplaced(op) &&
1796  "attempting to modify a replaced/erased op");
1797 #ifndef NDEBUG
1798  impl->pendingRootUpdates.insert(op);
1799 #endif
1800  impl->appendRewrite<ModifyOperationRewrite>(op);
1801 }
1802 
1804  assert(!impl->wasOpReplaced(op) &&
1805  "attempting to modify a replaced/erased op");
1807  // There is nothing to do here, we only need to track the operation at the
1808  // start of the update.
1809 #ifndef NDEBUG
1810  assert(impl->pendingRootUpdates.erase(op) &&
1811  "operation did not have a pending in-place update");
1812 #endif
1813 }
1814 
1816 #ifndef NDEBUG
1817  assert(impl->pendingRootUpdates.erase(op) &&
1818  "operation did not have a pending in-place update");
1819 #endif
1820  // Erase the last update for this operation.
1821  auto it = llvm::find_if(
1822  llvm::reverse(impl->rewrites), [&](std::unique_ptr<IRRewrite> &rewrite) {
1823  auto *modifyRewrite = dyn_cast<ModifyOperationRewrite>(rewrite.get());
1824  return modifyRewrite && modifyRewrite->getOperation() == op;
1825  });
1826  assert(it != impl->rewrites.rend() && "no root update started on op");
1827  (*it)->rollback();
1828  int updateIdx = std::prev(impl->rewrites.rend()) - it;
1829  impl->rewrites.erase(impl->rewrites.begin() + updateIdx);
1830 }
1831 
1833  return *impl;
1834 }
1835 
1836 //===----------------------------------------------------------------------===//
1837 // ConversionPattern
1838 //===----------------------------------------------------------------------===//
1839 
1841  ArrayRef<ValueRange> operands) const {
1842  SmallVector<Value> oneToOneOperands;
1843  oneToOneOperands.reserve(operands.size());
1844  for (ValueRange operand : operands) {
1845  if (operand.size() != 1)
1846  llvm::report_fatal_error("pattern '" + getDebugName() +
1847  "' does not support 1:N conversion");
1848  oneToOneOperands.push_back(operand.front());
1849  }
1850  return oneToOneOperands;
1851 }
1852 
1853 LogicalResult
1855  PatternRewriter &rewriter) const {
1856  auto &dialectRewriter = static_cast<ConversionPatternRewriter &>(rewriter);
1857  auto &rewriterImpl = dialectRewriter.getImpl();
1858 
1859  // Track the current conversion pattern type converter in the rewriter.
1860  llvm::SaveAndRestore currentConverterGuard(rewriterImpl.currentTypeConverter,
1861  getTypeConverter());
1862 
1863  // Remap the operands of the operation.
1864  SmallVector<ValueVector> remapped;
1865  if (failed(rewriterImpl.remapValues("operand", op->getLoc(), rewriter,
1866  op->getOperands(), remapped))) {
1867  return failure();
1868  }
1869  SmallVector<ValueRange> remappedAsRange =
1870  llvm::to_vector_of<ValueRange>(remapped);
1871  return matchAndRewrite(op, remappedAsRange, dialectRewriter);
1872 }
1873 
1874 //===----------------------------------------------------------------------===//
1875 // OperationLegalizer
1876 //===----------------------------------------------------------------------===//
1877 
1878 namespace {
1879 /// A set of rewrite patterns that can be used to legalize a given operation.
1880 using LegalizationPatterns = SmallVector<const Pattern *, 1>;
1881 
1882 /// This class defines a recursive operation legalizer.
1883 class OperationLegalizer {
1884 public:
1885  using LegalizationAction = ConversionTarget::LegalizationAction;
1886 
1887  OperationLegalizer(const ConversionTarget &targetInfo,
1889  const ConversionConfig &config);
1890 
1891  /// Returns true if the given operation is known to be illegal on the target.
1892  bool isIllegal(Operation *op) const;
1893 
1894  /// Attempt to legalize the given operation. Returns success if the operation
1895  /// was legalized, failure otherwise.
1896  LogicalResult legalize(Operation *op, ConversionPatternRewriter &rewriter);
1897 
1898  /// Returns the conversion target in use by the legalizer.
1899  const ConversionTarget &getTarget() { return target; }
1900 
1901 private:
1902  /// Attempt to legalize the given operation by folding it.
1903  LogicalResult legalizeWithFold(Operation *op,
1904  ConversionPatternRewriter &rewriter);
1905 
1906  /// Attempt to legalize the given operation by applying a pattern. Returns
1907  /// success if the operation was legalized, failure otherwise.
1908  LogicalResult legalizeWithPattern(Operation *op,
1909  ConversionPatternRewriter &rewriter);
1910 
1911  /// Return true if the given pattern may be applied to the given operation,
1912  /// false otherwise.
1913  bool canApplyPattern(Operation *op, const Pattern &pattern,
1914  ConversionPatternRewriter &rewriter);
1915 
1916  /// Legalize the resultant IR after successfully applying the given pattern.
1917  LogicalResult legalizePatternResult(Operation *op, const Pattern &pattern,
1918  ConversionPatternRewriter &rewriter,
1919  RewriterState &curState);
1920 
1921  /// Legalizes the actions registered during the execution of a pattern.
1922  LogicalResult
1923  legalizePatternBlockRewrites(Operation *op,
1924  ConversionPatternRewriter &rewriter,
1926  RewriterState &state, RewriterState &newState);
1927  LogicalResult legalizePatternCreatedOperations(
1929  RewriterState &state, RewriterState &newState);
1930  LogicalResult legalizePatternRootUpdates(ConversionPatternRewriter &rewriter,
1932  RewriterState &state,
1933  RewriterState &newState);
1934 
1935  //===--------------------------------------------------------------------===//
1936  // Cost Model
1937  //===--------------------------------------------------------------------===//
1938 
1939  /// Build an optimistic legalization graph given the provided patterns. This
1940  /// function populates 'anyOpLegalizerPatterns' and 'legalizerPatterns' with
1941  /// patterns for operations that are not directly legal, but may be
1942  /// transitively legal for the current target given the provided patterns.
1943  void buildLegalizationGraph(
1944  LegalizationPatterns &anyOpLegalizerPatterns,
1946 
1947  /// Compute the benefit of each node within the computed legalization graph.
1948  /// This orders the patterns within 'legalizerPatterns' based upon two
1949  /// criteria:
1950  /// 1) Prefer patterns that have the lowest legalization depth, i.e.
1951  /// represent the more direct mapping to the target.
1952  /// 2) When comparing patterns with the same legalization depth, prefer the
1953  /// pattern with the highest PatternBenefit. This allows for users to
1954  /// prefer specific legalizations over others.
1955  void computeLegalizationGraphBenefit(
1956  LegalizationPatterns &anyOpLegalizerPatterns,
1958 
1959  /// Compute the legalization depth when legalizing an operation of the given
1960  /// type.
1961  unsigned computeOpLegalizationDepth(
1962  OperationName op, DenseMap<OperationName, unsigned> &minOpPatternDepth,
1964 
1965  /// Apply the conversion cost model to the given set of patterns, and return
1966  /// the smallest legalization depth of any of the patterns. See
1967  /// `computeLegalizationGraphBenefit` for the breakdown of the cost model.
1968  unsigned applyCostModelToPatterns(
1969  LegalizationPatterns &patterns,
1970  DenseMap<OperationName, unsigned> &minOpPatternDepth,
1972 
1973  /// The current set of patterns that have been applied.
1974  SmallPtrSet<const Pattern *, 8> appliedPatterns;
1975 
1976  /// The legalization information provided by the target.
1977  const ConversionTarget &target;
1978 
1979  /// The pattern applicator to use for conversions.
1980  PatternApplicator applicator;
1981 
1982  /// Dialect conversion configuration.
1983  const ConversionConfig &config;
1984 };
1985 } // namespace
1986 
1987 OperationLegalizer::OperationLegalizer(const ConversionTarget &targetInfo,
1989  const ConversionConfig &config)
1990  : target(targetInfo), applicator(patterns), config(config) {
1991  // The set of patterns that can be applied to illegal operations to transform
1992  // them into legal ones.
1994  LegalizationPatterns anyOpLegalizerPatterns;
1995 
1996  buildLegalizationGraph(anyOpLegalizerPatterns, legalizerPatterns);
1997  computeLegalizationGraphBenefit(anyOpLegalizerPatterns, legalizerPatterns);
1998 }
1999 
2000 bool OperationLegalizer::isIllegal(Operation *op) const {
2001  return target.isIllegal(op);
2002 }
2003 
2004 LogicalResult
2005 OperationLegalizer::legalize(Operation *op,
2006  ConversionPatternRewriter &rewriter) {
2007 #ifndef NDEBUG
2008  const char *logLineComment =
2009  "//===-------------------------------------------===//\n";
2010 
2011  auto &logger = rewriter.getImpl().logger;
2012 #endif
2013  LLVM_DEBUG({
2014  logger.getOStream() << "\n";
2015  logger.startLine() << logLineComment;
2016  logger.startLine() << "Legalizing operation : '" << op->getName() << "'("
2017  << op << ") {\n";
2018  logger.indent();
2019 
2020  // If the operation has no regions, just print it here.
2021  if (op->getNumRegions() == 0) {
2022  op->print(logger.startLine(), OpPrintingFlags().printGenericOpForm());
2023  logger.getOStream() << "\n\n";
2024  }
2025  });
2026 
2027  // Check if this operation is legal on the target.
2028  if (auto legalityInfo = target.isLegal(op)) {
2029  LLVM_DEBUG({
2030  logSuccess(
2031  logger, "operation marked legal by the target{0}",
2032  legalityInfo->isRecursivelyLegal
2033  ? "; NOTE: operation is recursively legal; skipping internals"
2034  : "");
2035  logger.startLine() << logLineComment;
2036  });
2037 
2038  // If this operation is recursively legal, mark its children as ignored so
2039  // that we don't consider them for legalization.
2040  if (legalityInfo->isRecursivelyLegal) {
2041  op->walk([&](Operation *nested) {
2042  if (op != nested)
2043  rewriter.getImpl().ignoredOps.insert(nested);
2044  });
2045  }
2046 
2047  return success();
2048  }
2049 
2050  // Check to see if the operation is ignored and doesn't need to be converted.
2051  if (rewriter.getImpl().isOpIgnored(op)) {
2052  LLVM_DEBUG({
2053  logSuccess(logger, "operation marked 'ignored' during conversion");
2054  logger.startLine() << logLineComment;
2055  });
2056  return success();
2057  }
2058 
2059  // If the operation isn't legal, try to fold it in-place.
2060  // TODO: Should we always try to do this, even if the op is
2061  // already legal?
2062  if (succeeded(legalizeWithFold(op, rewriter))) {
2063  LLVM_DEBUG({
2064  logSuccess(logger, "operation was folded");
2065  logger.startLine() << logLineComment;
2066  });
2067  return success();
2068  }
2069 
2070  // Otherwise, we need to apply a legalization pattern to this operation.
2071  if (succeeded(legalizeWithPattern(op, rewriter))) {
2072  LLVM_DEBUG({
2073  logSuccess(logger, "");
2074  logger.startLine() << logLineComment;
2075  });
2076  return success();
2077  }
2078 
2079  LLVM_DEBUG({
2080  logFailure(logger, "no matched legalization pattern");
2081  logger.startLine() << logLineComment;
2082  });
2083  return failure();
2084 }
2085 
2086 LogicalResult
2087 OperationLegalizer::legalizeWithFold(Operation *op,
2088  ConversionPatternRewriter &rewriter) {
2089  auto &rewriterImpl = rewriter.getImpl();
2090  RewriterState curState = rewriterImpl.getCurrentState();
2091 
2092  LLVM_DEBUG({
2093  rewriterImpl.logger.startLine() << "* Fold {\n";
2094  rewriterImpl.logger.indent();
2095  });
2096 
2097  // Try to fold the operation.
2098  SmallVector<Value, 2> replacementValues;
2099  rewriter.setInsertionPoint(op);
2100  if (failed(rewriter.tryFold(op, replacementValues))) {
2101  LLVM_DEBUG(logFailure(rewriterImpl.logger, "unable to fold"));
2102  return failure();
2103  }
2104  // An empty list of replacement values indicates that the fold was in-place.
2105  // As the operation changed, a new legalization needs to be attempted.
2106  if (replacementValues.empty())
2107  return legalize(op, rewriter);
2108 
2109  // Insert a replacement for 'op' with the folded replacement values.
2110  rewriter.replaceOp(op, replacementValues);
2111 
2112  // Recursively legalize any new constant operations.
2113  for (unsigned i = curState.numRewrites, e = rewriterImpl.rewrites.size();
2114  i != e; ++i) {
2115  auto *createOp =
2116  dyn_cast<CreateOperationRewrite>(rewriterImpl.rewrites[i].get());
2117  if (!createOp)
2118  continue;
2119  if (failed(legalize(createOp->getOperation(), rewriter))) {
2120  LLVM_DEBUG(logFailure(rewriterImpl.logger,
2121  "failed to legalize generated constant '{0}'",
2122  createOp->getOperation()->getName()));
2123  rewriterImpl.resetState(curState);
2124  return failure();
2125  }
2126  }
2127 
2128  LLVM_DEBUG(logSuccess(rewriterImpl.logger, ""));
2129  return success();
2130 }
2131 
2132 LogicalResult
2133 OperationLegalizer::legalizeWithPattern(Operation *op,
2134  ConversionPatternRewriter &rewriter) {
2135  auto &rewriterImpl = rewriter.getImpl();
2136 
2137  // Functor that returns if the given pattern may be applied.
2138  auto canApply = [&](const Pattern &pattern) {
2139  bool canApply = canApplyPattern(op, pattern, rewriter);
2140  if (canApply && config.listener)
2141  config.listener->notifyPatternBegin(pattern, op);
2142  return canApply;
2143  };
2144 
2145  // Functor that cleans up the rewriter state after a pattern failed to match.
2146  RewriterState curState = rewriterImpl.getCurrentState();
2147  auto onFailure = [&](const Pattern &pattern) {
2148  assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates");
2149  LLVM_DEBUG({
2150  logFailure(rewriterImpl.logger, "pattern failed to match");
2151  if (rewriterImpl.config.notifyCallback) {
2153  diag << "Failed to apply pattern \"" << pattern.getDebugName()
2154  << "\" on op:\n"
2155  << *op;
2156  rewriterImpl.config.notifyCallback(diag);
2157  }
2158  });
2159  if (config.listener)
2160  config.listener->notifyPatternEnd(pattern, failure());
2161  rewriterImpl.resetState(curState);
2162  appliedPatterns.erase(&pattern);
2163  };
2164 
2165  // Functor that performs additional legalization when a pattern is
2166  // successfully applied.
2167  auto onSuccess = [&](const Pattern &pattern) {
2168  assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates");
2169  auto result = legalizePatternResult(op, pattern, rewriter, curState);
2170  appliedPatterns.erase(&pattern);
2171  if (failed(result))
2172  rewriterImpl.resetState(curState);
2173  if (config.listener)
2174  config.listener->notifyPatternEnd(pattern, result);
2175  return result;
2176  };
2177 
2178  // Try to match and rewrite a pattern on this operation.
2179  return applicator.matchAndRewrite(op, rewriter, canApply, onFailure,
2180  onSuccess);
2181 }
2182 
2183 bool OperationLegalizer::canApplyPattern(Operation *op, const Pattern &pattern,
2184  ConversionPatternRewriter &rewriter) {
2185  LLVM_DEBUG({
2186  auto &os = rewriter.getImpl().logger;
2187  os.getOStream() << "\n";
2188  os.startLine() << "* Pattern : '" << op->getName() << " -> (";
2189  llvm::interleaveComma(pattern.getGeneratedOps(), os.getOStream());
2190  os.getOStream() << ")' {\n";
2191  os.indent();
2192  });
2193 
2194  // Ensure that we don't cycle by not allowing the same pattern to be
2195  // applied twice in the same recursion stack if it is not known to be safe.
2196  if (!pattern.hasBoundedRewriteRecursion() &&
2197  !appliedPatterns.insert(&pattern).second) {
2198  LLVM_DEBUG(
2199  logFailure(rewriter.getImpl().logger, "pattern was already applied"));
2200  return false;
2201  }
2202  return true;
2203 }
2204 
2205 LogicalResult
2206 OperationLegalizer::legalizePatternResult(Operation *op, const Pattern &pattern,
2207  ConversionPatternRewriter &rewriter,
2208  RewriterState &curState) {
2209  auto &impl = rewriter.getImpl();
2210  assert(impl.pendingRootUpdates.empty() && "dangling root updates");
2211 
2212 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2213  // Check that the root was either replaced or updated in place.
2214  auto newRewrites = llvm::drop_begin(impl.rewrites, curState.numRewrites);
2215  auto replacedRoot = [&] {
2216  return hasRewrite<ReplaceOperationRewrite>(newRewrites, op);
2217  };
2218  auto updatedRootInPlace = [&] {
2219  return hasRewrite<ModifyOperationRewrite>(newRewrites, op);
2220  };
2221  if (!replacedRoot() && !updatedRootInPlace())
2222  llvm::report_fatal_error("expected pattern to replace the root operation");
2223 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2224 
2225  // Legalize each of the actions registered during application.
2226  RewriterState newState = impl.getCurrentState();
2227  if (failed(legalizePatternBlockRewrites(op, rewriter, impl, curState,
2228  newState)) ||
2229  failed(legalizePatternRootUpdates(rewriter, impl, curState, newState)) ||
2230  failed(legalizePatternCreatedOperations(rewriter, impl, curState,
2231  newState))) {
2232  return failure();
2233  }
2234 
2235  LLVM_DEBUG(logSuccess(impl.logger, "pattern applied successfully"));
2236  return success();
2237 }
2238 
2239 LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
2240  Operation *op, ConversionPatternRewriter &rewriter,
2241  ConversionPatternRewriterImpl &impl, RewriterState &state,
2242  RewriterState &newState) {
2243  SmallPtrSet<Operation *, 16> operationsToIgnore;
2244 
2245  // If the pattern moved or created any blocks, make sure the types of block
2246  // arguments get legalized.
2247  for (int i = state.numRewrites, e = newState.numRewrites; i != e; ++i) {
2248  BlockRewrite *rewrite = dyn_cast<BlockRewrite>(impl.rewrites[i].get());
2249  if (!rewrite)
2250  continue;
2251  Block *block = rewrite->getBlock();
2252  if (isa<BlockTypeConversionRewrite, EraseBlockRewrite,
2253  ReplaceBlockArgRewrite>(rewrite))
2254  continue;
2255  // Only check blocks outside of the current operation.
2256  Operation *parentOp = block->getParentOp();
2257  if (!parentOp || parentOp == op || block->getNumArguments() == 0)
2258  continue;
2259 
2260  // If the region of the block has a type converter, try to convert the block
2261  // directly.
2262  if (auto *converter = impl.regionToConverter.lookup(block->getParent())) {
2263  std::optional<TypeConverter::SignatureConversion> conversion =
2264  converter->convertBlockSignature(block);
2265  if (!conversion) {
2266  LLVM_DEBUG(logFailure(impl.logger, "failed to convert types of moved "
2267  "block"));
2268  return failure();
2269  }
2270  impl.applySignatureConversion(rewriter, block, converter, *conversion);
2271  continue;
2272  }
2273 
2274  // Otherwise, check that this operation isn't one generated by this pattern.
2275  // This is because we will attempt to legalize the parent operation, and
2276  // blocks in regions created by this pattern will already be legalized later
2277  // on. If we haven't built the set yet, build it now.
2278  if (operationsToIgnore.empty()) {
2279  for (unsigned i = state.numRewrites, e = impl.rewrites.size(); i != e;
2280  ++i) {
2281  auto *createOp =
2282  dyn_cast<CreateOperationRewrite>(impl.rewrites[i].get());
2283  if (!createOp)
2284  continue;
2285  operationsToIgnore.insert(createOp->getOperation());
2286  }
2287  }
2288 
2289  // If this operation should be considered for re-legalization, try it.
2290  if (operationsToIgnore.insert(parentOp).second &&
2291  failed(legalize(parentOp, rewriter))) {
2292  LLVM_DEBUG(logFailure(impl.logger,
2293  "operation '{0}'({1}) became illegal after rewrite",
2294  parentOp->getName(), parentOp));
2295  return failure();
2296  }
2297  }
2298  return success();
2299 }
2300 
2301 LogicalResult OperationLegalizer::legalizePatternCreatedOperations(
2303  RewriterState &state, RewriterState &newState) {
2304  for (int i = state.numRewrites, e = newState.numRewrites; i != e; ++i) {
2305  auto *createOp = dyn_cast<CreateOperationRewrite>(impl.rewrites[i].get());
2306  if (!createOp)
2307  continue;
2308  Operation *op = createOp->getOperation();
2309  if (failed(legalize(op, rewriter))) {
2310  LLVM_DEBUG(logFailure(impl.logger,
2311  "failed to legalize generated operation '{0}'({1})",
2312  op->getName(), op));
2313  return failure();
2314  }
2315  }
2316  return success();
2317 }
2318 
2319 LogicalResult OperationLegalizer::legalizePatternRootUpdates(
2321  RewriterState &state, RewriterState &newState) {
2322  for (int i = state.numRewrites, e = newState.numRewrites; i != e; ++i) {
2323  auto *rewrite = dyn_cast<ModifyOperationRewrite>(impl.rewrites[i].get());
2324  if (!rewrite)
2325  continue;
2326  Operation *op = rewrite->getOperation();
2327  if (failed(legalize(op, rewriter))) {
2328  LLVM_DEBUG(logFailure(
2329  impl.logger, "failed to legalize operation updated in-place '{0}'",
2330  op->getName()));
2331  return failure();
2332  }
2333  }
2334  return success();
2335 }
2336 
2337 //===----------------------------------------------------------------------===//
2338 // Cost Model
2339 
2340 void OperationLegalizer::buildLegalizationGraph(
2341  LegalizationPatterns &anyOpLegalizerPatterns,
2342  DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
2343  // A mapping between an operation and a set of operations that can be used to
2344  // generate it.
2346  // A mapping between an operation and any currently invalid patterns it has.
2348  // A worklist of patterns to consider for legality.
2349  SetVector<const Pattern *> patternWorklist;
2350 
2351  // Build the mapping from operations to the parent ops that may generate them.
2352  applicator.walkAllPatterns([&](const Pattern &pattern) {
2353  std::optional<OperationName> root = pattern.getRootKind();
2354 
2355  // If the pattern has no specific root, we can't analyze the relationship
2356  // between the root op and generated operations. Given that, add all such
2357  // patterns to the legalization set.
2358  if (!root) {
2359  anyOpLegalizerPatterns.push_back(&pattern);
2360  return;
2361  }
2362 
2363  // Skip operations that are always known to be legal.
2364  if (target.getOpAction(*root) == LegalizationAction::Legal)
2365  return;
2366 
2367  // Add this pattern to the invalid set for the root op and record this root
2368  // as a parent for any generated operations.
2369  invalidPatterns[*root].insert(&pattern);
2370  for (auto op : pattern.getGeneratedOps())
2371  parentOps[op].insert(*root);
2372 
2373  // Add this pattern to the worklist.
2374  patternWorklist.insert(&pattern);
2375  });
2376 
2377  // If there are any patterns that don't have a specific root kind, we can't
2378  // make direct assumptions about what operations will never be legalized.
2379  // Note: Technically we could, but it would require an analysis that may
2380  // recurse into itself. It would be better to perform this kind of filtering
2381  // at a higher level than here anyways.
2382  if (!anyOpLegalizerPatterns.empty()) {
2383  for (const Pattern *pattern : patternWorklist)
2384  legalizerPatterns[*pattern->getRootKind()].push_back(pattern);
2385  return;
2386  }
2387 
2388  while (!patternWorklist.empty()) {
2389  auto *pattern = patternWorklist.pop_back_val();
2390 
2391  // Check to see if any of the generated operations are invalid.
2392  if (llvm::any_of(pattern->getGeneratedOps(), [&](OperationName op) {
2393  std::optional<LegalizationAction> action = target.getOpAction(op);
2394  return !legalizerPatterns.count(op) &&
2395  (!action || action == LegalizationAction::Illegal);
2396  }))
2397  continue;
2398 
2399  // Otherwise, if all of the generated operation are valid, this op is now
2400  // legal so add all of the child patterns to the worklist.
2401  legalizerPatterns[*pattern->getRootKind()].push_back(pattern);
2402  invalidPatterns[*pattern->getRootKind()].erase(pattern);
2403 
2404  // Add any invalid patterns of the parent operations to see if they have now
2405  // become legal.
2406  for (auto op : parentOps[*pattern->getRootKind()])
2407  patternWorklist.set_union(invalidPatterns[op]);
2408  }
2409 }
2410 
2411 void OperationLegalizer::computeLegalizationGraphBenefit(
2412  LegalizationPatterns &anyOpLegalizerPatterns,
2413  DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
2414  // The smallest pattern depth, when legalizing an operation.
2415  DenseMap<OperationName, unsigned> minOpPatternDepth;
2416 
2417  // For each operation that is transitively legal, compute a cost for it.
2418  for (auto &opIt : legalizerPatterns)
2419  if (!minOpPatternDepth.count(opIt.first))
2420  computeOpLegalizationDepth(opIt.first, minOpPatternDepth,
2421  legalizerPatterns);
2422 
2423  // Apply the cost model to the patterns that can match any operation. Those
2424  // with a specific operation type are already resolved when computing the op
2425  // legalization depth.
2426  if (!anyOpLegalizerPatterns.empty())
2427  applyCostModelToPatterns(anyOpLegalizerPatterns, minOpPatternDepth,
2428  legalizerPatterns);
2429 
2430  // Apply a cost model to the pattern applicator. We order patterns first by
2431  // depth then benefit. `legalizerPatterns` contains per-op patterns by
2432  // decreasing benefit.
2433  applicator.applyCostModel([&](const Pattern &pattern) {
2434  ArrayRef<const Pattern *> orderedPatternList;
2435  if (std::optional<OperationName> rootName = pattern.getRootKind())
2436  orderedPatternList = legalizerPatterns[*rootName];
2437  else
2438  orderedPatternList = anyOpLegalizerPatterns;
2439 
2440  // If the pattern is not found, then it was removed and cannot be matched.
2441  auto *it = llvm::find(orderedPatternList, &pattern);
2442  if (it == orderedPatternList.end())
2444 
2445  // Patterns found earlier in the list have higher benefit.
2446  return PatternBenefit(std::distance(it, orderedPatternList.end()));
2447  });
2448 }
2449 
2450 unsigned OperationLegalizer::computeOpLegalizationDepth(
2451  OperationName op, DenseMap<OperationName, unsigned> &minOpPatternDepth,
2452  DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
2453  // Check for existing depth.
2454  auto depthIt = minOpPatternDepth.find(op);
2455  if (depthIt != minOpPatternDepth.end())
2456  return depthIt->second;
2457 
2458  // If a mapping for this operation does not exist, then this operation
2459  // is always legal. Return 0 as the depth for a directly legal operation.
2460  auto opPatternsIt = legalizerPatterns.find(op);
2461  if (opPatternsIt == legalizerPatterns.end() || opPatternsIt->second.empty())
2462  return 0u;
2463 
2464  // Record this initial depth in case we encounter this op again when
2465  // recursively computing the depth.
2466  minOpPatternDepth.try_emplace(op, std::numeric_limits<unsigned>::max());
2467 
2468  // Apply the cost model to the operation patterns, and update the minimum
2469  // depth.
2470  unsigned minDepth = applyCostModelToPatterns(
2471  opPatternsIt->second, minOpPatternDepth, legalizerPatterns);
2472  minOpPatternDepth[op] = minDepth;
2473  return minDepth;
2474 }
2475 
2476 unsigned OperationLegalizer::applyCostModelToPatterns(
2477  LegalizationPatterns &patterns,
2478  DenseMap<OperationName, unsigned> &minOpPatternDepth,
2479  DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
2480  unsigned minDepth = std::numeric_limits<unsigned>::max();
2481 
2482  // Compute the depth for each pattern within the set.
2483  SmallVector<std::pair<const Pattern *, unsigned>, 4> patternsByDepth;
2484  patternsByDepth.reserve(patterns.size());
2485  for (const Pattern *pattern : patterns) {
2486  unsigned depth = 1;
2487  for (auto generatedOp : pattern->getGeneratedOps()) {
2488  unsigned generatedOpDepth = computeOpLegalizationDepth(
2489  generatedOp, minOpPatternDepth, legalizerPatterns);
2490  depth = std::max(depth, generatedOpDepth + 1);
2491  }
2492  patternsByDepth.emplace_back(pattern, depth);
2493 
2494  // Update the minimum depth of the pattern list.
2495  minDepth = std::min(minDepth, depth);
2496  }
2497 
2498  // If the operation only has one legalization pattern, there is no need to
2499  // sort them.
2500  if (patternsByDepth.size() == 1)
2501  return minDepth;
2502 
2503  // Sort the patterns by those likely to be the most beneficial.
2504  std::stable_sort(patternsByDepth.begin(), patternsByDepth.end(),
2505  [](const std::pair<const Pattern *, unsigned> &lhs,
2506  const std::pair<const Pattern *, unsigned> &rhs) {
2507  // First sort by the smaller pattern legalization
2508  // depth.
2509  if (lhs.second != rhs.second)
2510  return lhs.second < rhs.second;
2511 
2512  // Then sort by the larger pattern benefit.
2513  auto lhsBenefit = lhs.first->getBenefit();
2514  auto rhsBenefit = rhs.first->getBenefit();
2515  return lhsBenefit > rhsBenefit;
2516  });
2517 
2518  // Update the legalization pattern to use the new sorted list.
2519  patterns.clear();
2520  for (auto &patternIt : patternsByDepth)
2521  patterns.push_back(patternIt.first);
2522  return minDepth;
2523 }
2524 
2525 //===----------------------------------------------------------------------===//
2526 // OperationConverter
2527 //===----------------------------------------------------------------------===//
2528 namespace {
2529 enum OpConversionMode {
2530  /// In this mode, the conversion will ignore failed conversions to allow
2531  /// illegal operations to co-exist in the IR.
2532  Partial,
2533 
2534  /// In this mode, all operations must be legal for the given target for the
2535  /// conversion to succeed.
2536  Full,
2537 
2538  /// In this mode, operations are analyzed for legality. No actual rewrites are
2539  /// applied to the operations on success.
2540  Analysis,
2541 };
2542 } // namespace
2543 
2544 namespace mlir {
2545 // This class converts operations to a given conversion target via a set of
2546 // rewrite patterns. The conversion behaves differently depending on the
2547 // conversion mode.
2549  explicit OperationConverter(const ConversionTarget &target,
2551  const ConversionConfig &config,
2552  OpConversionMode mode)
2553  : config(config), opLegalizer(target, patterns, this->config),
2554  mode(mode) {}
2555 
2556  /// Converts the given operations to the conversion target.
2557  LogicalResult convertOperations(ArrayRef<Operation *> ops);
2558 
2559 private:
2560  /// Converts an operation with the given rewriter.
2561  LogicalResult convert(ConversionPatternRewriter &rewriter, Operation *op);
2562 
2563  /// Dialect conversion configuration.
2564  ConversionConfig config;
2565 
2566  /// The legalizer to use when converting operations.
2567  OperationLegalizer opLegalizer;
2568 
2569  /// The conversion mode to use when legalizing operations.
2570  OpConversionMode mode;
2571 };
2572 } // namespace mlir
2573 
2574 LogicalResult OperationConverter::convert(ConversionPatternRewriter &rewriter,
2575  Operation *op) {
2576  // Legalize the given operation.
2577  if (failed(opLegalizer.legalize(op, rewriter))) {
2578  // Handle the case of a failed conversion for each of the different modes.
2579  // Full conversions expect all operations to be converted.
2580  if (mode == OpConversionMode::Full)
2581  return op->emitError()
2582  << "failed to legalize operation '" << op->getName() << "'";
2583  // Partial conversions allow conversions to fail iff the operation was not
2584  // explicitly marked as illegal. If the user provided a `unlegalizedOps`
2585  // set, non-legalizable ops are added to that set.
2586  if (mode == OpConversionMode::Partial) {
2587  if (opLegalizer.isIllegal(op))
2588  return op->emitError()
2589  << "failed to legalize operation '" << op->getName()
2590  << "' that was explicitly marked illegal";
2591  if (config.unlegalizedOps)
2592  config.unlegalizedOps->insert(op);
2593  }
2594  } else if (mode == OpConversionMode::Analysis) {
2595  // Analysis conversions don't fail if any operations fail to legalize,
2596  // they are only interested in the operations that were successfully
2597  // legalized.
2598  if (config.legalizableOps)
2599  config.legalizableOps->insert(op);
2600  }
2601  return success();
2602 }
2603 
2604 static LogicalResult
2606  UnresolvedMaterializationRewrite *rewrite) {
2607  UnrealizedConversionCastOp op = rewrite->getOperation();
2608  assert(!op.use_empty() &&
2609  "expected that dead materializations have already been DCE'd");
2610  Operation::operand_range inputOperands = op.getOperands();
2611 
2612  // Try to materialize the conversion.
2613  if (const TypeConverter *converter = rewrite->getConverter()) {
2614  rewriter.setInsertionPoint(op);
2615  SmallVector<Value> newMaterialization;
2616  switch (rewrite->getMaterializationKind()) {
2617  case MaterializationKind::Target:
2618  newMaterialization = converter->materializeTargetConversion(
2619  rewriter, op->getLoc(), op.getResultTypes(), inputOperands,
2620  rewrite->getOriginalType());
2621  break;
2622  case MaterializationKind::Source:
2623  assert(op->getNumResults() == 1 && "expected single result");
2624  Value sourceMat = converter->materializeSourceConversion(
2625  rewriter, op->getLoc(), op.getResultTypes().front(), inputOperands);
2626  if (sourceMat)
2627  newMaterialization.push_back(sourceMat);
2628  break;
2629  }
2630  if (!newMaterialization.empty()) {
2631 #ifndef NDEBUG
2632  ValueRange newMaterializationRange(newMaterialization);
2633  assert(TypeRange(newMaterializationRange) == op.getResultTypes() &&
2634  "materialization callback produced value of incorrect type");
2635 #endif // NDEBUG
2636  rewriter.replaceOp(op, newMaterialization);
2637  return success();
2638  }
2639  }
2640 
2641  InFlightDiagnostic diag = op->emitError()
2642  << "failed to legalize unresolved materialization "
2643  "from ("
2644  << inputOperands.getTypes() << ") to ("
2645  << op.getResultTypes()
2646  << ") that remained live after conversion";
2647  diag.attachNote(op->getUsers().begin()->getLoc())
2648  << "see existing live user here: " << *op->getUsers().begin();
2649  return failure();
2650 }
2651 
2653  if (ops.empty())
2654  return success();
2655  const ConversionTarget &target = opLegalizer.getTarget();
2656 
2657  // Compute the set of operations and blocks to convert.
2658  SmallVector<Operation *> toConvert;
2659  for (auto *op : ops) {
2661  [&](Operation *op) {
2662  toConvert.push_back(op);
2663  // Don't check this operation's children for conversion if the
2664  // operation is recursively legal.
2665  auto legalityInfo = target.isLegal(op);
2666  if (legalityInfo && legalityInfo->isRecursivelyLegal)
2667  return WalkResult::skip();
2668  return WalkResult::advance();
2669  });
2670  }
2671 
2672  // Convert each operation and discard rewrites on failure.
2673  ConversionPatternRewriter rewriter(ops.front()->getContext(), config);
2674  ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
2675 
2676  for (auto *op : toConvert)
2677  if (failed(convert(rewriter, op)))
2678  return rewriterImpl.undoRewrites(), failure();
2679 
2680  // After a successful conversion, apply rewrites.
2681  rewriterImpl.applyRewrites();
2682 
2683  // Gather all unresolved materializations.
2686  &materializations = rewriterImpl.unresolvedMaterializations;
2687  for (auto it : materializations) {
2688  if (rewriterImpl.eraseRewriter.wasErased(it.first))
2689  continue;
2690  allCastOps.push_back(it.first);
2691  }
2692 
2693  // Reconcile all UnrealizedConversionCastOps that were inserted by the
2694  // dialect conversion frameworks. (Not the one that were inserted by
2695  // patterns.)
2696  SmallVector<UnrealizedConversionCastOp> remainingCastOps;
2697  reconcileUnrealizedCasts(allCastOps, &remainingCastOps);
2698 
2699  // Try to legalize all unresolved materializations.
2700  if (config.buildMaterializations) {
2701  IRRewriter rewriter(rewriterImpl.context, config.listener);
2702  for (UnrealizedConversionCastOp castOp : remainingCastOps) {
2703  auto it = materializations.find(castOp);
2704  assert(it != materializations.end() && "inconsistent state");
2705  if (failed(legalizeUnresolvedMaterialization(rewriter, it->second)))
2706  return failure();
2707  }
2708  }
2709 
2710  return success();
2711 }
2712 
2713 //===----------------------------------------------------------------------===//
2714 // Reconcile Unrealized Casts
2715 //===----------------------------------------------------------------------===//
2716 
2719  SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
2720  SetVector<UnrealizedConversionCastOp> worklist(castOps.begin(),
2721  castOps.end());
2722  // This set is maintained only if `remainingCastOps` is provided.
2723  DenseSet<Operation *> erasedOps;
2724 
2725  // Helper function that adds all operands to the worklist that are an
2726  // unrealized_conversion_cast op result.
2727  auto enqueueOperands = [&](UnrealizedConversionCastOp castOp) {
2728  for (Value v : castOp.getInputs())
2729  if (auto inputCastOp = v.getDefiningOp<UnrealizedConversionCastOp>())
2730  worklist.insert(inputCastOp);
2731  };
2732 
2733  // Helper function that return the unrealized_conversion_cast op that
2734  // defines all inputs of the given op (in the same order). Return "nullptr"
2735  // if there is no such op.
2736  auto getInputCast =
2737  [](UnrealizedConversionCastOp castOp) -> UnrealizedConversionCastOp {
2738  if (castOp.getInputs().empty())
2739  return {};
2740  auto inputCastOp =
2741  castOp.getInputs().front().getDefiningOp<UnrealizedConversionCastOp>();
2742  if (!inputCastOp)
2743  return {};
2744  if (inputCastOp.getOutputs() != castOp.getInputs())
2745  return {};
2746  return inputCastOp;
2747  };
2748 
2749  // Process ops in the worklist bottom-to-top.
2750  while (!worklist.empty()) {
2751  UnrealizedConversionCastOp castOp = worklist.pop_back_val();
2752  if (castOp->use_empty()) {
2753  // DCE: If the op has no users, erase it. Add the operands to the
2754  // worklist to find additional DCE opportunities.
2755  enqueueOperands(castOp);
2756  if (remainingCastOps)
2757  erasedOps.insert(castOp.getOperation());
2758  castOp->erase();
2759  continue;
2760  }
2761 
2762  // Traverse the chain of input cast ops to see if an op with the same
2763  // input types can be found.
2764  UnrealizedConversionCastOp nextCast = castOp;
2765  while (nextCast) {
2766  if (nextCast.getInputs().getTypes() == castOp.getResultTypes()) {
2767  // Found a cast where the input types match the output types of the
2768  // matched op. We can directly use those inputs and the matched op can
2769  // be removed.
2770  enqueueOperands(castOp);
2771  castOp.replaceAllUsesWith(nextCast.getInputs());
2772  if (remainingCastOps)
2773  erasedOps.insert(castOp.getOperation());
2774  castOp->erase();
2775  break;
2776  }
2777  nextCast = getInputCast(nextCast);
2778  }
2779  }
2780 
2781  if (remainingCastOps)
2782  for (UnrealizedConversionCastOp op : castOps)
2783  if (!erasedOps.contains(op.getOperation()))
2784  remainingCastOps->push_back(op);
2785 }
2786 
2787 //===----------------------------------------------------------------------===//
2788 // Type Conversion
2789 //===----------------------------------------------------------------------===//
2790 
2792  ArrayRef<Type> types) {
2793  assert(!types.empty() && "expected valid types");
2794  remapInput(origInputNo, /*newInputNo=*/argTypes.size(), types.size());
2795  addInputs(types);
2796 }
2797 
2799  assert(!types.empty() &&
2800  "1->0 type remappings don't need to be added explicitly");
2801  argTypes.append(types.begin(), types.end());
2802 }
2803 
2804 void TypeConverter::SignatureConversion::remapInput(unsigned origInputNo,
2805  unsigned newInputNo,
2806  unsigned newInputCount) {
2807  assert(!remappedInputs[origInputNo] && "input has already been remapped");
2808  assert(newInputCount != 0 && "expected valid input count");
2809  remappedInputs[origInputNo] =
2810  InputMapping{newInputNo, newInputCount, /*replacementValues=*/{}};
2811 }
2812 
2814  unsigned origInputNo, ArrayRef<Value> replacements) {
2815  assert(!remappedInputs[origInputNo] && "input has already been remapped");
2816  remappedInputs[origInputNo] = InputMapping{
2817  origInputNo, /*size=*/0,
2818  SmallVector<Value, 1>(replacements.begin(), replacements.end())};
2819 }
2820 
2822  SmallVectorImpl<Type> &results) const {
2823  assert(t && "expected non-null type");
2824 
2825  {
2826  std::shared_lock<decltype(cacheMutex)> cacheReadLock(cacheMutex,
2827  std::defer_lock);
2829  cacheReadLock.lock();
2830  auto existingIt = cachedDirectConversions.find(t);
2831  if (existingIt != cachedDirectConversions.end()) {
2832  if (existingIt->second)
2833  results.push_back(existingIt->second);
2834  return success(existingIt->second != nullptr);
2835  }
2836  auto multiIt = cachedMultiConversions.find(t);
2837  if (multiIt != cachedMultiConversions.end()) {
2838  results.append(multiIt->second.begin(), multiIt->second.end());
2839  return success();
2840  }
2841  }
2842  // Walk the added converters in reverse order to apply the most recently
2843  // registered first.
2844  size_t currentCount = results.size();
2845 
2846  std::unique_lock<decltype(cacheMutex)> cacheWriteLock(cacheMutex,
2847  std::defer_lock);
2848 
2849  for (const ConversionCallbackFn &converter : llvm::reverse(conversions)) {
2850  if (std::optional<LogicalResult> result = converter(t, results)) {
2852  cacheWriteLock.lock();
2853  if (!succeeded(*result)) {
2854  cachedDirectConversions.try_emplace(t, nullptr);
2855  return failure();
2856  }
2857  auto newTypes = ArrayRef<Type>(results).drop_front(currentCount);
2858  if (newTypes.size() == 1)
2859  cachedDirectConversions.try_emplace(t, newTypes.front());
2860  else
2861  cachedMultiConversions.try_emplace(t, llvm::to_vector<2>(newTypes));
2862  return success();
2863  }
2864  }
2865  return failure();
2866 }
2867 
2869  // Use the multi-type result version to convert the type.
2870  SmallVector<Type, 1> results;
2871  if (failed(convertType(t, results)))
2872  return nullptr;
2873 
2874  // Check to ensure that only one type was produced.
2875  return results.size() == 1 ? results.front() : nullptr;
2876 }
2877 
2878 LogicalResult
2880  SmallVectorImpl<Type> &results) const {
2881  for (Type type : types)
2882  if (failed(convertType(type, results)))
2883  return failure();
2884  return success();
2885 }
2886 
2887 bool TypeConverter::isLegal(Type type) const {
2888  return convertType(type) == type;
2889 }
2891  return isLegal(op->getOperandTypes()) && isLegal(op->getResultTypes());
2892 }
2893 
2894 bool TypeConverter::isLegal(Region *region) const {
2895  return llvm::all_of(*region, [this](Block &block) {
2896  return isLegal(block.getArgumentTypes());
2897  });
2898 }
2899 
2900 bool TypeConverter::isSignatureLegal(FunctionType ty) const {
2901  return isLegal(llvm::concat<const Type>(ty.getInputs(), ty.getResults()));
2902 }
2903 
2904 LogicalResult
2906  SignatureConversion &result) const {
2907  // Try to convert the given input type.
2908  SmallVector<Type, 1> convertedTypes;
2909  if (failed(convertType(type, convertedTypes)))
2910  return failure();
2911 
2912  // If this argument is being dropped, there is nothing left to do.
2913  if (convertedTypes.empty())
2914  return success();
2915 
2916  // Otherwise, add the new inputs.
2917  result.addInputs(inputNo, convertedTypes);
2918  return success();
2919 }
2920 LogicalResult
2922  SignatureConversion &result,
2923  unsigned origInputOffset) const {
2924  for (unsigned i = 0, e = types.size(); i != e; ++i)
2925  if (failed(convertSignatureArg(origInputOffset + i, types[i], result)))
2926  return failure();
2927  return success();
2928 }
2929 
2931  Location loc,
2932  Type resultType,
2933  ValueRange inputs) const {
2934  for (const MaterializationCallbackFn &fn :
2935  llvm::reverse(argumentMaterializations))
2936  if (Value result = fn(builder, resultType, inputs, loc))
2937  return result;
2938  return nullptr;
2939 }
2940 
2942  Location loc, Type resultType,
2943  ValueRange inputs) const {
2944  for (const MaterializationCallbackFn &fn :
2945  llvm::reverse(sourceMaterializations))
2946  if (Value result = fn(builder, resultType, inputs, loc))
2947  return result;
2948  return nullptr;
2949 }
2950 
2952  Location loc, Type resultType,
2953  ValueRange inputs,
2954  Type originalType) const {
2956  builder, loc, TypeRange(resultType), inputs, originalType);
2957  if (result.empty())
2958  return nullptr;
2959  assert(result.size() == 1 && "expected single result");
2960  return result.front();
2961 }
2962 
2964  OpBuilder &builder, Location loc, TypeRange resultTypes, ValueRange inputs,
2965  Type originalType) const {
2966  for (const TargetMaterializationCallbackFn &fn :
2967  llvm::reverse(targetMaterializations)) {
2968  SmallVector<Value> result =
2969  fn(builder, resultTypes, inputs, loc, originalType);
2970  if (result.empty())
2971  continue;
2972  assert(TypeRange(ValueRange(result)) == resultTypes &&
2973  "callback produced incorrect number of values or values with "
2974  "incorrect types");
2975  return result;
2976  }
2977  return {};
2978 }
2979 
2980 std::optional<TypeConverter::SignatureConversion>
2982  SignatureConversion conversion(block->getNumArguments());
2983  if (failed(convertSignatureArgs(block->getArgumentTypes(), conversion)))
2984  return std::nullopt;
2985  return conversion;
2986 }
2987 
2988 //===----------------------------------------------------------------------===//
2989 // Type attribute conversion
2990 //===----------------------------------------------------------------------===//
2993  return AttributeConversionResult(attr, resultTag);
2994 }
2995 
2998  return AttributeConversionResult(nullptr, naTag);
2999 }
3000 
3003  return AttributeConversionResult(nullptr, abortTag);
3004 }
3005 
3007  return impl.getInt() == resultTag;
3008 }
3009 
3011  return impl.getInt() == naTag;
3012 }
3013 
3015  return impl.getInt() == abortTag;
3016 }
3017 
3019  assert(hasResult() && "Cannot get result from N/A or abort");
3020  return impl.getPointer();
3021 }
3022 
3023 std::optional<Attribute>
3025  for (const TypeAttributeConversionCallbackFn &fn :
3026  llvm::reverse(typeAttributeConversions)) {
3027  AttributeConversionResult res = fn(type, attr);
3028  if (res.hasResult())
3029  return res.getResult();
3030  if (res.isAbort())
3031  return std::nullopt;
3032  }
3033  return std::nullopt;
3034 }
3035 
3036 //===----------------------------------------------------------------------===//
3037 // FunctionOpInterfaceSignatureConversion
3038 //===----------------------------------------------------------------------===//
3039 
3040 static LogicalResult convertFuncOpTypes(FunctionOpInterface funcOp,
3041  const TypeConverter &typeConverter,
3042  ConversionPatternRewriter &rewriter) {
3043  FunctionType type = dyn_cast<FunctionType>(funcOp.getFunctionType());
3044  if (!type)
3045  return failure();
3046 
3047  // Convert the original function types.
3048  TypeConverter::SignatureConversion result(type.getNumInputs());
3049  SmallVector<Type, 1> newResults;
3050  if (failed(typeConverter.convertSignatureArgs(type.getInputs(), result)) ||
3051  failed(typeConverter.convertTypes(type.getResults(), newResults)) ||
3052  failed(rewriter.convertRegionTypes(&funcOp.getFunctionBody(),
3053  typeConverter, &result)))
3054  return failure();
3055 
3056  // Update the function signature in-place.
3057  auto newType = FunctionType::get(rewriter.getContext(),
3058  result.getConvertedTypes(), newResults);
3059 
3060  rewriter.modifyOpInPlace(funcOp, [&] { funcOp.setType(newType); });
3061 
3062  return success();
3063 }
3064 
3065 /// Create a default conversion pattern that rewrites the type signature of a
3066 /// FunctionOpInterface op. This only supports ops which use FunctionType to
3067 /// represent their type.
3068 namespace {
3069 struct FunctionOpInterfaceSignatureConversion : public ConversionPattern {
3070  FunctionOpInterfaceSignatureConversion(StringRef functionLikeOpName,
3071  MLIRContext *ctx,
3072  const TypeConverter &converter)
3073  : ConversionPattern(converter, functionLikeOpName, /*benefit=*/1, ctx) {}
3074 
3075  LogicalResult
3076  matchAndRewrite(Operation *op, ArrayRef<Value> /*operands*/,
3077  ConversionPatternRewriter &rewriter) const override {
3078  FunctionOpInterface funcOp = cast<FunctionOpInterface>(op);
3079  return convertFuncOpTypes(funcOp, *typeConverter, rewriter);
3080  }
3081 };
3082 
3083 struct AnyFunctionOpInterfaceSignatureConversion
3084  : public OpInterfaceConversionPattern<FunctionOpInterface> {
3086 
3087  LogicalResult
3088  matchAndRewrite(FunctionOpInterface funcOp, ArrayRef<Value> /*operands*/,
3089  ConversionPatternRewriter &rewriter) const override {
3090  return convertFuncOpTypes(funcOp, *typeConverter, rewriter);
3091  }
3092 };
3093 } // namespace
3094 
3095 FailureOr<Operation *>
3097  const TypeConverter &converter,
3098  ConversionPatternRewriter &rewriter) {
3099  assert(op && "Invalid op");
3100  Location loc = op->getLoc();
3101  if (converter.isLegal(op))
3102  return rewriter.notifyMatchFailure(loc, "op already legal");
3103 
3104  OperationState newOp(loc, op->getName());
3105  newOp.addOperands(operands);
3106 
3107  SmallVector<Type> newResultTypes;
3108  if (failed(converter.convertTypes(op->getResultTypes(), newResultTypes)))
3109  return rewriter.notifyMatchFailure(loc, "couldn't convert return types");
3110 
3111  newOp.addTypes(newResultTypes);
3112  newOp.addAttributes(op->getAttrs());
3113  return rewriter.create(newOp);
3114 }
3115 
3117  StringRef functionLikeOpName, RewritePatternSet &patterns,
3118  const TypeConverter &converter) {
3119  patterns.add<FunctionOpInterfaceSignatureConversion>(
3120  functionLikeOpName, patterns.getContext(), converter);
3121 }
3122 
3124  RewritePatternSet &patterns, const TypeConverter &converter) {
3125  patterns.add<AnyFunctionOpInterfaceSignatureConversion>(
3126  converter, patterns.getContext());
3127 }
3128 
3129 //===----------------------------------------------------------------------===//
3130 // ConversionTarget
3131 //===----------------------------------------------------------------------===//
3132 
3134  LegalizationAction action) {
3135  legalOperations[op].action = action;
3136 }
3137 
3139  LegalizationAction action) {
3140  for (StringRef dialect : dialectNames)
3141  legalDialects[dialect] = action;
3142 }
3143 
3145  -> std::optional<LegalizationAction> {
3146  std::optional<LegalizationInfo> info = getOpInfo(op);
3147  return info ? info->action : std::optional<LegalizationAction>();
3148 }
3149 
3151  -> std::optional<LegalOpDetails> {
3152  std::optional<LegalizationInfo> info = getOpInfo(op->getName());
3153  if (!info)
3154  return std::nullopt;
3155 
3156  // Returns true if this operation instance is known to be legal.
3157  auto isOpLegal = [&] {
3158  // Handle dynamic legality either with the provided legality function.
3159  if (info->action == LegalizationAction::Dynamic) {
3160  std::optional<bool> result = info->legalityFn(op);
3161  if (result)
3162  return *result;
3163  }
3164 
3165  // Otherwise, the operation is only legal if it was marked 'Legal'.
3166  return info->action == LegalizationAction::Legal;
3167  };
3168  if (!isOpLegal())
3169  return std::nullopt;
3170 
3171  // This operation is legal, compute any additional legality information.
3172  LegalOpDetails legalityDetails;
3173  if (info->isRecursivelyLegal) {
3174  auto legalityFnIt = opRecursiveLegalityFns.find(op->getName());
3175  if (legalityFnIt != opRecursiveLegalityFns.end()) {
3176  legalityDetails.isRecursivelyLegal =
3177  legalityFnIt->second(op).value_or(true);
3178  } else {
3179  legalityDetails.isRecursivelyLegal = true;
3180  }
3181  }
3182  return legalityDetails;
3183 }
3184 
3186  std::optional<LegalizationInfo> info = getOpInfo(op->getName());
3187  if (!info)
3188  return false;
3189 
3190  if (info->action == LegalizationAction::Dynamic) {
3191  std::optional<bool> result = info->legalityFn(op);
3192  if (!result)
3193  return false;
3194 
3195  return !(*result);
3196  }
3197 
3198  return info->action == LegalizationAction::Illegal;
3199 }
3200 
3204  if (!oldCallback)
3205  return newCallback;
3206 
3207  auto chain = [oldCl = std::move(oldCallback), newCl = std::move(newCallback)](
3208  Operation *op) -> std::optional<bool> {
3209  if (std::optional<bool> result = newCl(op))
3210  return *result;
3211 
3212  return oldCl(op);
3213  };
3214  return chain;
3215 }
3216 
3217 void ConversionTarget::setLegalityCallback(
3218  OperationName name, const DynamicLegalityCallbackFn &callback) {
3219  assert(callback && "expected valid legality callback");
3220  auto *infoIt = legalOperations.find(name);
3221  assert(infoIt != legalOperations.end() &&
3222  infoIt->second.action == LegalizationAction::Dynamic &&
3223  "expected operation to already be marked as dynamically legal");
3224  infoIt->second.legalityFn =
3225  composeLegalityCallbacks(std::move(infoIt->second.legalityFn), callback);
3226 }
3227 
3229  OperationName name, const DynamicLegalityCallbackFn &callback) {
3230  auto *infoIt = legalOperations.find(name);
3231  assert(infoIt != legalOperations.end() &&
3232  infoIt->second.action != LegalizationAction::Illegal &&
3233  "expected operation to already be marked as legal");
3234  infoIt->second.isRecursivelyLegal = true;
3235  if (callback)
3236  opRecursiveLegalityFns[name] = composeLegalityCallbacks(
3237  std::move(opRecursiveLegalityFns[name]), callback);
3238  else
3239  opRecursiveLegalityFns.erase(name);
3240 }
3241 
3242 void ConversionTarget::setLegalityCallback(
3243  ArrayRef<StringRef> dialects, const DynamicLegalityCallbackFn &callback) {
3244  assert(callback && "expected valid legality callback");
3245  for (StringRef dialect : dialects)
3246  dialectLegalityFns[dialect] = composeLegalityCallbacks(
3247  std::move(dialectLegalityFns[dialect]), callback);
3248 }
3249 
3250 void ConversionTarget::setLegalityCallback(
3251  const DynamicLegalityCallbackFn &callback) {
3252  assert(callback && "expected valid legality callback");
3253  unknownLegalityFn = composeLegalityCallbacks(unknownLegalityFn, callback);
3254 }
3255 
3256 auto ConversionTarget::getOpInfo(OperationName op) const
3257  -> std::optional<LegalizationInfo> {
3258  // Check for info for this specific operation.
3259  const auto *it = legalOperations.find(op);
3260  if (it != legalOperations.end())
3261  return it->second;
3262  // Check for info for the parent dialect.
3263  auto dialectIt = legalDialects.find(op.getDialectNamespace());
3264  if (dialectIt != legalDialects.end()) {
3265  DynamicLegalityCallbackFn callback;
3266  auto dialectFn = dialectLegalityFns.find(op.getDialectNamespace());
3267  if (dialectFn != dialectLegalityFns.end())
3268  callback = dialectFn->second;
3269  return LegalizationInfo{dialectIt->second, /*isRecursivelyLegal=*/false,
3270  callback};
3271  }
3272  // Otherwise, check if we mark unknown operations as dynamic.
3273  if (unknownLegalityFn)
3274  return LegalizationInfo{LegalizationAction::Dynamic,
3275  /*isRecursivelyLegal=*/false, unknownLegalityFn};
3276  return std::nullopt;
3277 }
3278 
3279 #if MLIR_ENABLE_PDL_IN_PATTERNMATCH
3280 //===----------------------------------------------------------------------===//
3281 // PDL Configuration
3282 //===----------------------------------------------------------------------===//
3283 
3285  auto &rewriterImpl =
3286  static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
3287  rewriterImpl.currentTypeConverter = getTypeConverter();
3288 }
3289 
3291  auto &rewriterImpl =
3292  static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
3293  rewriterImpl.currentTypeConverter = nullptr;
3294 }
3295 
3296 /// Remap the given value using the rewriter and the type converter in the
3297 /// provided config.
3298 static FailureOr<SmallVector<Value>>
3300  SmallVector<Value> mappedValues;
3301  if (failed(rewriter.getRemappedValues(values, mappedValues)))
3302  return failure();
3303  return std::move(mappedValues);
3304 }
3305 
3307  patterns.getPDLPatterns().registerRewriteFunction(
3308  "convertValue",
3309  [](PatternRewriter &rewriter, Value value) -> FailureOr<Value> {
3310  auto results = pdllConvertValues(
3311  static_cast<ConversionPatternRewriter &>(rewriter), value);
3312  if (failed(results))
3313  return failure();
3314  return results->front();
3315  });
3316  patterns.getPDLPatterns().registerRewriteFunction(
3317  "convertValues", [](PatternRewriter &rewriter, ValueRange values) {
3318  return pdllConvertValues(
3319  static_cast<ConversionPatternRewriter &>(rewriter), values);
3320  });
3321  patterns.getPDLPatterns().registerRewriteFunction(
3322  "convertType",
3323  [](PatternRewriter &rewriter, Type type) -> FailureOr<Type> {
3324  auto &rewriterImpl =
3325  static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
3326  if (const TypeConverter *converter =
3327  rewriterImpl.currentTypeConverter) {
3328  if (Type newType = converter->convertType(type))
3329  return newType;
3330  return failure();
3331  }
3332  return type;
3333  });
3334  patterns.getPDLPatterns().registerRewriteFunction(
3335  "convertTypes",
3336  [](PatternRewriter &rewriter,
3337  TypeRange types) -> FailureOr<SmallVector<Type>> {
3338  auto &rewriterImpl =
3339  static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
3340  const TypeConverter *converter = rewriterImpl.currentTypeConverter;
3341  if (!converter)
3342  return SmallVector<Type>(types);
3343 
3344  SmallVector<Type> remappedTypes;
3345  if (failed(converter->convertTypes(types, remappedTypes)))
3346  return failure();
3347  return std::move(remappedTypes);
3348  });
3349 }
3350 #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
3351 
3352 //===----------------------------------------------------------------------===//
3353 // Op Conversion Entry Points
3354 //===----------------------------------------------------------------------===//
3355 
3356 //===----------------------------------------------------------------------===//
3357 // Partial Conversion
3358 
3360  ArrayRef<Operation *> ops, const ConversionTarget &target,
3362  OperationConverter opConverter(target, patterns, config,
3363  OpConversionMode::Partial);
3364  return opConverter.convertOperations(ops);
3365 }
3366 LogicalResult
3370  return applyPartialConversion(llvm::ArrayRef(op), target, patterns, config);
3371 }
3372 
3373 //===----------------------------------------------------------------------===//
3374 // Full Conversion
3375 
3377  const ConversionTarget &target,
3380  OperationConverter opConverter(target, patterns, config,
3381  OpConversionMode::Full);
3382  return opConverter.convertOperations(ops);
3383 }
3385  const ConversionTarget &target,
3388  return applyFullConversion(llvm::ArrayRef(op), target, patterns, config);
3389 }
3390 
3391 //===----------------------------------------------------------------------===//
3392 // Analysis Conversion
3393 
3394 /// Find a common IsolatedFromAbove ancestor of the given ops. If at least one
3395 /// op is a top-level module op (which is expected to be isolated from above),
3396 /// return that op.
3398  // Check if there is a top-level operation within `ops`. If so, return that
3399  // op.
3400  for (Operation *op : ops) {
3401  if (!op->getParentOp()) {
3402 #ifndef NDEBUG
3403  assert(op->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
3404  "expected top-level op to be isolated from above");
3405  for (Operation *other : ops)
3406  assert(op->isAncestor(other) &&
3407  "expected ops to have a common ancestor");
3408 #endif // NDEBUG
3409  return op;
3410  }
3411  }
3412 
3413  // No top-level op. Find a common ancestor.
3414  Operation *commonAncestor =
3415  ops.front()->getParentWithTrait<OpTrait::IsIsolatedFromAbove>();
3416  for (Operation *op : ops.drop_front()) {
3417  while (!commonAncestor->isProperAncestor(op)) {
3418  commonAncestor =
3420  assert(commonAncestor &&
3421  "expected to find a common isolated from above ancestor");
3422  }
3423  }
3424 
3425  return commonAncestor;
3426 }
3427 
3431 #ifndef NDEBUG
3432  if (config.legalizableOps)
3433  assert(config.legalizableOps->empty() && "expected empty set");
3434 #endif // NDEBUG
3435 
3436  // Clone closted common ancestor that is isolated from above.
3437  Operation *commonAncestor = findCommonAncestor(ops);
3438  IRMapping mapping;
3439  Operation *clonedAncestor = commonAncestor->clone(mapping);
3440  // Compute inverse IR mapping.
3441  DenseMap<Operation *, Operation *> inverseOperationMap;
3442  for (auto &it : mapping.getOperationMap())
3443  inverseOperationMap[it.second] = it.first;
3444 
3445  // Convert the cloned operations. The original IR will remain unchanged.
3446  SmallVector<Operation *> opsToConvert = llvm::map_to_vector(
3447  ops, [&](Operation *op) { return mapping.lookup(op); });
3448  OperationConverter opConverter(target, patterns, config,
3449  OpConversionMode::Analysis);
3450  LogicalResult status = opConverter.convertOperations(opsToConvert);
3451 
3452  // Remap `legalizableOps`, so that they point to the original ops and not the
3453  // cloned ops.
3454  if (config.legalizableOps) {
3455  DenseSet<Operation *> originalLegalizableOps;
3456  for (Operation *op : *config.legalizableOps)
3457  originalLegalizableOps.insert(inverseOperationMap[op]);
3458  *config.legalizableOps = std::move(originalLegalizableOps);
3459  }
3460 
3461  // Erase the cloned IR.
3462  clonedAncestor->erase();
3463  return status;
3464 }
3465 
3466 LogicalResult
3470  return applyAnalysisConversion(llvm::ArrayRef(op), target, patterns, config);
3471 }
static ConversionTarget::DynamicLegalityCallbackFn composeLegalityCallbacks(ConversionTarget::DynamicLegalityCallbackFn oldCallback, ConversionTarget::DynamicLegalityCallbackFn newCallback)
static FailureOr< SmallVector< Value > > pdllConvertValues(ConversionPatternRewriter &rewriter, ValueRange values)
Remap the given value using the rewriter and the type converter in the provided config.
static LogicalResult convertFuncOpTypes(FunctionOpInterface funcOp, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args)
A utility function to log a failure result for the given reason.
static LogicalResult legalizeUnresolvedMaterialization(RewriterBase &rewriter, UnresolvedMaterializationRewrite *rewrite)
static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args)
A utility function to log a successful result for the given reason.
SmallVector< Value, 1 > ValueVector
A vector of SSA values, optimized for the most common case of a single value.
static Operation * findCommonAncestor(ArrayRef< Operation * > ops)
Find a common IsolatedFromAbove ancestor of the given ops.
static OpBuilder::InsertPoint computeInsertPoint(Value value)
Helper function that computes an insertion point where the given value is defined and can be used wit...
union mlir::linalg::@1181::ArityGroupAndKind::Kind kind
static std::string diag(const llvm::Value &value)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void rewrite(DataFlowSolver &solver, MLIRContext *context, MutableArrayRef< Region > initialRegions)
Rewrite the given regions using the computing analysis.
Definition: SCCP.cpp:67
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:319
Block * getOwner() const
Returns the block that owns this argument.
Definition: Value.h:328
Location getLoc() const
Return the location for this argument.
Definition: Value.h:334
Block represents an ordered list of Operations.
Definition: Block.h:33
OpListType::iterator iterator
Definition: Block.h:140
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
Definition: Block.cpp:151
bool empty()
Definition: Block.h:148
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
unsigned getNumArguments()
Definition: Block.h:128
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition: Block.cpp:29
void dropAllDefinedValueUses()
This drops all uses of values defined in this block or in the blocks of nested regions wherever the u...
Definition: Block.cpp:96
OpListType & getOperations()
Definition: Block.h:137
BlockArgListType getArguments()
Definition: Block.h:87
Operation & front()
Definition: Block.h:153
iterator end()
Definition: Block.h:144
iterator begin()
Definition: Block.h:143
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:33
MLIRContext * getContext() const
Definition: Builders.h:56
Location getUnknownLoc()
Definition: Builders.cpp:27
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
LogicalResult getRemappedValues(ValueRange keys, SmallVectorImpl< Value > &results)
Return the converted values that replace 'keys' with types defined by the type converter of the curre...
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Apply a signature conversion to each block in the given region.
void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt) override
PatternRewriter hook for inlining the ops of a block into another block.
Block * applySignatureConversion(Block *block, TypeConverter::SignatureConversion &conversion, const TypeConverter *converter=nullptr)
Apply a signature conversion to given block.
void startOpModification(Operation *op) override
PatternRewriter hook for updating the given operation in-place.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
void replaceOpWithMultiple(Operation *op, ArrayRef< ValueRange > newValues)
Replace the given operation with the new value ranges.
detail::ConversionPatternRewriterImpl & getImpl()
Return a reference to the internal implementation.
void eraseBlock(Block *block) override
PatternRewriter hook for erase all operations in a block.
void cancelOpModification(Operation *op) override
PatternRewriter hook for updating the given operation in-place.
Value getRemappedValue(Value key)
Return the converted value of 'key' with a type defined by the type converter of the currently execut...
void finalizeOpModification(Operation *op) override
PatternRewriter hook for updating the given operation in-place.
void replaceUsesOfBlockArgument(BlockArgument from, Value to)
Replace all the uses of the block argument from with value to.
Base class for the conversion patterns.
SmallVector< Value > getOneToOneAdaptorOperands(ArrayRef< ValueRange > operands) const
Given an array of value ranges, which are the inputs to a 1:N adaptor, try to extract the single valu...
virtual LogicalResult matchAndRewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const
Hook for derived classes to implement combined matching and rewriting.
This class describes a specific conversion target.
void setDialectAction(ArrayRef< StringRef > dialectNames, LegalizationAction action)
Register a legality action for the given dialects.
void setOpAction(OperationName op, LegalizationAction action)
Register a legality action for the given operation.
std::optional< LegalOpDetails > isLegal(Operation *op) const
If the given operation instance is legal on this target, a structure containing legality information ...
std::optional< LegalizationAction > getOpAction(OperationName op) const
Get the legality action for the given operation.
LegalizationAction
This enumeration corresponds to the specific action to take when considering an operation legal for t...
void markOpRecursivelyLegal(OperationName name, const DynamicLegalityCallbackFn &callback)
Mark an operation, that must have either been set as Legal or DynamicallyLegal, as being recursively ...
std::function< std::optional< bool >(Operation *)> DynamicLegalityCallbackFn
The signature of the callback used to determine if an operation is dynamically legal on the target.
bool isIllegal(Operation *op) const
Returns true is operation instance is illegal on this target.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
Definition: Diagnostics.h:155
A class for computing basic dominance information.
Definition: Dominance.h:140
bool dominates(Operation *a, Operation *b) const
Return true if operation A dominates operation B, i.e.
Definition: Dominance.h:158
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
const DenseMap< Operation *, Operation * > & getOperationMap() const
Return the held operation mapping.
Definition: IRMapping.h:88
auto lookup(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:72
user_range getUsers() const
Returns a range of all users.
Definition: UseDefLists.h:274
void replaceAllUsesWith(ValueT &&newValue)
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
Definition: UseDefLists.h:211
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:784
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:314
Location objects represent source locations information in MLIR.
Definition: Location.h:31
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
bool isMultithreadingEnabled()
Return true if multi-threading is enabled by the context.
This class represents a saved insertion point.
Definition: Builders.h:325
Block::iterator getPoint() const
Definition: Builders.h:338
bool isSet() const
Returns true if this insert point is set.
Definition: Builders.h:335
Block * getBlock() const
Definition: Builders.h:337
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class helps build Operations.
Definition: Builders.h:205
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:396
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Definition: Builders.h:318
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:426
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
LogicalResult tryFold(Operation *op, SmallVectorImpl< Value > &results)
Attempts to fold the given operation and places new results within results.
Definition: Builders.cpp:468
OpInterfaceConversionPattern is a wrapper around ConversionPattern that allows for matching and rewri...
OpInterfaceConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
This class represents an operand of an operation.
Definition: Value.h:267
This is a value defined by a result of an operation.
Definition: Value.h:457
This class provides the API for ops that are known to be isolated from above.
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:42
type_range getTypes() const
Definition: ValueRange.cpp:26
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
void setLoc(Location loc)
Set the source location the operation was defined or derived from.
Definition: Operation.h:226
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:750
void dropAllUses()
Drop all uses of results of this operation.
Definition: Operation.h:835
void setAttrs(DictionaryAttr newAttrs)
Set the attributes from a dictionary on this operation.
Definition: Operation.cpp:305
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Definition: Operation.cpp:386
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
Definition: Operation.cpp:717
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:798
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:674
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:512
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
Definition: Operation.h:248
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:677
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
operand_type_range getOperandTypes()
Definition: Operation.h:397
result_type_range getResultTypes()
Definition: Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
void setSuccessor(Block *block, unsigned index)
Definition: Operation.cpp:605
bool isAncestor(Operation *other)
Return true if this operation is an ancestor of the other operation.
Definition: Operation.h:263
void setOperands(ValueRange operands)
Replace the current operands of this operation with the ones provided in 'operands'.
Definition: Operation.cpp:237
result_range getResults()
Definition: Operation.h:415
int getPropertiesStorageSize() const
Returns the properties storage size.
Definition: Operation.h:897
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
Definition: Operation.cpp:219
OpaqueProperties getPropertiesStorage()
Returns the properties storage.
Definition: Operation.h:901
void erase()
Remove this operation from its parent block and delete it.
Definition: Operation.cpp:539
void copyProperties(OpaqueProperties rhs)
Copy properties from an existing other properties object.
Definition: Operation.cpp:366
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
void notifyRewriteEnd(PatternRewriter &rewriter) final
void notifyRewriteBegin(PatternRewriter &rewriter) final
Hooks that are invoked at the beginning and end of a rewrite of a matched pattern.
This class manages the application of a group of rewrite patterns, with a user-provided cost model.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
static PatternBenefit impossibleToMatch()
Definition: PatternMatch.h:43
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:803
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
Definition: PatternMatch.h:73
bool hasBoundedRewriteRecursion() const
Returns true if this pattern is known to result in recursive application, i.e.
Definition: PatternMatch.h:129
std::optional< OperationName > getRootKind() const
Return the root node that this pattern matches.
Definition: PatternMatch.h:94
ArrayRef< OperationName > getGeneratedOps() const
Return a list of operations that may be generated when rewriting an operation instance with this patt...
Definition: PatternMatch.h:90
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition: Region.h:200
bool empty()
Definition: Region.h:60
iterator end()
Definition: Region.h:56
BlockListType & getBlocks()
Definition: Region.h:45
Block & front()
Definition: Region.h:65
BlockListType::iterator iterator
Definition: Region.h:52
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:412
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:736
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:656
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void moveOpBefore(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:648
The general result of a type attribute conversion callback, allowing for early termination.
static AttributeConversionResult result(Attribute attr)
This class provides all of the information necessary to convert a type signature.
void addInputs(unsigned origInputNo, ArrayRef< Type > types)
Remap an input of the original signature with a new set of types.
std::optional< InputMapping > getInputMapping(unsigned input) const
Get the input mapping for the given argument.
ArrayRef< Type > getConvertedTypes() const
Return the argument types for the new signature.
void remapInput(unsigned origInputNo, ArrayRef< Value > replacements)
Remap an input of the original signature to replacements values.
Type conversion class.
std::optional< Attribute > convertTypeAttribute(Type type, Attribute attr) const
Convert an attribute present attr from within the type type using the registered conversion functions...
Value materializeSourceConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs) const
bool isLegal(Type type) const
Return true if the given type is legal for this type converter, i.e.
LogicalResult convertSignatureArgs(TypeRange types, SignatureConversion &result, unsigned origInputOffset=0) const
LogicalResult convertSignatureArg(unsigned inputNo, Type type, SignatureConversion &result) const
This method allows for converting a specific argument of a signature.
Value materializeArgumentConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs) const
Materialize a conversion from a set of types into one result type by generating a cast sequence of so...
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
std::optional< SignatureConversion > convertBlockSignature(Block *block) const
This function converts the type signature of the given block, by invoking 'convertSignatureArg' for e...
LogicalResult convertTypes(TypeRange types, SmallVectorImpl< Type > &results) const
Convert the given set of types, filling 'results' as necessary.
bool isSignatureLegal(FunctionType ty) const
Return true if the inputs and outputs of the given function type are legal.
Value materializeTargetConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs, Type originalType={}) const
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Block * getParentBlock()
Return the Block in which this Value is defined.
Definition: Value.cpp:48
user_range getUsers() const
Definition: Value.h:228
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
static WalkResult skip()
Definition: Visitors.h:52
static WalkResult advance()
Definition: Visitors.h:51
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
AttrTypeReplacer.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
@ Full
Documents are synced by always sending the full content of the document.
Kind
An enumeration of the kinds of predicates.
Definition: Predicate.h:44
Include the generated interface declarations.
void populateFunctionOpInterfaceTypeConversionPattern(StringRef functionLikeOpName, RewritePatternSet &patterns, const TypeConverter &converter)
Add a pattern to the given pattern list to convert the signature of a FunctionOpInterface op with the...
void populateAnyFunctionOpInterfaceTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter)
const FrozenRewritePatternSet GreedyRewriteConfig config
LogicalResult applyFullConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Apply a complete conversion on the given operations, and all nested operations.
FailureOr< Operation * > convertOpResultTypes(Operation *op, ValueRange operands, const TypeConverter &converter, ConversionPatternRewriter &rewriter)
Generic utility to convert op result types according to type converter without knowing exact op type.
const FrozenRewritePatternSet & patterns
LogicalResult applyAnalysisConversion(ArrayRef< Operation * > ops, ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Apply an analysis conversion on the given operations, and all nested operations.
void reconcileUnrealizedCasts(ArrayRef< UnrealizedConversionCastOp > castOps, SmallVectorImpl< UnrealizedConversionCastOp > *remainingCastOps=nullptr)
Try to reconcile all given UnrealizedConversionCastOps and store the left-over ops in remainingCastOp...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void registerConversionPDLFunctions(RewritePatternSet &patterns)
Register the dialect conversion PDL functions with the given pattern set.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
Dialect conversion configuration.
RewriterBase::Listener * listener
An optional listener that is notified about all IR modifications in case dialect conversion succeeds.
function_ref< void(Diagnostic &)> notifyCallback
An optional callback used to notify about match failure diagnostics during the conversion.
DenseSet< Operation * > * legalizableOps
Analysis conversion only.
DenseSet< Operation * > * unlegalizedOps
Partial conversion only.
bool buildMaterializations
If set to "true", the dialect conversion attempts to build source/target materializations through the...
A structure containing additional information describing a specific legal operation instance.
bool isRecursivelyLegal
A flag that indicates if this operation is 'recursively' legal.
This iterator enumerates elements according to their dominance relationship.
Definition: Iterators.h:48
LogicalResult convertOperations(ArrayRef< Operation * > ops)
Converts the given operations to the conversion target.
OperationConverter(const ConversionTarget &target, const FrozenRewritePatternSet &patterns, const ConversionConfig &config, OpConversionMode mode)
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addTypes(ArrayRef< Type > newTypes)
This struct represents a range of new types or a range of values that remaps an existing signature in...
A rewriter that keeps track of erased ops and blocks.
void eraseOp(Operation *op) override
Erase the given op (unless it was already erased).
void notifyBlockErased(Block *block) override
Notify the listener that the specified block is about to be erased.
void notifyOperationErased(Operation *op) override
Notify the listener that the specified operation is about to be erased.
void eraseBlock(Block *block) override
Erase the given block (unless it was already erased).
void notifyOperationInserted(Operation *op, OpBuilder::InsertPoint previous) override
Notify the listener that the specified operation was inserted.
Value findOrBuildReplacementValue(Value value, const TypeConverter *converter)
Find a replacement value for the given SSA value in the conversion value mapping.
ConversionPatternRewriterImpl(MLIRContext *ctx, const ConversionConfig &config)
DenseMap< Region *, const TypeConverter * > regionToConverter
A mapping of regions to type converters that should be used when converting the arguments of blocks w...
bool wasOpReplaced(Operation *op) const
Return "true" if the given operation was replaced or erased.
void notifyBlockInserted(Block *block, Region *previous, Region::iterator previousIt) override
Notifies that a block was inserted.
DenseMap< UnrealizedConversionCastOp, UnresolvedMaterializationRewrite * > unresolvedMaterializations
A mapping of all unresolved materializations (UnrealizedConversionCastOp) to the corresponding rewrit...
LogicalResult remapValues(StringRef valueDiagTag, std::optional< Location > inputLoc, PatternRewriter &rewriter, ValueRange values, SmallVector< ValueVector > &remapped)
Remap the given values to those with potentially different types.
void resetState(RewriterState state)
Reset the state of the rewriter to a previously saved point.
Block * applySignatureConversion(ConversionPatternRewriter &rewriter, Block *block, const TypeConverter *converter, TypeConverter::SignatureConversion &signatureConversion)
Apply the given signature conversion on the given block.
FailureOr< Block * > convertRegionTypes(ConversionPatternRewriter &rewriter, Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion)
Convert the types of block arguments within the given region.
void applyRewrites()
Apply all requested operation rewrites.
void undoRewrites(unsigned numRewritesToKeep=0)
Undo the rewrites (motions, splits) one by one in reverse order until "numRewritesToKeep" rewrites re...
RewriterState getCurrentState()
Return the current state of the rewriter.
llvm::ScopedPrinter logger
A logger used to emit diagnostics during the conversion process.
void notifyBlockBeingInlined(Block *block, Block *srcBlock, Block::iterator before)
Notifies that a block is being inlined into another block.
void appendRewrite(Args &&...args)
Append a rewrite.
SmallPtrSet< Operation *, 1 > pendingRootUpdates
A set of operations that have pending updates.
void notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
Notifies that a pattern match failed for the given reason.
SingleEraseRewriter eraseRewriter
A rewriter that keeps track of ops/block that were already erased and skips duplicate op/block erasur...
bool isOpIgnored(Operation *op) const
Return "true" if the given operation is ignored, and does not need to be converted.
ValueRange buildUnresolvedMaterialization(MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc, ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes, Type originalType, const TypeConverter *converter, UnrealizedConversionCastOp *castOp=nullptr)
Build an unresolved materialization operation given a range of output types and a list of input opera...
SetVector< Operation * > ignoredOps
A set of operations that should no longer be considered for legalization.
SmallVector< std::unique_ptr< IRRewrite > > rewrites
Ordered list of block operations (creations, splits, motions).
void notifyOpReplaced(Operation *op, ArrayRef< ValueRange > newValues)
Notifies that an op is about to be replaced with the given values.
const ConversionConfig & config
Dialect conversion configuration.
SetVector< Operation * > replacedOps
A set of operations that were replaced/erased.
void notifyBlockIsBeingErased(Block *block)
Notifies that a block is about to be erased.
const TypeConverter * currentTypeConverter
The current type converter, or nullptr if no type converter is currently active.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.