MLIR  20.0.0git
DialectConversion.cpp
Go to the documentation of this file.
1 //===- DialectConversion.cpp - MLIR dialect conversion generic pass -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 #include "mlir/Config/mlir-config.h"
11 #include "mlir/IR/Block.h"
12 #include "mlir/IR/Builders.h"
13 #include "mlir/IR/BuiltinOps.h"
14 #include "mlir/IR/Dominance.h"
15 #include "mlir/IR/IRMapping.h"
16 #include "mlir/IR/Iterators.h"
19 #include "llvm/ADT/ScopeExit.h"
20 #include "llvm/ADT/SetVector.h"
21 #include "llvm/ADT/SmallPtrSet.h"
22 #include "llvm/Support/Debug.h"
23 #include "llvm/Support/FormatVariadic.h"
24 #include "llvm/Support/SaveAndRestore.h"
25 #include "llvm/Support/ScopedPrinter.h"
26 #include <optional>
27 
28 using namespace mlir;
29 using namespace mlir::detail;
30 
31 #define DEBUG_TYPE "dialect-conversion"
32 
33 /// A utility function to log a successful result for the given reason.
34 template <typename... Args>
35 static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
36  LLVM_DEBUG({
37  os.unindent();
38  os.startLine() << "} -> SUCCESS";
39  if (!fmt.empty())
40  os.getOStream() << " : "
41  << llvm::formatv(fmt.data(), std::forward<Args>(args)...);
42  os.getOStream() << "\n";
43  });
44 }
45 
46 /// A utility function to log a failure result for the given reason.
47 template <typename... Args>
48 static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
49  LLVM_DEBUG({
50  os.unindent();
51  os.startLine() << "} -> FAILURE : "
52  << llvm::formatv(fmt.data(), std::forward<Args>(args)...)
53  << "\n";
54  });
55 }
56 
57 /// Helper function that computes an insertion point where the given value is
58 /// defined and can be used without a dominance violation.
60  Block *insertBlock = value.getParentBlock();
61  Block::iterator insertPt = insertBlock->begin();
62  if (OpResult inputRes = dyn_cast<OpResult>(value))
63  insertPt = ++inputRes.getOwner()->getIterator();
64  return OpBuilder::InsertPoint(insertBlock, insertPt);
65 }
66 
67 /// Helper function that computes an insertion point where the given values are
68 /// defined and can be used without a dominance violation.
70  assert(!vals.empty() && "expected at least one value");
71  DominanceInfo domInfo;
72  OpBuilder::InsertPoint pt = computeInsertPoint(vals.front());
73  for (Value v : vals.drop_front()) {
74  // Choose the "later" insertion point.
76  if (domInfo.dominates(pt.getBlock(), pt.getPoint(), nextPt.getBlock(),
77  nextPt.getPoint())) {
78  // pt is before nextPt => choose nextPt.
79  pt = nextPt;
80  } else {
81 #ifndef NDEBUG
82  // nextPt should be before pt => choose pt.
83  // If pt, nextPt are no dominance relationship, then there is no valid
84  // insertion point at which all given values are defined.
85  bool dom = domInfo.dominates(nextPt.getBlock(), nextPt.getPoint(),
86  pt.getBlock(), pt.getPoint());
87  assert(dom && "unable to find valid insertion point");
88 #endif // NDEBUG
89  }
90  }
91  return pt;
92 }
93 
94 //===----------------------------------------------------------------------===//
95 // ConversionValueMapping
96 //===----------------------------------------------------------------------===//
97 
98 /// A vector of SSA values, optimized for the most common case of a single
99 /// value.
100 using ValueVector = SmallVector<Value, 1>;
101 
102 namespace {
103 
104 /// Helper class to make it possible to use `ValueVector` as a key in DenseMap.
105 struct ValueVectorMapInfo {
106  static ValueVector getEmptyKey() { return ValueVector{Value()}; }
107  static ValueVector getTombstoneKey() { return ValueVector{Value(), Value()}; }
108  static ::llvm::hash_code getHashValue(const ValueVector &val) {
109  return ::llvm::hash_combine_range(val.begin(), val.end());
110  }
111  static bool isEqual(const ValueVector &LHS, const ValueVector &RHS) {
112  return LHS == RHS;
113  }
114 };
115 
116 /// This class wraps a IRMapping to provide recursive lookup
117 /// functionality, i.e. we will traverse if the mapped value also has a mapping.
118 struct ConversionValueMapping {
119  /// Return "true" if an SSA value is mapped to the given value. May return
120  /// false positives.
121  bool isMappedTo(Value value) const { return mappedTo.contains(value); }
122 
123  /// Lookup the most recently mapped values with the desired types in the
124  /// mapping.
125  ///
126  /// Special cases:
127  /// - If the desired type range is empty, simply return the most recently
128  /// mapped values.
129  /// - If there is no mapping to the desired types, also return the most
130  /// recently mapped values.
131  /// - If there is no mapping for the given values at all, return the given
132  /// value.
133  ValueVector lookupOrDefault(Value from, TypeRange desiredTypes = {}) const;
134 
135  /// Lookup the given value within the map, or return an empty vector if the
136  /// value is not mapped. If it is mapped, this follows the same behavior
137  /// as `lookupOrDefault`.
138  ValueVector lookupOrNull(Value from, TypeRange desiredTypes = {}) const;
139 
140  template <typename T>
141  struct IsValueVector : std::is_same<std::decay_t<T>, ValueVector> {};
142 
143  /// Map a value vector to the one provided.
144  template <typename OldVal, typename NewVal>
145  std::enable_if_t<IsValueVector<OldVal>::value && IsValueVector<NewVal>::value>
146  map(OldVal &&oldVal, NewVal &&newVal) {
147  LLVM_DEBUG({
148  ValueVector next(newVal);
149  while (true) {
150  assert(next != oldVal && "inserting cyclic mapping");
151  auto it = mapping.find(next);
152  if (it == mapping.end())
153  break;
154  next = it->second;
155  }
156  });
157  for (Value v : newVal)
158  mappedTo.insert(v);
159 
160  mapping[std::forward<OldVal>(oldVal)] = std::forward<NewVal>(newVal);
161  }
162 
163  /// Map a value vector or single value to the one provided.
164  template <typename OldVal, typename NewVal>
165  std::enable_if_t<!IsValueVector<OldVal>::value ||
166  !IsValueVector<NewVal>::value>
167  map(OldVal &&oldVal, NewVal &&newVal) {
168  if constexpr (IsValueVector<OldVal>{}) {
169  map(std::forward<OldVal>(oldVal), ValueVector{newVal});
170  } else if constexpr (IsValueVector<NewVal>{}) {
171  map(ValueVector{oldVal}, std::forward<NewVal>(newVal));
172  } else {
173  map(ValueVector{oldVal}, ValueVector{newVal});
174  }
175  }
176 
177  /// Drop the last mapping for the given values.
178  void erase(const ValueVector &value) { mapping.erase(value); }
179 
180 private:
181  /// Current value mappings.
183 
184  /// All SSA values that are mapped to. May contain false positives.
185  DenseSet<Value> mappedTo;
186 };
187 } // namespace
188 
190 ConversionValueMapping::lookupOrDefault(Value from,
191  TypeRange desiredTypes) const {
192  // Try to find the deepest values that have the desired types. If there is no
193  // such mapping, simply return the deepest values.
194  ValueVector desiredValue;
195  ValueVector current{from};
196  do {
197  // Store the current value if the types match.
198  if (TypeRange(ValueRange(current)) == desiredTypes)
199  desiredValue = current;
200 
201  // If possible, Replace each value with (one or multiple) mapped values.
202  ValueVector next;
203  for (Value v : current) {
204  auto it = mapping.find({v});
205  if (it != mapping.end()) {
206  llvm::append_range(next, it->second);
207  } else {
208  next.push_back(v);
209  }
210  }
211  if (next != current) {
212  // If at least one value was replaced, continue the lookup from there.
213  current = std::move(next);
214  continue;
215  }
216 
217  // Otherwise: Check if there is a mapping for the entire vector. Such
218  // mappings are materializations. (N:M mapping are not supported for value
219  // replacements.)
220  //
221  // Note: From a correctness point of view, materializations do not have to
222  // be stored (and looked up) in the mapping. But for performance reasons,
223  // we choose to reuse existing IR (when possible) instead of creating it
224  // multiple times.
225  auto it = mapping.find(current);
226  if (it == mapping.end()) {
227  // No mapping found: The lookup stops here.
228  break;
229  }
230  current = it->second;
231  } while (true);
232 
233  // If the desired values were found use them, otherwise default to the leaf
234  // values.
235  // Note: If `desiredTypes` is empty, this function always returns `current`.
236  return !desiredValue.empty() ? std::move(desiredValue) : std::move(current);
237 }
238 
239 ValueVector ConversionValueMapping::lookupOrNull(Value from,
240  TypeRange desiredTypes) const {
241  ValueVector result = lookupOrDefault(from, desiredTypes);
242  if (result == ValueVector{from} ||
243  (!desiredTypes.empty() && TypeRange(ValueRange(result)) != desiredTypes))
244  return {};
245  return result;
246 }
247 
248 //===----------------------------------------------------------------------===//
249 // Rewriter and Translation State
250 //===----------------------------------------------------------------------===//
251 namespace {
252 /// This class contains a snapshot of the current conversion rewriter state.
253 /// This is useful when saving and undoing a set of rewrites.
254 struct RewriterState {
255  RewriterState(unsigned numRewrites, unsigned numIgnoredOperations,
256  unsigned numReplacedOps)
257  : numRewrites(numRewrites), numIgnoredOperations(numIgnoredOperations),
258  numReplacedOps(numReplacedOps) {}
259 
260  /// The current number of rewrites performed.
261  unsigned numRewrites;
262 
263  /// The current number of ignored operations.
264  unsigned numIgnoredOperations;
265 
266  /// The current number of replaced ops that are scheduled for erasure.
267  unsigned numReplacedOps;
268 };
269 
270 //===----------------------------------------------------------------------===//
271 // IR rewrites
272 //===----------------------------------------------------------------------===//
273 
274 /// An IR rewrite that can be committed (upon success) or rolled back (upon
275 /// failure).
276 ///
277 /// The dialect conversion keeps track of IR modifications (requested by the
278 /// user through the rewriter API) in `IRRewrite` objects. Some kind of rewrites
279 /// are directly applied to the IR as the rewriter API is used, some are applied
280 /// partially, and some are delayed until the `IRRewrite` objects are committed.
281 class IRRewrite {
282 public:
283  /// The kind of the rewrite. Rewrites can be undone if the conversion fails.
284  /// Enum values are ordered, so that they can be used in `classof`: first all
285  /// block rewrites, then all operation rewrites.
286  enum class Kind {
287  // Block rewrites
288  CreateBlock,
289  EraseBlock,
290  InlineBlock,
291  MoveBlock,
292  BlockTypeConversion,
293  ReplaceBlockArg,
294  // Operation rewrites
295  MoveOperation,
296  ModifyOperation,
297  ReplaceOperation,
298  CreateOperation,
299  UnresolvedMaterialization
300  };
301 
302  virtual ~IRRewrite() = default;
303 
304  /// Roll back the rewrite. Operations may be erased during rollback.
305  virtual void rollback() = 0;
306 
307  /// Commit the rewrite. At this point, it is certain that the dialect
308  /// conversion will succeed. All IR modifications, except for operation/block
309  /// erasure, must be performed through the given rewriter.
310  ///
311  /// Instead of erasing operations/blocks, they should merely be unlinked
312  /// commit phase and finally be erased during the cleanup phase. This is
313  /// because internal dialect conversion state (such as `mapping`) may still
314  /// be using them.
315  ///
316  /// Any IR modification that was already performed before the commit phase
317  /// (e.g., insertion of an op) must be communicated to the listener that may
318  /// be attached to the given rewriter.
319  virtual void commit(RewriterBase &rewriter) {}
320 
321  /// Cleanup operations/blocks. Cleanup is called after commit.
322  virtual void cleanup(RewriterBase &rewriter) {}
323 
324  Kind getKind() const { return kind; }
325 
326  static bool classof(const IRRewrite *rewrite) { return true; }
327 
328 protected:
329  IRRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl)
330  : kind(kind), rewriterImpl(rewriterImpl) {}
331 
332  const ConversionConfig &getConfig() const;
333 
334  const Kind kind;
335  ConversionPatternRewriterImpl &rewriterImpl;
336 };
337 
338 /// A block rewrite.
339 class BlockRewrite : public IRRewrite {
340 public:
341  /// Return the block that this rewrite operates on.
342  Block *getBlock() const { return block; }
343 
344  static bool classof(const IRRewrite *rewrite) {
345  return rewrite->getKind() >= Kind::CreateBlock &&
346  rewrite->getKind() <= Kind::ReplaceBlockArg;
347  }
348 
349 protected:
350  BlockRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl,
351  Block *block)
352  : IRRewrite(kind, rewriterImpl), block(block) {}
353 
354  // The block that this rewrite operates on.
355  Block *block;
356 };
357 
358 /// Creation of a block. Block creations are immediately reflected in the IR.
359 /// There is no extra work to commit the rewrite. During rollback, the newly
360 /// created block is erased.
361 class CreateBlockRewrite : public BlockRewrite {
362 public:
363  CreateBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block)
364  : BlockRewrite(Kind::CreateBlock, rewriterImpl, block) {}
365 
366  static bool classof(const IRRewrite *rewrite) {
367  return rewrite->getKind() == Kind::CreateBlock;
368  }
369 
370  void commit(RewriterBase &rewriter) override {
371  // The block was already created and inserted. Just inform the listener.
372  if (auto *listener = rewriter.getListener())
373  listener->notifyBlockInserted(block, /*previous=*/{}, /*previousIt=*/{});
374  }
375 
376  void rollback() override {
377  // Unlink all of the operations within this block, they will be deleted
378  // separately.
379  auto &blockOps = block->getOperations();
380  while (!blockOps.empty())
381  blockOps.remove(blockOps.begin());
382  block->dropAllUses();
383  if (block->getParent())
384  block->erase();
385  else
386  delete block;
387  }
388 };
389 
390 /// Erasure of a block. Block erasures are partially reflected in the IR. Erased
391 /// blocks are immediately unlinked, but only erased during cleanup. This makes
392 /// it easier to rollback a block erasure: the block is simply inserted into its
393 /// original location.
394 class EraseBlockRewrite : public BlockRewrite {
395 public:
396  EraseBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block)
397  : BlockRewrite(Kind::EraseBlock, rewriterImpl, block),
398  region(block->getParent()), insertBeforeBlock(block->getNextNode()) {}
399 
400  static bool classof(const IRRewrite *rewrite) {
401  return rewrite->getKind() == Kind::EraseBlock;
402  }
403 
404  ~EraseBlockRewrite() override {
405  assert(!block &&
406  "rewrite was neither rolled back nor committed/cleaned up");
407  }
408 
409  void rollback() override {
410  // The block (owned by this rewrite) was not actually erased yet. It was
411  // just unlinked. Put it back into its original position.
412  assert(block && "expected block");
413  auto &blockList = region->getBlocks();
414  Region::iterator before = insertBeforeBlock
415  ? Region::iterator(insertBeforeBlock)
416  : blockList.end();
417  blockList.insert(before, block);
418  block = nullptr;
419  }
420 
421  void commit(RewriterBase &rewriter) override {
422  // Erase the block.
423  assert(block && "expected block");
424  assert(block->empty() && "expected empty block");
425 
426  // Notify the listener that the block is about to be erased.
427  if (auto *listener =
428  dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener()))
429  listener->notifyBlockErased(block);
430  }
431 
432  void cleanup(RewriterBase &rewriter) override {
433  // Erase the block.
434  block->dropAllDefinedValueUses();
435  delete block;
436  block = nullptr;
437  }
438 
439 private:
440  // The region in which this block was previously contained.
441  Region *region;
442 
443  // The original successor of this block before it was unlinked. "nullptr" if
444  // this block was the only block in the region.
445  Block *insertBeforeBlock;
446 };
447 
448 /// Inlining of a block. This rewrite is immediately reflected in the IR.
449 /// Note: This rewrite represents only the inlining of the operations. The
450 /// erasure of the inlined block is a separate rewrite.
451 class InlineBlockRewrite : public BlockRewrite {
452 public:
453  InlineBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block,
454  Block *sourceBlock, Block::iterator before)
455  : BlockRewrite(Kind::InlineBlock, rewriterImpl, block),
456  sourceBlock(sourceBlock),
457  firstInlinedInst(sourceBlock->empty() ? nullptr
458  : &sourceBlock->front()),
459  lastInlinedInst(sourceBlock->empty() ? nullptr : &sourceBlock->back()) {
460  // If a listener is attached to the dialect conversion, ops must be moved
461  // one-by-one. When they are moved in bulk, notifications cannot be sent
462  // because the ops that used to be in the source block at the time of the
463  // inlining (before the "commit" phase) are unknown at the time when
464  // notifications are sent (which is during the "commit" phase).
465  assert(!getConfig().listener &&
466  "InlineBlockRewrite not supported if listener is attached");
467  }
468 
469  static bool classof(const IRRewrite *rewrite) {
470  return rewrite->getKind() == Kind::InlineBlock;
471  }
472 
473  void rollback() override {
474  // Put the operations from the destination block (owned by the rewrite)
475  // back into the source block.
476  if (firstInlinedInst) {
477  assert(lastInlinedInst && "expected operation");
478  sourceBlock->getOperations().splice(sourceBlock->begin(),
479  block->getOperations(),
480  Block::iterator(firstInlinedInst),
481  ++Block::iterator(lastInlinedInst));
482  }
483  }
484 
485 private:
486  // The block that originally contained the operations.
487  Block *sourceBlock;
488 
489  // The first inlined operation.
490  Operation *firstInlinedInst;
491 
492  // The last inlined operation.
493  Operation *lastInlinedInst;
494 };
495 
496 /// Moving of a block. This rewrite is immediately reflected in the IR.
497 class MoveBlockRewrite : public BlockRewrite {
498 public:
499  MoveBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block,
500  Region *region, Block *insertBeforeBlock)
501  : BlockRewrite(Kind::MoveBlock, rewriterImpl, block), region(region),
502  insertBeforeBlock(insertBeforeBlock) {}
503 
504  static bool classof(const IRRewrite *rewrite) {
505  return rewrite->getKind() == Kind::MoveBlock;
506  }
507 
508  void commit(RewriterBase &rewriter) override {
509  // The block was already moved. Just inform the listener.
510  if (auto *listener = rewriter.getListener()) {
511  // Note: `previousIt` cannot be passed because this is a delayed
512  // notification and iterators into past IR state cannot be represented.
513  listener->notifyBlockInserted(block, /*previous=*/region,
514  /*previousIt=*/{});
515  }
516  }
517 
518  void rollback() override {
519  // Move the block back to its original position.
520  Region::iterator before =
521  insertBeforeBlock ? Region::iterator(insertBeforeBlock) : region->end();
522  region->getBlocks().splice(before, block->getParent()->getBlocks(), block);
523  }
524 
525 private:
526  // The region in which this block was previously contained.
527  Region *region;
528 
529  // The original successor of this block before it was moved. "nullptr" if
530  // this block was the only block in the region.
531  Block *insertBeforeBlock;
532 };
533 
534 /// Block type conversion. This rewrite is partially reflected in the IR.
535 class BlockTypeConversionRewrite : public BlockRewrite {
536 public:
537  BlockTypeConversionRewrite(ConversionPatternRewriterImpl &rewriterImpl,
538  Block *origBlock, Block *newBlock)
539  : BlockRewrite(Kind::BlockTypeConversion, rewriterImpl, origBlock),
540  newBlock(newBlock) {}
541 
542  static bool classof(const IRRewrite *rewrite) {
543  return rewrite->getKind() == Kind::BlockTypeConversion;
544  }
545 
546  Block *getOrigBlock() const { return block; }
547 
548  Block *getNewBlock() const { return newBlock; }
549 
550  void commit(RewriterBase &rewriter) override;
551 
552  void rollback() override;
553 
554 private:
555  /// The new block that was created as part of this signature conversion.
556  Block *newBlock;
557 };
558 
559 /// Replacing a block argument. This rewrite is not immediately reflected in the
560 /// IR. An internal IR mapping is updated, but the actual replacement is delayed
561 /// until the rewrite is committed.
562 class ReplaceBlockArgRewrite : public BlockRewrite {
563 public:
564  ReplaceBlockArgRewrite(ConversionPatternRewriterImpl &rewriterImpl,
565  Block *block, BlockArgument arg,
566  const TypeConverter *converter)
567  : BlockRewrite(Kind::ReplaceBlockArg, rewriterImpl, block), arg(arg),
568  converter(converter) {}
569 
570  static bool classof(const IRRewrite *rewrite) {
571  return rewrite->getKind() == Kind::ReplaceBlockArg;
572  }
573 
574  void commit(RewriterBase &rewriter) override;
575 
576  void rollback() override;
577 
578 private:
579  BlockArgument arg;
580 
581  /// The current type converter when the block argument was replaced.
582  const TypeConverter *converter;
583 };
584 
585 /// An operation rewrite.
586 class OperationRewrite : public IRRewrite {
587 public:
588  /// Return the operation that this rewrite operates on.
589  Operation *getOperation() const { return op; }
590 
591  static bool classof(const IRRewrite *rewrite) {
592  return rewrite->getKind() >= Kind::MoveOperation &&
593  rewrite->getKind() <= Kind::UnresolvedMaterialization;
594  }
595 
596 protected:
597  OperationRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl,
598  Operation *op)
599  : IRRewrite(kind, rewriterImpl), op(op) {}
600 
601  // The operation that this rewrite operates on.
602  Operation *op;
603 };
604 
605 /// Moving of an operation. This rewrite is immediately reflected in the IR.
606 class MoveOperationRewrite : public OperationRewrite {
607 public:
608  MoveOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
609  Operation *op, Block *block, Operation *insertBeforeOp)
610  : OperationRewrite(Kind::MoveOperation, rewriterImpl, op), block(block),
611  insertBeforeOp(insertBeforeOp) {}
612 
613  static bool classof(const IRRewrite *rewrite) {
614  return rewrite->getKind() == Kind::MoveOperation;
615  }
616 
617  void commit(RewriterBase &rewriter) override {
618  // The operation was already moved. Just inform the listener.
619  if (auto *listener = rewriter.getListener()) {
620  // Note: `previousIt` cannot be passed because this is a delayed
621  // notification and iterators into past IR state cannot be represented.
622  listener->notifyOperationInserted(
623  op, /*previous=*/OpBuilder::InsertPoint(/*insertBlock=*/block,
624  /*insertPt=*/{}));
625  }
626  }
627 
628  void rollback() override {
629  // Move the operation back to its original position.
630  Block::iterator before =
631  insertBeforeOp ? Block::iterator(insertBeforeOp) : block->end();
632  block->getOperations().splice(before, op->getBlock()->getOperations(), op);
633  }
634 
635 private:
636  // The block in which this operation was previously contained.
637  Block *block;
638 
639  // The original successor of this operation before it was moved. "nullptr"
640  // if this operation was the only operation in the region.
641  Operation *insertBeforeOp;
642 };
643 
644 /// In-place modification of an op. This rewrite is immediately reflected in
645 /// the IR. The previous state of the operation is stored in this object.
646 class ModifyOperationRewrite : public OperationRewrite {
647 public:
648  ModifyOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
649  Operation *op)
650  : OperationRewrite(Kind::ModifyOperation, rewriterImpl, op),
651  name(op->getName()), loc(op->getLoc()), attrs(op->getAttrDictionary()),
652  operands(op->operand_begin(), op->operand_end()),
653  successors(op->successor_begin(), op->successor_end()) {
654  if (OpaqueProperties prop = op->getPropertiesStorage()) {
655  // Make a copy of the properties.
656  propertiesStorage = operator new(op->getPropertiesStorageSize());
657  OpaqueProperties propCopy(propertiesStorage);
658  name.initOpProperties(propCopy, /*init=*/prop);
659  }
660  }
661 
662  static bool classof(const IRRewrite *rewrite) {
663  return rewrite->getKind() == Kind::ModifyOperation;
664  }
665 
666  ~ModifyOperationRewrite() override {
667  assert(!propertiesStorage &&
668  "rewrite was neither committed nor rolled back");
669  }
670 
671  void commit(RewriterBase &rewriter) override {
672  // Notify the listener that the operation was modified in-place.
673  if (auto *listener =
674  dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener()))
675  listener->notifyOperationModified(op);
676 
677  if (propertiesStorage) {
678  OpaqueProperties propCopy(propertiesStorage);
679  // Note: The operation may have been erased in the mean time, so
680  // OperationName must be stored in this object.
681  name.destroyOpProperties(propCopy);
682  operator delete(propertiesStorage);
683  propertiesStorage = nullptr;
684  }
685  }
686 
687  void rollback() override {
688  op->setLoc(loc);
689  op->setAttrs(attrs);
690  op->setOperands(operands);
691  for (const auto &it : llvm::enumerate(successors))
692  op->setSuccessor(it.value(), it.index());
693  if (propertiesStorage) {
694  OpaqueProperties propCopy(propertiesStorage);
695  op->copyProperties(propCopy);
696  name.destroyOpProperties(propCopy);
697  operator delete(propertiesStorage);
698  propertiesStorage = nullptr;
699  }
700  }
701 
702 private:
703  OperationName name;
704  LocationAttr loc;
705  DictionaryAttr attrs;
706  SmallVector<Value, 8> operands;
707  SmallVector<Block *, 2> successors;
708  void *propertiesStorage = nullptr;
709 };
710 
711 /// Replacing an operation. Erasing an operation is treated as a special case
712 /// with "null" replacements. This rewrite is not immediately reflected in the
713 /// IR. An internal IR mapping is updated, but values are not replaced and the
714 /// original op is not erased until the rewrite is committed.
715 class ReplaceOperationRewrite : public OperationRewrite {
716 public:
717  ReplaceOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
718  Operation *op, const TypeConverter *converter)
719  : OperationRewrite(Kind::ReplaceOperation, rewriterImpl, op),
720  converter(converter) {}
721 
722  static bool classof(const IRRewrite *rewrite) {
723  return rewrite->getKind() == Kind::ReplaceOperation;
724  }
725 
726  void commit(RewriterBase &rewriter) override;
727 
728  void rollback() override;
729 
730  void cleanup(RewriterBase &rewriter) override;
731 
732 private:
733  /// An optional type converter that can be used to materialize conversions
734  /// between the new and old values if necessary.
735  const TypeConverter *converter;
736 };
737 
738 class CreateOperationRewrite : public OperationRewrite {
739 public:
740  CreateOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
741  Operation *op)
742  : OperationRewrite(Kind::CreateOperation, rewriterImpl, op) {}
743 
744  static bool classof(const IRRewrite *rewrite) {
745  return rewrite->getKind() == Kind::CreateOperation;
746  }
747 
748  void commit(RewriterBase &rewriter) override {
749  // The operation was already created and inserted. Just inform the listener.
750  if (auto *listener = rewriter.getListener())
751  listener->notifyOperationInserted(op, /*previous=*/{});
752  }
753 
754  void rollback() override;
755 };
756 
757 /// The type of materialization.
758 enum MaterializationKind {
759  /// This materialization materializes a conversion from an illegal type to a
760  /// legal one.
761  Target,
762 
763  /// This materialization materializes a conversion from a legal type back to
764  /// an illegal one.
765  Source
766 };
767 
768 /// An unresolved materialization, i.e., a "builtin.unrealized_conversion_cast"
769 /// op. Unresolved materializations are erased at the end of the dialect
770 /// conversion.
771 class UnresolvedMaterializationRewrite : public OperationRewrite {
772 public:
773  UnresolvedMaterializationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
774  UnrealizedConversionCastOp op,
775  const TypeConverter *converter,
776  MaterializationKind kind, Type originalType,
777  ValueVector mappedValues);
778 
779  static bool classof(const IRRewrite *rewrite) {
780  return rewrite->getKind() == Kind::UnresolvedMaterialization;
781  }
782 
783  void rollback() override;
784 
785  UnrealizedConversionCastOp getOperation() const {
786  return cast<UnrealizedConversionCastOp>(op);
787  }
788 
789  /// Return the type converter of this materialization (which may be null).
790  const TypeConverter *getConverter() const {
791  return converterAndKind.getPointer();
792  }
793 
794  /// Return the kind of this materialization.
795  MaterializationKind getMaterializationKind() const {
796  return converterAndKind.getInt();
797  }
798 
799  /// Return the original type of the SSA value.
800  Type getOriginalType() const { return originalType; }
801 
802 private:
803  /// The corresponding type converter to use when resolving this
804  /// materialization, and the kind of this materialization.
805  llvm::PointerIntPair<const TypeConverter *, 2, MaterializationKind>
806  converterAndKind;
807 
808  /// The original type of the SSA value. Only used for target
809  /// materializations.
810  Type originalType;
811 
812  /// The values in the conversion value mapping that are being replaced by the
813  /// results of this unresolved materialization.
814  ValueVector mappedValues;
815 };
816 } // namespace
817 
818 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
819 /// Return "true" if there is an operation rewrite that matches the specified
820 /// rewrite type and operation among the given rewrites.
821 template <typename RewriteTy, typename R>
822 static bool hasRewrite(R &&rewrites, Operation *op) {
823  return any_of(std::forward<R>(rewrites), [&](auto &rewrite) {
824  auto *rewriteTy = dyn_cast<RewriteTy>(rewrite.get());
825  return rewriteTy && rewriteTy->getOperation() == op;
826  });
827 }
828 
829 /// Return "true" if there is a block rewrite that matches the specified
830 /// rewrite type and block among the given rewrites.
831 template <typename RewriteTy, typename R>
832 static bool hasRewrite(R &&rewrites, Block *block) {
833  return any_of(std::forward<R>(rewrites), [&](auto &rewrite) {
834  auto *rewriteTy = dyn_cast<RewriteTy>(rewrite.get());
835  return rewriteTy && rewriteTy->getBlock() == block;
836  });
837 }
838 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
839 
840 //===----------------------------------------------------------------------===//
841 // ConversionPatternRewriterImpl
842 //===----------------------------------------------------------------------===//
843 namespace mlir {
844 namespace detail {
847  const ConversionConfig &config)
848  : context(ctx), eraseRewriter(ctx), config(config) {}
849 
850  //===--------------------------------------------------------------------===//
851  // State Management
852  //===--------------------------------------------------------------------===//
853 
854  /// Return the current state of the rewriter.
855  RewriterState getCurrentState();
856 
857  /// Apply all requested operation rewrites. This method is invoked when the
858  /// conversion process succeeds.
859  void applyRewrites();
860 
861  /// Reset the state of the rewriter to a previously saved point.
862  void resetState(RewriterState state);
863 
864  /// Append a rewrite. Rewrites are committed upon success and rolled back upon
865  /// failure.
866  template <typename RewriteTy, typename... Args>
867  void appendRewrite(Args &&...args) {
868  rewrites.push_back(
869  std::make_unique<RewriteTy>(*this, std::forward<Args>(args)...));
870  }
871 
872  /// Undo the rewrites (motions, splits) one by one in reverse order until
873  /// "numRewritesToKeep" rewrites remains.
874  void undoRewrites(unsigned numRewritesToKeep = 0);
875 
876  /// Remap the given values to those with potentially different types. Returns
877  /// success if the values could be remapped, failure otherwise. `valueDiagTag`
878  /// is the tag used when describing a value within a diagnostic, e.g.
879  /// "operand".
880  LogicalResult remapValues(StringRef valueDiagTag,
881  std::optional<Location> inputLoc,
882  PatternRewriter &rewriter, ValueRange values,
883  SmallVector<ValueVector> &remapped);
884 
885  /// Return "true" if the given operation is ignored, and does not need to be
886  /// converted.
887  bool isOpIgnored(Operation *op) const;
888 
889  /// Return "true" if the given operation was replaced or erased.
890  bool wasOpReplaced(Operation *op) const;
891 
892  //===--------------------------------------------------------------------===//
893  // Type Conversion
894  //===--------------------------------------------------------------------===//
895 
896  /// Convert the types of block arguments within the given region.
897  FailureOr<Block *>
898  convertRegionTypes(ConversionPatternRewriter &rewriter, Region *region,
899  const TypeConverter &converter,
900  TypeConverter::SignatureConversion *entryConversion);
901 
902  /// Apply the given signature conversion on the given block. The new block
903  /// containing the updated signature is returned. If no conversions were
904  /// necessary, e.g. if the block has no arguments, `block` is returned.
905  /// `converter` is used to generate any necessary cast operations that
906  /// translate between the origin argument types and those specified in the
907  /// signature conversion.
908  Block *applySignatureConversion(
909  ConversionPatternRewriter &rewriter, Block *block,
910  const TypeConverter *converter,
911  TypeConverter::SignatureConversion &signatureConversion);
912 
913  //===--------------------------------------------------------------------===//
914  // Materializations
915  //===--------------------------------------------------------------------===//
916 
917  /// Build an unresolved materialization operation given a range of output
918  /// types and a list of input operands. Returns the inputs if they their
919  /// types match the output types.
920  ///
921  /// If a cast op was built, it can optionally be returned with the `castOp`
922  /// output argument.
923  ///
924  /// If `valuesToMap` is set to a non-null Value, then that value is mapped to
925  /// the results of the unresolved materialization in the conversion value
926  /// mapping.
927  ValueRange buildUnresolvedMaterialization(
928  MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
929  ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes,
930  Type originalType, const TypeConverter *converter,
931  UnrealizedConversionCastOp *castOp = nullptr);
932 
933  /// Find a replacement value for the given SSA value in the conversion value
934  /// mapping. The replacement value must have the same type as the given SSA
935  /// value. If there is no replacement value with the correct type, find the
936  /// latest replacement value (regardless of the type) and build a source
937  /// materialization.
938  Value findOrBuildReplacementValue(Value value,
939  const TypeConverter *converter);
940 
941  //===--------------------------------------------------------------------===//
942  // Rewriter Notification Hooks
943  //===--------------------------------------------------------------------===//
944 
945  //// Notifies that an op was inserted.
946  void notifyOperationInserted(Operation *op,
947  OpBuilder::InsertPoint previous) override;
948 
949  /// Notifies that an op is about to be replaced with the given values.
950  void notifyOpReplaced(Operation *op, ArrayRef<ValueRange> newValues);
951 
952  /// Notifies that a block is about to be erased.
953  void notifyBlockIsBeingErased(Block *block);
954 
955  /// Notifies that a block was inserted.
956  void notifyBlockInserted(Block *block, Region *previous,
957  Region::iterator previousIt) override;
958 
959  /// Notifies that a block is being inlined into another block.
960  void notifyBlockBeingInlined(Block *block, Block *srcBlock,
961  Block::iterator before);
962 
963  /// Notifies that a pattern match failed for the given reason.
964  void
965  notifyMatchFailure(Location loc,
966  function_ref<void(Diagnostic &)> reasonCallback) override;
967 
968  //===--------------------------------------------------------------------===//
969  // IR Erasure
970  //===--------------------------------------------------------------------===//
971 
972  /// A rewriter that keeps track of erased ops and blocks. It ensures that no
973  /// operation or block is erased multiple times. This rewriter assumes that
974  /// no new IR is created between calls to `eraseOp`/`eraseBlock`.
976  public:
978  : RewriterBase(context, /*listener=*/this) {}
979 
980  /// Erase the given op (unless it was already erased).
981  void eraseOp(Operation *op) override {
982  if (wasErased(op))
983  return;
984  op->dropAllUses();
986  }
987 
988  /// Erase the given block (unless it was already erased).
989  void eraseBlock(Block *block) override {
990  if (wasErased(block))
991  return;
992  assert(block->empty() && "expected empty block");
993  block->dropAllDefinedValueUses();
995  }
996 
997  bool wasErased(void *ptr) const { return erased.contains(ptr); }
998 
999  void notifyOperationErased(Operation *op) override { erased.insert(op); }
1000 
1001  void notifyBlockErased(Block *block) override { erased.insert(block); }
1002 
1003  private:
1004  /// Pointers to all erased operations and blocks.
1005  DenseSet<void *> erased;
1006  };
1007 
1008  //===--------------------------------------------------------------------===//
1009  // State
1010  //===--------------------------------------------------------------------===//
1011 
1012  /// MLIR context.
1014 
1015  /// A rewriter that keeps track of ops/block that were already erased and
1016  /// skips duplicate op/block erasures. This rewriter is used during the
1017  /// "cleanup" phase.
1019 
1020  // Mapping between replaced values that differ in type. This happens when
1021  // replacing a value with one of a different type.
1022  ConversionValueMapping mapping;
1023 
1024  /// Ordered list of block operations (creations, splits, motions).
1026 
1027  /// A set of operations that should no longer be considered for legalization.
1028  /// E.g., ops that are recursively legal. Ops that were replaced/erased are
1029  /// tracked separately.
1031 
1032  /// A set of operations that were replaced/erased. Such ops are not erased
1033  /// immediately but only when the dialect conversion succeeds. In the mean
1034  /// time, they should no longer be considered for legalization and any attempt
1035  /// to modify/access them is invalid rewriter API usage.
1037 
1038  /// A mapping of all unresolved materializations (UnrealizedConversionCastOp)
1039  /// to the corresponding rewrite objects.
1042 
1043  /// The current type converter, or nullptr if no type converter is currently
1044  /// active.
1045  const TypeConverter *currentTypeConverter = nullptr;
1046 
1047  /// A mapping of regions to type converters that should be used when
1048  /// converting the arguments of blocks within that region.
1050 
1051  /// Dialect conversion configuration.
1053 
1054 #ifndef NDEBUG
1055  /// A set of operations that have pending updates. This tracking isn't
1056  /// strictly necessary, and is thus only active during debug builds for extra
1057  /// verification.
1059 
1060  /// A logger used to emit diagnostics during the conversion process.
1061  llvm::ScopedPrinter logger{llvm::dbgs()};
1062 #endif
1063 };
1064 } // namespace detail
1065 } // namespace mlir
1066 
1067 const ConversionConfig &IRRewrite::getConfig() const {
1068  return rewriterImpl.config;
1069 }
1070 
1071 void BlockTypeConversionRewrite::commit(RewriterBase &rewriter) {
1072  // Inform the listener about all IR modifications that have already taken
1073  // place: References to the original block have been replaced with the new
1074  // block.
1075  if (auto *listener =
1076  dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener()))
1077  for (Operation *op : getNewBlock()->getUsers())
1078  listener->notifyOperationModified(op);
1079 }
1080 
1081 void BlockTypeConversionRewrite::rollback() {
1082  getNewBlock()->replaceAllUsesWith(getOrigBlock());
1083 }
1084 
1085 void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) {
1086  Value repl = rewriterImpl.findOrBuildReplacementValue(arg, converter);
1087  if (!repl)
1088  return;
1089 
1090  if (isa<BlockArgument>(repl)) {
1091  rewriter.replaceAllUsesWith(arg, repl);
1092  return;
1093  }
1094 
1095  // If the replacement value is an operation, we check to make sure that we
1096  // don't replace uses that are within the parent operation of the
1097  // replacement value.
1098  Operation *replOp = cast<OpResult>(repl).getOwner();
1099  Block *replBlock = replOp->getBlock();
1100  rewriter.replaceUsesWithIf(arg, repl, [&](OpOperand &operand) {
1101  Operation *user = operand.getOwner();
1102  return user->getBlock() != replBlock || replOp->isBeforeInBlock(user);
1103  });
1104 }
1105 
1106 void ReplaceBlockArgRewrite::rollback() { rewriterImpl.mapping.erase({arg}); }
1107 
1108 void ReplaceOperationRewrite::commit(RewriterBase &rewriter) {
1109  auto *listener =
1110  dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener());
1111 
1112  // Compute replacement values.
1113  SmallVector<Value> replacements =
1114  llvm::map_to_vector(op->getResults(), [&](OpResult result) {
1115  return rewriterImpl.findOrBuildReplacementValue(result, converter);
1116  });
1117 
1118  // Notify the listener that the operation is about to be replaced.
1119  if (listener)
1120  listener->notifyOperationReplaced(op, replacements);
1121 
1122  // Replace all uses with the new values.
1123  for (auto [result, newValue] :
1124  llvm::zip_equal(op->getResults(), replacements))
1125  if (newValue)
1126  rewriter.replaceAllUsesWith(result, newValue);
1127 
1128  // The original op will be erased, so remove it from the set of unlegalized
1129  // ops.
1130  if (getConfig().unlegalizedOps)
1131  getConfig().unlegalizedOps->erase(op);
1132 
1133  // Notify the listener that the operation (and its nested operations) was
1134  // erased.
1135  if (listener) {
1137  [&](Operation *op) { listener->notifyOperationErased(op); });
1138  }
1139 
1140  // Do not erase the operation yet. It may still be referenced in `mapping`.
1141  // Just unlink it for now and erase it during cleanup.
1142  op->getBlock()->getOperations().remove(op);
1143 }
1144 
1145 void ReplaceOperationRewrite::rollback() {
1146  for (auto result : op->getResults())
1147  rewriterImpl.mapping.erase({result});
1148 }
1149 
1150 void ReplaceOperationRewrite::cleanup(RewriterBase &rewriter) {
1151  rewriter.eraseOp(op);
1152 }
1153 
1154 void CreateOperationRewrite::rollback() {
1155  for (Region &region : op->getRegions()) {
1156  while (!region.getBlocks().empty())
1157  region.getBlocks().remove(region.getBlocks().begin());
1158  }
1159  op->dropAllUses();
1160  op->erase();
1161 }
1162 
1163 UnresolvedMaterializationRewrite::UnresolvedMaterializationRewrite(
1164  ConversionPatternRewriterImpl &rewriterImpl, UnrealizedConversionCastOp op,
1165  const TypeConverter *converter, MaterializationKind kind, Type originalType,
1166  ValueVector mappedValues)
1167  : OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
1168  converterAndKind(converter, kind), originalType(originalType),
1169  mappedValues(std::move(mappedValues)) {
1170  assert((!originalType || kind == MaterializationKind::Target) &&
1171  "original type is valid only for target materializations");
1172  rewriterImpl.unresolvedMaterializations[op] = this;
1173 }
1174 
1175 void UnresolvedMaterializationRewrite::rollback() {
1176  if (!mappedValues.empty())
1177  rewriterImpl.mapping.erase(mappedValues);
1178  rewriterImpl.unresolvedMaterializations.erase(getOperation());
1179  op->erase();
1180 }
1181 
1183  // Commit all rewrites.
1184  IRRewriter rewriter(context, config.listener);
1185  // Note: New rewrites may be added during the "commit" phase and the
1186  // `rewrites` vector may reallocate.
1187  for (size_t i = 0; i < rewrites.size(); ++i)
1188  rewrites[i]->commit(rewriter);
1189 
1190  // Clean up all rewrites.
1191  for (auto &rewrite : rewrites)
1192  rewrite->cleanup(eraseRewriter);
1193 }
1194 
1195 //===----------------------------------------------------------------------===//
1196 // State Management
1197 
1199  return RewriterState(rewrites.size(), ignoredOps.size(), replacedOps.size());
1200 }
1201 
1203  // Undo any rewrites.
1204  undoRewrites(state.numRewrites);
1205 
1206  // Pop all of the recorded ignored operations that are no longer valid.
1207  while (ignoredOps.size() != state.numIgnoredOperations)
1208  ignoredOps.pop_back();
1209 
1210  while (replacedOps.size() != state.numReplacedOps)
1211  replacedOps.pop_back();
1212 }
1213 
1214 void ConversionPatternRewriterImpl::undoRewrites(unsigned numRewritesToKeep) {
1215  for (auto &rewrite :
1216  llvm::reverse(llvm::drop_begin(rewrites, numRewritesToKeep)))
1217  rewrite->rollback();
1218  rewrites.resize(numRewritesToKeep);
1219 }
1220 
1222  StringRef valueDiagTag, std::optional<Location> inputLoc,
1223  PatternRewriter &rewriter, ValueRange values,
1224  SmallVector<ValueVector> &remapped) {
1225  remapped.reserve(llvm::size(values));
1226 
1227  for (const auto &it : llvm::enumerate(values)) {
1228  Value operand = it.value();
1229  Type origType = operand.getType();
1230  Location operandLoc = inputLoc ? *inputLoc : operand.getLoc();
1231 
1232  if (!currentTypeConverter) {
1233  // The current pattern does not have a type converter. I.e., it does not
1234  // distinguish between legal and illegal types. For each operand, simply
1235  // pass through the most recently mapped values.
1236  remapped.push_back(mapping.lookupOrDefault(operand));
1237  continue;
1238  }
1239 
1240  // If there is no legal conversion, fail to match this pattern.
1241  SmallVector<Type, 1> legalTypes;
1242  if (failed(currentTypeConverter->convertType(origType, legalTypes))) {
1243  notifyMatchFailure(operandLoc, [=](Diagnostic &diag) {
1244  diag << "unable to convert type for " << valueDiagTag << " #"
1245  << it.index() << ", type was " << origType;
1246  });
1247  return failure();
1248  }
1249  // If a type is converted to 0 types, there is nothing to do.
1250  if (legalTypes.empty()) {
1251  remapped.push_back({});
1252  continue;
1253  }
1254 
1255  ValueVector repl = mapping.lookupOrDefault(operand, legalTypes);
1256  if (!repl.empty() && TypeRange(ValueRange(repl)) == legalTypes) {
1257  // Mapped values have the correct type or there is an existing
1258  // materialization. Or the operand is not mapped at all and has the
1259  // correct type.
1260  remapped.push_back(std::move(repl));
1261  continue;
1262  }
1263 
1264  // Create a materialization for the most recently mapped values.
1265  repl = mapping.lookupOrDefault(operand);
1267  MaterializationKind::Target, computeInsertPoint(repl), operandLoc,
1268  /*valuesToMap=*/repl, /*inputs=*/repl, /*outputTypes=*/legalTypes,
1269  /*originalType=*/origType, currentTypeConverter);
1270  remapped.push_back(castValues);
1271  }
1272  return success();
1273 }
1274 
1276  // Check to see if this operation is ignored or was replaced.
1277  return replacedOps.count(op) || ignoredOps.count(op);
1278 }
1279 
1281  // Check to see if this operation was replaced.
1282  return replacedOps.count(op);
1283 }
1284 
1285 //===----------------------------------------------------------------------===//
1286 // Type Conversion
1287 
1289  ConversionPatternRewriter &rewriter, Region *region,
1290  const TypeConverter &converter,
1291  TypeConverter::SignatureConversion *entryConversion) {
1292  regionToConverter[region] = &converter;
1293  if (region->empty())
1294  return nullptr;
1295 
1296  // Convert the arguments of each non-entry block within the region.
1297  for (Block &block :
1298  llvm::make_early_inc_range(llvm::drop_begin(*region, 1))) {
1299  // Compute the signature for the block with the provided converter.
1300  std::optional<TypeConverter::SignatureConversion> conversion =
1301  converter.convertBlockSignature(&block);
1302  if (!conversion)
1303  return failure();
1304  // Convert the block with the computed signature.
1305  applySignatureConversion(rewriter, &block, &converter, *conversion);
1306  }
1307 
1308  // Convert the entry block. If an entry signature conversion was provided,
1309  // use that one. Otherwise, compute the signature with the type converter.
1310  if (entryConversion)
1311  return applySignatureConversion(rewriter, &region->front(), &converter,
1312  *entryConversion);
1313  std::optional<TypeConverter::SignatureConversion> conversion =
1314  converter.convertBlockSignature(&region->front());
1315  if (!conversion)
1316  return failure();
1317  return applySignatureConversion(rewriter, &region->front(), &converter,
1318  *conversion);
1319 }
1320 
1322  ConversionPatternRewriter &rewriter, Block *block,
1323  const TypeConverter *converter,
1324  TypeConverter::SignatureConversion &signatureConversion) {
1325 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
1326  // A block cannot be converted multiple times.
1327  if (hasRewrite<BlockTypeConversionRewrite>(rewrites, block))
1328  llvm::report_fatal_error("block was already converted");
1329 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
1330 
1331  OpBuilder::InsertionGuard g(rewriter);
1332 
1333  // If no arguments are being changed or added, there is nothing to do.
1334  unsigned origArgCount = block->getNumArguments();
1335  auto convertedTypes = signatureConversion.getConvertedTypes();
1336  if (llvm::equal(block->getArgumentTypes(), convertedTypes))
1337  return block;
1338 
1339  // Compute the locations of all block arguments in the new block.
1340  SmallVector<Location> newLocs(convertedTypes.size(),
1341  rewriter.getUnknownLoc());
1342  for (unsigned i = 0; i < origArgCount; ++i) {
1343  auto inputMap = signatureConversion.getInputMapping(i);
1344  if (!inputMap || inputMap->replacementValue)
1345  continue;
1346  Location origLoc = block->getArgument(i).getLoc();
1347  for (unsigned j = 0; j < inputMap->size; ++j)
1348  newLocs[inputMap->inputNo + j] = origLoc;
1349  }
1350 
1351  // Insert a new block with the converted block argument types and move all ops
1352  // from the old block to the new block.
1353  Block *newBlock =
1354  rewriter.createBlock(block->getParent(), std::next(block->getIterator()),
1355  convertedTypes, newLocs);
1356 
1357  // If a listener is attached to the dialect conversion, ops cannot be moved
1358  // to the destination block in bulk ("fast path"). This is because at the time
1359  // the notifications are sent, it is unknown which ops were moved. Instead,
1360  // ops should be moved one-by-one ("slow path"), so that a separate
1361  // `MoveOperationRewrite` is enqueued for each moved op. Moving ops in bulk is
1362  // a bit more efficient, so we try to do that when possible.
1363  bool fastPath = !config.listener;
1364  if (fastPath) {
1365  appendRewrite<InlineBlockRewrite>(newBlock, block, newBlock->end());
1366  newBlock->getOperations().splice(newBlock->end(), block->getOperations());
1367  } else {
1368  while (!block->empty())
1369  rewriter.moveOpBefore(&block->front(), newBlock, newBlock->end());
1370  }
1371 
1372  // Replace all uses of the old block with the new block.
1373  block->replaceAllUsesWith(newBlock);
1374 
1375  for (unsigned i = 0; i != origArgCount; ++i) {
1376  BlockArgument origArg = block->getArgument(i);
1377  Type origArgType = origArg.getType();
1378 
1379  std::optional<TypeConverter::SignatureConversion::InputMapping> inputMap =
1380  signatureConversion.getInputMapping(i);
1381  if (!inputMap) {
1382  // This block argument was dropped and no replacement value was provided.
1383  // Materialize a replacement value "out of thin air".
1385  MaterializationKind::Source,
1386  OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(),
1387  /*valuesToMap=*/{origArg}, /*inputs=*/ValueRange(),
1388  /*outputType=*/origArgType, /*originalType=*/Type(), converter);
1389  appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
1390  continue;
1391  }
1392 
1393  if (Value repl = inputMap->replacementValue) {
1394  // This block argument was dropped and a replacement value was provided.
1395  assert(inputMap->size == 0 &&
1396  "invalid to provide a replacement value when the argument isn't "
1397  "dropped");
1398  mapping.map(origArg, repl);
1399  appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
1400  continue;
1401  }
1402 
1403  // This is a 1->1+ mapping.
1404  auto replArgs =
1405  newBlock->getArguments().slice(inputMap->inputNo, inputMap->size);
1406  ValueVector replArgVals = llvm::to_vector_of<Value, 1>(replArgs);
1407  mapping.map(origArg, std::move(replArgVals));
1408  appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
1409  }
1410 
1411  appendRewrite<BlockTypeConversionRewrite>(/*origBlock=*/block, newBlock);
1412 
1413  // Erase the old block. (It is just unlinked for now and will be erased during
1414  // cleanup.)
1415  rewriter.eraseBlock(block);
1416 
1417  return newBlock;
1418 }
1419 
1420 //===----------------------------------------------------------------------===//
1421 // Materializations
1422 //===----------------------------------------------------------------------===//
1423 
1424 /// Build an unresolved materialization operation given an output type and set
1425 /// of input operands.
1427  MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
1428  ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes,
1429  Type originalType, const TypeConverter *converter,
1430  UnrealizedConversionCastOp *castOp) {
1431  assert((!originalType || kind == MaterializationKind::Target) &&
1432  "original type is valid only for target materializations");
1433  assert(TypeRange(inputs) != outputTypes &&
1434  "materialization is not necessary");
1435 
1436  // Create an unresolved materialization. We use a new OpBuilder to avoid
1437  // tracking the materialization like we do for other operations.
1438  OpBuilder builder(outputTypes.front().getContext());
1439  builder.setInsertionPoint(ip.getBlock(), ip.getPoint());
1440  auto convertOp =
1441  builder.create<UnrealizedConversionCastOp>(loc, outputTypes, inputs);
1442  if (!valuesToMap.empty())
1443  mapping.map(valuesToMap, convertOp.getResults());
1444  if (castOp)
1445  *castOp = convertOp;
1446  appendRewrite<UnresolvedMaterializationRewrite>(
1447  convertOp, converter, kind, originalType, std::move(valuesToMap));
1448  return convertOp.getResults();
1449 }
1450 
1452  Value value, const TypeConverter *converter) {
1453  // Try to find a replacement value with the same type in the conversion value
1454  // mapping. This includes cached materializations. We try to reuse those
1455  // instead of generating duplicate IR.
1456  ValueVector repl = mapping.lookupOrNull(value, value.getType());
1457  if (!repl.empty())
1458  return repl.front();
1459 
1460  // Check if the value is dead. No replacement value is needed in that case.
1461  // This is an approximate check that may have false negatives but does not
1462  // require computing and traversing an inverse mapping. (We may end up
1463  // building source materializations that are never used and that fold away.)
1464  if (llvm::all_of(value.getUsers(),
1465  [&](Operation *op) { return replacedOps.contains(op); }) &&
1466  !mapping.isMappedTo(value))
1467  return Value();
1468 
1469  // No replacement value was found. Get the latest replacement value
1470  // (regardless of the type) and build a source materialization to the
1471  // original type.
1472  repl = mapping.lookupOrNull(value);
1473  if (repl.empty()) {
1474  // No replacement value is registered in the mapping. This means that the
1475  // value is dropped and no longer needed. (If the value were still needed,
1476  // a source materialization producing a replacement value "out of thin air"
1477  // would have already been created during `replaceOp` or
1478  // `applySignatureConversion`.)
1479  return Value();
1480  }
1481 
1482  // Note: `computeInsertPoint` computes the "earliest" insertion point at
1483  // which all values in `repl` are defined. It is important to emit the
1484  // materialization at that location because the same materialization may be
1485  // reused in a different context. (That's because materializations are cached
1486  // in the conversion value mapping.) The insertion point of the
1487  // materialization must be valid for all future users that may be created
1488  // later in the conversion process.
1489  Value castValue =
1490  buildUnresolvedMaterialization(MaterializationKind::Source,
1491  computeInsertPoint(repl), value.getLoc(),
1492  /*valuesToMap=*/repl, /*inputs=*/repl,
1493  /*outputType=*/value.getType(),
1494  /*originalType=*/Type(), converter)
1495  .front();
1496  return castValue;
1497 }
1498 
1499 //===----------------------------------------------------------------------===//
1500 // Rewriter Notification Hooks
1501 
1503  Operation *op, OpBuilder::InsertPoint previous) {
1504  LLVM_DEBUG({
1505  logger.startLine() << "** Insert : '" << op->getName() << "'(" << op
1506  << ")\n";
1507  });
1508  assert(!wasOpReplaced(op->getParentOp()) &&
1509  "attempting to insert into a block within a replaced/erased op");
1510 
1511  if (!previous.isSet()) {
1512  // This is a newly created op.
1513  appendRewrite<CreateOperationRewrite>(op);
1514  return;
1515  }
1516  Operation *prevOp = previous.getPoint() == previous.getBlock()->end()
1517  ? nullptr
1518  : &*previous.getPoint();
1519  appendRewrite<MoveOperationRewrite>(op, previous.getBlock(), prevOp);
1520 }
1521 
1523  Operation *op, ArrayRef<ValueRange> newValues) {
1524  assert(newValues.size() == op->getNumResults());
1525  assert(!ignoredOps.contains(op) && "operation was already replaced");
1526 
1527  // Check if replaced op is an unresolved materialization, i.e., an
1528  // unrealized_conversion_cast op that was created by the conversion driver.
1529  bool isUnresolvedMaterialization = false;
1530  if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op))
1531  if (unresolvedMaterializations.contains(castOp))
1532  isUnresolvedMaterialization = true;
1533 
1534  // Create mappings for each of the new result values.
1535  for (auto [repl, result] : llvm::zip_equal(newValues, op->getResults())) {
1536  if (repl.empty()) {
1537  // This result was dropped and no replacement value was provided.
1538  if (isUnresolvedMaterialization) {
1539  // Do not create another materializations if we are erasing a
1540  // materialization.
1541  continue;
1542  }
1543 
1544  // Materialize a replacement value "out of thin air".
1546  MaterializationKind::Source, computeInsertPoint(result),
1547  result.getLoc(), /*valuesToMap=*/{result}, /*inputs=*/ValueRange(),
1548  /*outputType=*/result.getType(), /*originalType=*/Type(),
1550  continue;
1551  } else {
1552  // Make sure that the user does not mess with unresolved materializations
1553  // that were inserted by the conversion driver. We keep track of these
1554  // ops in internal data structures. Erasing them must be allowed because
1555  // this can happen when the user is erasing an entire block (including
1556  // its body). But replacing them with another value should be forbidden
1557  // to avoid problems with the `mapping`.
1558  assert(!isUnresolvedMaterialization &&
1559  "attempting to replace an unresolved materialization");
1560  }
1561 
1562  // Remap result to replacement value.
1563  if (repl.empty())
1564  continue;
1565  mapping.map(result, repl);
1566  }
1567 
1568  appendRewrite<ReplaceOperationRewrite>(op, currentTypeConverter);
1569  // Mark this operation and all nested ops as replaced.
1570  op->walk([&](Operation *op) { replacedOps.insert(op); });
1571 }
1572 
1574  appendRewrite<EraseBlockRewrite>(block);
1575 }
1576 
1578  Block *block, Region *previous, Region::iterator previousIt) {
1579  assert(!wasOpReplaced(block->getParentOp()) &&
1580  "attempting to insert into a region within a replaced/erased op");
1581  LLVM_DEBUG(
1582  {
1583  Operation *parent = block->getParentOp();
1584  if (parent) {
1585  logger.startLine() << "** Insert Block into : '" << parent->getName()
1586  << "'(" << parent << ")\n";
1587  } else {
1588  logger.startLine()
1589  << "** Insert Block into detached Region (nullptr parent op)'";
1590  }
1591  });
1592 
1593  if (!previous) {
1594  // This is a newly created block.
1595  appendRewrite<CreateBlockRewrite>(block);
1596  return;
1597  }
1598  Block *prevBlock = previousIt == previous->end() ? nullptr : &*previousIt;
1599  appendRewrite<MoveBlockRewrite>(block, previous, prevBlock);
1600 }
1601 
1603  Block *block, Block *srcBlock, Block::iterator before) {
1604  appendRewrite<InlineBlockRewrite>(block, srcBlock, before);
1605 }
1606 
1608  Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
1609  LLVM_DEBUG({
1611  reasonCallback(diag);
1612  logger.startLine() << "** Failure : " << diag.str() << "\n";
1613  if (config.notifyCallback)
1615  });
1616 }
1617 
1618 //===----------------------------------------------------------------------===//
1619 // ConversionPatternRewriter
1620 //===----------------------------------------------------------------------===//
1621 
1622 ConversionPatternRewriter::ConversionPatternRewriter(
1623  MLIRContext *ctx, const ConversionConfig &config)
1624  : PatternRewriter(ctx),
1625  impl(new detail::ConversionPatternRewriterImpl(ctx, config)) {
1626  setListener(impl.get());
1627 }
1628 
1630 
1632  assert(op && newOp && "expected non-null op");
1633  replaceOp(op, newOp->getResults());
1634 }
1635 
1637  assert(op->getNumResults() == newValues.size() &&
1638  "incorrect # of replacement values");
1639  LLVM_DEBUG({
1640  impl->logger.startLine()
1641  << "** Replace : '" << op->getName() << "'(" << op << ")\n";
1642  });
1643  SmallVector<ValueRange> newVals;
1644  for (size_t i = 0; i < newValues.size(); ++i) {
1645  if (newValues[i]) {
1646  newVals.push_back(newValues.slice(i, 1));
1647  } else {
1648  newVals.push_back(ValueRange());
1649  }
1650  }
1651  impl->notifyOpReplaced(op, newVals);
1652 }
1653 
1655  Operation *op, ArrayRef<ValueRange> newValues) {
1656  assert(op->getNumResults() == newValues.size() &&
1657  "incorrect # of replacement values");
1658  LLVM_DEBUG({
1659  impl->logger.startLine()
1660  << "** Replace : '" << op->getName() << "'(" << op << ")\n";
1661  });
1662  impl->notifyOpReplaced(op, newValues);
1663 }
1664 
1666  LLVM_DEBUG({
1667  impl->logger.startLine()
1668  << "** Erase : '" << op->getName() << "'(" << op << ")\n";
1669  });
1670  SmallVector<ValueRange> nullRepls(op->getNumResults(), {});
1671  impl->notifyOpReplaced(op, nullRepls);
1672 }
1673 
1675  assert(!impl->wasOpReplaced(block->getParentOp()) &&
1676  "attempting to erase a block within a replaced/erased op");
1677 
1678  // Mark all ops for erasure.
1679  for (Operation &op : *block)
1680  eraseOp(&op);
1681 
1682  // Unlink the block from its parent region. The block is kept in the rewrite
1683  // object and will be actually destroyed when rewrites are applied. This
1684  // allows us to keep the operations in the block live and undo the removal by
1685  // re-inserting the block.
1686  impl->notifyBlockIsBeingErased(block);
1687  block->getParent()->getBlocks().remove(block);
1688 }
1689 
1691  Block *block, TypeConverter::SignatureConversion &conversion,
1692  const TypeConverter *converter) {
1693  assert(!impl->wasOpReplaced(block->getParentOp()) &&
1694  "attempting to apply a signature conversion to a block within a "
1695  "replaced/erased op");
1696  return impl->applySignatureConversion(*this, block, converter, conversion);
1697 }
1698 
1700  Region *region, const TypeConverter &converter,
1701  TypeConverter::SignatureConversion *entryConversion) {
1702  assert(!impl->wasOpReplaced(region->getParentOp()) &&
1703  "attempting to apply a signature conversion to a block within a "
1704  "replaced/erased op");
1705  return impl->convertRegionTypes(*this, region, converter, entryConversion);
1706 }
1707 
1709  Value to) {
1710  LLVM_DEBUG({
1711  Operation *parentOp = from.getOwner()->getParentOp();
1712  impl->logger.startLine() << "** Replace Argument : '" << from
1713  << "'(in region of '" << parentOp->getName()
1714  << "'(" << from.getOwner()->getParentOp() << ")\n";
1715  });
1716  impl->appendRewrite<ReplaceBlockArgRewrite>(from.getOwner(), from,
1717  impl->currentTypeConverter);
1718  impl->mapping.map(impl->mapping.lookupOrDefault(from), to);
1719 }
1720 
1722  SmallVector<ValueVector> remappedValues;
1723  if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, key,
1724  remappedValues)))
1725  return nullptr;
1726  assert(remappedValues.front().size() == 1 && "1:N conversion not supported");
1727  return remappedValues.front().front();
1728 }
1729 
1730 LogicalResult
1732  SmallVectorImpl<Value> &results) {
1733  if (keys.empty())
1734  return success();
1735  SmallVector<ValueVector> remapped;
1736  if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, keys,
1737  remapped)))
1738  return failure();
1739  for (const auto &values : remapped) {
1740  assert(values.size() == 1 && "1:N conversion not supported");
1741  results.push_back(values.front());
1742  }
1743  return success();
1744 }
1745 
1747  Block::iterator before,
1748  ValueRange argValues) {
1749 #ifndef NDEBUG
1750  assert(argValues.size() == source->getNumArguments() &&
1751  "incorrect # of argument replacement values");
1752  assert(!impl->wasOpReplaced(source->getParentOp()) &&
1753  "attempting to inline a block from a replaced/erased op");
1754  assert(!impl->wasOpReplaced(dest->getParentOp()) &&
1755  "attempting to inline a block into a replaced/erased op");
1756  auto opIgnored = [&](Operation *op) { return impl->isOpIgnored(op); };
1757  // The source block will be deleted, so it should not have any users (i.e.,
1758  // there should be no predecessors).
1759  assert(llvm::all_of(source->getUsers(), opIgnored) &&
1760  "expected 'source' to have no predecessors");
1761 #endif // NDEBUG
1762 
1763  // If a listener is attached to the dialect conversion, ops cannot be moved
1764  // to the destination block in bulk ("fast path"). This is because at the time
1765  // the notifications are sent, it is unknown which ops were moved. Instead,
1766  // ops should be moved one-by-one ("slow path"), so that a separate
1767  // `MoveOperationRewrite` is enqueued for each moved op. Moving ops in bulk is
1768  // a bit more efficient, so we try to do that when possible.
1769  bool fastPath = !impl->config.listener;
1770 
1771  if (fastPath)
1772  impl->notifyBlockBeingInlined(dest, source, before);
1773 
1774  // Replace all uses of block arguments.
1775  for (auto it : llvm::zip(source->getArguments(), argValues))
1776  replaceUsesOfBlockArgument(std::get<0>(it), std::get<1>(it));
1777 
1778  if (fastPath) {
1779  // Move all ops at once.
1780  dest->getOperations().splice(before, source->getOperations());
1781  } else {
1782  // Move op by op.
1783  while (!source->empty())
1784  moveOpBefore(&source->front(), dest, before);
1785  }
1786 
1787  // Erase the source block.
1788  eraseBlock(source);
1789 }
1790 
1792  assert(!impl->wasOpReplaced(op) &&
1793  "attempting to modify a replaced/erased op");
1794 #ifndef NDEBUG
1795  impl->pendingRootUpdates.insert(op);
1796 #endif
1797  impl->appendRewrite<ModifyOperationRewrite>(op);
1798 }
1799 
1801  assert(!impl->wasOpReplaced(op) &&
1802  "attempting to modify a replaced/erased op");
1804  // There is nothing to do here, we only need to track the operation at the
1805  // start of the update.
1806 #ifndef NDEBUG
1807  assert(impl->pendingRootUpdates.erase(op) &&
1808  "operation did not have a pending in-place update");
1809 #endif
1810 }
1811 
1813 #ifndef NDEBUG
1814  assert(impl->pendingRootUpdates.erase(op) &&
1815  "operation did not have a pending in-place update");
1816 #endif
1817  // Erase the last update for this operation.
1818  auto it = llvm::find_if(
1819  llvm::reverse(impl->rewrites), [&](std::unique_ptr<IRRewrite> &rewrite) {
1820  auto *modifyRewrite = dyn_cast<ModifyOperationRewrite>(rewrite.get());
1821  return modifyRewrite && modifyRewrite->getOperation() == op;
1822  });
1823  assert(it != impl->rewrites.rend() && "no root update started on op");
1824  (*it)->rollback();
1825  int updateIdx = std::prev(impl->rewrites.rend()) - it;
1826  impl->rewrites.erase(impl->rewrites.begin() + updateIdx);
1827 }
1828 
1830  return *impl;
1831 }
1832 
1833 //===----------------------------------------------------------------------===//
1834 // ConversionPattern
1835 //===----------------------------------------------------------------------===//
1836 
1838  ArrayRef<ValueRange> operands) const {
1839  SmallVector<Value> oneToOneOperands;
1840  oneToOneOperands.reserve(operands.size());
1841  for (ValueRange operand : operands) {
1842  if (operand.size() != 1)
1843  llvm::report_fatal_error("pattern '" + getDebugName() +
1844  "' does not support 1:N conversion");
1845  oneToOneOperands.push_back(operand.front());
1846  }
1847  return oneToOneOperands;
1848 }
1849 
1850 LogicalResult
1852  PatternRewriter &rewriter) const {
1853  auto &dialectRewriter = static_cast<ConversionPatternRewriter &>(rewriter);
1854  auto &rewriterImpl = dialectRewriter.getImpl();
1855 
1856  // Track the current conversion pattern type converter in the rewriter.
1857  llvm::SaveAndRestore currentConverterGuard(rewriterImpl.currentTypeConverter,
1858  getTypeConverter());
1859 
1860  // Remap the operands of the operation.
1861  SmallVector<ValueVector> remapped;
1862  if (failed(rewriterImpl.remapValues("operand", op->getLoc(), rewriter,
1863  op->getOperands(), remapped))) {
1864  return failure();
1865  }
1866  SmallVector<ValueRange> remappedAsRange =
1867  llvm::to_vector_of<ValueRange>(remapped);
1868  return matchAndRewrite(op, remappedAsRange, dialectRewriter);
1869 }
1870 
1871 //===----------------------------------------------------------------------===//
1872 // OperationLegalizer
1873 //===----------------------------------------------------------------------===//
1874 
1875 namespace {
1876 /// A set of rewrite patterns that can be used to legalize a given operation.
1877 using LegalizationPatterns = SmallVector<const Pattern *, 1>;
1878 
1879 /// This class defines a recursive operation legalizer.
1880 class OperationLegalizer {
1881 public:
1882  using LegalizationAction = ConversionTarget::LegalizationAction;
1883 
1884  OperationLegalizer(const ConversionTarget &targetInfo,
1886  const ConversionConfig &config);
1887 
1888  /// Returns true if the given operation is known to be illegal on the target.
1889  bool isIllegal(Operation *op) const;
1890 
1891  /// Attempt to legalize the given operation. Returns success if the operation
1892  /// was legalized, failure otherwise.
1893  LogicalResult legalize(Operation *op, ConversionPatternRewriter &rewriter);
1894 
1895  /// Returns the conversion target in use by the legalizer.
1896  const ConversionTarget &getTarget() { return target; }
1897 
1898 private:
1899  /// Attempt to legalize the given operation by folding it.
1900  LogicalResult legalizeWithFold(Operation *op,
1901  ConversionPatternRewriter &rewriter);
1902 
1903  /// Attempt to legalize the given operation by applying a pattern. Returns
1904  /// success if the operation was legalized, failure otherwise.
1905  LogicalResult legalizeWithPattern(Operation *op,
1906  ConversionPatternRewriter &rewriter);
1907 
1908  /// Return true if the given pattern may be applied to the given operation,
1909  /// false otherwise.
1910  bool canApplyPattern(Operation *op, const Pattern &pattern,
1911  ConversionPatternRewriter &rewriter);
1912 
1913  /// Legalize the resultant IR after successfully applying the given pattern.
1914  LogicalResult legalizePatternResult(Operation *op, const Pattern &pattern,
1915  ConversionPatternRewriter &rewriter,
1916  RewriterState &curState);
1917 
1918  /// Legalizes the actions registered during the execution of a pattern.
1919  LogicalResult
1920  legalizePatternBlockRewrites(Operation *op,
1921  ConversionPatternRewriter &rewriter,
1923  RewriterState &state, RewriterState &newState);
1924  LogicalResult legalizePatternCreatedOperations(
1926  RewriterState &state, RewriterState &newState);
1927  LogicalResult legalizePatternRootUpdates(ConversionPatternRewriter &rewriter,
1929  RewriterState &state,
1930  RewriterState &newState);
1931 
1932  //===--------------------------------------------------------------------===//
1933  // Cost Model
1934  //===--------------------------------------------------------------------===//
1935 
1936  /// Build an optimistic legalization graph given the provided patterns. This
1937  /// function populates 'anyOpLegalizerPatterns' and 'legalizerPatterns' with
1938  /// patterns for operations that are not directly legal, but may be
1939  /// transitively legal for the current target given the provided patterns.
1940  void buildLegalizationGraph(
1941  LegalizationPatterns &anyOpLegalizerPatterns,
1943 
1944  /// Compute the benefit of each node within the computed legalization graph.
1945  /// This orders the patterns within 'legalizerPatterns' based upon two
1946  /// criteria:
1947  /// 1) Prefer patterns that have the lowest legalization depth, i.e.
1948  /// represent the more direct mapping to the target.
1949  /// 2) When comparing patterns with the same legalization depth, prefer the
1950  /// pattern with the highest PatternBenefit. This allows for users to
1951  /// prefer specific legalizations over others.
1952  void computeLegalizationGraphBenefit(
1953  LegalizationPatterns &anyOpLegalizerPatterns,
1955 
1956  /// Compute the legalization depth when legalizing an operation of the given
1957  /// type.
1958  unsigned computeOpLegalizationDepth(
1959  OperationName op, DenseMap<OperationName, unsigned> &minOpPatternDepth,
1961 
1962  /// Apply the conversion cost model to the given set of patterns, and return
1963  /// the smallest legalization depth of any of the patterns. See
1964  /// `computeLegalizationGraphBenefit` for the breakdown of the cost model.
1965  unsigned applyCostModelToPatterns(
1966  LegalizationPatterns &patterns,
1967  DenseMap<OperationName, unsigned> &minOpPatternDepth,
1969 
1970  /// The current set of patterns that have been applied.
1971  SmallPtrSet<const Pattern *, 8> appliedPatterns;
1972 
1973  /// The legalization information provided by the target.
1974  const ConversionTarget &target;
1975 
1976  /// The pattern applicator to use for conversions.
1977  PatternApplicator applicator;
1978 
1979  /// Dialect conversion configuration.
1980  const ConversionConfig &config;
1981 };
1982 } // namespace
1983 
1984 OperationLegalizer::OperationLegalizer(const ConversionTarget &targetInfo,
1986  const ConversionConfig &config)
1987  : target(targetInfo), applicator(patterns), config(config) {
1988  // The set of patterns that can be applied to illegal operations to transform
1989  // them into legal ones.
1991  LegalizationPatterns anyOpLegalizerPatterns;
1992 
1993  buildLegalizationGraph(anyOpLegalizerPatterns, legalizerPatterns);
1994  computeLegalizationGraphBenefit(anyOpLegalizerPatterns, legalizerPatterns);
1995 }
1996 
1997 bool OperationLegalizer::isIllegal(Operation *op) const {
1998  return target.isIllegal(op);
1999 }
2000 
2001 LogicalResult
2002 OperationLegalizer::legalize(Operation *op,
2003  ConversionPatternRewriter &rewriter) {
2004 #ifndef NDEBUG
2005  const char *logLineComment =
2006  "//===-------------------------------------------===//\n";
2007 
2008  auto &logger = rewriter.getImpl().logger;
2009 #endif
2010  LLVM_DEBUG({
2011  logger.getOStream() << "\n";
2012  logger.startLine() << logLineComment;
2013  logger.startLine() << "Legalizing operation : '" << op->getName() << "'("
2014  << op << ") {\n";
2015  logger.indent();
2016 
2017  // If the operation has no regions, just print it here.
2018  if (op->getNumRegions() == 0) {
2019  op->print(logger.startLine(), OpPrintingFlags().printGenericOpForm());
2020  logger.getOStream() << "\n\n";
2021  }
2022  });
2023 
2024  // Check if this operation is legal on the target.
2025  if (auto legalityInfo = target.isLegal(op)) {
2026  LLVM_DEBUG({
2027  logSuccess(
2028  logger, "operation marked legal by the target{0}",
2029  legalityInfo->isRecursivelyLegal
2030  ? "; NOTE: operation is recursively legal; skipping internals"
2031  : "");
2032  logger.startLine() << logLineComment;
2033  });
2034 
2035  // If this operation is recursively legal, mark its children as ignored so
2036  // that we don't consider them for legalization.
2037  if (legalityInfo->isRecursivelyLegal) {
2038  op->walk([&](Operation *nested) {
2039  if (op != nested)
2040  rewriter.getImpl().ignoredOps.insert(nested);
2041  });
2042  }
2043 
2044  return success();
2045  }
2046 
2047  // Check to see if the operation is ignored and doesn't need to be converted.
2048  if (rewriter.getImpl().isOpIgnored(op)) {
2049  LLVM_DEBUG({
2050  logSuccess(logger, "operation marked 'ignored' during conversion");
2051  logger.startLine() << logLineComment;
2052  });
2053  return success();
2054  }
2055 
2056  // If the operation isn't legal, try to fold it in-place.
2057  // TODO: Should we always try to do this, even if the op is
2058  // already legal?
2059  if (succeeded(legalizeWithFold(op, rewriter))) {
2060  LLVM_DEBUG({
2061  logSuccess(logger, "operation was folded");
2062  logger.startLine() << logLineComment;
2063  });
2064  return success();
2065  }
2066 
2067  // Otherwise, we need to apply a legalization pattern to this operation.
2068  if (succeeded(legalizeWithPattern(op, rewriter))) {
2069  LLVM_DEBUG({
2070  logSuccess(logger, "");
2071  logger.startLine() << logLineComment;
2072  });
2073  return success();
2074  }
2075 
2076  LLVM_DEBUG({
2077  logFailure(logger, "no matched legalization pattern");
2078  logger.startLine() << logLineComment;
2079  });
2080  return failure();
2081 }
2082 
2083 LogicalResult
2084 OperationLegalizer::legalizeWithFold(Operation *op,
2085  ConversionPatternRewriter &rewriter) {
2086  auto &rewriterImpl = rewriter.getImpl();
2087  RewriterState curState = rewriterImpl.getCurrentState();
2088 
2089  LLVM_DEBUG({
2090  rewriterImpl.logger.startLine() << "* Fold {\n";
2091  rewriterImpl.logger.indent();
2092  });
2093 
2094  // Try to fold the operation.
2095  SmallVector<Value, 2> replacementValues;
2096  rewriter.setInsertionPoint(op);
2097  if (failed(rewriter.tryFold(op, replacementValues))) {
2098  LLVM_DEBUG(logFailure(rewriterImpl.logger, "unable to fold"));
2099  return failure();
2100  }
2101  // An empty list of replacement values indicates that the fold was in-place.
2102  // As the operation changed, a new legalization needs to be attempted.
2103  if (replacementValues.empty())
2104  return legalize(op, rewriter);
2105 
2106  // Insert a replacement for 'op' with the folded replacement values.
2107  rewriter.replaceOp(op, replacementValues);
2108 
2109  // Recursively legalize any new constant operations.
2110  for (unsigned i = curState.numRewrites, e = rewriterImpl.rewrites.size();
2111  i != e; ++i) {
2112  auto *createOp =
2113  dyn_cast<CreateOperationRewrite>(rewriterImpl.rewrites[i].get());
2114  if (!createOp)
2115  continue;
2116  if (failed(legalize(createOp->getOperation(), rewriter))) {
2117  LLVM_DEBUG(logFailure(rewriterImpl.logger,
2118  "failed to legalize generated constant '{0}'",
2119  createOp->getOperation()->getName()));
2120  rewriterImpl.resetState(curState);
2121  return failure();
2122  }
2123  }
2124 
2125  LLVM_DEBUG(logSuccess(rewriterImpl.logger, ""));
2126  return success();
2127 }
2128 
2129 LogicalResult
2130 OperationLegalizer::legalizeWithPattern(Operation *op,
2131  ConversionPatternRewriter &rewriter) {
2132  auto &rewriterImpl = rewriter.getImpl();
2133 
2134  // Functor that returns if the given pattern may be applied.
2135  auto canApply = [&](const Pattern &pattern) {
2136  bool canApply = canApplyPattern(op, pattern, rewriter);
2137  if (canApply && config.listener)
2138  config.listener->notifyPatternBegin(pattern, op);
2139  return canApply;
2140  };
2141 
2142  // Functor that cleans up the rewriter state after a pattern failed to match.
2143  RewriterState curState = rewriterImpl.getCurrentState();
2144  auto onFailure = [&](const Pattern &pattern) {
2145  assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates");
2146  LLVM_DEBUG({
2147  logFailure(rewriterImpl.logger, "pattern failed to match");
2148  if (rewriterImpl.config.notifyCallback) {
2150  diag << "Failed to apply pattern \"" << pattern.getDebugName()
2151  << "\" on op:\n"
2152  << *op;
2153  rewriterImpl.config.notifyCallback(diag);
2154  }
2155  });
2156  if (config.listener)
2157  config.listener->notifyPatternEnd(pattern, failure());
2158  rewriterImpl.resetState(curState);
2159  appliedPatterns.erase(&pattern);
2160  };
2161 
2162  // Functor that performs additional legalization when a pattern is
2163  // successfully applied.
2164  auto onSuccess = [&](const Pattern &pattern) {
2165  assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates");
2166  auto result = legalizePatternResult(op, pattern, rewriter, curState);
2167  appliedPatterns.erase(&pattern);
2168  if (failed(result))
2169  rewriterImpl.resetState(curState);
2170  if (config.listener)
2171  config.listener->notifyPatternEnd(pattern, result);
2172  return result;
2173  };
2174 
2175  // Try to match and rewrite a pattern on this operation.
2176  return applicator.matchAndRewrite(op, rewriter, canApply, onFailure,
2177  onSuccess);
2178 }
2179 
2180 bool OperationLegalizer::canApplyPattern(Operation *op, const Pattern &pattern,
2181  ConversionPatternRewriter &rewriter) {
2182  LLVM_DEBUG({
2183  auto &os = rewriter.getImpl().logger;
2184  os.getOStream() << "\n";
2185  os.startLine() << "* Pattern : '" << op->getName() << " -> (";
2186  llvm::interleaveComma(pattern.getGeneratedOps(), os.getOStream());
2187  os.getOStream() << ")' {\n";
2188  os.indent();
2189  });
2190 
2191  // Ensure that we don't cycle by not allowing the same pattern to be
2192  // applied twice in the same recursion stack if it is not known to be safe.
2193  if (!pattern.hasBoundedRewriteRecursion() &&
2194  !appliedPatterns.insert(&pattern).second) {
2195  LLVM_DEBUG(
2196  logFailure(rewriter.getImpl().logger, "pattern was already applied"));
2197  return false;
2198  }
2199  return true;
2200 }
2201 
2202 LogicalResult
2203 OperationLegalizer::legalizePatternResult(Operation *op, const Pattern &pattern,
2204  ConversionPatternRewriter &rewriter,
2205  RewriterState &curState) {
2206  auto &impl = rewriter.getImpl();
2207  assert(impl.pendingRootUpdates.empty() && "dangling root updates");
2208 
2209 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2210  // Check that the root was either replaced or updated in place.
2211  auto newRewrites = llvm::drop_begin(impl.rewrites, curState.numRewrites);
2212  auto replacedRoot = [&] {
2213  return hasRewrite<ReplaceOperationRewrite>(newRewrites, op);
2214  };
2215  auto updatedRootInPlace = [&] {
2216  return hasRewrite<ModifyOperationRewrite>(newRewrites, op);
2217  };
2218  if (!replacedRoot() && !updatedRootInPlace())
2219  llvm::report_fatal_error("expected pattern to replace the root operation");
2220 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2221 
2222  // Legalize each of the actions registered during application.
2223  RewriterState newState = impl.getCurrentState();
2224  if (failed(legalizePatternBlockRewrites(op, rewriter, impl, curState,
2225  newState)) ||
2226  failed(legalizePatternRootUpdates(rewriter, impl, curState, newState)) ||
2227  failed(legalizePatternCreatedOperations(rewriter, impl, curState,
2228  newState))) {
2229  return failure();
2230  }
2231 
2232  LLVM_DEBUG(logSuccess(impl.logger, "pattern applied successfully"));
2233  return success();
2234 }
2235 
2236 LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
2237  Operation *op, ConversionPatternRewriter &rewriter,
2238  ConversionPatternRewriterImpl &impl, RewriterState &state,
2239  RewriterState &newState) {
2240  SmallPtrSet<Operation *, 16> operationsToIgnore;
2241 
2242  // If the pattern moved or created any blocks, make sure the types of block
2243  // arguments get legalized.
2244  for (int i = state.numRewrites, e = newState.numRewrites; i != e; ++i) {
2245  BlockRewrite *rewrite = dyn_cast<BlockRewrite>(impl.rewrites[i].get());
2246  if (!rewrite)
2247  continue;
2248  Block *block = rewrite->getBlock();
2249  if (isa<BlockTypeConversionRewrite, EraseBlockRewrite,
2250  ReplaceBlockArgRewrite>(rewrite))
2251  continue;
2252  // Only check blocks outside of the current operation.
2253  Operation *parentOp = block->getParentOp();
2254  if (!parentOp || parentOp == op || block->getNumArguments() == 0)
2255  continue;
2256 
2257  // If the region of the block has a type converter, try to convert the block
2258  // directly.
2259  if (auto *converter = impl.regionToConverter.lookup(block->getParent())) {
2260  std::optional<TypeConverter::SignatureConversion> conversion =
2261  converter->convertBlockSignature(block);
2262  if (!conversion) {
2263  LLVM_DEBUG(logFailure(impl.logger, "failed to convert types of moved "
2264  "block"));
2265  return failure();
2266  }
2267  impl.applySignatureConversion(rewriter, block, converter, *conversion);
2268  continue;
2269  }
2270 
2271  // Otherwise, check that this operation isn't one generated by this pattern.
2272  // This is because we will attempt to legalize the parent operation, and
2273  // blocks in regions created by this pattern will already be legalized later
2274  // on. If we haven't built the set yet, build it now.
2275  if (operationsToIgnore.empty()) {
2276  for (unsigned i = state.numRewrites, e = impl.rewrites.size(); i != e;
2277  ++i) {
2278  auto *createOp =
2279  dyn_cast<CreateOperationRewrite>(impl.rewrites[i].get());
2280  if (!createOp)
2281  continue;
2282  operationsToIgnore.insert(createOp->getOperation());
2283  }
2284  }
2285 
2286  // If this operation should be considered for re-legalization, try it.
2287  if (operationsToIgnore.insert(parentOp).second &&
2288  failed(legalize(parentOp, rewriter))) {
2289  LLVM_DEBUG(logFailure(impl.logger,
2290  "operation '{0}'({1}) became illegal after rewrite",
2291  parentOp->getName(), parentOp));
2292  return failure();
2293  }
2294  }
2295  return success();
2296 }
2297 
2298 LogicalResult OperationLegalizer::legalizePatternCreatedOperations(
2300  RewriterState &state, RewriterState &newState) {
2301  for (int i = state.numRewrites, e = newState.numRewrites; i != e; ++i) {
2302  auto *createOp = dyn_cast<CreateOperationRewrite>(impl.rewrites[i].get());
2303  if (!createOp)
2304  continue;
2305  Operation *op = createOp->getOperation();
2306  if (failed(legalize(op, rewriter))) {
2307  LLVM_DEBUG(logFailure(impl.logger,
2308  "failed to legalize generated operation '{0}'({1})",
2309  op->getName(), op));
2310  return failure();
2311  }
2312  }
2313  return success();
2314 }
2315 
2316 LogicalResult OperationLegalizer::legalizePatternRootUpdates(
2318  RewriterState &state, RewriterState &newState) {
2319  for (int i = state.numRewrites, e = newState.numRewrites; i != e; ++i) {
2320  auto *rewrite = dyn_cast<ModifyOperationRewrite>(impl.rewrites[i].get());
2321  if (!rewrite)
2322  continue;
2323  Operation *op = rewrite->getOperation();
2324  if (failed(legalize(op, rewriter))) {
2325  LLVM_DEBUG(logFailure(
2326  impl.logger, "failed to legalize operation updated in-place '{0}'",
2327  op->getName()));
2328  return failure();
2329  }
2330  }
2331  return success();
2332 }
2333 
2334 //===----------------------------------------------------------------------===//
2335 // Cost Model
2336 
2337 void OperationLegalizer::buildLegalizationGraph(
2338  LegalizationPatterns &anyOpLegalizerPatterns,
2339  DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
2340  // A mapping between an operation and a set of operations that can be used to
2341  // generate it.
2343  // A mapping between an operation and any currently invalid patterns it has.
2345  // A worklist of patterns to consider for legality.
2346  SetVector<const Pattern *> patternWorklist;
2347 
2348  // Build the mapping from operations to the parent ops that may generate them.
2349  applicator.walkAllPatterns([&](const Pattern &pattern) {
2350  std::optional<OperationName> root = pattern.getRootKind();
2351 
2352  // If the pattern has no specific root, we can't analyze the relationship
2353  // between the root op and generated operations. Given that, add all such
2354  // patterns to the legalization set.
2355  if (!root) {
2356  anyOpLegalizerPatterns.push_back(&pattern);
2357  return;
2358  }
2359 
2360  // Skip operations that are always known to be legal.
2361  if (target.getOpAction(*root) == LegalizationAction::Legal)
2362  return;
2363 
2364  // Add this pattern to the invalid set for the root op and record this root
2365  // as a parent for any generated operations.
2366  invalidPatterns[*root].insert(&pattern);
2367  for (auto op : pattern.getGeneratedOps())
2368  parentOps[op].insert(*root);
2369 
2370  // Add this pattern to the worklist.
2371  patternWorklist.insert(&pattern);
2372  });
2373 
2374  // If there are any patterns that don't have a specific root kind, we can't
2375  // make direct assumptions about what operations will never be legalized.
2376  // Note: Technically we could, but it would require an analysis that may
2377  // recurse into itself. It would be better to perform this kind of filtering
2378  // at a higher level than here anyways.
2379  if (!anyOpLegalizerPatterns.empty()) {
2380  for (const Pattern *pattern : patternWorklist)
2381  legalizerPatterns[*pattern->getRootKind()].push_back(pattern);
2382  return;
2383  }
2384 
2385  while (!patternWorklist.empty()) {
2386  auto *pattern = patternWorklist.pop_back_val();
2387 
2388  // Check to see if any of the generated operations are invalid.
2389  if (llvm::any_of(pattern->getGeneratedOps(), [&](OperationName op) {
2390  std::optional<LegalizationAction> action = target.getOpAction(op);
2391  return !legalizerPatterns.count(op) &&
2392  (!action || action == LegalizationAction::Illegal);
2393  }))
2394  continue;
2395 
2396  // Otherwise, if all of the generated operation are valid, this op is now
2397  // legal so add all of the child patterns to the worklist.
2398  legalizerPatterns[*pattern->getRootKind()].push_back(pattern);
2399  invalidPatterns[*pattern->getRootKind()].erase(pattern);
2400 
2401  // Add any invalid patterns of the parent operations to see if they have now
2402  // become legal.
2403  for (auto op : parentOps[*pattern->getRootKind()])
2404  patternWorklist.set_union(invalidPatterns[op]);
2405  }
2406 }
2407 
2408 void OperationLegalizer::computeLegalizationGraphBenefit(
2409  LegalizationPatterns &anyOpLegalizerPatterns,
2410  DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
2411  // The smallest pattern depth, when legalizing an operation.
2412  DenseMap<OperationName, unsigned> minOpPatternDepth;
2413 
2414  // For each operation that is transitively legal, compute a cost for it.
2415  for (auto &opIt : legalizerPatterns)
2416  if (!minOpPatternDepth.count(opIt.first))
2417  computeOpLegalizationDepth(opIt.first, minOpPatternDepth,
2418  legalizerPatterns);
2419 
2420  // Apply the cost model to the patterns that can match any operation. Those
2421  // with a specific operation type are already resolved when computing the op
2422  // legalization depth.
2423  if (!anyOpLegalizerPatterns.empty())
2424  applyCostModelToPatterns(anyOpLegalizerPatterns, minOpPatternDepth,
2425  legalizerPatterns);
2426 
2427  // Apply a cost model to the pattern applicator. We order patterns first by
2428  // depth then benefit. `legalizerPatterns` contains per-op patterns by
2429  // decreasing benefit.
2430  applicator.applyCostModel([&](const Pattern &pattern) {
2431  ArrayRef<const Pattern *> orderedPatternList;
2432  if (std::optional<OperationName> rootName = pattern.getRootKind())
2433  orderedPatternList = legalizerPatterns[*rootName];
2434  else
2435  orderedPatternList = anyOpLegalizerPatterns;
2436 
2437  // If the pattern is not found, then it was removed and cannot be matched.
2438  auto *it = llvm::find(orderedPatternList, &pattern);
2439  if (it == orderedPatternList.end())
2441 
2442  // Patterns found earlier in the list have higher benefit.
2443  return PatternBenefit(std::distance(it, orderedPatternList.end()));
2444  });
2445 }
2446 
2447 unsigned OperationLegalizer::computeOpLegalizationDepth(
2448  OperationName op, DenseMap<OperationName, unsigned> &minOpPatternDepth,
2449  DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
2450  // Check for existing depth.
2451  auto depthIt = minOpPatternDepth.find(op);
2452  if (depthIt != minOpPatternDepth.end())
2453  return depthIt->second;
2454 
2455  // If a mapping for this operation does not exist, then this operation
2456  // is always legal. Return 0 as the depth for a directly legal operation.
2457  auto opPatternsIt = legalizerPatterns.find(op);
2458  if (opPatternsIt == legalizerPatterns.end() || opPatternsIt->second.empty())
2459  return 0u;
2460 
2461  // Record this initial depth in case we encounter this op again when
2462  // recursively computing the depth.
2463  minOpPatternDepth.try_emplace(op, std::numeric_limits<unsigned>::max());
2464 
2465  // Apply the cost model to the operation patterns, and update the minimum
2466  // depth.
2467  unsigned minDepth = applyCostModelToPatterns(
2468  opPatternsIt->second, minOpPatternDepth, legalizerPatterns);
2469  minOpPatternDepth[op] = minDepth;
2470  return minDepth;
2471 }
2472 
2473 unsigned OperationLegalizer::applyCostModelToPatterns(
2474  LegalizationPatterns &patterns,
2475  DenseMap<OperationName, unsigned> &minOpPatternDepth,
2476  DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
2477  unsigned minDepth = std::numeric_limits<unsigned>::max();
2478 
2479  // Compute the depth for each pattern within the set.
2480  SmallVector<std::pair<const Pattern *, unsigned>, 4> patternsByDepth;
2481  patternsByDepth.reserve(patterns.size());
2482  for (const Pattern *pattern : patterns) {
2483  unsigned depth = 1;
2484  for (auto generatedOp : pattern->getGeneratedOps()) {
2485  unsigned generatedOpDepth = computeOpLegalizationDepth(
2486  generatedOp, minOpPatternDepth, legalizerPatterns);
2487  depth = std::max(depth, generatedOpDepth + 1);
2488  }
2489  patternsByDepth.emplace_back(pattern, depth);
2490 
2491  // Update the minimum depth of the pattern list.
2492  minDepth = std::min(minDepth, depth);
2493  }
2494 
2495  // If the operation only has one legalization pattern, there is no need to
2496  // sort them.
2497  if (patternsByDepth.size() == 1)
2498  return minDepth;
2499 
2500  // Sort the patterns by those likely to be the most beneficial.
2501  std::stable_sort(patternsByDepth.begin(), patternsByDepth.end(),
2502  [](const std::pair<const Pattern *, unsigned> &lhs,
2503  const std::pair<const Pattern *, unsigned> &rhs) {
2504  // First sort by the smaller pattern legalization
2505  // depth.
2506  if (lhs.second != rhs.second)
2507  return lhs.second < rhs.second;
2508 
2509  // Then sort by the larger pattern benefit.
2510  auto lhsBenefit = lhs.first->getBenefit();
2511  auto rhsBenefit = rhs.first->getBenefit();
2512  return lhsBenefit > rhsBenefit;
2513  });
2514 
2515  // Update the legalization pattern to use the new sorted list.
2516  patterns.clear();
2517  for (auto &patternIt : patternsByDepth)
2518  patterns.push_back(patternIt.first);
2519  return minDepth;
2520 }
2521 
2522 //===----------------------------------------------------------------------===//
2523 // OperationConverter
2524 //===----------------------------------------------------------------------===//
2525 namespace {
2526 enum OpConversionMode {
2527  /// In this mode, the conversion will ignore failed conversions to allow
2528  /// illegal operations to co-exist in the IR.
2529  Partial,
2530 
2531  /// In this mode, all operations must be legal for the given target for the
2532  /// conversion to succeed.
2533  Full,
2534 
2535  /// In this mode, operations are analyzed for legality. No actual rewrites are
2536  /// applied to the operations on success.
2537  Analysis,
2538 };
2539 } // namespace
2540 
2541 namespace mlir {
2542 // This class converts operations to a given conversion target via a set of
2543 // rewrite patterns. The conversion behaves differently depending on the
2544 // conversion mode.
2546  explicit OperationConverter(const ConversionTarget &target,
2548  const ConversionConfig &config,
2549  OpConversionMode mode)
2550  : config(config), opLegalizer(target, patterns, this->config),
2551  mode(mode) {}
2552 
2553  /// Converts the given operations to the conversion target.
2554  LogicalResult convertOperations(ArrayRef<Operation *> ops);
2555 
2556 private:
2557  /// Converts an operation with the given rewriter.
2558  LogicalResult convert(ConversionPatternRewriter &rewriter, Operation *op);
2559 
2560  /// Dialect conversion configuration.
2561  ConversionConfig config;
2562 
2563  /// The legalizer to use when converting operations.
2564  OperationLegalizer opLegalizer;
2565 
2566  /// The conversion mode to use when legalizing operations.
2567  OpConversionMode mode;
2568 };
2569 } // namespace mlir
2570 
2571 LogicalResult OperationConverter::convert(ConversionPatternRewriter &rewriter,
2572  Operation *op) {
2573  // Legalize the given operation.
2574  if (failed(opLegalizer.legalize(op, rewriter))) {
2575  // Handle the case of a failed conversion for each of the different modes.
2576  // Full conversions expect all operations to be converted.
2577  if (mode == OpConversionMode::Full)
2578  return op->emitError()
2579  << "failed to legalize operation '" << op->getName() << "'";
2580  // Partial conversions allow conversions to fail iff the operation was not
2581  // explicitly marked as illegal. If the user provided a `unlegalizedOps`
2582  // set, non-legalizable ops are added to that set.
2583  if (mode == OpConversionMode::Partial) {
2584  if (opLegalizer.isIllegal(op))
2585  return op->emitError()
2586  << "failed to legalize operation '" << op->getName()
2587  << "' that was explicitly marked illegal";
2588  if (config.unlegalizedOps)
2589  config.unlegalizedOps->insert(op);
2590  }
2591  } else if (mode == OpConversionMode::Analysis) {
2592  // Analysis conversions don't fail if any operations fail to legalize,
2593  // they are only interested in the operations that were successfully
2594  // legalized.
2595  if (config.legalizableOps)
2596  config.legalizableOps->insert(op);
2597  }
2598  return success();
2599 }
2600 
2601 static LogicalResult
2603  UnresolvedMaterializationRewrite *rewrite) {
2604  UnrealizedConversionCastOp op = rewrite->getOperation();
2605  assert(!op.use_empty() &&
2606  "expected that dead materializations have already been DCE'd");
2607  Operation::operand_range inputOperands = op.getOperands();
2608 
2609  // Try to materialize the conversion.
2610  if (const TypeConverter *converter = rewrite->getConverter()) {
2611  rewriter.setInsertionPoint(op);
2612  SmallVector<Value> newMaterialization;
2613  switch (rewrite->getMaterializationKind()) {
2614  case MaterializationKind::Target:
2615  newMaterialization = converter->materializeTargetConversion(
2616  rewriter, op->getLoc(), op.getResultTypes(), inputOperands,
2617  rewrite->getOriginalType());
2618  break;
2619  case MaterializationKind::Source:
2620  assert(op->getNumResults() == 1 && "expected single result");
2621  Value sourceMat = converter->materializeSourceConversion(
2622  rewriter, op->getLoc(), op.getResultTypes().front(), inputOperands);
2623  if (sourceMat)
2624  newMaterialization.push_back(sourceMat);
2625  break;
2626  }
2627  if (!newMaterialization.empty()) {
2628 #ifndef NDEBUG
2629  ValueRange newMaterializationRange(newMaterialization);
2630  assert(TypeRange(newMaterializationRange) == op.getResultTypes() &&
2631  "materialization callback produced value of incorrect type");
2632 #endif // NDEBUG
2633  rewriter.replaceOp(op, newMaterialization);
2634  return success();
2635  }
2636  }
2637 
2638  InFlightDiagnostic diag = op->emitError()
2639  << "failed to legalize unresolved materialization "
2640  "from ("
2641  << inputOperands.getTypes() << ") to ("
2642  << op.getResultTypes()
2643  << ") that remained live after conversion";
2644  diag.attachNote(op->getUsers().begin()->getLoc())
2645  << "see existing live user here: " << *op->getUsers().begin();
2646  return failure();
2647 }
2648 
2650  if (ops.empty())
2651  return success();
2652  const ConversionTarget &target = opLegalizer.getTarget();
2653 
2654  // Compute the set of operations and blocks to convert.
2655  SmallVector<Operation *> toConvert;
2656  for (auto *op : ops) {
2658  [&](Operation *op) {
2659  toConvert.push_back(op);
2660  // Don't check this operation's children for conversion if the
2661  // operation is recursively legal.
2662  auto legalityInfo = target.isLegal(op);
2663  if (legalityInfo && legalityInfo->isRecursivelyLegal)
2664  return WalkResult::skip();
2665  return WalkResult::advance();
2666  });
2667  }
2668 
2669  // Convert each operation and discard rewrites on failure.
2670  ConversionPatternRewriter rewriter(ops.front()->getContext(), config);
2671  ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
2672 
2673  for (auto *op : toConvert)
2674  if (failed(convert(rewriter, op)))
2675  return rewriterImpl.undoRewrites(), failure();
2676 
2677  // After a successful conversion, apply rewrites.
2678  rewriterImpl.applyRewrites();
2679 
2680  // Gather all unresolved materializations.
2683  &materializations = rewriterImpl.unresolvedMaterializations;
2684  for (auto it : materializations) {
2685  if (rewriterImpl.eraseRewriter.wasErased(it.first))
2686  continue;
2687  allCastOps.push_back(it.first);
2688  }
2689 
2690  // Reconcile all UnrealizedConversionCastOps that were inserted by the
2691  // dialect conversion frameworks. (Not the one that were inserted by
2692  // patterns.)
2693  SmallVector<UnrealizedConversionCastOp> remainingCastOps;
2694  reconcileUnrealizedCasts(allCastOps, &remainingCastOps);
2695 
2696  // Try to legalize all unresolved materializations.
2697  if (config.buildMaterializations) {
2698  IRRewriter rewriter(rewriterImpl.context, config.listener);
2699  for (UnrealizedConversionCastOp castOp : remainingCastOps) {
2700  auto it = materializations.find(castOp);
2701  assert(it != materializations.end() && "inconsistent state");
2702  if (failed(legalizeUnresolvedMaterialization(rewriter, it->second)))
2703  return failure();
2704  }
2705  }
2706 
2707  return success();
2708 }
2709 
2710 //===----------------------------------------------------------------------===//
2711 // Reconcile Unrealized Casts
2712 //===----------------------------------------------------------------------===//
2713 
2716  SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
2717  SetVector<UnrealizedConversionCastOp> worklist(castOps.begin(),
2718  castOps.end());
2719  // This set is maintained only if `remainingCastOps` is provided.
2720  DenseSet<Operation *> erasedOps;
2721 
2722  // Helper function that adds all operands to the worklist that are an
2723  // unrealized_conversion_cast op result.
2724  auto enqueueOperands = [&](UnrealizedConversionCastOp castOp) {
2725  for (Value v : castOp.getInputs())
2726  if (auto inputCastOp = v.getDefiningOp<UnrealizedConversionCastOp>())
2727  worklist.insert(inputCastOp);
2728  };
2729 
2730  // Helper function that return the unrealized_conversion_cast op that
2731  // defines all inputs of the given op (in the same order). Return "nullptr"
2732  // if there is no such op.
2733  auto getInputCast =
2734  [](UnrealizedConversionCastOp castOp) -> UnrealizedConversionCastOp {
2735  if (castOp.getInputs().empty())
2736  return {};
2737  auto inputCastOp =
2738  castOp.getInputs().front().getDefiningOp<UnrealizedConversionCastOp>();
2739  if (!inputCastOp)
2740  return {};
2741  if (inputCastOp.getOutputs() != castOp.getInputs())
2742  return {};
2743  return inputCastOp;
2744  };
2745 
2746  // Process ops in the worklist bottom-to-top.
2747  while (!worklist.empty()) {
2748  UnrealizedConversionCastOp castOp = worklist.pop_back_val();
2749  if (castOp->use_empty()) {
2750  // DCE: If the op has no users, erase it. Add the operands to the
2751  // worklist to find additional DCE opportunities.
2752  enqueueOperands(castOp);
2753  if (remainingCastOps)
2754  erasedOps.insert(castOp.getOperation());
2755  castOp->erase();
2756  continue;
2757  }
2758 
2759  // Traverse the chain of input cast ops to see if an op with the same
2760  // input types can be found.
2761  UnrealizedConversionCastOp nextCast = castOp;
2762  while (nextCast) {
2763  if (nextCast.getInputs().getTypes() == castOp.getResultTypes()) {
2764  // Found a cast where the input types match the output types of the
2765  // matched op. We can directly use those inputs and the matched op can
2766  // be removed.
2767  enqueueOperands(castOp);
2768  castOp.replaceAllUsesWith(nextCast.getInputs());
2769  if (remainingCastOps)
2770  erasedOps.insert(castOp.getOperation());
2771  castOp->erase();
2772  break;
2773  }
2774  nextCast = getInputCast(nextCast);
2775  }
2776  }
2777 
2778  if (remainingCastOps)
2779  for (UnrealizedConversionCastOp op : castOps)
2780  if (!erasedOps.contains(op.getOperation()))
2781  remainingCastOps->push_back(op);
2782 }
2783 
2784 //===----------------------------------------------------------------------===//
2785 // Type Conversion
2786 //===----------------------------------------------------------------------===//
2787 
2789  ArrayRef<Type> types) {
2790  assert(!types.empty() && "expected valid types");
2791  remapInput(origInputNo, /*newInputNo=*/argTypes.size(), types.size());
2792  addInputs(types);
2793 }
2794 
2796  assert(!types.empty() &&
2797  "1->0 type remappings don't need to be added explicitly");
2798  argTypes.append(types.begin(), types.end());
2799 }
2800 
2801 void TypeConverter::SignatureConversion::remapInput(unsigned origInputNo,
2802  unsigned newInputNo,
2803  unsigned newInputCount) {
2804  assert(!remappedInputs[origInputNo] && "input has already been remapped");
2805  assert(newInputCount != 0 && "expected valid input count");
2806  remappedInputs[origInputNo] =
2807  InputMapping{newInputNo, newInputCount, /*replacementValue=*/nullptr};
2808 }
2809 
2811  Value replacementValue) {
2812  assert(!remappedInputs[origInputNo] && "input has already been remapped");
2813  remappedInputs[origInputNo] =
2814  InputMapping{origInputNo, /*size=*/0, replacementValue};
2815 }
2816 
2818  SmallVectorImpl<Type> &results) const {
2819  assert(t && "expected non-null type");
2820 
2821  {
2822  std::shared_lock<decltype(cacheMutex)> cacheReadLock(cacheMutex,
2823  std::defer_lock);
2825  cacheReadLock.lock();
2826  auto existingIt = cachedDirectConversions.find(t);
2827  if (existingIt != cachedDirectConversions.end()) {
2828  if (existingIt->second)
2829  results.push_back(existingIt->second);
2830  return success(existingIt->second != nullptr);
2831  }
2832  auto multiIt = cachedMultiConversions.find(t);
2833  if (multiIt != cachedMultiConversions.end()) {
2834  results.append(multiIt->second.begin(), multiIt->second.end());
2835  return success();
2836  }
2837  }
2838  // Walk the added converters in reverse order to apply the most recently
2839  // registered first.
2840  size_t currentCount = results.size();
2841 
2842  std::unique_lock<decltype(cacheMutex)> cacheWriteLock(cacheMutex,
2843  std::defer_lock);
2844 
2845  for (const ConversionCallbackFn &converter : llvm::reverse(conversions)) {
2846  if (std::optional<LogicalResult> result = converter(t, results)) {
2848  cacheWriteLock.lock();
2849  if (!succeeded(*result)) {
2850  cachedDirectConversions.try_emplace(t, nullptr);
2851  return failure();
2852  }
2853  auto newTypes = ArrayRef<Type>(results).drop_front(currentCount);
2854  if (newTypes.size() == 1)
2855  cachedDirectConversions.try_emplace(t, newTypes.front());
2856  else
2857  cachedMultiConversions.try_emplace(t, llvm::to_vector<2>(newTypes));
2858  return success();
2859  }
2860  }
2861  return failure();
2862 }
2863 
2865  // Use the multi-type result version to convert the type.
2866  SmallVector<Type, 1> results;
2867  if (failed(convertType(t, results)))
2868  return nullptr;
2869 
2870  // Check to ensure that only one type was produced.
2871  return results.size() == 1 ? results.front() : nullptr;
2872 }
2873 
2874 LogicalResult
2876  SmallVectorImpl<Type> &results) const {
2877  for (Type type : types)
2878  if (failed(convertType(type, results)))
2879  return failure();
2880  return success();
2881 }
2882 
2883 bool TypeConverter::isLegal(Type type) const {
2884  return convertType(type) == type;
2885 }
2887  return isLegal(op->getOperandTypes()) && isLegal(op->getResultTypes());
2888 }
2889 
2890 bool TypeConverter::isLegal(Region *region) const {
2891  return llvm::all_of(*region, [this](Block &block) {
2892  return isLegal(block.getArgumentTypes());
2893  });
2894 }
2895 
2896 bool TypeConverter::isSignatureLegal(FunctionType ty) const {
2897  return isLegal(llvm::concat<const Type>(ty.getInputs(), ty.getResults()));
2898 }
2899 
2900 LogicalResult
2902  SignatureConversion &result) const {
2903  // Try to convert the given input type.
2904  SmallVector<Type, 1> convertedTypes;
2905  if (failed(convertType(type, convertedTypes)))
2906  return failure();
2907 
2908  // If this argument is being dropped, there is nothing left to do.
2909  if (convertedTypes.empty())
2910  return success();
2911 
2912  // Otherwise, add the new inputs.
2913  result.addInputs(inputNo, convertedTypes);
2914  return success();
2915 }
2916 LogicalResult
2918  SignatureConversion &result,
2919  unsigned origInputOffset) const {
2920  for (unsigned i = 0, e = types.size(); i != e; ++i)
2921  if (failed(convertSignatureArg(origInputOffset + i, types[i], result)))
2922  return failure();
2923  return success();
2924 }
2925 
2927  Location loc,
2928  Type resultType,
2929  ValueRange inputs) const {
2930  for (const MaterializationCallbackFn &fn :
2931  llvm::reverse(argumentMaterializations))
2932  if (Value result = fn(builder, resultType, inputs, loc))
2933  return result;
2934  return nullptr;
2935 }
2936 
2938  Location loc, Type resultType,
2939  ValueRange inputs) const {
2940  for (const MaterializationCallbackFn &fn :
2941  llvm::reverse(sourceMaterializations))
2942  if (Value result = fn(builder, resultType, inputs, loc))
2943  return result;
2944  return nullptr;
2945 }
2946 
2948  Location loc, Type resultType,
2949  ValueRange inputs,
2950  Type originalType) const {
2952  builder, loc, TypeRange(resultType), inputs, originalType);
2953  if (result.empty())
2954  return nullptr;
2955  assert(result.size() == 1 && "expected single result");
2956  return result.front();
2957 }
2958 
2960  OpBuilder &builder, Location loc, TypeRange resultTypes, ValueRange inputs,
2961  Type originalType) const {
2962  for (const TargetMaterializationCallbackFn &fn :
2963  llvm::reverse(targetMaterializations)) {
2964  SmallVector<Value> result =
2965  fn(builder, resultTypes, inputs, loc, originalType);
2966  if (result.empty())
2967  continue;
2968  assert(TypeRange(ValueRange(result)) == resultTypes &&
2969  "callback produced incorrect number of values or values with "
2970  "incorrect types");
2971  return result;
2972  }
2973  return {};
2974 }
2975 
2976 std::optional<TypeConverter::SignatureConversion>
2978  SignatureConversion conversion(block->getNumArguments());
2979  if (failed(convertSignatureArgs(block->getArgumentTypes(), conversion)))
2980  return std::nullopt;
2981  return conversion;
2982 }
2983 
2984 //===----------------------------------------------------------------------===//
2985 // Type attribute conversion
2986 //===----------------------------------------------------------------------===//
2989  return AttributeConversionResult(attr, resultTag);
2990 }
2991 
2994  return AttributeConversionResult(nullptr, naTag);
2995 }
2996 
2999  return AttributeConversionResult(nullptr, abortTag);
3000 }
3001 
3003  return impl.getInt() == resultTag;
3004 }
3005 
3007  return impl.getInt() == naTag;
3008 }
3009 
3011  return impl.getInt() == abortTag;
3012 }
3013 
3015  assert(hasResult() && "Cannot get result from N/A or abort");
3016  return impl.getPointer();
3017 }
3018 
3019 std::optional<Attribute>
3021  for (const TypeAttributeConversionCallbackFn &fn :
3022  llvm::reverse(typeAttributeConversions)) {
3023  AttributeConversionResult res = fn(type, attr);
3024  if (res.hasResult())
3025  return res.getResult();
3026  if (res.isAbort())
3027  return std::nullopt;
3028  }
3029  return std::nullopt;
3030 }
3031 
3032 //===----------------------------------------------------------------------===//
3033 // FunctionOpInterfaceSignatureConversion
3034 //===----------------------------------------------------------------------===//
3035 
3036 static LogicalResult convertFuncOpTypes(FunctionOpInterface funcOp,
3037  const TypeConverter &typeConverter,
3038  ConversionPatternRewriter &rewriter) {
3039  FunctionType type = dyn_cast<FunctionType>(funcOp.getFunctionType());
3040  if (!type)
3041  return failure();
3042 
3043  // Convert the original function types.
3044  TypeConverter::SignatureConversion result(type.getNumInputs());
3045  SmallVector<Type, 1> newResults;
3046  if (failed(typeConverter.convertSignatureArgs(type.getInputs(), result)) ||
3047  failed(typeConverter.convertTypes(type.getResults(), newResults)) ||
3048  failed(rewriter.convertRegionTypes(&funcOp.getFunctionBody(),
3049  typeConverter, &result)))
3050  return failure();
3051 
3052  // Update the function signature in-place.
3053  auto newType = FunctionType::get(rewriter.getContext(),
3054  result.getConvertedTypes(), newResults);
3055 
3056  rewriter.modifyOpInPlace(funcOp, [&] { funcOp.setType(newType); });
3057 
3058  return success();
3059 }
3060 
3061 /// Create a default conversion pattern that rewrites the type signature of a
3062 /// FunctionOpInterface op. This only supports ops which use FunctionType to
3063 /// represent their type.
3064 namespace {
3065 struct FunctionOpInterfaceSignatureConversion : public ConversionPattern {
3066  FunctionOpInterfaceSignatureConversion(StringRef functionLikeOpName,
3067  MLIRContext *ctx,
3068  const TypeConverter &converter)
3069  : ConversionPattern(converter, functionLikeOpName, /*benefit=*/1, ctx) {}
3070 
3071  LogicalResult
3072  matchAndRewrite(Operation *op, ArrayRef<Value> /*operands*/,
3073  ConversionPatternRewriter &rewriter) const override {
3074  FunctionOpInterface funcOp = cast<FunctionOpInterface>(op);
3075  return convertFuncOpTypes(funcOp, *typeConverter, rewriter);
3076  }
3077 };
3078 
3079 struct AnyFunctionOpInterfaceSignatureConversion
3080  : public OpInterfaceConversionPattern<FunctionOpInterface> {
3082 
3083  LogicalResult
3084  matchAndRewrite(FunctionOpInterface funcOp, ArrayRef<Value> /*operands*/,
3085  ConversionPatternRewriter &rewriter) const override {
3086  return convertFuncOpTypes(funcOp, *typeConverter, rewriter);
3087  }
3088 };
3089 } // namespace
3090 
3091 FailureOr<Operation *>
3093  const TypeConverter &converter,
3094  ConversionPatternRewriter &rewriter) {
3095  assert(op && "Invalid op");
3096  Location loc = op->getLoc();
3097  if (converter.isLegal(op))
3098  return rewriter.notifyMatchFailure(loc, "op already legal");
3099 
3100  OperationState newOp(loc, op->getName());
3101  newOp.addOperands(operands);
3102 
3103  SmallVector<Type> newResultTypes;
3104  if (failed(converter.convertTypes(op->getResultTypes(), newResultTypes)))
3105  return rewriter.notifyMatchFailure(loc, "couldn't convert return types");
3106 
3107  newOp.addTypes(newResultTypes);
3108  newOp.addAttributes(op->getAttrs());
3109  return rewriter.create(newOp);
3110 }
3111 
3113  StringRef functionLikeOpName, RewritePatternSet &patterns,
3114  const TypeConverter &converter) {
3115  patterns.add<FunctionOpInterfaceSignatureConversion>(
3116  functionLikeOpName, patterns.getContext(), converter);
3117 }
3118 
3120  RewritePatternSet &patterns, const TypeConverter &converter) {
3121  patterns.add<AnyFunctionOpInterfaceSignatureConversion>(
3122  converter, patterns.getContext());
3123 }
3124 
3125 //===----------------------------------------------------------------------===//
3126 // ConversionTarget
3127 //===----------------------------------------------------------------------===//
3128 
3130  LegalizationAction action) {
3131  legalOperations[op].action = action;
3132 }
3133 
3135  LegalizationAction action) {
3136  for (StringRef dialect : dialectNames)
3137  legalDialects[dialect] = action;
3138 }
3139 
3141  -> std::optional<LegalizationAction> {
3142  std::optional<LegalizationInfo> info = getOpInfo(op);
3143  return info ? info->action : std::optional<LegalizationAction>();
3144 }
3145 
3147  -> std::optional<LegalOpDetails> {
3148  std::optional<LegalizationInfo> info = getOpInfo(op->getName());
3149  if (!info)
3150  return std::nullopt;
3151 
3152  // Returns true if this operation instance is known to be legal.
3153  auto isOpLegal = [&] {
3154  // Handle dynamic legality either with the provided legality function.
3155  if (info->action == LegalizationAction::Dynamic) {
3156  std::optional<bool> result = info->legalityFn(op);
3157  if (result)
3158  return *result;
3159  }
3160 
3161  // Otherwise, the operation is only legal if it was marked 'Legal'.
3162  return info->action == LegalizationAction::Legal;
3163  };
3164  if (!isOpLegal())
3165  return std::nullopt;
3166 
3167  // This operation is legal, compute any additional legality information.
3168  LegalOpDetails legalityDetails;
3169  if (info->isRecursivelyLegal) {
3170  auto legalityFnIt = opRecursiveLegalityFns.find(op->getName());
3171  if (legalityFnIt != opRecursiveLegalityFns.end()) {
3172  legalityDetails.isRecursivelyLegal =
3173  legalityFnIt->second(op).value_or(true);
3174  } else {
3175  legalityDetails.isRecursivelyLegal = true;
3176  }
3177  }
3178  return legalityDetails;
3179 }
3180 
3182  std::optional<LegalizationInfo> info = getOpInfo(op->getName());
3183  if (!info)
3184  return false;
3185 
3186  if (info->action == LegalizationAction::Dynamic) {
3187  std::optional<bool> result = info->legalityFn(op);
3188  if (!result)
3189  return false;
3190 
3191  return !(*result);
3192  }
3193 
3194  return info->action == LegalizationAction::Illegal;
3195 }
3196 
3200  if (!oldCallback)
3201  return newCallback;
3202 
3203  auto chain = [oldCl = std::move(oldCallback), newCl = std::move(newCallback)](
3204  Operation *op) -> std::optional<bool> {
3205  if (std::optional<bool> result = newCl(op))
3206  return *result;
3207 
3208  return oldCl(op);
3209  };
3210  return chain;
3211 }
3212 
3213 void ConversionTarget::setLegalityCallback(
3214  OperationName name, const DynamicLegalityCallbackFn &callback) {
3215  assert(callback && "expected valid legality callback");
3216  auto *infoIt = legalOperations.find(name);
3217  assert(infoIt != legalOperations.end() &&
3218  infoIt->second.action == LegalizationAction::Dynamic &&
3219  "expected operation to already be marked as dynamically legal");
3220  infoIt->second.legalityFn =
3221  composeLegalityCallbacks(std::move(infoIt->second.legalityFn), callback);
3222 }
3223 
3225  OperationName name, const DynamicLegalityCallbackFn &callback) {
3226  auto *infoIt = legalOperations.find(name);
3227  assert(infoIt != legalOperations.end() &&
3228  infoIt->second.action != LegalizationAction::Illegal &&
3229  "expected operation to already be marked as legal");
3230  infoIt->second.isRecursivelyLegal = true;
3231  if (callback)
3232  opRecursiveLegalityFns[name] = composeLegalityCallbacks(
3233  std::move(opRecursiveLegalityFns[name]), callback);
3234  else
3235  opRecursiveLegalityFns.erase(name);
3236 }
3237 
3238 void ConversionTarget::setLegalityCallback(
3239  ArrayRef<StringRef> dialects, const DynamicLegalityCallbackFn &callback) {
3240  assert(callback && "expected valid legality callback");
3241  for (StringRef dialect : dialects)
3242  dialectLegalityFns[dialect] = composeLegalityCallbacks(
3243  std::move(dialectLegalityFns[dialect]), callback);
3244 }
3245 
3246 void ConversionTarget::setLegalityCallback(
3247  const DynamicLegalityCallbackFn &callback) {
3248  assert(callback && "expected valid legality callback");
3249  unknownLegalityFn = composeLegalityCallbacks(unknownLegalityFn, callback);
3250 }
3251 
3252 auto ConversionTarget::getOpInfo(OperationName op) const
3253  -> std::optional<LegalizationInfo> {
3254  // Check for info for this specific operation.
3255  const auto *it = legalOperations.find(op);
3256  if (it != legalOperations.end())
3257  return it->second;
3258  // Check for info for the parent dialect.
3259  auto dialectIt = legalDialects.find(op.getDialectNamespace());
3260  if (dialectIt != legalDialects.end()) {
3261  DynamicLegalityCallbackFn callback;
3262  auto dialectFn = dialectLegalityFns.find(op.getDialectNamespace());
3263  if (dialectFn != dialectLegalityFns.end())
3264  callback = dialectFn->second;
3265  return LegalizationInfo{dialectIt->second, /*isRecursivelyLegal=*/false,
3266  callback};
3267  }
3268  // Otherwise, check if we mark unknown operations as dynamic.
3269  if (unknownLegalityFn)
3270  return LegalizationInfo{LegalizationAction::Dynamic,
3271  /*isRecursivelyLegal=*/false, unknownLegalityFn};
3272  return std::nullopt;
3273 }
3274 
3275 #if MLIR_ENABLE_PDL_IN_PATTERNMATCH
3276 //===----------------------------------------------------------------------===//
3277 // PDL Configuration
3278 //===----------------------------------------------------------------------===//
3279 
3281  auto &rewriterImpl =
3282  static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
3283  rewriterImpl.currentTypeConverter = getTypeConverter();
3284 }
3285 
3287  auto &rewriterImpl =
3288  static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
3289  rewriterImpl.currentTypeConverter = nullptr;
3290 }
3291 
3292 /// Remap the given value using the rewriter and the type converter in the
3293 /// provided config.
3294 static FailureOr<SmallVector<Value>>
3296  SmallVector<Value> mappedValues;
3297  if (failed(rewriter.getRemappedValues(values, mappedValues)))
3298  return failure();
3299  return std::move(mappedValues);
3300 }
3301 
3303  patterns.getPDLPatterns().registerRewriteFunction(
3304  "convertValue",
3305  [](PatternRewriter &rewriter, Value value) -> FailureOr<Value> {
3306  auto results = pdllConvertValues(
3307  static_cast<ConversionPatternRewriter &>(rewriter), value);
3308  if (failed(results))
3309  return failure();
3310  return results->front();
3311  });
3312  patterns.getPDLPatterns().registerRewriteFunction(
3313  "convertValues", [](PatternRewriter &rewriter, ValueRange values) {
3314  return pdllConvertValues(
3315  static_cast<ConversionPatternRewriter &>(rewriter), values);
3316  });
3317  patterns.getPDLPatterns().registerRewriteFunction(
3318  "convertType",
3319  [](PatternRewriter &rewriter, Type type) -> FailureOr<Type> {
3320  auto &rewriterImpl =
3321  static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
3322  if (const TypeConverter *converter =
3323  rewriterImpl.currentTypeConverter) {
3324  if (Type newType = converter->convertType(type))
3325  return newType;
3326  return failure();
3327  }
3328  return type;
3329  });
3330  patterns.getPDLPatterns().registerRewriteFunction(
3331  "convertTypes",
3332  [](PatternRewriter &rewriter,
3333  TypeRange types) -> FailureOr<SmallVector<Type>> {
3334  auto &rewriterImpl =
3335  static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
3336  const TypeConverter *converter = rewriterImpl.currentTypeConverter;
3337  if (!converter)
3338  return SmallVector<Type>(types);
3339 
3340  SmallVector<Type> remappedTypes;
3341  if (failed(converter->convertTypes(types, remappedTypes)))
3342  return failure();
3343  return std::move(remappedTypes);
3344  });
3345 }
3346 #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
3347 
3348 //===----------------------------------------------------------------------===//
3349 // Op Conversion Entry Points
3350 //===----------------------------------------------------------------------===//
3351 
3352 //===----------------------------------------------------------------------===//
3353 // Partial Conversion
3354 
3356  ArrayRef<Operation *> ops, const ConversionTarget &target,
3358  OperationConverter opConverter(target, patterns, config,
3359  OpConversionMode::Partial);
3360  return opConverter.convertOperations(ops);
3361 }
3362 LogicalResult
3366  return applyPartialConversion(llvm::ArrayRef(op), target, patterns, config);
3367 }
3368 
3369 //===----------------------------------------------------------------------===//
3370 // Full Conversion
3371 
3373  const ConversionTarget &target,
3376  OperationConverter opConverter(target, patterns, config,
3377  OpConversionMode::Full);
3378  return opConverter.convertOperations(ops);
3379 }
3381  const ConversionTarget &target,
3384  return applyFullConversion(llvm::ArrayRef(op), target, patterns, config);
3385 }
3386 
3387 //===----------------------------------------------------------------------===//
3388 // Analysis Conversion
3389 
3390 /// Find a common IsolatedFromAbove ancestor of the given ops. If at least one
3391 /// op is a top-level module op (which is expected to be isolated from above),
3392 /// return that op.
3394  // Check if there is a top-level operation within `ops`. If so, return that
3395  // op.
3396  for (Operation *op : ops) {
3397  if (!op->getParentOp()) {
3398 #ifndef NDEBUG
3399  assert(op->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
3400  "expected top-level op to be isolated from above");
3401  for (Operation *other : ops)
3402  assert(op->isAncestor(other) &&
3403  "expected ops to have a common ancestor");
3404 #endif // NDEBUG
3405  return op;
3406  }
3407  }
3408 
3409  // No top-level op. Find a common ancestor.
3410  Operation *commonAncestor =
3411  ops.front()->getParentWithTrait<OpTrait::IsIsolatedFromAbove>();
3412  for (Operation *op : ops.drop_front()) {
3413  while (!commonAncestor->isProperAncestor(op)) {
3414  commonAncestor =
3416  assert(commonAncestor &&
3417  "expected to find a common isolated from above ancestor");
3418  }
3419  }
3420 
3421  return commonAncestor;
3422 }
3423 
3427 #ifndef NDEBUG
3428  if (config.legalizableOps)
3429  assert(config.legalizableOps->empty() && "expected empty set");
3430 #endif // NDEBUG
3431 
3432  // Clone closted common ancestor that is isolated from above.
3433  Operation *commonAncestor = findCommonAncestor(ops);
3434  IRMapping mapping;
3435  Operation *clonedAncestor = commonAncestor->clone(mapping);
3436  // Compute inverse IR mapping.
3437  DenseMap<Operation *, Operation *> inverseOperationMap;
3438  for (auto &it : mapping.getOperationMap())
3439  inverseOperationMap[it.second] = it.first;
3440 
3441  // Convert the cloned operations. The original IR will remain unchanged.
3442  SmallVector<Operation *> opsToConvert = llvm::map_to_vector(
3443  ops, [&](Operation *op) { return mapping.lookup(op); });
3444  OperationConverter opConverter(target, patterns, config,
3445  OpConversionMode::Analysis);
3446  LogicalResult status = opConverter.convertOperations(opsToConvert);
3447 
3448  // Remap `legalizableOps`, so that they point to the original ops and not the
3449  // cloned ops.
3450  if (config.legalizableOps) {
3451  DenseSet<Operation *> originalLegalizableOps;
3452  for (Operation *op : *config.legalizableOps)
3453  originalLegalizableOps.insert(inverseOperationMap[op]);
3454  *config.legalizableOps = std::move(originalLegalizableOps);
3455  }
3456 
3457  // Erase the cloned IR.
3458  clonedAncestor->erase();
3459  return status;
3460 }
3461 
3462 LogicalResult
3466  return applyAnalysisConversion(llvm::ArrayRef(op), target, patterns, config);
3467 }
static ConversionTarget::DynamicLegalityCallbackFn composeLegalityCallbacks(ConversionTarget::DynamicLegalityCallbackFn oldCallback, ConversionTarget::DynamicLegalityCallbackFn newCallback)
static FailureOr< SmallVector< Value > > pdllConvertValues(ConversionPatternRewriter &rewriter, ValueRange values)
Remap the given value using the rewriter and the type converter in the provided config.
static LogicalResult convertFuncOpTypes(FunctionOpInterface funcOp, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args)
A utility function to log a failure result for the given reason.
static LogicalResult legalizeUnresolvedMaterialization(RewriterBase &rewriter, UnresolvedMaterializationRewrite *rewrite)
static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args)
A utility function to log a successful result for the given reason.
SmallVector< Value, 1 > ValueVector
A vector of SSA values, optimized for the most common case of a single value.
static Operation * findCommonAncestor(ArrayRef< Operation * > ops)
Find a common IsolatedFromAbove ancestor of the given ops.
static OpBuilder::InsertPoint computeInsertPoint(Value value)
Helper function that computes an insertion point where the given value is defined and can be used wit...
static std::string diag(const llvm::Value &value)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void rewrite(DataFlowSolver &solver, MLIRContext *context, MutableArrayRef< Region > initialRegions)
Rewrite the given regions using the computing analysis.
Definition: SCCP.cpp:67
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:319
Block * getOwner() const
Returns the block that owns this argument.
Definition: Value.h:328
Location getLoc() const
Return the location for this argument.
Definition: Value.h:334
Block represents an ordered list of Operations.
Definition: Block.h:33
OpListType::iterator iterator
Definition: Block.h:140
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
Definition: Block.cpp:151
bool empty()
Definition: Block.h:148
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
unsigned getNumArguments()
Definition: Block.h:128
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition: Block.cpp:29
void dropAllDefinedValueUses()
This drops all uses of values defined in this block or in the blocks of nested regions wherever the u...
Definition: Block.cpp:96
OpListType & getOperations()
Definition: Block.h:137
BlockArgListType getArguments()
Definition: Block.h:87
Operation & front()
Definition: Block.h:153
iterator end()
Definition: Block.h:144
iterator begin()
Definition: Block.h:143
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:33
MLIRContext * getContext() const
Definition: Builders.h:56
Location getUnknownLoc()
Definition: Builders.cpp:27
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
LogicalResult getRemappedValues(ValueRange keys, SmallVectorImpl< Value > &results)
Return the converted values that replace 'keys' with types defined by the type converter of the curre...
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Apply a signature conversion to each block in the given region.
void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt) override
PatternRewriter hook for inlining the ops of a block into another block.
Block * applySignatureConversion(Block *block, TypeConverter::SignatureConversion &conversion, const TypeConverter *converter=nullptr)
Apply a signature conversion to given block.
void startOpModification(Operation *op) override
PatternRewriter hook for updating the given operation in-place.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
void replaceOpWithMultiple(Operation *op, ArrayRef< ValueRange > newValues)
Replace the given operation with the new value ranges.
detail::ConversionPatternRewriterImpl & getImpl()
Return a reference to the internal implementation.
void eraseBlock(Block *block) override
PatternRewriter hook for erase all operations in a block.
void cancelOpModification(Operation *op) override
PatternRewriter hook for updating the given operation in-place.
Value getRemappedValue(Value key)
Return the converted value of 'key' with a type defined by the type converter of the currently execut...
void finalizeOpModification(Operation *op) override
PatternRewriter hook for updating the given operation in-place.
void replaceUsesOfBlockArgument(BlockArgument from, Value to)
Replace all the uses of the block argument from with value to.
Base class for the conversion patterns.
SmallVector< Value > getOneToOneAdaptorOperands(ArrayRef< ValueRange > operands) const
Given an array of value ranges, which are the inputs to a 1:N adaptor, try to extract the single valu...
virtual LogicalResult matchAndRewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const
Hook for derived classes to implement combined matching and rewriting.
This class describes a specific conversion target.
void setDialectAction(ArrayRef< StringRef > dialectNames, LegalizationAction action)
Register a legality action for the given dialects.
void setOpAction(OperationName op, LegalizationAction action)
Register a legality action for the given operation.
std::optional< LegalOpDetails > isLegal(Operation *op) const
If the given operation instance is legal on this target, a structure containing legality information ...
std::optional< LegalizationAction > getOpAction(OperationName op) const
Get the legality action for the given operation.
LegalizationAction
This enumeration corresponds to the specific action to take when considering an operation legal for t...
void markOpRecursivelyLegal(OperationName name, const DynamicLegalityCallbackFn &callback)
Mark an operation, that must have either been set as Legal or DynamicallyLegal, as being recursively ...
std::function< std::optional< bool >(Operation *)> DynamicLegalityCallbackFn
The signature of the callback used to determine if an operation is dynamically legal on the target.
bool isIllegal(Operation *op) const
Returns true is operation instance is illegal on this target.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
Definition: Diagnostics.h:155
A class for computing basic dominance information.
Definition: Dominance.h:140
bool dominates(Operation *a, Operation *b) const
Return true if operation A dominates operation B, i.e.
Definition: Dominance.h:158
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
const DenseMap< Operation *, Operation * > & getOperationMap() const
Return the held operation mapping.
Definition: IRMapping.h:88
auto lookup(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:72
user_range getUsers() const
Returns a range of all users.
Definition: UseDefLists.h:274
void replaceAllUsesWith(ValueT &&newValue)
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
Definition: UseDefLists.h:211
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:772
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:314
Location objects represent source locations information in MLIR.
Definition: Location.h:31
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
bool isMultithreadingEnabled()
Return true if multi-threading is enabled by the context.
This class represents a saved insertion point.
Definition: Builders.h:325
Block::iterator getPoint() const
Definition: Builders.h:338
bool isSet() const
Returns true if this insert point is set.
Definition: Builders.h:335
Block * getBlock() const
Definition: Builders.h:337
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class helps build Operations.
Definition: Builders.h:205
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:396
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Definition: Builders.h:318
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:426
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
LogicalResult tryFold(Operation *op, SmallVectorImpl< Value > &results)
Attempts to fold the given operation and places new results within results.
Definition: Builders.cpp:468
OpInterfaceConversionPattern is a wrapper around ConversionPattern that allows for matching and rewri...
OpInterfaceConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
This class represents an operand of an operation.
Definition: Value.h:267
This is a value defined by a result of an operation.
Definition: Value.h:457
This class provides the API for ops that are known to be isolated from above.
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:42
type_range getTypes() const
Definition: ValueRange.cpp:26
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
void setLoc(Location loc)
Set the source location the operation was defined or derived from.
Definition: Operation.h:226
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:750
void dropAllUses()
Drop all uses of results of this operation.
Definition: Operation.h:835
void setAttrs(DictionaryAttr newAttrs)
Set the attributes from a dictionary on this operation.
Definition: Operation.cpp:305
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Definition: Operation.cpp:386
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
Definition: Operation.cpp:717
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:798
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:674
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:512
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
Definition: Operation.h:248
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:677
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
operand_type_range getOperandTypes()
Definition: Operation.h:397
result_type_range getResultTypes()
Definition: Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
void setSuccessor(Block *block, unsigned index)
Definition: Operation.cpp:605
bool isAncestor(Operation *other)
Return true if this operation is an ancestor of the other operation.
Definition: Operation.h:263
void setOperands(ValueRange operands)
Replace the current operands of this operation with the ones provided in 'operands'.
Definition: Operation.cpp:237
result_range getResults()
Definition: Operation.h:415
int getPropertiesStorageSize() const
Returns the properties storage size.
Definition: Operation.h:897
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
Definition: Operation.cpp:219
OpaqueProperties getPropertiesStorage()
Returns the properties storage.
Definition: Operation.h:901
void erase()
Remove this operation from its parent block and delete it.
Definition: Operation.cpp:539
void copyProperties(OpaqueProperties rhs)
Copy properties from an existing other properties object.
Definition: Operation.cpp:366
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
void notifyRewriteEnd(PatternRewriter &rewriter) final
void notifyRewriteBegin(PatternRewriter &rewriter) final
Hooks that are invoked at the beginning and end of a rewrite of a matched pattern.
This class manages the application of a group of rewrite patterns, with a user-provided cost model.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
static PatternBenefit impossibleToMatch()
Definition: PatternMatch.h:43
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
Definition: PatternMatch.h:73
bool hasBoundedRewriteRecursion() const
Returns true if this pattern is known to result in recursive application, i.e.
Definition: PatternMatch.h:129
std::optional< OperationName > getRootKind() const
Return the root node that this pattern matches.
Definition: PatternMatch.h:94
ArrayRef< OperationName > getGeneratedOps() const
Return a list of operations that may be generated when rewriting an operation instance with this patt...
Definition: PatternMatch.h:90
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition: Region.h:200
bool empty()
Definition: Region.h:60
iterator end()
Definition: Region.h:56
BlockListType & getBlocks()
Definition: Region.h:45
Block & front()
Definition: Region.h:65
BlockListType::iterator iterator
Definition: Region.h:52
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:724
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:644
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void moveOpBefore(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:636
The general result of a type attribute conversion callback, allowing for early termination.
static AttributeConversionResult result(Attribute attr)
This class provides all of the information necessary to convert a type signature.
void addInputs(unsigned origInputNo, ArrayRef< Type > types)
Remap an input of the original signature with a new set of types.
std::optional< InputMapping > getInputMapping(unsigned input) const
Get the input mapping for the given argument.
ArrayRef< Type > getConvertedTypes() const
Return the argument types for the new signature.
void remapInput(unsigned origInputNo, Value replacement)
Remap an input of the original signature to another replacement value.
Type conversion class.
std::optional< Attribute > convertTypeAttribute(Type type, Attribute attr) const
Convert an attribute present attr from within the type type using the registered conversion functions...
Value materializeSourceConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs) const
bool isLegal(Type type) const
Return true if the given type is legal for this type converter, i.e.
LogicalResult convertSignatureArgs(TypeRange types, SignatureConversion &result, unsigned origInputOffset=0) const
LogicalResult convertSignatureArg(unsigned inputNo, Type type, SignatureConversion &result) const
This method allows for converting a specific argument of a signature.
Value materializeArgumentConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs) const
Materialize a conversion from a set of types into one result type by generating a cast sequence of so...
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
std::optional< SignatureConversion > convertBlockSignature(Block *block) const
This function converts the type signature of the given block, by invoking 'convertSignatureArg' for e...
LogicalResult convertTypes(TypeRange types, SmallVectorImpl< Type > &results) const
Convert the given set of types, filling 'results' as necessary.
bool isSignatureLegal(FunctionType ty) const
Return true if the inputs and outputs of the given function type are legal.
Value materializeTargetConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs, Type originalType={}) const
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Block * getParentBlock()
Return the Block in which this Value is defined.
Definition: Value.cpp:48
user_range getUsers() const
Definition: Value.h:228
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
static WalkResult skip()
Definition: Visitors.h:52
static WalkResult advance()
Definition: Visitors.h:51
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
AttrTypeReplacer.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
@ Full
Documents are synced by always sending the full content of the document.
Kind
An enumeration of the kinds of predicates.
Definition: Predicate.h:44
Include the generated interface declarations.
void populateFunctionOpInterfaceTypeConversionPattern(StringRef functionLikeOpName, RewritePatternSet &patterns, const TypeConverter &converter)
Add a pattern to the given pattern list to convert the signature of a FunctionOpInterface op with the...
void populateAnyFunctionOpInterfaceTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter)
const FrozenRewritePatternSet GreedyRewriteConfig config
LogicalResult applyFullConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Apply a complete conversion on the given operations, and all nested operations.
FailureOr< Operation * > convertOpResultTypes(Operation *op, ValueRange operands, const TypeConverter &converter, ConversionPatternRewriter &rewriter)
Generic utility to convert op result types according to type converter without knowing exact op type.
const FrozenRewritePatternSet & patterns
LogicalResult applyAnalysisConversion(ArrayRef< Operation * > ops, ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Apply an analysis conversion on the given operations, and all nested operations.
void reconcileUnrealizedCasts(ArrayRef< UnrealizedConversionCastOp > castOps, SmallVectorImpl< UnrealizedConversionCastOp > *remainingCastOps=nullptr)
Try to reconcile all given UnrealizedConversionCastOps and store the left-over ops in remainingCastOp...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void registerConversionPDLFunctions(RewritePatternSet &patterns)
Register the dialect conversion PDL functions with the given pattern set.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
Dialect conversion configuration.
RewriterBase::Listener * listener
An optional listener that is notified about all IR modifications in case dialect conversion succeeds.
function_ref< void(Diagnostic &)> notifyCallback
An optional callback used to notify about match failure diagnostics during the conversion.
DenseSet< Operation * > * legalizableOps
Analysis conversion only.
DenseSet< Operation * > * unlegalizedOps
Partial conversion only.
bool buildMaterializations
If set to "true", the dialect conversion attempts to build source/target materializations through the...
A structure containing additional information describing a specific legal operation instance.
bool isRecursivelyLegal
A flag that indicates if this operation is 'recursively' legal.
This iterator enumerates elements according to their dominance relationship.
Definition: Iterators.h:48
LogicalResult convertOperations(ArrayRef< Operation * > ops)
Converts the given operations to the conversion target.
OperationConverter(const ConversionTarget &target, const FrozenRewritePatternSet &patterns, const ConversionConfig &config, OpConversionMode mode)
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addTypes(ArrayRef< Type > newTypes)
This struct represents a range of new types or a single value that remaps an existing signature input...
A rewriter that keeps track of erased ops and blocks.
void eraseOp(Operation *op) override
Erase the given op (unless it was already erased).
void notifyBlockErased(Block *block) override
Notify the listener that the specified block is about to be erased.
void notifyOperationErased(Operation *op) override
Notify the listener that the specified operation is about to be erased.
void eraseBlock(Block *block) override
Erase the given block (unless it was already erased).
void notifyOperationInserted(Operation *op, OpBuilder::InsertPoint previous) override
Notify the listener that the specified operation was inserted.
Value findOrBuildReplacementValue(Value value, const TypeConverter *converter)
Find a replacement value for the given SSA value in the conversion value mapping.
ConversionPatternRewriterImpl(MLIRContext *ctx, const ConversionConfig &config)
DenseMap< Region *, const TypeConverter * > regionToConverter
A mapping of regions to type converters that should be used when converting the arguments of blocks w...
bool wasOpReplaced(Operation *op) const
Return "true" if the given operation was replaced or erased.
void notifyBlockInserted(Block *block, Region *previous, Region::iterator previousIt) override
Notifies that a block was inserted.
DenseMap< UnrealizedConversionCastOp, UnresolvedMaterializationRewrite * > unresolvedMaterializations
A mapping of all unresolved materializations (UnrealizedConversionCastOp) to the corresponding rewrit...
LogicalResult remapValues(StringRef valueDiagTag, std::optional< Location > inputLoc, PatternRewriter &rewriter, ValueRange values, SmallVector< ValueVector > &remapped)
Remap the given values to those with potentially different types.
void resetState(RewriterState state)
Reset the state of the rewriter to a previously saved point.
Block * applySignatureConversion(ConversionPatternRewriter &rewriter, Block *block, const TypeConverter *converter, TypeConverter::SignatureConversion &signatureConversion)
Apply the given signature conversion on the given block.
FailureOr< Block * > convertRegionTypes(ConversionPatternRewriter &rewriter, Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion)
Convert the types of block arguments within the given region.
void applyRewrites()
Apply all requested operation rewrites.
void undoRewrites(unsigned numRewritesToKeep=0)
Undo the rewrites (motions, splits) one by one in reverse order until "numRewritesToKeep" rewrites re...
RewriterState getCurrentState()
Return the current state of the rewriter.
llvm::ScopedPrinter logger
A logger used to emit diagnostics during the conversion process.
void notifyBlockBeingInlined(Block *block, Block *srcBlock, Block::iterator before)
Notifies that a block is being inlined into another block.
void appendRewrite(Args &&...args)
Append a rewrite.
SmallPtrSet< Operation *, 1 > pendingRootUpdates
A set of operations that have pending updates.
void notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
Notifies that a pattern match failed for the given reason.
SingleEraseRewriter eraseRewriter
A rewriter that keeps track of ops/block that were already erased and skips duplicate op/block erasur...
bool isOpIgnored(Operation *op) const
Return "true" if the given operation is ignored, and does not need to be converted.
ValueRange buildUnresolvedMaterialization(MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc, ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes, Type originalType, const TypeConverter *converter, UnrealizedConversionCastOp *castOp=nullptr)
Build an unresolved materialization operation given a range of output types and a list of input opera...
SetVector< Operation * > ignoredOps
A set of operations that should no longer be considered for legalization.
SmallVector< std::unique_ptr< IRRewrite > > rewrites
Ordered list of block operations (creations, splits, motions).
void notifyOpReplaced(Operation *op, ArrayRef< ValueRange > newValues)
Notifies that an op is about to be replaced with the given values.
const ConversionConfig & config
Dialect conversion configuration.
SetVector< Operation * > replacedOps
A set of operations that were replaced/erased.
void notifyBlockIsBeingErased(Block *block)
Notifies that a block is about to be erased.
const TypeConverter * currentTypeConverter
The current type converter, or nullptr if no type converter is currently active.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.